script_wizard/
ask.rs

1use std::process::Command;
2
3use chrono::{NaiveDate, Weekday};
4use clap::ValueEnum;
5use inquire::{
6    autocompletion::Replacement, error::CustomUserError, Confirm, DateSelect, Editor, InquireError,
7    MultiSelect, Select, Text,
8};
9
10#[derive(Clone, ValueEnum)]
11pub enum Confirmation {
12    Yes,
13    No,
14}
15
16fn read_json_array(json: &str) -> Result<Vec<String>, CustomUserError> {
17    let a: Vec<String> = serde_json::from_str(json).expect("invalid json array");
18    Ok(a)
19}
20
21#[derive(Clone, Default)]
22pub struct AskAutoCompleter {
23    input: String,
24    suggestions_json: String,
25    suggestions: Vec<String>,
26    suggestion_index: usize,
27}
28
29impl AskAutoCompleter {
30    fn update_input(&mut self, input: &str) -> Result<(), CustomUserError> {
31        if input == self.input {
32            // No change:
33            return Ok(());
34        }
35        self.input = input.to_string();
36        self.suggestion_index = 0;
37        Ok(())
38    }
39}
40
41impl inquire::Autocomplete for AskAutoCompleter {
42    fn get_suggestions(&mut self, input: &str) -> Result<Vec<String>, CustomUserError> {
43        self.update_input(input)?;
44        self.suggestions = read_json_array(&self.suggestions_json)
45            .expect("Couldn't parse suggestions")
46            .iter()
47            .filter(|s| s.to_lowercase().contains(&input.to_lowercase()))
48            .map(|s| String::from(s.clone()))
49            .collect();
50        Ok(self.suggestions.clone())
51    }
52
53    fn get_completion(
54        &mut self,
55        input: &str,
56        highlighted_suggestion: Option<String>,
57    ) -> Result<Replacement, CustomUserError> {
58        self.update_input(input)?;
59        match highlighted_suggestion {
60            Some(suggestion) => Ok(Replacement::Some(suggestion)),
61            None => {
62                if self.suggestions.len() > 0 {
63                    self.suggestion_index = (self.suggestion_index + 1) % self.suggestions.len();
64                    Ok(Replacement::Some(
65                        self.suggestions
66                            .get(self.suggestion_index)
67                            .unwrap()
68                            .to_string(),
69                    ))
70                } else {
71                    Ok(Replacement::None)
72                }
73            }
74        }
75    }
76}
77
78pub fn ask_prompt(
79    question: &str,
80    default: &str,
81    allow_blank: bool,
82    suggestions_json: &str,
83) -> String {
84    if question == "" {
85        panic!("Blank question")
86    }
87    let mut auto_completer = AskAutoCompleter::default();
88    auto_completer.suggestions_json = suggestions_json.to_string();
89    match allow_blank {
90        true => {
91            let r: Result<String, InquireError>;
92            match default {
93                "" => {
94                    r = Text::new(question)
95                        .with_autocomplete(auto_completer.clone())
96                        .prompt();
97                }
98                _ => {
99                    r = Text::new(question)
100                        .with_autocomplete(auto_completer.clone())
101                        .with_default(default)
102                        .prompt();
103                }
104            }
105            if r.is_err() {
106                std::process::exit(1);
107            }
108            r.unwrap()
109        }
110        false => {
111            let mut a = String::from("");
112            while a == "" {
113                let r: Result<String, InquireError>;
114                match default {
115                    "" => {
116                        r = Text::new(question)
117                            .with_autocomplete(auto_completer.clone())
118                            .prompt();
119                    }
120                    _ => {
121                        r = Text::new(question)
122                            .with_default(default)
123                            .with_autocomplete(auto_completer.clone())
124                            .prompt();
125                    }
126                }
127                if r.is_err() {
128                    std::process::exit(1);
129                }
130                a = r.unwrap();
131            }
132            a
133        }
134    }
135}
136
137#[macro_export]
138macro_rules! ask {
139    ($question: expr, $default: expr, $allow_blank: expr, $suggestions_json: expr) => {
140        ask::ask_prompt($question, $default, $allow_blank, $suggestions_json)
141    };
142    ($question: expr, $default: expr, $allow_blank: expr) => {
143        ask::ask_prompt($question, $default, $allow_blank, "")
144    };
145    ($question: expr, $default: expr) => {
146        ask::ask_prompt($question, $default, false, "")
147    };
148    ($question: expr) => {
149        ask::ask_prompt($question, "", false, "")
150    };
151}
152pub use ask;
153
154pub fn confirm(question: &str, default_answer: Option<Confirmation>, cancel_code: u8) -> bool {
155    let mut c = Confirm::new(question);
156    match default_answer {
157        Some(Confirmation::Yes) => c = c.with_default(true),
158        Some(Confirmation::No) => c = c.with_default(false),
159        _ => (),
160    }
161    match c.prompt() {
162        Ok(true) => true,
163        Ok(false) => false,
164        Err(_) => std::process::exit(cancel_code.into()),
165    }
166}
167
168pub fn choose(
169    question: &str,
170    default: &str,
171    options: Vec<&str>,
172    numeric: &bool,
173    cancel_code: u8,
174) -> String {
175    let default_index: usize;
176    match default.trim().parse::<usize>() {
177        Ok(n) => {
178            default_index = n;
179        }
180        Err(_) => {
181            default_index = options.iter().position(|&r| r == default).unwrap_or(0);
182        }
183    }
184    let ans: Result<&str, InquireError> = Select::new(question, options.clone())
185        .with_starting_cursor(default_index)
186        .with_help_message("↑↓ to move, enter to select, type to filter, ESC to cancel")
187        .prompt();
188    match ans {
189        Ok(selection) => match numeric {
190            true => {
191                let index = options.iter().position(|&r| r == selection).unwrap();
192                format!("{}", index)
193            }
194            false => String::from(selection),
195        },
196        Err(_) => std::process::exit(cancel_code.into()),
197    }
198}
199
200pub fn select(question: &str, default: &str, options: Vec<&str>, cancel_code: u8) -> Vec<String> {
201    let defaults: Vec<&str> = serde_json::from_str(default).unwrap_or(vec![]);
202    let mut default_indices = vec![];
203    for (index, item) in options.iter().enumerate() {
204        match defaults.iter().find(|&r| r == item) {
205            Some(_) => default_indices.append(&mut vec![index]),
206            None => {}
207        };
208    }
209    let ans = MultiSelect::new(question, options)
210        .with_default(&default_indices)
211        .with_help_message("↑↓ to move, space to select one, → to all, ← to none, type to filter, ESC to cancel")
212        .prompt();
213    match ans {
214        Ok(selection) => selection.iter().map(|&x| x.into()).collect(),
215        Err(_) => std::process::exit(cancel_code.into()),
216    }
217}
218
219pub fn date(
220    question: &str,
221    default: &str,
222    min_date: &str,
223    max_date: &str,
224    starting_date: &str,
225    week_start: Weekday,
226    help_message: &str,
227    date_format: &str,
228) -> String {
229    let date = DateSelect::new(question)
230        .with_starting_date(
231            NaiveDate::parse_from_str(default, date_format)
232                .unwrap_or(chrono::Local::now().naive_local().into()),
233        )
234        .with_min_date(NaiveDate::parse_from_str(min_date, date_format).unwrap_or(NaiveDate::MIN))
235        .with_max_date(NaiveDate::parse_from_str(max_date, date_format).unwrap_or(NaiveDate::MAX))
236        .with_starting_date(
237            NaiveDate::parse_from_str(starting_date, date_format).unwrap_or(
238                NaiveDate::parse_from_str(min_date, date_format).unwrap_or(NaiveDate::MIN),
239            ),
240        )
241        .with_week_start(week_start)
242        .with_help_message(help_message)
243        .prompt()
244        .unwrap();
245    return date.format(date_format).to_string();
246}
247
248pub fn editor(message: &str, default: &str, help_message: &str, file_extension: &str) -> String {
249    let text = Editor::new(message)
250        .with_predefined_text(default)
251        .with_help_message(help_message)
252        .with_file_extension(file_extension)
253        .prompt()
254        .unwrap();
255    return text;
256}
257
258pub fn menu(
259    heading: &str,
260    entries: &Vec<String>,
261    default: &Option<String>,
262    once: &bool,
263    cancel_code: u8,
264) -> Result<usize, u8> {
265    let mut new_default: String = default.clone().unwrap_or("".to_string());
266    loop {
267        eprintln!("");
268        let titles: Vec<&str> = entries
269            .iter()
270            .map(|e| e.split(" = ").collect::<Vec<&str>>()[0])
271            .collect();
272        let commands: Vec<&str> = entries
273            .iter()
274            .map(|e| e.split(" = ").collect::<Vec<&str>>()[1])
275            .collect();
276        let command_index = choose(heading, new_default.as_str(), titles, &true, cancel_code)
277            .parse::<usize>()
278            .unwrap_or(1);
279
280        new_default = command_index.to_string();
281
282        // Run the command:
283        let cmd = commands[command_index];
284        let status = Command::new("/bin/bash")
285            .args(["-c", cmd])
286            .status()
287            .unwrap();
288
289        match status.code().unwrap_or(1) {
290            0 => {
291                //Keep looping unless --once is given:
292                if *once {
293                    return Ok(0);
294                }
295            }
296            2 => {
297                // Ok(2) signals to quit the loop:
298                return Ok(2);
299            }
300            _ => {
301                return Err(1);
302            }
303        }
304    }
305}