Skip to main content

safe_chains/registry/
dispatch.rs

1use crate::parse::Token;
2use crate::verdict::{SafetyLevel, Verdict};
3
4use super::policy::check_owned;
5use super::types::*;
6use super::{CMD_HANDLERS, SUB_HANDLERS};
7
8type HandlerMap = std::collections::HashMap<&'static str, super::HandlerFn>;
9
10fn short_flag_char(s: &str) -> Option<char> {
11    let bytes = s.as_bytes();
12    if bytes.len() == 2 && bytes[0] == b'-' && bytes[1] != b'-' {
13        s.chars().nth(1)
14    } else {
15        None
16    }
17}
18
19fn is_combined_short(s: &str) -> bool {
20    let bytes = s.as_bytes();
21    bytes.len() > 2 && bytes[0] == b'-' && bytes[1] != b'-'
22}
23
24fn dispatch_first_arg(tokens: &[Token], patterns: &[String], level: SafetyLevel) -> Verdict {
25    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
26        return Verdict::Allowed(SafetyLevel::Inert);
27    }
28    let Some(arg) = tokens.get(1) else {
29        return Verdict::Denied;
30    };
31    let arg_str = arg.as_str();
32    let matches = patterns.iter().any(|p| {
33        if let Some(prefix) = p.strip_suffix('*') {
34            arg_str.starts_with(prefix)
35        } else {
36            arg_str == p
37        }
38    });
39    if matches { Verdict::Allowed(level) } else { Verdict::Denied }
40}
41
42fn dispatch_require_any(
43    tokens: &[Token],
44    require_any: &[String],
45    policy: &OwnedPolicy,
46    level: SafetyLevel,
47    accept_bare_help: bool,
48) -> Verdict {
49    if tokens.len() == 2 {
50        let t = tokens[1].as_str();
51        if t == "--help" || t == "-h" || (accept_bare_help && t == "help") {
52            return Verdict::Allowed(SafetyLevel::Inert);
53        }
54    }
55    let has_required = tokens[1..].iter().any(|t| {
56        require_any.iter().any(|r| {
57            let t_str = t.as_str();
58            if t_str == r.as_str() {
59                return true;
60            }
61            if r.starts_with("--") && t_str.starts_with(&format!("{r}=")) {
62                return true;
63            }
64            if let Some(short_char) = short_flag_char(r)
65                && is_combined_short(t_str)
66                && t_str[1..].contains(short_char)
67            {
68                return true;
69            }
70            false
71        })
72    });
73    if has_required && check_owned(tokens, policy) {
74        Verdict::Allowed(level)
75    } else {
76        Verdict::Denied
77    }
78}
79
80fn skip_pre_flags(
81    tokens: &[Token],
82    pre_standalone: &[String],
83    pre_valued: &[String],
84) -> usize {
85    let mut i = 1;
86    while i < tokens.len() {
87        let t = &tokens[i];
88        if !t.starts_with('-') {
89            break;
90        }
91        if pre_valued.iter().any(|f| t == f.as_str()) {
92            i += 2;
93            continue;
94        }
95        if pre_valued.iter().any(|f| t.as_str().starts_with(&format!("{f}="))) {
96            i += 1;
97            continue;
98        }
99        if pre_standalone.iter().any(|f| t == f.as_str()) {
100            i += 1;
101            continue;
102        }
103        break;
104    }
105    i
106}
107
108fn dispatch_branching(
109    tokens: &[Token],
110    subs: &[SubSpec],
111    bare_flags: &[String],
112    bare_ok: bool,
113    pre_standalone: &[String],
114    pre_valued: &[String],
115    first_arg: &[String],
116    first_arg_level: SafetyLevel,
117) -> Verdict {
118    let start = skip_pre_flags(tokens, pre_standalone, pre_valued);
119    if start >= tokens.len() {
120        return if bare_ok { Verdict::Allowed(SafetyLevel::Inert) } else { Verdict::Denied };
121    }
122    let arg = tokens[start].as_str();
123    if bare_flags.is_empty() && matches!(arg, "--help" | "-h") {
124        if tokens.len() == start + 1 {
125            return Verdict::Allowed(SafetyLevel::Inert);
126        }
127        return Verdict::Denied;
128    }
129    if start + 1 == tokens.len() && bare_flags.iter().any(|f| f == arg) {
130        return Verdict::Allowed(SafetyLevel::Inert);
131    }
132    if let Some(sub) = subs.iter().find(|s| s.name == arg) {
133        return dispatch_kind(&tokens[start..], &sub.kind, &SUB_HANDLERS);
134    }
135    if !first_arg.is_empty() {
136        let matches = first_arg.iter().any(|p| {
137            if let Some(prefix) = p.strip_suffix('*') {
138                arg.starts_with(prefix)
139            } else {
140                arg == p
141            }
142        });
143        if matches {
144            return Verdict::Allowed(first_arg_level);
145        }
146    }
147    Verdict::Denied
148}
149
150fn dispatch_wrapper(
151    tokens: &[Token],
152    standalone: &[String],
153    valued: &[String],
154    positional_skip: usize,
155    separator: Option<&str>,
156    bare_ok: bool,
157) -> Verdict {
158    let mut i = 1;
159    while i < tokens.len() {
160        let t = &tokens[i];
161        if let Some(sep) = separator
162            && t == sep
163        {
164            i += 1;
165            break;
166        }
167        if !t.starts_with('-') {
168            break;
169        }
170        if valued.iter().any(|f| t == f.as_str()) {
171            i += 2;
172            continue;
173        }
174        if valued.iter().any(|f| t.as_str().starts_with(&format!("{f}="))) {
175            i += 1;
176            continue;
177        }
178        if standalone.iter().any(|f| t == f.as_str()) {
179            i += 1;
180            continue;
181        }
182        return Verdict::Denied;
183    }
184    for _ in 0..positional_skip {
185        if i >= tokens.len() {
186            return if bare_ok {
187                Verdict::Allowed(SafetyLevel::Inert)
188            } else {
189                Verdict::Denied
190            };
191        }
192        i += 1;
193    }
194    if i >= tokens.len() {
195        return if bare_ok {
196            Verdict::Allowed(SafetyLevel::Inert)
197        } else {
198            Verdict::Denied
199        };
200    }
201    let inner = shell_words::join(tokens[i..].iter().map(|t| t.as_str()));
202    crate::command_verdict(&inner)
203}
204
205fn dispatch_kind(tokens: &[Token], kind: &DispatchKind, handlers: &HandlerMap) -> Verdict {
206    match kind {
207        DispatchKind::Policy { policy, level } => {
208            if check_owned(tokens, policy) {
209                Verdict::Allowed(*level)
210            } else {
211                Verdict::Denied
212            }
213        }
214        DispatchKind::FirstArg { patterns, level } => {
215            dispatch_first_arg(tokens, patterns, *level)
216        }
217        DispatchKind::RequireAny { require_any, policy, level, accept_bare_help } => {
218            dispatch_require_any(tokens, require_any, policy, *level, *accept_bare_help)
219        }
220        DispatchKind::Branching {
221            subs, bare_flags, bare_ok, pre_standalone, pre_valued, first_arg, first_arg_level,
222        } => {
223            dispatch_branching(
224                tokens, subs, bare_flags, *bare_ok, pre_standalone, pre_valued,
225                first_arg, *first_arg_level,
226            )
227        }
228        DispatchKind::WriteFlagged { policy, base_level, write_flags } => {
229            if !check_owned(tokens, policy) {
230                return Verdict::Denied;
231            }
232            let has_write = tokens[1..].iter().any(|t| {
233                write_flags.iter().any(|f| t == f.as_str() || t.as_str().starts_with(&format!("{f}=")))
234            });
235            if has_write {
236                Verdict::Allowed(SafetyLevel::SafeWrite)
237            } else {
238                Verdict::Allowed(*base_level)
239            }
240        }
241        DispatchKind::DelegateAfterSeparator { separator } => {
242            let sep_pos = tokens[1..].iter().position(|t| t == separator.as_str());
243            let Some(pos) = sep_pos else {
244                return Verdict::Denied;
245            };
246            let inner_start = pos + 2;
247            if inner_start >= tokens.len() {
248                return Verdict::Denied;
249            }
250            let inner = shell_words::join(tokens[inner_start..].iter().map(|t| t.as_str()));
251            crate::command_verdict(&inner)
252        }
253        DispatchKind::DelegateSkip { skip } => {
254            if tokens.len() <= *skip {
255                return Verdict::Denied;
256            }
257            let inner = shell_words::join(tokens[*skip..].iter().map(|t| t.as_str()));
258            crate::command_verdict(&inner)
259        }
260        DispatchKind::Wrapper {
261            standalone, valued, positional_skip, separator, bare_ok,
262        } => {
263            dispatch_wrapper(tokens, standalone, valued, *positional_skip, separator.as_deref(), *bare_ok)
264        }
265        DispatchKind::Custom { handler_name } => {
266            handlers
267                .get(handler_name.as_str())
268                .map(|f| f(tokens))
269                .unwrap_or(Verdict::Denied)
270        }
271    }
272}
273
274pub fn dispatch_spec(tokens: &[Token], spec: &CommandSpec) -> Verdict {
275    dispatch_kind(tokens, &spec.kind, &CMD_HANDLERS)
276}