1use std::io::{Read, Write};
2use std::net::{SocketAddr, TcpStream};
3use std::path::Path;
4use std::sync::Arc;
5
6use once_cell::sync::Lazy;
7use poem_mcpserver::{Tools, content::Json};
8use serde::{Deserialize, Serialize};
9use ssh2::Session;
10use tokio::sync::Mutex;
11use tracing::{debug, error, info};
12use uuid::Uuid;
13
14static SSH_SESSIONS: Lazy<Mutex<std::collections::HashMap<String, Arc<Mutex<Session>>>>> =
16 Lazy::new(|| Mutex::new(std::collections::HashMap::new()));
17
18#[derive(Debug, Serialize, Deserialize)]
19pub struct SshConnectResponse {
20 session_id: String,
21 message: String,
22 authenticated: bool,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26pub struct SshCommandResponse {
27 stdout: String,
28 stderr: String,
29 exit_code: i32,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct PortForwardingResponse {
34 local_address: String,
35 remote_address: String,
36 active: bool,
37}
38
39#[derive(Debug, Serialize, Deserialize)]
40pub struct ErrorResponse {
41 error: String,
42}
43
44pub struct McpSSHCommands;
45
46#[Tools]
47impl McpSSHCommands {
48 async fn ssh_connect(
50 &self,
51 address: String,
52 username: String,
53 password: Option<String>,
54 key_path: Option<String>,
55 ) -> Json<Result<SshConnectResponse, ErrorResponse>> {
56 info!("Attempting SSH connection to {}@{}", username, address);
57
58 match connect_to_ssh(
59 &address,
60 &username,
61 password.as_deref(),
62 key_path.as_deref(),
63 )
64 .await
65 {
66 Ok(session) => {
67 let session_id = Uuid::new_v4().to_string();
69
70 let mut sessions = SSH_SESSIONS.lock().await;
72 sessions.insert(session_id.clone(), Arc::new(Mutex::new(session)));
73
74 Json(Ok(SshConnectResponse {
75 session_id,
76 message: format!("Successfully connected to {}@{}", username, address),
77 authenticated: true,
78 }))
79 }
80 Err(e) => {
81 error!("SSH connection failed: {}", e);
82 Json(Err(ErrorResponse {
83 error: e.to_string(),
84 }))
85 }
86 }
87 }
88
89 async fn ssh_execute(
91 &self,
92 session_id: String,
93 command: String,
94 ) -> Json<Result<SshCommandResponse, ErrorResponse>> {
95 info!(
96 "Executing command on SSH session {}: {}",
97 session_id, command
98 );
99
100 let sessions = SSH_SESSIONS.lock().await;
101 let Some(session_arc) = sessions.get(&session_id) else {
102 return Json(Err(ErrorResponse {
103 error: format!("No active SSH session with ID: {}", session_id),
104 }));
105 };
106 let session = session_arc.lock().await;
107 let res = execute_ssh_command(&session, &command).await.map_err(|e| {
108 error!("Command execution failed: {}", e);
109 ErrorResponse {
110 error: e.to_string(),
111 }
112 });
113 Json(res)
114 }
115
116 #[cfg(feature = "port_forward")]
118 async fn ssh_forward(
119 &self,
120 session_id: String,
121 local_port: u16,
122 remote_address: String,
123 remote_port: u16,
124 ) -> Json<Result<PortForwardingResponse, ErrorResponse>> {
125 info!(
126 "Setting up port forwarding from local port {} to {}:{} using session {}",
127 local_port, remote_address, remote_port, session_id
128 );
129
130 let sessions = SSH_SESSIONS.lock().await;
131 let Some(session_arc) = sessions.get(&session_id) else {
132 return Json(Err(ErrorResponse {
133 error: format!("No active SSH session with ID: {}", session_id),
134 }));
135 };
136 let session = session_arc.lock().await;
137 match setup_port_forwarding(&session, local_port, &remote_address, remote_port).await {
138 Ok(local_addr) => Json(Ok(PortForwardingResponse {
139 local_address: local_addr.to_string(),
140 remote_address: format!("{}:{}", remote_address, remote_port),
141 active: true,
142 })),
143 Err(e) => {
144 error!("Port forwarding setup failed: {}", e);
145 Json(Err(ErrorResponse {
146 error: e.to_string(),
147 }))
148 }
149 }
150 }
151
152 async fn ssh_disconnect(&self, session_id: String) -> Json<Result<String, ErrorResponse>> {
154 info!("Disconnecting SSH session: {}", session_id);
155
156 let mut sessions = SSH_SESSIONS.lock().await;
157 if sessions.remove(&session_id).is_some() {
158 Json(Ok(format!(
159 "Session {} disconnected successfully",
160 session_id
161 )))
162 } else {
163 Json(Err(ErrorResponse {
164 error: format!("No active SSH session with ID: {}", session_id),
165 }))
166 }
167 }
168
169 async fn ssh_list_sessions(&self) -> Json<Result<Vec<String>, ErrorResponse>> {
171 let sessions = SSH_SESSIONS.lock().await;
172 let session_ids: Vec<String> = sessions.keys().cloned().collect();
173
174 Json(Ok(session_ids))
175 }
176}
177
178async fn connect_to_ssh(
181 address: &str,
182 username: &str,
183 password: Option<&str>,
184 key_path: Option<&str>,
185) -> Result<Session, String> {
186 let tcp = TcpStream::connect(address).map_err(|e| format!("Failed to connect: {}", e))?;
187 let mut sess = Session::new().map_err(|e| format!("Failed to create SSH session: {}", e))?;
188
189 sess.set_tcp_stream(tcp);
190 sess.handshake()
191 .map_err(|e| format!("SSH handshake failed: {}", e))?;
192
193 if let Some(password) = password {
195 sess.userauth_password(username, password)
196 .map_err(|e| format!("Password authentication failed: {}", e))?;
197 } else if let Some(key_path) = key_path {
198 sess.userauth_pubkey_file(username, None, Path::new(key_path), None)
199 .map_err(|e| format!("Key authentication failed: {}", e))?;
200 } else {
201 sess.userauth_agent(username)
203 .map_err(|e| format!("Agent authentication failed: {}", e))?;
204 }
205
206 if !sess.authenticated() {
207 return Err("Authentication failed".to_string());
208 }
209
210 Ok(sess)
211}
212
213async fn execute_ssh_command(sess: &Session, command: &str) -> Result<SshCommandResponse, String> {
214 let mut channel = sess
215 .channel_session()
216 .map_err(|e| format!("Failed to open channel: {}", e))?;
217
218 channel
219 .exec(command)
220 .map_err(|e| format!("Failed to execute command: {}", e))?;
221
222 let mut stdout = String::new();
223 channel
224 .read_to_string(&mut stdout)
225 .map_err(|e| format!("Failed to read stdout: {}", e))?;
226
227 let mut stderr = String::new();
228 channel
229 .stderr()
230 .read_to_string(&mut stderr)
231 .map_err(|e| format!("Failed to read stderr: {}", e))?;
232
233 let exit_code = channel
234 .exit_status()
235 .map_err(|e| format!("Failed to get exit status: {}", e))?;
236
237 channel
238 .wait_close()
239 .map_err(|e| format!("Failed to close channel: {}", e))?;
240
241 Ok(SshCommandResponse {
242 stdout,
243 stderr,
244 exit_code,
245 })
246}
247
248#[cfg(feature = "port_forward")]
249async fn setup_port_forwarding(
250 sess: &Session,
251 local_port: u16,
252 remote_address: &str,
253 remote_port: u16,
254) -> Result<SocketAddr, String> {
255 let listener_addr = format!("127.0.0.1:{}", local_port);
257 let listener = std::net::TcpListener::bind(&listener_addr)
258 .map_err(|e| format!("Failed to bind to local port {}: {}", local_port, e))?;
259
260 let local_addr = listener
261 .local_addr()
262 .map_err(|e| format!("Failed to get local address: {}", e))?;
263
264 let sess_clone = sess.clone();
266
267 let remote_addr_clone = remote_address.to_string();
269 let remote_port_clone = remote_port;
270
271 std::thread::spawn(move || {
273 debug!("Port forwarding active on {}", local_addr);
274
275 for stream in listener.incoming() {
276 match stream {
277 Ok(local_stream) => {
278 let client_addr = match local_stream.peer_addr() {
279 Ok(addr) => addr,
280 Err(_) => continue,
281 };
282
283 debug!("New connection from {} to forwarded port", client_addr);
284
285 match sess_clone.channel_direct_tcpip(
287 &remote_addr_clone,
288 remote_port_clone,
289 None,
290 ) {
291 Ok(mut remote_channel) => {
292 std::thread::spawn(move || {
294 let mut buffer = [0; 8192];
295 let mut local_stream = local_stream;
296
297 loop {
298 match local_stream.read(&mut buffer) {
299 Ok(0) => break, Ok(n) => {
301 if remote_channel.write(&buffer[..n]).is_err() {
302 break;
303 }
304 }
305 Err(_) => break,
306 }
307 }
308
309 let _ = remote_channel.close();
311 debug!("Port forwarding connection closed");
312 });
313 }
314 Err(e) => {
315 error!("Failed to create direct channel: {}", e);
316 }
317 }
318 }
319 Err(e) => {
320 error!("Error accepting connection: {}", e);
321 break;
322 }
323 }
324 }
325 });
326
327 Ok(local_addr)
328}