1use crate::{BM25Index, Result};
9use serde::{de::DeserializeOwned, Serialize};
10
11pub use batuta_common::compression::Compression;
15
16pub fn serialize_compressed<T: Serialize>(index: &T, compression: Compression) -> Result<Vec<u8>> {
21 let bytes = bincode::serialize(index).map_err(|e| {
22 crate::Error::SerializationError(format!("Bincode serialization failed: {e}"))
23 })?;
24 Ok(compression.compress(&bytes)?)
25}
26
27pub fn deserialize_compressed<T: DeserializeOwned>(
32 data: &[u8],
33 compression: Compression,
34) -> Result<T> {
35 let decompressed = compression.decompress(data)?;
36 bincode::deserialize(&decompressed).map_err(|e| {
37 crate::Error::SerializationError(format!("Bincode deserialization failed: {e}"))
38 })
39}
40
41impl BM25Index {
42 pub fn to_compressed_bytes(&self, compression: Compression) -> Result<Vec<u8>> {
47 serialize_compressed(self, compression)
48 }
49
50 pub fn from_compressed_bytes(data: &[u8], compression: Compression) -> Result<Self> {
55 deserialize_compressed(data, compression)
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::{index::SparseIndex, Chunk, DocumentId};
63
64 fn create_test_chunk(content: &str) -> Chunk {
65 Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
66 }
67
68 #[test]
73 fn test_compression_as_str() {
74 assert_eq!(Compression::Lz4.as_str(), "lz4");
75 assert_eq!(Compression::Zstd.as_str(), "zstd");
76 }
77
78 #[test]
79 fn test_compression_default() {
80 assert_eq!(Compression::default(), Compression::Lz4);
81 }
82
83 #[test]
84 fn test_lz4_compress_decompress() {
85 let data = b"hello world hello world hello world".to_vec();
86 let compressed = Compression::Lz4.compress(&data).unwrap();
87 let decompressed = Compression::Lz4.decompress(&compressed).unwrap();
88 assert_eq!(decompressed, data);
89 }
90
91 #[test]
92 fn test_zstd_compress_decompress() {
93 let data = b"hello world hello world hello world".to_vec();
94 let compressed = Compression::Zstd.compress(&data).unwrap();
95 let decompressed = Compression::Zstd.decompress(&compressed).unwrap();
96 assert_eq!(decompressed, data);
97 }
98
99 #[test]
100 fn test_empty_data_compression() {
101 let empty: Vec<u8> = vec![];
102
103 let lz4_compressed = Compression::Lz4.compress(&empty).unwrap();
104 assert!(lz4_compressed.is_empty());
105 let lz4_decompressed = Compression::Lz4.decompress(&lz4_compressed).unwrap();
106 assert!(lz4_decompressed.is_empty());
107
108 let zstd_compressed = Compression::Zstd.compress(&empty).unwrap();
109 assert!(zstd_compressed.is_empty());
110 let zstd_decompressed = Compression::Zstd.decompress(&zstd_compressed).unwrap();
111 assert!(zstd_decompressed.is_empty());
112 }
113
114 #[test]
115 fn test_lz4_compresses_repeated_data() {
116 let data = vec![0u8; 10000];
117 let compressed = Compression::Lz4.compress(&data).unwrap();
118 assert!(compressed.len() < data.len() / 10);
120 }
121
122 #[test]
123 fn test_zstd_compresses_repeated_data() {
124 let data = vec![0u8; 10000];
125 let compressed = Compression::Zstd.compress(&data).unwrap();
126 assert!(compressed.len() < data.len() / 10);
128 }
129
130 #[test]
135 fn test_bm25_lz4_roundtrip() {
136 let mut index = BM25Index::new();
137 index.add(&create_test_chunk("machine learning is great"));
138 index.add(&create_test_chunk("deep learning neural networks"));
139 index.add(&create_test_chunk("natural language processing"));
140
141 let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
142 let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
143
144 assert_eq!(index.len(), restored.len());
146 let original_results = index.search("machine learning", 10);
147 let restored_results = restored.search("machine learning", 10);
148 assert_eq!(original_results.len(), restored_results.len());
149 }
150
151 #[test]
152 fn test_bm25_zstd_roundtrip() {
153 let mut index = BM25Index::new();
154 index.add(&create_test_chunk("rust programming language"));
155 index.add(&create_test_chunk("systems programming with rust"));
156
157 let compressed = index.to_compressed_bytes(Compression::Zstd).unwrap();
158 let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Zstd).unwrap();
159
160 assert_eq!(index.len(), restored.len());
161 }
162
163 #[test]
164 fn test_bm25_compression_reduces_size() {
165 let mut index = BM25Index::new();
166 for i in 0..100 {
168 index.add(&create_test_chunk(&format!(
169 "document number {i} about machine learning and artificial intelligence"
170 )));
171 }
172
173 let uncompressed = bincode::serialize(&index).unwrap();
174 let lz4_compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
175 let zstd_compressed = index.to_compressed_bytes(Compression::Zstd).unwrap();
176
177 assert!(lz4_compressed.len() < uncompressed.len());
179 assert!(zstd_compressed.len() < uncompressed.len());
180
181 assert!(zstd_compressed.len() <= lz4_compressed.len());
183 }
184
185 #[test]
186 fn test_bm25_empty_index_compression() {
187 let index = BM25Index::new();
188
189 let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
190 let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
191
192 assert!(restored.is_empty());
193 }
194
195 #[test]
196 fn test_bm25_preserved_search_behavior() {
197 let mut index = BM25Index::new();
198 index.add(&create_test_chunk("python programming language scripting"));
199 index.add(&create_test_chunk("javascript web development frontend"));
200 index.add(&create_test_chunk("rust systems programming performance"));
201
202 let compressed = index.to_compressed_bytes(Compression::Lz4).unwrap();
204 let restored = BM25Index::from_compressed_bytes(&compressed, Compression::Lz4).unwrap();
205
206 let query = "programming language";
208 let original_results = index.search(query, 3);
209 let restored_results = restored.search(query, 3);
210
211 assert_eq!(original_results.len(), restored_results.len());
212 for ((orig_id, orig_score), (rest_id, rest_score)) in
214 original_results.iter().zip(restored_results.iter())
215 {
216 assert_eq!(orig_id, rest_id);
217 assert!((orig_score - rest_score).abs() < 1e-5);
218 }
219 }
220}