1use sphereql_core::{CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical};
10
11use crate::category::CategoryLayer;
12
13#[derive(Debug, Clone)]
15pub struct DomainGroup {
16 pub member_categories: Vec<usize>,
18 pub category_names: Vec<String>,
20 pub centroid: SphericalPoint,
22 pub angular_spread: f64,
24 pub cohesion: f64,
26 pub total_items: usize,
28}
29
30pub fn detect_domain_groups(layer: &CategoryLayer, target_groups: usize) -> Vec<DomainGroup> {
37 let n = layer.summaries.len();
38 if n == 0 {
39 return Vec::new();
40 }
41 let effective = target_groups.max(1).min(n);
42 if effective < target_groups {
43 tracing::warn!(
44 requested = target_groups,
45 effective,
46 n_categories = n,
47 "num_domain_groups clamped to {} (corpus has only {} distinct categories); \
48 TuneReport.best_config.routing.num_domain_groups will show the requested value, \
49 not the realized one — compare against pipeline.domain_groups().len()",
50 effective,
51 n
52 );
53 }
54 let target_groups = effective;
55 let sq = &layer.spatial_quality;
56
57 let mut similarity = vec![vec![0.0f64; n]; n];
59 #[allow(clippy::needless_range_loop)] for i in 0..n {
61 for j in (i + 1)..n {
62 let mut s = 0.0;
63 if sq.are_voronoi_neighbors(i, j) {
64 s += 0.5;
65 }
66 let overlap = sq.intersection_area(i, j);
67 let min_cap = sq.cap_areas[i].min(sq.cap_areas[j]);
68 if min_cap > 1e-15 {
69 s += 0.5 * (overlap / min_cap).min(1.0);
70 }
71 similarity[i][j] = s;
72 similarity[j][i] = s;
73 }
74 }
75
76 let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
78
79 while clusters.len() > target_groups {
80 let mut best_sim = f64::NEG_INFINITY;
81 let mut best_i = 0;
82 let mut best_j = 1;
83
84 for i in 0..clusters.len() {
85 for j in (i + 1)..clusters.len() {
86 let mut total = 0.0;
87 let mut count = 0usize;
88 for &ci in &clusters[i] {
89 for &cj in &clusters[j] {
90 total += similarity[ci][cj];
91 count += 1;
92 }
93 }
94 let avg = if count > 0 { total / count as f64 } else { 0.0 };
95 if avg > best_sim {
96 best_sim = avg;
97 best_i = i;
98 best_j = j;
99 }
100 }
101 }
102
103 let merged = clusters.remove(best_j);
104 clusters[best_i].extend(merged);
105 }
106
107 clusters
109 .into_iter()
110 .map(|members| build_group(layer, members))
111 .collect()
112}
113
114fn build_group(layer: &CategoryLayer, members: Vec<usize>) -> DomainGroup {
115 let category_names: Vec<String> = members
116 .iter()
117 .map(|&i| layer.summaries[i].name.clone())
118 .collect();
119
120 let total_items: usize = members
121 .iter()
122 .map(|&i| layer.summaries[i].member_count)
123 .sum();
124
125 let (mut sx, mut sy, mut sz) = (0.0, 0.0, 0.0);
127 for &i in &members {
128 let c = layer.summaries[i].centroid_position.unit_cartesian();
129 sx += c[0];
130 sy += c[1];
131 sz += c[2];
132 }
133 let mag = (sx * sx + sy * sy + sz * sz).sqrt();
134 let centroid = if mag > 1e-15 {
135 cartesian_to_spherical(&CartesianPoint::new(sx / mag, sy / mag, sz / mag))
136 } else {
137 layer.summaries[members[0]].centroid_position
138 };
139
140 let angular_spread = if members.len() > 1 {
141 members
142 .iter()
143 .map(|&i| angular_distance(&layer.summaries[i].centroid_position, ¢roid))
144 .sum::<f64>()
145 / members.len() as f64
146 } else {
147 0.0
148 };
149
150 DomainGroup {
151 member_categories: members,
152 category_names,
153 centroid,
154 angular_spread,
155 cohesion: 1.0 / (1.0 + angular_spread),
156 total_items,
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::projection::{PcaProjection, Projection};
164 use crate::types::{Embedding, RadialStrategy};
165
166 fn emb(vals: &[f64]) -> Embedding {
167 Embedding::new(vals.to_vec())
168 }
169
170 fn build_layer() -> CategoryLayer {
171 let categories = vec![
172 "science".into(),
173 "science".into(),
174 "science".into(),
175 "cooking".into(),
176 "cooking".into(),
177 "cooking".into(),
178 "music".into(),
179 "music".into(),
180 "music".into(),
181 ];
182 let embeddings = vec![
183 emb(&[1.0, 0.1, 0.0, 0.05, 0.02]),
184 emb(&[0.9, 0.15, 0.05, 0.03, 0.01]),
185 emb(&[0.95, 0.05, 0.1, 0.04, 0.03]),
186 emb(&[0.1, 1.0, 0.0, 0.02, 0.05]),
187 emb(&[0.15, 0.9, 0.05, 0.03, 0.04]),
188 emb(&[0.05, 0.95, 0.1, 0.01, 0.06]),
189 emb(&[0.0, 0.1, 1.0, 0.05, 0.02]),
190 emb(&[0.05, 0.15, 0.9, 0.03, 0.01]),
191 emb(&[0.1, 0.05, 0.95, 0.04, 0.03]),
192 ];
193 let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0)).unwrap();
194 let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
195 let evr = pca.explained_variance_ratio();
196 CategoryLayer::build(&categories, &embeddings, &projected, &pca, evr)
197 }
198
199 #[test]
200 fn target_clamped_to_category_count() {
201 let layer = build_layer();
202 let groups = detect_domain_groups(&layer, 99);
203 assert_eq!(groups.len(), layer.num_categories());
204 }
205
206 #[test]
207 fn target_one_merges_everything() {
208 let layer = build_layer();
209 let groups = detect_domain_groups(&layer, 1);
210 assert_eq!(groups.len(), 1);
211 assert_eq!(groups[0].member_categories.len(), layer.num_categories());
212 }
213
214 #[test]
215 fn total_items_preserved() {
216 let layer = build_layer();
217 let groups = detect_domain_groups(&layer, 2);
218 let total_in_groups: usize = groups.iter().map(|g| g.total_items).sum();
219 let total_in_layer: usize = layer.summaries.iter().map(|s| s.member_count).sum();
220 assert_eq!(total_in_groups, total_in_layer);
221 }
222
223 #[test]
224 fn every_category_assigned_once() {
225 let layer = build_layer();
226 let groups = detect_domain_groups(&layer, 2);
227 let mut all: Vec<usize> = groups
228 .iter()
229 .flat_map(|g| g.member_categories.iter().copied())
230 .collect();
231 all.sort();
232 let before_dedup = all.len();
233 all.dedup();
234 assert_eq!(
235 before_dedup,
236 all.len(),
237 "category assigned to multiple groups"
238 );
239 assert_eq!(all.len(), layer.num_categories());
240 }
241
242 #[test]
243 fn cohesion_in_range() {
244 let layer = build_layer();
245 for g in detect_domain_groups(&layer, 2) {
246 assert!(g.cohesion > 0.0 && g.cohesion <= 1.0);
247 assert!(g.angular_spread >= 0.0);
248 }
249 }
250}