1use std::collections::HashSet;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use crate::error::{Result, YgrepError};
6
7pub struct SymlinkResolver {
9 visited_canonical: HashSet<PathBuf>,
11
12 max_depth: usize,
14
15 follow_symlinks: bool,
17}
18
19impl SymlinkResolver {
20 pub fn new(follow_symlinks: bool, max_depth: usize) -> Self {
21 Self {
22 visited_canonical: HashSet::new(),
23 max_depth,
24 follow_symlinks,
25 }
26 }
27
28 pub fn resolve(&mut self, path: &Path) -> Result<ResolvedPath> {
30 self.resolve_inner(path, 0)
31 }
32
33 fn resolve_inner(&mut self, path: &Path, depth: usize) -> Result<ResolvedPath> {
34 if depth > self.max_depth {
35 return Err(YgrepError::SymlinkDepthExceeded(path.to_path_buf()));
36 }
37
38 let metadata = match fs::symlink_metadata(path) {
39 Ok(m) => m,
40 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
41 return Ok(ResolvedPath::Skipped(SkipReason::NotFound));
42 }
43 Err(e) => return Err(e.into()),
44 };
45
46 if metadata.is_symlink() {
47 if !self.follow_symlinks {
48 return Ok(ResolvedPath::Skipped(SkipReason::SymlinkNotFollowed));
49 }
50
51 let target = match fs::read_link(path) {
53 Ok(t) => t,
54 Err(_) => {
55 return Ok(ResolvedPath::Skipped(SkipReason::BrokenSymlink));
56 }
57 };
58
59 let resolved = if target.is_absolute() {
61 target
62 } else {
63 path.parent()
64 .ok_or_else(|| YgrepError::InvalidPath(path.to_path_buf()))?
65 .join(&target)
66 };
67
68 let canonical = match fs::canonicalize(&resolved) {
70 Ok(c) => c,
71 Err(_) => {
72 return Ok(ResolvedPath::Skipped(SkipReason::BrokenSymlink));
73 }
74 };
75
76 if self.visited_canonical.contains(&canonical) {
78 return Ok(ResolvedPath::Skipped(SkipReason::CircularSymlink));
79 }
80
81 self.visited_canonical.insert(canonical.clone());
82
83 return Ok(ResolvedPath::Resolved {
84 original: path.to_path_buf(),
85 canonical,
86 is_symlink: true,
87 });
88 }
89
90 let canonical = match fs::canonicalize(path) {
92 Ok(c) => c,
93 Err(_) => path.to_path_buf(),
94 };
95
96 if self.visited_canonical.contains(&canonical) {
98 return Ok(ResolvedPath::Skipped(SkipReason::Duplicate));
99 }
100
101 self.visited_canonical.insert(canonical.clone());
102
103 Ok(ResolvedPath::Resolved {
104 original: path.to_path_buf(),
105 canonical,
106 is_symlink: false,
107 })
108 }
109
110 pub fn is_visited(&self, canonical: &Path) -> bool {
112 self.visited_canonical.contains(canonical)
113 }
114
115 pub fn mark_visited(&mut self, canonical: PathBuf) {
117 self.visited_canonical.insert(canonical);
118 }
119
120 pub fn reset(&mut self) {
122 self.visited_canonical.clear();
123 }
124
125 pub fn visited_count(&self) -> usize {
127 self.visited_canonical.len()
128 }
129}
130
131#[derive(Debug, Clone)]
133pub enum ResolvedPath {
134 Resolved {
135 original: PathBuf,
136 canonical: PathBuf,
137 is_symlink: bool,
138 },
139 Skipped(SkipReason),
140}
141
142impl ResolvedPath {
143 pub fn canonical(&self) -> Option<&Path> {
144 match self {
145 ResolvedPath::Resolved { canonical, .. } => Some(canonical),
146 ResolvedPath::Skipped(_) => None,
147 }
148 }
149
150 pub fn is_skipped(&self) -> bool {
151 matches!(self, ResolvedPath::Skipped(_))
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum SkipReason {
158 CircularSymlink,
159 SymlinkNotFollowed,
160 BrokenSymlink,
161 Duplicate,
162 NotFound,
163}
164
165impl std::fmt::Display for SkipReason {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 match self {
168 SkipReason::CircularSymlink => write!(f, "circular symlink"),
169 SkipReason::SymlinkNotFollowed => write!(f, "symlink not followed"),
170 SkipReason::BrokenSymlink => write!(f, "broken symlink"),
171 SkipReason::Duplicate => write!(f, "duplicate path"),
172 SkipReason::NotFound => write!(f, "not found"),
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use tempfile::tempdir;
181
182 #[test]
183 fn test_regular_file() {
184 let temp_dir = tempdir().unwrap();
185 let file_path = temp_dir.path().join("test.txt");
186 fs::write(&file_path, "content").unwrap();
187
188 let mut resolver = SymlinkResolver::new(true, 10);
189 let result = resolver.resolve(&file_path).unwrap();
190
191 match result {
192 ResolvedPath::Resolved { is_symlink, .. } => {
193 assert!(!is_symlink);
194 }
195 _ => panic!("Expected Resolved"),
196 }
197 }
198
199 #[test]
200 fn test_symlink_detection() {
201 let temp_dir = tempdir().unwrap();
202 let file_path = temp_dir.path().join("target.txt");
203 let link_path = temp_dir.path().join("link.txt");
204
205 fs::write(&file_path, "content").unwrap();
206
207 #[cfg(unix)]
208 {
209 std::os::unix::fs::symlink(&file_path, &link_path).unwrap();
210
211 let mut resolver = SymlinkResolver::new(true, 10);
212 let result = resolver.resolve(&link_path).unwrap();
213
214 match result {
215 ResolvedPath::Resolved { is_symlink, .. } => {
216 assert!(is_symlink);
217 }
218 _ => panic!("Expected Resolved"),
219 }
220 }
221 }
222
223 #[test]
224 fn test_duplicate_detection() {
225 let temp_dir = tempdir().unwrap();
226 let file_path = temp_dir.path().join("test.txt");
227 fs::write(&file_path, "content").unwrap();
228
229 let mut resolver = SymlinkResolver::new(true, 10);
230
231 let result1 = resolver.resolve(&file_path).unwrap();
233 assert!(!result1.is_skipped());
234
235 let result2 = resolver.resolve(&file_path).unwrap();
237 match result2 {
238 ResolvedPath::Skipped(SkipReason::Duplicate) => {}
239 _ => panic!("Expected Skipped(Duplicate)"),
240 }
241 }
242}