Skip to main content

shirabe_core/
command.rs

1use crate::context::filter::ContextFilter;
2use crate::error::FrameworkResult;
3use crate::session::Session;
4use std::collections::HashMap;
5use std::fmt::Debug;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, RwLock};
9
10/// 解析后的命令参数和选项。
11#[derive(Debug, Clone, Default)]
12pub struct ParsedArgs {
13    /// 位置参数。
14    pub arguments: Vec<String>,
15    /// 提供的选项,例如 --option value。
16    /// 对于标志(没有值的选项),其值可能为空字符串或占位符。
17    pub options: HashMap<String, String>, // 选项名称 -> 选项值
18}
19
20/// 命令执行的异步动作的类型别名。
21pub type CommandAction = Box<
22    dyn Fn(
23            Arc<Session>,
24            ParsedArgs,
25        ) -> Pin<Box<dyn Future<Output = FrameworkResult<()>> + Send + Sync>>
26        + Send
27        + Sync,
28>;
29
30/// 代表一个机器人指令。
31pub struct Command {
32    /// 指令的主要名称。
33    pub name: String,
34    /// 指令的别名列表。
35    pub aliases: Vec<String>,
36    /// 指令功能的简要描述。
37    pub description: Option<String>,
38    /// 指令注册时关联的上下文过滤器。
39    pub filter: ContextFilter,
40    /// 指令执行时的异步动作。
41    pub action: CommandAction,
42    // TODO: 为帮助信息生成和验证添加参数和选项定义的字段
43    // pub arg_defs: Vec<CommandArgumentDef>,
44    // pub opt_defs: Vec<CommandOptionDef>,
45}
46
47impl Debug for Command {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("Command")
50            .field("name", &self.name)
51            .field("aliases", &self.aliases)
52            .field("description", &self.description)
53            .field("filter", &self.filter)
54            .field("action", &"Box<dyn Fn(...)>") // 不打印闭包本身
55            .finish()
56    }
57}
58
59/// 管理并执行指令。
60#[derive(Default, Debug)]
61pub struct CommandRegistry {
62    /// 存储指令,将指令名称/别名映射到指令定义。
63    pub commands: HashMap<String, Arc<Command>>,
64}
65
66impl CommandRegistry {
67    pub fn new() -> Self {
68        Self::default()
69    }
70
71    /// 注册一个指令。
72    /// 指令的主要名称及其所有别名都将被注册。
73    pub fn register(&mut self, command: Command) -> FrameworkResult<()> {
74        let command_arc = Arc::new(command);
75        // 注册主要名称
76        if self.commands.contains_key(&command_arc.name) {
77            // 可以返回错误或记录警告
78            tracing::warn!("指令 {} 已注册,将被覆盖。", command_arc.name);
79        }
80        self.commands
81            .insert(command_arc.name.clone(), Arc::clone(&command_arc));
82
83        // 注册别名
84        for alias in &command_arc.aliases {
85            if self.commands.contains_key(alias) {
86                tracing::warn!(
87                    "指令 {} 的别名 {} 已注册或与另一指令冲突,将被覆盖。",
88                    command_arc.name,
89                    alias
90                );
91            }
92            self.commands
93                .insert(alias.clone(), Arc::clone(&command_arc));
94        }
95        Ok(())
96    }
97
98    /// 解析消息并执行相应的指令(如果找到)。
99    ///
100    /// # 参数
101    ///
102    /// * `session`: 当前会话,提供上下文和机器人访问。
103    /// * `message_content`: 消息的原始文本内容。
104    /// * `prefixes`: 命令前缀集合。
105    ///
106    /// # 返回
107    ///
108    /// * `Ok(true)` 如果找到了指令并尝试执行。
109    /// * `Ok(false)` 如果没有找到指令(例如,没有前缀匹配或未知指令)。
110    /// * `Err(FrameworkError)` 如果在指令执行期间发生错误。
111    pub async fn parse_and_execute(
112        &self,
113        session: Arc<Session>,
114        message_content: &str,
115        // TODO: 从配置中获取前缀
116        prefixes: &[&str], // 允许传入前缀
117    ) -> FrameworkResult<bool> {
118        let mut potential_command_text: Option<&str> = None;
119
120        for prefix in prefixes {
121            if message_content.starts_with(prefix) {
122                potential_command_text =
123                    Some(message_content.trim_start_matches(prefix).trim_start());
124                break;
125            }
126        }
127
128        if potential_command_text.is_none() {
129            return Ok(false); // 不是指令 (没有匹配的前缀)
130        }
131
132        let text = potential_command_text.unwrap();
133        if text.is_empty() {
134            return Ok(false); // 只有前缀,没有指令名称
135        }
136
137        let parts: Vec<&str> = text.split_whitespace().collect();
138        let command_name = parts[0];
139
140        if let Some(command_arc) = self.commands.get(command_name) {
141            // 检查指令自身注册时绑定的过滤器
142            if !command_arc.filter.matches_session(&session) {
143                tracing::trace!(
144                    "指令 {} 找到,但其上下文过滤器不匹配当前会话。",
145                    command_name
146                );
147                return Ok(false); // 指令的上下文过滤器不匹配
148            }
149
150            tracing::debug!("正在执行指令: {}", command_name);
151
152            let mut parsed_args = ParsedArgs::default();
153            let mut i = 1; // 从指令名称后的第一个部分开始
154            while i < parts.len() {
155                let part = parts[i];
156                if part.starts_with("--") {
157                    let option_name = part.trim_start_matches("--").to_string();
158                    if i + 1 < parts.len() {
159                        let next_part = parts[i + 1];
160                        // 值不能以 '-' 开头 (除非它是负数等特定情况,但这里简化)
161                        if !next_part.starts_with('-') && !next_part.is_empty() {
162                            // 主要修改在这里
163                            parsed_args
164                                .options
165                                .insert(option_name, next_part.to_string());
166                            i += 1; // 消耗选项值部分
167                        } else {
168                            // 下一个 token 是另一个选项,或没有值,当前长选项是标志
169                            parsed_args.options.insert(option_name, String::new());
170                        }
171                    } else {
172                        // 没有更多 token 了,当前长选项是标志
173                        parsed_args.options.insert(option_name, String::new());
174                    }
175                } else if part.starts_with('-') && part.len() > 1 && !part.starts_with("--") {
176                    // 短选项逻辑
177                    for (idx, char_val) in part.char_indices() {
178                        if idx == 0 {
179                            continue;
180                        }
181                        parsed_args
182                            .options
183                            .insert(char_val.to_string(), String::new());
184                    }
185                } else {
186                    // 参数
187                    parsed_args.arguments.push(part.to_string());
188                }
189                i += 1;
190            }
191
192            // 执行指令的动作
193            (command_arc.action)(session, parsed_args).await?;
194            Ok(true) // 指令找到并尝试执行
195        } else {
196            tracing::trace!("未知指令: {}", command_name);
197            Ok(false) // 未知指令
198        }
199    }
200}
201
202/// 用于链式构建和注册指令的构建器。
203pub struct CommandBuilder {
204    name: String,
205    aliases: Vec<String>,
206    description: Option<String>,
207    filter: ContextFilter, // 从调用 command() 的上下文中捕获
208    action: Option<CommandAction>,
209    registry: Arc<RwLock<CommandRegistry>>, // 指向共享的指令注册表
210}
211
212impl CommandBuilder {
213    /// 创建一个新的 CommandBuilder。
214    /// 通常由 `Context::command()` 调用。
215    pub fn new(
216        name: String,
217        filter: ContextFilter,
218        registry: Arc<RwLock<CommandRegistry>>,
219    ) -> Self {
220        CommandBuilder {
221            name,
222            aliases: Vec::new(),
223            description: None,
224            filter,
225            action: None,
226            registry,
227        }
228    }
229
230    /// 为指令添加一个别名。
231    pub fn alias(mut self, alias: &str) -> Self {
232        self.aliases.push(alias.to_string());
233        self
234    }
235
236    /// 为指令设置描述。
237    pub fn description(mut self, description: &str) -> Self {
238        self.description = Some(description.to_string());
239        self
240    }
241
242    /// 设置指令的执行动作。
243    pub fn action<F, Fut>(mut self, f: F) -> Self
244    where
245        F: Fn(Arc<Session>, ParsedArgs) -> Fut + Send + Sync + 'static,
246        Fut: Future<Output = FrameworkResult<()>> + Send + Sync + 'static,
247    {
248        self.action = Some(Box::new(move |session, args| Box::pin(f(session, args))));
249        self
250    }
251
252    /// 构建并注册指令。
253    pub fn register(self) -> FrameworkResult<()> {
254        let action = self.action.ok_or_else(|| {
255            crate::error::FrameworkError::Command(format!("指令 '{}' 没有定义 action", self.name))
256        })?;
257
258        let command = Command {
259            name: self.name.clone(),
260            aliases: self.aliases,
261            description: self.description,
262            filter: self.filter,
263            action,
264        };
265
266        let mut registry_guard = self.registry.write().map_err(|_| {
267            crate::error::FrameworkError::Internal("无法获取 CommandRegistry 的写锁".to_string())
268        })?;
269
270        registry_guard.register(command)?;
271        tracing::info!("指令 '{}' 已注册", self.name);
272        Ok(())
273    }
274}