usage/spec/
flag.rs

1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17    pub name: String,
18    pub usage: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub help: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub help_long: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub help_md: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub help_first_line: Option<String>,
27    pub short: Vec<char>,
28    pub long: Vec<String>,
29    #[serde(skip_serializing_if = "is_false")]
30    pub required: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub deprecated: Option<String>,
33    #[serde(skip_serializing_if = "is_false")]
34    pub var: bool,
35    pub hide: bool,
36    pub global: bool,
37    #[serde(skip_serializing_if = "is_false")]
38    pub count: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub arg: Option<SpecArg>,
41    #[serde(skip_serializing_if = "Vec::is_empty")]
42    pub default: Vec<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub negate: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub env: Option<String>,
47}
48
49impl SpecFlag {
50    pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
51        let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
52        for (k, v) in node.props() {
53            match k {
54                "help" => flag.help = Some(v.ensure_string()?),
55                "long_help" => flag.help_long = Some(v.ensure_string()?),
56                "help_long" => flag.help_long = Some(v.ensure_string()?),
57                "help_md" => flag.help_md = Some(v.ensure_string()?),
58                "required" => flag.required = v.ensure_bool()?,
59                "var" => flag.var = v.ensure_bool()?,
60                "hide" => flag.hide = v.ensure_bool()?,
61                "deprecated" => {
62                    flag.deprecated = match v.value.as_bool() {
63                        Some(true) => Some("deprecated".into()),
64                        Some(false) => None,
65                        None => Some(v.ensure_string()?),
66                    }
67                }
68                "global" => flag.global = v.ensure_bool()?,
69                "count" => flag.count = v.ensure_bool()?,
70                "default" => {
71                    // Support both string and boolean defaults
72                    let default_value = match v.value.as_bool() {
73                        Some(b) => b.to_string(),
74                        None => v.ensure_string()?,
75                    };
76                    flag.default = vec![default_value];
77                }
78                "negate" => flag.negate = v.ensure_string().map(Some)?,
79                "env" => flag.env = v.ensure_string().map(Some)?,
80                k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
81            }
82        }
83        if !flag.default.is_empty() {
84            flag.required = false;
85        }
86        for child in node.children() {
87            match child.name() {
88                "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
89                "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
90                "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
91                "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
92                "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
93                "required" => flag.required = child.arg(0)?.ensure_bool()?,
94                "var" => flag.var = child.arg(0)?.ensure_bool()?,
95                "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
96                "deprecated" => {
97                    flag.deprecated = match child.arg(0)?.ensure_bool() {
98                        Ok(true) => Some("deprecated".into()),
99                        Ok(false) => None,
100                        _ => Some(child.arg(0)?.ensure_string()?),
101                    }
102                }
103                "global" => flag.global = child.arg(0)?.ensure_bool()?,
104                "count" => flag.count = child.arg(0)?.ensure_bool()?,
105                "default" => {
106                    // Support both single value and multiple values
107                    // default "bar"            -> vec!["bar"]
108                    // default #true            -> vec!["true"]
109                    // default { "xyz"; "bar" } -> vec!["xyz", "bar"]
110                    let children = child.children();
111                    if children.is_empty() {
112                        // Single value: default "bar" or default #true
113                        let arg = child.arg(0)?;
114                        let default_value = match arg.value.as_bool() {
115                            Some(b) => b.to_string(),
116                            None => arg.ensure_string()?,
117                        };
118                        flag.default = vec![default_value];
119                    } else {
120                        // Multiple values from children: default { "xyz"; "bar" }
121                        // In KDL, these are child nodes where the string is the node name
122                        flag.default = children.iter().map(|c| c.name().to_string()).collect();
123                    }
124                }
125                "env" => flag.env = child.arg(0)?.ensure_string().map(Some)?,
126                "choices" => {
127                    if let Some(arg) = &mut flag.arg {
128                        arg.choices = Some(SpecChoices::parse(ctx, &child)?);
129                    } else {
130                        bail_parse!(
131                            ctx,
132                            child.node.name().span(),
133                            "flag must have value to have choices"
134                        )
135                    }
136                }
137                k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
138            }
139        }
140        flag.usage = flag.usage();
141        flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
142        Ok(flag)
143    }
144    pub fn usage(&self) -> String {
145        let mut parts = vec![];
146        let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
147        if name != self.name {
148            parts.push(format!("{}:", self.name));
149        }
150        if let Some(short) = self.short.first() {
151            parts.push(format!("-{short}"));
152        }
153        if let Some(long) = self.long.first() {
154            parts.push(format!("--{long}"));
155        }
156        let mut out = parts.join(" ");
157        if self.var {
158            out = format!("{out}…");
159        }
160        if let Some(arg) = &self.arg {
161            out = format!("{} {}", out, arg.usage());
162        }
163        out
164    }
165}
166
167impl From<&SpecFlag> for KdlNode {
168    fn from(flag: &SpecFlag) -> KdlNode {
169        let mut node = KdlNode::new("flag");
170        let name = flag
171            .short
172            .iter()
173            .map(|c| format!("-{c}"))
174            .chain(flag.long.iter().map(|s| format!("--{s}")))
175            .collect_vec()
176            .join(" ");
177        node.push(KdlEntry::new(name));
178        if let Some(desc) = &flag.help {
179            node.push(KdlEntry::new_prop("help", desc.clone()));
180        }
181        if let Some(desc) = &flag.help_long {
182            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
183            let mut node = KdlNode::new("long_help");
184            node.entries_mut().push(KdlEntry::new(desc.clone()));
185            children.nodes_mut().push(node);
186        }
187        if let Some(desc) = &flag.help_md {
188            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
189            let mut node = KdlNode::new("help_md");
190            node.entries_mut().push(KdlEntry::new(desc.clone()));
191            children.nodes_mut().push(node);
192        }
193        if flag.required {
194            node.push(KdlEntry::new_prop("required", true));
195        }
196        if flag.var {
197            node.push(KdlEntry::new_prop("var", true));
198        }
199        if flag.hide {
200            node.push(KdlEntry::new_prop("hide", true));
201        }
202        if flag.global {
203            node.push(KdlEntry::new_prop("global", true));
204        }
205        if flag.count {
206            node.push(KdlEntry::new_prop("count", true));
207        }
208        if let Some(negate) = &flag.negate {
209            node.push(KdlEntry::new_prop("negate", negate.clone()));
210        }
211        if let Some(env) = &flag.env {
212            node.push(KdlEntry::new_prop("env", env.clone()));
213        }
214        if let Some(deprecated) = &flag.deprecated {
215            node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
216        }
217        // Serialize default values
218        if !flag.default.is_empty() {
219            if flag.default.len() == 1 {
220                // Single value: use property default="bar"
221                node.push(KdlEntry::new_prop("default", flag.default[0].clone()));
222            } else {
223                // Multiple values: use child node default { "xyz"; "bar" }
224                let children = node.children_mut().get_or_insert_with(KdlDocument::new);
225                let mut default_node = KdlNode::new("default");
226                let default_children = default_node
227                    .children_mut()
228                    .get_or_insert_with(KdlDocument::new);
229                for val in &flag.default {
230                    default_children
231                        .nodes_mut()
232                        .push(KdlNode::new(val.as_str()));
233                }
234                children.nodes_mut().push(default_node);
235            }
236        }
237        if let Some(arg) = &flag.arg {
238            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
239            children.nodes_mut().push(arg.into());
240        }
241        node
242    }
243}
244
245impl FromStr for SpecFlag {
246    type Err = UsageErr;
247    fn from_str(input: &str) -> Result<Self> {
248        let mut flag = Self::default();
249        let input = input.replace("...", "…").replace("…", " … ");
250        for part in input.split_whitespace() {
251            if let Some(name) = part.strip_suffix(':') {
252                flag.name = name.to_string();
253            } else if let Some(long) = part.strip_prefix("--") {
254                flag.long.push(long.to_string());
255            } else if let Some(short) = part.strip_prefix('-') {
256                if short.len() != 1 {
257                    return Err(InvalidFlag(
258                        short.to_string(),
259                        (0, input.len()).into(),
260                        input.to_string(),
261                    ));
262                }
263                flag.short.push(short.chars().next().unwrap());
264            } else if part == "…" {
265                if let Some(arg) = &mut flag.arg {
266                    arg.var = true;
267                } else {
268                    flag.var = true;
269                }
270            } else if part.starts_with('<') && part.ends_with('>')
271                || part.starts_with('[') && part.ends_with(']')
272            {
273                flag.arg = Some(part.to_string().parse()?);
274            } else {
275                return Err(InvalidFlag(
276                    part.to_string(),
277                    (0, input.len()).into(),
278                    input.to_string(),
279                ));
280            }
281        }
282        if flag.name.is_empty() {
283            flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
284        }
285        flag.usage = flag.usage();
286        Ok(flag)
287    }
288}
289
290#[cfg(feature = "clap")]
291impl From<&clap::Arg> for SpecFlag {
292    fn from(c: &clap::Arg) -> Self {
293        let required = c.is_required_set();
294        let help = c.get_help().map(|s| s.to_string());
295        let help_long = c.get_long_help().map(|s| s.to_string());
296        let help_first_line = help.as_ref().map(|s| string::first_line(s));
297        let hide = c.is_hide_set();
298        let var = matches!(
299            c.get_action(),
300            clap::ArgAction::Count | clap::ArgAction::Append
301        );
302        let default: Vec<String> = c
303            .get_default_values()
304            .iter()
305            .map(|s| s.to_string_lossy().to_string())
306            .collect();
307        let short = c.get_short_and_visible_aliases().unwrap_or_default();
308        let long = c
309            .get_long_and_visible_aliases()
310            .unwrap_or_default()
311            .into_iter()
312            .map(|s| s.to_string())
313            .collect::<Vec<_>>();
314        let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
315        let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
316            let mut arg = SpecArg::from(
317                c.get_value_names()
318                    .map(|s| s.iter().map(|s| s.to_string()).join(" "))
319                    .unwrap_or(name.clone())
320                    .as_str(),
321            );
322
323            let choices = c
324                .get_possible_values()
325                .iter()
326                .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
327                .collect::<Vec<_>>();
328            if !choices.is_empty() {
329                arg.choices = Some(SpecChoices { choices });
330            }
331
332            Some(arg)
333        } else {
334            None
335        };
336        Self {
337            name,
338            usage: "".into(),
339            short,
340            long,
341            required,
342            help,
343            help_long,
344            help_md: None,
345            help_first_line,
346            var,
347            hide,
348            global: c.is_global_set(),
349            arg,
350            count: matches!(c.get_action(), clap::ArgAction::Count),
351            default,
352            deprecated: None,
353            negate: None,
354            env: None,
355        }
356    }
357}
358
359// #[cfg(feature = "clap")]
360// impl From<&SpecFlag> for clap::Arg {
361//     fn from(flag: &SpecFlag) -> Self {
362//         let mut a = clap::Arg::new(&flag.name);
363//         if let Some(desc) = &flag.help {
364//             a = a.help(desc);
365//         }
366//         if flag.required {
367//             a = a.required(true);
368//         }
369//         if let Some(arg) = &flag.arg {
370//             a = a.value_name(&arg.name);
371//             if arg.var {
372//                 a = a.action(clap::ArgAction::Append)
373//             } else {
374//                 a = a.action(clap::ArgAction::Set)
375//             }
376//         } else {
377//             a = a.action(clap::ArgAction::SetTrue)
378//         }
379//         // let mut a = clap::Arg::new(&flag.name)
380//         //     .required(flag.required)
381//         //     .action(clap::ArgAction::SetTrue);
382//         if let Some(short) = flag.short.first() {
383//             a = a.short(*short);
384//         }
385//         if let Some(long) = flag.long.first() {
386//             a = a.long(long);
387//         }
388//         for short in flag.short.iter().skip(1) {
389//             a = a.visible_short_alias(*short);
390//         }
391//         for long in flag.long.iter().skip(1) {
392//             a = a.visible_alias(long);
393//         }
394//         // cmd = cmd.arg(a);
395//         // if flag.multiple {
396//         //     a = a.multiple(true);
397//         // }
398//         // if flag.hide {
399//         //     a = a.hide_possible_values(true);
400//         // }
401//         a
402//     }
403// }
404
405impl Display for SpecFlag {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        write!(f, "{}", self.usage())
408    }
409}
410impl PartialEq for SpecFlag {
411    fn eq(&self, other: &Self) -> bool {
412        self.name == other.name
413    }
414}
415impl Eq for SpecFlag {}
416impl Hash for SpecFlag {
417    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
418        self.name.hash(state);
419    }
420}
421
422fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
423    long.first()
424        .map(|s| s.to_string())
425        .or_else(|| short.first().map(|c| c.to_string()))
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::Spec;
432    use insta::assert_snapshot;
433
434    #[test]
435    fn from_str() {
436        assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
437        assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
438        assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
439        assert_snapshot!("-f --flag…".parse::<SpecFlag>().unwrap(), @"-f --flag…");
440        assert_snapshot!("-f --flag …".parse::<SpecFlag>().unwrap(), @"-f --flag…");
441        assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
442        assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
443        assert_snapshot!("-f --flag… <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag… <arg>");
444        assert_snapshot!("-f --flag <arg>…".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>…");
445        assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
446        assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
447    }
448
449    #[test]
450    fn test_flag_with_env() {
451        let spec = Spec::parse(
452            &Default::default(),
453            r#"
454flag "--color" env="MYCLI_COLOR" help="Enable color output"
455flag "--verbose" env="MYCLI_VERBOSE"
456            "#,
457        )
458        .unwrap();
459
460        assert_snapshot!(spec, @r#"
461        flag --color help="Enable color output" env=MYCLI_COLOR
462        flag --verbose env=MYCLI_VERBOSE
463        "#);
464
465        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
466        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
467
468        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
469        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
470    }
471
472    #[test]
473    fn test_flag_with_env_child_node() {
474        let spec = Spec::parse(
475            &Default::default(),
476            r#"
477flag "--color" help="Enable color output" {
478    env "MYCLI_COLOR"
479}
480flag "--verbose" {
481    env "MYCLI_VERBOSE"
482}
483            "#,
484        )
485        .unwrap();
486
487        assert_snapshot!(spec, @r#"
488        flag --color help="Enable color output" env=MYCLI_COLOR
489        flag --verbose env=MYCLI_VERBOSE
490        "#);
491
492        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
493        assert_eq!(color_flag.env, Some("MYCLI_COLOR".to_string()));
494
495        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
496        assert_eq!(verbose_flag.env, Some("MYCLI_VERBOSE".to_string()));
497    }
498
499    #[test]
500    fn test_flag_with_boolean_defaults() {
501        let spec = Spec::parse(
502            &Default::default(),
503            r#"
504flag "--color" default=#true
505flag "--verbose" default=#false
506flag "--debug" default="true"
507flag "--quiet" default="false"
508            "#,
509        )
510        .unwrap();
511
512        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
513        assert_eq!(color_flag.default, vec!["true".to_string()]);
514
515        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
516        assert_eq!(verbose_flag.default, vec!["false".to_string()]);
517
518        let debug_flag = spec.cmd.flags.iter().find(|f| f.name == "debug").unwrap();
519        assert_eq!(debug_flag.default, vec!["true".to_string()]);
520
521        let quiet_flag = spec.cmd.flags.iter().find(|f| f.name == "quiet").unwrap();
522        assert_eq!(quiet_flag.default, vec!["false".to_string()]);
523    }
524
525    #[test]
526    fn test_flag_with_boolean_defaults_child_node() {
527        let spec = Spec::parse(
528            &Default::default(),
529            r#"
530flag "--color" {
531    default #true
532}
533flag "--verbose" {
534    default #false
535}
536            "#,
537        )
538        .unwrap();
539
540        let color_flag = spec.cmd.flags.iter().find(|f| f.name == "color").unwrap();
541        assert_eq!(color_flag.default, vec!["true".to_string()]);
542
543        let verbose_flag = spec.cmd.flags.iter().find(|f| f.name == "verbose").unwrap();
544        assert_eq!(verbose_flag.default, vec!["false".to_string()]);
545    }
546
547    #[test]
548    fn test_flag_with_single_default() {
549        let spec = Spec::parse(
550            &Default::default(),
551            r#"
552flag "--foo <foo>" var=#true default="bar"
553            "#,
554        )
555        .unwrap();
556
557        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
558        assert_eq!(flag.var, true);
559        assert_eq!(flag.default, vec!["bar".to_string()]);
560    }
561
562    #[test]
563    fn test_flag_with_multiple_defaults_child_node() {
564        let spec = Spec::parse(
565            &Default::default(),
566            r#"
567flag "--foo <foo>" var=#true {
568    default {
569        "xyz"
570        "bar"
571    }
572}
573            "#,
574        )
575        .unwrap();
576
577        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
578        assert_eq!(flag.var, true);
579        assert_eq!(flag.default, vec!["xyz".to_string(), "bar".to_string()]);
580    }
581
582    #[test]
583    fn test_flag_with_single_default_child_node() {
584        let spec = Spec::parse(
585            &Default::default(),
586            r#"
587flag "--foo <foo>" var=#true {
588    default "bar"
589}
590            "#,
591        )
592        .unwrap();
593
594        let flag = spec.cmd.flags.iter().find(|f| f.name == "foo").unwrap();
595        assert_eq!(flag.var, true);
596        assert_eq!(flag.default, vec!["bar".to_string()]);
597    }
598
599    #[test]
600    fn test_flag_default_serialization_single() {
601        let spec = Spec::parse(
602            &Default::default(),
603            r#"
604flag "--foo <foo>" default="bar"
605            "#,
606        )
607        .unwrap();
608
609        // When serialized, single default should use property format
610        let output = spec.to_string();
611        assert!(output.contains("default=bar") || output.contains(r#"default="bar""#));
612    }
613
614    #[test]
615    fn test_flag_default_serialization_multiple() {
616        let spec = Spec::parse(
617            &Default::default(),
618            r#"
619flag "--foo <foo>" var=#true {
620    default {
621        "xyz"
622        "bar"
623    }
624}
625            "#,
626        )
627        .unwrap();
628
629        // When serialized, multiple defaults should use child node format
630        let output = spec.to_string();
631        // The output should contain a default block with children
632        assert!(output.contains("default {"));
633    }
634}