1use std::collections::HashMap;
16
17use sphereql_core::cosine_similarity;
18
19use crate::config::LaplacianConfig;
20
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
25pub struct CorpusFeatures {
26 pub n_items: usize,
28 pub n_categories: usize,
30 pub dim: usize,
32 pub mean_members_per_category: f64,
34 pub category_size_entropy: f64,
38 pub mean_sparsity: f64,
41 pub axis_utilization_entropy: f64,
45 pub noise_estimate: f64,
48 pub mean_intra_category_similarity: f64,
51 pub mean_inter_category_similarity: f64,
55 pub category_separation_ratio: f64,
59}
60
61pub const CORPUS_FEATURE_COUNT: usize = 10;
63
64impl CorpusFeatures {
65 pub fn feature_names() -> [&'static str; CORPUS_FEATURE_COUNT] {
73 [
74 "n_items",
75 "n_categories",
76 "dim",
77 "mean_members_per_category",
78 "category_size_entropy",
79 "mean_sparsity",
80 "axis_utilization_entropy",
81 "noise_estimate",
82 "mean_intra_category_similarity",
83 "mean_inter_category_similarity",
84 ]
85 }
86
87 pub fn to_vec(&self) -> [f64; CORPUS_FEATURE_COUNT] {
93 [
94 self.n_items as f64,
95 self.n_categories as f64,
96 self.dim as f64,
97 self.mean_members_per_category,
98 self.category_size_entropy,
99 self.mean_sparsity,
100 self.axis_utilization_entropy,
101 self.noise_estimate,
102 self.mean_intra_category_similarity,
103 self.mean_inter_category_similarity,
104 ]
105 }
106
107 pub fn extract(categories: &[String], embeddings: &[Vec<f64>]) -> Result<Self, String> {
113 Self::extract_with_threshold(
114 categories,
115 embeddings,
116 LaplacianConfig::default().active_threshold,
117 )
118 }
119
120 pub fn extract_with_threshold(
127 categories: &[String],
128 embeddings: &[Vec<f64>],
129 active_threshold: f64,
130 ) -> Result<Self, String> {
131 if categories.len() != embeddings.len() {
132 return Err(format!(
133 "categories length {} does not match embeddings length {}",
134 categories.len(),
135 embeddings.len()
136 ));
137 }
138 let n = embeddings.len();
139 if n == 0 {
140 return Err("cannot extract features from an empty corpus".into());
141 }
142 let dim = embeddings[0].len();
143 if dim == 0 {
144 return Err("embeddings must have positive dimensionality".into());
145 }
146 for (i, e) in embeddings.iter().enumerate() {
147 if e.len() != dim {
148 return Err(format!(
149 "ragged embeddings: row {i} length {} != dim {dim}",
150 e.len()
151 ));
152 }
153 }
154
155 let mut cat_counts: HashMap<&str, usize> = HashMap::new();
157 for c in categories {
158 *cat_counts.entry(c.as_str()).or_insert(0) += 1;
159 }
160 let n_categories = cat_counts.len();
161 let mean_members_per_category = n as f64 / n_categories.max(1) as f64;
162
163 let category_size_entropy = if n_categories > 1 {
164 let h: f64 = cat_counts
165 .values()
166 .map(|&c| {
167 let p = c as f64 / n as f64;
168 if p > 0.0 { -p * p.ln() } else { 0.0 }
169 })
170 .sum();
171 h / (n_categories as f64).ln().max(f64::EPSILON)
173 } else {
174 0.0
175 };
176
177 let mut axis_usage = vec![0usize; dim];
179 let mut active_per_item = vec![0usize; n];
180 let mut noise_sum = 0.0f64;
181 let mut noise_count = 0usize;
182
183 for (i, e) in embeddings.iter().enumerate() {
184 let mut inactive_magnitudes: Vec<f64> = Vec::with_capacity(dim);
185 for (d, &v) in e.iter().enumerate() {
186 if v.abs() > active_threshold {
187 axis_usage[d] += 1;
188 active_per_item[i] += 1;
189 } else {
190 inactive_magnitudes.push(v.abs());
191 }
192 }
193 if !inactive_magnitudes.is_empty() {
194 inactive_magnitudes.sort_by(|a, b| a.total_cmp(b));
195 let median = inactive_magnitudes[inactive_magnitudes.len() / 2];
196 noise_sum += median;
197 noise_count += 1;
198 }
199 }
200
201 let mean_sparsity: f64 =
202 active_per_item.iter().map(|&a| a as f64).sum::<f64>() / (n * dim) as f64;
203
204 let axis_utilization_entropy = {
205 let total: f64 = axis_usage.iter().map(|&c| c as f64).sum();
206 if total > 0.0 && dim > 1 {
207 let h: f64 = axis_usage
208 .iter()
209 .map(|&c| {
210 let p = c as f64 / total;
211 if p > 0.0 { -p * p.ln() } else { 0.0 }
212 })
213 .sum();
214 h / (dim as f64).ln().max(f64::EPSILON)
215 } else {
216 0.0
217 }
218 };
219
220 let noise_estimate = if noise_count > 0 {
221 noise_sum / noise_count as f64
222 } else {
223 0.0
224 };
225
226 let mean_intra_category_similarity =
228 pairwise_similarity(embeddings, categories, SimilarityMode::IntraCategory);
229 let mean_inter_category_similarity =
230 pairwise_similarity(embeddings, categories, SimilarityMode::InterCategory);
231 let category_separation_ratio =
232 mean_intra_category_similarity / mean_inter_category_similarity.abs().max(1e-12);
233
234 Ok(Self {
235 n_items: n,
236 n_categories,
237 dim,
238 mean_members_per_category,
239 category_size_entropy,
240 mean_sparsity,
241 axis_utilization_entropy,
242 noise_estimate,
243 mean_intra_category_similarity,
244 mean_inter_category_similarity,
245 category_separation_ratio,
246 })
247 }
248}
249
250#[derive(Copy, Clone)]
253enum SimilarityMode {
254 IntraCategory,
255 InterCategory,
256}
257
258fn pairwise_similarity(
259 embeddings: &[Vec<f64>],
260 categories: &[String],
261 mode: SimilarityMode,
262) -> f64 {
263 use rayon::prelude::*;
264
265 let n = embeddings.len();
266 if n < 2 {
267 return 0.0;
268 }
269
270 const SERIAL_THRESHOLD: usize = 256;
278 if n < SERIAL_THRESHOLD {
279 let mut sum = 0.0;
280 let mut count: usize = 0;
281 for i in 0..n {
282 for j in (i + 1)..n {
283 if pair_matches(mode, &categories[i], &categories[j]) {
284 sum += cosine_similarity(&embeddings[i], &embeddings[j])
288 .expect("corpus embeddings share fixed dimensionality");
289 count += 1;
290 }
291 }
292 }
293 return if count == 0 { 0.0 } else { sum / count as f64 };
294 }
295
296 let (sum, count) = (0..n)
297 .into_par_iter()
298 .map(|i| {
299 let mut s = 0.0;
300 let mut c = 0usize;
301 for j in (i + 1)..n {
302 if pair_matches(mode, &categories[i], &categories[j]) {
303 s += cosine_similarity(&embeddings[i], &embeddings[j])
305 .expect("corpus embeddings share fixed dimensionality");
306 c += 1;
307 }
308 }
309 (s, c)
310 })
311 .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb));
312
313 if count == 0 { 0.0 } else { sum / count as f64 }
314}
315
316#[inline]
319fn pair_matches(mode: SimilarityMode, a: &str, b: &str) -> bool {
320 let same = a == b;
321 match mode {
322 SimilarityMode::IntraCategory => same,
323 SimilarityMode::InterCategory => !same,
324 }
325}
326
327#[cfg(test)]
330mod tests {
331 use super::*;
332
333 fn toy_corpus() -> (Vec<String>, Vec<Vec<f64>>) {
334 let categories: Vec<String> = vec![
335 "a".into(),
336 "a".into(),
337 "a".into(),
338 "b".into(),
339 "b".into(),
340 "b".into(),
341 ];
342 let embeddings = vec![
343 vec![1.0, 0.1, 0.0, 0.0, 0.02],
344 vec![0.9, 0.15, 0.0, 0.0, 0.01],
345 vec![0.95, 0.05, 0.0, 0.0, 0.03],
346 vec![0.1, 0.0, 1.0, 0.0, 0.02],
347 vec![0.15, 0.0, 0.9, 0.0, 0.01],
348 vec![0.05, 0.0, 0.95, 0.0, 0.03],
349 ];
350 (categories, embeddings)
351 }
352
353 #[test]
354 fn extract_rejects_empty_corpus() {
355 let result = CorpusFeatures::extract(&[], &[]);
356 assert!(result.is_err());
357 assert!(result.unwrap_err().contains("empty corpus"));
358 }
359
360 #[test]
361 fn extract_rejects_mismatched_lengths() {
362 let cats = vec!["a".to_string()];
363 let embs: Vec<Vec<f64>> = vec![];
364 let result = CorpusFeatures::extract(&cats, &embs);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn extract_rejects_ragged_embeddings() {
370 let cats = vec!["a".to_string(), "b".to_string()];
371 let embs = vec![vec![1.0, 2.0], vec![1.0]];
372 let result = CorpusFeatures::extract(&cats, &embs);
373 assert!(result.is_err());
374 assert!(result.unwrap_err().contains("ragged"));
375 }
376
377 #[test]
378 fn extract_basic_shape() {
379 let (cats, embs) = toy_corpus();
380 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
381 assert_eq!(cf.n_items, 6);
382 assert_eq!(cf.n_categories, 2);
383 assert_eq!(cf.dim, 5);
384 assert!((cf.mean_members_per_category - 3.0).abs() < 1e-12);
385 }
386
387 #[test]
388 fn category_size_entropy_balanced() {
389 let (cats, embs) = toy_corpus();
391 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
392 assert!(
393 (cf.category_size_entropy - 1.0).abs() < 1e-10,
394 "balanced split should give entropy = 1.0, got {}",
395 cf.category_size_entropy
396 );
397 }
398
399 #[test]
400 fn category_size_entropy_skewed() {
401 let cats: Vec<String> = vec!["a", "a", "a", "a", "a", "b"]
403 .into_iter()
404 .map(Into::into)
405 .collect();
406 let embs = vec![vec![1.0, 0.0, 0.0]; 6];
407 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
408 assert!(
409 cf.category_size_entropy < 0.9,
410 "skewed split should give entropy < 0.9, got {}",
411 cf.category_size_entropy
412 );
413 }
414
415 #[test]
416 fn sparsity_matches_threshold() {
417 let (cats, embs) = toy_corpus();
418 let cf = CorpusFeatures::extract_with_threshold(&cats, &embs, 0.05).unwrap();
420 assert!(
421 (cf.mean_sparsity - 0.4).abs() < 0.11,
422 "expected ~0.4, got {}",
423 cf.mean_sparsity
424 );
425 }
426
427 #[test]
428 fn intra_higher_than_inter_for_well_separated() {
429 let (cats, embs) = toy_corpus();
430 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
431 assert!(
432 cf.mean_intra_category_similarity > cf.mean_inter_category_similarity,
433 "expected intra > inter on well-separated corpus"
434 );
435 assert!(cf.category_separation_ratio > 1.0);
436 }
437
438 #[test]
439 fn to_vec_length_matches_feature_names() {
440 let (cats, embs) = toy_corpus();
441 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
442 assert_eq!(cf.to_vec().len(), CorpusFeatures::feature_names().len());
443 assert_eq!(cf.to_vec().len(), CORPUS_FEATURE_COUNT);
444 }
445
446 #[test]
447 fn features_serialize_json_roundtrip() {
448 let (cats, embs) = toy_corpus();
449 let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
450 let json = serde_json::to_string(&cf).unwrap();
451 let back: CorpusFeatures = serde_json::from_str(&json).unwrap();
452 assert_eq!(cf.n_items, back.n_items);
453 assert_eq!(cf.n_categories, back.n_categories);
454 assert!(
455 (cf.mean_intra_category_similarity - back.mean_intra_category_similarity).abs() < 1e-12
456 );
457 }
458
459 #[test]
460 fn empty_inactive_sets_produce_zero_noise() {
461 let cats: Vec<String> = vec!["a".into(), "a".into()];
463 let embs = vec![vec![1.0, 1.0], vec![0.9, 0.9]];
464 let cf = CorpusFeatures::extract_with_threshold(&cats, &embs, 0.05).unwrap();
465 assert_eq!(cf.noise_estimate, 0.0);
466 }
467}