zed_util/
shell.rs

1use std::{fmt, path::Path, sync::LazyLock};
2
3#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
4pub enum ShellKind {
5    #[default]
6    Posix,
7    Csh,
8    Tcsh,
9    Rc,
10    Fish,
11    PowerShell,
12    Nushell,
13    Cmd,
14}
15
16pub fn get_system_shell() -> String {
17    if cfg!(windows) {
18        get_windows_system_shell()
19    } else {
20        std::env::var("SHELL").unwrap_or("/bin/sh".to_string())
21    }
22}
23
24pub fn get_default_system_shell() -> String {
25    if cfg!(windows) {
26        get_windows_system_shell()
27    } else {
28        "/bin/sh".to_string()
29    }
30}
31
32pub fn get_windows_system_shell() -> String {
33    use std::path::PathBuf;
34
35    fn find_pwsh_in_programfiles(find_alternate: bool, find_preview: bool) -> Option<PathBuf> {
36        #[cfg(target_pointer_width = "64")]
37        let env_var = if find_alternate {
38            "ProgramFiles(x86)"
39        } else {
40            "ProgramFiles"
41        };
42
43        #[cfg(target_pointer_width = "32")]
44        let env_var = if find_alternate {
45            "ProgramW6432"
46        } else {
47            "ProgramFiles"
48        };
49
50        let install_base_dir = PathBuf::from(std::env::var_os(env_var)?).join("PowerShell");
51        install_base_dir
52            .read_dir()
53            .ok()?
54            .filter_map(Result::ok)
55            .filter(|entry| matches!(entry.file_type(), Ok(ft) if ft.is_dir()))
56            .filter_map(|entry| {
57                let dir_name = entry.file_name();
58                let dir_name = dir_name.to_string_lossy();
59
60                let version = if find_preview {
61                    let dash_index = dir_name.find('-')?;
62                    if &dir_name[dash_index + 1..] != "preview" {
63                        return None;
64                    };
65                    dir_name[..dash_index].parse::<u32>().ok()?
66                } else {
67                    dir_name.parse::<u32>().ok()?
68                };
69
70                let exe_path = entry.path().join("pwsh.exe");
71                if exe_path.exists() {
72                    Some((version, exe_path))
73                } else {
74                    None
75                }
76            })
77            .max_by_key(|(version, _)| *version)
78            .map(|(_, path)| path)
79    }
80
81    fn find_pwsh_in_msix(find_preview: bool) -> Option<PathBuf> {
82        let msix_app_dir =
83            PathBuf::from(std::env::var_os("LOCALAPPDATA")?).join("Microsoft\\WindowsApps");
84        if !msix_app_dir.exists() {
85            return None;
86        }
87
88        let prefix = if find_preview {
89            "Microsoft.PowerShellPreview_"
90        } else {
91            "Microsoft.PowerShell_"
92        };
93        msix_app_dir
94            .read_dir()
95            .ok()?
96            .filter_map(|entry| {
97                let entry = entry.ok()?;
98                if !matches!(entry.file_type(), Ok(ft) if ft.is_dir()) {
99                    return None;
100                }
101
102                if !entry.file_name().to_string_lossy().starts_with(prefix) {
103                    return None;
104                }
105
106                let exe_path = entry.path().join("pwsh.exe");
107                exe_path.exists().then_some(exe_path)
108            })
109            .next()
110    }
111
112    fn find_pwsh_in_scoop() -> Option<PathBuf> {
113        let pwsh_exe =
114            PathBuf::from(std::env::var_os("USERPROFILE")?).join("scoop\\shims\\pwsh.exe");
115        pwsh_exe.exists().then_some(pwsh_exe)
116    }
117
118    static SYSTEM_SHELL: LazyLock<String> = LazyLock::new(|| {
119        find_pwsh_in_programfiles(false, false)
120            .or_else(|| find_pwsh_in_programfiles(true, false))
121            .or_else(|| find_pwsh_in_msix(false))
122            .or_else(|| find_pwsh_in_programfiles(false, true))
123            .or_else(|| find_pwsh_in_msix(true))
124            .or_else(|| find_pwsh_in_programfiles(true, true))
125            .or_else(find_pwsh_in_scoop)
126            .map(|p| p.to_string_lossy().into_owned())
127            .unwrap_or("powershell.exe".to_string())
128    });
129
130    (*SYSTEM_SHELL).clone()
131}
132
133impl fmt::Display for ShellKind {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        match self {
136            ShellKind::Posix => write!(f, "sh"),
137            ShellKind::Csh => write!(f, "csh"),
138            ShellKind::Tcsh => write!(f, "tcsh"),
139            ShellKind::Fish => write!(f, "fish"),
140            ShellKind::PowerShell => write!(f, "powershell"),
141            ShellKind::Nushell => write!(f, "nu"),
142            ShellKind::Cmd => write!(f, "cmd"),
143            ShellKind::Rc => write!(f, "rc"),
144        }
145    }
146}
147
148impl ShellKind {
149    pub fn system() -> Self {
150        Self::new(&get_system_shell())
151    }
152
153    pub fn new(program: impl AsRef<Path>) -> Self {
154        let program = program.as_ref();
155        let Some(program) = program.file_name().and_then(|s| s.to_str()) else {
156            return if cfg!(windows) {
157                ShellKind::PowerShell
158            } else {
159                ShellKind::Posix
160            };
161        };
162        if program == "powershell"
163            || program.ends_with("powershell.exe")
164            || program == "pwsh"
165            || program.ends_with("pwsh.exe")
166        {
167            ShellKind::PowerShell
168        } else if program == "cmd" || program.ends_with("cmd.exe") {
169            ShellKind::Cmd
170        } else if program == "nu" {
171            ShellKind::Nushell
172        } else if program == "fish" {
173            ShellKind::Fish
174        } else if program == "csh" {
175            ShellKind::Csh
176        } else if program == "tcsh" {
177            ShellKind::Tcsh
178        } else if program == "rc" {
179            ShellKind::Rc
180        } else {
181            if cfg!(windows) {
182                ShellKind::PowerShell
183            } else {
184                // Some other shell detected, the user might install and use a
185                // unix-like shell.
186                ShellKind::Posix
187            }
188        }
189    }
190
191    pub fn to_shell_variable(self, input: &str) -> String {
192        match self {
193            Self::PowerShell => Self::to_powershell_variable(input),
194            Self::Cmd => Self::to_cmd_variable(input),
195            Self::Posix => input.to_owned(),
196            Self::Fish => input.to_owned(),
197            Self::Csh => input.to_owned(),
198            Self::Tcsh => input.to_owned(),
199            Self::Rc => input.to_owned(),
200            Self::Nushell => Self::to_nushell_variable(input),
201        }
202    }
203
204    fn to_cmd_variable(input: &str) -> String {
205        if let Some(var_str) = input.strip_prefix("${") {
206            if var_str.find(':').is_none() {
207                // If the input starts with "${", remove the trailing "}"
208                format!("%{}%", &var_str[..var_str.len() - 1])
209            } else {
210                // `${SOME_VAR:-SOME_DEFAULT}`, we currently do not handle this situation,
211                // which will result in the task failing to run in such cases.
212                input.into()
213            }
214        } else if let Some(var_str) = input.strip_prefix('$') {
215            // If the input starts with "$", directly append to "$env:"
216            format!("%{}%", var_str)
217        } else {
218            // If no prefix is found, return the input as is
219            input.into()
220        }
221    }
222    fn to_powershell_variable(input: &str) -> String {
223        if let Some(var_str) = input.strip_prefix("${") {
224            if var_str.find(':').is_none() {
225                // If the input starts with "${", remove the trailing "}"
226                format!("$env:{}", &var_str[..var_str.len() - 1])
227            } else {
228                // `${SOME_VAR:-SOME_DEFAULT}`, we currently do not handle this situation,
229                // which will result in the task failing to run in such cases.
230                input.into()
231            }
232        } else if let Some(var_str) = input.strip_prefix('$') {
233            // If the input starts with "$", directly append to "$env:"
234            format!("$env:{}", var_str)
235        } else {
236            // If no prefix is found, return the input as is
237            input.into()
238        }
239    }
240
241    fn to_nushell_variable(input: &str) -> String {
242        let mut result = String::new();
243        let mut source = input;
244        let mut is_start = true;
245
246        loop {
247            match source.chars().next() {
248                None => return result,
249                Some('$') => {
250                    source = Self::parse_nushell_var(&source[1..], &mut result, is_start);
251                    is_start = false;
252                }
253                Some(_) => {
254                    is_start = false;
255                    let chunk_end = source.find('$').unwrap_or(source.len());
256                    let (chunk, rest) = source.split_at(chunk_end);
257                    result.push_str(chunk);
258                    source = rest;
259                }
260            }
261        }
262    }
263
264    fn parse_nushell_var<'a>(source: &'a str, text: &mut String, is_start: bool) -> &'a str {
265        if source.starts_with("env.") {
266            text.push('$');
267            return source;
268        }
269
270        match source.chars().next() {
271            Some('{') => {
272                let source = &source[1..];
273                if let Some(end) = source.find('}') {
274                    let var_name = &source[..end];
275                    if !var_name.is_empty() {
276                        if !is_start {
277                            text.push_str("(");
278                        }
279                        text.push_str("$env.");
280                        text.push_str(var_name);
281                        if !is_start {
282                            text.push_str(")");
283                        }
284                        &source[end + 1..]
285                    } else {
286                        text.push_str("${}");
287                        &source[end + 1..]
288                    }
289                } else {
290                    text.push_str("${");
291                    source
292                }
293            }
294            Some(c) if c.is_alphabetic() || c == '_' => {
295                let end = source
296                    .find(|c: char| !c.is_alphanumeric() && c != '_')
297                    .unwrap_or(source.len());
298                let var_name = &source[..end];
299                if !is_start {
300                    text.push_str("(");
301                }
302                text.push_str("$env.");
303                text.push_str(var_name);
304                if !is_start {
305                    text.push_str(")");
306                }
307                &source[end..]
308            }
309            _ => {
310                text.push('$');
311                source
312            }
313        }
314    }
315
316    pub fn args_for_shell(&self, interactive: bool, combined_command: String) -> Vec<String> {
317        match self {
318            ShellKind::PowerShell => vec!["-C".to_owned(), combined_command],
319            ShellKind::Cmd => vec!["/C".to_owned(), combined_command],
320            ShellKind::Posix
321            | ShellKind::Nushell
322            | ShellKind::Fish
323            | ShellKind::Csh
324            | ShellKind::Tcsh
325            | ShellKind::Rc => interactive
326                .then(|| "-i".to_owned())
327                .into_iter()
328                .chain(["-c".to_owned(), combined_command])
329                .collect(),
330        }
331    }
332
333    pub fn command_prefix(&self) -> Option<char> {
334        match self {
335            ShellKind::PowerShell => Some('&'),
336            ShellKind::Nushell => Some('^'),
337            _ => None,
338        }
339    }
340}