1use anyhow::{Context, Result};
4use serde::{Deserialize, Deserializer, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::Path;
8
9fn deserialize_args<'de, D>(deserializer: D) -> Result<Option<Vec<String>>, D::Error>
11where
12 D: Deserializer<'de>,
13{
14 use serde::de::{self, Visitor};
15 use std::fmt;
16
17 struct ArgsVisitor;
18
19 impl<'de> Visitor<'de> for ArgsVisitor {
20 type Value = Option<Vec<String>>;
21
22 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
23 formatter.write_str("a string or array of strings")
24 }
25
26 fn visit_none<E>(self) -> Result<Self::Value, E>
27 where
28 E: de::Error,
29 {
30 Ok(None)
31 }
32
33 fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
34 where
35 D: Deserializer<'de>,
36 {
37 deserializer.deserialize_any(ArgsValueVisitor).map(Some)
38 }
39 }
40
41 struct ArgsValueVisitor;
42
43 impl<'de> Visitor<'de> for ArgsValueVisitor {
44 type Value = Vec<String>;
45
46 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
47 formatter.write_str("a string or array of strings")
48 }
49
50 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
51 where
52 E: de::Error,
53 {
54 shell_words::split(value).map_err(|_| de::Error::custom("invalid shell arguments"))
56 }
57
58 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
59 where
60 A: de::SeqAccess<'de>,
61 {
62 let mut vec = Vec::new();
63 while let Some(element) = seq.next_element::<String>()? {
64 vec.push(element);
65 }
66 Ok(vec)
67 }
68 }
69
70 deserializer.deserialize_option(ArgsVisitor)
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ShimConfig {
76 pub path: String,
78
79 #[serde(deserialize_with = "deserialize_args", default)]
82 pub args: Option<Vec<String>>,
83
84 pub working_dir: Option<String>,
86
87 pub env: Option<HashMap<String, String>>,
89
90 pub hide_console: Option<bool>,
92
93 pub run_as_admin: Option<bool>,
95
96 pub signal_handling: Option<SignalHandling>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SignalHandling {
103 pub ignore_sigint: Option<bool>,
105
106 pub forward_signals: Option<bool>,
108
109 pub kill_on_exit: Option<bool>,
111}
112
113impl Default for SignalHandling {
114 fn default() -> Self {
115 Self {
116 ignore_sigint: Some(true),
117 forward_signals: Some(true),
118 kill_on_exit: Some(true),
119 }
120 }
121}
122
123impl ShimConfig {
124 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
126 let content = fs::read_to_string(&path)
127 .with_context(|| format!("Failed to read shim file: {}", path.as_ref().display()))?;
128
129 Self::parse(&content)
130 }
131
132 pub fn parse(content: &str) -> Result<Self> {
134 toml::from_str::<ShimConfig>(content)
135 .with_context(|| "Failed to parse shim configuration as TOML")
136 }
137
138 pub fn resolved_path(&self) -> String {
140 expand_env_vars(&self.path)
141 }
142
143 pub fn resolved_args(&self) -> Vec<String> {
145 if let Some(ref args) = self.args {
146 args.iter().map(|arg| expand_env_vars(arg)).collect()
147 } else {
148 Vec::new()
149 }
150 }
151
152 pub fn resolved_working_dir(&self) -> Option<String> {
154 self.working_dir.as_ref().map(|dir| expand_env_vars(dir))
155 }
156
157 pub fn resolved_env(&self) -> HashMap<String, String> {
159 self.env
160 .as_ref()
161 .map(|env| {
162 env.iter()
163 .map(|(k, v)| (k.clone(), expand_env_vars(v)))
164 .collect()
165 })
166 .unwrap_or_default()
167 }
168}
169
170fn expand_env_vars(input: &str) -> String {
172 let mut result = input.to_string();
173
174 while let Some(start) = result.find('$') {
176 let (var_start, var_end, var_name) = if result.chars().nth(start + 1) == Some('{') {
177 if let Some(end) = result[start + 2..].find('}') {
179 (start, start + end + 3, &result[start + 2..start + end + 2])
180 } else {
181 break;
182 }
183 } else {
184 let var_start_pos = start + 1;
186 let remaining = &result[var_start_pos..];
187
188 let var_len = remaining
192 .chars()
193 .take_while(|c| c.is_alphanumeric() || *c == '_')
194 .count();
195
196 if var_len == 0 {
197 result.replace_range(start..start + 1, "");
199 continue;
200 }
201
202 let var_end_pos = var_start_pos + var_len;
203 let var_name = &result[var_start_pos..var_end_pos];
204
205 let actual_var_name = if std::env::var(var_name).is_err() {
208 let parts: Vec<&str> = var_name.split('_').collect();
210 let mut found_var = None;
211
212 for i in (1..=parts.len()).rev() {
213 let candidate = parts[..i].join("_");
214 if std::env::var(&candidate).is_ok() {
215 found_var = Some((candidate, parts[..i].join("_").len()));
216 break;
217 }
218 }
219
220 found_var
221 } else {
222 Some((var_name.to_string(), var_len))
223 };
224
225 if let Some((_actual_name, actual_len)) = actual_var_name {
226 let actual_end_pos = var_start_pos + actual_len;
227 (
228 start,
229 actual_end_pos,
230 &result[var_start_pos..actual_end_pos],
231 )
232 } else {
233 (start, var_end_pos, var_name)
234 }
235 };
236
237 let replacement = std::env::var(var_name).unwrap_or_default();
238 result.replace_range(var_start..var_end, &replacement);
239 }
240
241 result
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_parse_toml_format() {
250 let content = r#"
251path = "/usr/bin/git"
252args = ["status", "-u"]
253
254[signal_handling]
255ignore_sigint = true
256kill_on_exit = true
257"#;
258
259 let config = ShimConfig::parse(content).unwrap();
260 assert_eq!(config.path, "/usr/bin/git");
261 assert_eq!(
262 config.args,
263 Some(vec!["status".to_string(), "-u".to_string()])
264 );
265
266 let signal_handling = config.signal_handling.unwrap();
267 assert_eq!(signal_handling.ignore_sigint, Some(true));
268 assert_eq!(signal_handling.kill_on_exit, Some(true));
269 }
270
271 #[test]
272 fn test_expand_env_vars() {
273 std::env::set_var("TEST_VAR", "test_value");
274
275 assert_eq!(expand_env_vars("$TEST_VAR"), "test_value");
276 assert_eq!(expand_env_vars("${TEST_VAR}"), "test_value");
277 assert_eq!(
278 expand_env_vars("prefix_$TEST_VAR_suffix"),
279 "prefix_test_value_suffix"
280 );
281 assert_eq!(expand_env_vars("${TEST_VAR}/path"), "test_value/path");
282
283 std::env::remove_var("TEST_VAR");
284 }
285}