Skip to main content

sphereql_embed/
corpus_features.rs

1//! Low-dimensional profile of a corpus — features that meta-learning can
2//! map to an optimal [`PipelineConfig`](crate::config::PipelineConfig).
3//!
4//! Meta-learning across corpora needs a stable, compact characterization
5//! of "what kind of data is this?" so that past (corpus, best_config)
6//! pairs can be indexed and retrieved for prediction on new corpora.
7//! [`CorpusFeatures`] is that characterization. Extraction is O(N² · d)
8//! for pairwise-similarity features and O(N · d) for shape features;
9//! ~100ms at N = 775, d = 128.
10//!
11//! Every field is a scalar [0, 1] or an unbounded non-negative number.
12//! [`CorpusFeatures::to_vec`] flattens to a fixed-length `Vec<f64>` in a
13//! stable order matching [`CorpusFeatures::feature_names`].
14
15use std::collections::HashMap;
16
17use sphereql_core::cosine_similarity;
18
19use crate::config::LaplacianConfig;
20
21/// Low-dimensional profile of a corpus. Computed once per corpus; fed
22/// into any [`MetaModel`](crate::meta_model::MetaModel) to predict the
23/// pipeline config that's likely to work best on it.
24#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
25pub struct CorpusFeatures {
26    /// Total item count.
27    pub n_items: usize,
28    /// Unique category count.
29    pub n_categories: usize,
30    /// Embedding dimensionality.
31    pub dim: usize,
32    /// `n_items / n_categories`.
33    pub mean_members_per_category: f64,
34    /// Shannon entropy of the category-size distribution, normalized to
35    /// `[0, 1]` by dividing by `log(n_categories)`. High = balanced
36    /// category sizes; low = heavily skewed.
37    pub category_size_entropy: f64,
38    /// Mean per-item active-axis fraction: `|active axes| / dim`, averaged
39    /// over items. An axis is active when `|v_i| > active_threshold`.
40    pub mean_sparsity: f64,
41    /// Entropy of how often each axis is active across the corpus,
42    /// normalized by `log(dim)`. High = all axes used similarly; low =
43    /// a few axes dominate.
44    pub axis_utilization_entropy: f64,
45    /// Median of `|v_i|` across inactive entries (|v_i| ≤ threshold),
46    /// averaged across items. A proxy for the noise floor.
47    pub noise_estimate: f64,
48    /// Mean intra-category cosine similarity in embedding space. High =
49    /// items within a category are tightly clustered.
50    pub mean_intra_category_similarity: f64,
51    /// Mean inter-category cosine similarity in embedding space. High =
52    /// categories overlap heavily; low = categories are semantically
53    /// distinct.
54    pub mean_inter_category_similarity: f64,
55    /// `mean_intra / max(mean_inter, eps)`. A ratio-based separation
56    /// signal: values > 1 mean categories separate well in embedding
57    /// space; values near 1 mean the corpus is difficult to partition.
58    pub category_separation_ratio: f64,
59}
60
61/// Length of the vector returned by [`CorpusFeatures::to_vec`].
62pub const CORPUS_FEATURE_COUNT: usize = 10;
63
64impl CorpusFeatures {
65    /// Stable feature names aligned with [`Self::to_vec`]. Useful for
66    /// logging, feature importance reports, and CSV headers.
67    ///
68    /// Note: `category_separation_ratio` is deliberately excluded — it's
69    /// a derived ratio of two features already named here, so including
70    /// it would double-count under any distance metric. See
71    /// [`Self::to_vec`].
72    pub fn feature_names() -> [&'static str; CORPUS_FEATURE_COUNT] {
73        [
74            "n_items",
75            "n_categories",
76            "dim",
77            "mean_members_per_category",
78            "category_size_entropy",
79            "mean_sparsity",
80            "axis_utilization_entropy",
81            "noise_estimate",
82            "mean_intra_category_similarity",
83            "mean_inter_category_similarity",
84        ]
85    }
86
87    /// Fixed-length flattened representation in the order declared by
88    /// [`Self::feature_names`]. Suitable as input to any nearest-neighbor
89    /// or regression meta-model. `category_separation_ratio` is
90    /// intentionally *excluded* because it's a derived ratio of two
91    /// features already in the vector — keeping it in would double-count.
92    pub fn to_vec(&self) -> [f64; CORPUS_FEATURE_COUNT] {
93        [
94            self.n_items as f64,
95            self.n_categories as f64,
96            self.dim as f64,
97            self.mean_members_per_category,
98            self.category_size_entropy,
99            self.mean_sparsity,
100            self.axis_utilization_entropy,
101            self.noise_estimate,
102            self.mean_intra_category_similarity,
103            self.mean_inter_category_similarity,
104        ]
105    }
106
107    /// Extract features from a corpus using default Laplacian config
108    /// (for the `active_threshold` used in sparsity/noise estimation).
109    ///
110    /// Returns an error if the inputs are invalid (empty corpus, mismatched
111    /// lengths, zero-dim embeddings, or ragged rows).
112    pub fn extract(categories: &[String], embeddings: &[Vec<f64>]) -> Result<Self, String> {
113        Self::extract_with_threshold(
114            categories,
115            embeddings,
116            LaplacianConfig::default().active_threshold,
117        )
118    }
119
120    /// Extract features with an explicit active-axis threshold. Use this
121    /// when you want feature values comparable across different Laplacian
122    /// configurations.
123    ///
124    /// Returns an error if the inputs are invalid (empty corpus, mismatched
125    /// lengths, zero-dim embeddings, or ragged rows).
126    pub fn extract_with_threshold(
127        categories: &[String],
128        embeddings: &[Vec<f64>],
129        active_threshold: f64,
130    ) -> Result<Self, String> {
131        if categories.len() != embeddings.len() {
132            return Err(format!(
133                "categories length {} does not match embeddings length {}",
134                categories.len(),
135                embeddings.len()
136            ));
137        }
138        let n = embeddings.len();
139        if n == 0 {
140            return Err("cannot extract features from an empty corpus".into());
141        }
142        let dim = embeddings[0].len();
143        if dim == 0 {
144            return Err("embeddings must have positive dimensionality".into());
145        }
146        for (i, e) in embeddings.iter().enumerate() {
147            if e.len() != dim {
148                return Err(format!(
149                    "ragged embeddings: row {i} length {} != dim {dim}",
150                    e.len()
151                ));
152            }
153        }
154
155        // 1. Category bookkeeping.
156        let mut cat_counts: HashMap<&str, usize> = HashMap::new();
157        for c in categories {
158            *cat_counts.entry(c.as_str()).or_insert(0) += 1;
159        }
160        let n_categories = cat_counts.len();
161        let mean_members_per_category = n as f64 / n_categories.max(1) as f64;
162
163        let category_size_entropy = if n_categories > 1 {
164            let h: f64 = cat_counts
165                .values()
166                .map(|&c| {
167                    let p = c as f64 / n as f64;
168                    if p > 0.0 { -p * p.ln() } else { 0.0 }
169                })
170                .sum();
171            // Normalize by ln(n_categories), the maximum for n_categories bins.
172            h / (n_categories as f64).ln().max(f64::EPSILON)
173        } else {
174            0.0
175        };
176
177        // 2. Per-axis usage + per-item active counts + noise estimate.
178        let mut axis_usage = vec![0usize; dim];
179        let mut active_per_item = vec![0usize; n];
180        let mut noise_sum = 0.0f64;
181        let mut noise_count = 0usize;
182
183        for (i, e) in embeddings.iter().enumerate() {
184            let mut inactive_magnitudes: Vec<f64> = Vec::with_capacity(dim);
185            for (d, &v) in e.iter().enumerate() {
186                if v.abs() > active_threshold {
187                    axis_usage[d] += 1;
188                    active_per_item[i] += 1;
189                } else {
190                    inactive_magnitudes.push(v.abs());
191                }
192            }
193            if !inactive_magnitudes.is_empty() {
194                inactive_magnitudes.sort_by(|a, b| a.total_cmp(b));
195                let median = inactive_magnitudes[inactive_magnitudes.len() / 2];
196                noise_sum += median;
197                noise_count += 1;
198            }
199        }
200
201        let mean_sparsity: f64 =
202            active_per_item.iter().map(|&a| a as f64).sum::<f64>() / (n * dim) as f64;
203
204        let axis_utilization_entropy = {
205            let total: f64 = axis_usage.iter().map(|&c| c as f64).sum();
206            if total > 0.0 && dim > 1 {
207                let h: f64 = axis_usage
208                    .iter()
209                    .map(|&c| {
210                        let p = c as f64 / total;
211                        if p > 0.0 { -p * p.ln() } else { 0.0 }
212                    })
213                    .sum();
214                h / (dim as f64).ln().max(f64::EPSILON)
215            } else {
216                0.0
217            }
218        };
219
220        let noise_estimate = if noise_count > 0 {
221            noise_sum / noise_count as f64
222        } else {
223            0.0
224        };
225
226        // 3. Pairwise intra/inter category similarity.
227        let mean_intra_category_similarity =
228            pairwise_similarity(embeddings, categories, SimilarityMode::IntraCategory);
229        let mean_inter_category_similarity =
230            pairwise_similarity(embeddings, categories, SimilarityMode::InterCategory);
231        let category_separation_ratio =
232            mean_intra_category_similarity / mean_inter_category_similarity.abs().max(1e-12);
233
234        Ok(Self {
235            n_items: n,
236            n_categories,
237            dim,
238            mean_members_per_category,
239            category_size_entropy,
240            mean_sparsity,
241            axis_utilization_entropy,
242            noise_estimate,
243            mean_intra_category_similarity,
244            mean_inter_category_similarity,
245            category_separation_ratio,
246        })
247    }
248}
249
250// ── Similarity helpers ─────────────────────────────────────────────────
251
252#[derive(Copy, Clone)]
253enum SimilarityMode {
254    IntraCategory,
255    InterCategory,
256}
257
258fn pairwise_similarity(
259    embeddings: &[Vec<f64>],
260    categories: &[String],
261    mode: SimilarityMode,
262) -> f64 {
263    use rayon::prelude::*;
264
265    let n = embeddings.len();
266    if n < 2 {
267        return 0.0;
268    }
269
270    // Outer loop parallelized over `i`; each thread scans `j > i`
271    // serially into a `(sum, count)` pair, reduced at the end.
272    // `cosine_similarity` is pure, embeddings are read-only — trivial
273    // rayon pattern. At N = 10k, d = 128 this goes from ~13B serial
274    // ops to ~13B ÷ thread-count, landing in the low seconds instead
275    // of tens of seconds. Below a small threshold we stay serial to
276    // avoid the thread-pool startup cost dominating.
277    const SERIAL_THRESHOLD: usize = 256;
278    if n < SERIAL_THRESHOLD {
279        let mut sum = 0.0;
280        let mut count: usize = 0;
281        for i in 0..n {
282            for j in (i + 1)..n {
283                if pair_matches(mode, &categories[i], &categories[j]) {
284                    // Invariant: extract_with_threshold validates that all rows
285                    // share the same dim before reaching here, so cosine_similarity
286                    // cannot fail on a dimension mismatch.
287                    sum += cosine_similarity(&embeddings[i], &embeddings[j])
288                        .expect("corpus embeddings share fixed dimensionality");
289                    count += 1;
290                }
291            }
292        }
293        return if count == 0 { 0.0 } else { sum / count as f64 };
294    }
295
296    let (sum, count) = (0..n)
297        .into_par_iter()
298        .map(|i| {
299            let mut s = 0.0;
300            let mut c = 0usize;
301            for j in (i + 1)..n {
302                if pair_matches(mode, &categories[i], &categories[j]) {
303                    // Same invariant as the serial branch above.
304                    s += cosine_similarity(&embeddings[i], &embeddings[j])
305                        .expect("corpus embeddings share fixed dimensionality");
306                    c += 1;
307                }
308            }
309            (s, c)
310        })
311        .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb));
312
313    if count == 0 { 0.0 } else { sum / count as f64 }
314}
315
316/// Predicate factored so the serial and parallel branches stay
317/// identical. Inlined by the optimizer; no cost at the call site.
318#[inline]
319fn pair_matches(mode: SimilarityMode, a: &str, b: &str) -> bool {
320    let same = a == b;
321    match mode {
322        SimilarityMode::IntraCategory => same,
323        SimilarityMode::InterCategory => !same,
324    }
325}
326
327// ── Tests ──────────────────────────────────────────────────────────────
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    fn toy_corpus() -> (Vec<String>, Vec<Vec<f64>>) {
334        let categories: Vec<String> = vec![
335            "a".into(),
336            "a".into(),
337            "a".into(),
338            "b".into(),
339            "b".into(),
340            "b".into(),
341        ];
342        let embeddings = vec![
343            vec![1.0, 0.1, 0.0, 0.0, 0.02],
344            vec![0.9, 0.15, 0.0, 0.0, 0.01],
345            vec![0.95, 0.05, 0.0, 0.0, 0.03],
346            vec![0.1, 0.0, 1.0, 0.0, 0.02],
347            vec![0.15, 0.0, 0.9, 0.0, 0.01],
348            vec![0.05, 0.0, 0.95, 0.0, 0.03],
349        ];
350        (categories, embeddings)
351    }
352
353    #[test]
354    fn extract_rejects_empty_corpus() {
355        let result = CorpusFeatures::extract(&[], &[]);
356        assert!(result.is_err());
357        assert!(result.unwrap_err().contains("empty corpus"));
358    }
359
360    #[test]
361    fn extract_rejects_mismatched_lengths() {
362        let cats = vec!["a".to_string()];
363        let embs: Vec<Vec<f64>> = vec![];
364        let result = CorpusFeatures::extract(&cats, &embs);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn extract_rejects_ragged_embeddings() {
370        let cats = vec!["a".to_string(), "b".to_string()];
371        let embs = vec![vec![1.0, 2.0], vec![1.0]];
372        let result = CorpusFeatures::extract(&cats, &embs);
373        assert!(result.is_err());
374        assert!(result.unwrap_err().contains("ragged"));
375    }
376
377    #[test]
378    fn extract_basic_shape() {
379        let (cats, embs) = toy_corpus();
380        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
381        assert_eq!(cf.n_items, 6);
382        assert_eq!(cf.n_categories, 2);
383        assert_eq!(cf.dim, 5);
384        assert!((cf.mean_members_per_category - 3.0).abs() < 1e-12);
385    }
386
387    #[test]
388    fn category_size_entropy_balanced() {
389        // Balanced 3/3 split → maximum entropy = 1.0 (after log normalization).
390        let (cats, embs) = toy_corpus();
391        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
392        assert!(
393            (cf.category_size_entropy - 1.0).abs() < 1e-10,
394            "balanced split should give entropy = 1.0, got {}",
395            cf.category_size_entropy
396        );
397    }
398
399    #[test]
400    fn category_size_entropy_skewed() {
401        // 5/1 split → lower entropy than balanced.
402        let cats: Vec<String> = vec!["a", "a", "a", "a", "a", "b"]
403            .into_iter()
404            .map(Into::into)
405            .collect();
406        let embs = vec![vec![1.0, 0.0, 0.0]; 6];
407        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
408        assert!(
409            cf.category_size_entropy < 0.9,
410            "skewed split should give entropy < 0.9, got {}",
411            cf.category_size_entropy
412        );
413    }
414
415    #[test]
416    fn sparsity_matches_threshold() {
417        let (cats, embs) = toy_corpus();
418        // With threshold 0.05: each 5-dim vector has 2 active axes → sparsity 2/5 = 0.4
419        let cf = CorpusFeatures::extract_with_threshold(&cats, &embs, 0.05).unwrap();
420        assert!(
421            (cf.mean_sparsity - 0.4).abs() < 0.11,
422            "expected ~0.4, got {}",
423            cf.mean_sparsity
424        );
425    }
426
427    #[test]
428    fn intra_higher_than_inter_for_well_separated() {
429        let (cats, embs) = toy_corpus();
430        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
431        assert!(
432            cf.mean_intra_category_similarity > cf.mean_inter_category_similarity,
433            "expected intra > inter on well-separated corpus"
434        );
435        assert!(cf.category_separation_ratio > 1.0);
436    }
437
438    #[test]
439    fn to_vec_length_matches_feature_names() {
440        let (cats, embs) = toy_corpus();
441        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
442        assert_eq!(cf.to_vec().len(), CorpusFeatures::feature_names().len());
443        assert_eq!(cf.to_vec().len(), CORPUS_FEATURE_COUNT);
444    }
445
446    #[test]
447    fn features_serialize_json_roundtrip() {
448        let (cats, embs) = toy_corpus();
449        let cf = CorpusFeatures::extract(&cats, &embs).unwrap();
450        let json = serde_json::to_string(&cf).unwrap();
451        let back: CorpusFeatures = serde_json::from_str(&json).unwrap();
452        assert_eq!(cf.n_items, back.n_items);
453        assert_eq!(cf.n_categories, back.n_categories);
454        assert!(
455            (cf.mean_intra_category_similarity - back.mean_intra_category_similarity).abs() < 1e-12
456        );
457    }
458
459    #[test]
460    fn empty_inactive_sets_produce_zero_noise() {
461        // All axes active — noise_estimate defaults to 0.
462        let cats: Vec<String> = vec!["a".into(), "a".into()];
463        let embs = vec![vec![1.0, 1.0], vec![0.9, 0.9]];
464        let cf = CorpusFeatures::extract_with_threshold(&cats, &embs, 0.05).unwrap();
465        assert_eq!(cf.noise_estimate, 0.0);
466    }
467}