1use std::fs::File;
2use std::io::Read;
3use std::path::Path;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use http::header::HeaderMap;
8use tracing::{Instrument, Span, info_span, instrument};
9use ulid::Ulid;
10use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
11use xet_core_structures::merklehash::MerkleHash;
12use xet_runtime::core::par_utils::run_constrained_with_semaphore;
13use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
14
15use super::configurations::{SessionContext, TranslatorConfig};
16use super::file_cleaner::Sha256Policy;
17use super::{FileUploadSession, XetFileInfo};
18use crate::deduplication::{Chunker, DeduplicationMetrics};
19use crate::error::Result;
20
21pub fn default_config(
22 endpoint: String,
23 token_info: Option<(String, u64)>,
24 token_refresher: Option<Arc<dyn TokenRefresher>>,
25 custom_headers: Option<Arc<HeaderMap>>,
26) -> Result<TranslatorConfig> {
27 let (token, token_expiration) = token_info.unzip();
28 let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher);
29
30 let session = SessionContext {
31 endpoint,
32 auth: auth_cfg,
33 custom_headers,
34 repo_paths: vec!["".into()],
35 session_id: Some(Ulid::new().to_string()),
36 };
37
38 TranslatorConfig::new(session)
39}
40
41#[instrument(skip_all, name = "clean_bytes", fields(bytes.len = bytes.len()))]
42pub async fn clean_bytes(
43 processor: Arc<FileUploadSession>,
44 bytes: Vec<u8>,
45 sha256_policy: Sha256Policy,
46) -> Result<(XetFileInfo, DeduplicationMetrics)> {
47 let (_id, mut handle) = processor.start_clean(None, Some(bytes.len() as u64), sha256_policy)?;
48 handle.add_data(&bytes).await?;
49 handle.finish().await
50}
51
52#[instrument(skip_all, name = "clean_file", fields(file.name = tracing::field::Empty, file.len = tracing::field::Empty))]
53pub async fn clean_file(
54 processor: Arc<FileUploadSession>,
55 filename: impl AsRef<Path>,
56 sha256_policy: Sha256Policy,
57) -> Result<(XetFileInfo, DeduplicationMetrics)> {
58 let mut reader = File::open(&filename)?;
59
60 let filesize = reader.metadata()?.len();
61 let span = Span::current();
62 span.record("file.name", filename.as_ref().to_str());
63 span.record("file.len", filesize);
64 let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
65
66 let (_id, mut handle) =
67 processor.start_clean(Some(filename.as_ref().to_string_lossy().into()), Some(filesize), sha256_policy)?;
68
69 loop {
70 let bytes = reader.read(&mut buffer)?;
71 if bytes == 0 {
72 break;
73 }
74
75 handle.add_data(&buffer[0..bytes]).await?;
76 }
77
78 handle.finish().await
79}
80
81fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo> {
103 let mut reader = File::open(&filename)?;
104 let filesize = reader.metadata()?.len();
105
106 let mut buffer = vec![0u8; buffer_size];
107 let mut chunker = Chunker::default();
108 let mut chunk_hashes: Vec<(MerkleHash, u64)> = Vec::new();
109
110 loop {
111 check_sigint_shutdown()?;
112
113 let bytes_read = reader.read(&mut buffer)?;
114 if bytes_read == 0 {
115 break;
116 }
117
118 let data = Bytes::copy_from_slice(&buffer[0..bytes_read]);
119 let chunks = chunker.next_block_bytes(&data, false);
120
121 for chunk in chunks {
122 chunk_hashes.push((chunk.hash, chunk.data.len() as u64));
123 }
124 }
125
126 if let Some(final_chunk) = chunker.finish() {
128 chunk_hashes.push((final_chunk.hash, final_chunk.data.len() as u64));
129 }
130
131 let file_hash = xet_core_structures::merklehash::file_hash(&chunk_hashes);
132 Ok(XetFileInfo::new(file_hash.hex(), filesize))
133}
134
135#[instrument(skip_all, name = "data_client::hash_files", fields(num_files=file_paths.len()))]
161pub async fn hash_files_async(file_paths: Vec<String>) -> Result<Vec<XetFileInfo>> {
162 let rt = XetRuntime::current();
163 let semaphore = rt.common().file_ingestion_semaphore.clone();
164 let buffer_size = *xet_config().data.ingestion_block_size as usize;
165
166 let hash_futures = file_paths.into_iter().map(|file_path| {
167 let rt = rt.clone();
168 async move {
169 rt.spawn_blocking(move || hash_single_file(file_path, buffer_size))
170 .await
171 .map_err(|e| std::io::Error::other(e.to_string()))?
172 }
173 .instrument(info_span!("hash_file"))
174 });
175
176 let files = run_constrained_with_semaphore(hash_futures, semaphore).await?;
177
178 Ok(files)
179}
180
181#[cfg(test)]
182mod tests {
183 use dirs::home_dir;
184 use serial_test::serial;
185 use tempfile::tempdir;
186 use xet_runtime::utils::EnvVarGuard;
187
188 use super::*;
189
190 #[test]
191 #[serial(default_config_env)]
192 fn test_default_config_with_hf_home() {
193 let temp_dir = tempdir().unwrap();
194 let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
195
196 let endpoint = "http://localhost:8080".to_string();
197 let result = default_config(endpoint, None, None, None);
198
199 assert!(result.is_ok());
200 let config = result.unwrap();
201 assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
202 }
203
204 #[test]
205 #[serial(default_config_env)]
206 fn test_default_config_with_hf_xet_cache_and_hf_home() {
207 let temp_dir_xet_cache = tempdir().unwrap();
208 let temp_dir_hf_home = tempdir().unwrap();
209
210 let hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir_xet_cache.path().to_str().unwrap());
211 let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
212
213 let endpoint = "http://localhost:8080".to_string();
214 let result = default_config(endpoint, None, None, None);
215
216 assert!(result.is_ok());
217 let config = result.unwrap();
218 assert!(config.shard_cache_directory.starts_with(temp_dir_xet_cache.path()));
219
220 drop(hf_xet_cache_guard);
221 drop(hf_home_guard);
222
223 let temp_dir = tempdir().unwrap();
224 let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
225
226 let endpoint = "http://localhost:8080".to_string();
227 let result = default_config(endpoint, None, None, None);
228
229 assert!(result.is_ok());
230 let config = result.unwrap();
231 assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
232 }
233
234 #[test]
235 #[serial(default_config_env)]
236 fn test_default_config_with_hf_xet_cache() {
237 let temp_dir = tempdir().unwrap();
238 let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
239
240 let endpoint = "http://localhost:8080".to_string();
241 let result = default_config(endpoint, None, None, None);
242
243 assert!(result.is_ok());
244 let config = result.unwrap();
245 assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
246 }
247
248 #[test]
249 #[serial(default_config_env)]
250 fn test_default_config_without_env_vars() {
251 let endpoint = "http://localhost:8080".to_string();
252 let result = default_config(endpoint, None, None, None);
253
254 let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");
255
256 assert!(result.is_ok());
257 let config = result.unwrap();
258 let test_cache_dir = &config.shard_cache_directory;
259 assert!(
260 test_cache_dir.starts_with(&expected),
261 "cache dir = {test_cache_dir:?}; does not start with {expected:?}",
262 );
263 }
264
265 #[tokio::test]
266 async fn test_hash_empty_file() {
267 let temp_dir = tempdir().unwrap();
268 let file_path = temp_dir.path().join("empty.txt");
269 std::fs::write(&file_path, b"").unwrap();
270
271 let buffer_size = 8 * 1024 * 1024; let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
273 assert!(result.is_ok());
274
275 let file_info = result.unwrap();
276 assert_eq!(file_info.file_size(), Some(0));
277 assert!(!file_info.hash().is_empty());
278 }
279
280 #[tokio::test]
281 async fn test_hash_small_file() {
282 let temp_dir = tempdir().unwrap();
283 let file_path = temp_dir.path().join("small.txt");
284 let content = b"Hello, World!";
285 std::fs::write(&file_path, content).unwrap();
286
287 let buffer_size = 8 * 1024 * 1024; let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
289 assert!(result.is_ok());
290
291 let file_info = result.unwrap();
292 assert_eq!(file_info.file_size(), Some(content.len() as u64));
293 assert!(!file_info.hash().is_empty());
294 }
295
296 #[tokio::test]
297 #[cfg_attr(feature = "smoke-test", ignore)]
298 async fn test_hash_determinism() {
299 let temp_dir = tempdir().unwrap();
300 let file_path = temp_dir.path().join("test.txt");
301
302 let file_size = 20 * 1024 * 1024;
305 let content: Vec<u8> = (0..file_size).map(|i| (i % 256) as u8).collect();
306 std::fs::write(&file_path, &content).unwrap();
307
308 let file_path_str = file_path.to_str().unwrap().to_string();
309
310 let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
312 assert!(result1.is_ok());
313 let file_info1 = result1.unwrap();
314
315 let result2 = hash_single_file(file_path_str, 4 * 1024 * 1024);
317 assert!(result2.is_ok());
318 let file_info2 = result2.unwrap();
319
320 assert_eq!(file_info1.hash(), file_info2.hash());
323 assert_eq!(file_info1.file_size(), file_info2.file_size());
324 }
325
326 #[tokio::test]
327 async fn test_hash_file_not_found() {
328 let buffer_size = 8 * 1024 * 1024; let result = hash_single_file("/nonexistent/file.txt".to_string(), buffer_size);
330 assert!(result.is_err());
331 }
332
333 #[tokio::test]
334 async fn test_hash_files_async() {
335 let temp_dir = tempdir().unwrap();
336
337 let file1_path = temp_dir.path().join("file1.txt");
338 let file2_path = temp_dir.path().join("file2.txt");
339
340 std::fs::write(&file1_path, b"First file content").unwrap();
341 std::fs::write(&file2_path, b"Second file content").unwrap();
342
343 let file_paths = vec![
344 file1_path.to_str().unwrap().to_string(),
345 file2_path.to_str().unwrap().to_string(),
346 ];
347
348 let result = hash_files_async(file_paths).await;
349 assert!(result.is_ok());
350
351 let file_infos = result.unwrap();
352 assert_eq!(file_infos.len(), 2);
353 assert_eq!(file_infos[0].file_size(), Some(18));
354 assert_eq!(file_infos[1].file_size(), Some(19));
355 assert_ne!(file_infos[0].hash(), file_infos[1].hash());
356 }
357
358 #[tokio::test]
359 #[cfg_attr(feature = "smoke-test", ignore)]
360 async fn test_hash_file_size_multiple_of_buffer() {
361 let temp_dir = tempdir().unwrap();
365 let file_path = temp_dir.path().join("multiple_of_buffer.bin");
366
367 let file_size = 16 * 1024 * 1024;
369 let content: Vec<u8> = (0..file_size).map(|i| (i % 256) as u8).collect();
370 std::fs::write(&file_path, &content).unwrap();
371
372 let file_path_str = file_path.to_str().unwrap().to_string();
373
374 let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
376 assert!(result1.is_ok());
377 let file_info1 = result1.unwrap();
378 assert_eq!(file_info1.file_size(), Some(file_size as u64));
379 assert!(!file_info1.hash().is_empty());
380
381 let result2 = hash_single_file(file_path_str.clone(), 4 * 1024 * 1024);
383 assert!(result2.is_ok());
384 let file_info2 = result2.unwrap();
385
386 let result3 = hash_single_file(file_path_str, 2 * 1024 * 1024);
388 assert!(result3.is_ok());
389 let file_info3 = result3.unwrap();
390
391 assert_eq!(file_info1.hash(), file_info2.hash(), "Hash mismatch between 8MB and 4MB buffer sizes");
395 assert_eq!(file_info1.hash(), file_info3.hash(), "Hash mismatch between 8MB and 2MB buffer sizes");
396 assert_eq!(file_info1.file_size(), file_info2.file_size());
397 assert_eq!(file_info1.file_size(), file_info3.file_size());
398 }
399}