1use std::sync::Arc;
2
3use sphereql_core::SphericalPoint;
4
5#[derive(Debug, Clone)]
6pub struct Embedding {
7 pub values: Vec<f64>,
8}
9
10#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
15pub struct ProjectedPoint {
16 pub position: SphericalPoint,
18 pub certainty: f64,
23 pub intensity: f64,
26 pub projection_magnitude: f64,
30}
31
32impl ProjectedPoint {
33 pub fn new(
34 position: SphericalPoint,
35 certainty: f64,
36 intensity: f64,
37 projection_magnitude: f64,
38 ) -> Self {
39 Self {
40 position,
41 certainty,
42 intensity,
43 projection_magnitude,
44 }
45 }
46
47 pub fn from_position(position: SphericalPoint, intensity: f64) -> Self {
49 Self {
50 position,
51 certainty: 1.0,
52 intensity,
53 projection_magnitude: 1.0,
54 }
55 }
56}
57
58impl Embedding {
59 pub fn new(values: Vec<f64>) -> Self {
60 Self { values }
61 }
62
63 pub fn dimension(&self) -> usize {
64 self.values.len()
65 }
66
67 pub fn magnitude(&self) -> f64 {
68 self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
69 }
70
71 pub fn normalized(&self) -> Vec<f64> {
72 let mag = self.magnitude();
73 if mag < f64::EPSILON {
74 let mut v = vec![0.0; self.values.len()];
75 if !v.is_empty() {
76 v[0] = 1.0;
77 }
78 return v;
79 }
80 self.values.iter().map(|v| v / mag).collect()
81 }
82}
83
84impl From<Vec<f64>> for Embedding {
85 fn from(values: Vec<f64>) -> Self {
86 Self { values }
87 }
88}
89
90impl From<&[f64]> for Embedding {
91 fn from(values: &[f64]) -> Self {
92 Self {
93 values: values.to_vec(),
94 }
95 }
96}
97
98#[derive(Default)]
103pub enum RadialStrategy {
104 Fixed(f64),
106 #[default]
109 Magnitude,
110 MagnitudeTransform(Arc<dyn Fn(f64) -> f64 + Send + Sync>),
113}
114
115impl Clone for RadialStrategy {
116 fn clone(&self) -> Self {
117 match self {
118 Self::Fixed(r) => Self::Fixed(*r),
119 Self::Magnitude => Self::Magnitude,
120 Self::MagnitudeTransform(f) => Self::MagnitudeTransform(Arc::clone(f)),
121 }
122 }
123}
124
125impl std::fmt::Debug for RadialStrategy {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 match self {
128 Self::Fixed(r) => write!(f, "Fixed({r})"),
129 Self::Magnitude => write!(f, "Magnitude"),
130 Self::MagnitudeTransform(_) => write!(f, "MagnitudeTransform(<fn>)"),
131 }
132 }
133}
134
135impl RadialStrategy {
136 pub fn compute(&self, magnitude: f64) -> f64 {
137 match self {
138 Self::Fixed(r) => *r,
139 Self::Magnitude => magnitude,
140 Self::MagnitudeTransform(f) => f(magnitude),
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn embedding_magnitude() {
151 let e = Embedding::new(vec![3.0, 4.0]);
152 assert!((e.magnitude() - 5.0).abs() < 1e-12);
153 }
154
155 #[test]
156 fn embedding_normalized() {
157 let e = Embedding::new(vec![3.0, 4.0]);
158 let n = e.normalized();
159 assert!((n[0] - 0.6).abs() < 1e-12);
160 assert!((n[1] - 0.8).abs() < 1e-12);
161 }
162
163 #[test]
164 fn zero_embedding_normalized_fallback() {
165 let e = Embedding::new(vec![0.0, 0.0, 0.0]);
166 let n = e.normalized();
167 assert!((n[0] - 1.0).abs() < 1e-12);
168 assert!(n[1].abs() < 1e-12);
169 assert!(n[2].abs() < 1e-12);
170 }
171
172 #[test]
173 fn from_vec() {
174 let e: Embedding = vec![1.0, 2.0, 3.0].into();
175 assert_eq!(e.dimension(), 3);
176 }
177
178 #[test]
179 fn from_slice() {
180 let data = [1.0, 2.0, 3.0];
181 let e: Embedding = data.as_slice().into();
182 assert_eq!(e.dimension(), 3);
183 }
184
185 #[test]
186 fn radial_fixed() {
187 let r = RadialStrategy::Fixed(2.5);
188 assert!((r.compute(999.0) - 2.5).abs() < 1e-12);
189 }
190
191 #[test]
192 fn radial_magnitude() {
193 let r = RadialStrategy::Magnitude;
194 assert!((r.compute(7.0) - 7.0).abs() < 1e-12);
195 }
196
197 #[test]
198 fn radial_transform() {
199 let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m.ln_1p()));
200 let expected = 5.0_f64.ln_1p();
201 assert!((r.compute(5.0) - expected).abs() < 1e-12);
202 }
203
204 #[test]
205 fn radial_clone() {
206 let r = RadialStrategy::MagnitudeTransform(Arc::new(|m| m * 2.0));
207 let r2 = r.clone();
208 assert!((r.compute(3.0) - r2.compute(3.0)).abs() < 1e-12);
209 }
210
211 #[test]
212 fn radial_debug() {
213 assert_eq!(format!("{:?}", RadialStrategy::Fixed(1.0)), "Fixed(1)");
214 assert_eq!(format!("{:?}", RadialStrategy::Magnitude), "Magnitude");
215 let t = RadialStrategy::MagnitudeTransform(Arc::new(|m| m));
216 assert_eq!(format!("{t:?}"), "MagnitudeTransform(<fn>)");
217 }
218}