steer_tui/tui/
custom_commands.rs1use crate::tui::commands::{CoreCommandType, TuiCommandType};
2use directories::ProjectDirs;
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use strum::IntoEnumIterator;
6use thiserror::Error;
7use tracing::{debug, warn};
8
9#[derive(Debug, Error)]
11pub enum CustomCommandError {
12 #[error("Failed to load custom commands config: {0}")]
13 ConfigLoadError(String),
14 #[error("Failed to parse custom commands config: {0}")]
15 ParseError(#[from] toml::de::Error),
16 #[error("IO error: {0}")]
17 IoError(#[from] std::io::Error),
18 #[error("Invalid command name '{0}': {1}")]
19 InvalidCommandName(String, String),
20 #[error("Command name '{0}' conflicts with built-in command")]
21 ConflictingCommandName(String),
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum CustomCommand {
28 Prompt {
30 name: String,
31 description: String,
32 prompt: String,
33 },
34 }
38
39impl CustomCommand {
40 pub fn name(&self) -> &str {
42 match self {
43 CustomCommand::Prompt { name, .. } => name,
44 }
45 }
46
47 pub fn description(&self) -> &str {
49 match self {
50 CustomCommand::Prompt { description, .. } => description,
51 }
52 }
53
54 pub fn validate(&self) -> Result<(), CustomCommandError> {
56 let name = self.name();
57
58 if name.is_empty() {
60 return Err(CustomCommandError::InvalidCommandName(
61 name.to_string(),
62 "Command name cannot be empty".to_string(),
63 ));
64 }
65
66 if name.contains('/') || name.contains(' ') {
68 return Err(CustomCommandError::InvalidCommandName(
69 name.to_string(),
70 "Command name cannot contain '/' or spaces".to_string(),
71 ));
72 }
73
74 for cmd in TuiCommandType::iter() {
76 if cmd.command_name() == name {
77 return Err(CustomCommandError::ConflictingCommandName(name.to_string()));
78 }
79 }
80
81 for cmd in CoreCommandType::iter() {
82 if cmd.command_name() == name {
83 return Err(CustomCommandError::ConflictingCommandName(name.to_string()));
84 }
85 }
86
87 Ok(())
88 }
89}
90
91#[derive(Debug, Serialize, Deserialize, Default)]
93pub struct CustomCommandsConfig {
94 #[serde(default)]
95 pub commands: Vec<CustomCommand>,
96}
97
98pub fn get_config_paths() -> Vec<PathBuf> {
100 let mut paths = Vec::new();
101
102 paths.push(PathBuf::from(".steer").join("commands.toml"));
104
105 if let Some(proj_dirs) = ProjectDirs::from("", "", "steer") {
107 paths.push(proj_dirs.config_dir().join("commands.toml"));
108 }
109
110 paths
111}
112
113pub fn get_config_path() -> PathBuf {
115 get_config_paths()
117 .into_iter()
118 .next()
119 .unwrap_or_else(|| PathBuf::from(".steer").join("commands.toml"))
120}
121
122pub fn load_custom_commands() -> Result<Vec<CustomCommand>, CustomCommandError> {
124 let mut all_commands = Vec::new();
125 let mut seen_names = std::collections::HashSet::new();
126
127 for config_path in get_config_paths() {
129 debug!("Checking for custom commands at: {}", config_path.display());
130
131 if !config_path.exists() {
132 debug!("Config not found at {}", config_path.display());
133 continue;
134 }
135
136 let config_content = match std::fs::read_to_string(&config_path) {
137 Ok(content) => content,
138 Err(e) => {
139 warn!("Failed to read {}: {}", config_path.display(), e);
140 continue;
141 }
142 };
143
144 let config: CustomCommandsConfig = match toml::from_str(&config_content) {
145 Ok(config) => config,
146 Err(e) => {
147 warn!(
148 "Failed to parse {}: {}. Skipping this config file.",
149 config_path.display(),
150 e
151 );
152 continue;
153 }
154 };
155
156 debug!(
157 "Found {} commands in {}",
158 config.commands.len(),
159 config_path.display()
160 );
161
162 let mut file_names = std::collections::HashSet::new();
164 for cmd in &config.commands {
165 if !file_names.insert(cmd.name().to_string()) {
166 warn!(
167 "Duplicate command '{}' found within {}. Skipping duplicates.",
168 cmd.name(),
169 config_path.display()
170 );
171 }
172 }
173
174 for cmd in config.commands {
176 match cmd.validate() {
178 Ok(()) => {
179 if seen_names.insert(cmd.name().to_string()) {
180 all_commands.push(cmd);
181 } else {
182 debug!(
183 "Skipping duplicate command '{}' from {}",
184 cmd.name(),
185 config_path.display()
186 );
187 }
188 }
189 Err(e) => {
190 warn!(
191 "Skipping invalid command from {}: {}",
192 config_path.display(),
193 e
194 );
195 }
196 }
197 }
198 }
199
200 debug!("Total custom commands loaded: {}", all_commands.len());
201 Ok(all_commands)
202}
203
204pub fn save_custom_commands(commands: &[CustomCommand]) -> Result<(), CustomCommandError> {
206 let config_path = get_config_path();
207
208 if let Some(parent) = config_path.parent() {
210 std::fs::create_dir_all(parent)?;
211 }
212
213 let config = CustomCommandsConfig {
214 commands: commands.to_vec(),
215 };
216
217 let config_content = toml::to_string_pretty(&config).map_err(|e| {
218 CustomCommandError::ConfigLoadError(format!("Failed to serialize config: {e}"))
219 })?;
220
221 std::fs::write(&config_path, config_content)?;
222
223 Ok(())
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_parse_prompt_command() {
232 let config_content = r#"
233[[commands]]
234type = "prompt"
235name = "standup"
236description = "Generate a standup report"
237prompt = "What did I work on today? Check git log and recent file changes."
238"#;
239
240 let config: CustomCommandsConfig = toml::from_str(config_content).unwrap();
241 assert_eq!(config.commands.len(), 1);
242
243 match &config.commands[0] {
244 CustomCommand::Prompt {
245 name,
246 description,
247 prompt,
248 } => {
249 assert_eq!(name, "standup");
250 assert_eq!(description, "Generate a standup report");
251 assert_eq!(
252 prompt,
253 "What did I work on today? Check git log and recent file changes."
254 );
255 }
256 }
257 }
258
259 #[test]
260 fn test_multiple_commands() {
261 let config_content = r#"
262[[commands]]
263type = "prompt"
264name = "test"
265description = "Run tests"
266prompt = "Run the test suite and show me any failures"
267
268[[commands]]
269type = "prompt"
270name = "review"
271description = "Code review helper"
272prompt = "Review the recent changes and suggest improvements"
273"#;
274
275 let config: CustomCommandsConfig = toml::from_str(config_content).unwrap();
276 assert_eq!(config.commands.len(), 2);
277 assert_eq!(config.commands[0].name(), "test");
278 assert_eq!(config.commands[1].name(), "review");
279 }
280}