1use std::sync::Arc;
2
3use sphereql_core::SphericalPoint;
4
5#[derive(Debug, Clone)]
15pub struct Embedding {
16 pub values: Vec<f64>,
17}
18
19#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
24pub struct ProjectedPoint {
25 pub position: SphericalPoint,
27 pub certainty: f64,
32 pub intensity: f64,
35 pub projection_magnitude: f64,
39}
40
41impl ProjectedPoint {
42 pub fn new(
43 position: SphericalPoint,
44 certainty: f64,
45 intensity: f64,
46 projection_magnitude: f64,
47 ) -> Self {
48 Self {
49 position,
50 certainty,
51 intensity,
52 projection_magnitude,
53 }
54 }
55
56 pub fn from_position(position: SphericalPoint, intensity: f64) -> Self {
58 Self {
59 position,
60 certainty: 1.0,
61 intensity,
62 projection_magnitude: 1.0,
63 }
64 }
65}
66
67impl Embedding {
68 pub fn new(values: Vec<f64>) -> Self {
69 Self { values }
70 }
71
72 pub fn dimension(&self) -> usize {
73 self.values.len()
74 }
75
76 pub fn magnitude(&self) -> f64 {
77 self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
78 }
79
80 pub fn normalized(&self) -> Vec<f64> {
81 let mag = self.magnitude();
82 if mag < f64::EPSILON {
83 let mut v = vec![0.0; self.values.len()];
84 if !v.is_empty() {
85 v[0] = 1.0;
86 }
87 return v;
88 }
89 self.values.iter().map(|v| v / mag).collect()
90 }
91}
92
93impl From<Vec<f64>> for Embedding {
94 fn from(values: Vec<f64>) -> Self {
95 Self { values }
96 }
97}
98
99impl From<&[f64]> for Embedding {
100 fn from(values: &[f64]) -> Self {
101 Self {
102 values: values.to_vec(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
115pub struct RadialContext {
116 pub embedding_magnitude: f64,
118 pub projection_magnitude: f64,
122 pub certainty: f64,
127}
128
129impl RadialContext {
130 pub fn from_magnitude(embedding_magnitude: f64) -> Self {
134 Self {
135 embedding_magnitude,
136 projection_magnitude: embedding_magnitude,
137 certainty: 1.0,
138 }
139 }
140
141 pub fn full(embedding_magnitude: f64, projection_magnitude: f64, certainty: f64) -> Self {
143 Self {
144 embedding_magnitude,
145 projection_magnitude,
146 certainty,
147 }
148 }
149}
150
151#[derive(Default)]
158pub enum RadialStrategy {
159 Fixed(f64),
161 #[default]
166 Magnitude,
167 MagnitudeTransform(Arc<dyn Fn(f64) -> f64 + Send + Sync>),
170 ProjectionMagnitude,
174 Certainty { scale: f64 },
178 Custom(Arc<dyn Fn(&RadialContext) -> f64 + Send + Sync>),
181}
182
183impl Clone for RadialStrategy {
184 fn clone(&self) -> Self {
185 match self {
186 Self::Fixed(r) => Self::Fixed(*r),
187 Self::Magnitude => Self::Magnitude,
188 Self::MagnitudeTransform(f) => Self::MagnitudeTransform(Arc::clone(f)),
189 Self::ProjectionMagnitude => Self::ProjectionMagnitude,
190 Self::Certainty { scale } => Self::Certainty { scale: *scale },
191 Self::Custom(f) => Self::Custom(Arc::clone(f)),
192 }
193 }
194}
195
196impl std::fmt::Debug for RadialStrategy {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 match self {
199 Self::Fixed(r) => write!(f, "Fixed({r})"),
200 Self::Magnitude => write!(f, "Magnitude"),
201 Self::MagnitudeTransform(_) => write!(f, "MagnitudeTransform(<fn>)"),
202 Self::ProjectionMagnitude => write!(f, "ProjectionMagnitude"),
203 Self::Certainty { scale } => write!(f, "Certainty {{ scale: {scale} }}"),
204 Self::Custom(_) => write!(f, "Custom(<fn>)"),
205 }
206 }
207}
208
209impl RadialStrategy {
210 pub fn compute_rich(&self, ctx: &RadialContext) -> f64 {
213 match self {
214 Self::Fixed(r) => *r,
215 Self::Magnitude => ctx.embedding_magnitude,
216 Self::MagnitudeTransform(f) => f(ctx.embedding_magnitude),
217 Self::ProjectionMagnitude => ctx.projection_magnitude,
218 Self::Certainty { scale } => scale * ctx.certainty,
219 Self::Custom(f) => f(ctx),
220 }
221 }
222
223 pub fn compute(&self, magnitude: f64) -> f64 {
227 self.compute_rich(&RadialContext::from_magnitude(magnitude))
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn embedding_magnitude() {
237 let e = Embedding::new(vec![3.0, 4.0]);
238 assert!((e.magnitude() - 5.0).abs() < 1e-12);
239 }
240
241 #[test]
242 fn embedding_normalized() {
243 let e = Embedding::new(vec![3.0, 4.0]);
244 let n = e.normalized();
245 assert!((n[0] - 0.6).abs() < 1e-12);
246 assert!((n[1] - 0.8).abs() < 1e-12);
247 }
248
249 #[test]
250 fn zero_embedding_normalized_fallback() {
251 let e = Embedding::new(vec![0.0, 0.0, 0.0]);
252 let n = e.normalized();
253 assert!((n[0] - 1.0).abs() < 1e-12);
254 assert!(n[1].abs() < 1e-12);
255 assert!(n[2].abs() < 1e-12);
256 }
257
258 #[test]
259 fn from_vec() {
260 let e: Embedding = vec![1.0, 2.0, 3.0].into();
261 assert_eq!(e.dimension(), 3);
262 }
263
264 #[test]
265 fn from_slice() {
266 let data = [1.0, 2.0, 3.0];
267 let e: Embedding = data.as_slice().into();
268 assert_eq!(e.dimension(), 3);
269 }
270
271 #[test]
272 fn radial_fixed() {
273 let r = RadialStrategy::Fixed(2.5);
274 assert!((r.compute(999.0) - 2.5).abs() < 1e-12);
275 }
276
277 #[test]
278 fn radial_magnitude() {
279 let r = RadialStrategy::Magnitude;
280 assert!((r.compute(7.0) - 7.0).abs() < 1e-12);
281 }
282
283 #[test]
284 fn radial_transform() {
285 let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m.ln_1p()));
286 let expected = 5.0_f64.ln_1p();
287 assert!((r.compute(5.0) - expected).abs() < 1e-12);
288 }
289
290 #[test]
291 fn radial_clone() {
292 let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * 2.0));
293 let r2 = r.clone();
294 assert!((r.compute(3.0) - r2.compute(3.0)).abs() < 1e-12);
295 }
296
297 #[test]
298 fn radial_projection_magnitude_uses_context() {
299 let r = RadialStrategy::ProjectionMagnitude;
300 let ctx = RadialContext::full(99.0, 0.42, 0.5);
301 assert!((r.compute_rich(&ctx) - 0.42).abs() < 1e-12);
302 }
303
304 #[test]
305 fn radial_certainty_scales() {
306 let r = RadialStrategy::Certainty { scale: 2.0 };
307 let ctx = RadialContext::full(99.0, 0.42, 0.3);
308 assert!((r.compute_rich(&ctx) - 0.6).abs() < 1e-12);
309 }
310
311 #[test]
312 fn radial_custom_sees_full_context() {
313 let r = RadialStrategy::Custom(Arc::new(|c| {
314 c.embedding_magnitude + c.projection_magnitude + c.certainty
315 }));
316 let ctx = RadialContext::full(1.0, 2.0, 0.5);
317 assert!((r.compute_rich(&ctx) - 3.5).abs() < 1e-12);
318 }
319
320 #[test]
321 fn radial_compute_shim_is_backward_compatible() {
322 assert!((RadialStrategy::Fixed(7.0).compute(123.0) - 7.0).abs() < 1e-12);
325 assert!((RadialStrategy::Magnitude.compute(7.0) - 7.0).abs() < 1e-12);
326 let xform = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * m));
327 assert!((xform.compute(3.0) - 9.0).abs() < 1e-12);
328 }
329
330 #[test]
331 fn radial_debug() {
332 assert_eq!(format!("{:?}", RadialStrategy::Fixed(1.0)), "Fixed(1)");
333 assert_eq!(format!("{:?}", RadialStrategy::Magnitude), "Magnitude");
334 let t = RadialStrategy::MagnitudeTransform(Arc::new(|m| m));
335 assert_eq!(format!("{t:?}"), "MagnitudeTransform(<fn>)");
336 }
337}