Skip to main content

safe_chains/registry/
dispatch.rs

1use crate::parse::Token;
2use crate::policy::FlagSet;
3use crate::verdict::{SafetyLevel, Verdict};
4
5use super::policy::check_owned;
6use super::types::*;
7use super::{CMD_HANDLERS, SUB_HANDLERS};
8
9type HandlerMap = std::collections::HashMap<&'static str, super::HandlerFn>;
10
11fn short_flag_char(s: &str) -> Option<char> {
12    let bytes = s.as_bytes();
13    if bytes.len() == 2 && bytes[0] == b'-' && bytes[1] != b'-' {
14        s.chars().nth(1)
15    } else {
16        None
17    }
18}
19
20fn is_combined_short(s: &str) -> bool {
21    let bytes = s.as_bytes();
22    bytes.len() > 2 && bytes[0] == b'-' && bytes[1] != b'-'
23}
24
25fn dispatch_first_arg(tokens: &[Token], patterns: &[String], level: SafetyLevel) -> Verdict {
26    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
27        return Verdict::Allowed(SafetyLevel::Inert);
28    }
29    let Some(arg) = tokens.get(1) else {
30        return Verdict::Denied;
31    };
32    let arg_str = arg.as_str();
33    let matches = patterns.iter().any(|p| {
34        if let Some(prefix) = p.strip_suffix('*') {
35            arg_str.starts_with(prefix)
36        } else {
37            arg_str == p
38        }
39    });
40    if matches { Verdict::Allowed(level) } else { Verdict::Denied }
41}
42
43fn dispatch_require_any(
44    tokens: &[Token],
45    require_any: &[String],
46    policy: &OwnedPolicy,
47    level: SafetyLevel,
48    accept_bare_help: bool,
49) -> Verdict {
50    if tokens.len() == 2 {
51        let t = tokens[1].as_str();
52        if t == "--help" || t == "-h" || (accept_bare_help && t == "help") {
53            return Verdict::Allowed(SafetyLevel::Inert);
54        }
55    }
56    let has_required = tokens[1..].iter().any(|t| {
57        require_any.iter().any(|r| {
58            let t_str = t.as_str();
59            if t_str == r.as_str() {
60                return true;
61            }
62            if r.starts_with("--") && t_str.starts_with(&format!("{r}=")) {
63                return true;
64            }
65            if let Some(short_char) = short_flag_char(r)
66                && is_combined_short(t_str)
67                && t_str[1..].contains(short_char)
68            {
69                return true;
70            }
71            false
72        })
73    });
74    if has_required && check_owned(tokens, policy) {
75        Verdict::Allowed(level)
76    } else {
77        Verdict::Denied
78    }
79}
80
81fn skip_pre_flags(
82    tokens: &[Token],
83    pre_standalone: &[String],
84    pre_valued: &[String],
85    start: usize,
86) -> usize {
87    let mut i = start;
88    while i < tokens.len() {
89        let t = &tokens[i];
90        let s = t.as_str();
91        if !s.starts_with('-') {
92            break;
93        }
94        if pre_valued.contains_flag(s) {
95            i += 2;
96            continue;
97        }
98        if let Some((flag, _)) = s.split_once('=')
99            && pre_valued.contains_flag(flag)
100        {
101            i += 1;
102            continue;
103        }
104        if pre_standalone.contains_flag(s) {
105            i += 1;
106            continue;
107        }
108        // POSIX-style short-flag cluster (`-vv`, `-vy`): every byte after
109        // the dash must be a known standalone short. Mirrors the same
110        // logic in policy::check_flags for non-wrapper subs.
111        let bytes = s.as_bytes();
112        if bytes.len() > 2
113            && bytes[1] != b'-'
114            && bytes[1..].iter().all(|&b| pre_standalone.contains_short(b))
115        {
116            i += 1;
117            continue;
118        }
119        break;
120    }
121    i
122}
123
124fn dispatch_branching(
125    tokens: &[Token],
126    subs: &[SubSpec],
127    bare_flags: &[String],
128    bare_ok: bool,
129    pre_standalone: &[String],
130    pre_valued: &[String],
131    first_arg: &[String],
132    first_arg_level: SafetyLevel,
133) -> Verdict {
134    let start = skip_pre_flags(tokens, pre_standalone, pre_valued, 1);
135    if start >= tokens.len() {
136        return if bare_ok { Verdict::Allowed(SafetyLevel::Inert) } else { Verdict::Denied };
137    }
138    let arg = tokens[start].as_str();
139    let is_bare_flag = bare_flags.iter().any(|f| f == arg)
140        || (bare_flags.is_empty() && matches!(arg, "--help" | "-h"));
141    if is_bare_flag {
142        let after = skip_pre_flags(tokens, pre_standalone, pre_valued, start + 1);
143        if after >= tokens.len() {
144            return Verdict::Allowed(SafetyLevel::Inert);
145        }
146        if bare_flags.is_empty() {
147            return Verdict::Denied;
148        }
149    }
150    if let Some(sub) = subs.iter().find(|s| s.name == arg) {
151        return dispatch_kind(&tokens[start..], &sub.kind, &SUB_HANDLERS);
152    }
153    if !first_arg.is_empty() {
154        let matches = first_arg.iter().any(|p| {
155            if let Some(prefix) = p.strip_suffix('*') {
156                arg.starts_with(prefix)
157            } else {
158                arg == p
159            }
160        });
161        if matches {
162            return Verdict::Allowed(first_arg_level);
163        }
164    }
165    Verdict::Denied
166}
167
168fn dispatch_wrapper(
169    tokens: &[Token],
170    standalone: &[String],
171    valued: &[String],
172    positional_skip: usize,
173    separator: Option<&str>,
174    bare_ok: bool,
175) -> Verdict {
176    let mut i = 1;
177    while i < tokens.len() {
178        let t = &tokens[i];
179        if let Some(sep) = separator
180            && t == sep
181        {
182            i += 1;
183            break;
184        }
185        if !t.starts_with('-') {
186            break;
187        }
188        if valued.iter().any(|f| t == f.as_str()) {
189            i += 2;
190            continue;
191        }
192        if valued.iter().any(|f| t.as_str().starts_with(&format!("{f}="))) {
193            i += 1;
194            continue;
195        }
196        if standalone.iter().any(|f| t == f.as_str()) {
197            i += 1;
198            continue;
199        }
200        return Verdict::Denied;
201    }
202    for _ in 0..positional_skip {
203        if i >= tokens.len() {
204            return if bare_ok {
205                Verdict::Allowed(SafetyLevel::Inert)
206            } else {
207                Verdict::Denied
208            };
209        }
210        i += 1;
211    }
212    if i >= tokens.len() {
213        return if bare_ok {
214            Verdict::Allowed(SafetyLevel::Inert)
215        } else {
216            Verdict::Denied
217        };
218    }
219    let inner = shell_words::join(tokens[i..].iter().map(|t| t.as_str()));
220    crate::command_verdict(&inner)
221}
222
223fn dispatch_kind(tokens: &[Token], kind: &DispatchKind, handlers: &HandlerMap) -> Verdict {
224    match kind {
225        DispatchKind::Policy { policy, level } => {
226            if check_owned(tokens, policy) {
227                Verdict::Allowed(*level)
228            } else {
229                Verdict::Denied
230            }
231        }
232        DispatchKind::FirstArg { patterns, level } => {
233            dispatch_first_arg(tokens, patterns, *level)
234        }
235        DispatchKind::RequireAny { require_any, policy, level, accept_bare_help } => {
236            dispatch_require_any(tokens, require_any, policy, *level, *accept_bare_help)
237        }
238        DispatchKind::Branching {
239            subs, bare_flags, bare_ok, pre_standalone, pre_valued, first_arg, first_arg_level,
240        } => {
241            dispatch_branching(
242                tokens, subs, bare_flags, *bare_ok, pre_standalone, pre_valued,
243                first_arg, *first_arg_level,
244            )
245        }
246        DispatchKind::WriteFlagged { policy, base_level, write_flags } => {
247            if !check_owned(tokens, policy) {
248                return Verdict::Denied;
249            }
250            let has_write = tokens[1..].iter().any(|t| {
251                write_flags.iter().any(|f| t == f.as_str() || t.as_str().starts_with(&format!("{f}=")))
252            });
253            if has_write {
254                Verdict::Allowed(SafetyLevel::SafeWrite)
255            } else {
256                Verdict::Allowed(*base_level)
257            }
258        }
259        DispatchKind::DelegateAfterSeparator { separator } => {
260            let sep_pos = tokens[1..].iter().position(|t| t == separator.as_str());
261            let Some(pos) = sep_pos else {
262                return Verdict::Denied;
263            };
264            let inner_start = pos + 2;
265            if inner_start >= tokens.len() {
266                return Verdict::Denied;
267            }
268            let inner = shell_words::join(tokens[inner_start..].iter().map(|t| t.as_str()));
269            crate::command_verdict(&inner)
270        }
271        DispatchKind::DelegateSkip { skip } => {
272            if tokens.len() <= *skip {
273                return Verdict::Denied;
274            }
275            let inner = shell_words::join(tokens[*skip..].iter().map(|t| t.as_str()));
276            crate::command_verdict(&inner)
277        }
278        DispatchKind::Wrapper {
279            standalone, valued, positional_skip, separator, bare_ok,
280        } => {
281            dispatch_wrapper(tokens, standalone, valued, *positional_skip, separator.as_deref(), *bare_ok)
282        }
283        DispatchKind::Custom { handler_name, .. } => {
284            handlers
285                .get(handler_name.as_str())
286                .map(|f| f(tokens))
287                .unwrap_or(Verdict::Denied)
288        }
289    }
290}
291
292pub fn dispatch_spec(tokens: &[Token], spec: &CommandSpec) -> Verdict {
293    dispatch_kind(tokens, &spec.kind, &CMD_HANDLERS)
294}
295
296/// Dispatches a sub's kind directly, used by `registry::try_sub_dispatch`
297/// when a handler-using command consults its TOML-declared subs.
298pub(super) fn dispatch_sub_kind(tokens: &[Token], kind: &DispatchKind) -> Verdict {
299    dispatch_kind(tokens, kind, &SUB_HANDLERS)
300}
301
302pub(super) fn check_handler_policy_owned(tokens: &[Token], policy: &OwnedPolicy) -> bool {
303    check_owned(tokens, policy)
304}
305
306pub(super) fn dispatch_matrix_action(
307    tokens: &[Token],
308    policy: &OwnedPolicy,
309    level: SafetyLevel,
310) -> Verdict {
311    if check_owned(tokens, policy) {
312        Verdict::Allowed(level)
313    } else {
314        Verdict::Denied
315    }
316}
317
318/// Applies a TOML-declared fallback grammar. Used by
319/// `registry::try_fallback_grammar()`.
320pub(super) fn dispatch_fallback(tokens: &[Token], spec: &FallbackSpec) -> Verdict {
321    if let Some(shape) = spec.positional_shape
322        && let Some(first) = super::policy::first_positional(tokens, &spec.policy)
323        && !shape.matches(first)
324    {
325        return Verdict::Denied;
326    }
327    if !check_owned(tokens, &spec.policy) {
328        return Verdict::Denied;
329    }
330    Verdict::Allowed(spec.level)
331}