usage/spec/
flag.rs

1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17    pub name: String,
18    pub usage: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub help: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub help_long: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub help_md: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub help_first_line: Option<String>,
27    pub short: Vec<char>,
28    pub long: Vec<String>,
29    #[serde(skip_serializing_if = "is_false")]
30    pub required: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub deprecated: Option<String>,
33    #[serde(skip_serializing_if = "is_false")]
34    pub var: bool,
35    pub hide: bool,
36    pub global: bool,
37    #[serde(skip_serializing_if = "is_false")]
38    pub count: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub arg: Option<SpecArg>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub default: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub negate: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub env: Option<String>,
47}
48
49impl SpecFlag {
50    pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
51        let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
52        for (k, v) in node.props() {
53            match k {
54                "help" => flag.help = Some(v.ensure_string()?),
55                "long_help" => flag.help_long = Some(v.ensure_string()?),
56                "help_long" => flag.help_long = Some(v.ensure_string()?),
57                "help_md" => flag.help_md = Some(v.ensure_string()?),
58                "required" => flag.required = v.ensure_bool()?,
59                "var" => flag.var = v.ensure_bool()?,
60                "hide" => flag.hide = v.ensure_bool()?,
61                "deprecated" => {
62                    flag.deprecated = match v.value.as_bool() {
63                        Some(true) => Some("deprecated".into()),
64                        Some(false) => None,
65                        None => Some(v.ensure_string()?),
66                    }
67                }
68                "global" => flag.global = v.ensure_bool()?,
69                "count" => flag.count = v.ensure_bool()?,
70                "default" => flag.default = v.ensure_string().map(Some)?,
71                "negate" => flag.negate = v.ensure_string().map(Some)?,
72                "env" => flag.env = v.ensure_string().map(Some)?,
73                k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
74            }
75        }
76        if flag.default.is_some() {
77            flag.required = false;
78        }
79        for child in node.children() {
80            match child.name() {
81                "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
82                "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
83                "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
84                "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
85                "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
86                "required" => flag.required = child.arg(0)?.ensure_bool()?,
87                "var" => flag.var = child.arg(0)?.ensure_bool()?,
88                "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
89                "deprecated" => {
90                    flag.deprecated = match child.arg(0)?.ensure_bool() {
91                        Ok(true) => Some("deprecated".into()),
92                        Ok(false) => None,
93                        _ => Some(child.arg(0)?.ensure_string()?),
94                    }
95                }
96                "global" => flag.global = child.arg(0)?.ensure_bool()?,
97                "count" => flag.count = child.arg(0)?.ensure_bool()?,
98                "default" => flag.default = child.arg(0)?.ensure_string().map(Some)?,
99                "env" => flag.env = child.arg(0)?.ensure_string().map(Some)?,
100                "choices" => {
101                    if let Some(arg) = &mut flag.arg {
102                        arg.choices = Some(SpecChoices::parse(ctx, &child)?);
103                    } else {
104                        bail_parse!(
105                            ctx,
106                            child.node.name().span(),
107                            "flag must have value to have choices"
108                        )
109                    }
110                }
111                k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
112            }
113        }
114        flag.usage = flag.usage();
115        flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
116        Ok(flag)
117    }
118    pub fn usage(&self) -> String {
119        let mut parts = vec![];
120        let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
121        if name != self.name {
122            parts.push(format!("{}:", self.name));
123        }
124        if let Some(short) = self.short.first() {
125            parts.push(format!("-{short}"));
126        }
127        if let Some(long) = self.long.first() {
128            parts.push(format!("--{long}"));
129        }
130        let mut out = parts.join(" ");
131        if self.var {
132            out = format!("{out}…");
133        }
134        if let Some(arg) = &self.arg {
135            out = format!("{} {}", out, arg.usage());
136        }
137        out
138    }
139}
140
141impl From<&SpecFlag> for KdlNode {
142    fn from(flag: &SpecFlag) -> KdlNode {
143        let mut node = KdlNode::new("flag");
144        let name = flag
145            .short
146            .iter()
147            .map(|c| format!("-{c}"))
148            .chain(flag.long.iter().map(|s| format!("--{s}")))
149            .collect_vec()
150            .join(" ");
151        node.push(KdlEntry::new(name));
152        if let Some(desc) = &flag.help {
153            node.push(KdlEntry::new_prop("help", desc.clone()));
154        }
155        if let Some(desc) = &flag.help_long {
156            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
157            let mut node = KdlNode::new("long_help");
158            node.entries_mut().push(KdlEntry::new(desc.clone()));
159            children.nodes_mut().push(node);
160        }
161        if let Some(desc) = &flag.help_md {
162            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
163            let mut node = KdlNode::new("help_md");
164            node.entries_mut().push(KdlEntry::new(desc.clone()));
165            children.nodes_mut().push(node);
166        }
167        if flag.required {
168            node.push(KdlEntry::new_prop("required", true));
169        }
170        if flag.var {
171            node.push(KdlEntry::new_prop("var", true));
172        }
173        if flag.hide {
174            node.push(KdlEntry::new_prop("hide", true));
175        }
176        if flag.global {
177            node.push(KdlEntry::new_prop("global", true));
178        }
179        if flag.count {
180            node.push(KdlEntry::new_prop("count", true));
181        }
182        if let Some(negate) = &flag.negate {
183            node.push(KdlEntry::new_prop("negate", negate.clone()));
184        }
185        if let Some(env) = &flag.env {
186            node.push(KdlEntry::new_prop("env", env.clone()));
187        }
188        if let Some(deprecated) = &flag.deprecated {
189            node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
190        }
191        if let Some(arg) = &flag.arg {
192            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
193            children.nodes_mut().push(arg.into());
194        }
195        node
196    }
197}
198
199impl FromStr for SpecFlag {
200    type Err = UsageErr;
201    fn from_str(input: &str) -> Result<Self> {
202        let mut flag = Self::default();
203        let input = input.replace("...", "…").replace("…", " … ");
204        for part in input.split_whitespace() {
205            if let Some(name) = part.strip_suffix(':') {
206                flag.name = name.to_string();
207            } else if let Some(long) = part.strip_prefix("--") {
208                flag.long.push(long.to_string());
209            } else if let Some(short) = part.strip_prefix('-') {
210                if short.len() != 1 {
211                    return Err(InvalidFlag(
212                        short.to_string(),
213                        (0, input.len()).into(),
214                        input.to_string(),
215                    ));
216                }
217                flag.short.push(short.chars().next().unwrap());
218            } else if part == "…" {
219                if let Some(arg) = &mut flag.arg {
220                    arg.var = true;
221                } else {
222                    flag.var = true;
223                }
224            } else if part.starts_with('<') && part.ends_with('>')
225                || part.starts_with('[') && part.ends_with(']')
226            {
227                flag.arg = Some(part.to_string().parse()?);
228            } else {
229                return Err(InvalidFlag(
230                    part.to_string(),
231                    (0, input.len()).into(),
232                    input.to_string(),
233                ));
234            }
235        }
236        if flag.name.is_empty() {
237            flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
238        }
239        flag.usage = flag.usage();
240        Ok(flag)
241    }
242}
243
244#[cfg(feature = "clap")]
245impl From<&clap::Arg> for SpecFlag {
246    fn from(c: &clap::Arg) -> Self {
247        let required = c.is_required_set();
248        let help = c.get_help().map(|s| s.to_string());
249        let help_long = c.get_long_help().map(|s| s.to_string());
250        let help_first_line = help.as_ref().map(|s| string::first_line(s));
251        let hide = c.is_hide_set();
252        let var = matches!(
253            c.get_action(),
254            clap::ArgAction::Count | clap::ArgAction::Append
255        );
256        let default = c
257            .get_default_values()
258            .first()
259            .map(|s| s.to_string_lossy().to_string());
260        let short = c.get_short_and_visible_aliases().unwrap_or_default();
261        let long = c
262            .get_long_and_visible_aliases()
263            .unwrap_or_default()
264            .into_iter()
265            .map(|s| s.to_string())
266            .collect::<Vec<_>>();
267        let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
268        let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
269            let mut arg = SpecArg::from(
270                c.get_value_names()
271                    .map(|s| s.iter().map(|s| s.to_string()).join(" "))
272                    .unwrap_or(name.clone())
273                    .as_str(),
274            );
275
276            let choices = c
277                .get_possible_values()
278                .iter()
279                .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
280                .collect::<Vec<_>>();
281            if !choices.is_empty() {
282                arg.choices = Some(SpecChoices { choices });
283            }
284
285            Some(arg)
286        } else {
287            None
288        };
289        Self {
290            name,
291            usage: "".into(),
292            short,
293            long,
294            required,
295            help,
296            help_long,
297            help_md: None,
298            help_first_line,
299            var,
300            hide,
301            global: c.is_global_set(),
302            arg,
303            count: matches!(c.get_action(), clap::ArgAction::Count),
304            default,
305            deprecated: None,
306            negate: None,
307            env: None,
308        }
309    }
310}
311
312// #[cfg(feature = "clap")]
313// impl From<&SpecFlag> for clap::Arg {
314//     fn from(flag: &SpecFlag) -> Self {
315//         let mut a = clap::Arg::new(&flag.name);
316//         if let Some(desc) = &flag.help {
317//             a = a.help(desc);
318//         }
319//         if flag.required {
320//             a = a.required(true);
321//         }
322//         if let Some(arg) = &flag.arg {
323//             a = a.value_name(&arg.name);
324//             if arg.var {
325//                 a = a.action(clap::ArgAction::Append)
326//             } else {
327//                 a = a.action(clap::ArgAction::Set)
328//             }
329//         } else {
330//             a = a.action(clap::ArgAction::SetTrue)
331//         }
332//         // let mut a = clap::Arg::new(&flag.name)
333//         //     .required(flag.required)
334//         //     .action(clap::ArgAction::SetTrue);
335//         if let Some(short) = flag.short.first() {
336//             a = a.short(*short);
337//         }
338//         if let Some(long) = flag.long.first() {
339//             a = a.long(long);
340//         }
341//         for short in flag.short.iter().skip(1) {
342//             a = a.visible_short_alias(*short);
343//         }
344//         for long in flag.long.iter().skip(1) {
345//             a = a.visible_alias(long);
346//         }
347//         // cmd = cmd.arg(a);
348//         // if flag.multiple {
349//         //     a = a.multiple(true);
350//         // }
351//         // if flag.hide {
352//         //     a = a.hide_possible_values(true);
353//         // }
354//         a
355//     }
356// }
357
358impl Display for SpecFlag {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        write!(f, "{}", self.usage())
361    }
362}
363impl PartialEq for SpecFlag {
364    fn eq(&self, other: &Self) -> bool {
365        self.name == other.name
366    }
367}
368impl Eq for SpecFlag {}
369impl Hash for SpecFlag {
370    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
371        self.name.hash(state);
372    }
373}
374
375fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
376    long.first()
377        .map(|s| s.to_string())
378        .or_else(|| short.first().map(|c| c.to_string()))
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::Spec;
385    use insta::assert_snapshot;
386
387    #[test]
388    fn from_str() {
389        assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
390        assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
391        assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
392        assert_snapshot!("-f --flag…".parse::<SpecFlag>().unwrap(), @"-f --flag…");
393        assert_snapshot!("-f --flag …".parse::<SpecFlag>().unwrap(), @"-f --flag…");
394        assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
395        assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
396        assert_snapshot!("-f --flag… <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag… <arg>");
397        assert_snapshot!("-f --flag <arg>…".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>…");
398        assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
399        assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
400    }
401
402    #[test]
403    fn test_flag_with_env() {
404        let spec = Spec::parse(
405            &Default::default(),
406            r#"
407flag "--color" env="MYCLI_COLOR" help="Enable color output"
408flag "--verbose" env="MYCLI_VERBOSE"
409            "#,
410        )
411        .unwrap();
412
413        assert_snapshot!(spec, @r#"
414        flag --color help="Enable color output" env=MYCLI_COLOR
415        flag --verbose env=MYCLI_VERBOSE
416        "#);
417
418        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
419        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
420
421        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
422        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
423    }
424
425    #[test]
426    fn test_flag_with_env_child_node() {
427        let spec = Spec::parse(
428            &Default::default(),
429            r#"
430flag "--color" help="Enable color output" {
431    env "MYCLI_COLOR"
432}
433flag "--verbose" {
434    env "MYCLI_VERBOSE"
435}
436            "#,
437        )
438        .unwrap();
439
440        assert_snapshot!(spec, @r#"
441        flag --color help="Enable color output" env=MYCLI_COLOR
442        flag --verbose env=MYCLI_VERBOSE
443        "#);
444
445        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
446        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
447
448        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
449        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
450    }
451}