1use std::f64::consts::PI;
2
3use sphereql_core::{
4 CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5 spherical_to_cartesian,
6};
7
8use crate::traits::{DimensionMapper, LayoutStrategy};
9use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
10
11const EPSILON: f64 = 1e-6;
12const STEP_SIZE_FACTOR: f64 = 0.1;
13const OVERLAP_THRESHOLD: f64 = 0.01;
14
15pub struct ForceDirectedLayout {
16 pub iterations: usize,
17 pub repulsion_strength: f64,
18 pub attraction_strength: f64,
19 pub cooling_rate: f64,
20 pub radius: f64,
21}
22
23impl ForceDirectedLayout {
24 pub fn new() -> Self {
25 Self {
26 iterations: 100,
27 repulsion_strength: 1.0,
28 attraction_strength: 0.1,
29 cooling_rate: 0.95,
30 radius: 1.0,
31 }
32 }
33
34 pub fn with_iterations(mut self, n: usize) -> Self {
35 self.iterations = n;
36 self
37 }
38
39 pub fn with_repulsion(mut self, f: f64) -> Self {
40 self.repulsion_strength = f;
41 self
42 }
43
44 pub fn with_attraction(mut self, f: f64) -> Self {
45 self.attraction_strength = f;
46 self
47 }
48
49 pub fn with_cooling(mut self, f: f64) -> Self {
50 self.cooling_rate = f;
51 self
52 }
53
54 pub fn with_radius(mut self, r: f64) -> Self {
55 self.radius = r;
56 self
57 }
58
59 fn project_to_unit_sphere(p: &SphericalPoint) -> CartesianPoint {
60 let unit = SphericalPoint::new_unchecked(1.0, p.theta, p.phi);
61 spherical_to_cartesian(&unit)
62 }
63
64 const MAX_QUALITY_N: usize = 5000;
65
66 fn compute_quality(positions: &[SphericalPoint], n: usize) -> LayoutQuality {
67 if n <= 1 {
68 return LayoutQuality {
69 dispersion_score: 1.0,
70 overlap_score: 0.0,
71 silhouette_score: 0.0,
72 };
73 }
74
75 let (positions, n) = if n > Self::MAX_QUALITY_N {
76 let step = n / Self::MAX_QUALITY_N;
77 let sampled: Vec<_> = positions
78 .iter()
79 .step_by(step)
80 .take(Self::MAX_QUALITY_N)
81 .copied()
82 .collect();
83 let len = sampled.len();
84 (sampled, len)
85 } else {
86 (positions.to_vec(), n)
87 };
88
89 let ideal_spacing = (4.0 * PI / n as f64).sqrt();
90 let total_pairs = (n * (n - 1) / 2) as u64;
91
92 use rayon::prelude::*;
96 const SERIAL_THRESHOLD: usize = 128;
97 let per_i = |i: usize| -> (f64, u64) {
98 let mut min_local = f64::MAX;
99 let mut count_local = 0u64;
100 for j in (i + 1)..n {
101 let d = angular_distance(&positions[i], &positions[j]);
102 if d < min_local {
103 min_local = d;
104 }
105 if d < OVERLAP_THRESHOLD {
106 count_local += 1;
107 }
108 }
109 (min_local, count_local)
110 };
111 let (min_dist, overlap_count) = if n < SERIAL_THRESHOLD {
112 (0..n)
113 .map(per_i)
114 .fold((f64::MAX, 0u64), |(ma, ca), (mb, cb)| {
115 (if mb < ma { mb } else { ma }, ca + cb)
116 })
117 } else {
118 (0..n).into_par_iter().map(per_i).reduce(
119 || (f64::MAX, 0u64),
120 |(ma, ca), (mb, cb)| (if mb < ma { mb } else { ma }, ca + cb),
121 )
122 };
123
124 let dispersion = (min_dist / ideal_spacing).clamp(0.0, 1.0);
125 let overlap = overlap_count as f64 / total_pairs as f64;
126
127 LayoutQuality {
128 dispersion_score: dispersion,
129 overlap_score: overlap,
130 silhouette_score: 0.0,
131 }
132 }
133}
134
135impl Default for ForceDirectedLayout {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl<T: Clone> LayoutStrategy<T> for ForceDirectedLayout {
142 fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
143 let n = items.len();
144
145 if n == 0 {
146 return LayoutResult {
147 entries: Vec::new(),
148 quality: LayoutQuality::default(),
149 };
150 }
151
152 let original_positions: Vec<SphericalPoint> =
153 items.iter().map(|item| mapper.map(item)).collect();
154
155 let original_cartesian: Vec<CartesianPoint> = original_positions
156 .iter()
157 .map(Self::project_to_unit_sphere)
158 .collect();
159
160 let mut positions: Vec<CartesianPoint> = original_cartesian.clone();
161
162 let mut temperature = 1.0;
163
164 for _ in 0..self.iterations {
165 let mut forces: Vec<CartesianPoint> = vec![CartesianPoint::new(0.0, 0.0, 0.0); n];
166
167 for i in 0..n {
168 let pi = positions[i];
169
170 for (j, &pj) in positions.iter().enumerate() {
172 if i == j {
173 continue;
174 }
175
176 let sp_i = cartesian_to_spherical(&pi);
177 let sp_j = cartesian_to_spherical(&pj);
178 let dist = angular_distance(&sp_i, &sp_j);
179
180 let dx = pi.x - pj.x;
181 let dy = pi.y - pj.y;
182 let dz = pi.z - pj.z;
183
184 let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
185 if cart_dist < EPSILON {
186 continue;
187 }
188
189 let magnitude = self.repulsion_strength / (dist * dist + EPSILON);
190
191 forces[i] = CartesianPoint::new(
192 forces[i].x + magnitude * dx / cart_dist,
193 forces[i].y + magnitude * dy / cart_dist,
194 forces[i].z + magnitude * dz / cart_dist,
195 );
196 }
197
198 let oi = original_cartesian[i];
200 let sp_i = cartesian_to_spherical(&pi);
201 let sp_oi = cartesian_to_spherical(&oi);
202 let dist_to_original = angular_distance(&sp_i, &sp_oi);
203
204 let dx = oi.x - pi.x;
205 let dy = oi.y - pi.y;
206 let dz = oi.z - pi.z;
207 let cart_dist = (dx * dx + dy * dy + dz * dz).sqrt();
208
209 if cart_dist > EPSILON {
210 let magnitude = self.attraction_strength * dist_to_original;
211 forces[i] = CartesianPoint::new(
212 forces[i].x + magnitude * dx / cart_dist,
213 forces[i].y + magnitude * dy / cart_dist,
214 forces[i].z + magnitude * dz / cart_dist,
215 );
216 }
217 }
218
219 let step_size = temperature * STEP_SIZE_FACTOR;
221 for i in 0..n {
222 let p = positions[i];
223 let f = forces[i];
224
225 let dot = f.x * p.x + f.y * p.y + f.z * p.z;
227 let ft = CartesianPoint::new(f.x - dot * p.x, f.y - dot * p.y, f.z - dot * p.z);
228
229 let new_pos = CartesianPoint::new(
230 p.x + step_size * ft.x,
231 p.y + step_size * ft.y,
232 p.z + step_size * ft.z,
233 );
234
235 positions[i] = new_pos.normalize();
236 }
237
238 temperature *= self.cooling_rate;
239 }
240
241 let final_positions: Vec<SphericalPoint> = positions
242 .iter()
243 .map(|c| {
244 let sp = cartesian_to_spherical(c);
245 SphericalPoint::new_unchecked(self.radius, sp.theta, sp.phi)
246 })
247 .collect();
248
249 let entries: Vec<LayoutEntry<T>> = items
250 .iter()
251 .zip(final_positions.iter())
252 .map(|(item, pos)| LayoutEntry {
253 item: item.clone(),
254 position: *pos,
255 })
256 .collect();
257
258 let quality = Self::compute_quality(&final_positions, n);
259
260 LayoutResult { entries, quality }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use std::f64::consts::FRAC_PI_2;
268
269 struct FixedMapper {
270 positions: Vec<SphericalPoint>,
271 }
272
273 impl DimensionMapper for FixedMapper {
274 type Item = usize;
275 fn map(&self, item: &usize) -> SphericalPoint {
276 self.positions[*item]
277 }
278 }
279
280 struct OriginMapper;
281
282 impl DimensionMapper for OriginMapper {
283 type Item = usize;
284 fn map(&self, _item: &usize) -> SphericalPoint {
285 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2)
286 }
287 }
288
289 #[test]
290 fn empty_items_returns_empty() {
291 let layout = ForceDirectedLayout::new();
292 let items: Vec<usize> = vec![];
293 let result = layout.layout(&items, &OriginMapper);
294 assert!(result.entries.is_empty());
295 }
296
297 #[test]
298 fn single_item_stays_near_mapper_position() {
299 let target = SphericalPoint::new_unchecked(1.0, 1.0, 1.0);
300 let mapper = FixedMapper {
301 positions: vec![target],
302 };
303 let layout = ForceDirectedLayout::new().with_iterations(50);
304 let result = layout.layout(&[0usize], &mapper);
305
306 assert_eq!(result.entries.len(), 1);
307 let pos = &result.entries[0].position;
308 let dist = angular_distance(pos, &target);
309 assert!(
310 dist < 0.1,
311 "single item should stay near mapper position, but angular distance was {dist}"
312 );
313 }
314
315 #[test]
316 fn two_items_pushed_apart_by_repulsion() {
317 let mapper = FixedMapper {
318 positions: vec![
319 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
320 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
321 ],
322 };
323
324 let layout = ForceDirectedLayout::new()
325 .with_iterations(200)
326 .with_repulsion(2.0)
327 .with_attraction(0.01);
328
329 let result = layout.layout(&[0usize, 1], &mapper);
330 assert_eq!(result.entries.len(), 2);
331
332 let dist = angular_distance(&result.entries[0].position, &result.entries[1].position);
333
334 assert!(
335 dist > PI * 0.5,
336 "two items should be pushed far apart by repulsion, but angular distance was {dist}"
337 );
338 }
339
340 #[test]
341 fn all_positions_have_correct_radius() {
342 let r = 3.5;
343 let mapper = FixedMapper {
344 positions: vec![
345 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
346 SphericalPoint::new_unchecked(1.0, 1.0, 1.0),
347 SphericalPoint::new_unchecked(1.0, 2.0, 0.5),
348 SphericalPoint::new_unchecked(1.0, 3.0, 2.5),
349 ],
350 };
351 let layout = ForceDirectedLayout::new().with_radius(r);
352 let result = layout.layout(&[0usize, 1, 2, 3], &mapper);
353
354 for (i, entry) in result.entries.iter().enumerate() {
355 assert!(
356 (entry.position.r - r).abs() < 1e-12,
357 "entry {i} has radius {}, expected {r}",
358 entry.position.r
359 );
360 }
361 }
362
363 #[test]
364 fn more_iterations_produce_better_or_equal_dispersion() {
365 let mapper = FixedMapper {
366 positions: vec![
367 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
368 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
369 SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
370 SphericalPoint::new_unchecked(1.0, 0.3, FRAC_PI_2),
371 SphericalPoint::new_unchecked(1.0, 0.4, FRAC_PI_2),
372 ],
373 };
374
375 let few = ForceDirectedLayout::new()
376 .with_iterations(5)
377 .with_repulsion(1.0)
378 .with_attraction(0.01);
379 let many = ForceDirectedLayout::new()
380 .with_iterations(200)
381 .with_repulsion(1.0)
382 .with_attraction(0.01);
383
384 let items: Vec<usize> = (0..5).collect();
385 let result_few = few.layout(&items, &mapper);
386 let result_many = many.layout(&items, &mapper);
387
388 assert!(
389 result_many.quality.dispersion_score >= result_few.quality.dispersion_score - 1e-6,
390 "more iterations ({}) should produce >= dispersion than fewer ({})",
391 result_many.quality.dispersion_score,
392 result_few.quality.dispersion_score,
393 );
394 }
395
396 #[test]
397 fn cooling_reduces_movement_over_time() {
398 let mapper = FixedMapper {
399 positions: vec![
400 SphericalPoint::new_unchecked(1.0, 0.0, FRAC_PI_2),
401 SphericalPoint::new_unchecked(1.0, 0.1, FRAC_PI_2),
402 SphericalPoint::new_unchecked(1.0, 0.2, FRAC_PI_2),
403 ],
404 };
405
406 let aggressive_cooling = ForceDirectedLayout::new()
407 .with_iterations(100)
408 .with_cooling(0.5);
409
410 let no_cooling = ForceDirectedLayout::new()
411 .with_iterations(100)
412 .with_cooling(1.0);
413
414 let items: Vec<usize> = (0..3).collect();
415 let result_cooled = aggressive_cooling.layout(&items, &mapper);
416 let result_uncooled = no_cooling.layout(&items, &mapper);
417
418 for entry in &result_cooled.entries {
419 assert!(!entry.position.theta.is_nan());
420 assert!(!entry.position.phi.is_nan());
421 }
422
423 let mut total_dist_cooled = 0.0;
424 let mut total_dist_uncooled = 0.0;
425 for (i, orig) in mapper.positions.iter().enumerate() {
426 total_dist_cooled += angular_distance(&result_cooled.entries[i].position, orig);
427 total_dist_uncooled += angular_distance(&result_uncooled.entries[i].position, orig);
428 }
429
430 assert!(
431 total_dist_uncooled >= total_dist_cooled - 1e-6,
432 "uncooled ({total_dist_uncooled}) should move points at least as far as \
433 aggressively cooled ({total_dist_cooled})"
434 );
435 }
436
437 #[test]
438 fn default_builder_matches_new() {
439 let from_new = ForceDirectedLayout::new();
440 let from_default = ForceDirectedLayout::default();
441 assert_eq!(from_new.iterations, from_default.iterations);
442 assert!((from_new.repulsion_strength - from_default.repulsion_strength).abs() < 1e-15);
443 assert!((from_new.attraction_strength - from_default.attraction_strength).abs() < 1e-15);
444 assert!((from_new.cooling_rate - from_default.cooling_rate).abs() < 1e-15);
445 assert!((from_new.radius - from_default.radius).abs() < 1e-15);
446 }
447}