Skip to main content

safe_chains/registry/
dispatch.rs

1use crate::parse::Token;
2use crate::verdict::{SafetyLevel, Verdict};
3
4use super::policy::check_owned;
5use super::types::*;
6use super::{CMD_HANDLERS, SUB_HANDLERS};
7
8fn has_flag_owned(tokens: &[Token], short: Option<&str>, long: &str) -> bool {
9    tokens[1..].iter().any(|t| {
10        t == long
11            || short.is_some_and(|s| t == s)
12            || t.as_str().starts_with(&format!("{long}="))
13    })
14}
15
16fn dispatch_first_arg(tokens: &[Token], patterns: &[String], level: SafetyLevel) -> Verdict {
17    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
18        return Verdict::Allowed(SafetyLevel::Inert);
19    }
20    let Some(arg) = tokens.get(1) else {
21        return Verdict::Denied;
22    };
23    let arg_str = arg.as_str();
24    let matches = patterns.iter().any(|p| {
25        if let Some(prefix) = p.strip_suffix('*') {
26            arg_str.starts_with(prefix)
27        } else {
28            arg_str == p
29        }
30    });
31    if matches { Verdict::Allowed(level) } else { Verdict::Denied }
32}
33
34fn dispatch_require_any(
35    tokens: &[Token],
36    require_any: &[String],
37    policy: &OwnedPolicy,
38    level: SafetyLevel,
39) -> Verdict {
40    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
41        return Verdict::Allowed(SafetyLevel::Inert);
42    }
43    let has_required = tokens[1..].iter().any(|t| {
44        require_any.iter().any(|r| {
45            t == r.as_str() || t.as_str().starts_with(&format!("{r}="))
46        })
47    });
48    if has_required && check_owned(tokens, policy) {
49        Verdict::Allowed(level)
50    } else {
51        Verdict::Denied
52    }
53}
54
55fn dispatch_sub(tokens: &[Token], sub: &SubSpec) -> Verdict {
56    match &sub.kind {
57        SubKind::Policy { policy, level } => {
58            if check_owned(tokens, policy) {
59                Verdict::Allowed(*level)
60            } else {
61                Verdict::Denied
62            }
63        }
64        SubKind::Guarded {
65            guard_long,
66            guard_short,
67            policy,
68            level,
69        } => {
70            if tokens.len() == 2 && matches!(tokens[1].as_str(), "--help" | "-h" | "help") {
71                return Verdict::Allowed(SafetyLevel::Inert);
72            }
73            if has_flag_owned(tokens, guard_short.as_deref(), guard_long)
74                && check_owned(tokens, policy)
75            {
76                Verdict::Allowed(*level)
77            } else {
78                Verdict::Denied
79            }
80        }
81        SubKind::Nested { subs, allow_bare } => {
82            if tokens.len() < 2 {
83                if *allow_bare {
84                    return Verdict::Allowed(SafetyLevel::Inert);
85                }
86                return Verdict::Denied;
87            }
88            let arg = tokens[1].as_str();
89            if matches!(arg, "--help" | "-h") {
90                if tokens.len() == 2 {
91                    return Verdict::Allowed(SafetyLevel::Inert);
92                }
93                return Verdict::Denied;
94            }
95            subs.iter()
96                .find(|s| s.name == arg)
97                .map(|s| dispatch_sub(&tokens[1..], s))
98                .unwrap_or(Verdict::Denied)
99        }
100        SubKind::AllowAll { level } => Verdict::Allowed(*level),
101        SubKind::WriteFlagged {
102            policy,
103            base_level,
104            write_flags,
105        } => {
106            if !check_owned(tokens, policy) {
107                return Verdict::Denied;
108            }
109            let has_write = tokens[1..].iter().any(|t| {
110                write_flags.iter().any(|f| t == f.as_str() || t.as_str().starts_with(&format!("{f}=")))
111            });
112            if has_write {
113                Verdict::Allowed(SafetyLevel::SafeWrite)
114            } else {
115                Verdict::Allowed(*base_level)
116            }
117        }
118        SubKind::FirstArgFilter { patterns, level } => {
119            dispatch_first_arg(tokens, patterns, *level)
120        }
121        SubKind::RequireAny {
122            require_any,
123            policy,
124            level,
125        } => dispatch_require_any(tokens, require_any, policy, *level),
126        SubKind::DelegateAfterSeparator { separator } => {
127            let sep_pos = tokens[1..].iter().position(|t| t == separator.as_str());
128            let Some(pos) = sep_pos else {
129                return Verdict::Denied;
130            };
131            let inner_start = pos + 2;
132            if inner_start >= tokens.len() {
133                return Verdict::Denied;
134            }
135            let inner = shell_words::join(tokens[inner_start..].iter().map(|t| t.as_str()));
136            crate::command_verdict(&inner)
137        }
138        SubKind::DelegateSkip { skip, .. } => {
139            if tokens.len() <= *skip {
140                return Verdict::Denied;
141            }
142            let inner = shell_words::join(tokens[*skip..].iter().map(|t| t.as_str()));
143            crate::command_verdict(&inner)
144        }
145        SubKind::Custom { handler_name } => {
146            SUB_HANDLERS
147                .get(handler_name.as_str())
148                .map(|f| f(tokens))
149                .unwrap_or(Verdict::Denied)
150        }
151    }
152}
153
154fn dispatch_structured(
155    tokens: &[Token],
156    bare_flags: &[String],
157    subs: &[SubSpec],
158    pre_standalone: &[String],
159    pre_valued: &[String],
160    bare_ok: bool,
161    first_arg: &[String],
162    first_arg_level: SafetyLevel,
163) -> Verdict {
164    let mut start = 1;
165    while start < tokens.len() {
166        let t = &tokens[start];
167        if !t.starts_with('-') {
168            break;
169        }
170        if pre_valued.iter().any(|f| t == f.as_str()) {
171            start += 2;
172            continue;
173        }
174        if pre_valued.iter().any(|f| t.as_str().starts_with(&format!("{f}="))) {
175            start += 1;
176            continue;
177        }
178        if pre_standalone.iter().any(|f| t == f.as_str()) {
179            start += 1;
180            continue;
181        }
182        break;
183    }
184    if start >= tokens.len() {
185        return if bare_ok { Verdict::Allowed(SafetyLevel::Inert) } else { Verdict::Denied };
186    }
187    let arg = tokens[start].as_str();
188    if start + 1 == tokens.len() && bare_flags.iter().any(|f| f == arg) {
189        return Verdict::Allowed(SafetyLevel::Inert);
190    }
191    if let Some(sub) = subs.iter().find(|s| s.name == arg) {
192        return dispatch_sub(&tokens[start..], sub);
193    }
194    if !first_arg.is_empty() {
195        let matches = first_arg.iter().any(|p| {
196            if let Some(prefix) = p.strip_suffix('*') {
197                arg.starts_with(prefix)
198            } else {
199                arg == p
200            }
201        });
202        if matches {
203            return Verdict::Allowed(first_arg_level);
204        }
205    }
206    Verdict::Denied
207}
208
209pub fn dispatch_spec(tokens: &[Token], spec: &CommandSpec) -> Verdict {
210    match &spec.kind {
211        CommandKind::Flat { policy, level } => {
212            if check_owned(tokens, policy) {
213                Verdict::Allowed(*level)
214            } else {
215                Verdict::Denied
216            }
217        }
218        CommandKind::FlatFirstArg { patterns, level } => {
219            dispatch_first_arg(tokens, patterns, *level)
220        }
221        CommandKind::FlatRequireAny {
222            require_any,
223            policy,
224            level,
225        } => dispatch_require_any(tokens, require_any, policy, *level),
226        CommandKind::Structured { bare_flags, subs, pre_standalone, pre_valued, bare_ok, first_arg, first_arg_level } => {
227            dispatch_structured(tokens, bare_flags, subs, pre_standalone, pre_valued, *bare_ok, first_arg, *first_arg_level)
228        }
229        CommandKind::Wrapper {
230            standalone,
231            valued,
232            positional_skip,
233            separator,
234            bare_ok,
235        } => {
236            let mut i = 1;
237            while i < tokens.len() {
238                let t = &tokens[i];
239                if let Some(sep) = separator
240                    && t == sep.as_str()
241                {
242                    i += 1;
243                    break;
244                }
245                if !t.starts_with('-') {
246                    break;
247                }
248                if valued.iter().any(|f| t == f.as_str()) {
249                    i += 2;
250                    continue;
251                }
252                if standalone.iter().any(|f| t == f.as_str()) {
253                    i += 1;
254                    continue;
255                }
256                i += 1;
257            }
258            for _ in 0..*positional_skip {
259                if i >= tokens.len() {
260                    return if *bare_ok {
261                        Verdict::Allowed(SafetyLevel::Inert)
262                    } else {
263                        Verdict::Denied
264                    };
265                }
266                i += 1;
267            }
268            if i >= tokens.len() {
269                return if *bare_ok {
270                    Verdict::Allowed(SafetyLevel::Inert)
271                } else {
272                    Verdict::Denied
273                };
274            }
275            let inner = shell_words::join(tokens[i..].iter().map(|t| t.as_str()));
276            crate::command_verdict(&inner)
277        }
278        CommandKind::Custom { handler_name } => {
279            CMD_HANDLERS
280                .get(handler_name.as_str())
281                .map(|f| f(tokens))
282                .unwrap_or(Verdict::Denied)
283        }
284    }
285}