sphereql_embed/
configured_projection.rs1use sphereql_core::SphericalPoint;
15
16use crate::config::ProjectionKind;
17use crate::kernel_pca::KernelPcaProjection;
18use crate::laplacian::LaplacianEigenmapProjection;
19use crate::projection::{PcaProjection, Projection};
20use crate::types::{Embedding, ProjectedPoint};
21use crate::umap::UmapSphereProjection;
22
23#[derive(Clone)]
29pub enum ConfiguredProjection {
30 Pca(PcaProjection),
31 KernelPca(KernelPcaProjection),
32 Laplacian(LaplacianEigenmapProjection),
33 UmapSphere(UmapSphereProjection),
34}
35
36impl std::fmt::Debug for ConfiguredProjection {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Pca(_) => write!(f, "ConfiguredProjection::Pca"),
40 Self::KernelPca(_) => write!(f, "ConfiguredProjection::KernelPca"),
41 Self::Laplacian(_) => write!(f, "ConfiguredProjection::Laplacian"),
42 Self::UmapSphere(_) => write!(f, "ConfiguredProjection::UmapSphere"),
43 }
44 }
45}
46
47impl Projection for ConfiguredProjection {
48 fn project(&self, embedding: &Embedding) -> SphericalPoint {
49 match self {
50 Self::Pca(p) => p.project(embedding),
51 Self::KernelPca(p) => p.project(embedding),
52 Self::Laplacian(p) => p.project(embedding),
53 Self::UmapSphere(p) => p.project(embedding),
54 }
55 }
56
57 fn project_rich(&self, embedding: &Embedding) -> ProjectedPoint {
58 match self {
59 Self::Pca(p) => p.project_rich(embedding),
60 Self::KernelPca(p) => p.project_rich(embedding),
61 Self::Laplacian(p) => p.project_rich(embedding),
62 Self::UmapSphere(p) => p.project_rich(embedding),
63 }
64 }
65
66 fn dimensionality(&self) -> usize {
67 match self {
68 Self::Pca(p) => p.dimensionality(),
69 Self::KernelPca(p) => p.dimensionality(),
70 Self::Laplacian(p) => p.dimensionality(),
71 Self::UmapSphere(p) => p.dimensionality(),
72 }
73 }
74}
75
76impl ConfiguredProjection {
77 pub fn kind(&self) -> ProjectionKind {
79 match self {
80 Self::Pca(_) => ProjectionKind::Pca,
81 Self::KernelPca(_) => ProjectionKind::KernelPca,
82 Self::Laplacian(_) => ProjectionKind::LaplacianEigenmap,
83 Self::UmapSphere(_) => ProjectionKind::UmapSphere,
84 }
85 }
86
87 pub fn explained_variance_ratio(&self) -> f64 {
93 match self {
94 Self::Pca(p) => p.explained_variance_ratio(),
95 Self::KernelPca(p) => p.explained_variance_ratio(),
96 Self::Laplacian(p) => p.explained_variance_ratio(),
97 Self::UmapSphere(p) => p.explained_variance_ratio(),
98 }
99 }
100
101 pub fn as_umap_sphere(&self) -> Option<&UmapSphereProjection> {
102 match self {
103 Self::UmapSphere(p) => Some(p),
104 _ => None,
105 }
106 }
107
108 pub fn as_pca(&self) -> Option<&PcaProjection> {
109 match self {
110 Self::Pca(p) => Some(p),
111 _ => None,
112 }
113 }
114
115 pub fn as_kernel_pca(&self) -> Option<&KernelPcaProjection> {
116 match self {
117 Self::KernelPca(p) => Some(p),
118 _ => None,
119 }
120 }
121
122 pub fn as_laplacian(&self) -> Option<&LaplacianEigenmapProjection> {
123 match self {
124 Self::Laplacian(p) => Some(p),
125 _ => None,
126 }
127 }
128}
129
130impl From<PcaProjection> for ConfiguredProjection {
131 fn from(p: PcaProjection) -> Self {
132 Self::Pca(p)
133 }
134}
135
136impl From<KernelPcaProjection> for ConfiguredProjection {
137 fn from(p: KernelPcaProjection) -> Self {
138 Self::KernelPca(p)
139 }
140}
141
142impl From<LaplacianEigenmapProjection> for ConfiguredProjection {
143 fn from(p: LaplacianEigenmapProjection) -> Self {
144 Self::Laplacian(p)
145 }
146}
147
148impl From<UmapSphereProjection> for ConfiguredProjection {
149 fn from(p: UmapSphereProjection) -> Self {
150 Self::UmapSphere(p)
151 }
152}
153
154#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::types::RadialStrategy;
160
161 fn emb(vals: &[f64]) -> Embedding {
162 Embedding::new(vals.to_vec())
163 }
164
165 fn toy_corpus() -> Vec<Embedding> {
166 (0..8)
167 .map(|i| {
168 let t = i as f64;
169 emb(&[1.0 + t * 0.01, 0.5 - t * 0.01, 0.2, 0.05, 0.03])
170 })
171 .collect()
172 }
173
174 #[test]
175 fn pca_variant_dispatches() {
176 let corpus = toy_corpus();
177 let pca = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
178 let cp: ConfiguredProjection = pca.into();
179 assert_eq!(cp.kind(), ProjectionKind::Pca);
180 assert_eq!(cp.dimensionality(), 5);
181 let sp = cp.project(&corpus[0]);
182 assert!((sp.r - 1.0).abs() < 1e-9);
183 assert!(cp.as_pca().is_some());
184 assert!(cp.as_kernel_pca().is_none());
185 assert!(cp.as_laplacian().is_none());
186 }
187
188 #[test]
189 fn kernel_pca_variant_dispatches() {
190 let corpus = toy_corpus();
191 let kpca = KernelPcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
192 let cp: ConfiguredProjection = kpca.into();
193 assert_eq!(cp.kind(), ProjectionKind::KernelPca);
194 assert_eq!(cp.dimensionality(), 5);
195 assert!(cp.as_kernel_pca().is_some());
196 assert!(cp.as_pca().is_none());
197 }
198
199 #[test]
200 fn laplacian_variant_dispatches() {
201 let corpus = toy_corpus();
204 let lap = LaplacianEigenmapProjection::fit(&corpus, RadialStrategy::Fixed(1.0)).unwrap();
205 let cp: ConfiguredProjection = lap.into();
206 assert_eq!(cp.kind(), ProjectionKind::LaplacianEigenmap);
207 assert_eq!(cp.dimensionality(), 5);
208 assert!(cp.as_laplacian().is_some());
209 assert!(cp.as_pca().is_none());
210 }
211
212 #[test]
213 fn explained_variance_ratio_in_range_for_every_variant() {
214 let corpus = toy_corpus();
215 let pca: ConfiguredProjection = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0))
216 .unwrap()
217 .into();
218 let kpca: ConfiguredProjection =
219 KernelPcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0))
220 .unwrap()
221 .into();
222 let lap: ConfiguredProjection =
223 LaplacianEigenmapProjection::fit(&corpus, RadialStrategy::Fixed(1.0))
224 .unwrap()
225 .into();
226 for cp in &[pca, kpca, lap] {
227 let r = cp.explained_variance_ratio();
228 assert!((0.0..=1.0).contains(&r), "{cp:?}: {r}");
229 }
230 }
231
232 #[test]
233 fn debug_formats_kind_not_inner() {
234 let corpus = toy_corpus();
235 let pca: ConfiguredProjection = PcaProjection::fit(&corpus, RadialStrategy::Fixed(1.0))
236 .unwrap()
237 .into();
238 assert_eq!(format!("{pca:?}"), "ConfiguredProjection::Pca");
239 }
240}