Skip to main content

xet_data/processing/
test_utils.rs

1use std::fs::{File, create_dir_all, read_dir};
2use std::io::{Read, Seek, SeekFrom, Write};
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use itertools::multizip;
7use rand::prelude::*;
8use tempfile::TempDir;
9use xet_client::cas_client::{Client, LocalClient};
10#[cfg(feature = "simulation")]
11use xet_client::cas_client::{LocalTestServer, LocalTestServerBuilder};
12
13use super::configurations::TranslatorConfig;
14use super::data_client::clean_file;
15use super::file_cleaner::Sha256Policy;
16use super::{FileDownloadSession, FileUploadSession, XetFileInfo};
17
18/// Describes how hydration (download/smudge) should be performed during a test.
19///
20/// Each variant exercises a different reconstruction path:
21/// - `DirectClient`: Uses `LocalClient` directly (no HTTP server).
22/// - `ServerV2`: Uses `LocalTestServer` with default V2 reconstruction.
23/// - `ServerV1Fallback`: Uses `LocalTestServer` with V2 disabled, forcing V1 fallback.
24/// - `ServerMaxRanges2`: Uses `LocalTestServer` with `max_ranges_per_fetch=2`, forcing multi-range fetch splitting in
25///   V2 responses.
26#[derive(Debug, Clone, Copy)]
27pub enum HydrationMode {
28    DirectClient,
29    #[cfg(feature = "simulation")]
30    ServerV2,
31    #[cfg(feature = "simulation")]
32    ServerV1Fallback,
33    #[cfg(feature = "simulation")]
34    ServerMaxRanges2,
35}
36
37impl HydrationMode {
38    pub fn all() -> &'static [HydrationMode] {
39        &[
40            HydrationMode::DirectClient,
41            #[cfg(feature = "simulation")]
42            HydrationMode::ServerV2,
43            #[cfg(feature = "simulation")]
44            HydrationMode::ServerV1Fallback,
45            #[cfg(feature = "simulation")]
46            HydrationMode::ServerMaxRanges2,
47        ]
48    }
49
50    pub fn uses_server(&self) -> bool {
51        match self {
52            HydrationMode::DirectClient => false,
53            #[cfg(feature = "simulation")]
54            _ => true,
55        }
56    }
57}
58
59impl std::fmt::Display for HydrationMode {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            HydrationMode::DirectClient => write!(f, "direct_client"),
63            #[cfg(feature = "simulation")]
64            HydrationMode::ServerV2 => write!(f, "server_v2"),
65            #[cfg(feature = "simulation")]
66            HydrationMode::ServerV1Fallback => write!(f, "server_v1_fallback"),
67            #[cfg(feature = "simulation")]
68            HydrationMode::ServerMaxRanges2 => write!(f, "server_max_ranges_2"),
69        }
70    }
71}
72
73/// Creates or overwrites a single file in `dir` with `size` bytes of random data.
74/// Panics on any I/O error. Returns the total number of bytes written (=`size`).
75pub fn create_random_file(path: impl AsRef<Path>, size: usize, seed: u64) -> usize {
76    let path = path.as_ref();
77
78    let dir = path.parent().unwrap();
79
80    // Make sure the directory exists, or create it.
81    create_dir_all(dir).unwrap();
82
83    let mut rng = StdRng::seed_from_u64(seed);
84
85    // Build the path to the file, create the file, and write random data.
86    let mut file = File::create(path).unwrap();
87
88    let mut buffer = vec![0_u8; size];
89    rng.fill_bytes(&mut buffer);
90
91    file.write_all(&buffer).unwrap();
92
93    size
94}
95
96/// Creates a collection of random files, each with a deterministic seed.
97/// the total number of bytes written for all files combined.
98pub fn create_random_files(dir: impl AsRef<Path>, files: &[(impl AsRef<str>, usize)], seed: u64) -> usize {
99    let dir = dir.as_ref();
100
101    let mut total_bytes = 0;
102    let mut rng = SmallRng::seed_from_u64(seed);
103
104    for (file_name, size) in files {
105        total_bytes += create_random_file(dir.join(file_name.as_ref()), *size, rng.random());
106    }
107    total_bytes
108}
109
110/// Creates or overwrites a single file in `dir` with consecutive segments determined by the list of [(size, seed)].
111/// Panics on any I/O error. Returns the total number of bytes written (=`size`).
112pub fn create_random_multipart_file(path: impl AsRef<Path>, segments: &[(usize, u64)]) -> usize {
113    let path = path.as_ref();
114    let dir = path.parent().unwrap();
115
116    // Make sure the directory exists, or create it.
117    create_dir_all(dir).unwrap();
118
119    // Build the path to the file, create the file, and write random data.
120    let mut file = File::create(path).unwrap();
121
122    let mut total_size = 0;
123    for &(size, seed) in segments {
124        let mut rng = StdRng::seed_from_u64(seed);
125
126        let mut buffer = vec![0_u8; size];
127        rng.fill_bytes(&mut buffer);
128        file.write_all(&buffer).unwrap();
129        total_size += size;
130    }
131    total_size
132}
133
134/// Panics if `dir1` and `dir2` differ in terms of files or file contents.
135/// Uses `unwrap()` everywhere; intended for test-only use.
136pub fn verify_directories_match(dir1: impl AsRef<Path>, dir2: impl AsRef<Path>) {
137    let dir1 = dir1.as_ref();
138    let dir2 = dir2.as_ref();
139
140    let mut files_in_dir1 = Vec::new();
141    for entry in read_dir(dir1).unwrap() {
142        let entry = entry.unwrap();
143        assert!(entry.file_type().unwrap().is_file());
144        files_in_dir1.push(entry.file_name());
145    }
146
147    let mut files_in_dir2 = Vec::new();
148    for entry in read_dir(dir2).unwrap() {
149        let entry = entry.unwrap();
150        assert!(entry.file_type().unwrap().is_file());
151        files_in_dir2.push(entry.file_name());
152    }
153
154    files_in_dir1.sort();
155    files_in_dir2.sort();
156
157    if files_in_dir1 != files_in_dir2 {
158        panic!(
159            "Directories differ: file sets are not the same.\n \
160             dir1: {files_in_dir1:?}\n dir2: {files_in_dir2:?}"
161        );
162    }
163
164    // Compare file contents byte-for-byte
165    for file_name in &files_in_dir1 {
166        let path1 = dir1.join(file_name);
167        let path2 = dir2.join(file_name);
168
169        let mut buf1 = Vec::new();
170        let mut buf2 = Vec::new();
171
172        File::open(&path1).unwrap().read_to_end(&mut buf1).unwrap();
173        File::open(&path2).unwrap().read_to_end(&mut buf2).unwrap();
174
175        if buf1 != buf2 {
176            panic!(
177                "File contents differ for {file_name:?}\n \
178                 dir1 path: {path1:?}\n dir2 path: {path2:?}"
179            );
180        }
181    }
182}
183
184pub struct HydrateDehydrateTest {
185    _temp_dir: TempDir,
186    pub cas_dir: PathBuf,
187    pub src_dir: PathBuf,
188    pub ptr_dir: PathBuf,
189    pub dest_dir: PathBuf,
190    use_test_server: bool,
191    /// Kept alive so the test server stays running for the duration of the test.
192    #[cfg(feature = "simulation")]
193    test_server: Option<LocalTestServer>,
194}
195
196impl Default for HydrateDehydrateTest {
197    fn default() -> Self {
198        Self::new(false)
199    }
200}
201
202impl HydrateDehydrateTest {
203    /// Creates a new test harness with the specified options.
204    ///
205    /// # Arguments
206    /// * `use_test_server` - If true, uses a LocalTestServer (RemoteClient over HTTP); otherwise uses LocalClient
207    ///   directly.
208    pub fn new(use_test_server: bool) -> Self {
209        let _temp_dir = TempDir::new().unwrap();
210        let temp_path = _temp_dir.path();
211
212        let cas_dir = temp_path.join("cas");
213        let src_dir = temp_path.join("src");
214        let ptr_dir = temp_path.join("pointers");
215        let dest_dir = temp_path.join("dest");
216
217        std::fs::create_dir_all(&cas_dir).unwrap();
218        std::fs::create_dir_all(&src_dir).unwrap();
219        std::fs::create_dir_all(&ptr_dir).unwrap();
220        std::fs::create_dir_all(&dest_dir).unwrap();
221
222        Self {
223            cas_dir,
224            src_dir,
225            ptr_dir,
226            dest_dir,
227            _temp_dir,
228            use_test_server,
229            #[cfg(feature = "simulation")]
230            test_server: None,
231        }
232    }
233
234    /// Creates a new test harness configured for a specific hydration mode.
235    pub fn for_mode(mode: HydrationMode) -> Self {
236        Self::new(mode.uses_server())
237    }
238
239    /// Applies hydration mode configuration to the test server.
240    /// Must be called after `dehydrate()` and before `hydrate()`.
241    pub async fn apply_hydration_mode(&mut self, mode: HydrationMode) {
242        match mode {
243            HydrationMode::DirectClient => {},
244            #[cfg(feature = "simulation")]
245            HydrationMode::ServerV2 => {
246                self.ensure_server_created().await;
247            },
248            #[cfg(feature = "simulation")]
249            HydrationMode::ServerV1Fallback => {
250                self.ensure_server_created().await;
251                self.test_server.as_ref().unwrap().client().disable_v2_reconstruction(404);
252            },
253            #[cfg(feature = "simulation")]
254            HydrationMode::ServerMaxRanges2 => {
255                self.ensure_server_created().await;
256                self.test_server.as_ref().unwrap().client().set_max_ranges_per_fetch(2);
257            },
258        }
259    }
260
261    /// Ensures the test server is running, creating it if necessary.
262    /// Call this before configuring the server (e.g., disabling V2 or setting max ranges).
263    #[cfg(feature = "simulation")]
264    pub async fn ensure_server_created(&mut self) {
265        if self.use_test_server && self.test_server.is_none() {
266            let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
267            self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
268        }
269    }
270
271    /// Returns a reference to the test server, if one has been created.
272    #[cfg(feature = "simulation")]
273    pub fn test_server(&self) -> Option<&LocalTestServer> {
274        self.test_server.as_ref()
275    }
276
277    /// Lazily initializes the test server (if needed) and returns a CAS client.
278    async fn get_or_create_client(&mut self) -> Arc<dyn Client> {
279        if self.use_test_server {
280            #[cfg(feature = "simulation")]
281            {
282                if self.test_server.is_none() {
283                    let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
284                    self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
285                }
286                self.test_server.as_ref().unwrap().remote_client().clone() as Arc<dyn Client>
287            }
288            #[cfg(not(feature = "simulation"))]
289            {
290                panic!("test server requires the 'simulation' feature");
291            }
292        } else {
293            LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap() as Arc<dyn Client>
294        }
295    }
296
297    pub async fn new_upload_session(&self) -> Arc<FileUploadSession> {
298        let config = Arc::new(TranslatorConfig::local_config(&self.cas_dir).unwrap());
299        FileUploadSession::new(config.clone()).await.unwrap()
300    }
301
302    pub async fn clean_all_files(&self, upload_session: &Arc<FileUploadSession>, sequential: bool) {
303        create_dir_all(&self.ptr_dir).unwrap();
304
305        if sequential {
306            for entry in read_dir(&self.src_dir).unwrap() {
307                let entry = entry.unwrap();
308                let out_file = self.ptr_dir.join(entry.file_name());
309                let upload_session = upload_session.clone();
310
311                if sequential {
312                    let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), Sha256Policy::Compute)
313                        .await
314                        .unwrap();
315                    assert_eq!({ metrics.total_bytes }, entry.metadata().unwrap().len());
316                    std::fs::write(out_file, pf.as_pointer_file().unwrap().as_bytes()).unwrap();
317
318                    // Force a checkpoint after every file.
319                    upload_session.checkpoint().await.unwrap();
320                }
321            }
322        } else {
323            let files: Vec<PathBuf> = read_dir(&self.src_dir)
324                .unwrap()
325                .map(|entry| self.src_dir.join(entry.unwrap().file_name()))
326                .collect();
327
328            let files_and_sha256 = multizip((files.iter(), std::iter::repeat_with(|| Sha256Policy::Compute)));
329
330            let clean_results = upload_session.upload_files(files_and_sha256).await.unwrap();
331
332            for (i, xf) in clean_results.into_iter().enumerate() {
333                std::fs::write(self.ptr_dir.join(files[i].file_name().unwrap()), serde_json::to_string(&xf).unwrap())
334                    .unwrap();
335            }
336        }
337    }
338
339    pub async fn dehydrate(&mut self, sequential: bool) {
340        let upload_session = self.new_upload_session().await;
341        self.clean_all_files(&upload_session, sequential).await;
342
343        upload_session.finalize().await.unwrap();
344    }
345
346    pub async fn hydrate(&mut self) {
347        let client = self.get_or_create_client().await;
348        let session = FileDownloadSession::from_client(client, None);
349
350        for entry in read_dir(&self.ptr_dir).unwrap() {
351            let entry = entry.unwrap();
352            let out_filename = self.dest_dir.join(entry.file_name());
353
354            let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
355            let (_id, _) = session.download_file(&xf, &out_filename).await.unwrap();
356        }
357    }
358
359    pub async fn hydrate_partitioned_writers(&mut self, partitions: usize) {
360        let client = self.get_or_create_client().await;
361        let session = FileDownloadSession::from_client(client, None);
362
363        for entry in read_dir(&self.ptr_dir).unwrap() {
364            let entry = entry.unwrap();
365            let out_filename = self.dest_dir.join(entry.file_name());
366            let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
367            let file_size = xf.file_size().expect("file size required for partitioned hydration");
368
369            let out_file = File::create(&out_filename).unwrap();
370            out_file.set_len(file_size).unwrap();
371
372            if file_size == 0 {
373                continue;
374            }
375
376            let partition_count = partitions.max(1) as u64;
377            let mut tasks = Vec::new();
378
379            for idx in 0..partition_count {
380                let start = (idx * file_size) / partition_count;
381                let end = ((idx + 1) * file_size) / partition_count;
382
383                if start == end {
384                    continue;
385                }
386
387                let session = session.clone();
388                let xf = xf.clone();
389                let out_filename = out_filename.clone();
390                tasks.push(tokio::spawn(async move {
391                    let mut writer = std::fs::OpenOptions::new().write(true).open(out_filename).unwrap();
392                    writer.seek(SeekFrom::Start(start)).unwrap();
393                    session.download_to_writer(&xf, start..end, writer).await
394                }));
395            }
396
397            for task in tasks {
398                task.await.unwrap().unwrap();
399            }
400        }
401    }
402
403    pub async fn hydrate_stream(&mut self) {
404        let client = self.get_or_create_client().await;
405        let session = FileDownloadSession::from_client(client, None);
406
407        for entry in read_dir(&self.ptr_dir).unwrap() {
408            let entry = entry.unwrap();
409            let out_filename = self.dest_dir.join(entry.file_name());
410
411            let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
412            let (_id, mut stream) = session.download_stream(&xf, None).await.unwrap();
413
414            let mut file = File::create(&out_filename).unwrap();
415            while let Some(chunk) = stream.next().await.unwrap() {
416                file.write_all(&chunk).unwrap();
417            }
418        }
419    }
420
421    pub fn verify_src_dest_match(&self) {
422        verify_directories_match(&self.src_dir, &self.dest_dir);
423    }
424}
425
426/// Provides a test environment with a config suitable for `FileUploadSession` / `FileDownloadSession`.
427///
428/// When the `simulation` feature is enabled the environment spins up a `LocalTestServer` and
429/// returns a server-backed config; otherwise it falls back to `LocalClient` via `local_config`.
430pub struct TestEnvironment {
431    _temp_dir: TempDir,
432    pub base_dir: PathBuf,
433    pub config: Arc<super::configurations::TranslatorConfig>,
434    #[cfg(feature = "simulation")]
435    _server: Option<LocalTestServer>,
436}
437
438impl TestEnvironment {
439    pub async fn new() -> Self {
440        let temp_dir = TempDir::new().unwrap();
441        let base_dir = temp_dir.path().to_path_buf();
442
443        #[cfg(feature = "simulation")]
444        let (config, server) = {
445            let server = LocalTestServerBuilder::new().start().await;
446            let config = Arc::new(
447                super::configurations::TranslatorConfig::test_server_config(server.http_endpoint(), &base_dir).unwrap(),
448            );
449            (config, Some(server))
450        };
451
452        #[cfg(not(feature = "simulation"))]
453        let config = Arc::new(super::configurations::TranslatorConfig::local_config(&base_dir).unwrap());
454
455        Self {
456            _temp_dir: temp_dir,
457            base_dir,
458            config,
459            #[cfg(feature = "simulation")]
460            _server: server,
461        }
462    }
463}