xet_runtime/utils/
guards.rs1use std::env;
2use std::ffi::OsStr;
3use std::path::{Path, PathBuf};
4
5pub struct EnvVarGuard {
29 key: &'static str,
30 prev: Option<String>,
31}
32
33impl EnvVarGuard {
34 pub fn set(key: &'static str, value: impl AsRef<OsStr>) -> Self {
35 let prev = env::var(key).ok();
36 unsafe {
37 env::set_var(key, value);
38 }
39 Self { key, prev }
40 }
41}
42
43impl Drop for EnvVarGuard {
44 fn drop(&mut self) {
45 if let Some(v) = &self.prev {
46 unsafe {
47 env::set_var(self.key, v);
48 }
49 } else {
50 unsafe {
51 env::remove_var(self.key);
52 }
53 }
54 }
55}
56
57pub struct CwdGuard {
83 prev: PathBuf,
84}
85
86impl CwdGuard {
87 pub fn set(new_dir: &Path) -> std::io::Result<Self> {
88 let prev = env::current_dir()?;
89 env::set_current_dir(new_dir)?;
90 Ok(Self { prev })
91 }
92}
93
94impl Drop for CwdGuard {
95 fn drop(&mut self) {
96 let _ = env::set_current_dir(&self.prev);
97 }
98}
99
100pub struct ClosureGuard<F: FnOnce()> {
115 closure: Option<F>,
116}
117
118impl<F: FnOnce()> ClosureGuard<F> {
119 pub fn new(f: F) -> Self {
120 Self { closure: Some(f) }
121 }
122}
123
124impl<F: FnOnce()> Drop for ClosureGuard<F> {
125 fn drop(&mut self) {
126 if let Some(f) = self.closure.take() {
127 f();
128 }
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use serial_test::serial;
135 use tempfile::tempdir;
136
137 use super::*;
138
139 #[test]
140 #[serial(default_config_env)]
141 fn env_var_guard_sets_and_restores() {
142 let key = "TEST_ENV_VAR_GUARD";
143
144 unsafe {
146 env::set_var(key, "initial");
147 }
148
149 {
150 let _guard = EnvVarGuard::set(key, "temporary");
151 assert_eq!(env::var(key).unwrap(), "temporary");
152 }
153
154 assert_eq!(env::var(key).unwrap(), "initial");
156 }
157
158 #[test]
159 #[serial(default_config_env)]
160 fn env_var_guard_restores_none_when_var_did_not_exist() {
161 let key = "TEST_ENV_VAR_GUARD_NEW";
162
163 unsafe {
165 env::remove_var(key);
166 }
167 assert!(env::var(key).is_err());
168
169 {
170 let _guard = EnvVarGuard::set(key, "temporary");
171 assert_eq!(env::var(key).unwrap(), "temporary");
172 }
173
174 assert!(env::var(key).is_err());
176 }
177
178 #[test]
179 #[serial(default_config_env)]
180 fn env_var_guard_multiple_guards_same_key() {
181 let key = "TEST_ENV_VAR_GUARD_MULTI";
182
183 unsafe {
184 env::set_var(key, "initial");
185 }
186
187 {
188 let _guard1 = EnvVarGuard::set(key, "first");
189 assert_eq!(env::var(key).unwrap(), "first");
190
191 {
192 let _guard2 = EnvVarGuard::set(key, "second");
193 assert_eq!(env::var(key).unwrap(), "second");
194 }
195
196 assert_eq!(env::var(key).unwrap(), "first");
198 }
199
200 assert_eq!(env::var(key).unwrap(), "initial");
202 }
203
204 #[test]
205 #[serial(default_config_env)]
206 fn cwd_guard_changes_and_restores() {
207 let cur_dir = || env::current_dir().unwrap().canonicalize().unwrap();
208 let tmp1 = tempdir().unwrap();
209 let tmp2 = tempdir().unwrap();
210
211 let original_dir = cur_dir();
212
213 {
214 let _guard = CwdGuard::set(tmp1.path()).unwrap();
215 assert_eq!(cur_dir(), tmp1.path().canonicalize().unwrap());
216
217 {
218 let _guard2 = CwdGuard::set(tmp2.path()).unwrap();
219 assert_eq!(cur_dir(), tmp2.path().canonicalize().unwrap());
220 }
221
222 assert_eq!(cur_dir(), tmp1.path().canonicalize().unwrap());
224 }
225
226 assert_eq!(cur_dir(), original_dir);
228 }
229
230 #[test]
231 #[serial(default_config_env)]
232 fn cwd_guard_handles_nonexistent_directory_error() {
233 let nonexistent = Path::new("/nonexistent/path/that/does/not/exist");
234
235 let result = CwdGuard::set(nonexistent);
236 assert!(result.is_err());
237 }
238
239 #[test]
240 fn closure_guard_runs_on_drop() {
241 use std::sync::Arc;
242 use std::sync::atomic::{AtomicBool, Ordering};
243
244 let ran = Arc::new(AtomicBool::new(false));
245 let ran_clone = ran.clone();
246 {
247 let _guard = ClosureGuard::new(move || ran_clone.store(true, Ordering::SeqCst));
248 assert!(!ran.load(Ordering::SeqCst));
249 }
250 assert!(ran.load(Ordering::SeqCst));
251 }
252}