1use std::f64::consts::PI;
2
3use sphereql_core::{
4 CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5 spherical_to_cartesian,
6};
7
8use crate::quality::{MAX_QUALITY_N, OVERLAP_THRESHOLD};
9use crate::traits::{DimensionMapper, LayoutStrategy};
10use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
11
12const EPSILON: f64 = 1e-6;
13const STEP_SIZE_FACTOR: f64 = 0.1;
14
15pub struct ForceDirectedLayout {
16 pub iterations: usize,
17 pub repulsion_strength: f64,
18 pub attraction_strength: f64,
19 pub cooling_rate: f64,
20 pub radius: f64,
21}
22
23impl ForceDirectedLayout {
24 pub fn new() -> Self {
25 Self {
26 iterations: 100,
27 repulsion_strength: 1.0,
28 attraction_strength: 0.1,
29 cooling_rate: 0.95,
30 radius: 1.0,
31 }
32 }
33
34 pub fn with_iterations(mut self, n: usize) -> Self {
35 self.iterations = n;
36 self
37 }
38
39 pub fn with_repulsion(mut self, f: f64) -> Self {
40 self.repulsion_strength = f;
41 self
42 }
43
44 pub fn with_attraction(mut self, f: f64) -> Self {
45 self.attraction_strength = f;
46 self
47 }
48
49 pub fn with_cooling(mut self, f: f64) -> Self {
50 self.cooling_rate = f;
51 self
52 }
53
54 pub fn with_radius(mut self, r: f64) -> Self {
55 self.radius = r;
56 self
57 }
58
59 fn project_to_unit_sphere(p: &SphericalPoint) -> CartesianPoint {
60 let unit = SphericalPoint::new_unchecked(1.0, p.theta, p.phi);
61 spherical_to_cartesian(&unit)
62 }
63
64 fn compute_quality(positions: &[SphericalPoint], n: usize) -> LayoutQuality {
65 if n <= 1 {
66 return LayoutQuality {
67 dispersion_score: 1.0,
68 overlap_score: 0.0,
69 silhouette_score: 0.0,
70 };
71 }
72
73 let (positions, n) = if n > MAX_QUALITY_N {
74 let step = n / MAX_QUALITY_N;
75 let sampled: Vec<_> = positions
76 .iter()
77 .step_by(step)
78 .take(MAX_QUALITY_N)
79 .copied()
80 .collect();
81 let len = sampled.len();
82 (sampled, len)
83 } else {
84 (positions.to_vec(), n)
85 };
86
87 let ideal_spacing = (4.0 * PI / n as f64).sqrt();
88 let total_pairs = (n * (n - 1) / 2) as u64;
89
90 use rayon::prelude::*;
94 const SERIAL_THRESHOLD: usize = 128;
95 let per_i = |i: usize| -> (f64, u64) {
96 let mut min_local = f64::MAX;
97 let mut count_local = 0u64;
98 for j in (i + 1)..n {
99 let d = angular_distance(&positions[i], &positions[j]);
100 if d < min_local {
101 min_local = d;
102 }
103 if d < OVERLAP_THRESHOLD {
104 count_local += 1;
105 }
106 }
107 (min_local, count_local)
108 };
109 let (min_dist, overlap_count) = if n < SERIAL_THRESHOLD {
110 (0..n)
111 .map(per_i)
112 .fold((f64::MAX, 0u64), |(ma, ca), (mb, cb)| {
113 (if mb < ma { mb } else { ma }, ca + cb)
114 })
115 } else {
116 (0..n).into_par_iter().map(per_i).reduce(
117 || (f64::MAX, 0u64),
118 |(ma, ca), (mb, cb)| (if mb < ma { mb } else { ma }, ca + cb),
119 )
120 };
121
122 let dispersion = (min_dist / ideal_spacing).clamp(0.0, 1.0);
123 let overlap = overlap_count as f64 / total_pairs as f64;
124
125 LayoutQuality {
126 dispersion_score: dispersion,
127 overlap_score: overlap,
128 silhouette_score: 0.0,
129 }
130 }
131}
132
133impl Default for ForceDirectedLayout {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl<T: Clone> LayoutStrategy<T> for ForceDirectedLayout {
140 fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
141 let n = items.len();
142
143 if n == 0 {
144 return LayoutResult {
145 entries: Vec::new(),
146 quality: LayoutQuality::default(),
147 };
148 }
149
150 let original_positions: Vec<SphericalPoint> =
151 items.iter().map(|item| mapper.map(item)).collect();
152
153 let original_cartesian: Vec<CartesianPoint> = original_positions
154 .iter()
155 .map(Self::project_to_unit_sphere)
156 .collect();
157
158 let mut positions: Vec<CartesianPoint> = original_cartesian.clone();
159
160 let mut temperature = 1.0;
161
162 use rayon::prelude::*;
178 const FORCE_PARALLEL_THRESHOLD: usize = 128;
179
180 for _ in 0..self.iterations {
181 let compute_force = |i: usize| -> CartesianPoint {
182 let pi = positions[i];
183 let mut fx = 0.0;
184 let mut fy = 0.0;
185 let mut fz = 0.0;
186
187 for (j, &pj) in positions.iter().enumerate() {
188 if i == j {
189 continue;
190 }
191
192 let dot = pi.x * pj.x + pi.y * pj.y + pi.z * pj.z;
193 let dist = dot.clamp(-1.0, 1.0).acos();
194
195 let dx = pi.x - pj.x;
196 let dy = pi.y - pj.y;
197 let dz = pi.z - pj.z;
198 let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
199 if cart_dist < EPSILON {
200 continue;
201 }
202
203 let magnitude = self.repulsion_strength / (dist * dist + EPSILON);
204 let inv = magnitude / cart_dist;
205 fx += inv * dx;
206 fy += inv * dy;
207 fz += inv * dz;
208 }
209
210 let oi = original_cartesian[i];
211 let dot_oi = pi.x * oi.x + pi.y * oi.y + pi.z * oi.z;
212 let dist_to_original = dot_oi.clamp(-1.0, 1.0).acos();
213
214 let dx = oi.x - pi.x;
215 let dy = oi.y - pi.y;
216 let dz = oi.z - pi.z;
217 let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
218 if cart_dist > EPSILON {
219 let magnitude = self.attraction_strength * dist_to_original;
220 let inv = magnitude / cart_dist;
221 fx += inv * dx;
222 fy += inv * dy;
223 fz += inv * dz;
224 }
225
226 CartesianPoint::new(fx, fy, fz)
227 };
228
229 let forces: Vec<CartesianPoint> = if n < FORCE_PARALLEL_THRESHOLD {
230 (0..n).map(compute_force).collect()
231 } else {
232 (0..n).into_par_iter().map(compute_force).collect()
233 };
234
235 let step_size = temperature * STEP_SIZE_FACTOR;
237 for i in 0..n {
238 let p = positions[i];
239 let f = forces[i];
240
241 let dot = f.x * p.x + f.y * p.y + f.z * p.z;
243 let ft = CartesianPoint::new(f.x - dot * p.x, f.y - dot * p.y, f.z - dot * p.z);
244
245 let new_pos = CartesianPoint::new(
246 p.x + step_size * ft.x,
247 p.y + step_size * ft.y,
248 p.z + step_size * ft.z,
249 );
250
251 positions[i] = new_pos.normalize();
252 }
253
254 temperature *= self.cooling_rate;
255 }
256
257 let final_positions: Vec<SphericalPoint> = positions
258 .iter()
259 .map(|c| {
260 let sp = cartesian_to_spherical(c);
261 SphericalPoint::new_unchecked(self.radius, sp.theta, sp.phi)
262 })
263 .collect();
264
265 let entries: Vec<LayoutEntry<T>> = items
266 .iter()
267 .zip(final_positions.iter())
268 .map(|(item, pos)| LayoutEntry {
269 item: item.clone(),
270 position: *pos,
271 })
272 .collect();
273
274 let quality = Self::compute_quality(&final_positions, n);
275
276 LayoutResult { entries, quality }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use std::f64::consts::FRAC_PI_2;
284
285 struct FixedMapper {
286 positions: Vec<SphericalPoint>,
287 }
288
289 impl DimensionMapper for FixedMapper {
290 type Item = usize;
291 fn map(&self, item: &usize) -> SphericalPoint {
292 self.positions[*item]
293 }
294 }
295
296 struct OriginMapper;
297
298 impl DimensionMapper for OriginMapper {
299 type Item = usize;
300 fn map(&self, _item: &usize) -> SphericalPoint {
301 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2)
302 }
303 }
304
305 #[test]
306 fn empty_items_returns_empty() {
307 let layout = ForceDirectedLayout::new();
308 let items: Vec<usize> = vec![];
309 let result = layout.layout(&items, &OriginMapper);
310 assert!(result.entries.is_empty());
311 }
312
313 #[test]
314 fn single_item_stays_near_mapper_position() {
315 let target = SphericalPoint::new_unchecked(1.0, 1.0, 1.0);
316 let mapper = FixedMapper {
317 positions: vec![target],
318 };
319 let layout = ForceDirectedLayout::new().with_iterations(50);
320 let result = layout.layout(&[0usize], &mapper);
321
322 assert_eq!(result.entries.len(), 1);
323 let pos = &result.entries[0].position;
324 let dist = angular_distance(pos, &target);
325 assert!(
326 dist < 0.1,
327 "single item should stay near mapper position, but angular distance was {dist}"
328 );
329 }
330
331 #[test]
332 fn two_items_pushed_apart_by_repulsion() {
333 let mapper = FixedMapper {
334 positions: vec![
335 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
336 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
337 ],
338 };
339
340 let layout = ForceDirectedLayout::new()
341 .with_iterations(200)
342 .with_repulsion(2.0)
343 .with_attraction(0.01);
344
345 let result = layout.layout(&[0usize, 1], &mapper);
346 assert_eq!(result.entries.len(), 2);
347
348 let dist = angular_distance(&result.entries[0].position, &result.entries[1].position);
349
350 assert!(
351 dist > PI * 0.5,
352 "two items should be pushed far apart by repulsion, but angular distance was {dist}"
353 );
354 }
355
356 #[test]
357 fn all_positions_have_correct_radius() {
358 let r = 3.5;
359 let mapper = FixedMapper {
360 positions: vec![
361 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
362 SphericalPoint::new_unchecked(1.0, 1.0, 1.0),
363 SphericalPoint::new_unchecked(1.0, 2.0, 0.5),
364 SphericalPoint::new_unchecked(1.0, 3.0, 2.5),
365 ],
366 };
367 let layout = ForceDirectedLayout::new().with_radius(r);
368 let result = layout.layout(&[0usize, 1, 2, 3], &mapper);
369
370 for (i, entry) in result.entries.iter().enumerate() {
371 assert!(
372 (entry.position.r - r).abs() < 1e-12,
373 "entry {i} has radius {}, expected {r}",
374 entry.position.r
375 );
376 }
377 }
378
379 #[test]
380 fn more_iterations_produce_better_or_equal_dispersion() {
381 let mapper = FixedMapper {
382 positions: vec![
383 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
384 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
385 SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
386 SphericalPoint::new_unchecked(1.0, 0.3, FRAC_PI_2),
387 SphericalPoint::new_unchecked(1.0, 0.4, FRAC_PI_2),
388 ],
389 };
390
391 let few = ForceDirectedLayout::new()
392 .with_iterations(5)
393 .with_repulsion(1.0)
394 .with_attraction(0.01);
395 let many = ForceDirectedLayout::new()
396 .with_iterations(200)
397 .with_repulsion(1.0)
398 .with_attraction(0.01);
399
400 let items: Vec<usize> = (0..5).collect();
401 let result_few = few.layout(&items, &mapper);
402 let result_many = many.layout(&items, &mapper);
403
404 assert!(
405 result_many.quality.dispersion_score >= result_few.quality.dispersion_score - 1e-6,
406 "more iterations ({}) should produce >= dispersion than fewer ({})",
407 result_many.quality.dispersion_score,
408 result_few.quality.dispersion_score,
409 );
410 }
411
412 #[test]
413 fn cooling_reduces_movement_over_time() {
414 let mapper = FixedMapper {
415 positions: vec![
416 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
417 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
418 SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
419 ],
420 };
421
422 let aggressive_cooling = ForceDirectedLayout::new()
423 .with_iterations(100)
424 .with_cooling(0.5);
425
426 let no_cooling = ForceDirectedLayout::new()
427 .with_iterations(100)
428 .with_cooling(1.0);
429
430 let items: Vec<usize> = (0..3).collect();
431 let result_cooled = aggressive_cooling.layout(&items, &mapper);
432 let result_uncooled = no_cooling.layout(&items, &mapper);
433
434 for entry in &result_cooled.entries {
435 assert!(!entry.position.theta.is_nan());
436 assert!(!entry.position.phi.is_nan());
437 }
438
439 let mut total_dist_cooled = 0.0;
440 let mut total_dist_uncooled = 0.0;
441 for (i, orig) in mapper.positions.iter().enumerate() {
442 total_dist_cooled += angular_distance(&result_cooled.entries[i].position, orig);
443 total_dist_uncooled += angular_distance(&result_uncooled.entries[i].position, orig);
444 }
445
446 assert!(
447 total_dist_uncooled >= total_dist_cooled - 1e-6,
448 "uncooled ({total_dist_uncooled}) should move points at least as far as \
449 aggressively cooled ({total_dist_cooled})"
450 );
451 }
452
453 #[test]
454 fn default_builder_matches_new() {
455 let from_new = ForceDirectedLayout::new();
456 let from_default = ForceDirectedLayout::default();
457 assert_eq!(from_new.iterations, from_default.iterations);
458 assert!((from_new.repulsion_strength - from_default.repulsion_strength).abs() < 1e-15);
459 assert!((from_new.attraction_strength - from_default.attraction_strength).abs() < 1e-15);
460 assert!((from_new.cooling_rate - from_default.cooling_rate).abs() < 1e-15);
461 assert!((from_new.radius - from_default.radius).abs() < 1e-15);
462 }
463}