turbomcp_cli/
executor.rs

1//! Command execution using turbomcp-client
2
3use crate::cli::*;
4use crate::error::{CliError, CliResult};
5use crate::formatter::Formatter;
6use crate::path_security;
7use crate::transport::create_client;
8use std::collections::HashMap;
9
10/// Execute CLI commands
11pub struct CommandExecutor {
12    pub formatter: Formatter,
13    verbose: bool,
14}
15
16impl CommandExecutor {
17    #[must_use]
18    pub fn new(format: OutputFormat, colored: bool, verbose: bool) -> Self {
19        Self {
20            formatter: Formatter::new(format, colored),
21            verbose,
22        }
23    }
24
25    /// Display an error with rich formatting
26    pub fn display_error(&self, error: &CliError) {
27        self.formatter.display_error(error);
28    }
29
30    /// Execute a command
31    pub async fn execute(&self, command: Commands) -> CliResult<()> {
32        match command {
33            Commands::Tools(cmd) => self.execute_tool_command(cmd).await,
34            Commands::Resources(cmd) => self.execute_resource_command(cmd).await,
35            Commands::Prompts(cmd) => self.execute_prompt_command(cmd).await,
36            Commands::Complete(cmd) => self.execute_completion_command(cmd).await,
37            Commands::Server(cmd) => self.execute_server_command(cmd).await,
38            Commands::Sample(cmd) => self.execute_sampling_command(cmd).await,
39            Commands::Connect(conn) => self.execute_connect(conn).await,
40            Commands::Status(conn) => self.execute_status(conn).await,
41        }
42    }
43
44    // Tool commands
45
46    async fn execute_tool_command(&self, command: ToolCommands) -> CliResult<()> {
47        match command {
48            ToolCommands::List { conn } => {
49                let client = create_client(&conn).await?;
50                client.initialize().await?;
51                let tools = client.list_tools().await?;
52                self.formatter.display_tools(&tools)
53            }
54
55            ToolCommands::Call {
56                conn,
57                name,
58                arguments,
59            } => {
60                let args: HashMap<String, serde_json::Value> =
61                    if arguments.trim().is_empty() || arguments == "{}" {
62                        HashMap::new()
63                    } else {
64                        serde_json::from_str(&arguments).map_err(|e| {
65                            CliError::InvalidArguments(format!("Invalid JSON arguments: {}", e))
66                        })?
67                    };
68
69                let client = create_client(&conn).await?;
70                client.initialize().await?;
71                let result = client.call_tool(&name, Some(args)).await?;
72                self.formatter.display(&result)
73            }
74
75            ToolCommands::Schema { conn, name } => {
76                let client = create_client(&conn).await?;
77                client.initialize().await?;
78                let tools = client.list_tools().await?;
79
80                if let Some(tool_name) = name {
81                    let tool = tools.iter().find(|t| t.name == tool_name).ok_or_else(|| {
82                        CliError::Other(format!("Tool '{}' not found", tool_name))
83                    })?;
84
85                    self.formatter.display(&tool.input_schema)
86                } else {
87                    let schemas: Vec<_> = tools
88                        .iter()
89                        .map(|t| {
90                            serde_json::json!({
91                                "name": t.name,
92                                "schema": t.input_schema
93                            })
94                        })
95                        .collect();
96
97                    self.formatter.display(&schemas)
98                }
99            }
100
101            ToolCommands::Export { conn, output } => {
102                let client = create_client(&conn).await?;
103                client.initialize().await?;
104                let tools = client.list_tools().await?;
105
106                // Create output directory (must exist for path validation)
107                std::fs::create_dir_all(&output)?;
108
109                let mut exported_count = 0;
110                let mut skipped_count = 0;
111
112                for tool in tools {
113                    // Sanitize tool name and construct safe output path
114                    // This prevents path traversal attacks from malicious servers
115                    match path_security::safe_output_path(&output, &tool.name, "json") {
116                        Ok(filepath) => {
117                            let schema = serde_json::to_string_pretty(&tool.input_schema)?;
118                            std::fs::write(&filepath, schema)?;
119
120                            if self.verbose {
121                                println!("Exported: {}", filepath.display());
122                            }
123                            exported_count += 1;
124                        }
125                        Err(e) => {
126                            // Log security violation but continue processing other tools
127                            eprintln!("Warning: Skipped tool '{}': {}", tool.name, e);
128                            skipped_count += 1;
129                        }
130                    }
131                }
132
133                if exported_count > 0 {
134                    println!(
135                        "✓ Exported {} schema{} to: {}",
136                        exported_count,
137                        if exported_count == 1 { "" } else { "s" },
138                        output.display()
139                    );
140                }
141
142                if skipped_count > 0 {
143                    println!(
144                        "⚠ Skipped {} tool{} due to invalid names",
145                        skipped_count,
146                        if skipped_count == 1 { "" } else { "s" }
147                    );
148                }
149
150                Ok(())
151            }
152        }
153    }
154
155    // Resource commands
156
157    async fn execute_resource_command(&self, command: ResourceCommands) -> CliResult<()> {
158        match command {
159            ResourceCommands::List { conn } => {
160                let client = create_client(&conn).await?;
161                client.initialize().await?;
162                let resources = client.list_resources().await?;
163                self.formatter.display(&resources)
164            }
165
166            ResourceCommands::Read { conn, uri } => {
167                let client = create_client(&conn).await?;
168                client.initialize().await?;
169                let result = client.read_resource(&uri).await?;
170                self.formatter.display(&result)
171            }
172
173            ResourceCommands::Templates { conn } => {
174                let client = create_client(&conn).await?;
175                client.initialize().await?;
176                let templates = client.list_resource_templates().await?;
177                self.formatter.display(&templates)
178            }
179
180            ResourceCommands::Subscribe { conn, uri } => {
181                let client = create_client(&conn).await?;
182                client.initialize().await?;
183                client.subscribe(&uri).await?;
184                println!("✓ Subscribed to: {uri}");
185                Ok(())
186            }
187
188            ResourceCommands::Unsubscribe { conn, uri } => {
189                let client = create_client(&conn).await?;
190                client.initialize().await?;
191                client.unsubscribe(&uri).await?;
192                println!("✓ Unsubscribed from: {uri}");
193                Ok(())
194            }
195        }
196    }
197
198    // Prompt commands
199
200    async fn execute_prompt_command(&self, command: PromptCommands) -> CliResult<()> {
201        match command {
202            PromptCommands::List { conn } => {
203                let client = create_client(&conn).await?;
204                client.initialize().await?;
205                let prompts = client.list_prompts().await?;
206                self.formatter.display_prompts(&prompts)
207            }
208
209            PromptCommands::Get {
210                conn,
211                name,
212                arguments,
213            } => {
214                // Parse arguments as HashMap<String, Value>
215                let args: HashMap<String, serde_json::Value> =
216                    if arguments.trim().is_empty() || arguments == "{}" {
217                        HashMap::new()
218                    } else {
219                        serde_json::from_str(&arguments).map_err(|e| {
220                            CliError::InvalidArguments(format!("Invalid JSON arguments: {}", e))
221                        })?
222                    };
223
224                let args_option = if args.is_empty() { None } else { Some(args) };
225
226                let client = create_client(&conn).await?;
227                client.initialize().await?;
228                let result = client.get_prompt(&name, args_option).await?;
229                self.formatter.display(&result)
230            }
231
232            PromptCommands::Schema { conn, name } => {
233                let client = create_client(&conn).await?;
234                client.initialize().await?;
235                let prompts = client.list_prompts().await?;
236
237                let prompt = prompts
238                    .iter()
239                    .find(|p| p.name == name)
240                    .ok_or_else(|| CliError::Other(format!("Prompt '{}' not found", name)))?;
241
242                self.formatter.display(&prompt.arguments)
243            }
244        }
245    }
246
247    // Completion commands
248
249    async fn execute_completion_command(&self, command: CompletionCommands) -> CliResult<()> {
250        match command {
251            CompletionCommands::Get {
252                conn,
253                ref_type,
254                ref_value,
255                argument,
256            } => {
257                let client = create_client(&conn).await?;
258                client.initialize().await?;
259
260                // Use the appropriate completion method based on reference type
261                let result = match ref_type {
262                    RefType::Prompt => {
263                        let arg_name = argument.as_deref().unwrap_or("value");
264                        client
265                            .complete_prompt(&ref_value, arg_name, "", None)
266                            .await?
267                    }
268                    RefType::Resource => {
269                        let arg_name = argument.as_deref().unwrap_or("uri");
270                        client
271                            .complete_resource(&ref_value, arg_name, "", None)
272                            .await?
273                    }
274                };
275
276                self.formatter.display(&result)
277            }
278        }
279    }
280
281    // Server commands
282
283    async fn execute_server_command(&self, command: ServerCommands) -> CliResult<()> {
284        match command {
285            ServerCommands::Info { conn } => {
286                let client = create_client(&conn).await?;
287                let result = client.initialize().await?;
288                self.formatter.display_server_info(&result.server_info)
289            }
290
291            ServerCommands::Ping { conn } => {
292                let client = create_client(&conn).await?;
293                let start = std::time::Instant::now();
294
295                client.initialize().await?;
296                client.ping().await?;
297
298                let elapsed = start.elapsed();
299                println!("✓ Pong! ({:.2}ms)", elapsed.as_secs_f64() * 1000.0);
300                Ok(())
301            }
302
303            ServerCommands::LogLevel { conn, level } => {
304                // Convert level once before using
305                let protocol_level: turbomcp_protocol::types::LogLevel = level.clone().into();
306
307                let client = create_client(&conn).await?;
308                client.initialize().await?;
309                client.set_log_level(protocol_level).await?;
310                println!("✓ Log level set to: {:?}", level);
311                Ok(())
312            }
313
314            ServerCommands::Roots { conn } => {
315                // Roots are part of server capabilities returned during initialization
316                let client = create_client(&conn).await?;
317                let result = client.initialize().await?;
318
319                // Display server capabilities which includes roots info
320                self.formatter.display(&result.server_capabilities)
321            }
322        }
323    }
324
325    // Sampling commands
326
327    async fn execute_sampling_command(&self, _command: SamplingCommands) -> CliResult<()> {
328        Err(CliError::NotSupported(
329            "Sampling commands require LLM handler implementation".to_string(),
330        ))
331    }
332
333    // Connection commands
334
335    async fn execute_connect(&self, conn: Connection) -> CliResult<()> {
336        println!("Connecting to server...");
337        let client = create_client(&conn).await?;
338
339        let result = client.initialize().await?;
340
341        println!("✓ Connected successfully!");
342        self.formatter.display_server_info(&result.server_info)
343    }
344
345    async fn execute_status(&self, conn: Connection) -> CliResult<()> {
346        let client = create_client(&conn).await?;
347
348        let result = client.initialize().await?;
349
350        println!("Status: Connected");
351        self.formatter.display_server_info(&result.server_info)
352    }
353}