1use crate::utils::{DirectoryEntry, FileSystemProvider};
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use russh::client::{self, Handler};
5use russh_sftp::client::SftpSession;
6use serde::{Deserialize, Serialize};
7use std::{
8 collections::HashMap,
9 fmt::{self, Display},
10 fs,
11 path::PathBuf,
12 sync::Arc,
13 time::Duration,
14};
15use tokio::io::AsyncWriteExt;
16use tokio::sync::RwLock;
17use tracing::debug;
18use uuid;
19
20#[derive(Debug)]
21struct ParsedConnection {
22 username: String,
23 hostname: String,
24 port: u16,
25}
26
27pub struct CommandOptions {
28 pub timeout: Option<Duration>,
29 pub with_progress: bool,
30 pub simple: bool,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RemoteConnectionInfo {
35 pub connection_string: String, pub password: Option<String>,
37 pub private_key_path: Option<String>,
38}
39
40impl RemoteConnectionInfo {
41 fn parse_connection_string(&self) -> Result<ParsedConnection> {
42 let (username, host_port) = self.connection_string.split_once('@').ok_or_else(|| {
43 anyhow!("Invalid connection string format. Expected: user@host or user@host:port")
44 })?;
45
46 let (hostname, port) = if let Some((host, port_str)) = host_port.split_once(':') {
47 let port = port_str
48 .parse::<u16>()
49 .map_err(|_| anyhow!("Invalid port number: {}", port_str))?;
50 (host, port)
51 } else {
52 (host_port, 22)
53 };
54
55 Ok(ParsedConnection {
56 username: username.to_string(),
57 hostname: hostname.to_string(),
58 port,
59 })
60 }
61}
62
63pub struct SSHClient;
64
65impl Handler for SSHClient {
66 type Error = russh::Error;
67
68 async fn check_server_key(
69 &mut self,
70 _server_public_key: &russh::keys::PublicKey,
71 ) -> Result<bool, Self::Error> {
72 Ok(true)
75 }
76}
77
78pub struct RemoteConnection {
79 sftp: SftpSession,
80 connection_info: RemoteConnectionInfo,
81}
82
83impl RemoteConnection {
84 fn map_ssh_error(error: russh::Error, context: &str) -> anyhow::Error {
85 anyhow!("SSH {}: {}", context, error)
86 }
87
88 fn map_auth_error(result: russh::client::AuthResult, method: &str) -> Result<()> {
89 match result {
90 russh::client::AuthResult::Success => Ok(()),
91 _ => Err(anyhow!("{} authentication failed", method)),
92 }
93 }
94
95 async fn create_authenticated_session_static(
96 connection_info: &RemoteConnectionInfo,
97 ) -> Result<client::Handle<SSHClient>> {
98 let parsed = connection_info.parse_connection_string()?;
99
100 debug!(
101 "Connecting to {}@{}:{}",
102 parsed.username, parsed.hostname, parsed.port
103 );
104
105 let config = client::Config::default();
106 let mut session = client::connect(
107 config.into(),
108 (parsed.hostname.as_str(), parsed.port),
109 SSHClient {},
110 )
111 .await
112 .map_err(|e| Self::map_ssh_error(e, "connection failed"))?;
113
114 Self::authenticate_session_static(&mut session, &parsed.username, connection_info).await?;
115 Ok(session)
116 }
117
118 async fn authenticate_session_static(
119 session: &mut client::Handle<SSHClient>,
120 username: &str,
121 connection_info: &RemoteConnectionInfo,
122 ) -> Result<()> {
123 if let Some(password) = &connection_info.password {
124 debug!("Authenticating with password");
125 let auth_result = session
126 .authenticate_password(username, password)
127 .await
128 .map_err(|e| Self::map_ssh_error(e, "password authentication"))?;
129 Self::map_auth_error(auth_result, "Password")?;
130 } else {
131 debug!("Authenticating with public key");
132 let private_key_path = if let Some(path) = &connection_info.private_key_path {
133 Self::canonicalize_key_path(path)?
134 } else {
135 Self::get_default_key_files()?.0
136 };
137
138 let keypair = russh::keys::load_secret_key(&private_key_path, None).map_err(|e| {
139 anyhow!(
140 "Failed to load private key from {}: {}",
141 private_key_path.display(),
142 e
143 )
144 })?;
145
146 let auth_result = session
147 .authenticate_publickey(
148 username,
149 russh::keys::PrivateKeyWithHashAlg::new(
150 Arc::new(keypair),
151 Some(russh::keys::HashAlg::Sha256),
152 ),
153 )
154 .await
155 .map_err(|e| Self::map_ssh_error(e, "public key authentication"))?;
156 Self::map_auth_error(auth_result, "Public key")?;
157 }
158 Ok(())
159 }
160
161 pub fn get_default_key_files() -> Result<(PathBuf, PathBuf)> {
162 let home_dir = dirs::home_dir().ok_or_else(|| anyhow!("Home directory not found"))?;
163 let ssh_dir = home_dir.join(".ssh");
164
165 if !ssh_dir.is_dir() {
166 return Err(anyhow!("SSH directory not found: {}", ssh_dir.display()));
167 }
168
169 let key_names = ["id_ed25519", "id_rsa", "id_ecdsa", "id_dsa"];
171
172 for key_name in &key_names {
173 let private_key = ssh_dir.join(key_name);
174 let public_key = ssh_dir.join(format!("{}.pub", key_name));
175
176 if private_key.is_file() {
177 return Ok((private_key, public_key));
178 }
179 }
180
181 Err(anyhow!("No SSH private key found in {}", ssh_dir.display()))
182 }
183
184 pub fn canonicalize_key_path(path: &str) -> Result<PathBuf> {
186 let path_buf = PathBuf::from(path);
187
188 if path_buf.is_absolute() {
190 return fs::canonicalize(&path_buf)
191 .map_err(|e| anyhow!("Failed to access private key at {}: {}", path, e));
192 }
193
194 if let Ok(canonical) = fs::canonicalize(&path_buf) {
196 return Ok(canonical);
197 }
198
199 let home_dir = dirs::home_dir()
201 .ok_or_else(|| anyhow!("Home directory not found for relative key path"))?;
202 let ssh_relative_path = home_dir.join(".ssh").join(&path_buf);
203
204 if ssh_relative_path.exists() {
205 return fs::canonicalize(ssh_relative_path)
206 .map_err(|e| anyhow!("Failed to access private key at ~/.ssh/{}: {}", path, e));
207 }
208
209 if let Some(stripped) = path.strip_prefix("~/") {
211 let expanded_path = home_dir.join(stripped);
212 return fs::canonicalize(expanded_path)
213 .map_err(|e| anyhow!("Failed to access private key at {}: {}", path, e));
214 }
215
216 Err(anyhow!(
217 "Private key not found at {} (tried current directory and ~/.ssh/)",
218 path
219 ))
220 }
221
222 pub async fn new(connection_info: RemoteConnectionInfo) -> Result<Self> {
223 let session = Self::create_authenticated_session_static(&connection_info).await?;
224
225 let channel = session
227 .channel_open_session()
228 .await
229 .map_err(|e| Self::map_ssh_error(e, "failed to open SSH channel"))?;
230
231 channel
232 .request_subsystem(true, "sftp")
233 .await
234 .map_err(|e| Self::map_ssh_error(e, "failed to request SFTP subsystem"))?;
235
236 let sftp = SftpSession::new(channel.into_stream())
237 .await
238 .map_err(|e| anyhow!("Failed to create SFTP session: {}", e))?;
239
240 debug!("SFTP connection established successfully");
241
242 Ok(Self {
243 sftp,
244 connection_info,
245 })
246 }
247
248 pub async fn separator(&self) -> Result<char> {
249 let canonical_path = self.sftp.canonicalize("/").await?;
251 Ok(if canonical_path.contains('\\') {
252 '\\'
253 } else {
254 '/'
255 })
256 }
257
258 pub async fn canonicalize(&self, path: &str) -> Result<String> {
259 self.sftp
260 .canonicalize(path)
261 .await
262 .map_err(|e| anyhow!("Failed to canonicalize path {}: {}", path, e))
263 }
264
265 pub fn get_ssh_prefix(&self) -> Result<String> {
268 let parsed = self.connection_info.parse_connection_string()?;
269 if parsed.port == 22 {
270 Ok(format!("{}@{}:", parsed.username, parsed.hostname))
271 } else {
272 Ok(format!(
273 "{}@{}#{}:",
274 parsed.username, parsed.hostname, parsed.port
275 ))
276 }
277 }
278
279 pub async fn read_file(&self, path: &str) -> Result<Vec<u8>> {
280 self.sftp
281 .read(path)
282 .await
283 .map_err(|e| anyhow!("Failed to read file {}: {}", path, e))
284 }
285
286 pub async fn read_file_to_string(&self, path: &str) -> Result<String> {
287 let content = self.read_file(path).await?;
288 String::from_utf8(content)
289 .map_err(|e| anyhow!("File {} contains invalid UTF-8: {}", path, e))
290 }
291
292 pub async fn write_file(&self, path: &str, data: &[u8]) -> Result<()> {
293 self.sftp
294 .write(path, data)
295 .await
296 .map_err(|e| anyhow!("Failed to write file {}: {}", path, e))
297 }
298
299 pub async fn create_file(&self, path: &str, data: &[u8]) -> Result<()> {
300 let mut file_handle = self
302 .sftp
303 .create(path)
304 .await
305 .map_err(|e| anyhow!("Failed to create file {}: {}", path, e))?;
306
307 file_handle
309 .write_all(data)
310 .await
311 .map_err(|e| anyhow!("Failed to write data to file {}: {}", path, e))?;
312
313 Ok(())
315 }
316
317 pub async fn create_directories(&self, path: &str) -> Result<()> {
318 let path_buf = PathBuf::from(path);
319 let mut current_path = PathBuf::new();
320
321 for component in path_buf.components() {
322 current_path.push(component);
323 let path_str = current_path.to_string_lossy().to_string();
324
325 if self.sftp.read_dir(&path_str).await.is_err() {
326 self.sftp
327 .create_dir(&path_str)
328 .await
329 .map_err(|e| anyhow!("Failed to create directory {}: {}", path_str, e))?;
330 }
331 }
332
333 Ok(())
334 }
335
336 pub async fn list_directory(&self, path: &str) -> Result<Vec<String>> {
337 let entries = self
338 .sftp
339 .read_dir(path)
340 .await
341 .map_err(|e| anyhow!("Failed to read directory {}: {}", path, e))?;
342
343 let separator = self.separator().await?;
344 let mut result = Vec::new();
345
346 for entry in entries {
347 let entry_path = if path.ends_with(separator) {
348 format!("{}{}", path, entry.file_name())
349 } else {
350 format!("{}{}{}", path, separator, entry.file_name())
351 };
352 result.push(entry_path);
353 }
354
355 result.sort();
356 Ok(result)
357 }
358
359 pub async fn list_directory_with_types(&self, path: &str) -> Result<Vec<(String, bool)>> {
361 let entries = self
362 .sftp
363 .read_dir(path)
364 .await
365 .map_err(|e| anyhow!("Failed to read directory {}: {}", path, e))?;
366
367 let separator = self.separator().await?;
368 let mut result = Vec::new();
369
370 for entry in entries {
371 let entry_path = if path.ends_with(separator) {
372 format!("{}{}", path, entry.file_name())
373 } else {
374 format!("{}{}{}", path, separator, entry.file_name())
375 };
376 let is_directory = entry.metadata().is_dir();
377 result.push((entry_path, is_directory));
378 }
379
380 result.sort_by(|a, b| a.0.cmp(&b.0));
381 Ok(result)
382 }
383
384 pub async fn is_file(&self, path: &str) -> bool {
385 self.sftp
386 .metadata(path)
387 .await
388 .map(|metadata| !metadata.is_dir())
389 .unwrap_or(false)
390 }
391
392 pub async fn is_directory(&self, path: &str) -> bool {
393 self.sftp
394 .metadata(path)
395 .await
396 .map(|metadata| metadata.is_dir())
397 .unwrap_or(false)
398 }
399
400 pub async fn exists(&self, path: &str) -> bool {
401 self.sftp.metadata(path).await.is_ok()
402 }
403
404 pub async fn file_size(&self, path: &str) -> Result<u64> {
405 let metadata = self
406 .sftp
407 .metadata(path)
408 .await
409 .map_err(|e| anyhow!("Failed to get metadata for {}: {}", path, e))?;
410
411 Ok(metadata.len())
412 }
413
414 pub async fn rename(&self, old_path: &str, new_path: &str) -> Result<()> {
415 self.sftp
416 .rename(old_path, new_path)
417 .await
418 .map_err(|e| anyhow!("Failed to rename '{}' to '{}': {}", old_path, new_path, e))
419 }
420
421 pub async fn execute_command_unified(
422 &self,
423 command: &str,
424 options: CommandOptions,
425 cancel_rx: &mut tokio::sync::oneshot::Receiver<()>,
426 progress_callback: Option<impl Fn(String) + Send + Sync + 'static>,
427 ctx: Option<&rmcp::service::RequestContext<rmcp::RoleServer>>,
428 ) -> Result<(String, i32)> {
429 use regex::Regex;
430
431 let session = Self::create_authenticated_session_static(&self.connection_info).await?;
432
433 let mut channel = session
435 .channel_open_session()
436 .await
437 .map_err(|e| Self::map_ssh_error(e, "failed to open channel"))?;
438
439 let wrapped_command = if options.simple {
441 command.to_string()
442 } else {
443 format!(
444 "bash -c 'echo \"PID:$$\"; exec bash -c \"{}\"'",
445 command.replace('\\', "\\\\").replace('"', "\\\"")
446 )
447 };
448
449 channel
450 .exec(true, wrapped_command.as_str())
451 .await
452 .map_err(|e| Self::map_ssh_error(e, "failed to execute command"))?;
453
454 let mut output = String::new();
455 let mut exit_code = 0i32;
456 let mut remote_pid: Option<String> = None;
457 let progress_id = uuid::Uuid::new_v4();
458
459 let pid_regex = if !options.simple {
461 Some(Regex::new(r"PID:(\d+)").expect("Invalid PID regex"))
462 } else {
463 None
464 };
465
466 let command_execution = async {
468 while let Some(msg) = channel.wait().await {
469 match msg {
470 russh::ChannelMsg::Data { data } => {
471 let text = String::from_utf8_lossy(&data).to_string();
472
473 if let Some(ref regex) = pid_regex
475 && remote_pid.is_none()
476 && let Some(captures) = regex.captures(&text)
477 && let Some(pid_match) = captures.get(1)
478 {
479 remote_pid = Some(pid_match.as_str().to_string());
480 let cleaned_text = regex.replace_all(&text, "").to_string();
482 if !cleaned_text.trim().is_empty() {
483 output.push_str(&cleaned_text);
484 if let Some(ref callback) = progress_callback {
485 callback(cleaned_text);
486 }
487 }
488 continue;
489 }
490
491 output.push_str(&text);
493 if let Some(ref callback) = progress_callback {
494 callback(text.clone());
495 }
496
497 if let Some(ctx) = &ctx
499 && options.with_progress
500 && !text.trim().is_empty()
501 {
502 let _ = ctx.peer.notify_progress(rmcp::model::ProgressNotificationParam {
503 progress_token: rmcp::model::ProgressToken(rmcp::model::NumberOrString::Number(0)),
504 progress: 50.0,
505 total: Some(100.0),
506 message: Some(serde_json::to_string(&crate::models::integrations::openai::ToolCallResultProgress {
507 id: progress_id,
508 message: text,
509 }).unwrap_or_default()),
510 }).await;
511 }
512 }
513 russh::ChannelMsg::ExtendedData { data, ext: _ } => {
514 let text = String::from_utf8_lossy(&data).to_string();
515 output.push_str(&text);
516 if let Some(ref callback) = progress_callback {
517 callback(text.clone());
518 }
519
520 if let Some(ctx) = &ctx
522 && options.with_progress
523 && !text.trim().is_empty()
524 {
525 let _ = ctx.peer.notify_progress(rmcp::model::ProgressNotificationParam {
526 progress_token: rmcp::model::ProgressToken(rmcp::model::NumberOrString::Number(0)),
527 progress: 50.0,
528 total: Some(100.0),
529 message: Some(serde_json::to_string(&crate::models::integrations::openai::ToolCallResultProgress {
530 id: progress_id,
531 message: text,
532 }).unwrap_or_default()),
533 }).await;
534 }
535 }
536 russh::ChannelMsg::ExitStatus { exit_status } => {
537 exit_code = exit_status as i32;
538 }
539 russh::ChannelMsg::Eof => {
540 break;
541 }
542 _ => {}
543 }
544 }
545 };
546
547 macro_rules! handle_cancellation {
549 ($error_msg:expr) => {{
550 if let Some(pid) = &remote_pid {
552 let kill_cmd = format!("kill -9 {}", pid);
553 if let Ok(kill_channel) = session.channel_open_session().await {
554 let _ = kill_channel.exec(true, kill_cmd.as_str()).await;
555 let _ = kill_channel.close().await;
556 }
557 }
558 let _ = channel.close().await;
559 Err(anyhow!($error_msg))
560 }};
561 }
562
563 tokio::select! {
565 _ = command_execution => Ok((output, exit_code)),
567
568 _ = async {
570 if let Some(timeout_duration) = options.timeout {
571 tokio::time::sleep(timeout_duration).await;
572 } else {
573 std::future::pending::<()>().await;
575 }
576 } => {
577 handle_cancellation!(format!("Command timed out after {:?}", options.timeout))
578 },
579
580 _ = async {
582 if let Some(ctx) = &ctx {
583 ctx.ct.cancelled().await;
584 } else {
585 std::future::pending::<()>().await;
587 }
588 } => {
589 handle_cancellation!("Command was cancelled")
590 },
591
592 _ = cancel_rx => {
594 handle_cancellation!("Command was cancelled")
595 }
596 }
597 }
598
599 pub async fn execute_command(
600 &self,
601 command: &str,
602 timeout: Option<Duration>,
603 ctx: Option<&rmcp::service::RequestContext<rmcp::RoleServer>>,
604 ) -> Result<(String, i32)> {
605 let options = CommandOptions {
606 timeout,
607 with_progress: true,
608 simple: false,
609 };
610
611 let (_cancel_tx, mut cancel_rx) = tokio::sync::oneshot::channel();
612
613 self.execute_command_unified(command, options, &mut cancel_rx, None::<fn(String)>, ctx)
614 .await
615 }
616
617 pub async fn execute_command_with_streaming<F>(
618 &self,
619 command: &str,
620 timeout: Option<Duration>,
621 cancel_rx: &mut tokio::sync::oneshot::Receiver<()>,
622 progress_callback: F,
623 ) -> Result<(String, i32)>
624 where
625 F: Fn(String) + Send + Sync + 'static,
626 {
627 let options = CommandOptions {
628 timeout,
629 with_progress: false,
630 simple: false,
631 };
632
633 self.execute_command_unified(command, options, cancel_rx, Some(progress_callback), None)
634 .await
635 }
636
637 pub fn connection_string(&self) -> &str {
638 &self.connection_info.connection_string
639 }
640}
641
642pub struct RemoteFileSystemProvider {
644 connection: Arc<RemoteConnection>,
645}
646
647impl RemoteFileSystemProvider {
648 pub fn new(connection: Arc<RemoteConnection>) -> Self {
649 Self { connection }
650 }
651}
652
653#[async_trait]
654impl FileSystemProvider for RemoteFileSystemProvider {
655 type Error = String;
656
657 async fn list_directory(&self, path: &str) -> Result<Vec<DirectoryEntry>, Self::Error> {
658 let timeout_duration = std::time::Duration::from_secs(10);
660
661 let entries = tokio::time::timeout(
662 timeout_duration,
663 self.connection.list_directory_with_types(path),
664 )
665 .await
666 .map_err(|_| format!("Timeout listing remote directory: {}", path))?
667 .map_err(|e| format!("Failed to list remote directory: {}", e))?;
668
669 let mut result = Vec::new();
670 for (entry_path, is_directory) in entries {
671 let name = entry_path
672 .split('/')
673 .next_back()
674 .unwrap_or(&entry_path)
675 .to_string();
676
677 result.push(DirectoryEntry {
678 name,
679 path: entry_path,
680 is_directory,
681 });
682 }
683
684 Ok(result)
685 }
686}
687
688impl Display for RemoteConnection {
689 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
690 write!(f, "SSH:{}", self.connection_info.connection_string)
691 }
692}
693
694pub struct RemoteConnectionManager {
696 connections: RwLock<HashMap<String, Arc<RemoteConnection>>>,
697}
698
699impl RemoteConnectionManager {
700 pub fn new() -> Self {
701 Self {
702 connections: RwLock::new(HashMap::new()),
703 }
704 }
705
706 pub async fn get_connection(
707 &self,
708 connection_info: &RemoteConnectionInfo,
709 ) -> Result<Arc<RemoteConnection>> {
710 let key = connection_info.connection_string.clone();
711
712 {
714 let connections = self.connections.read().await;
715 if let Some(conn) = connections.get(&key) {
716 return Ok(conn.clone());
717 }
718 }
719
720 let connection = RemoteConnection::new(connection_info.clone()).await?;
722 let arc_connection = Arc::new(connection);
723
724 {
726 let mut connections = self.connections.write().await;
727 connections.insert(key, arc_connection.clone());
728 }
729
730 Ok(arc_connection)
731 }
732
733 pub async fn remove_connection(&self, connection_string: &str) {
734 let mut connections = self.connections.write().await;
735 connections.remove(connection_string);
736 }
737
738 pub async fn list_connections(&self) -> Vec<String> {
739 let connections = self.connections.read().await;
740 connections.keys().cloned().collect()
741 }
742}
743
744impl Default for RemoteConnectionManager {
745 fn default() -> Self {
746 Self::new()
747 }
748}
749
750#[derive(Debug, Clone)]
751pub enum PathLocation {
752 Local(String),
753 Remote {
754 connection: RemoteConnectionInfo,
755 path: String,
756 },
757}
758
759impl PathLocation {
760 pub fn parse(path_str: &str) -> Result<Self> {
763 if let Some(without_scheme) = path_str.strip_prefix("ssh://") {
764 if let Some((connection_part, path_part)) = without_scheme.split_once('/') {
767 let connection_info = RemoteConnectionInfo {
768 connection_string: connection_part.to_string(),
769 password: None,
770 private_key_path: None,
771 };
772
773 return Ok(PathLocation::Remote {
774 connection: connection_info,
775 path: format!("/{}", path_part),
776 });
777 }
778 } else if path_str.contains('@') && path_str.contains(':') {
779 if let Some((connection_part, path_part)) = path_str.split_once(':')
781 && path_part.starts_with('/')
782 {
783 let connection_info = RemoteConnectionInfo {
784 connection_string: connection_part.to_string(),
785 password: None,
786 private_key_path: None,
787 };
788
789 return Ok(PathLocation::Remote {
790 connection: connection_info,
791 path: path_part.to_string(),
792 });
793 }
794 }
795
796 Ok(PathLocation::Local(path_str.to_string()))
798 }
799
800 pub fn is_remote(&self) -> bool {
801 matches!(self, PathLocation::Remote { .. })
802 }
803
804 pub fn as_local_path(&self) -> Option<&str> {
805 match self {
806 PathLocation::Local(path) => Some(path),
807 PathLocation::Remote { .. } => None,
808 }
809 }
810
811 pub fn as_remote_info(&self) -> Option<(&RemoteConnectionInfo, &str)> {
812 match self {
813 PathLocation::Local(_) => None,
814 PathLocation::Remote { connection, path } => Some((connection, path)),
815 }
816 }
817}