1use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Result, TurboQuantError};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
9pub enum RadiusCodecProfileV1 {
10 F32,
11 BlockLinearU16,
12 BlockLogU8,
13}
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
16pub struct CompressedRadiiV1 {
17 pub profile: RadiusCodecProfileV1,
18 pub count: usize,
19 pub min: f32,
20 pub max: f32,
21 pub payload: Vec<u8>,
22}
23
24impl CompressedRadiiV1 {
25 pub fn compress(radii: &[f32], profile: RadiusCodecProfileV1) -> Result<Self> {
26 validate_radii(radii)?;
27 match profile {
28 RadiusCodecProfileV1::F32 => {
29 let mut payload = Vec::with_capacity(radii.len() * 4);
30 for radius in radii {
31 payload.extend_from_slice(&radius.to_le_bytes());
32 }
33 Ok(Self {
34 profile,
35 count: radii.len(),
36 min: 0.0,
37 max: 0.0,
38 payload,
39 })
40 }
41 RadiusCodecProfileV1::BlockLinearU16 => {
42 let (min, max) = min_max(radii);
43 let mut payload = Vec::with_capacity(radii.len() * 2);
44 let span = (max - min).max(f32::EPSILON);
45 for radius in radii {
46 let normalized = ((*radius - min) / span).clamp(0.0, 1.0);
47 let quantized = (normalized * u16::MAX as f32).round() as u16;
48 payload.extend_from_slice(&quantized.to_le_bytes());
49 }
50 Ok(Self {
51 profile,
52 count: radii.len(),
53 min,
54 max,
55 payload,
56 })
57 }
58 RadiusCodecProfileV1::BlockLogU8 => {
59 let logged = radii
60 .iter()
61 .map(|value| value.max(f32::MIN_POSITIVE).ln())
62 .collect::<Vec<f32>>();
63 let (min, max) = min_max(&logged);
64 let mut payload = Vec::with_capacity(radii.len());
65 let span = (max - min).max(f32::EPSILON);
66 for value in &logged {
67 let normalized = ((*value - min) / span).clamp(0.0, 1.0);
68 payload.push((normalized * u8::MAX as f32).round() as u8);
69 }
70 Ok(Self {
71 profile,
72 count: radii.len(),
73 min,
74 max,
75 payload,
76 })
77 }
78 }
79 }
80
81 pub fn decompress(&self) -> Result<Vec<f32>> {
82 match self.profile {
83 RadiusCodecProfileV1::F32 => {
84 if self.payload.len() != self.count * 4 {
85 return Err(TurboQuantError::MalformedCode {
86 reason: format!(
87 "f32 radius payload has {} bytes, expected {}",
88 self.payload.len(),
89 self.count * 4
90 ),
91 });
92 }
93 self.payload
94 .chunks_exact(4)
95 .enumerate()
96 .map(|(index, chunk)| {
97 let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
98 if !value.is_finite() || value < 0.0 {
99 return Err(TurboQuantError::MalformedCode {
100 reason: format!("radius {index} is not finite and non-negative"),
101 });
102 }
103 Ok(value)
104 })
105 .collect()
106 }
107 RadiusCodecProfileV1::BlockLinearU16 => {
108 if self.payload.len() != self.count * 2 {
109 return Err(TurboQuantError::MalformedCode {
110 reason: format!(
111 "linear u16 radius payload has {} bytes, expected {}",
112 self.payload.len(),
113 self.count * 2
114 ),
115 });
116 }
117 validate_range(self.min, self.max)?;
118 let span = self.max - self.min;
119 Ok(self
120 .payload
121 .chunks_exact(2)
122 .map(|chunk| {
123 let value = u16::from_le_bytes([chunk[0], chunk[1]]) as f32;
124 self.min + span * (value / u16::MAX as f32)
125 })
126 .collect())
127 }
128 RadiusCodecProfileV1::BlockLogU8 => {
129 if self.payload.len() != self.count {
130 return Err(TurboQuantError::MalformedCode {
131 reason: format!(
132 "log u8 radius payload has {} bytes, expected {}",
133 self.payload.len(),
134 self.count
135 ),
136 });
137 }
138 validate_range(self.min, self.max)?;
139 let span = self.max - self.min;
140 Ok(self
141 .payload
142 .iter()
143 .map(|value| (self.min + span * (*value as f32 / u8::MAX as f32)).exp())
144 .collect())
145 }
146 }
147 }
148
149 pub fn encoded_bytes(&self) -> usize {
150 match self.profile {
151 RadiusCodecProfileV1::F32 => self.payload.len(),
152 RadiusCodecProfileV1::BlockLinearU16 | RadiusCodecProfileV1::BlockLogU8 => {
153 self.payload.len() + 8
154 }
155 }
156 }
157}
158
159fn validate_radii(radii: &[f32]) -> Result<()> {
160 for (index, radius) in radii.iter().enumerate() {
161 if !radius.is_finite() || *radius < 0.0 {
162 return Err(TurboQuantError::MalformedCode {
163 reason: format!("radius {index} is not finite and non-negative"),
164 });
165 }
166 }
167 Ok(())
168}
169
170fn validate_range(min: f32, max: f32) -> Result<()> {
171 if !min.is_finite() || !max.is_finite() || min > max {
172 return Err(TurboQuantError::MalformedCode {
173 reason: "radius codec range is malformed".into(),
174 });
175 }
176 Ok(())
177}
178
179fn min_max(values: &[f32]) -> (f32, f32) {
180 if values.is_empty() {
181 return (0.0, 0.0);
182 }
183 values
184 .iter()
185 .copied()
186 .fold((f32::INFINITY, f32::NEG_INFINITY), |(min, max), value| {
187 (min.min(value), max.max(value))
188 })
189}