Skip to main content

xet_runtime/core/
common.rs

1use std::sync::atomic::AtomicU64;
2use std::sync::{Arc, Mutex};
3
4use reqwest::Client;
5use tokio::sync::Semaphore;
6
7use crate::config::XetConfig;
8use crate::utils::adjustable_semaphore::AdjustableSemaphore;
9
10/// Holds global values that are shared across the entire runtime.
11///
12/// Accessible via `XetRuntime::current().common()`.
13#[derive(Debug)]
14pub struct XetCommon {
15    // A cached reqwest Client to be shared by all high-level clients.
16    // The String tag identifies the client type (e.g., "tcp" for regular, socket path for UDS).
17    global_reqwest_client: Mutex<Option<(String, Client)>>,
18
19    /// Limits the number of files being ingested (cleaned/uploaded) concurrently.
20    pub file_ingestion_semaphore: Arc<Semaphore>,
21
22    /// Limits the number of files being downloaded concurrently.
23    pub file_download_semaphore: Arc<Semaphore>,
24
25    /// Limits total memory used for buffering data during reconstruction downloads.
26    pub reconstruction_download_buffer: Arc<AdjustableSemaphore>,
27
28    /// Tracks the number of currently active file downloads for dynamic buffer scaling.
29    pub active_downloads: Arc<AtomicU64>,
30}
31
32impl XetCommon {
33    /// Creates a new `XetCommon` instance with the given configuration.
34    pub fn new(config: &XetConfig) -> Self {
35        Self {
36            global_reqwest_client: Mutex::new(None),
37            file_ingestion_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_ingestion)),
38            file_download_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_downloads)),
39            reconstruction_download_buffer: {
40                let base = config.reconstruction.download_buffer_size.as_u64();
41                let limit = config.reconstruction.download_buffer_limit.as_u64();
42                AdjustableSemaphore::new(base, (base, limit))
43            },
44            active_downloads: Arc::new(AtomicU64::new(0)),
45        }
46    }
47
48    /// Gets or creates a reqwest client, using a tag to identify the client type.
49    ///
50    /// # Arguments
51    /// * `tag` - A string identifier for the client (e.g., "tcp" for regular, socket path for UDS)
52    /// * `create_client_fn` - A function that creates the client if needed
53    ///
54    /// # Returns
55    /// Returns a clone of the cached client if the tag matches, or creates a new client if the tag differs.
56    pub fn get_or_create_reqwest_client<F>(&self, tag: String, create_client_fn: F) -> crate::error::Result<Client>
57    where
58        F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
59    {
60        let mut guard = self.global_reqwest_client.lock()?;
61
62        match guard.as_ref() {
63            Some((cached_tag, cached_client)) if cached_tag == &tag => {
64                // Tag matches, return a clone of the existing client
65                Ok(cached_client.clone())
66            },
67            _ => {
68                // Tag doesn't match or no client exists, create a new one
69                let new_client = create_client_fn()?;
70                *guard = Some((tag, new_client.clone()));
71                Ok(new_client)
72            },
73        }
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::sync::atomic::{AtomicUsize, Ordering};
80
81    use super::*;
82
83    #[test]
84    fn test_get_or_create_reqwest_client_caches_by_tag() {
85        let common = XetCommon::new(&XetConfig::new());
86        let call_count = AtomicUsize::new(0);
87
88        let _client1 = common
89            .get_or_create_reqwest_client("test-tag".to_string(), || {
90                call_count.fetch_add(1, Ordering::SeqCst);
91                reqwest::Client::builder().build()
92            })
93            .unwrap();
94
95        let _client2 = common
96            .get_or_create_reqwest_client("test-tag".to_string(), || {
97                call_count.fetch_add(1, Ordering::SeqCst);
98                reqwest::Client::builder().build()
99            })
100            .unwrap();
101
102        assert_eq!(call_count.load(Ordering::SeqCst), 1);
103    }
104
105    #[test]
106    fn test_get_or_create_reqwest_client_creates_new_for_different_tag() {
107        let common = XetCommon::new(&XetConfig::new());
108        let call_count = AtomicUsize::new(0);
109
110        let _client1 = common
111            .get_or_create_reqwest_client("tag1".to_string(), || {
112                call_count.fetch_add(1, Ordering::SeqCst);
113                reqwest::Client::builder().user_agent("client1").build()
114            })
115            .unwrap();
116
117        let _client2 = common
118            .get_or_create_reqwest_client("tag2".to_string(), || {
119                call_count.fetch_add(1, Ordering::SeqCst);
120                reqwest::Client::builder().user_agent("client2").build()
121            })
122            .unwrap();
123
124        assert_eq!(call_count.load(Ordering::SeqCst), 2);
125    }
126
127    #[test]
128    fn test_initializes_with_empty_client_cache() {
129        let common = XetCommon::new(&XetConfig::new());
130
131        let guard = common.global_reqwest_client.lock().unwrap();
132        assert!(guard.is_none());
133    }
134
135    #[test]
136    fn test_replaces_client_when_tag_changes() {
137        let common = XetCommon::new(&XetConfig::new());
138
139        let _client1 = common
140            .get_or_create_reqwest_client("tcp".to_string(), || {
141                reqwest::Client::builder().user_agent("tcp-client").build()
142            })
143            .unwrap();
144
145        {
146            let guard = common.global_reqwest_client.lock().unwrap();
147            let (tag, _) = guard.as_ref().unwrap();
148            assert_eq!(tag, "tcp");
149        }
150
151        let _client2 = common
152            .get_or_create_reqwest_client("/tmp/socket.sock".to_string(), || {
153                reqwest::Client::builder().user_agent("uds-client").build()
154            })
155            .unwrap();
156
157        {
158            let guard = common.global_reqwest_client.lock().unwrap();
159            let (tag, _) = guard.as_ref().unwrap();
160            assert_eq!(tag, "/tmp/socket.sock");
161        }
162    }
163
164    #[test]
165    fn test_semaphores_initialized_from_config() {
166        let config = XetConfig::new();
167        let common = XetCommon::new(&config);
168
169        assert_eq!(common.file_ingestion_semaphore.available_permits(), config.data.max_concurrent_file_ingestion);
170        assert_eq!(common.file_download_semaphore.available_permits(), config.data.max_concurrent_file_downloads);
171
172        // Total permits is at least the configured download_buffer_base (may be slightly
173        // larger due to rounding up to a whole number of internal permits).
174        assert!(
175            common.reconstruction_download_buffer.total_permits()
176                >= config.reconstruction.download_buffer_size.as_u64()
177        );
178
179        assert_eq!(common.active_downloads.load(Ordering::Relaxed), 0);
180    }
181}