1use std::fs;
10use std::path::Path;
11
12use crate::constants::TMP_SUFFIX;
13use rustix::fd::OwnedFd;
14use rustix::fs::{openat, renameat, symlinkat, unlinkat, AtFlags, Mode, OFlags, CWD};
15use rustix::io::Errno;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Instant;
18
19fn errno_to_io(e: Errno) -> std::io::Error {
20 std::io::Error::from_raw_os_error(e.raw_os_error())
21}
22
23static NEXT_TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
25
26pub fn open_dir_nofollow(dir: &Path) -> std::io::Result<OwnedFd> {
32 use std::os::unix::ffi::OsStrExt;
33 let c = std::ffi::CString::new(dir.as_os_str().as_bytes())
34 .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid path"))?;
35 openat(
36 CWD,
37 c.as_c_str(),
38 OFlags::RDONLY | OFlags::DIRECTORY | OFlags::CLOEXEC | OFlags::NOFOLLOW,
39 Mode::empty(),
40 )
41 .map_err(errno_to_io)
42}
43
44pub fn fsync_parent_dir(path: &Path) -> std::io::Result<()> {
50 if let Some(parent) = path.parent() {
51 let dir = fs::File::open(parent)?;
52 dir.sync_all()?;
53 }
54 Ok(())
55}
56
57fn fsync_dirfd(dirfd: &OwnedFd) -> std::io::Result<()> {
61 rustix::fs::fsync(dirfd).map_err(errno_to_io)
62}
63
64pub fn atomic_symlink_swap(
70 source: &Path,
71 target: &Path,
72 allow_degraded: bool,
73 force_exdev: Option<bool>,
74) -> std::io::Result<(bool, u64)> {
75 use std::os::unix::ffi::OsStrExt;
76
77 let parent = target.parent().unwrap_or_else(|| Path::new("."));
78 let fname = target.file_name().ok_or_else(|| {
79 std::io::Error::new(
80 std::io::ErrorKind::InvalidInput,
81 "target must not end with a slash",
82 )
83 })?;
84
85 let pid = std::process::id();
86 let ctr = NEXT_TMP_COUNTER.fetch_add(1, Ordering::Relaxed);
87 let tmp_name = format!(".{}.{}.{}{}", fname.to_string_lossy(), pid, ctr, TMP_SUFFIX);
88
89 let dirfd = open_dir_nofollow(parent)?;
90
91 let tmp_c = std::ffi::CString::new(tmp_name.as_str()).map_err(|_| {
93 std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid tmp cstring")
94 })?;
95 let new_c = std::ffi::CString::new(fname.as_bytes()).map_err(|_| {
96 std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid target name")
97 })?;
98 let src_c = std::ffi::CString::new(source.as_os_str().as_bytes()).map_err(|_| {
99 std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid source path")
100 })?;
101
102 match unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty()) {
104 Ok(()) | Err(Errno::NOENT) => {}
105 Err(e) => return Err(errno_to_io(e)),
106 }
107
108 symlinkat(src_c.as_c_str(), &dirfd, tmp_c.as_c_str()).map_err(errno_to_io)?;
110
111 let rename_res = renameat(&dirfd, tmp_c.as_c_str(), &dirfd, new_c.as_c_str());
113
114 let allow_env_overrides = std::env::var_os("SWITCHYARD_TEST_ALLOW_ENV_OVERRIDES")
116 == Some(std::ffi::OsString::from("1"));
117 let inject_exdev = match force_exdev {
118 Some(b) => b,
119 None => {
120 allow_env_overrides
121 && std::env::var_os("SWITCHYARD_FORCE_EXDEV") == Some(std::ffi::OsString::from("1"))
122 }
123 };
124
125 let rename_res = if inject_exdev {
126 match rename_res {
127 Ok(()) => Err(Errno::XDEV),
128 Err(e) => Err(e),
129 }
130 } else {
131 rename_res
132 };
133
134 match rename_res {
135 Ok(()) => {
136 let t_fsync = Instant::now();
138 let res = fsync_dirfd(&dirfd);
139 let fsync_ms = u64::try_from(t_fsync.elapsed().as_millis()).unwrap_or(u64::MAX);
140
141 if let Err(e) = res {
142 let _ = e;
145 }
146
147 Ok((false, fsync_ms))
148 }
149 Err(e) if e == Errno::XDEV && allow_degraded => {
150 match unlinkat(&dirfd, new_c.as_c_str(), AtFlags::empty()) {
152 Ok(()) | Err(Errno::NOENT) => {}
153 Err(e) => {
154 let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
156 return Err(errno_to_io(e));
157 }
158 }
159 if let Err(e) =
160 symlinkat(src_c.as_c_str(), &dirfd, new_c.as_c_str()).map_err(errno_to_io)
161 {
162 let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
163 return Err(e);
164 }
165
166 let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
168
169 let t_fsync = Instant::now();
170 let res = fsync_dirfd(&dirfd);
171 let fsync_ms = u64::try_from(t_fsync.elapsed().as_millis()).unwrap_or(u64::MAX);
172 if let Err(e) = res {
173 let _ = e;
174 }
175 Ok((true, fsync_ms))
176 }
177 Err(e) => {
178 let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
180 Err(errno_to_io(e))
181 }
182 }
183}