Skip to main content

sphereql_embed/
domain_groups.rs

1//! Hierarchical domain groups: coarse routing for noisy projections.
2//!
3//! When EVR is low (< 0.35), routing through N individual categories on a
4//! distorted outer sphere is unreliable — random angular noise drowns the
5//! signal. Collapsing the N categories into a handful of super-groups
6//! (derived from Voronoi adjacency + cap overlap) reduces the routing
7//! problem's dimensionality and restores usable coarse structure.
8
9use sphereql_core::{CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical};
10
11use crate::category::CategoryLayer;
12
13/// A cluster of related categories detected from sphere geometry.
14#[derive(Debug, Clone)]
15pub struct DomainGroup {
16    /// Indices of member categories in the [`CategoryLayer`] summaries vec.
17    pub member_categories: Vec<usize>,
18    /// Category names for convenience.
19    pub category_names: Vec<String>,
20    /// Centroid of the group on S² (mean of member centroids).
21    pub centroid: SphericalPoint,
22    /// Angular spread of the group (mean distance of members from group centroid).
23    pub angular_spread: f64,
24    /// Cohesion: `1 / (1 + angular_spread)`.
25    pub cohesion: f64,
26    /// Total items across all member categories.
27    pub total_items: usize,
28}
29
30/// Detect up to `target_groups` domain groups from the category layer.
31///
32/// Greedy agglomerative clustering over a Voronoi-adjacency + cap-overlap
33/// similarity matrix. Pairs of spatially adjacent or heavily overlapping
34/// categories are merged first; the merge stops when `target_groups`
35/// clusters remain (or earlier if fewer categories exist).
36pub fn detect_domain_groups(layer: &CategoryLayer, target_groups: usize) -> Vec<DomainGroup> {
37    let n = layer.summaries.len();
38    if n == 0 {
39        return Vec::new();
40    }
41    let effective = target_groups.max(1).min(n);
42    if effective < target_groups {
43        tracing::warn!(
44            requested = target_groups,
45            effective,
46            n_categories = n,
47            "num_domain_groups clamped to {} (corpus has only {} distinct categories); \
48             TuneReport.best_config.routing.num_domain_groups will show the requested value, \
49             not the realized one — compare against pipeline.domain_groups().len()",
50            effective,
51            n
52        );
53    }
54    let target_groups = effective;
55    let sq = &layer.spatial_quality;
56
57    // 1. Similarity matrix from Voronoi adjacency + normalized cap overlap.
58    let mut similarity = vec![vec![0.0f64; n]; n];
59    #[allow(clippy::needless_range_loop)] // symmetric 2D fill needs both indices
60    for i in 0..n {
61        for j in (i + 1)..n {
62            let mut s = 0.0;
63            if sq.are_voronoi_neighbors(i, j) {
64                s += 0.5;
65            }
66            let overlap = sq.intersection_area(i, j);
67            let min_cap = sq.cap_areas[i].min(sq.cap_areas[j]);
68            if min_cap > 1e-15 {
69                s += 0.5 * (overlap / min_cap).min(1.0);
70            }
71            similarity[i][j] = s;
72            similarity[j][i] = s;
73        }
74    }
75
76    // 2. Greedy agglomerative clustering (average linkage).
77    let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
78
79    while clusters.len() > target_groups {
80        let mut best_sim = f64::NEG_INFINITY;
81        let mut best_i = 0;
82        let mut best_j = 1;
83
84        for i in 0..clusters.len() {
85            for j in (i + 1)..clusters.len() {
86                let mut total = 0.0;
87                let mut count = 0usize;
88                for &ci in &clusters[i] {
89                    for &cj in &clusters[j] {
90                        total += similarity[ci][cj];
91                        count += 1;
92                    }
93                }
94                let avg = if count > 0 { total / count as f64 } else { 0.0 };
95                if avg > best_sim {
96                    best_sim = avg;
97                    best_i = i;
98                    best_j = j;
99                }
100            }
101        }
102
103        let merged = clusters.remove(best_j);
104        clusters[best_i].extend(merged);
105    }
106
107    // 3. Build DomainGroup records.
108    clusters
109        .into_iter()
110        .map(|members| build_group(layer, members))
111        .collect()
112}
113
114fn build_group(layer: &CategoryLayer, members: Vec<usize>) -> DomainGroup {
115    let category_names: Vec<String> = members
116        .iter()
117        .map(|&i| layer.summaries[i].name.clone())
118        .collect();
119
120    let total_items: usize = members
121        .iter()
122        .map(|&i| layer.summaries[i].member_count)
123        .sum();
124
125    // Group centroid: normalized mean of member unit vectors, then back to spherical.
126    let (mut sx, mut sy, mut sz) = (0.0, 0.0, 0.0);
127    for &i in &members {
128        let c = layer.summaries[i].centroid_position.unit_cartesian();
129        sx += c[0];
130        sy += c[1];
131        sz += c[2];
132    }
133    let mag = (sx * sx + sy * sy + sz * sz).sqrt();
134    let centroid = if mag > 1e-15 {
135        cartesian_to_spherical(&CartesianPoint::new(sx / mag, sy / mag, sz / mag))
136    } else {
137        layer.summaries[members[0]].centroid_position
138    };
139
140    let angular_spread = if members.len() > 1 {
141        members
142            .iter()
143            .map(|&i| angular_distance(&layer.summaries[i].centroid_position, &centroid))
144            .sum::<f64>()
145            / members.len() as f64
146    } else {
147        0.0
148    };
149
150    DomainGroup {
151        member_categories: members,
152        category_names,
153        centroid,
154        angular_spread,
155        cohesion: 1.0 / (1.0 + angular_spread),
156        total_items,
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::projection::{PcaProjection, Projection};
164    use crate::types::{Embedding, RadialStrategy};
165
166    fn emb(vals: &[f64]) -> Embedding {
167        Embedding::new(vals.to_vec())
168    }
169
170    fn build_layer() -> CategoryLayer {
171        let categories = vec![
172            "science".into(),
173            "science".into(),
174            "science".into(),
175            "cooking".into(),
176            "cooking".into(),
177            "cooking".into(),
178            "music".into(),
179            "music".into(),
180            "music".into(),
181        ];
182        let embeddings = vec![
183            emb(&[1.0, 0.1, 0.0, 0.05, 0.02]),
184            emb(&[0.9, 0.15, 0.05, 0.03, 0.01]),
185            emb(&[0.95, 0.05, 0.1, 0.04, 0.03]),
186            emb(&[0.1, 1.0, 0.0, 0.02, 0.05]),
187            emb(&[0.15, 0.9, 0.05, 0.03, 0.04]),
188            emb(&[0.05, 0.95, 0.1, 0.01, 0.06]),
189            emb(&[0.0, 0.1, 1.0, 0.05, 0.02]),
190            emb(&[0.05, 0.15, 0.9, 0.03, 0.01]),
191            emb(&[0.1, 0.05, 0.95, 0.04, 0.03]),
192        ];
193        let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0)).unwrap();
194        let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
195        let evr = pca.explained_variance_ratio();
196        CategoryLayer::build(&categories, &embeddings, &projected, &pca, evr)
197    }
198
199    #[test]
200    fn target_clamped_to_category_count() {
201        let layer = build_layer();
202        let groups = detect_domain_groups(&layer, 99);
203        assert_eq!(groups.len(), layer.num_categories());
204    }
205
206    #[test]
207    fn target_one_merges_everything() {
208        let layer = build_layer();
209        let groups = detect_domain_groups(&layer, 1);
210        assert_eq!(groups.len(), 1);
211        assert_eq!(groups[0].member_categories.len(), layer.num_categories());
212    }
213
214    #[test]
215    fn total_items_preserved() {
216        let layer = build_layer();
217        let groups = detect_domain_groups(&layer, 2);
218        let total_in_groups: usize = groups.iter().map(|g| g.total_items).sum();
219        let total_in_layer: usize = layer.summaries.iter().map(|s| s.member_count).sum();
220        assert_eq!(total_in_groups, total_in_layer);
221    }
222
223    #[test]
224    fn every_category_assigned_once() {
225        let layer = build_layer();
226        let groups = detect_domain_groups(&layer, 2);
227        let mut all: Vec<usize> = groups
228            .iter()
229            .flat_map(|g| g.member_categories.iter().copied())
230            .collect();
231        all.sort();
232        let before_dedup = all.len();
233        all.dedup();
234        assert_eq!(
235            before_dedup,
236            all.len(),
237            "category assigned to multiple groups"
238        );
239        assert_eq!(all.len(), layer.num_categories());
240    }
241
242    #[test]
243    fn cohesion_in_range() {
244        let layer = build_layer();
245        for g in detect_domain_groups(&layer, 2) {
246            assert!(g.cohesion > 0.0 && g.cohesion <= 1.0);
247            assert!(g.angular_spread >= 0.0);
248        }
249    }
250}