xet_runtime/core/
common.rs1use std::sync::atomic::AtomicU64;
2use std::sync::{Arc, Mutex};
3
4use reqwest::Client;
5use tokio::sync::Semaphore;
6
7use crate::config::XetConfig;
8use crate::utils::adjustable_semaphore::AdjustableSemaphore;
9
10#[derive(Debug)]
14pub struct XetCommon {
15 global_reqwest_client: Mutex<Option<(String, Client)>>,
18
19 pub file_ingestion_semaphore: Arc<Semaphore>,
21
22 pub file_download_semaphore: Arc<Semaphore>,
24
25 pub reconstruction_download_buffer: Arc<AdjustableSemaphore>,
27
28 pub active_downloads: Arc<AtomicU64>,
30}
31
32impl XetCommon {
33 pub fn new(config: &XetConfig) -> Self {
35 Self {
36 global_reqwest_client: Mutex::new(None),
37 file_ingestion_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_ingestion)),
38 file_download_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_downloads)),
39 reconstruction_download_buffer: {
40 let base = config.reconstruction.download_buffer_size.as_u64();
41 let limit = config.reconstruction.download_buffer_limit.as_u64();
42 AdjustableSemaphore::new(base, (base, limit))
43 },
44 active_downloads: Arc::new(AtomicU64::new(0)),
45 }
46 }
47
48 pub fn get_or_create_reqwest_client<F>(&self, tag: String, create_client_fn: F) -> crate::error::Result<Client>
57 where
58 F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
59 {
60 let mut guard = self.global_reqwest_client.lock()?;
61
62 match guard.as_ref() {
63 Some((cached_tag, cached_client)) if cached_tag == &tag => {
64 Ok(cached_client.clone())
66 },
67 _ => {
68 let new_client = create_client_fn()?;
70 *guard = Some((tag, new_client.clone()));
71 Ok(new_client)
72 },
73 }
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::sync::atomic::{AtomicUsize, Ordering};
80
81 use super::*;
82
83 #[test]
84 fn test_get_or_create_reqwest_client_caches_by_tag() {
85 let common = XetCommon::new(&XetConfig::new());
86 let call_count = AtomicUsize::new(0);
87
88 let _client1 = common
89 .get_or_create_reqwest_client("test-tag".to_string(), || {
90 call_count.fetch_add(1, Ordering::SeqCst);
91 reqwest::Client::builder().build()
92 })
93 .unwrap();
94
95 let _client2 = common
96 .get_or_create_reqwest_client("test-tag".to_string(), || {
97 call_count.fetch_add(1, Ordering::SeqCst);
98 reqwest::Client::builder().build()
99 })
100 .unwrap();
101
102 assert_eq!(call_count.load(Ordering::SeqCst), 1);
103 }
104
105 #[test]
106 fn test_get_or_create_reqwest_client_creates_new_for_different_tag() {
107 let common = XetCommon::new(&XetConfig::new());
108 let call_count = AtomicUsize::new(0);
109
110 let _client1 = common
111 .get_or_create_reqwest_client("tag1".to_string(), || {
112 call_count.fetch_add(1, Ordering::SeqCst);
113 reqwest::Client::builder().user_agent("client1").build()
114 })
115 .unwrap();
116
117 let _client2 = common
118 .get_or_create_reqwest_client("tag2".to_string(), || {
119 call_count.fetch_add(1, Ordering::SeqCst);
120 reqwest::Client::builder().user_agent("client2").build()
121 })
122 .unwrap();
123
124 assert_eq!(call_count.load(Ordering::SeqCst), 2);
125 }
126
127 #[test]
128 fn test_initializes_with_empty_client_cache() {
129 let common = XetCommon::new(&XetConfig::new());
130
131 let guard = common.global_reqwest_client.lock().unwrap();
132 assert!(guard.is_none());
133 }
134
135 #[test]
136 fn test_replaces_client_when_tag_changes() {
137 let common = XetCommon::new(&XetConfig::new());
138
139 let _client1 = common
140 .get_or_create_reqwest_client("tcp".to_string(), || {
141 reqwest::Client::builder().user_agent("tcp-client").build()
142 })
143 .unwrap();
144
145 {
146 let guard = common.global_reqwest_client.lock().unwrap();
147 let (tag, _) = guard.as_ref().unwrap();
148 assert_eq!(tag, "tcp");
149 }
150
151 let _client2 = common
152 .get_or_create_reqwest_client("/tmp/socket.sock".to_string(), || {
153 reqwest::Client::builder().user_agent("uds-client").build()
154 })
155 .unwrap();
156
157 {
158 let guard = common.global_reqwest_client.lock().unwrap();
159 let (tag, _) = guard.as_ref().unwrap();
160 assert_eq!(tag, "/tmp/socket.sock");
161 }
162 }
163
164 #[test]
165 fn test_semaphores_initialized_from_config() {
166 let config = XetConfig::new();
167 let common = XetCommon::new(&config);
168
169 assert_eq!(common.file_ingestion_semaphore.available_permits(), config.data.max_concurrent_file_ingestion);
170 assert_eq!(common.file_download_semaphore.available_permits(), config.data.max_concurrent_file_downloads);
171
172 assert!(
175 common.reconstruction_download_buffer.total_permits()
176 >= config.reconstruction.download_buffer_size.as_u64()
177 );
178
179 assert_eq!(common.active_downloads.load(Ordering::Relaxed), 0);
180 }
181}