1use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::Instant;
14
15use bytes::Bytes;
16
17use crate::{ChunkManifest, Codec, CodecError, CodecKind, CompressTelemetry, looks_like_oom};
18
19fn is_gpu_kind(kind: CodecKind) -> bool {
25 matches!(
26 kind,
27 CodecKind::NvcompZstd
28 | CodecKind::NvcompBitcomp
29 | CodecKind::NvcompGans
30 | CodecKind::NvcompGDeflate
31 | CodecKind::DietGpuAns
32 )
33}
34
35pub struct CodecRegistry {
37 codecs: HashMap<CodecKind, Arc<dyn Codec>>,
38 default: CodecKind,
39}
40
41impl CodecRegistry {
42 pub fn new(default: CodecKind) -> Self {
45 Self {
46 codecs: HashMap::new(),
47 default,
48 }
49 }
50
51 pub fn register(&mut self, codec: Arc<dyn Codec>) -> &mut Self {
53 self.codecs.insert(codec.kind(), codec);
54 self
55 }
56
57 #[must_use]
59 pub fn with(mut self, codec: Arc<dyn Codec>) -> Self {
60 self.register(codec);
61 self
62 }
63
64 pub fn kinds(&self) -> impl Iterator<Item = CodecKind> + '_ {
66 self.codecs.keys().copied()
67 }
68
69 pub fn default_kind(&self) -> CodecKind {
71 self.default
72 }
73
74 fn lookup(&self, kind: CodecKind) -> Result<&Arc<dyn Codec>, CodecError> {
75 self.codecs
76 .get(&kind)
77 .ok_or(CodecError::UnregisteredCodec(kind))
78 }
79
80 pub async fn compress(
82 &self,
83 input: Bytes,
84 kind: CodecKind,
85 ) -> Result<(Bytes, ChunkManifest), CodecError> {
86 let codec = self.lookup(kind)?;
87 codec.compress(input).await
88 }
89
90 pub async fn decompress(
92 &self,
93 input: Bytes,
94 manifest: &ChunkManifest,
95 ) -> Result<Bytes, CodecError> {
96 let codec = self.lookup(manifest.codec)?;
97 codec.decompress(input, manifest).await
98 }
99
100 pub async fn compress_with_telemetry(
115 &self,
116 input: Bytes,
117 kind: CodecKind,
118 ) -> (
119 Result<(Bytes, ChunkManifest), CodecError>,
120 CompressTelemetry,
121 ) {
122 let bytes_in = input.len() as u64;
123 let codec = match self.lookup(kind) {
124 Ok(c) => c,
125 Err(e) => {
126 let tel = CompressTelemetry {
127 codec: kind.as_str(),
128 bytes_in,
129 bytes_out: 0,
130 gpu_seconds: None,
131 oom: false,
132 };
133 return (Err(e), tel);
134 }
135 };
136 let is_gpu = is_gpu_kind(kind);
137 let started = Instant::now();
138 let result = codec.compress(input).await;
139 let elapsed = started.elapsed().as_secs_f64();
140 match &result {
141 Ok((out, _manifest)) => {
142 let bytes_out = out.len() as u64;
143 let tel = if is_gpu {
144 CompressTelemetry::gpu(kind.as_str(), bytes_in, bytes_out, elapsed)
145 } else {
146 CompressTelemetry::cpu(kind.as_str(), bytes_in, bytes_out)
147 };
148 (result, tel)
149 }
150 Err(e) => {
151 let mut tel = if is_gpu {
152 CompressTelemetry::gpu(kind.as_str(), bytes_in, 0, elapsed)
153 } else {
154 CompressTelemetry::cpu(kind.as_str(), bytes_in, 0)
155 };
156 if looks_like_oom(e) {
157 tel = tel.with_oom();
158 }
159 (result, tel)
160 }
161 }
162 }
163
164 pub async fn decompress_with_telemetry(
169 &self,
170 input: Bytes,
171 manifest: &ChunkManifest,
172 ) -> (Result<Bytes, CodecError>, CompressTelemetry) {
173 let bytes_in = input.len() as u64;
174 let kind = manifest.codec;
175 let codec = match self.lookup(kind) {
176 Ok(c) => c,
177 Err(e) => {
178 let tel = CompressTelemetry {
179 codec: kind.as_str(),
180 bytes_in,
181 bytes_out: 0,
182 gpu_seconds: None,
183 oom: false,
184 };
185 return (Err(e), tel);
186 }
187 };
188 let is_gpu = is_gpu_kind(kind);
189 let started = Instant::now();
190 let result = codec.decompress(input, manifest).await;
191 let elapsed = started.elapsed().as_secs_f64();
192 match &result {
193 Ok(out) => {
194 let bytes_out = out.len() as u64;
195 let tel = if is_gpu {
196 CompressTelemetry::gpu(kind.as_str(), bytes_in, bytes_out, elapsed)
197 } else {
198 CompressTelemetry::cpu(kind.as_str(), bytes_in, bytes_out)
199 };
200 (result, tel)
201 }
202 Err(e) => {
203 let mut tel = if is_gpu {
204 CompressTelemetry::gpu(kind.as_str(), bytes_in, 0, elapsed)
205 } else {
206 CompressTelemetry::cpu(kind.as_str(), bytes_in, 0)
207 };
208 if looks_like_oom(e) {
209 tel = tel.with_oom();
210 }
211 (result, tel)
212 }
213 }
214 }
215}
216
217impl std::fmt::Debug for CodecRegistry {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 let mut kinds: Vec<&CodecKind> = self.codecs.keys().collect();
220 kinds.sort_unstable_by_key(|k| k.as_str());
221 f.debug_struct("CodecRegistry")
222 .field("default", &self.default)
223 .field("registered", &kinds)
224 .finish()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::cpu_zstd::CpuZstd;
232 use crate::passthrough::Passthrough;
233
234 fn registry() -> CodecRegistry {
235 CodecRegistry::new(CodecKind::CpuZstd)
236 .with(Arc::new(Passthrough))
237 .with(Arc::new(CpuZstd::default()))
238 }
239
240 #[tokio::test]
241 async fn dispatches_compress_by_kind() {
242 let r = registry();
243 let input = Bytes::from(vec![b'a'; 1024]);
244
245 let (compressed_pt, manifest_pt) = r
246 .compress(input.clone(), CodecKind::Passthrough)
247 .await
248 .unwrap();
249 assert_eq!(manifest_pt.codec, CodecKind::Passthrough);
250 assert_eq!(compressed_pt.len(), input.len());
251
252 let (compressed_zstd, manifest_zstd) =
253 r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
254 assert_eq!(manifest_zstd.codec, CodecKind::CpuZstd);
255 assert!(compressed_zstd.len() < input.len() / 5);
256 }
257
258 #[tokio::test]
259 async fn dispatches_decompress_by_manifest() {
260 let r = registry();
261 let input = Bytes::from(vec![b'a'; 1024]);
262 let (compressed, manifest) = r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
263 let decompressed = r.decompress(compressed, &manifest).await.unwrap();
264 assert_eq!(decompressed, input);
265 }
266
267 #[tokio::test]
268 async fn unregistered_codec_yields_error() {
269 let r = registry();
270 let bogus_manifest = ChunkManifest {
271 codec: CodecKind::NvcompBitcomp,
272 original_size: 10,
273 compressed_size: 10,
274 crc32c: 0,
275 };
276 let err = r
277 .decompress(Bytes::from_static(b"0123456789"), &bogus_manifest)
278 .await
279 .unwrap_err();
280 assert!(matches!(
281 err,
282 CodecError::UnregisteredCodec(CodecKind::NvcompBitcomp)
283 ));
284 }
285
286 #[tokio::test]
290 async fn compress_with_telemetry_cpu_marks_gpu_seconds_none() {
291 let r = registry();
292 let input = Bytes::from(vec![b'a'; 1024]);
293 let (res, tel) = r
294 .compress_with_telemetry(input.clone(), CodecKind::CpuZstd)
295 .await;
296 let (out, _manifest) = res.expect("compress ok");
297 assert_eq!(tel.codec, "cpu-zstd");
298 assert_eq!(tel.bytes_in, input.len() as u64);
299 assert_eq!(tel.bytes_out, out.len() as u64);
300 assert!(
301 tel.gpu_seconds.is_none(),
302 "CPU codec must report gpu_seconds=None, got {:?}",
303 tel.gpu_seconds
304 );
305 assert!(!tel.oom);
306 }
307
308 #[tokio::test]
313 async fn compress_with_telemetry_unregistered_returns_telemetry() {
314 let r = registry();
315 let input = Bytes::from(vec![b'b'; 32]);
316 let (res, tel) = r
317 .compress_with_telemetry(input.clone(), CodecKind::NvcompBitcomp)
318 .await;
319 assert!(matches!(
320 res,
321 Err(CodecError::UnregisteredCodec(CodecKind::NvcompBitcomp))
322 ));
323 assert_eq!(tel.codec, "nvcomp-bitcomp");
324 assert_eq!(tel.bytes_in, input.len() as u64);
325 assert_eq!(tel.bytes_out, 0);
326 assert!(tel.gpu_seconds.is_none());
327 assert!(!tel.oom);
328 }
329
330 #[tokio::test]
333 async fn decompress_with_telemetry_cpu_reports_output_size() {
334 let r = registry();
335 let input = Bytes::from(vec![b'c'; 1024]);
336 let (compressed, manifest) = r.compress(input.clone(), CodecKind::CpuZstd).await.unwrap();
337 let (res, tel) = r
338 .decompress_with_telemetry(compressed.clone(), &manifest)
339 .await;
340 let out = res.expect("decompress ok");
341 assert_eq!(out, input);
342 assert_eq!(tel.codec, "cpu-zstd");
343 assert_eq!(tel.bytes_in, compressed.len() as u64);
344 assert_eq!(tel.bytes_out, input.len() as u64);
345 assert!(tel.gpu_seconds.is_none());
346 }
347}