Skip to main content

xet_data/processing/
data_client.rs

1use std::fs::File;
2use std::io::Read;
3use std::path::Path;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use http::header::HeaderMap;
8use tracing::{Instrument, Span, info_span, instrument};
9use ulid::Ulid;
10use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
11use xet_core_structures::merklehash::MerkleHash;
12use xet_runtime::core::par_utils::run_constrained_with_semaphore;
13use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
14
15use super::configurations::{SessionContext, TranslatorConfig};
16use super::file_cleaner::Sha256Policy;
17use super::{FileUploadSession, XetFileInfo};
18use crate::deduplication::{Chunker, DeduplicationMetrics};
19use crate::error::Result;
20
21pub fn default_config(
22    endpoint: String,
23    token_info: Option<(String, u64)>,
24    token_refresher: Option<Arc<dyn TokenRefresher>>,
25    custom_headers: Option<Arc<HeaderMap>>,
26) -> Result<TranslatorConfig> {
27    let (token, token_expiration) = token_info.unzip();
28    let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher);
29
30    let session = SessionContext {
31        endpoint,
32        auth: auth_cfg,
33        custom_headers,
34        repo_paths: vec!["".into()],
35        session_id: Some(Ulid::new().to_string()),
36    };
37
38    TranslatorConfig::new(session)
39}
40
41#[instrument(skip_all, name = "clean_bytes", fields(bytes.len = bytes.len()))]
42pub async fn clean_bytes(
43    processor: Arc<FileUploadSession>,
44    bytes: Vec<u8>,
45    sha256_policy: Sha256Policy,
46) -> Result<(XetFileInfo, DeduplicationMetrics)> {
47    let (_id, mut handle) = processor.start_clean(None, Some(bytes.len() as u64), sha256_policy)?;
48    handle.add_data(&bytes).await?;
49    handle.finish().await
50}
51
52#[instrument(skip_all, name = "clean_file", fields(file.name = tracing::field::Empty, file.len = tracing::field::Empty))]
53pub async fn clean_file(
54    processor: Arc<FileUploadSession>,
55    filename: impl AsRef<Path>,
56    sha256_policy: Sha256Policy,
57) -> Result<(XetFileInfo, DeduplicationMetrics)> {
58    let mut reader = File::open(&filename)?;
59
60    let filesize = reader.metadata()?.len();
61    let span = Span::current();
62    span.record("file.name", filename.as_ref().to_str());
63    span.record("file.len", filesize);
64    let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
65
66    let (_id, mut handle) =
67        processor.start_clean(Some(filename.as_ref().to_string_lossy().into()), Some(filesize), sha256_policy)?;
68
69    loop {
70        let bytes = reader.read(&mut buffer)?;
71        if bytes == 0 {
72            break;
73        }
74
75        handle.add_data(&buffer[0..bytes]).await?;
76    }
77
78    handle.finish().await
79}
80
81/// Computes the xet hash for a single file without uploading.
82///
83/// This function performs local-only hash computation by reading the file,
84/// chunking it using content-defined chunking, and computing the aggregated
85/// hash from the chunk hashes. The resulting hash is identical to what would
86/// be returned by upload operations, enabling verification of downloaded files.
87///
88/// # Arguments
89/// * `filename` - Path to the file to hash
90/// * `buffer_size` - Size of the read buffer in bytes
91///
92/// # Returns
93/// * `XetFileInfo` containing the hex-encoded hash and file size
94///
95/// # Errors
96/// * `IoError` if the file cannot be opened or read
97///
98/// # Use Cases
99/// - Verify that downloaded files are correctly reassembled
100/// - Check if a file needs to be uploaded (by comparing hashes)
101/// - Generate cache keys for local file operations
102fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo> {
103    let mut reader = File::open(&filename)?;
104    let filesize = reader.metadata()?.len();
105
106    let mut buffer = vec![0u8; buffer_size];
107    let mut chunker = Chunker::default();
108    let mut chunk_hashes: Vec<(MerkleHash, u64)> = Vec::new();
109
110    loop {
111        check_sigint_shutdown()?;
112
113        let bytes_read = reader.read(&mut buffer)?;
114        if bytes_read == 0 {
115            break;
116        }
117
118        let data = Bytes::copy_from_slice(&buffer[0..bytes_read]);
119        let chunks = chunker.next_block_bytes(&data, false);
120
121        for chunk in chunks {
122            chunk_hashes.push((chunk.hash, chunk.data.len() as u64));
123        }
124    }
125
126    // Get the final chunk if any data remains in the chunker
127    if let Some(final_chunk) = chunker.finish() {
128        chunk_hashes.push((final_chunk.hash, final_chunk.data.len() as u64));
129    }
130
131    let file_hash = xet_core_structures::merklehash::file_hash(&chunk_hashes);
132    Ok(XetFileInfo::new(file_hash.hex(), filesize))
133}
134
135/// Computes xet hashes for multiple files in parallel without uploading.
136///
137/// This function processes multiple files concurrently using a semaphore to limit
138/// parallelism. Each file is hashed independently using `hash_single_file()`.
139/// The resulting hashes are identical to those from upload operations,
140/// enabling validation and verification of file transfers.
141///
142/// # Arguments
143/// * `file_paths` - Vector of file paths to hash
144///
145/// # Returns
146/// * Vector of `XetFileInfo` in the same order as input file paths
147///
148/// # Errors
149/// * Returns error if any file cannot be read or hashed
150///
151/// # Use Cases
152/// - Verify integrity of downloaded files by comparing computed hashes
153/// - Batch validation of multiple files after transfer
154/// - Determine which files need to be uploaded by comparing with server hashes
155///
156/// # Performance
157/// - Uses `file_ingestion_semaphore` to control parallelism
158/// - No authentication or server connection required
159/// - Pure local computation
160#[instrument(skip_all, name = "data_client::hash_files", fields(num_files=file_paths.len()))]
161pub async fn hash_files_async(file_paths: Vec<String>) -> Result<Vec<XetFileInfo>> {
162    let rt = XetRuntime::current();
163    let semaphore = rt.common().file_ingestion_semaphore.clone();
164    let buffer_size = *xet_config().data.ingestion_block_size as usize;
165
166    let hash_futures = file_paths.into_iter().map(|file_path| {
167        let rt = rt.clone();
168        async move {
169            rt.spawn_blocking(move || hash_single_file(file_path, buffer_size))
170                .await
171                .map_err(|e| std::io::Error::other(e.to_string()))?
172        }
173        .instrument(info_span!("hash_file"))
174    });
175
176    let files = run_constrained_with_semaphore(hash_futures, semaphore).await?;
177
178    Ok(files)
179}
180
181#[cfg(test)]
182mod tests {
183    use dirs::home_dir;
184    use serial_test::serial;
185    use tempfile::tempdir;
186    use xet_runtime::utils::EnvVarGuard;
187
188    use super::*;
189
190    #[test]
191    #[serial(default_config_env)]
192    fn test_default_config_with_hf_home() {
193        let temp_dir = tempdir().unwrap();
194        let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
195
196        let endpoint = "http://localhost:8080".to_string();
197        let result = default_config(endpoint, None, None, None);
198
199        assert!(result.is_ok());
200        let config = result.unwrap();
201        assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
202    }
203
204    #[test]
205    #[serial(default_config_env)]
206    fn test_default_config_with_hf_xet_cache_and_hf_home() {
207        let temp_dir_xet_cache = tempdir().unwrap();
208        let temp_dir_hf_home = tempdir().unwrap();
209
210        let hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir_xet_cache.path().to_str().unwrap());
211        let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
212
213        let endpoint = "http://localhost:8080".to_string();
214        let result = default_config(endpoint, None, None, None);
215
216        assert!(result.is_ok());
217        let config = result.unwrap();
218        assert!(config.shard_cache_directory.starts_with(temp_dir_xet_cache.path()));
219
220        drop(hf_xet_cache_guard);
221        drop(hf_home_guard);
222
223        let temp_dir = tempdir().unwrap();
224        let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
225
226        let endpoint = "http://localhost:8080".to_string();
227        let result = default_config(endpoint, None, None, None);
228
229        assert!(result.is_ok());
230        let config = result.unwrap();
231        assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
232    }
233
234    #[test]
235    #[serial(default_config_env)]
236    fn test_default_config_with_hf_xet_cache() {
237        let temp_dir = tempdir().unwrap();
238        let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
239
240        let endpoint = "http://localhost:8080".to_string();
241        let result = default_config(endpoint, None, None, None);
242
243        assert!(result.is_ok());
244        let config = result.unwrap();
245        assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
246    }
247
248    #[test]
249    #[serial(default_config_env)]
250    fn test_default_config_without_env_vars() {
251        let endpoint = "http://localhost:8080".to_string();
252        let result = default_config(endpoint, None, None, None);
253
254        let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");
255
256        assert!(result.is_ok());
257        let config = result.unwrap();
258        let test_cache_dir = &config.shard_cache_directory;
259        assert!(
260            test_cache_dir.starts_with(&expected),
261            "cache dir = {test_cache_dir:?}; does not start with {expected:?}",
262        );
263    }
264
265    #[tokio::test]
266    async fn test_hash_empty_file() {
267        let temp_dir = tempdir().unwrap();
268        let file_path = temp_dir.path().join("empty.txt");
269        std::fs::write(&file_path, b"").unwrap();
270
271        let buffer_size = 8 * 1024 * 1024; // 8MB
272        let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
273        assert!(result.is_ok());
274
275        let file_info = result.unwrap();
276        assert_eq!(file_info.file_size(), Some(0));
277        assert!(!file_info.hash().is_empty());
278    }
279
280    #[tokio::test]
281    async fn test_hash_small_file() {
282        let temp_dir = tempdir().unwrap();
283        let file_path = temp_dir.path().join("small.txt");
284        let content = b"Hello, World!";
285        std::fs::write(&file_path, content).unwrap();
286
287        let buffer_size = 8 * 1024 * 1024; // 8MB
288        let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
289        assert!(result.is_ok());
290
291        let file_info = result.unwrap();
292        assert_eq!(file_info.file_size(), Some(content.len() as u64));
293        assert!(!file_info.hash().is_empty());
294    }
295
296    #[tokio::test]
297    #[cfg_attr(feature = "smoke-test", ignore)]
298    async fn test_hash_determinism() {
299        let temp_dir = tempdir().unwrap();
300        let file_path = temp_dir.path().join("test.txt");
301
302        // Create a file that is large enough to span multiple buffer reads
303        // Using 20MB to ensure it's larger than typical buffer sizes
304        let file_size = 20 * 1024 * 1024;
305        let content: Vec<u8> = (0..file_size).map(|i| (i % 256) as u8).collect();
306        std::fs::write(&file_path, &content).unwrap();
307
308        let file_path_str = file_path.to_str().unwrap().to_string();
309
310        // Hash with 8MB buffer size
311        let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
312        assert!(result1.is_ok());
313        let file_info1 = result1.unwrap();
314
315        // Hash with 4MB buffer size
316        let result2 = hash_single_file(file_path_str, 4 * 1024 * 1024);
317        assert!(result2.is_ok());
318        let file_info2 = result2.unwrap();
319
320        // Hashes should be identical regardless of buffer size
321        // This verifies that chunker.finish() is called correctly
322        assert_eq!(file_info1.hash(), file_info2.hash());
323        assert_eq!(file_info1.file_size(), file_info2.file_size());
324    }
325
326    #[tokio::test]
327    async fn test_hash_file_not_found() {
328        let buffer_size = 8 * 1024 * 1024; // 8MB
329        let result = hash_single_file("/nonexistent/file.txt".to_string(), buffer_size);
330        assert!(result.is_err());
331    }
332
333    #[tokio::test]
334    async fn test_hash_files_async() {
335        let temp_dir = tempdir().unwrap();
336
337        let file1_path = temp_dir.path().join("file1.txt");
338        let file2_path = temp_dir.path().join("file2.txt");
339
340        std::fs::write(&file1_path, b"First file content").unwrap();
341        std::fs::write(&file2_path, b"Second file content").unwrap();
342
343        let file_paths = vec![
344            file1_path.to_str().unwrap().to_string(),
345            file2_path.to_str().unwrap().to_string(),
346        ];
347
348        let result = hash_files_async(file_paths).await;
349        assert!(result.is_ok());
350
351        let file_infos = result.unwrap();
352        assert_eq!(file_infos.len(), 2);
353        assert_eq!(file_infos[0].file_size(), Some(18));
354        assert_eq!(file_infos[1].file_size(), Some(19));
355        assert_ne!(file_infos[0].hash(), file_infos[1].hash());
356    }
357
358    #[tokio::test]
359    #[cfg_attr(feature = "smoke-test", ignore)]
360    async fn test_hash_file_size_multiple_of_buffer() {
361        // Regression test for bug where final chunk wasn't produced when file size
362        // is exactly a multiple of buffer_size. This test verifies that
363        // chunker.finish() is called to flush any remaining data.
364        let temp_dir = tempdir().unwrap();
365        let file_path = temp_dir.path().join("multiple_of_buffer.bin");
366
367        // Create a file that is exactly 16MB
368        let file_size = 16 * 1024 * 1024;
369        let content: Vec<u8> = (0..file_size).map(|i| (i % 256) as u8).collect();
370        std::fs::write(&file_path, &content).unwrap();
371
372        let file_path_str = file_path.to_str().unwrap().to_string();
373
374        // Hash with 8MB buffer size - file is exactly 2x buffer size
375        let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
376        assert!(result1.is_ok());
377        let file_info1 = result1.unwrap();
378        assert_eq!(file_info1.file_size(), Some(file_size as u64));
379        assert!(!file_info1.hash().is_empty());
380
381        // Hash with 4MB buffer size - file is exactly 4x buffer size
382        let result2 = hash_single_file(file_path_str.clone(), 4 * 1024 * 1024);
383        assert!(result2.is_ok());
384        let file_info2 = result2.unwrap();
385
386        // Hash with 2MB buffer size - file is exactly 8x buffer size
387        let result3 = hash_single_file(file_path_str, 2 * 1024 * 1024);
388        assert!(result3.is_ok());
389        let file_info3 = result3.unwrap();
390
391        // All hashes should be identical regardless of buffer size
392        // This verifies that chunker.finish() is properly called to flush remaining chunks
393        // Without finish(), different buffer sizes would produce different (incomplete) hashes
394        assert_eq!(file_info1.hash(), file_info2.hash(), "Hash mismatch between 8MB and 4MB buffer sizes");
395        assert_eq!(file_info1.hash(), file_info3.hash(), "Hash mismatch between 8MB and 2MB buffer sizes");
396        assert_eq!(file_info1.file_size(), file_info2.file_size());
397        assert_eq!(file_info1.file_size(), file_info3.file_size());
398    }
399}