Skip to main content

sphereql_embed/
category.rs

1use std::collections::HashMap;
2
3use sphereql_core::{SphericalPoint, angular_distance};
4
5use crate::kernel_pca::KernelPcaProjection;
6use crate::projection::{PcaProjection, Projection};
7use crate::types::{Embedding, RadialStrategy};
8
9// ── Thresholds ─────────────────────────────────────────────────────────
10
11/// Minimum category size to consider fitting an inner sphere.
12const MIN_INNER_SPHERE_SIZE: usize = 20;
13
14/// Minimum EVR improvement (inner − global_subset) to justify an inner sphere.
15const MIN_EVR_IMPROVEMENT: f64 = 0.10;
16
17/// Minimum category size to consider kernel PCA for the inner sphere.
18const KERNEL_PCA_MIN_SIZE: usize = 80;
19
20/// Minimum EVR improvement of kernel PCA over linear PCA to choose it.
21const MIN_KERNEL_IMPROVEMENT: f64 = 0.05;
22
23// ── Category summary ───────────────────────────────────────────────────
24
25/// Aggregate statistics for a single category on the outer sphere.
26///
27/// Computed from the projected positions of all items in that category.
28/// Every category gets a summary regardless of size — this is the
29/// foundation of the Category Enrichment Layer.
30#[derive(Debug, Clone)]
31pub struct CategorySummary {
32    /// Category name (as provided by the user).
33    pub name: String,
34    /// Indices of member items in the pipeline's item list.
35    pub member_indices: Vec<usize>,
36    /// Mean embedding in high-dimensional space (pre-projection).
37    /// Length = embedding dimensionality.
38    pub centroid_embedding: Vec<f64>,
39    /// The centroid projected onto the outer sphere.
40    pub centroid_position: SphericalPoint,
41    /// Mean angular distance (radians) of members from the centroid
42    /// on the projected sphere. Measures how "spread out" the category is.
43    pub angular_spread: f64,
44    /// 1.0 / (1.0 + angular_spread). Higher = tighter cluster.
45    /// Normalized to (0, 1].
46    pub cohesion: f64,
47    /// Number of member items.
48    pub member_count: usize,
49}
50
51// ── Bridge items ───────────────────────────────────────────────────────
52
53/// An item that semantically spans two categories.
54///
55/// Bridge items are closer to a foreign category's centroid than to the
56/// median distance within their own category. They are the conceptual
57/// connectors that make cross-domain paths meaningful.
58#[derive(Debug, Clone)]
59pub struct BridgeItem {
60    /// Index of this item in the pipeline's item list.
61    pub item_index: usize,
62    /// The item's own category index.
63    pub source_category: usize,
64    /// The foreign category this item bridges toward.
65    pub target_category: usize,
66    /// Cosine similarity to own category centroid (in high-D space).
67    pub affinity_to_source: f64,
68    /// Cosine similarity to foreign category centroid (in high-D space).
69    pub affinity_to_target: f64,
70    /// Bridge strength: harmonic mean of the two affinities.
71    /// Higher = equally strong connection to both domains.
72    pub bridge_strength: f64,
73}
74
75// ── Category graph ─────────────────────────────────────────────────────
76
77/// Edge in the category adjacency graph.
78#[derive(Debug, Clone)]
79pub struct CategoryEdge {
80    /// Index of the neighbor category.
81    pub target: usize,
82    /// Angular distance between category centroids on the sphere.
83    pub centroid_distance: f64,
84    /// Number of bridge items connecting these two categories.
85    pub bridge_count: usize,
86    /// Combined edge weight (lower = more connected).
87    /// Computed as centroid_distance / (1 + bridge_count).
88    pub weight: f64,
89}
90
91/// The full category adjacency graph.
92#[derive(Debug, Clone)]
93pub struct CategoryGraph {
94    /// Adjacency list: `adjacency[i]` contains edges from category i.
95    pub adjacency: Vec<Vec<CategoryEdge>>,
96    /// Bridge items keyed by (source_category, target_category).
97    /// Sorted by descending bridge_strength within each pair.
98    pub bridges: HashMap<(usize, usize), Vec<BridgeItem>>,
99}
100
101// ── Category-level concept path ────────────────────────────────────────
102
103/// A step in a category-level concept path.
104#[derive(Debug, Clone)]
105pub struct CategoryPathStep {
106    /// Category index.
107    pub category_index: usize,
108    /// Category name.
109    pub category_name: String,
110    /// Cumulative distance from the start.
111    pub cumulative_distance: f64,
112    /// Bridge items connecting this step to the next (empty for the last step).
113    pub bridges_to_next: Vec<BridgeItem>,
114}
115
116/// Result of a category-level concept path query.
117#[derive(Debug, Clone)]
118pub struct CategoryPath {
119    /// Ordered steps from source to target category.
120    pub steps: Vec<CategoryPathStep>,
121    /// Total path distance.
122    pub total_distance: f64,
123}
124
125// ── Inner sphere (Phase 2) ─────────────────────────────────────────────
126
127/// The projection type used for a category's inner sphere.
128///
129/// Wraps either a linear PCA or kernel PCA projection, chosen
130/// automatically based on the category's size and measured EVR
131/// improvement over the global projection.
132#[derive(Clone)]
133pub enum InnerProjection {
134    /// Standard linear PCA — used for categories with 20–79 members,
135    /// or when kernel PCA doesn't improve over linear.
136    LinearPca(PcaProjection),
137    /// Gaussian kernel PCA — used for categories with ≥80 members
138    /// where kernel PCA measurably outperforms linear PCA.
139    KernelPca(KernelPcaProjection),
140}
141
142impl std::fmt::Debug for InnerProjection {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            Self::LinearPca(_) => write!(f, "LinearPca"),
146            Self::KernelPca(_) => write!(f, "KernelPca"),
147        }
148    }
149}
150
151impl Projection for InnerProjection {
152    fn project(&self, embedding: &Embedding) -> SphericalPoint {
153        match self {
154            Self::LinearPca(p) => p.project(embedding),
155            Self::KernelPca(p) => p.project(embedding),
156        }
157    }
158    fn project_rich(&self, embedding: &Embedding) -> crate::types::ProjectedPoint {
159        match self {
160            Self::LinearPca(p) => p.project_rich(embedding),
161            Self::KernelPca(p) => p.project_rich(embedding),
162        }
163    }
164    fn dimensionality(&self) -> usize {
165        match self {
166            Self::LinearPca(p) => p.dimensionality(),
167            Self::KernelPca(p) => p.dimensionality(),
168        }
169    }
170}
171
172/// A category-specific inner sphere with its own optimized projection.
173///
174/// Only created for categories that meet all of:
175/// 1. At least `MIN_INNER_SPHERE_SIZE` members
176/// 2. Inner EVR improves over global subset EVR by ≥ `MIN_EVR_IMPROVEMENT`
177///
178/// The inner sphere gives higher-resolution angular discrimination
179/// within the category than the global outer projection can provide.
180#[derive(Debug, Clone)]
181pub struct InnerSphere {
182    /// The category-specific projection (linear PCA or kernel PCA).
183    pub projection: InnerProjection,
184    /// Positions of member items in the inner sphere's coordinate system.
185    /// `inner_positions[i]` corresponds to `member_indices[i]`.
186    pub inner_positions: Vec<SphericalPoint>,
187    /// Global item indices of the members (same order as inner_positions).
188    pub member_indices: Vec<usize>,
189    /// Explained variance ratio of the inner projection.
190    pub explained_variance_ratio: f64,
191    /// Mean certainty of these items under the global (outer) projection.
192    /// Baseline for measuring improvement.
193    pub global_subset_evr: f64,
194    /// `explained_variance_ratio - global_subset_evr`.
195    pub evr_improvement: f64,
196}
197
198/// A single item from a [`drill_down`](CategoryLayer::drill_down) query.
199#[derive(Debug, Clone)]
200pub struct DrillDownResult {
201    /// Index of the item in the pipeline's global item list.
202    pub item_index: usize,
203    /// Angular distance to the query in the relevant coordinate system.
204    pub distance: f64,
205    /// Whether the inner sphere's projection was used (true) or the
206    /// outer sphere was used as fallback (false).
207    pub used_inner_sphere: bool,
208}
209
210/// Stats for a single inner sphere, returned by
211/// [`inner_sphere_stats`](CategoryLayer::inner_sphere_stats).
212#[derive(Debug, Clone)]
213pub struct InnerSphereReport {
214    /// Category name.
215    pub category_name: String,
216    /// Category index.
217    pub category_index: usize,
218    /// Number of members in the inner sphere.
219    pub member_count: usize,
220    /// `"LinearPca"` or `"KernelPca"`.
221    pub projection_type: &'static str,
222    /// Explained variance ratio of the inner projection.
223    pub inner_evr: f64,
224    /// Mean certainty of members under the global projection.
225    pub global_subset_evr: f64,
226    /// EVR improvement over global.
227    pub evr_improvement: f64,
228}
229
230// ── The enrichment layer ───────────────────────────────────────────────
231
232/// Category Enrichment Layer: aggregate statistics, inter-category graph,
233/// bridge item detection, and automatic inner spheres over a projected
234/// SphereQL corpus.
235///
236/// This is a read-only structure computed from an existing pipeline's
237/// data. It adds category-level reasoning without modifying the
238/// underlying projection or spatial index.
239#[derive(Debug, Clone)]
240pub struct CategoryLayer {
241    /// One summary per unique category, in insertion order.
242    pub summaries: Vec<CategorySummary>,
243    /// Map from category name to index in `summaries`.
244    pub name_to_index: HashMap<String, usize>,
245    /// The inter-category adjacency graph.
246    pub graph: CategoryGraph,
247    /// Outer-sphere positions for all items (same indexing as embeddings).
248    outer_positions: Vec<SphericalPoint>,
249    /// Inner spheres keyed by category index. Only present for categories
250    /// that meet the size and EVR-improvement thresholds.
251    pub inner_spheres: HashMap<usize, InnerSphere>,
252}
253
254impl CategoryLayer {
255    /// Build the category enrichment layer from pipeline data.
256    ///
257    /// - `categories[i]` is the category name for item i.
258    /// - `embeddings[i]` is the raw embedding for item i.
259    /// - `projected_positions[i]` is the spherical position on the outer sphere.
260    /// - `projection` is used to project category centroids and measure
261    ///   per-point certainty for inner sphere threshold decisions.
262    ///
263    /// Inner spheres are automatically constructed for categories that:
264    /// 1. Have ≥ 20 members
265    /// 2. Show ≥ 0.10 EVR improvement over the global projection
266    ///
267    /// Categories with ≥ 80 members additionally try kernel PCA and
268    /// select it if it improves EVR by ≥ 0.05 over linear PCA.
269    ///
270    /// O(N·C + C²) for the base layer, plus O(n_c²·d) per inner sphere.
271    pub fn build<P: Projection>(
272        categories: &[String],
273        embeddings: &[Embedding],
274        projected_positions: &[SphericalPoint],
275        projection: &P,
276    ) -> Self {
277        let n = categories.len();
278        assert_eq!(n, embeddings.len());
279        assert_eq!(n, projected_positions.len());
280
281        // 1. Discover unique categories and group member indices
282        let mut name_to_index: HashMap<String, usize> = HashMap::new();
283        let mut cat_names: Vec<String> = Vec::new();
284        let mut cat_members: Vec<Vec<usize>> = Vec::new();
285
286        for (i, cat) in categories.iter().enumerate() {
287            let idx = if let Some(&idx) = name_to_index.get(cat) {
288                idx
289            } else {
290                let idx = cat_names.len();
291                name_to_index.insert(cat.clone(), idx);
292                cat_names.push(cat.clone());
293                cat_members.push(Vec::new());
294                idx
295            };
296            cat_members[idx].push(i);
297        }
298
299        let num_cats = cat_names.len();
300        let dim = if n > 0 { embeddings[0].dimension() } else { 0 };
301
302        // 2. Compute category summaries
303        let mut summaries: Vec<CategorySummary> = Vec::with_capacity(num_cats);
304
305        for (ci, name) in cat_names.iter().enumerate() {
306            let members = &cat_members[ci];
307            let count = members.len();
308
309            // Centroid in high-D space
310            let mut centroid_emb = vec![0.0; dim];
311            for &mi in members {
312                for (j, &v) in embeddings[mi].values.iter().enumerate() {
313                    centroid_emb[j] += v;
314                }
315            }
316            if count > 0 {
317                for v in &mut centroid_emb {
318                    *v /= count as f64;
319                }
320            }
321
322            // Project the centroid
323            let centroid_embedding_obj = Embedding::new(centroid_emb.clone());
324            let centroid_position = projection.project(&centroid_embedding_obj);
325
326            // Angular spread: mean angular distance of members from centroid
327            let angular_spread = if count > 1 {
328                let total: f64 = members
329                    .iter()
330                    .map(|&mi| angular_distance(&projected_positions[mi], &centroid_position))
331                    .sum();
332                total / count as f64
333            } else {
334                0.0
335            };
336
337            let cohesion = 1.0 / (1.0 + angular_spread);
338
339            summaries.push(CategorySummary {
340                name: name.clone(),
341                member_indices: members.clone(),
342                centroid_embedding: centroid_emb,
343                centroid_position,
344                angular_spread,
345                cohesion,
346                member_count: count,
347            });
348        }
349
350        // 3. Build category graph + detect bridges
351        let graph = Self::build_graph(&summaries, embeddings, num_cats);
352
353        // 4. Build inner spheres for qualifying categories (Phase 2)
354        let inner_spheres = Self::build_inner_spheres(&summaries, embeddings, projection);
355
356        CategoryLayer {
357            summaries,
358            name_to_index,
359            graph,
360            outer_positions: projected_positions.to_vec(),
361            inner_spheres,
362        }
363    }
364
365    /// Build the inter-category adjacency graph and detect bridge items.
366    fn build_graph(
367        summaries: &[CategorySummary],
368        embeddings: &[Embedding],
369        num_cats: usize,
370    ) -> CategoryGraph {
371        // Precompute centroid pairwise distances
372        let mut centroid_dists = vec![vec![0.0; num_cats]; num_cats];
373        for i in 0..num_cats {
374            for j in (i + 1)..num_cats {
375                let d = angular_distance(
376                    &summaries[i].centroid_position,
377                    &summaries[j].centroid_position,
378                );
379                centroid_dists[i][j] = d;
380                centroid_dists[j][i] = d;
381            }
382        }
383
384        // Detect bridge items
385        let mut bridges: HashMap<(usize, usize), Vec<BridgeItem>> = HashMap::new();
386
387        for (ci, summary) in summaries.iter().enumerate() {
388            let centroid_a = &summary.centroid_embedding;
389
390            for &mi in &summary.member_indices {
391                let item_emb = &embeddings[mi];
392                let sim_to_own = cosine_similarity(&item_emb.values, centroid_a);
393
394                for (cj, other_summary) in summaries.iter().enumerate() {
395                    if ci == cj {
396                        continue;
397                    }
398
399                    let sim_to_other =
400                        cosine_similarity(&item_emb.values, &other_summary.centroid_embedding);
401
402                    if sim_to_other > 0.0 && sim_to_other > sim_to_own * 0.5 {
403                        let bridge_strength = if sim_to_own + sim_to_other > f64::EPSILON {
404                            2.0 * sim_to_own * sim_to_other / (sim_to_own + sim_to_other)
405                        } else {
406                            0.0
407                        };
408
409                        bridges.entry((ci, cj)).or_default().push(BridgeItem {
410                            item_index: mi,
411                            source_category: ci,
412                            target_category: cj,
413                            affinity_to_source: sim_to_own,
414                            affinity_to_target: sim_to_other,
415                            bridge_strength,
416                        });
417                    }
418                }
419            }
420        }
421
422        for list in bridges.values_mut() {
423            list.sort_by(|a, b| {
424                b.bridge_strength
425                    .partial_cmp(&a.bridge_strength)
426                    .unwrap_or(std::cmp::Ordering::Equal)
427            });
428        }
429
430        let mut adjacency: Vec<Vec<CategoryEdge>> = vec![Vec::new(); num_cats];
431        for i in 0..num_cats {
432            for (j, &cd) in centroid_dists[i].iter().enumerate() {
433                if i == j {
434                    continue;
435                }
436                let bridge_count = bridges.get(&(i, j)).map_or(0, |b| b.len());
437                let weight = cd / (1.0 + bridge_count as f64);
438
439                adjacency[i].push(CategoryEdge {
440                    target: j,
441                    centroid_distance: cd,
442                    bridge_count,
443                    weight,
444                });
445            }
446            adjacency[i].sort_by(|a, b| {
447                a.weight
448                    .partial_cmp(&b.weight)
449                    .unwrap_or(std::cmp::Ordering::Equal)
450            });
451        }
452
453        CategoryGraph { adjacency, bridges }
454    }
455
456    /// Evaluate each category and build inner spheres where they help.
457    fn build_inner_spheres<P: Projection>(
458        summaries: &[CategorySummary],
459        embeddings: &[Embedding],
460        projection: &P,
461    ) -> HashMap<usize, InnerSphere> {
462        let mut result = HashMap::new();
463
464        for (ci, summary) in summaries.iter().enumerate() {
465            if summary.member_count < MIN_INNER_SPHERE_SIZE {
466                continue;
467            }
468
469            let member_embs: Vec<Embedding> = summary
470                .member_indices
471                .iter()
472                .map(|&i| embeddings[i].clone())
473                .collect();
474
475            // Global subset EVR: mean certainty under global projection
476            let global_subset_evr: f64 = member_embs
477                .iter()
478                .map(|e| projection.project_rich(e).certainty)
479                .sum::<f64>()
480                / member_embs.len() as f64;
481
482            // Fit inner linear PCA
483            let inner_pca = PcaProjection::fit(&member_embs, RadialStrategy::Fixed(1.0));
484            let inner_linear_evr = inner_pca.explained_variance_ratio();
485
486            if inner_linear_evr - global_subset_evr < MIN_EVR_IMPROVEMENT {
487                continue;
488            }
489
490            let (inner_proj, inner_evr) = if summary.member_count >= KERNEL_PCA_MIN_SIZE {
491                let inner_kpca = KernelPcaProjection::fit(&member_embs, RadialStrategy::Fixed(1.0));
492                let kernel_evr = inner_kpca.explained_variance_ratio();
493
494                if kernel_evr > inner_linear_evr + MIN_KERNEL_IMPROVEMENT {
495                    (InnerProjection::KernelPca(inner_kpca), kernel_evr)
496                } else {
497                    (InnerProjection::LinearPca(inner_pca), inner_linear_evr)
498                }
499            } else {
500                (InnerProjection::LinearPca(inner_pca), inner_linear_evr)
501            };
502
503            let inner_positions: Vec<SphericalPoint> =
504                member_embs.iter().map(|e| inner_proj.project(e)).collect();
505
506            result.insert(
507                ci,
508                InnerSphere {
509                    projection: inner_proj,
510                    inner_positions,
511                    member_indices: summary.member_indices.clone(),
512                    explained_variance_ratio: inner_evr,
513                    global_subset_evr,
514                    evr_improvement: inner_evr - global_subset_evr,
515                },
516            );
517        }
518
519        result
520    }
521
522    // ── Phase 1 query methods (unchanged) ──────────────────────────────
523
524    /// Number of categories.
525    pub fn num_categories(&self) -> usize {
526        self.summaries.len()
527    }
528
529    /// Look up a category by name.
530    pub fn get_category(&self, name: &str) -> Option<&CategorySummary> {
531        self.name_to_index
532            .get(name)
533            .map(|&idx| &self.summaries[idx])
534    }
535
536    /// Get the k nearest neighbor categories to the given category.
537    pub fn category_neighbors(&self, category_name: &str, k: usize) -> Vec<&CategorySummary> {
538        let Some(&ci) = self.name_to_index.get(category_name) else {
539            return Vec::new();
540        };
541        self.graph.adjacency[ci]
542            .iter()
543            .take(k)
544            .map(|edge| &self.summaries[edge.target])
545            .collect()
546    }
547
548    /// Get bridge items between two categories.
549    pub fn bridge_items(
550        &self,
551        source_category: &str,
552        target_category: &str,
553        max_bridges: usize,
554    ) -> Vec<&BridgeItem> {
555        let Some(&si) = self.name_to_index.get(source_category) else {
556            return Vec::new();
557        };
558        let Some(&ti) = self.name_to_index.get(target_category) else {
559            return Vec::new();
560        };
561        self.graph
562            .bridges
563            .get(&(si, ti))
564            .map(|list| list.iter().take(max_bridges).collect())
565            .unwrap_or_default()
566    }
567
568    /// Find the shortest path between two categories through the category graph.
569    pub fn category_path(
570        &self,
571        source_category: &str,
572        target_category: &str,
573    ) -> Option<CategoryPath> {
574        let &si = self.name_to_index.get(source_category)?;
575        let &ti = self.name_to_index.get(target_category)?;
576        if si == ti {
577            return Some(CategoryPath {
578                steps: vec![CategoryPathStep {
579                    category_index: si,
580                    category_name: self.summaries[si].name.clone(),
581                    cumulative_distance: 0.0,
582                    bridges_to_next: Vec::new(),
583                }],
584                total_distance: 0.0,
585            });
586        }
587
588        let n = self.summaries.len();
589        let mut dist = vec![f64::INFINITY; n];
590        let mut prev: Vec<Option<usize>> = vec![None; n];
591        let mut visited = vec![false; n];
592
593        dist[si] = 0.0;
594
595        for _ in 0..n {
596            let mut u = None;
597            let mut best = f64::INFINITY;
598            for (i, (&d, &v)) in dist.iter().zip(visited.iter()).enumerate() {
599                if !v && d < best {
600                    best = d;
601                    u = Some(i);
602                }
603            }
604            let Some(u) = u else { break };
605            if u == ti {
606                break;
607            }
608            visited[u] = true;
609
610            for edge in &self.graph.adjacency[u] {
611                let nd = dist[u] + edge.weight;
612                if nd < dist[edge.target] {
613                    dist[edge.target] = nd;
614                    prev[edge.target] = Some(u);
615                }
616            }
617        }
618
619        if dist[ti].is_infinite() {
620            return None;
621        }
622
623        let mut path_indices = Vec::new();
624        let mut cur = ti;
625        loop {
626            path_indices.push(cur);
627            match prev[cur] {
628                Some(p) => cur = p,
629                None => break,
630            }
631        }
632        path_indices.reverse();
633
634        let mut steps = Vec::with_capacity(path_indices.len());
635        for (step_idx, &ci) in path_indices.iter().enumerate() {
636            let bridges_to_next = if step_idx + 1 < path_indices.len() {
637                let next_ci = path_indices[step_idx + 1];
638                self.graph
639                    .bridges
640                    .get(&(ci, next_ci))
641                    .map(|list| list.iter().take(3).cloned().collect())
642                    .unwrap_or_default()
643            } else {
644                Vec::new()
645            };
646
647            steps.push(CategoryPathStep {
648                category_index: ci,
649                category_name: self.summaries[ci].name.clone(),
650                cumulative_distance: dist[ci],
651                bridges_to_next,
652            });
653        }
654
655        Some(CategoryPath {
656            total_distance: dist[ti],
657            steps,
658        })
659    }
660
661    /// Find all categories whose centroid is within `max_angle` radians
662    /// of the given embedding's projected position.
663    pub fn categories_near_embedding<P: Projection>(
664        &self,
665        embedding: &Embedding,
666        projection: &P,
667        max_angle: f64,
668    ) -> Vec<(usize, f64)> {
669        let pos = projection.project(embedding);
670        let mut results: Vec<(usize, f64)> = self
671            .summaries
672            .iter()
673            .enumerate()
674            .map(|(i, s)| (i, angular_distance(&pos, &s.centroid_position)))
675            .filter(|&(_, d)| d <= max_angle)
676            .collect();
677        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
678        results
679    }
680
681    // ── Phase 2 query methods (inner spheres) ──────────────────────────
682
683    /// Whether a given category has an inner sphere.
684    pub fn has_inner_sphere(&self, category_name: &str) -> bool {
685        self.name_to_index
686            .get(category_name)
687            .is_some_and(|&ci| self.inner_spheres.contains_key(&ci))
688    }
689
690    /// Get the inner sphere for a category, if one exists.
691    pub fn get_inner_sphere(&self, category_name: &str) -> Option<&InnerSphere> {
692        self.name_to_index
693            .get(category_name)
694            .and_then(|&ci| self.inner_spheres.get(&ci))
695    }
696
697    /// Number of categories that have inner spheres.
698    pub fn num_inner_spheres(&self) -> usize {
699        self.inner_spheres.len()
700    }
701
702    /// Drill down into a category: find the k nearest members to a query
703    /// embedding, using the inner sphere's projection if available.
704    ///
705    /// Falls back to angular distance from the category centroid on the
706    /// outer sphere when no inner sphere exists.
707    pub fn drill_down(
708        &self,
709        category_name: &str,
710        embedding: &Embedding,
711        k: usize,
712    ) -> Vec<DrillDownResult> {
713        let Some(&ci) = self.name_to_index.get(category_name) else {
714            return Vec::new();
715        };
716        let summary = &self.summaries[ci];
717
718        if let Some(inner) = self.inner_spheres.get(&ci) {
719            let query_pos = inner.projection.project(embedding);
720            let mut results: Vec<DrillDownResult> = inner
721                .inner_positions
722                .iter()
723                .enumerate()
724                .map(|(local_idx, pos)| DrillDownResult {
725                    item_index: inner.member_indices[local_idx],
726                    distance: angular_distance(&query_pos, pos),
727                    used_inner_sphere: true,
728                })
729                .collect();
730            results.sort_by(|a, b| {
731                a.distance
732                    .partial_cmp(&b.distance)
733                    .unwrap_or(std::cmp::Ordering::Equal)
734            });
735            results.truncate(k);
736            results
737        } else {
738            // Fallback: rank by distance from category centroid on outer sphere
739            let centroid = &summary.centroid_position;
740            let mut results: Vec<DrillDownResult> = summary
741                .member_indices
742                .iter()
743                .map(|&mi| DrillDownResult {
744                    item_index: mi,
745                    distance: angular_distance(&self.outer_positions[mi], centroid),
746                    used_inner_sphere: false,
747                })
748                .collect();
749            results.sort_by(|a, b| {
750                a.distance
751                    .partial_cmp(&b.distance)
752                    .unwrap_or(std::cmp::Ordering::Equal)
753            });
754            results.truncate(k);
755            results
756        }
757    }
758
759    /// Drill down with an explicit outer projection for the fallback case.
760    ///
761    /// When no inner sphere exists, the query is projected using the
762    /// provided projection and compared against stored outer positions.
763    pub fn drill_down_with_projection<P: Projection>(
764        &self,
765        category_name: &str,
766        embedding: &Embedding,
767        projection: &P,
768        k: usize,
769    ) -> Vec<DrillDownResult> {
770        let Some(&ci) = self.name_to_index.get(category_name) else {
771            return Vec::new();
772        };
773        let summary = &self.summaries[ci];
774
775        if let Some(inner) = self.inner_spheres.get(&ci) {
776            let query_pos = inner.projection.project(embedding);
777            let mut results: Vec<DrillDownResult> = inner
778                .inner_positions
779                .iter()
780                .enumerate()
781                .map(|(local_idx, pos)| DrillDownResult {
782                    item_index: inner.member_indices[local_idx],
783                    distance: angular_distance(&query_pos, pos),
784                    used_inner_sphere: true,
785                })
786                .collect();
787            results.sort_by(|a, b| {
788                a.distance
789                    .partial_cmp(&b.distance)
790                    .unwrap_or(std::cmp::Ordering::Equal)
791            });
792            results.truncate(k);
793            results
794        } else {
795            let query_pos = projection.project(embedding);
796            let mut results: Vec<DrillDownResult> = summary
797                .member_indices
798                .iter()
799                .map(|&mi| DrillDownResult {
800                    item_index: mi,
801                    distance: angular_distance(&self.outer_positions[mi], &query_pos),
802                    used_inner_sphere: false,
803                })
804                .collect();
805            results.sort_by(|a, b| {
806                a.distance
807                    .partial_cmp(&b.distance)
808                    .unwrap_or(std::cmp::Ordering::Equal)
809            });
810            results.truncate(k);
811            results
812        }
813    }
814
815    /// Report which categories have inner spheres, their projection type,
816    /// and EVR metrics.
817    pub fn inner_sphere_stats(&self) -> Vec<InnerSphereReport> {
818        let mut reports: Vec<InnerSphereReport> = self
819            .inner_spheres
820            .iter()
821            .map(|(&ci, inner)| {
822                let proj_type = match &inner.projection {
823                    InnerProjection::LinearPca(_) => "LinearPca",
824                    InnerProjection::KernelPca(_) => "KernelPca",
825                };
826                InnerSphereReport {
827                    category_name: self.summaries[ci].name.clone(),
828                    category_index: ci,
829                    member_count: inner.member_indices.len(),
830                    projection_type: proj_type,
831                    inner_evr: inner.explained_variance_ratio,
832                    global_subset_evr: inner.global_subset_evr,
833                    evr_improvement: inner.evr_improvement,
834                }
835            })
836            .collect();
837        reports.sort_by_key(|r| r.category_index);
838        reports
839    }
840}
841
842// ── Helpers ────────────────────────────────────────────────────────────
843
844fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
845    let dot: f64 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
846    let mag_a = a.iter().map(|x| x * x).sum::<f64>().sqrt();
847    let mag_b = b.iter().map(|x| x * x).sum::<f64>().sqrt();
848    let denom = mag_a * mag_b;
849    if denom < f64::EPSILON {
850        return 0.0;
851    }
852    (dot / denom).clamp(-1.0, 1.0)
853}
854
855// ── Tests ──────────────────────────────────────────────────────────────
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    fn emb(vals: &[f64]) -> Embedding {
862        Embedding::new(vals.to_vec())
863    }
864
865    // --- Phase 1 test helpers ---
866
867    fn test_corpus() -> (Vec<String>, Vec<Embedding>) {
868        let categories = vec![
869            "science".into(),
870            "science".into(),
871            "science".into(),
872            "science".into(),
873            "cooking".into(),
874            "cooking".into(),
875            "cooking".into(),
876            "cooking".into(),
877            "music".into(),
878            "music".into(),
879            "music".into(),
880            "music".into(),
881        ];
882        let embeddings = vec![
883            emb(&[1.0, 0.1, 0.0, 0.05, 0.02]),
884            emb(&[0.9, 0.15, 0.05, 0.03, 0.01]),
885            emb(&[0.95, 0.05, 0.1, 0.04, 0.03]),
886            emb(&[0.85, 0.2, 0.0, 0.06, 0.01]),
887            emb(&[0.1, 1.0, 0.0, 0.02, 0.05]),
888            emb(&[0.15, 0.9, 0.05, 0.03, 0.04]),
889            emb(&[0.05, 0.95, 0.1, 0.01, 0.06]),
890            emb(&[0.2, 0.85, 0.0, 0.04, 0.03]),
891            emb(&[0.0, 0.1, 1.0, 0.05, 0.02]),
892            emb(&[0.05, 0.15, 0.9, 0.03, 0.01]),
893            emb(&[0.1, 0.05, 0.95, 0.04, 0.03]),
894            emb(&[0.0, 0.2, 0.85, 0.06, 0.01]),
895        ];
896        (categories, embeddings)
897    }
898
899    fn build_test_layer() -> (CategoryLayer, Vec<Embedding>, PcaProjection) {
900        let (categories, embeddings) = test_corpus();
901        let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0));
902        let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
903        let layer = CategoryLayer::build(&categories, &embeddings, &projected, &pca);
904        (layer, embeddings, pca)
905    }
906
907    // --- Phase 2 test helpers ---
908
909    fn large_category_corpus() -> (Vec<String>, Vec<Embedding>) {
910        let mut categories = Vec::new();
911        let mut embeddings = Vec::new();
912
913        for i in 0..25 {
914            categories.push("big".into());
915            let t = i as f64 / 25.0;
916            let mut v = vec![0.0; 10];
917            v[0] = 1.0 + 0.3 * (t * std::f64::consts::TAU).sin();
918            v[1] = 0.5 + 0.3 * (t * std::f64::consts::TAU).cos();
919            v[2] = 0.2 * t;
920            for (d, slot) in v.iter_mut().enumerate().take(10).skip(3) {
921                *slot = 0.01 * ((i * 7 + d) as f64 % 1.0);
922            }
923            embeddings.push(emb(&v));
924        }
925
926        for i in 0..4 {
927            categories.push("small_a".into());
928            let mut v = vec![0.0; 10];
929            v[5] = 1.0 + 0.1 * i as f64;
930            v[6] = 0.05;
931            embeddings.push(emb(&v));
932        }
933
934        for i in 0..4 {
935            categories.push("small_b".into());
936            let mut v = vec![0.0; 10];
937            v[8] = 1.0 + 0.1 * i as f64;
938            v[9] = 0.05;
939            embeddings.push(emb(&v));
940        }
941
942        (categories, embeddings)
943    }
944
945    fn build_large_test_layer() -> (CategoryLayer, Vec<Embedding>, PcaProjection) {
946        let (categories, embeddings) = large_category_corpus();
947        let pca = PcaProjection::fit(&embeddings, RadialStrategy::Fixed(1.0));
948        let projected: Vec<SphericalPoint> = embeddings.iter().map(|e| pca.project(e)).collect();
949        let layer = CategoryLayer::build(&categories, &embeddings, &projected, &pca);
950        (layer, embeddings, pca)
951    }
952
953    // ======== Phase 1 tests (unchanged) ========
954
955    #[test]
956    fn builds_correct_number_of_categories() {
957        let (layer, _, _) = build_test_layer();
958        assert_eq!(layer.num_categories(), 3);
959    }
960
961    #[test]
962    fn category_names_correct() {
963        let (layer, _, _) = build_test_layer();
964        let names: Vec<&str> = layer.summaries.iter().map(|s| s.name.as_str()).collect();
965        assert!(names.contains(&"science"));
966        assert!(names.contains(&"cooking"));
967        assert!(names.contains(&"music"));
968    }
969
970    #[test]
971    fn member_counts_correct() {
972        let (layer, _, _) = build_test_layer();
973        for summary in &layer.summaries {
974            assert_eq!(summary.member_count, 4);
975            assert_eq!(summary.member_indices.len(), 4);
976        }
977    }
978
979    #[test]
980    fn centroid_embedding_is_mean() {
981        let (layer, embeddings, _) = build_test_layer();
982        let science = layer.get_category("science").unwrap();
983        let mut expected = vec![0.0; 5];
984        for emb in embeddings.iter().take(4) {
985            for (j, &v) in emb.values.iter().enumerate() {
986                expected[j] += v;
987            }
988        }
989        for v in &mut expected {
990            *v /= 4.0;
991        }
992        for (j, (&actual, &exp)) in science
993            .centroid_embedding
994            .iter()
995            .zip(expected.iter())
996            .enumerate()
997        {
998            assert!(
999                (actual - exp).abs() < 1e-10,
1000                "centroid dim {j}: {actual} != {exp}"
1001            );
1002        }
1003    }
1004
1005    #[test]
1006    fn angular_spread_is_nonnegative() {
1007        let (layer, _, _) = build_test_layer();
1008        for s in &layer.summaries {
1009            assert!(s.angular_spread >= 0.0);
1010        }
1011    }
1012
1013    #[test]
1014    fn cohesion_in_range() {
1015        let (layer, _, _) = build_test_layer();
1016        for s in &layer.summaries {
1017            assert!(s.cohesion > 0.0 && s.cohesion <= 1.0);
1018        }
1019    }
1020
1021    #[test]
1022    fn graph_has_edges_for_all_pairs() {
1023        let (layer, _, _) = build_test_layer();
1024        for (i, edges) in layer.graph.adjacency.iter().enumerate() {
1025            assert_eq!(edges.len(), layer.num_categories() - 1, "cat {i}");
1026        }
1027    }
1028
1029    #[test]
1030    fn edge_weights_positive() {
1031        let (layer, _, _) = build_test_layer();
1032        for edges in &layer.graph.adjacency {
1033            for e in edges {
1034                assert!(e.weight > 0.0);
1035                assert!(e.centroid_distance > 0.0);
1036            }
1037        }
1038    }
1039
1040    #[test]
1041    fn edges_sorted_by_weight() {
1042        let (layer, _, _) = build_test_layer();
1043        for edges in &layer.graph.adjacency {
1044            for w in edges.windows(2) {
1045                assert!(w[0].weight <= w[1].weight);
1046            }
1047        }
1048    }
1049
1050    #[test]
1051    fn get_category_by_name() {
1052        let (layer, _, _) = build_test_layer();
1053        assert!(layer.get_category("science").is_some());
1054        assert!(layer.get_category("astrology").is_none());
1055    }
1056
1057    #[test]
1058    fn category_neighbors_returns_sorted() {
1059        let (layer, _, _) = build_test_layer();
1060        assert_eq!(layer.category_neighbors("science", 2).len(), 2);
1061    }
1062
1063    #[test]
1064    fn category_neighbors_k_larger_than_available() {
1065        let (layer, _, _) = build_test_layer();
1066        assert_eq!(layer.category_neighbors("science", 100).len(), 2);
1067    }
1068
1069    #[test]
1070    fn category_neighbors_unknown_returns_empty() {
1071        let (layer, _, _) = build_test_layer();
1072        assert!(layer.category_neighbors("nonexistent", 5).is_empty());
1073    }
1074
1075    #[test]
1076    fn bridge_items_detected() {
1077        let (layer, _, _) = build_test_layer();
1078        let _ = layer.bridge_items("science", "cooking", 10);
1079    }
1080
1081    #[test]
1082    fn bridge_items_unknown_category_returns_empty() {
1083        let (layer, _, _) = build_test_layer();
1084        assert!(layer.bridge_items("science", "nonexistent", 10).is_empty());
1085    }
1086
1087    #[test]
1088    fn bridge_strength_in_valid_range() {
1089        let (layer, _, _) = build_test_layer();
1090        for list in layer.graph.bridges.values() {
1091            for b in list {
1092                assert!(b.bridge_strength >= 0.0 && b.bridge_strength <= 1.0);
1093            }
1094        }
1095    }
1096
1097    #[test]
1098    fn bridges_sorted_by_strength() {
1099        let (layer, _, _) = build_test_layer();
1100        for list in layer.graph.bridges.values() {
1101            for w in list.windows(2) {
1102                assert!(w[0].bridge_strength >= w[1].bridge_strength);
1103            }
1104        }
1105    }
1106
1107    #[test]
1108    fn category_path_same_category() {
1109        let (layer, _, _) = build_test_layer();
1110        let path = layer.category_path("science", "science").unwrap();
1111        assert_eq!(path.steps.len(), 1);
1112        assert!(path.total_distance.abs() < 1e-12);
1113    }
1114
1115    #[test]
1116    fn category_path_adjacent() {
1117        let (layer, _, _) = build_test_layer();
1118        let path = layer.category_path("science", "cooking").unwrap();
1119        assert!(path.steps.len() >= 2);
1120        assert_eq!(path.steps.first().unwrap().category_name, "science");
1121        assert_eq!(path.steps.last().unwrap().category_name, "cooking");
1122        assert!(path.total_distance > 0.0);
1123    }
1124
1125    #[test]
1126    fn category_path_unknown_returns_none() {
1127        let (layer, _, _) = build_test_layer();
1128        assert!(layer.category_path("science", "nonexistent").is_none());
1129    }
1130
1131    #[test]
1132    fn category_path_distances_monotonic() {
1133        let (layer, _, _) = build_test_layer();
1134        let path = layer.category_path("science", "music").unwrap();
1135        for w in path.steps.windows(2) {
1136            assert!(w[1].cumulative_distance >= w[0].cumulative_distance);
1137        }
1138    }
1139
1140    #[test]
1141    fn categories_near_embedding_finds_correct() {
1142        let (layer, _, pca) = build_test_layer();
1143        let near = layer.categories_near_embedding(
1144            &emb(&[1.0, 0.0, 0.0, 0.0, 0.0]),
1145            &pca,
1146            std::f64::consts::PI,
1147        );
1148        assert!(!near.is_empty());
1149        assert_eq!(layer.summaries[near[0].0].name, "science");
1150    }
1151
1152    #[test]
1153    fn categories_near_embedding_sorted_by_distance() {
1154        let (layer, _, pca) = build_test_layer();
1155        let near = layer.categories_near_embedding(
1156            &emb(&[0.5, 0.5, 0.5, 0.0, 0.0]),
1157            &pca,
1158            std::f64::consts::PI,
1159        );
1160        for w in near.windows(2) {
1161            assert!(w[0].1 <= w[1].1);
1162        }
1163    }
1164
1165    #[test]
1166    fn categories_near_embedding_respects_threshold() {
1167        let (layer, _, pca) = build_test_layer();
1168        let near = layer.categories_near_embedding(&emb(&[1.0, 0.0, 0.0, 0.0, 0.0]), &pca, 0.01);
1169        for &(_, d) in &near {
1170            assert!(d <= 0.01);
1171        }
1172    }
1173
1174    #[test]
1175    fn cosine_similarity_identical() {
1176        assert!((cosine_similarity(&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-12);
1177    }
1178
1179    #[test]
1180    fn cosine_similarity_orthogonal() {
1181        assert!(cosine_similarity(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).abs() < 1e-12);
1182    }
1183
1184    #[test]
1185    fn cosine_similarity_opposite() {
1186        assert!((cosine_similarity(&[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0]) + 1.0).abs() < 1e-12);
1187    }
1188
1189    #[test]
1190    fn cosine_similarity_zero_vector() {
1191        assert!(cosine_similarity(&[0.0, 0.0, 0.0], &[1.0, 0.0, 0.0]).abs() < 1e-12);
1192    }
1193
1194    // ======== Phase 2 tests (inner spheres) ========
1195
1196    #[test]
1197    fn small_categories_get_no_inner_sphere() {
1198        let (layer, _, _) = build_test_layer();
1199        assert_eq!(layer.num_inner_spheres(), 0);
1200        assert!(!layer.has_inner_sphere("science"));
1201    }
1202
1203    #[test]
1204    fn large_category_may_get_inner_sphere() {
1205        let (layer, _, _) = build_large_test_layer();
1206        assert!(!layer.has_inner_sphere("small_a"));
1207        assert!(!layer.has_inner_sphere("small_b"));
1208        let _ = layer.has_inner_sphere("big");
1209    }
1210
1211    #[test]
1212    fn inner_sphere_stats_count_matches() {
1213        let (layer, _, _) = build_large_test_layer();
1214        assert_eq!(layer.inner_sphere_stats().len(), layer.num_inner_spheres());
1215    }
1216
1217    #[test]
1218    fn inner_sphere_stats_sorted_by_index() {
1219        let (layer, _, _) = build_large_test_layer();
1220        let stats = layer.inner_sphere_stats();
1221        for w in stats.windows(2) {
1222            assert!(w[0].category_index <= w[1].category_index);
1223        }
1224    }
1225
1226    #[test]
1227    fn inner_sphere_evr_improvement_positive() {
1228        let (layer, _, _) = build_large_test_layer();
1229        for inner in layer.inner_spheres.values() {
1230            assert!(inner.evr_improvement >= MIN_EVR_IMPROVEMENT);
1231        }
1232    }
1233
1234    #[test]
1235    fn inner_sphere_positions_match_member_count() {
1236        let (layer, _, _) = build_large_test_layer();
1237        for (&ci, inner) in &layer.inner_spheres {
1238            assert_eq!(inner.inner_positions.len(), inner.member_indices.len());
1239            assert_eq!(inner.member_indices.len(), layer.summaries[ci].member_count);
1240        }
1241    }
1242
1243    #[test]
1244    fn inner_sphere_member_indices_valid() {
1245        let (layer, _, _) = build_large_test_layer();
1246        let total = layer.outer_positions.len();
1247        for inner in layer.inner_spheres.values() {
1248            for &mi in &inner.member_indices {
1249                assert!(mi < total);
1250            }
1251        }
1252    }
1253
1254    #[test]
1255    fn inner_sphere_report_projection_type_valid() {
1256        let (layer, _, _) = build_large_test_layer();
1257        for r in layer.inner_sphere_stats() {
1258            assert!(r.projection_type == "LinearPca" || r.projection_type == "KernelPca");
1259        }
1260    }
1261
1262    #[test]
1263    fn inner_sphere_evr_in_range() {
1264        let (layer, _, _) = build_large_test_layer();
1265        for inner in layer.inner_spheres.values() {
1266            assert!(inner.explained_variance_ratio >= 0.0 && inner.explained_variance_ratio <= 1.0);
1267            assert!(inner.global_subset_evr >= 0.0 && inner.global_subset_evr <= 1.0);
1268        }
1269    }
1270
1271    #[test]
1272    fn has_inner_sphere_unknown_category() {
1273        let (layer, _, _) = build_test_layer();
1274        assert!(!layer.has_inner_sphere("nonexistent"));
1275    }
1276
1277    #[test]
1278    fn get_inner_sphere_returns_none_for_small() {
1279        let (layer, _, _) = build_test_layer();
1280        assert!(layer.get_inner_sphere("science").is_none());
1281    }
1282
1283    #[test]
1284    fn drill_down_returns_results() {
1285        let (layer, _, pca) = build_large_test_layer();
1286        let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1287        let results = layer.drill_down_with_projection("big", &q, &pca, 5);
1288        assert!(!results.is_empty());
1289        assert!(results.len() <= 5);
1290    }
1291
1292    #[test]
1293    fn drill_down_sorted_by_distance() {
1294        let (layer, _, pca) = build_large_test_layer();
1295        let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1296        let results = layer.drill_down_with_projection("big", &q, &pca, 10);
1297        for w in results.windows(2) {
1298            assert!(w[0].distance <= w[1].distance);
1299        }
1300    }
1301
1302    #[test]
1303    fn drill_down_unknown_category_empty() {
1304        let (layer, _, pca) = build_large_test_layer();
1305        assert!(
1306            layer
1307                .drill_down_with_projection("nonexistent", &emb(&[1.0; 10]), &pca, 5)
1308                .is_empty()
1309        );
1310    }
1311
1312    #[test]
1313    fn drill_down_item_indices_valid() {
1314        let (layer, _, pca) = build_large_test_layer();
1315        let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1316        let total = layer.outer_positions.len();
1317        for r in layer.drill_down_with_projection("big", &q, &pca, 25) {
1318            assert!(r.item_index < total);
1319        }
1320    }
1321
1322    #[test]
1323    fn drill_down_small_category_uses_outer() {
1324        let (layer, _, pca) = build_large_test_layer();
1325        let q = emb(&[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
1326        for r in layer.drill_down_with_projection("small_a", &q, &pca, 4) {
1327            assert!(!r.used_inner_sphere);
1328        }
1329    }
1330
1331    #[test]
1332    fn drill_down_distances_nonnegative() {
1333        let (layer, _, pca) = build_large_test_layer();
1334        for r in layer.drill_down_with_projection("big", &emb(&[1.0; 10]), &pca, 10) {
1335            assert!(r.distance >= 0.0);
1336        }
1337    }
1338
1339    #[test]
1340    fn drill_down_without_projection_works() {
1341        let (layer, _, _) = build_large_test_layer();
1342        let q = emb(&[1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
1343        assert!(layer.drill_down("big", &q, 5).len() <= 5);
1344    }
1345
1346    #[test]
1347    fn inner_projection_enum_debug() {
1348        let corpus: Vec<Embedding> = (0..5)
1349            .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1350            .collect();
1351        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1352        assert_eq!(
1353            format!("{:?}", InnerProjection::LinearPca(pca)),
1354            "LinearPca"
1355        );
1356    }
1357
1358    #[test]
1359    fn inner_projection_projects_correctly() {
1360        let corpus: Vec<Embedding> = (0..5)
1361            .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1362            .collect();
1363        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1364        let proj = InnerProjection::LinearPca(pca.clone());
1365        let e = emb(&[1.0, 0.0, 0.0, 0.0, 0.0]);
1366        let sp_enum = proj.project(&e);
1367        let sp_direct = pca.project(&e);
1368        assert!((sp_enum.theta - sp_direct.theta).abs() < 1e-12);
1369        assert!((sp_enum.phi - sp_direct.phi).abs() < 1e-12);
1370    }
1371
1372    #[test]
1373    fn inner_projection_dimensionality() {
1374        let corpus: Vec<Embedding> = (0..5)
1375            .map(|i| emb(&[i as f64, 0.0, 0.0, 0.0, 0.0]))
1376            .collect();
1377        let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0));
1378        assert_eq!(InnerProjection::LinearPca(pca).dimensionality(), 5);
1379    }
1380}