Skip to main content

sphereql_layout/
clustered.rs

1use std::f64::consts::PI;
2
3use sphereql_core::{
4    CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5    spherical_to_cartesian,
6};
7
8use crate::quality::{MAX_QUALITY_N, OVERLAP_THRESHOLD};
9use crate::traits::{DimensionMapper, LayoutStrategy};
10use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
11
12const MAX_KMEANS_ITERATIONS: usize = 50;
13
14pub struct ClusteredLayout {
15    pub num_clusters: usize,
16    pub radius: f64,
17    pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21    fn default() -> Self {
22        Self {
23            num_clusters: 4,
24            radius: 1.0,
25            intra_cluster_spread: 0.3,
26        }
27    }
28}
29
30impl ClusteredLayout {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn with_clusters(mut self, n: usize) -> Self {
36        self.num_clusters = n;
37        self
38    }
39
40    pub fn with_radius(mut self, r: f64) -> Self {
41        self.radius = r;
42        self
43    }
44
45    pub fn with_spread(mut self, s: f64) -> Self {
46        self.intra_cluster_spread = s;
47        self
48    }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52    let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53    (0..k)
54        .map(|i| {
55            let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56                .clamp(-1.0, 1.0)
57                .acos();
58            let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59            let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60            spherical_to_cartesian(&sp)
61        })
62        .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66    if points.is_empty() {
67        return CartesianPoint::new(0.0, 0.0, 1.0);
68    }
69    let mut sx = 0.0;
70    let mut sy = 0.0;
71    let mut sz = 0.0;
72    for p in points {
73        sx += p.x;
74        sy += p.y;
75        sz += p.z;
76    }
77    let mean = CartesianPoint::new(sx, sy, sz);
78    let n = mean.normalize();
79    // Antipodal points cancel; fall back to the first member that has a
80    // well-defined direction. If every member is degenerate (mag ≈ 0),
81    // fall back to the canonical north pole rather than returning an
82    // unnormalized zero vector.
83    if n.magnitude() < 1e-12 {
84        for p in points {
85            if p.magnitude() >= 1e-12 {
86                return p.normalize();
87            }
88        }
89        CartesianPoint::new(0.0, 0.0, 1.0)
90    } else {
91        n
92    }
93}
94
95struct KMeansResult {
96    assignments: Vec<usize>,
97    centers: Vec<CartesianPoint>,
98}
99
100fn kmeans_spherical(mapped_cartesian: &[CartesianPoint], k: usize) -> KMeansResult {
101    let n = mapped_cartesian.len();
102
103    let mut centers: Vec<CartesianPoint> = if n >= k {
104        mapped_cartesian[..k]
105            .iter()
106            .map(|c| c.normalize())
107            .collect()
108    } else {
109        evenly_spaced_centers(k)
110    };
111
112    let mut assignments = vec![0usize; n];
113
114    // k-means inner loop in Cartesian. Both points and centers live on
115    // the unit sphere, so angular distance `acos(dot)` and
116    // `maximize dot` pick the same cluster for every point (acos is
117    // monotonic). The previous implementation converted every
118    // Cartesian center back to spherical and called `angular_distance`
119    // per (point, center) pair — ~50·n·k redundant trig calls per
120    // k-means run, which dominated layout cost.
121    #[inline]
122    fn dot(a: &CartesianPoint, b: &CartesianPoint) -> f64 {
123        a.x * b.x + a.y * b.y + a.z * b.z
124    }
125
126    for _ in 0..MAX_KMEANS_ITERATIONS {
127        let mut changed = false;
128
129        for (i, point) in mapped_cartesian.iter().enumerate() {
130            let mut best = 0;
131            let mut best_dot = f64::MIN;
132            for (j, center) in centers.iter().enumerate() {
133                let d = dot(point, center);
134                if d > best_dot {
135                    best_dot = d;
136                    best = j;
137                }
138            }
139            if assignments[i] != best {
140                assignments[i] = best;
141                changed = true;
142            }
143        }
144
145        if !changed {
146            break;
147        }
148
149        let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
150        for (i, &a) in assignments.iter().enumerate() {
151            cluster_points[a].push(mapped_cartesian[i]);
152        }
153
154        for (j, cp) in cluster_points.iter().enumerate() {
155            if cp.is_empty() {
156                // Reseed from whichever point is farthest from its
157                // current cluster center — Cartesian dot gives the
158                // ordering (smaller dot = larger angular distance).
159                let mut farthest_idx = 0;
160                let mut farthest_dot = f64::MAX;
161                for (i, point) in mapped_cartesian.iter().enumerate() {
162                    let d = dot(point, &centers[assignments[i]]);
163                    if d < farthest_dot {
164                        farthest_dot = d;
165                        farthest_idx = i;
166                    }
167                }
168                centers[j] = mapped_cartesian[farthest_idx].normalize();
169            } else {
170                centers[j] = normalized_mean(cp);
171            }
172        }
173    }
174
175    KMeansResult {
176        assignments,
177        centers,
178    }
179}
180
181fn fibonacci_sub_spiral(
182    center: &SphericalPoint,
183    count: usize,
184    spread: f64,
185    radius: f64,
186) -> Vec<SphericalPoint> {
187    if count == 0 {
188        return vec![];
189    }
190    if count == 1 {
191        return vec![SphericalPoint::new_unchecked(
192            radius,
193            center.theta,
194            center.phi,
195        )];
196    }
197
198    let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
199    let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
200        1.0,
201        center.theta,
202        center.phi,
203    ));
204
205    let (tangent_u, tangent_v) = local_frame(&center_cart);
206
207    (0..count)
208        .map(|i| {
209            let frac = i as f64 / count as f64;
210            let angular_r = spread * frac.sqrt();
211            let angle = golden_angle * i as f64;
212
213            let offset_u = angular_r * angle.cos();
214            let offset_v = angular_r * angle.sin();
215
216            let displaced = CartesianPoint::new(
217                center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
218                center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
219                center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
220            )
221            .normalize();
222
223            let sp = cartesian_to_spherical(&displaced);
224            SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
225        })
226        .collect()
227}
228
229fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
230    let up = if center.z.abs() < 0.9 {
231        CartesianPoint::new(0.0, 0.0, 1.0)
232    } else {
233        CartesianPoint::new(1.0, 0.0, 0.0)
234    };
235
236    // u = normalize(up x center)
237    let ux = up.y * center.z - up.z * center.y;
238    let uy = up.z * center.x - up.x * center.z;
239    let uz = up.x * center.y - up.y * center.x;
240    let u = CartesianPoint::new(ux, uy, uz).normalize();
241
242    // v = center x u
243    let vx = center.y * u.z - center.z * u.y;
244    let vy = center.z * u.x - center.x * u.z;
245    let vz = center.x * u.y - center.y * u.x;
246    let v = CartesianPoint::new(vx, vy, vz).normalize();
247
248    (u, v)
249}
250
251fn compute_quality(
252    positions: &[SphericalPoint],
253    assignments: &[usize],
254    num_clusters: usize,
255) -> LayoutQuality {
256    let n = positions.len();
257
258    if n <= 1 {
259        return LayoutQuality {
260            dispersion_score: if n == 0 { 0.0 } else { 1.0 },
261            overlap_score: 0.0,
262            silhouette_score: 0.0,
263        };
264    }
265
266    let (positions, assignments, n) = if n > MAX_QUALITY_N {
267        let step = n / MAX_QUALITY_N;
268        let sampled_pos: Vec<_> = positions
269            .iter()
270            .step_by(step)
271            .take(MAX_QUALITY_N)
272            .copied()
273            .collect();
274        let sampled_asgn: Vec<_> = assignments
275            .iter()
276            .step_by(step)
277            .take(MAX_QUALITY_N)
278            .copied()
279            .collect();
280        let len = sampled_pos.len();
281        (sampled_pos, sampled_asgn, len)
282    } else {
283        (positions.to_vec(), assignments.to_vec(), n)
284    };
285
286    // Dispersion: average inter-cluster center distance / PI
287    let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
288    for (i, &a) in assignments.iter().enumerate() {
289        cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
290    }
291    let active_centers: Vec<SphericalPoint> = cluster_point_sets
292        .iter()
293        .filter(|cp| !cp.is_empty())
294        .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
295        .collect();
296
297    // Every pair-scan below is embarrassingly parallel: pure reads
298    // over `positions` / `active_centers` / `assignments`, per-point
299    // scalar reductions. `rayon::par_iter` over the outer index
300    // keeps the reduce trivial and skips the thread-pool overhead
301    // for small corpora.
302    use rayon::prelude::*;
303    const SERIAL_THRESHOLD: usize = 128;
304
305    let dispersion_score = if active_centers.len() >= 2 {
306        let len = active_centers.len();
307        let (sum, count) = if len < SERIAL_THRESHOLD {
308            let mut s = 0.0;
309            let mut c = 0u64;
310            for i in 0..len {
311                for j in (i + 1)..len {
312                    s += angular_distance(&active_centers[i], &active_centers[j]);
313                    c += 1;
314                }
315            }
316            (s, c)
317        } else {
318            (0..len)
319                .into_par_iter()
320                .map(|i| {
321                    let mut s = 0.0;
322                    let mut c = 0u64;
323                    for j in (i + 1)..len {
324                        s += angular_distance(&active_centers[i], &active_centers[j]);
325                        c += 1;
326                    }
327                    (s, c)
328                })
329                .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb))
330        };
331        (sum / count as f64 / PI).clamp(0.0, 1.0)
332    } else {
333        0.0
334    };
335
336    // Overlap: fraction of pairs within threshold.
337    let total_pairs = (n * (n - 1)) / 2;
338    let overlap_count: u64 = if n < SERIAL_THRESHOLD {
339        let mut c = 0u64;
340        for i in 0..n {
341            for j in (i + 1)..n {
342                if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
343                    c += 1;
344                }
345            }
346        }
347        c
348    } else {
349        (0..n)
350            .into_par_iter()
351            .map(|i| {
352                let mut c = 0u64;
353                for j in (i + 1)..n {
354                    if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
355                        c += 1;
356                    }
357                }
358                c
359            })
360            .sum()
361    };
362    let overlap_score = if total_pairs > 0 {
363        overlap_count as f64 / total_pairs as f64
364    } else {
365        0.0
366    };
367
368    // Silhouette coefficient. Pre-bucket assignments by cluster so the
369    // per-point inner work visits each remote point exactly once across
370    // all `k`. Without bucketing the inner loop scans `n` points per
371    // cluster, making the function O(n² · num_clusters); with it the
372    // total stays at O(n²). Singleton clusters trivially have a(i) = 0
373    // and are short-circuited.
374    let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
375        0.0
376    } else {
377        let cluster_members: Vec<Vec<usize>> = {
378            let mut buckets = vec![Vec::new(); num_clusters];
379            for (j, &cj) in assignments.iter().enumerate() {
380                if cj < num_clusters {
381                    buckets[cj].push(j);
382                }
383            }
384            buckets
385        };
386
387        let per_point = |i: usize| -> f64 {
388            let ci = assignments[i];
389            let same = &cluster_members[ci];
390
391            let a = if same.len() <= 1 {
392                0.0
393            } else {
394                let s: f64 = same
395                    .iter()
396                    .filter(|&&j| j != i)
397                    .map(|&j| angular_distance(&positions[i], &positions[j]))
398                    .sum();
399                s / (same.len() - 1) as f64
400            };
401
402            let mut b = f64::MAX;
403            for (k, members) in cluster_members.iter().enumerate() {
404                if k == ci || members.is_empty() {
405                    continue;
406                }
407                let s: f64 = members
408                    .iter()
409                    .map(|&j| angular_distance(&positions[i], &positions[j]))
410                    .sum();
411                let mean_dist = s / members.len() as f64;
412                if mean_dist < b {
413                    b = mean_dist;
414                }
415            }
416            if b == f64::MAX {
417                b = 0.0;
418            }
419
420            let denom = a.max(b);
421            if denom > 0.0 { (b - a) / denom } else { 0.0 }
422        };
423
424        let sil_sum: f64 = if n < SERIAL_THRESHOLD {
425            (0..n).map(per_point).sum()
426        } else {
427            (0..n).into_par_iter().map(per_point).sum()
428        };
429        sil_sum / n as f64
430    };
431
432    LayoutQuality {
433        dispersion_score,
434        overlap_score,
435        silhouette_score,
436    }
437}
438
439impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
440    fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
441        if items.is_empty() {
442            return LayoutResult {
443                entries: vec![],
444                quality: LayoutQuality::default(),
445            };
446        }
447
448        let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
449        let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
450
451        let k = self.num_clusters.min(items.len()).max(1);
452        let km = kmeans_spherical(&mapped_cart, k);
453
454        let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
455        for (i, &a) in km.assignments.iter().enumerate() {
456            cluster_items[a].push(i);
457        }
458
459        let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
460        let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
461        let mut final_assignments = vec![0usize; items.len()];
462
463        for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
464            let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
465            let sub_positions = fibonacci_sub_spiral(
466                &center_sp,
467                member_indices.len(),
468                self.intra_cluster_spread,
469                self.radius,
470            );
471
472            for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
473                let pos = sub_positions[sub_idx];
474                entries.push((
475                    item_idx,
476                    LayoutEntry {
477                        item: items[item_idx].clone(),
478                        position: pos,
479                    },
480                ));
481                final_positions.push((item_idx, pos));
482                final_assignments[item_idx] = cluster_idx;
483            }
484        }
485
486        entries.sort_by_key(|(idx, _)| *idx);
487        let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
488
489        final_positions.sort_by_key(|(idx, _)| *idx);
490        let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
491
492        let quality = compute_quality(&positions, &final_assignments, k);
493
494        LayoutResult { entries, quality }
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    struct FixedMapper {
503        positions: Vec<SphericalPoint>,
504    }
505
506    impl DimensionMapper for FixedMapper {
507        type Item = usize;
508        fn map(&self, item: &usize) -> SphericalPoint {
509            self.positions[*item]
510        }
511    }
512
513    #[test]
514    fn empty_items_returns_empty_result() {
515        let layout = ClusteredLayout::new();
516        let mapper = FixedMapper { positions: vec![] };
517        let result = layout.layout(&[], &mapper);
518        assert!(result.entries.is_empty());
519    }
520
521    #[test]
522    fn single_item_gets_placed() {
523        let layout = ClusteredLayout::new().with_clusters(1);
524        let mapper = FixedMapper {
525            positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
526        };
527        let result = layout.layout(&[0usize], &mapper);
528        assert_eq!(result.entries.len(), 1);
529        assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
530    }
531
532    #[test]
533    fn correct_number_of_entries() {
534        let layout = ClusteredLayout::new().with_clusters(3);
535        let positions: Vec<SphericalPoint> = (0..20)
536            .map(|i| {
537                let theta = (i as f64 * 0.3) % (2.0 * PI);
538                SphericalPoint::new_unchecked(1.0, theta, 1.0)
539            })
540            .collect();
541        let mapper = FixedMapper { positions };
542        let items: Vec<usize> = (0..20).collect();
543        let result = layout.layout(&items, &mapper);
544        assert_eq!(result.entries.len(), 20);
545    }
546
547    #[test]
548    fn items_in_same_cluster_are_angularly_close() {
549        let mut positions = Vec::new();
550        for i in 0..5 {
551            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
552        }
553        for i in 0..5 {
554            positions.push(SphericalPoint::new_unchecked(
555                1.0,
556                0.01 * i as f64,
557                PI - 0.1,
558            ));
559        }
560        let mapper = FixedMapper { positions };
561        let items: Vec<usize> = (0..10).collect();
562        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
563        let result = layout.layout(&items, &mapper);
564
565        let group_a: Vec<&SphericalPoint> =
566            result.entries[..5].iter().map(|e| &e.position).collect();
567        for i in 0..group_a.len() {
568            for j in (i + 1)..group_a.len() {
569                let d = angular_distance(group_a[i], group_a[j]);
570                assert!(d < 1.0, "Intra-cluster distance too large: {d}");
571            }
572        }
573    }
574
575    #[test]
576    fn different_clusters_are_angularly_separated() {
577        let mut positions = Vec::new();
578        for i in 0..5 {
579            positions.push(SphericalPoint::new_unchecked(
580                1.0,
581                0.01 * i as f64,
582                PI / 2.0,
583            ));
584        }
585        for i in 0..5 {
586            positions.push(SphericalPoint::new_unchecked(
587                1.0,
588                PI + 0.01 * i as f64,
589                PI / 2.0,
590            ));
591        }
592        let mapper = FixedMapper { positions };
593        let items: Vec<usize> = (0..10).collect();
594        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
595        let result = layout.layout(&items, &mapper);
596
597        let p_a = &result.entries[0].position;
598        let p_b = &result.entries[5].position;
599        let d = angular_distance(p_a, p_b);
600        assert!(d > 1.0, "Inter-cluster distance too small: {d}");
601    }
602
603    #[test]
604    fn silhouette_positive_for_well_separated_data() {
605        let mut positions = Vec::new();
606        for i in 0..10 {
607            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
608        }
609        for i in 0..10 {
610            positions.push(SphericalPoint::new_unchecked(
611                1.0,
612                PI + 0.01 * i as f64,
613                PI - 0.2,
614            ));
615        }
616        let mapper = FixedMapper { positions };
617        let items: Vec<usize> = (0..20).collect();
618        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
619        let result = layout.layout(&items, &mapper);
620        assert!(
621            result.quality.silhouette_score > 0.0,
622            "Silhouette should be positive for well-separated clusters, got {}",
623            result.quality.silhouette_score
624        );
625    }
626
627    #[test]
628    fn builder_methods_apply() {
629        let layout = ClusteredLayout::new()
630            .with_clusters(8)
631            .with_radius(2.5)
632            .with_spread(0.5);
633        assert_eq!(layout.num_clusters, 8);
634        assert!((layout.radius - 2.5).abs() < 1e-12);
635        assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
636    }
637
638    #[test]
639    fn output_radius_matches_configured() {
640        let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
641        let positions = vec![
642            SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
643            SphericalPoint::new_unchecked(1.0, PI, 2.0),
644        ];
645        let mapper = FixedMapper { positions };
646        let result = layout.layout(&[0usize, 1], &mapper);
647        for entry in &result.entries {
648            assert!(
649                (entry.position.r - 3.0).abs() < 1e-12,
650                "Expected radius 3.0, got {}",
651                entry.position.r
652            );
653        }
654    }
655}