Skip to main content

turbo_quant/
radius.rs

1//! Radius compression profiles for packed sidecar payloads.
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Result, TurboQuantError};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
9pub enum RadiusCodecProfileV1 {
10    F32,
11    BlockLinearU16,
12    BlockLogU8,
13}
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
16pub struct CompressedRadiiV1 {
17    pub profile: RadiusCodecProfileV1,
18    pub count: usize,
19    pub min: f32,
20    pub max: f32,
21    pub payload: Vec<u8>,
22}
23
24impl CompressedRadiiV1 {
25    pub fn compress(radii: &[f32], profile: RadiusCodecProfileV1) -> Result<Self> {
26        validate_radii(radii)?;
27        match profile {
28            RadiusCodecProfileV1::F32 => {
29                let mut payload = Vec::with_capacity(radii.len() * 4);
30                for radius in radii {
31                    payload.extend_from_slice(&radius.to_le_bytes());
32                }
33                Ok(Self {
34                    profile,
35                    count: radii.len(),
36                    min: 0.0,
37                    max: 0.0,
38                    payload,
39                })
40            }
41            RadiusCodecProfileV1::BlockLinearU16 => {
42                let (min, max) = min_max(radii);
43                let mut payload = Vec::with_capacity(radii.len() * 2);
44                let span = (max - min).max(f32::EPSILON);
45                for radius in radii {
46                    let normalized = ((*radius - min) / span).clamp(0.0, 1.0);
47                    let quantized = (normalized * u16::MAX as f32).round() as u16;
48                    payload.extend_from_slice(&quantized.to_le_bytes());
49                }
50                Ok(Self {
51                    profile,
52                    count: radii.len(),
53                    min,
54                    max,
55                    payload,
56                })
57            }
58            RadiusCodecProfileV1::BlockLogU8 => {
59                let logged = radii
60                    .iter()
61                    .map(|value| value.max(f32::MIN_POSITIVE).ln())
62                    .collect::<Vec<f32>>();
63                let (min, max) = min_max(&logged);
64                let mut payload = Vec::with_capacity(radii.len());
65                let span = (max - min).max(f32::EPSILON);
66                for value in &logged {
67                    let normalized = ((*value - min) / span).clamp(0.0, 1.0);
68                    payload.push((normalized * u8::MAX as f32).round() as u8);
69                }
70                Ok(Self {
71                    profile,
72                    count: radii.len(),
73                    min,
74                    max,
75                    payload,
76                })
77            }
78        }
79    }
80
81    pub fn decompress(&self) -> Result<Vec<f32>> {
82        match self.profile {
83            RadiusCodecProfileV1::F32 => {
84                if self.payload.len() != self.count * 4 {
85                    return Err(TurboQuantError::MalformedCode {
86                        reason: format!(
87                            "f32 radius payload has {} bytes, expected {}",
88                            self.payload.len(),
89                            self.count * 4
90                        ),
91                    });
92                }
93                self.payload
94                    .chunks_exact(4)
95                    .enumerate()
96                    .map(|(index, chunk)| {
97                        let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
98                        if !value.is_finite() || value < 0.0 {
99                            return Err(TurboQuantError::MalformedCode {
100                                reason: format!("radius {index} is not finite and non-negative"),
101                            });
102                        }
103                        Ok(value)
104                    })
105                    .collect()
106            }
107            RadiusCodecProfileV1::BlockLinearU16 => {
108                if self.payload.len() != self.count * 2 {
109                    return Err(TurboQuantError::MalformedCode {
110                        reason: format!(
111                            "linear u16 radius payload has {} bytes, expected {}",
112                            self.payload.len(),
113                            self.count * 2
114                        ),
115                    });
116                }
117                validate_range(self.min, self.max)?;
118                let span = self.max - self.min;
119                Ok(self
120                    .payload
121                    .chunks_exact(2)
122                    .map(|chunk| {
123                        let value = u16::from_le_bytes([chunk[0], chunk[1]]) as f32;
124                        self.min + span * (value / u16::MAX as f32)
125                    })
126                    .collect())
127            }
128            RadiusCodecProfileV1::BlockLogU8 => {
129                if self.payload.len() != self.count {
130                    return Err(TurboQuantError::MalformedCode {
131                        reason: format!(
132                            "log u8 radius payload has {} bytes, expected {}",
133                            self.payload.len(),
134                            self.count
135                        ),
136                    });
137                }
138                validate_range(self.min, self.max)?;
139                let span = self.max - self.min;
140                Ok(self
141                    .payload
142                    .iter()
143                    .map(|value| (self.min + span * (*value as f32 / u8::MAX as f32)).exp())
144                    .collect())
145            }
146        }
147    }
148
149    pub fn encoded_bytes(&self) -> usize {
150        match self.profile {
151            RadiusCodecProfileV1::F32 => self.payload.len(),
152            RadiusCodecProfileV1::BlockLinearU16 | RadiusCodecProfileV1::BlockLogU8 => {
153                self.payload.len() + 8
154            }
155        }
156    }
157}
158
159fn validate_radii(radii: &[f32]) -> Result<()> {
160    for (index, radius) in radii.iter().enumerate() {
161        if !radius.is_finite() || *radius < 0.0 {
162            return Err(TurboQuantError::MalformedCode {
163                reason: format!("radius {index} is not finite and non-negative"),
164            });
165        }
166    }
167    Ok(())
168}
169
170fn validate_range(min: f32, max: f32) -> Result<()> {
171    if !min.is_finite() || !max.is_finite() || min > max {
172        return Err(TurboQuantError::MalformedCode {
173            reason: "radius codec range is malformed".into(),
174        });
175    }
176    Ok(())
177}
178
179fn min_max(values: &[f32]) -> (f32, f32) {
180    if values.is_empty() {
181        return (0.0, 0.0);
182    }
183    values
184        .iter()
185        .copied()
186        .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), value| {
187            (min.min(value), max.max(value))
188        })
189}