Skip to main content

sphereql_embed/
types.rs

1use std::sync::Arc;
2
3use sphereql_core::SphericalPoint;
4
5/// A high-dimensional embedding vector.
6///
7/// Wraps a `Vec<f64>` with projection-oriented helpers: `magnitude`,
8/// `normalized` (L2 unit vector with a [1, 0, …] fallback for the zero
9/// vector), and `From<Vec<f64>>` / `From<&[f64]>` for ergonomic construction.
10///
11/// All projection families normalize embeddings to the unit hypersphere
12/// before extracting angular structure, so raw magnitude is preserved
13/// separately as the radial coordinate via [`RadialStrategy`].
14#[derive(Debug, Clone)]
15pub struct Embedding {
16    pub values: Vec<f64>,
17}
18
19/// A projected point on the sphere with rich attributes from the projection.
20///
21/// Extends the raw `SphericalPoint` with metadata that captures how much
22/// information was preserved (or lost) during dimensionality reduction.
23#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
24pub struct ProjectedPoint {
25    /// The spherical position (r, theta, phi).
26    pub position: SphericalPoint,
27    /// How well the 3D projection captures this point's original direction.
28    /// Computed as 1 - (residual / total variance). Range [0, 1]:
29    /// - 1.0: perfect reconstruction (all variance explained by 3 PCA components)
30    /// - 0.0: the projection lost everything
31    pub certainty: f64,
32    /// Semantic strength of the original embedding (pre-normalization magnitude).
33    /// Higher values indicate more specific/confident embeddings.
34    pub intensity: f64,
35    /// Magnitude of the 3-component PCA projection before normalization.
36    /// Points near the PCA centroid have low projection magnitude and are
37    /// ambiguous — they don't strongly align with any principal direction.
38    pub projection_magnitude: f64,
39}
40
41impl ProjectedPoint {
42    pub fn new(
43        position: SphericalPoint,
44        certainty: f64,
45        intensity: f64,
46        projection_magnitude: f64,
47    ) -> Self {
48        Self {
49            position,
50            certainty,
51            intensity,
52            projection_magnitude,
53        }
54    }
55
56    /// Create a basic projected point with no metadata (legacy compat).
57    pub fn from_position(position: SphericalPoint, intensity: f64) -> Self {
58        Self {
59            position,
60            certainty: 1.0,
61            intensity,
62            projection_magnitude: 1.0,
63        }
64    }
65}
66
67impl Embedding {
68    pub fn new(values: Vec<f64>) -> Self {
69        Self { values }
70    }
71
72    pub fn dimension(&self) -> usize {
73        self.values.len()
74    }
75
76    pub fn magnitude(&self) -> f64 {
77        self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
78    }
79
80    pub fn normalized(&self) -> Vec<f64> {
81        let mag = self.magnitude();
82        if mag < f64::EPSILON {
83            let mut v = vec![0.0; self.values.len()];
84            if !v.is_empty() {
85                v[0] = 1.0;
86            }
87            return v;
88        }
89        self.values.iter().map(|v| v / mag).collect()
90    }
91}
92
93impl From<Vec<f64>> for Embedding {
94    fn from(values: Vec<f64>) -> Self {
95        Self { values }
96    }
97}
98
99impl From<&[f64]> for Embedding {
100    fn from(values: &[f64]) -> Self {
101        Self {
102            values: values.to_vec(),
103        }
104    }
105}
106
107/// Per-point information available when resolving the radial coordinate.
108///
109/// Modern L2-normalized embeddings make `embedding_magnitude` a constant
110/// 1.0, wasting the radial axis. This struct exposes the additional
111/// signals the projection naturally produces (post-projection magnitude,
112/// per-point certainty) so a [`RadialStrategy`] can encode something
113/// useful in `r` instead.
114#[derive(Debug, Clone, Copy)]
115pub struct RadialContext {
116    /// L2 norm of the raw input embedding (pre-normalization).
117    pub embedding_magnitude: f64,
118    /// L2 norm of the (x, y, z) projected vector before re-scaling.
119    /// High values mean the 3 components captured most of the input's
120    /// variance; low values mean the input fell mostly into the residual.
121    pub projection_magnitude: f64,
122    /// Fraction of input variance retained by the projection, in `[0, 1]`.
123    /// Source depends on the projection family — PCA uses captured-variance
124    /// ratio; KPCA uses Hoffmann's reconstruction-error formula;
125    /// Laplacian uses tanh(projection_magnitude); Random reports 1.0.
126    pub certainty: f64,
127}
128
129impl RadialContext {
130    /// Construct from just the embedding magnitude. Other fields default
131    /// to neutral values; use [`Self::full`] when projection-side
132    /// information is available.
133    pub fn from_magnitude(embedding_magnitude: f64) -> Self {
134        Self {
135            embedding_magnitude,
136            projection_magnitude: embedding_magnitude,
137            certainty: 1.0,
138        }
139    }
140
141    /// Construct with all three signals populated.
142    pub fn full(embedding_magnitude: f64, projection_magnitude: f64, certainty: f64) -> Self {
143        Self {
144            embedding_magnitude,
145            projection_magnitude,
146            certainty,
147        }
148    }
149}
150
151/// Controls how the radial coordinate r is computed from an embedding.
152///
153/// The angular coordinates (theta, phi) always encode semantic direction.
154/// The radial coordinate is free to encode magnitude, fidelity, or a
155/// caller-defined function of any per-point signal the projection exposes
156/// via [`RadialContext`].
157#[derive(Default)]
158pub enum RadialStrategy {
159    /// Constant radius for all projections.
160    Fixed(f64),
161    /// r = L2 magnitude of the raw (pre-normalization) embedding.
162    /// Encodes embedding "confidence" or specificity. Degenerates to a
163    /// constant when inputs are L2-normalized — pick one of the
164    /// projection-side variants below in that case.
165    #[default]
166    Magnitude,
167    /// r = f(embedding_magnitude). Apply a custom transform to the
168    /// pre-normalization magnitude (e.g. log scaling, clamping).
169    MagnitudeTransform(Arc<dyn Fn(f64) -> f64 + Send + Sync>),
170    /// r = ‖(x, y, z)‖ — how much of the input variance landed in the
171    /// projected 3-vector. Universal across all four projection families.
172    /// Recommended starting point for normalized embeddings.
173    ProjectionMagnitude,
174    /// r = `scale * certainty`, where `certainty ∈ [0, 1]` is the
175    /// projection-supplied per-point fidelity score. Higher r ⇒ this
176    /// point is well-explained by the 3D projection.
177    Certainty { scale: f64 },
178    /// r = f(&context). Escape hatch for arbitrary per-point logic over
179    /// any combination of the [`RadialContext`] signals.
180    Custom(Arc<dyn Fn(&RadialContext) -> f64 + Send + Sync>),
181}
182
183impl Clone for RadialStrategy {
184    fn clone(&self) -> Self {
185        match self {
186            Self::Fixed(r) => Self::Fixed(*r),
187            Self::Magnitude => Self::Magnitude,
188            Self::MagnitudeTransform(f) => Self::MagnitudeTransform(Arc::clone(f)),
189            Self::ProjectionMagnitude => Self::ProjectionMagnitude,
190            Self::Certainty { scale } => Self::Certainty { scale: *scale },
191            Self::Custom(f) => Self::Custom(Arc::clone(f)),
192        }
193    }
194}
195
196impl std::fmt::Debug for RadialStrategy {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        match self {
199            Self::Fixed(r) => write!(f, "Fixed({r})"),
200            Self::Magnitude => write!(f, "Magnitude"),
201            Self::MagnitudeTransform(_) => write!(f, "MagnitudeTransform(<fn>)"),
202            Self::ProjectionMagnitude => write!(f, "ProjectionMagnitude"),
203            Self::Certainty { scale } => write!(f, "Certainty {{ scale: {scale} }}"),
204            Self::Custom(_) => write!(f, "Custom(<fn>)"),
205        }
206    }
207}
208
209impl RadialStrategy {
210    /// Resolve `r` against the full per-point context. All four projection
211    /// families construct a [`RadialContext`] inline and call this.
212    pub fn compute_rich(&self, ctx: &RadialContext) -> f64 {
213        match self {
214            Self::Fixed(r) => *r,
215            Self::Magnitude => ctx.embedding_magnitude,
216            Self::MagnitudeTransform(f) => f(ctx.embedding_magnitude),
217            Self::ProjectionMagnitude => ctx.projection_magnitude,
218            Self::Certainty { scale } => scale * ctx.certainty,
219            Self::Custom(f) => f(ctx),
220        }
221    }
222
223    /// Backward-compatible shim. Use [`Self::compute_rich`] when
224    /// projection-side context is available — the new `ProjectionMagnitude`,
225    /// `Certainty`, and `Custom` variants only do something useful there.
226    pub fn compute(&self, magnitude: f64) -> f64 {
227        self.compute_rich(&RadialContext::from_magnitude(magnitude))
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn embedding_magnitude() {
237        let e = Embedding::new(vec![3.0, 4.0]);
238        assert!((e.magnitude() - 5.0).abs() < 1e-12);
239    }
240
241    #[test]
242    fn embedding_normalized() {
243        let e = Embedding::new(vec![3.0, 4.0]);
244        let n = e.normalized();
245        assert!((n[0] - 0.6).abs() < 1e-12);
246        assert!((n[1] - 0.8).abs() < 1e-12);
247    }
248
249    #[test]
250    fn zero_embedding_normalized_fallback() {
251        let e = Embedding::new(vec![0.0, 0.0, 0.0]);
252        let n = e.normalized();
253        assert!((n[0] - 1.0).abs() < 1e-12);
254        assert!(n[1].abs() < 1e-12);
255        assert!(n[2].abs() < 1e-12);
256    }
257
258    #[test]
259    fn from_vec() {
260        let e: Embedding = vec![1.0, 2.0, 3.0].into();
261        assert_eq!(e.dimension(), 3);
262    }
263
264    #[test]
265    fn from_slice() {
266        let data = [1.0, 2.0, 3.0];
267        let e: Embedding = data.as_slice().into();
268        assert_eq!(e.dimension(), 3);
269    }
270
271    #[test]
272    fn radial_fixed() {
273        let r = RadialStrategy::Fixed(2.5);
274        assert!((r.compute(999.0) - 2.5).abs() < 1e-12);
275    }
276
277    #[test]
278    fn radial_magnitude() {
279        let r = RadialStrategy::Magnitude;
280        assert!((r.compute(7.0) - 7.0).abs() < 1e-12);
281    }
282
283    #[test]
284    fn radial_transform() {
285        let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m.ln_1p()));
286        let expected = 5.0_f64.ln_1p();
287        assert!((r.compute(5.0) - expected).abs() < 1e-12);
288    }
289
290    #[test]
291    fn radial_clone() {
292        let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * 2.0));
293        let r2 = r.clone();
294        assert!((r.compute(3.0) - r2.compute(3.0)).abs() < 1e-12);
295    }
296
297    #[test]
298    fn radial_projection_magnitude_uses_context() {
299        let r = RadialStrategy::ProjectionMagnitude;
300        let ctx = RadialContext::full(99.0, 0.42, 0.5);
301        assert!((r.compute_rich(&ctx) - 0.42).abs() < 1e-12);
302    }
303
304    #[test]
305    fn radial_certainty_scales() {
306        let r = RadialStrategy::Certainty { scale: 2.0 };
307        let ctx = RadialContext::full(99.0, 0.42, 0.3);
308        assert!((r.compute_rich(&ctx) - 0.6).abs() < 1e-12);
309    }
310
311    #[test]
312    fn radial_custom_sees_full_context() {
313        let r = RadialStrategy::Custom(Arc::new(|c| {
314            c.embedding_magnitude + c.projection_magnitude + c.certainty
315        }));
316        let ctx = RadialContext::full(1.0, 2.0, 0.5);
317        assert!((r.compute_rich(&ctx) - 3.5).abs() < 1e-12);
318    }
319
320    #[test]
321    fn radial_compute_shim_is_backward_compatible() {
322        // The deprecated `compute(magnitude)` shim must still produce the
323        // same value as before for the original three variants.
324        assert!((RadialStrategy::Fixed(7.0).compute(123.0) - 7.0).abs() < 1e-12);
325        assert!((RadialStrategy::Magnitude.compute(7.0) - 7.0).abs() < 1e-12);
326        let xform = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * m));
327        assert!((xform.compute(3.0) - 9.0).abs() < 1e-12);
328    }
329
330    #[test]
331    fn radial_debug() {
332        assert_eq!(format!("{:?}", RadialStrategy::Fixed(1.0)), "Fixed(1)");
333        assert_eq!(format!("{:?}", RadialStrategy::Magnitude), "Magnitude");
334        let t = RadialStrategy::MagnitudeTransform(Arc::new(|m| m));
335        assert_eq!(format!("{t:?}"), "MagnitudeTransform(<fn>)");
336    }
337}