Skip to main content

sphereql_layout/
force.rs

1use std::f64::consts::PI;
2
3use sphereql_core::{
4    CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5    spherical_to_cartesian,
6};
7
8use crate::quality::{MAX_QUALITY_N, OVERLAP_THRESHOLD};
9use crate::traits::{DimensionMapper, LayoutStrategy};
10use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
11
12const EPSILON: f64 = 1e-6;
13const STEP_SIZE_FACTOR: f64 = 0.1;
14
15pub struct ForceDirectedLayout {
16    pub iterations: usize,
17    pub repulsion_strength: f64,
18    pub attraction_strength: f64,
19    pub cooling_rate: f64,
20    pub radius: f64,
21}
22
23impl ForceDirectedLayout {
24    pub fn new() -> Self {
25        Self {
26            iterations: 100,
27            repulsion_strength: 1.0,
28            attraction_strength: 0.1,
29            cooling_rate: 0.95,
30            radius: 1.0,
31        }
32    }
33
34    pub fn with_iterations(mut self, n: usize) -> Self {
35        self.iterations = n;
36        self
37    }
38
39    pub fn with_repulsion(mut self, f: f64) -> Self {
40        self.repulsion_strength = f;
41        self
42    }
43
44    pub fn with_attraction(mut self, f: f64) -> Self {
45        self.attraction_strength = f;
46        self
47    }
48
49    pub fn with_cooling(mut self, f: f64) -> Self {
50        self.cooling_rate = f;
51        self
52    }
53
54    pub fn with_radius(mut self, r: f64) -> Self {
55        self.radius = r;
56        self
57    }
58
59    fn project_to_unit_sphere(p: &SphericalPoint) -> CartesianPoint {
60        let unit = SphericalPoint::new_unchecked(1.0, p.theta, p.phi);
61        spherical_to_cartesian(&unit)
62    }
63
64    fn compute_quality(positions: &[SphericalPoint], n: usize) -> LayoutQuality {
65        if n <= 1 {
66            return LayoutQuality {
67                dispersion_score: 1.0,
68                overlap_score: 0.0,
69                silhouette_score: 0.0,
70            };
71        }
72
73        let (positions, n) = if n > MAX_QUALITY_N {
74            let step = n / MAX_QUALITY_N;
75            let sampled: Vec<_> = positions
76                .iter()
77                .step_by(step)
78                .take(MAX_QUALITY_N)
79                .copied()
80                .collect();
81            let len = sampled.len();
82            (sampled, len)
83        } else {
84            (positions.to_vec(), n)
85        };
86
87        let ideal_spacing = (4.0 * PI / n as f64).sqrt();
88        let total_pairs = (n * (n - 1) / 2) as u64;
89
90        // Parallel pair-scan: each `i`-worker tracks its own (min, count)
91        // pair and reduces at the end. f64::min takes NaN as "the other",
92        // so we route through `.total_cmp` for predictable results.
93        use rayon::prelude::*;
94        const SERIAL_THRESHOLD: usize = 128;
95        let per_i = |i: usize| -> (f64, u64) {
96            let mut min_local = f64::MAX;
97            let mut count_local = 0u64;
98            for j in (i + 1)..n {
99                let d = angular_distance(&positions[i], &positions[j]);
100                if d < min_local {
101                    min_local = d;
102                }
103                if d < OVERLAP_THRESHOLD {
104                    count_local += 1;
105                }
106            }
107            (min_local, count_local)
108        };
109        let (min_dist, overlap_count) = if n < SERIAL_THRESHOLD {
110            (0..n)
111                .map(per_i)
112                .fold((f64::MAX, 0u64), |(ma, ca), (mb, cb)| {
113                    (if mb < ma { mb } else { ma }, ca + cb)
114                })
115        } else {
116            (0..n).into_par_iter().map(per_i).reduce(
117                || (f64::MAX, 0u64),
118                |(ma, ca), (mb, cb)| (if mb < ma { mb } else { ma }, ca + cb),
119            )
120        };
121
122        let dispersion = (min_dist / ideal_spacing).clamp(0.0, 1.0);
123        let overlap = overlap_count as f64 / total_pairs as f64;
124
125        LayoutQuality {
126            dispersion_score: dispersion,
127            overlap_score: overlap,
128            silhouette_score: 0.0,
129        }
130    }
131}
132
133impl Default for ForceDirectedLayout {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl<T: Clone> LayoutStrategy<T> for ForceDirectedLayout {
140    fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
141        let n = items.len();
142
143        if n == 0 {
144            return LayoutResult {
145                entries: Vec::new(),
146                quality: LayoutQuality::default(),
147            };
148        }
149
150        let original_positions: Vec<SphericalPoint> =
151            items.iter().map(|item| mapper.map(item)).collect();
152
153        let original_cartesian: Vec<CartesianPoint> = original_positions
154            .iter()
155            .map(Self::project_to_unit_sphere)
156            .collect();
157
158        let mut positions: Vec<CartesianPoint> = original_cartesian.clone();
159
160        let mut temperature = 1.0;
161
162        // The inner loop runs once per (point, point) pair per iteration,
163        // so this is the dominant cost of the layout. Two wins applied
164        // here:
165        //
166        // 1. Both `pi` and `pj` come out of `project_to_unit_sphere` and
167        //    are kept on the unit sphere by `.normalize()` after every
168        //    step, so `angular_distance` reduces to a single `acos` of
169        //    the dot product. The previous code paid two
170        //    `cartesian_to_spherical` calls plus a Vincenty `atan2`
171        //    inside `angular_distance` per pair, all of which can be
172        //    skipped.
173        // 2. Each `i` builds its force independently from read-only
174        //    snapshots of `positions` / `original_cartesian`, so the
175        //    outer loop is trivially parallel. Stays serial under a
176        //    small threshold to avoid thread-pool churn for tiny inputs.
177        use rayon::prelude::*;
178        const FORCE_PARALLEL_THRESHOLD: usize = 128;
179
180        for _ in 0..self.iterations {
181            let compute_force = |i: usize| -> CartesianPoint {
182                let pi = positions[i];
183                let mut fx = 0.0;
184                let mut fy = 0.0;
185                let mut fz = 0.0;
186
187                for (j, &pj) in positions.iter().enumerate() {
188                    if i == j {
189                        continue;
190                    }
191
192                    let dot = pi.x * pj.x + pi.y * pj.y + pi.z * pj.z;
193                    let dist = dot.clamp(-1.0, 1.0).acos();
194
195                    let dx = pi.x - pj.x;
196                    let dy = pi.y - pj.y;
197                    let dz = pi.z - pj.z;
198                    let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
199                    if cart_dist < EPSILON {
200                        continue;
201                    }
202
203                    let magnitude = self.repulsion_strength / (dist * dist + EPSILON);
204                    let inv = magnitude / cart_dist;
205                    fx += inv * dx;
206                    fy += inv * dy;
207                    fz += inv * dz;
208                }
209
210                let oi = original_cartesian[i];
211                let dot_oi = pi.x * oi.x + pi.y * oi.y + pi.z * oi.z;
212                let dist_to_original = dot_oi.clamp(-1.0, 1.0).acos();
213
214                let dx = oi.x - pi.x;
215                let dy = oi.y - pi.y;
216                let dz = oi.z - pi.z;
217                let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
218                if cart_dist > EPSILON {
219                    let magnitude = self.attraction_strength * dist_to_original;
220                    let inv = magnitude / cart_dist;
221                    fx += inv * dx;
222                    fy += inv * dy;
223                    fz += inv * dz;
224                }
225
226                CartesianPoint::new(fx, fy, fz)
227            };
228
229            let forces: Vec<CartesianPoint> = if n < FORCE_PARALLEL_THRESHOLD {
230                (0..n).map(compute_force).collect()
231            } else {
232                (0..n).into_par_iter().map(compute_force).collect()
233            };
234
235            // Apply forces: project onto tangent plane, then normalize back to sphere
236            let step_size = temperature * STEP_SIZE_FACTOR;
237            for i in 0..n {
238                let p = positions[i];
239                let f = forces[i];
240
241                // Project force onto tangent plane at p: f_tangent = f - dot(f, p) * p
242                let dot = f.x * p.x + f.y * p.y + f.z * p.z;
243                let ft = CartesianPoint::new(f.x - dot * p.x, f.y - dot * p.y, f.z - dot * p.z);
244
245                let new_pos = CartesianPoint::new(
246                    p.x + step_size * ft.x,
247                    p.y + step_size * ft.y,
248                    p.z + step_size * ft.z,
249                );
250
251                positions[i] = new_pos.normalize();
252            }
253
254            temperature *= self.cooling_rate;
255        }
256
257        let final_positions: Vec<SphericalPoint> = positions
258            .iter()
259            .map(|c| {
260                let sp = cartesian_to_spherical(c);
261                SphericalPoint::new_unchecked(self.radius, sp.theta, sp.phi)
262            })
263            .collect();
264
265        let entries: Vec<LayoutEntry<T>> = items
266            .iter()
267            .zip(final_positions.iter())
268            .map(|(item, pos)| LayoutEntry {
269                item: item.clone(),
270                position: *pos,
271            })
272            .collect();
273
274        let quality = Self::compute_quality(&final_positions, n);
275
276        LayoutResult { entries, quality }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use std::f64::consts::FRAC_PI_2;
284
285    struct FixedMapper {
286        positions: Vec<SphericalPoint>,
287    }
288
289    impl DimensionMapper for FixedMapper {
290        type Item = usize;
291        fn map(&self, item: &usize) -> SphericalPoint {
292            self.positions[*item]
293        }
294    }
295
296    struct OriginMapper;
297
298    impl DimensionMapper for OriginMapper {
299        type Item = usize;
300        fn map(&self, _item: &usize) -> SphericalPoint {
301            SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2)
302        }
303    }
304
305    #[test]
306    fn empty_items_returns_empty() {
307        let layout = ForceDirectedLayout::new();
308        let items: Vec<usize> = vec![];
309        let result = layout.layout(&items, &OriginMapper);
310        assert!(result.entries.is_empty());
311    }
312
313    #[test]
314    fn single_item_stays_near_mapper_position() {
315        let target = SphericalPoint::new_unchecked(1.0, 1.0, 1.0);
316        let mapper = FixedMapper {
317            positions: vec![target],
318        };
319        let layout = ForceDirectedLayout::new().with_iterations(50);
320        let result = layout.layout(&[0usize], &mapper);
321
322        assert_eq!(result.entries.len(), 1);
323        let pos = &result.entries[0].position;
324        let dist = angular_distance(pos, &target);
325        assert!(
326            dist < 0.1,
327            "single item should stay near mapper position, but angular distance was {dist}"
328        );
329    }
330
331    #[test]
332    fn two_items_pushed_apart_by_repulsion() {
333        let mapper = FixedMapper {
334            positions: vec![
335                SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
336                SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
337            ],
338        };
339
340        let layout = ForceDirectedLayout::new()
341            .with_iterations(200)
342            .with_repulsion(2.0)
343            .with_attraction(0.01);
344
345        let result = layout.layout(&[0usize, 1], &mapper);
346        assert_eq!(result.entries.len(), 2);
347
348        let dist = angular_distance(&result.entries[0].position, &result.entries[1].position);
349
350        assert!(
351            dist > PI * 0.5,
352            "two items should be pushed far apart by repulsion, but angular distance was {dist}"
353        );
354    }
355
356    #[test]
357    fn all_positions_have_correct_radius() {
358        let r = 3.5;
359        let mapper = FixedMapper {
360            positions: vec![
361                SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
362                SphericalPoint::new_unchecked(1.0, 1.0, 1.0),
363                SphericalPoint::new_unchecked(1.0, 2.0, 0.5),
364                SphericalPoint::new_unchecked(1.0, 3.0, 2.5),
365            ],
366        };
367        let layout = ForceDirectedLayout::new().with_radius(r);
368        let result = layout.layout(&[0usize, 1, 2, 3], &mapper);
369
370        for (i, entry) in result.entries.iter().enumerate() {
371            assert!(
372                (entry.position.r - r).abs() < 1e-12,
373                "entry {i} has radius {}, expected {r}",
374                entry.position.r
375            );
376        }
377    }
378
379    #[test]
380    fn more_iterations_produce_better_or_equal_dispersion() {
381        let mapper = FixedMapper {
382            positions: vec![
383                SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
384                SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
385                SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
386                SphericalPoint::new_unchecked(1.0, 0.3, FRAC_PI_2),
387                SphericalPoint::new_unchecked(1.0, 0.4, FRAC_PI_2),
388            ],
389        };
390
391        let few = ForceDirectedLayout::new()
392            .with_iterations(5)
393            .with_repulsion(1.0)
394            .with_attraction(0.01);
395        let many = ForceDirectedLayout::new()
396            .with_iterations(200)
397            .with_repulsion(1.0)
398            .with_attraction(0.01);
399
400        let items: Vec<usize> = (0..5).collect();
401        let result_few = few.layout(&items, &mapper);
402        let result_many = many.layout(&items, &mapper);
403
404        assert!(
405            result_many.quality.dispersion_score >= result_few.quality.dispersion_score - 1e-6,
406            "more iterations ({}) should produce >= dispersion than fewer ({})",
407            result_many.quality.dispersion_score,
408            result_few.quality.dispersion_score,
409        );
410    }
411
412    #[test]
413    fn cooling_reduces_movement_over_time() {
414        let mapper = FixedMapper {
415            positions: vec![
416                SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
417                SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
418                SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
419            ],
420        };
421
422        let aggressive_cooling = ForceDirectedLayout::new()
423            .with_iterations(100)
424            .with_cooling(0.5);
425
426        let no_cooling = ForceDirectedLayout::new()
427            .with_iterations(100)
428            .with_cooling(1.0);
429
430        let items: Vec<usize> = (0..3).collect();
431        let result_cooled = aggressive_cooling.layout(&items, &mapper);
432        let result_uncooled = no_cooling.layout(&items, &mapper);
433
434        for entry in &result_cooled.entries {
435            assert!(!entry.position.theta.is_nan());
436            assert!(!entry.position.phi.is_nan());
437        }
438
439        let mut total_dist_cooled = 0.0;
440        let mut total_dist_uncooled = 0.0;
441        for (i, orig) in mapper.positions.iter().enumerate() {
442            total_dist_cooled += angular_distance(&result_cooled.entries[i].position, orig);
443            total_dist_uncooled += angular_distance(&result_uncooled.entries[i].position, orig);
444        }
445
446        assert!(
447            total_dist_uncooled >= total_dist_cooled - 1e-6,
448            "uncooled ({total_dist_uncooled}) should move points at least as far as \
449             aggressively cooled ({total_dist_cooled})"
450        );
451    }
452
453    #[test]
454    fn default_builder_matches_new() {
455        let from_new = ForceDirectedLayout::new();
456        let from_default = ForceDirectedLayout::default();
457        assert_eq!(from_new.iterations, from_default.iterations);
458        assert!((from_new.repulsion_strength - from_default.repulsion_strength).abs() < 1e-15);
459        assert!((from_new.attraction_strength - from_default.attraction_strength).abs() < 1e-15);
460        assert!((from_new.cooling_rate - from_default.cooling_rate).abs() < 1e-15);
461        assert!((from_new.radius - from_default.radius).abs() < 1e-15);
462    }
463}