Skip to main content

supplement/core/
mod.rs

1//! The module defining CLI objects, including [`Command`], [`Flag`] and [`Arg`].
2//!
3//! Normally, everything here is constant created by code-gen [`crate::generate`].
4//! User can just call [`Command::supplement`] to start the CLI completion,
5//! without diving into the detail of these objects.
6
7mod flag;
8pub use crate::id;
9pub use flag::{CompleteWithEqual, Flag, flag_type};
10
11use crate::arg_context::ArgsContext;
12use crate::completion::{CompletionGroup, Unready};
13use crate::error::Error;
14use crate::parsed_flag::ParsedFlag;
15use crate::{Completion, History, Result};
16use std::fmt::Debug;
17use std::iter::Peekable;
18
19type PossibleValues = &'static [(&'static str, &'static str)];
20
21pub struct Arg<ID> {
22    pub id: id::Valued<ID>,
23    pub max_values: usize,
24    pub possible_values: PossibleValues,
25}
26
27fn comp_with_possible<ID>(
28    mut unready: Unready,
29    values: PossibleValues,
30    value: String,
31    id: id::Valued<ID>,
32) -> CompletionGroup<ID> {
33    let values = values.iter().map(|(v, d)| Completion::new(v, d));
34    unready = unready.preexist(values);
35    match id.id() {
36        Some(id) => CompletionGroup::Unready { id, unready, value },
37        None => CompletionGroup::Ready(unready.to_ready(vec![])),
38    }
39}
40
41/// The object to represent a command.
42/// User can just call [`Command::supplement`] function for CLI completion.
43pub struct Command<ID: 'static> {
44    pub name: &'static str,
45    pub description: &'static str,
46    pub all_flags: &'static [Flag<ID>],
47    pub args: &'static [Arg<ID>],
48    pub commands: &'static [Command<ID>],
49}
50
51fn supplement_arg<ID: PartialEq + Copy + Debug>(
52    history: &mut History<ID>,
53    ctx: &mut ArgsContext<ID>,
54    arg: String,
55) -> Result {
56    let Some(arg_obj) = ctx.next_arg() else {
57        return Err(Error::UnexpectedArg(arg));
58    };
59    history.push_valued(arg_obj.id, arg);
60    Ok(())
61}
62fn parse_flag(s: &str, disable_flag: bool) -> ParsedFlag<'_> {
63    if disable_flag {
64        log::info!("flag is disabled: {}", s);
65        ParsedFlag::NotFlag
66    } else {
67        ParsedFlag::new(s)
68    }
69}
70
71fn check_no_flag<ID>(arg: String, v: Vec<Completion>) -> Result<CompletionGroup<ID>> {
72    if v.is_empty() {
73        return Err(Error::UnexpectedFlag);
74    }
75    Ok(CompletionGroup::new_ready(v, arg))
76}
77
78impl<ID: 'static + Copy + PartialEq + Debug> Command<ID> {
79    /// The main entry point of CLI completion.
80    ///
81    /// ```
82    /// # use supplement::core::*;
83    /// # use supplement::*;
84    /// # use supplement::completion::CompletionGroup;
85    /// # type ID = u32;
86    /// const fn create_cmd(name: &'static str, subcmd: &'static [Command<ID>]) -> Command<ID> {
87    ///     Command {
88    ///         name,
89    ///         description: "",
90    ///         all_flags: &[],
91    ///         args: &[],
92    ///         commands: subcmd,
93    ///     }
94    /// }
95    ///
96    /// const CHECKOUT: Command<ID> = create_cmd("checkout", &[]);
97    /// const LOG: Command<ID> = create_cmd("log", &[]);
98    /// let root = create_cmd("qit", &[CHECKOUT, LOG]);
99    ///
100    /// let args = ["qit", ""].iter().map(|s| s.to_string());
101    /// let (_history, grp) = root.supplement(args).unwrap();
102    /// let ready = match grp {
103    ///     CompletionGroup::Ready(ready) => ready,
104    ///     CompletionGroup::Unready{ .. } => unreachable!(),
105    /// };
106    ///
107    /// let comps = ready.into_inner().0;
108    /// assert_eq!(comps[0], Completion::new("checkout", "").group("command"));
109    /// assert_eq!(comps[1], Completion::new("log", "").group("command"));
110    /// ```
111    pub fn supplement(
112        &self,
113        args: impl Iterator<Item = String>,
114    ) -> Result<(History<ID>, CompletionGroup<ID>)> {
115        let mut history = History::<ID>::new();
116        let grp = self.supplement_with_history(&mut history, args)?;
117        Ok((history, grp))
118    }
119
120    pub fn supplement_with_history(
121        &self,
122        history: &mut History<ID>,
123        mut args: impl Iterator<Item = String>,
124    ) -> Result<CompletionGroup<ID>> {
125        args.next(); // ignore the first arg which is the program's name
126
127        let mut args = args.peekable();
128        if args.peek().is_none() {
129            return Err(Error::ArgsTooShort);
130        }
131
132        self.supplement_recur(&mut None, history, &mut args)
133    }
134
135    fn doing_external(&self, ctx: &ArgsContext<'_, ID>) -> bool {
136        let has_subcmd = !self.commands.is_empty();
137        has_subcmd && ctx.has_seen_arg()
138    }
139    fn flags(&self, history: &History<ID>) -> impl Iterator<Item = &Flag<ID>> {
140        self.all_flags.iter().filter(|f| {
141            if !f.once {
142                true
143            } else {
144                let exists = f.exists_in_history(history);
145                if exists {
146                    log::debug!("flag {:?} already exists", f.name());
147                }
148                !exists
149            }
150        })
151    }
152
153    fn find_flag<F: FnMut(&Flag<ID>) -> bool>(
154        &self,
155        arg: &str,
156        history: &History<ID>,
157        mut filter: F,
158    ) -> Result<&Flag<ID>> {
159        match self.flags(history).find(|f| filter(f)) {
160            Some(flag) => Ok(flag),
161            None => Err(Error::FlagNotFound(arg.to_owned())),
162        }
163    }
164
165    fn find_long_flag(&self, flag: &str, history: &History<ID>) -> Result<&Flag<ID>> {
166        self.find_flag(flag, history, |f| f.long.iter().any(|l| *l == flag))
167    }
168    fn find_short_flag(&self, flag: char, history: &History<ID>) -> Result<&Flag<ID>> {
169        self.find_flag(&flag.to_string(), history, |f| {
170            f.short.iter().any(|s| *s == flag)
171        })
172    }
173
174    fn supplement_recur(
175        &self,
176        args_ctx_opt: &mut Option<ArgsContext<'_, ID>>,
177        history: &mut History<ID>,
178        args: &mut Peekable<impl Iterator<Item = String>>,
179    ) -> Result<CompletionGroup<ID>> {
180        let arg = args.next().unwrap();
181
182        let args_ctx = if let Some(ctx) = args_ctx_opt {
183            ctx
184        } else {
185            *args_ctx_opt = Some(ArgsContext::new(&self.args));
186            args_ctx_opt.as_mut().unwrap()
187        };
188
189        if args.peek().is_none() {
190            return self.supplement_last(args_ctx, history, arg);
191        }
192
193        macro_rules! handle_flag {
194            ($flag:expr, $equal:expr, $history:expr) => {
195                if let Some(equal) = $equal {
196                    match $flag.ty {
197                        flag_type::Type::Valued(flag) => flag.push($history, equal.to_string()),
198                        _ => return Err(Error::BoolFlagEqualsValue(arg)),
199                    }
200                } else {
201                    let res = $flag.supplement($history, args)?;
202                    if let Some(res) = res {
203                        return Ok(res);
204                    }
205                }
206            };
207        }
208
209        match parse_flag(&arg, self.doing_external(args_ctx)) {
210            ParsedFlag::SingleDash | ParsedFlag::DoubleDash | ParsedFlag::Empty => {
211                supplement_arg(history, args_ctx, arg)?;
212            }
213            ParsedFlag::NotFlag => {
214                let command = if args_ctx.has_seen_arg() {
215                    None
216                } else {
217                    self.commands.iter().find(|c| arg == c.name)
218                };
219                match command {
220                    Some(command) => {
221                        return command.supplement_recur(&mut None, history, args);
222                    }
223                    None => {
224                        log::info!("No subcommand. Try fallback args.");
225                        supplement_arg(history, args_ctx, arg)?;
226                    }
227                }
228            }
229            ParsedFlag::Long { body, equal } => {
230                let flag = self.find_long_flag(body, history)?;
231                handle_flag!(flag, equal, history);
232            }
233            ParsedFlag::Shorts => {
234                let resolved = self.resolve_shorts(history, &arg)?;
235                handle_flag!(resolved.last_flag, resolved.value, history);
236            }
237        }
238
239        self.supplement_recur(args_ctx_opt, history, args)
240    }
241
242    fn supplement_last(
243        &self,
244        args_ctx: &mut ArgsContext<'_, ID>,
245        history: &mut History<ID>,
246        arg: String,
247    ) -> Result<CompletionGroup<ID>> {
248        let ret: CompletionGroup<ID> = match parse_flag(&arg, self.doing_external(args_ctx)) {
249            ParsedFlag::Empty | ParsedFlag::NotFlag => {
250                let cmd_slice = if args_ctx.has_seen_arg() {
251                    log::info!("no completion for subcmd because we've already seen some args");
252                    &[]
253                } else {
254                    log::debug!("completion for {} subcommands", self.commands.len());
255                    self.commands
256                };
257                let cmd_comps = cmd_slice
258                    .iter()
259                    .map(|c| Completion::new(c.name, c.description).group("command"));
260
261                if let Some(arg_obj) = args_ctx.next_arg() {
262                    log::debug!("completion for args {:?}", arg_obj.id);
263                    let unready = Unready::new(String::new(), arg.clone()).preexist(cmd_comps);
264                    comp_with_possible(unready, arg_obj.possible_values, arg, arg_obj.id)
265                } else {
266                    if cmd_slice.is_empty() {
267                        return Err(Error::UnexpectedArg(arg));
268                    }
269                    CompletionGroup::new_ready(cmd_comps.collect(), arg)
270                }
271            }
272            ParsedFlag::DoubleDash | ParsedFlag::Long { equal: None, .. } => check_no_flag(
273                arg,
274                self.flags(history)
275                    .map(|f| f.gen_completion(Some(true)))
276                    .flatten()
277                    .collect(),
278            )?,
279            ParsedFlag::SingleDash => check_no_flag(
280                arg,
281                self.flags(history)
282                    .map(|f| f.gen_completion(None))
283                    .flatten()
284                    .collect(),
285            )?,
286            ParsedFlag::Long {
287                equal: Some(value),
288                body,
289            } => {
290                let flag = self.find_long_flag(body, history)?;
291                let valued = match flag.ty {
292                    flag_type::Type::Valued(valued) => valued,
293                    _ => return Err(Error::BoolFlagEqualsValue(arg)),
294                };
295                let prefix = format!("--{body}=");
296                let value = value.to_string();
297                let unready = Unready::new(prefix, arg);
298                comp_with_possible(unready, valued.possible_values, value, valued.id)
299            }
300            ParsedFlag::Shorts => self.supplement_last_short_flags(history, arg)?,
301        };
302        Ok(ret)
303    }
304
305    fn resolve_shorts<'a, 'b>(
306        &'b self,
307        history: &mut History<ID>,
308        shorts: &'a str,
309    ) -> Result<ResolvedMultiShort<'a, 'b, ID>> {
310        let mut chars = shorts.chars().peekable();
311        let mut len = 1; // ignore the first '-'
312        chars.next(); // ignore the first '-'
313        loop {
314            len += 1;
315            let ch = chars.next().unwrap();
316            let flag = self.find_short_flag(ch, history)?;
317            match chars.peek() {
318                None => {
319                    return Ok(ResolvedMultiShort {
320                        flag_part: shorts,
321                        last_flag: flag,
322                        value: None,
323                    });
324                }
325                Some('=') => {
326                    if matches!(flag.ty, flag_type::Type::Bool(_)) {
327                        return Err(Error::BoolFlagEqualsValue(shorts.to_owned()));
328                    };
329                    len += 1;
330                    return Ok(ResolvedMultiShort {
331                        flag_part: &shorts[..len],
332                        last_flag: flag,
333                        value: Some(&shorts[len..]),
334                    });
335                }
336                _ => {
337                    let valued = match flag.ty {
338                        flag_type::Type::Bool(inner) => {
339                            inner.push(history);
340                            continue;
341                        }
342                        flag_type::Type::Valued(valued) => valued,
343                    };
344
345                    match valued.complete_with_equal {
346                        CompleteWithEqual::Must => {
347                            return Err(Error::RequiresEqual(flag.name()));
348                        }
349                        CompleteWithEqual::Optional => {
350                            // TODO: Maybe one day clap will tell us.
351                            log::info!(
352                                "Optional flag {} doesn't have value. Push an empty string to history because we don't know its default value (clap wouldn't tell us).",
353                                flag.name(),
354                            );
355                            valued.push(history, String::new());
356                        }
357                        CompleteWithEqual::NoNeed => {
358                            return Ok(ResolvedMultiShort {
359                                flag_part: &shorts[..len],
360                                last_flag: flag,
361                                value: Some(&shorts[len..]),
362                            });
363                        }
364                    }
365                }
366            }
367        }
368    }
369
370    fn supplement_last_short_flags(
371        &self,
372        history: &mut History<ID>,
373        arg: String,
374    ) -> Result<CompletionGroup<ID>> {
375        let resolved = self.resolve_shorts(history, &arg)?;
376        let flag = resolved.last_flag;
377        let ret = match flag.ty {
378            flag_type::Type::Valued(valued) => {
379                let value = resolved.value.unwrap_or_default().to_string();
380                let mut eq = "";
381                if valued.complete_with_equal != CompleteWithEqual::NoNeed {
382                    if resolved.value.is_none() {
383                        // E.g. `cmd -af`
384                        // Want: `-af=opt1`, `-af=opt2`
385                        // NOTE: we don't want `-afx`, `-afy` where x and y are other flags. That's too much.
386                        eq = "=";
387                    } else {
388                        // E.g. `cmd -af=xyz` or `cmd -af=`.
389                        // Want: `-af=opt1`, `-af=opt2`
390                        // NOTE: It's impossible to be `cmd -afxyz`, either it throws error (Must) or `f` isn't the last flag (Optional).
391                    }
392                }
393                let prefix = format!("{}{}", resolved.flag_part, eq);
394                let unready = Unready::new(prefix, arg);
395                comp_with_possible(unready, valued.possible_values, value, valued.id)
396            }
397            flag_type::Type::Bool(inner) => {
398                log::debug!("list short flags with history {:?}", history);
399                inner.push(history);
400                let comps = self
401                    .flags(history)
402                    .map(|f| f.gen_completion(Some(false)))
403                    .flatten()
404                    .map(|c| {
405                        c.value(|v| {
406                            let flag = &v[1..]; // skip the first '-' character
407                            format!("{}{}", resolved.flag_part, flag)
408                        })
409                    })
410                    .collect();
411                check_no_flag(arg, comps)?
412            }
413        };
414        Ok(ret)
415    }
416}
417
418#[derive(Clone, Copy)]
419struct ResolvedMultiShort<'a, 'b, ID> {
420    flag_part: &'a str,
421    last_flag: &'b Flag<ID>,
422    value: Option<&'a str>,
423}