ssh_mcp/mcp/
ssh_commands.rs

1use std::io::{Read, Write};
2use std::net::{SocketAddr, TcpStream};
3use std::path::Path;
4use std::sync::Arc;
5
6use once_cell::sync::Lazy;
7use poem_mcpserver::{Tools, content::Json};
8use serde::{Deserialize, Serialize};
9use ssh2::Session;
10use tokio::sync::Mutex;
11use tracing::{debug, error, info};
12use uuid::Uuid;
13
14// Global storage for active SSH sessions
15static SSH_SESSIONS: Lazy<Mutex<std::collections::HashMap<String, Arc<Mutex<Session>>>>> =
16    Lazy::new(|| Mutex::new(std::collections::HashMap::new()));
17
18#[derive(Debug, Serialize, Deserialize)]
19pub struct SshConnectResponse {
20    session_id: String,
21    message: String,
22    authenticated: bool,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26pub struct SshCommandResponse {
27    stdout: String,
28    stderr: String,
29    exit_code: i32,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct PortForwardingResponse {
34    local_address: String,
35    remote_address: String,
36    active: bool,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
40pub struct ErrorResponse {
41    error: String,
42}
43
44pub struct McpSSHCommands;
45
46#[Tools]
47impl McpSSHCommands {
48    /// Connect to an SSH server and store the session
49    async fn ssh_connect(
50        &self,
51        address: String,
52        username: String,
53        password: Option<String>,
54        key_path: Option<String>,
55    ) -> Json<Result<SshConnectResponse, ErrorResponse>> {
56        info!("Attempting SSH connection to {}@{}", username, address);
57
58        match connect_to_ssh(
59            &address,
60            &username,
61            password.as_deref(),
62            key_path.as_deref(),
63        )
64        .await
65        {
66            Ok(session) => {
67                // Generate a unique session ID
68                let session_id = Uuid::new_v4().to_string();
69
70                // Store the session
71                let mut sessions = SSH_SESSIONS.lock().await;
72                sessions.insert(session_id.clone(), Arc::new(Mutex::new(session)));
73
74                Json(Ok(SshConnectResponse {
75                    session_id,
76                    message: format!("Successfully connected to {}@{}", username, address),
77                    authenticated: true,
78                }))
79            }
80            Err(e) => {
81                error!("SSH connection failed: {}", e);
82                Json(Err(ErrorResponse {
83                    error: e.to_string(),
84                }))
85            }
86        }
87    }
88
89    /// Execute a command on a connected SSH session
90    async fn ssh_execute(
91        &self,
92        session_id: String,
93        command: String,
94    ) -> Json<Result<SshCommandResponse, ErrorResponse>> {
95        info!(
96            "Executing command on SSH session {}: {}",
97            session_id, command
98        );
99
100        let sessions = SSH_SESSIONS.lock().await;
101        let Some(session_arc) = sessions.get(&session_id) else {
102            return Json(Err(ErrorResponse {
103                error: format!("No active SSH session with ID: {}", session_id),
104            }));
105        };
106        let session = session_arc.lock().await;
107        let res = execute_ssh_command(&session, &command).await.map_err(|e| {
108            error!("Command execution failed: {}", e);
109            ErrorResponse {
110                error: e.to_string(),
111            }
112        });
113        Json(res)
114    }
115
116    /// Setup port forwarding on an existing SSH session
117    #[cfg(feature = "port_forward")]
118    async fn ssh_forward(
119        &self,
120        session_id: String,
121        local_port: u16,
122        remote_address: String,
123        remote_port: u16,
124    ) -> Json<Result<PortForwardingResponse, ErrorResponse>> {
125        info!(
126            "Setting up port forwarding from local port {} to {}:{} using session {}",
127            local_port, remote_address, remote_port, session_id
128        );
129
130        let sessions = SSH_SESSIONS.lock().await;
131        let Some(session_arc) = sessions.get(&session_id) else {
132            return Json(Err(ErrorResponse {
133                error: format!("No active SSH session with ID: {}", session_id),
134            }));
135        };
136        let session = session_arc.lock().await;
137        match setup_port_forwarding(&session, local_port, &remote_address, remote_port).await {
138            Ok(local_addr) => Json(Ok(PortForwardingResponse {
139                local_address: local_addr.to_string(),
140                remote_address: format!("{}:{}", remote_address, remote_port),
141                active: true,
142            })),
143            Err(e) => {
144                error!("Port forwarding setup failed: {}", e);
145                Json(Err(ErrorResponse {
146                    error: e.to_string(),
147                }))
148            }
149        }
150    }
151
152    /// Disconnect an SSH session
153    async fn ssh_disconnect(&self, session_id: String) -> Json<Result<String, ErrorResponse>> {
154        info!("Disconnecting SSH session: {}", session_id);
155
156        let mut sessions = SSH_SESSIONS.lock().await;
157        if sessions.remove(&session_id).is_some() {
158            Json(Ok(format!(
159                "Session {} disconnected successfully",
160                session_id
161            )))
162        } else {
163            Json(Err(ErrorResponse {
164                error: format!("No active SSH session with ID: {}", session_id),
165            }))
166        }
167    }
168
169    /// List all active SSH sessions
170    async fn ssh_list_sessions(&self) -> Json<Result<Vec<String>, ErrorResponse>> {
171        let sessions = SSH_SESSIONS.lock().await;
172        let session_ids: Vec<String> = sessions.keys().cloned().collect();
173
174        Json(Ok(session_ids))
175    }
176}
177
178// Implementation functions for SSH operations
179
180async fn connect_to_ssh(
181    address: &str,
182    username: &str,
183    password: Option<&str>,
184    key_path: Option<&str>,
185) -> Result<Session, String> {
186    let tcp = TcpStream::connect(address).map_err(|e| format!("Failed to connect: {}", e))?;
187    let mut sess = Session::new().map_err(|e| format!("Failed to create SSH session: {}", e))?;
188
189    sess.set_tcp_stream(tcp);
190    sess.handshake()
191        .map_err(|e| format!("SSH handshake failed: {}", e))?;
192
193    // Authenticate with either password or key
194    if let Some(password) = password {
195        sess.userauth_password(username, password)
196            .map_err(|e| format!("Password authentication failed: {}", e))?;
197    } else if let Some(key_path) = key_path {
198        sess.userauth_pubkey_file(username, None, Path::new(key_path), None)
199            .map_err(|e| format!("Key authentication failed: {}", e))?;
200    } else {
201        // Try agent authentication
202        sess.userauth_agent(username)
203            .map_err(|e| format!("Agent authentication failed: {}", e))?;
204    }
205
206    if !sess.authenticated() {
207        return Err("Authentication failed".to_string());
208    }
209
210    Ok(sess)
211}
212
213async fn execute_ssh_command(sess: &Session, command: &str) -> Result<SshCommandResponse, String> {
214    let mut channel = sess
215        .channel_session()
216        .map_err(|e| format!("Failed to open channel: {}", e))?;
217
218    channel
219        .exec(command)
220        .map_err(|e| format!("Failed to execute command: {}", e))?;
221
222    let mut stdout = String::new();
223    channel
224        .read_to_string(&mut stdout)
225        .map_err(|e| format!("Failed to read stdout: {}", e))?;
226
227    let mut stderr = String::new();
228    channel
229        .stderr()
230        .read_to_string(&mut stderr)
231        .map_err(|e| format!("Failed to read stderr: {}", e))?;
232
233    let exit_code = channel
234        .exit_status()
235        .map_err(|e| format!("Failed to get exit status: {}", e))?;
236
237    channel
238        .wait_close()
239        .map_err(|e| format!("Failed to close channel: {}", e))?;
240
241    Ok(SshCommandResponse {
242        stdout,
243        stderr,
244        exit_code,
245    })
246}
247
248#[cfg(feature = "port_forward")]
249async fn setup_port_forwarding(
250    sess: &Session,
251    local_port: u16,
252    remote_address: &str,
253    remote_port: u16,
254) -> Result<SocketAddr, String> {
255    // Create a TCP listener for the local port
256    let listener_addr = format!("127.0.0.1:{}", local_port);
257    let listener = std::net::TcpListener::bind(&listener_addr)
258        .map_err(|e| format!("Failed to bind to local port {}: {}", local_port, e))?;
259
260    let local_addr = listener
261        .local_addr()
262        .map_err(|e| format!("Failed to get local address: {}", e))?;
263
264    // Create a new session instance that can be moved into the spawned task
265    let sess_clone = sess.clone();
266
267    // Clone needed values for the spawned task to avoid borrowing issues
268    let remote_addr_clone = remote_address.to_string();
269    let remote_port_clone = remote_port;
270
271    // Start a separate thread to handle port forwarding connections
272    std::thread::spawn(move || {
273        debug!("Port forwarding active on {}", local_addr);
274
275        for stream in listener.incoming() {
276            match stream {
277                Ok(local_stream) => {
278                    let client_addr = match local_stream.peer_addr() {
279                        Ok(addr) => addr,
280                        Err(_) => continue,
281                    };
282
283                    debug!("New connection from {} to forwarded port", client_addr);
284
285                    // For each connection, create a channel to the remote destination
286                    match sess_clone.channel_direct_tcpip(
287                        &remote_addr_clone,
288                        remote_port_clone,
289                        None,
290                    ) {
291                        Ok(mut remote_channel) => {
292                            // Handle the forwarding in a separate thread
293                            std::thread::spawn(move || {
294                                let mut buffer = [0; 8192];
295                                let mut local_stream = local_stream;
296
297                                loop {
298                                    match local_stream.read(&mut buffer) {
299                                        Ok(0) => break, // EOF
300                                        Ok(n) => {
301                                            if remote_channel.write(&buffer[..n]).is_err() {
302                                                break;
303                                            }
304                                        }
305                                        Err(_) => break,
306                                    }
307                                }
308
309                                // Cleanup
310                                let _ = remote_channel.close();
311                                debug!("Port forwarding connection closed");
312                            });
313                        }
314                        Err(e) => {
315                            error!("Failed to create direct channel: {}", e);
316                        }
317                    }
318                }
319                Err(e) => {
320                    error!("Error accepting connection: {}", e);
321                    break;
322                }
323            }
324        }
325    });
326
327    Ok(local_addr)
328}