1use std::collections::HashMap;
2use std::io::Read;
3use std::net::TcpStream;
4use std::time::{Duration, Instant};
5
6use ssh2::Session;
7
8use crate::url::WatchTarget;
9use crate::watcher::{
10 ConnectionState, PathWatcher, WatchError, WatchEvent, WatchEventKind, WatchOptions,
11};
12
13enum WatchMode {
14 InotifyPush {
15 channel: ssh2::Channel,
16 buf: Vec<u8>,
17 },
18 StatPoll {
19 known_mtimes: HashMap<String, i64>,
20 last_poll: Instant,
21 },
22}
23
24pub struct SshWatcher {
25 session: Session,
26 target: WatchTarget,
27 mode: WatchMode,
28 pending: Vec<WatchEvent>,
29 poll_interval: Duration,
30 loss_timeout: Duration,
31 last_success: Instant,
32}
33
34impl SshWatcher {
35 pub fn connect(target: WatchTarget, options: &WatchOptions) -> Result<Self, WatchError> {
36 let host = target
37 .host
38 .as_deref()
39 .ok_or_else(|| WatchError::InvalidUrl("SSH requires a host".to_string()))?;
40 let port = target.port.unwrap_or(22);
41
42 let tcp = TcpStream::connect(format!("{host}:{port}"))
43 .map_err(|e| WatchError::Connection(e.to_string()))?;
44
45 let mut session = Session::new().map_err(|e| WatchError::Ssh(e.to_string()))?;
46 session.set_tcp_stream(tcp);
47 session
48 .handshake()
49 .map_err(|e| WatchError::Ssh(e.to_string()))?;
50
51 let user = target.user.as_deref().unwrap_or("root");
52 authenticate(&session, user, options)?;
53
54 let mode = try_inotifywait(&session, &target.path).unwrap_or_else(|| WatchMode::StatPoll {
55 known_mtimes: HashMap::new(),
56 last_poll: Instant::now() - options.poll_interval,
57 });
58
59 Ok(Self {
60 session,
61 target,
62 mode,
63 pending: Vec::new(),
64 poll_interval: options.poll_interval,
65 loss_timeout: options.loss_timeout,
66 last_success: Instant::now(),
67 })
68 }
69}
70
71fn authenticate(session: &Session, user: &str, options: &WatchOptions) -> Result<(), WatchError> {
72 if let Some(key_path) = &options.key_path {
73 session
74 .userauth_pubkey_file(user, None, key_path, options.password.as_deref())
75 .map_err(|e| WatchError::Ssh(format!("key auth failed: {e}")))?;
76 } else if let Some(password) = &options.password {
77 session
78 .userauth_password(user, password)
79 .map_err(|e| WatchError::Ssh(format!("password auth failed: {e}")))?;
80 } else {
81 session
82 .userauth_agent(user)
83 .map_err(|e| WatchError::Ssh(format!("agent auth failed: {e}")))?;
84 }
85 Ok(())
86}
87
88fn try_inotifywait(session: &Session, path: &str) -> Option<WatchMode> {
89 let mut check = session.channel_session().ok()?;
90 check.exec("which inotifywait").ok()?;
91 let mut output = String::new();
92 check.read_to_string(&mut output).ok()?;
93 check.wait_close().ok()?;
94 if check.exit_status().ok()? != 0 {
95 return None;
96 }
97
98 let mut channel = session.channel_session().ok()?;
99 let quoted_path = shlex::try_quote(path).ok()?;
100 let cmd = format!("inotifywait -m -r --format '%w%f %e' {quoted_path}");
101 channel.exec(&cmd).ok()?;
102
103 Some(WatchMode::InotifyPush {
104 channel,
105 buf: Vec::new(),
106 })
107}
108
109impl PathWatcher for SshWatcher {
110 fn poll(&mut self) -> Result<Vec<WatchEvent>, WatchError> {
111 match &mut self.mode {
112 WatchMode::InotifyPush { channel, buf } => {
113 self.session.set_blocking(false);
114 let mut tmp = [0u8; 4096];
115 loop {
116 match channel.read(&mut tmp) {
117 Ok(0) => break,
118 Ok(n) => buf.extend_from_slice(&tmp[..n]),
119 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
120 Err(e) => {
121 self.session.set_blocking(true);
122 return Err(WatchError::Ssh(e.to_string()));
123 }
124 }
125 }
126 self.session.set_blocking(true);
127
128 while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
129 let line = String::from_utf8_lossy(&buf[..pos]).to_string();
130 buf.drain(..=pos);
131 if let Some(event) = parse_inotify_line(&line) {
132 self.pending.push(event);
133 }
134 }
135
136 if !self.pending.is_empty() {
137 self.last_success = Instant::now();
138 }
139 }
140 WatchMode::StatPoll {
141 known_mtimes,
142 last_poll,
143 } => {
144 if last_poll.elapsed() < self.poll_interval {
145 return Ok(Vec::new());
146 }
147 *last_poll = Instant::now();
148
149 let path = self.target.path.clone();
150 let mut channel = self
151 .session
152 .channel_session()
153 .map_err(|e| WatchError::Ssh(e.to_string()))?;
154
155 let quoted_path = shlex::try_quote(&path).map_err(|_| {
156 WatchError::InvalidUrl(format!("path contains invalid characters: {path}"))
157 })?;
158 let cmd = format!("find {quoted_path} -type f -printf '%p %T@\\n'");
159 channel
160 .exec(&cmd)
161 .map_err(|e| WatchError::Ssh(e.to_string()))?;
162
163 let mut output = String::new();
164 channel
165 .read_to_string(&mut output)
166 .map_err(|e| WatchError::Ssh(e.to_string()))?;
167 let _ = channel.wait_close();
168
169 self.last_success = Instant::now();
170
171 let mut current_mtimes: HashMap<String, i64> = HashMap::new();
172 for line in output.lines() {
173 let parts: Vec<&str> = line.rsplitn(2, ' ').collect();
174 if parts.len() != 2 {
175 continue;
176 }
177 let mtime_str = parts[0];
178 let file_path = parts[1];
179 if let Ok(mtime) = mtime_str.parse::<f64>() {
180 current_mtimes.insert(file_path.to_string(), mtime as i64);
181 }
182 }
183
184 for (file_path, mtime) in ¤t_mtimes {
185 let changed = match known_mtimes.get(file_path) {
186 Some(old_mtime) => *mtime != *old_mtime,
187 None => true,
188 };
189 if changed {
190 let kind = if known_mtimes.contains_key(file_path) {
191 WatchEventKind::Modified
192 } else {
193 WatchEventKind::Created
194 };
195 self.pending.push(WatchEvent {
196 path: file_path.clone(),
197 kind,
198 });
199 }
200 }
201
202 *known_mtimes = current_mtimes;
203 }
204 }
205
206 Ok(std::mem::take(&mut self.pending))
207 }
208
209 fn read(&mut self, path: &str) -> Result<Vec<u8>, WatchError> {
210 let sftp = self
211 .session
212 .sftp()
213 .map_err(|e| WatchError::Ssh(e.to_string()))?;
214
215 let mut file = sftp
216 .open(std::path::Path::new(path))
217 .map_err(|e| WatchError::Ssh(e.to_string()))?;
218
219 let mut buf = Vec::new();
220 file.read_to_end(&mut buf)
221 .map_err(|e| WatchError::Ssh(e.to_string()))?;
222
223 self.last_success = Instant::now();
224 Ok(buf)
225 }
226
227 fn has_pending(&self) -> bool {
228 !self.pending.is_empty()
229 }
230
231 fn connection_state(&self) -> ConnectionState {
232 let elapsed = self.last_success.elapsed();
233 if elapsed < self.poll_interval * 2 {
234 ConnectionState::Connected
235 } else if elapsed < self.loss_timeout {
236 ConnectionState::Degraded
237 } else {
238 ConnectionState::Lost
239 }
240 }
241}
242
243fn parse_inotify_line(line: &str) -> Option<WatchEvent> {
244 let parts: Vec<&str> = line.splitn(2, ' ').collect();
245 if parts.len() != 2 {
246 return None;
247 }
248
249 let path = parts[0].to_string();
250 let events_str = parts[1];
251
252 let kind = if events_str.contains("CREATE") {
253 WatchEventKind::Created
254 } else if events_str.contains("MODIFY") || events_str.contains("CLOSE_WRITE") {
255 WatchEventKind::Modified
256 } else {
257 return None;
258 };
259
260 Some(WatchEvent { path, kind })
261}