Skip to main content

steer_tui/tui/
custom_commands.rs

1use crate::tui::commands::{CoreCommandType, TuiCommandType};
2use directories::ProjectDirs;
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use strum::IntoEnumIterator;
6use thiserror::Error;
7use tracing::{debug, warn};
8
9/// Errors that can occur with custom commands
10#[derive(Debug, Error)]
11pub enum CustomCommandError {
12    #[error("Failed to load custom commands config: {0}")]
13    ConfigLoadError(String),
14    #[error("Failed to parse custom commands config: {0}")]
15    ParseError(#[from] toml::de::Error),
16    #[error("IO error: {0}")]
17    IoError(#[from] std::io::Error),
18    #[error("Invalid command name '{0}': {1}")]
19    InvalidCommandName(String, String),
20    #[error("Command name '{0}' conflicts with built-in command")]
21    ConflictingCommandName(String),
22}
23
24/// Represents a custom command that can be dynamically loaded
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum CustomCommand {
28    /// A simple prompt command that sends predefined text
29    Prompt {
30        name: String,
31        description: String,
32        prompt: String,
33    },
34    // Future command types can be added here:
35    // Shell { name: String, description: String, command: String },
36    // Macro { name: String, description: String, steps: Vec<String> },
37}
38
39impl CustomCommand {
40    /// Get the name of the command
41    pub fn name(&self) -> &str {
42        match self {
43            CustomCommand::Prompt { name, .. } => name,
44        }
45    }
46
47    /// Get the description of the command
48    pub fn description(&self) -> &str {
49        match self {
50            CustomCommand::Prompt { description, .. } => description,
51        }
52    }
53
54    /// Validate the command configuration
55    pub fn validate(&self) -> Result<(), CustomCommandError> {
56        let name = self.name();
57
58        // Check for empty name
59        if name.is_empty() {
60            return Err(CustomCommandError::InvalidCommandName(
61                name.to_string(),
62                "Command name cannot be empty".to_string(),
63            ));
64        }
65
66        // Check for invalid characters
67        if name.contains('/') || name.contains(' ') {
68            return Err(CustomCommandError::InvalidCommandName(
69                name.to_string(),
70                "Command name cannot contain '/' or spaces".to_string(),
71            ));
72        }
73
74        // Check for conflicts with built-in commands
75        for cmd in TuiCommandType::iter() {
76            if cmd.command_name() == name {
77                return Err(CustomCommandError::ConflictingCommandName(name.to_string()));
78            }
79        }
80
81        for cmd in CoreCommandType::iter() {
82            if cmd.command_name() == name {
83                return Err(CustomCommandError::ConflictingCommandName(name.to_string()));
84            }
85        }
86
87        Ok(())
88    }
89}
90
91/// Configuration file structure for custom commands
92#[derive(Debug, Serialize, Deserialize, Default)]
93pub struct CustomCommandsConfig {
94    #[serde(default)]
95    pub commands: Vec<CustomCommand>,
96}
97
98/// Get all paths where custom commands can be defined, in order of precedence
99pub fn get_config_paths() -> Vec<PathBuf> {
100    let mut paths = Vec::new();
101
102    // 1. Project-specific config (highest precedence)
103    paths.push(PathBuf::from(".steer").join("commands.toml"));
104
105    // 2. User config directory (platform-specific)
106    if let Some(proj_dirs) = ProjectDirs::from("", "", "steer") {
107        paths.push(proj_dirs.config_dir().join("commands.toml"));
108    }
109
110    paths
111}
112
113/// Get the path to the custom commands configuration file
114pub fn get_config_path() -> PathBuf {
115    // For backwards compatibility, return the first writable path
116    get_config_paths()
117        .into_iter()
118        .next()
119        .unwrap_or_else(|| PathBuf::from(".steer").join("commands.toml"))
120}
121
122/// Load custom commands from the configuration file
123pub fn load_custom_commands() -> Result<Vec<CustomCommand>, CustomCommandError> {
124    let mut all_commands = Vec::new();
125    let mut seen_names = std::collections::HashSet::new();
126
127    // Load from all paths in order of precedence
128    for config_path in get_config_paths() {
129        debug!("Checking for custom commands at: {}", config_path.display());
130
131        if !config_path.exists() {
132            debug!("Config not found at {}", config_path.display());
133            continue;
134        }
135
136        let config_content = match std::fs::read_to_string(&config_path) {
137            Ok(content) => content,
138            Err(e) => {
139                warn!("Failed to read {}: {}", config_path.display(), e);
140                continue;
141            }
142        };
143
144        let config: CustomCommandsConfig = match toml::from_str(&config_content) {
145            Ok(config) => config,
146            Err(e) => {
147                warn!(
148                    "Failed to parse {}: {}. Skipping this config file.",
149                    config_path.display(),
150                    e
151                );
152                continue;
153            }
154        };
155
156        debug!(
157            "Found {} commands in {}",
158            config.commands.len(),
159            config_path.display()
160        );
161
162        // Check for duplicates within the same file
163        let mut file_names = std::collections::HashSet::new();
164        for cmd in &config.commands {
165            if !file_names.insert(cmd.name().to_string()) {
166                warn!(
167                    "Duplicate command '{}' found within {}. Skipping duplicates.",
168                    cmd.name(),
169                    config_path.display()
170                );
171            }
172        }
173
174        // Add commands, skipping duplicates (earlier paths take precedence)
175        for cmd in config.commands {
176            // Validate the command
177            match cmd.validate() {
178                Ok(()) => {
179                    if seen_names.insert(cmd.name().to_string()) {
180                        all_commands.push(cmd);
181                    } else {
182                        debug!(
183                            "Skipping duplicate command '{}' from {}",
184                            cmd.name(),
185                            config_path.display()
186                        );
187                    }
188                }
189                Err(e) => {
190                    warn!(
191                        "Skipping invalid command from {}: {}",
192                        config_path.display(),
193                        e
194                    );
195                }
196            }
197        }
198    }
199
200    debug!("Total custom commands loaded: {}", all_commands.len());
201    Ok(all_commands)
202}
203
204/// Save custom commands to the configuration file
205pub fn save_custom_commands(commands: &[CustomCommand]) -> Result<(), CustomCommandError> {
206    let config_path = get_config_path();
207
208    // Ensure parent directory exists
209    if let Some(parent) = config_path.parent() {
210        std::fs::create_dir_all(parent)?;
211    }
212
213    let config = CustomCommandsConfig {
214        commands: commands.to_vec(),
215    };
216
217    let config_content = toml::to_string_pretty(&config).map_err(|e| {
218        CustomCommandError::ConfigLoadError(format!("Failed to serialize config: {e}"))
219    })?;
220
221    std::fs::write(&config_path, config_content)?;
222
223    Ok(())
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_parse_prompt_command() {
232        let config_content = r#"
233[[commands]]
234type = "prompt"
235name = "standup"
236description = "Generate a standup report"
237prompt = "What did I work on today? Check git log and recent file changes."
238"#;
239
240        let config: CustomCommandsConfig = toml::from_str(config_content).unwrap();
241        assert_eq!(config.commands.len(), 1);
242
243        match &config.commands[0] {
244            CustomCommand::Prompt {
245                name,
246                description,
247                prompt,
248            } => {
249                assert_eq!(name, "standup");
250                assert_eq!(description, "Generate a standup report");
251                assert_eq!(
252                    prompt,
253                    "What did I work on today? Check git log and recent file changes."
254                );
255            }
256        }
257    }
258
259    #[test]
260    fn test_multiple_commands() {
261        let config_content = r#"
262[[commands]]
263type = "prompt"
264name = "test"
265description = "Run tests"
266prompt = "Run the test suite and show me any failures"
267
268[[commands]]
269type = "prompt"
270name = "review"
271description = "Code review helper"
272prompt = "Review the recent changes and suggest improvements"
273"#;
274
275        let config: CustomCommandsConfig = toml::from_str(config_content).unwrap();
276        assert_eq!(config.commands.len(), 2);
277        assert_eq!(config.commands[0].name(), "test");
278        assert_eq!(config.commands[1].name(), "review");
279    }
280}