Skip to main content

sphereql_embed/
types.rs

1use std::sync::Arc;
2
3use sphereql_core::SphericalPoint;
4
5#[derive(Debug, Clone)]
6pub struct Embedding {
7    pub values: Vec<f64>,
8}
9
10/// A projected point on the sphere with rich attributes from the projection.
11///
12/// Extends the raw `SphericalPoint` with metadata that captures how much
13/// information was preserved (or lost) during dimensionality reduction.
14#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct ProjectedPoint {
16    /// The spherical position (r, theta, phi).
17    pub position: SphericalPoint,
18    /// How well the 3D projection captures this point's original direction.
19    /// Computed as 1 - (residual / total variance). Range [0, 1]:
20    /// - 1.0: perfect reconstruction (all variance explained by 3 PCA components)
21    /// - 0.0: the projection lost everything
22    pub certainty: f64,
23    /// Semantic strength of the original embedding (pre-normalization magnitude).
24    /// Higher values indicate more specific/confident embeddings.
25    pub intensity: f64,
26    /// Magnitude of the 3-component PCA projection before normalization.
27    /// Points near the PCA centroid have low projection magnitude and are
28    /// ambiguous — they don't strongly align with any principal direction.
29    pub projection_magnitude: f64,
30}
31
32impl ProjectedPoint {
33    pub fn new(
34        position: SphericalPoint,
35        certainty: f64,
36        intensity: f64,
37        projection_magnitude: f64,
38    ) -> Self {
39        Self {
40            position,
41            certainty,
42            intensity,
43            projection_magnitude,
44        }
45    }
46
47    /// Create a basic projected point with no metadata (legacy compat).
48    pub fn from_position(position: SphericalPoint, intensity: f64) -> Self {
49        Self {
50            position,
51            certainty: 1.0,
52            intensity,
53            projection_magnitude: 1.0,
54        }
55    }
56}
57
58impl Embedding {
59    pub fn new(values: Vec<f64>) -> Self {
60        Self { values }
61    }
62
63    pub fn dimension(&self) -> usize {
64        self.values.len()
65    }
66
67    pub fn magnitude(&self) -> f64 {
68        self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
69    }
70
71    pub fn normalized(&self) -> Vec<f64> {
72        let mag = self.magnitude();
73        if mag < f64::EPSILON {
74            let mut v = vec![0.0; self.values.len()];
75            if !v.is_empty() {
76                v[0] = 1.0;
77            }
78            return v;
79        }
80        self.values.iter().map(|v| v / mag).collect()
81    }
82}
83
84impl From<Vec<f64>> for Embedding {
85    fn from(values: Vec<f64>) -> Self {
86        Self { values }
87    }
88}
89
90impl From<&[f64]> for Embedding {
91    fn from(values: &[f64]) -> Self {
92        Self {
93            values: values.to_vec(),
94        }
95    }
96}
97
98/// Controls how the radial coordinate r is computed from an embedding.
99///
100/// The angular coordinates (theta, phi) always encode semantic direction.
101/// The radial coordinate is free to encode magnitude, metadata, or a fixed value.
102#[derive(Default)]
103pub enum RadialStrategy {
104    /// Constant radius for all projections.
105    Fixed(f64),
106    /// r = L2 magnitude of the raw (pre-normalization) embedding.
107    /// Encodes embedding "confidence" or specificity.
108    #[default]
109    Magnitude,
110    /// r = f(magnitude). Apply a custom transform to the pre-normalization magnitude.
111    /// Useful for log-scaling, clamping, or mapping metadata that correlates with magnitude.
112    MagnitudeTransform(Arc<dyn Fn(f64) -> f64 + Send + Sync>),
113}
114
115impl Clone for RadialStrategy {
116    fn clone(&self) -> Self {
117        match self {
118            Self::Fixed(r) => Self::Fixed(*r),
119            Self::Magnitude => Self::Magnitude,
120            Self::MagnitudeTransform(f) => Self::MagnitudeTransform(Arc::clone(f)),
121        }
122    }
123}
124
125impl std::fmt::Debug for RadialStrategy {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        match self {
128            Self::Fixed(r) => write!(f, "Fixed({r})"),
129            Self::Magnitude => write!(f, "Magnitude"),
130            Self::MagnitudeTransform(_) => write!(f, "MagnitudeTransform(<fn>)"),
131        }
132    }
133}
134
135impl RadialStrategy {
136    pub fn compute(&self, magnitude: f64) -> f64 {
137        match self {
138            Self::Fixed(r) => *r,
139            Self::Magnitude => magnitude,
140            Self::MagnitudeTransform(f) => f(magnitude),
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn embedding_magnitude() {
151        let e = Embedding::new(vec![3.0, 4.0]);
152        assert!((e.magnitude() - 5.0).abs() < 1e-12);
153    }
154
155    #[test]
156    fn embedding_normalized() {
157        let e = Embedding::new(vec![3.0, 4.0]);
158        let n = e.normalized();
159        assert!((n[0] - 0.6).abs() < 1e-12);
160        assert!((n[1] - 0.8).abs() < 1e-12);
161    }
162
163    #[test]
164    fn zero_embedding_normalized_fallback() {
165        let e = Embedding::new(vec![0.0, 0.0, 0.0]);
166        let n = e.normalized();
167        assert!((n[0] - 1.0).abs() < 1e-12);
168        assert!(n[1].abs() < 1e-12);
169        assert!(n[2].abs() < 1e-12);
170    }
171
172    #[test]
173    fn from_vec() {
174        let e: Embedding = vec![1.0, 2.0, 3.0].into();
175        assert_eq!(e.dimension(), 3);
176    }
177
178    #[test]
179    fn from_slice() {
180        let data = [1.0, 2.0, 3.0];
181        let e: Embedding = data.as_slice().into();
182        assert_eq!(e.dimension(), 3);
183    }
184
185    #[test]
186    fn radial_fixed() {
187        let r = RadialStrategy::Fixed(2.5);
188        assert!((r.compute(999.0) - 2.5).abs() < 1e-12);
189    }
190
191    #[test]
192    fn radial_magnitude() {
193        let r = RadialStrategy::Magnitude;
194        assert!((r.compute(7.0) - 7.0).abs() < 1e-12);
195    }
196
197    #[test]
198    fn radial_transform() {
199        let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m.ln_1p()));
200        let expected = 5.0_f64.ln_1p();
201        assert!((r.compute(5.0) - expected).abs() < 1e-12);
202    }
203
204    #[test]
205    fn radial_clone() {
206        let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * 2.0));
207        let r2 = r.clone();
208        assert!((r.compute(3.0) - r2.compute(3.0)).abs() < 1e-12);
209    }
210
211    #[test]
212    fn radial_debug() {
213        assert_eq!(format!("{:?}", RadialStrategy::Fixed(1.0)), "Fixed(1)");
214        assert_eq!(format!("{:?}", RadialStrategy::Magnitude), "Magnitude");
215        let t = RadialStrategy::MagnitudeTransform(Arc::new(|m| m));
216        assert_eq!(format!("{t:?}"), "MagnitudeTransform(<fn>)");
217    }
218}