1use std::fs::File;
7use std::io::{Read, Write};
8use std::path::Path;
9
10use ndarray::{ArrayBase, Data, Dimension, IxDyn, OwnedRepr};
11use serde::{Deserialize, Serialize};
12
13use super::{compress_data, decompress_data, CompressionAlgorithm};
14use crate::error::{IoError, Result};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CompressedArrayMetadata {
19 pub shape: Vec<usize>,
21 pub dtype: String,
23 pub element_size: usize,
25 pub algorithm: String,
27 pub original_size: usize,
29 pub compressed_size: usize,
31 pub compression_ratio: f64,
33 pub compression_level: u32,
35 pub additional_metadata: std::collections::HashMap<String, String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CompressedArray {
42 pub metadata: CompressedArrayMetadata,
44 pub data: Vec<u8>,
46}
47
48pub fn compress_array<P, A, S, D>(
62 path: P,
63 array: &ArrayBase<S, D>,
64 algorithm: CompressionAlgorithm,
65 level: Option<u32>,
66 additional_metadata: Option<std::collections::HashMap<String, String>>,
67) -> Result<()>
68where
69 P: AsRef<Path>,
70 A: Serialize + Clone,
71 S: Data<Elem = A>,
72 D: Dimension + Serialize,
73{
74 let flat_data: Vec<u8> =
76 bincode::serialize(array).map_err(|e| IoError::SerializationError(e.to_string()))?;
77
78 let level = level.unwrap_or(6);
80 let compressed_data = compress_data(&flat_data, algorithm, Some(level))?;
81
82 let metadata = CompressedArrayMetadata {
84 shape: array.shape().to_vec(),
85 dtype: std::any::type_name::<A>().to_string(),
86 element_size: std::mem::size_of::<A>(),
87 algorithm: format!("{:?}", algorithm),
88 original_size: flat_data.len(),
89 compressed_size: compressed_data.len(),
90 compression_ratio: flat_data.len() as f64 / compressed_data.len() as f64,
91 compression_level: level,
92 additional_metadata: additional_metadata.unwrap_or_default(),
93 };
94
95 let compressed_array = CompressedArray {
97 metadata,
98 data: compressed_data,
99 };
100
101 let serialized = bincode::serialize(&compressed_array)
103 .map_err(|e| IoError::SerializationError(e.to_string()))?;
104
105 File::create(path)
106 .map_err(|e| IoError::FileError(format!("Failed to create output file: {}", e)))?
107 .write_all(&serialized)
108 .map_err(|e| IoError::FileError(format!("Failed to write to output file: {}", e)))?;
109
110 Ok(())
111}
112
113pub fn decompress_array<P, A, D>(path: P) -> Result<ArrayBase<OwnedRepr<A>, D>>
123where
124 P: AsRef<Path>,
125 A: for<'de> Deserialize<'de> + Clone,
126 D: Dimension + for<'de> Deserialize<'de>,
127{
128 let mut file = File::open(path)
130 .map_err(|e| IoError::FileError(format!("Failed to open input file: {}", e)))?;
131
132 let mut serialized = Vec::new();
133 file.read_to_end(&mut serialized)
134 .map_err(|e| IoError::FileError(format!("Failed to read input file: {}", e)))?;
135
136 let compressed_array: CompressedArray = bincode::deserialize(&serialized)
138 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
139
140 let algorithm = match compressed_array.metadata.algorithm.as_str() {
142 "Gzip" => CompressionAlgorithm::Gzip,
143 "Zstd" => CompressionAlgorithm::Zstd,
144 "Lz4" => CompressionAlgorithm::Lz4,
145 "Bzip2" => CompressionAlgorithm::Bzip2,
146 _ => {
147 return Err(IoError::DecompressionError(format!(
148 "Unknown compression algorithm: {}",
149 compressed_array.metadata.algorithm
150 )))
151 }
152 };
153
154 let decompressed_data = decompress_data(&compressed_array.data, algorithm)?;
156
157 let array: ArrayBase<OwnedRepr<A>, D> = bincode::deserialize(&decompressed_data)
159 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
160
161 Ok(array)
162}
163
164pub fn compress_array_chunked<P, A, S>(
181 path: P,
182 array: &ArrayBase<S, IxDyn>,
183 algorithm: CompressionAlgorithm,
184 level: Option<u32>,
185 chunk_size: usize,
186) -> Result<()>
187where
188 P: AsRef<Path>,
189 A: Serialize + Clone,
190 S: Data<Elem = A>,
191{
192 let mut compressed_chunks = Vec::new();
194 let mut total_original_size = 0;
195 let mut total_compressed_size = 0;
196
197 for chunk_idx in 0..((array.len() + chunk_size - 1) / chunk_size) {
200 let start = chunk_idx * chunk_size;
201 let end = (start + chunk_size).min(array.len());
202
203 let chunk_data: Vec<A> = array
205 .iter()
206 .skip(start)
207 .take(end - start)
208 .cloned()
209 .collect();
210
211 let serialized_chunk = bincode::serialize(&chunk_data)
213 .map_err(|e| IoError::SerializationError(e.to_string()))?;
214
215 let compressed_chunk = compress_data(&serialized_chunk, algorithm, level)?;
217
218 total_original_size += serialized_chunk.len();
220 total_compressed_size += compressed_chunk.len();
221
222 compressed_chunks.push(compressed_chunk);
224 }
225
226 let metadata = CompressedArrayMetadata {
228 shape: array.shape().to_vec(),
229 dtype: std::any::type_name::<A>().to_string(),
230 element_size: std::mem::size_of::<A>(),
231 algorithm: format!("{:?}", algorithm),
232 original_size: total_original_size,
233 compressed_size: total_compressed_size,
234 compression_ratio: total_original_size as f64 / total_compressed_size as f64,
235 compression_level: level.unwrap_or(6),
236 additional_metadata: {
237 let mut map = std::collections::HashMap::new();
238 map.insert("chunked".to_string(), "true".to_string());
239 map.insert(
240 "num_chunks".to_string(),
241 compressed_chunks.len().to_string(),
242 );
243 map.insert("chunk_size".to_string(), chunk_size.to_string());
244 map
245 },
246 };
247
248 let mut file = File::create(path)
250 .map_err(|e| IoError::FileError(format!("Failed to create output file: {}", e)))?;
251
252 let serialized_metadata =
254 bincode::serialize(&metadata).map_err(|e| IoError::SerializationError(e.to_string()))?;
255
256 let metadata_size = serialized_metadata.len() as u64;
257 file.write_all(&metadata_size.to_le_bytes())
258 .map_err(|e| IoError::FileError(format!("Failed to write metadata size: {}", e)))?;
259
260 file.write_all(&serialized_metadata)
261 .map_err(|e| IoError::FileError(format!("Failed to write metadata: {}", e)))?;
262
263 let num_chunks = compressed_chunks.len() as u64;
265 file.write_all(&num_chunks.to_le_bytes())
266 .map_err(|e| IoError::FileError(format!("Failed to write chunk count: {}", e)))?;
267
268 for chunk in compressed_chunks {
270 let chunk_size = chunk.len() as u64;
271 file.write_all(&chunk_size.to_le_bytes())
272 .map_err(|e| IoError::FileError(format!("Failed to write chunk size: {}", e)))?;
273
274 file.write_all(&chunk)
275 .map_err(|e| IoError::FileError(format!("Failed to write chunk data: {}", e)))?;
276 }
277
278 Ok(())
279}
280
281pub fn decompress_array_chunked<P, A>(
291 path: P,
292) -> Result<(ArrayBase<OwnedRepr<A>, IxDyn>, CompressedArrayMetadata)>
293where
294 P: AsRef<Path>,
295 A: for<'de> Deserialize<'de> + Clone,
296{
297 let mut file = File::open(path)
298 .map_err(|e| IoError::FileError(format!("Failed to open input file: {}", e)))?;
299
300 let mut metadata_size_bytes = [0u8; 8];
302 file.read_exact(&mut metadata_size_bytes)
303 .map_err(|e| IoError::FileError(format!("Failed to read metadata size: {}", e)))?;
304
305 let metadata_size = u64::from_le_bytes(metadata_size_bytes) as usize;
306
307 let mut metadata_bytes = vec![0u8; metadata_size];
309 file.read_exact(&mut metadata_bytes)
310 .map_err(|e| IoError::FileError(format!("Failed to read metadata: {}", e)))?;
311
312 let metadata: CompressedArrayMetadata = bincode::deserialize(&metadata_bytes)
313 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
314
315 let algorithm = match metadata.algorithm.as_str() {
317 "Gzip" => CompressionAlgorithm::Gzip,
318 "Zstd" => CompressionAlgorithm::Zstd,
319 "Lz4" => CompressionAlgorithm::Lz4,
320 "Bzip2" => CompressionAlgorithm::Bzip2,
321 _ => {
322 return Err(IoError::DecompressionError(format!(
323 "Unknown compression algorithm: {}",
324 metadata.algorithm
325 )))
326 }
327 };
328
329 let mut num_chunks_bytes = [0u8; 8];
331 file.read_exact(&mut num_chunks_bytes)
332 .map_err(|e| IoError::FileError(format!("Failed to read chunk count: {}", e)))?;
333
334 let num_chunks = u64::from_le_bytes(num_chunks_bytes) as usize;
335
336 let total_elements: usize = metadata.shape.iter().product();
338 let mut all_elements = Vec::with_capacity(total_elements);
339
340 for _ in 0..num_chunks {
342 let mut chunk_size_bytes = [0u8; 8];
344 file.read_exact(&mut chunk_size_bytes)
345 .map_err(|e| IoError::FileError(format!("Failed to read chunk size: {}", e)))?;
346
347 let chunk_size = u64::from_le_bytes(chunk_size_bytes) as usize;
348
349 let mut chunk_bytes = vec![0u8; chunk_size];
351 file.read_exact(&mut chunk_bytes)
352 .map_err(|e| IoError::FileError(format!("Failed to read chunk data: {}", e)))?;
353
354 let decompressed_chunk = decompress_data(&chunk_bytes, algorithm)?;
356
357 let chunk_elements: Vec<A> = bincode::deserialize(&decompressed_chunk)
359 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
360
361 all_elements.extend(chunk_elements);
362 }
363
364 let array = ArrayBase::from_shape_vec(IxDyn(&metadata.shape), all_elements)
366 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
367
368 Ok((array, metadata))
369}
370
371pub fn compare_compression_algorithms<A, S, D>(
386 array: &ArrayBase<S, D>,
387 algorithms: &[CompressionAlgorithm],
388 level: Option<u32>,
389) -> Result<Vec<(CompressionAlgorithm, f64, usize)>>
390where
391 A: Serialize + Clone,
392 S: Data<Elem = A>,
393 D: Dimension + Serialize,
394{
395 let serialized =
397 bincode::serialize(array).map_err(|e| IoError::SerializationError(e.to_string()))?;
398
399 let original_size = serialized.len();
400
401 let mut results = Vec::new();
403
404 for &algorithm in algorithms {
405 let compressed = compress_data(&serialized, algorithm, level)?;
407 let compressed_size = compressed.len();
408 let ratio = original_size as f64 / compressed_size as f64;
409
410 results.push((algorithm, ratio, compressed_size));
411 }
412
413 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
415
416 Ok(results)
417}