Skip to main content

safe_chains/registry/
dispatch.rs

1use crate::parse::Token;
2use crate::verdict::{SafetyLevel, Verdict};
3
4use super::policy::check_owned;
5use super::types::*;
6use super::{CMD_HANDLERS, SUB_HANDLERS};
7
8fn has_flag_owned(tokens: &[Token], short: Option<&str>, long: &str) -> bool {
9    tokens[1..].iter().any(|t| {
10        t == long
11            || short.is_some_and(|s| t == s)
12            || t.as_str().starts_with(&format!("{long}="))
13    })
14}
15
16fn dispatch_first_arg(tokens: &[Token], patterns: &[String], level: SafetyLevel) -> Verdict {
17    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
18        return Verdict::Allowed(SafetyLevel::Inert);
19    }
20    let Some(arg) = tokens.get(1) else {
21        return Verdict::Denied;
22    };
23    let arg_str = arg.as_str();
24    let matches = patterns.iter().any(|p| {
25        if let Some(prefix) = p.strip_suffix('*') {
26            arg_str.starts_with(prefix)
27        } else {
28            arg_str == p
29        }
30    });
31    if matches { Verdict::Allowed(level) } else { Verdict::Denied }
32}
33
34fn dispatch_require_any(
35    tokens: &[Token],
36    require_any: &[String],
37    policy: &OwnedPolicy,
38    level: SafetyLevel,
39) -> Verdict {
40    if tokens.len() == 2 && (tokens[1] == "--help" || tokens[1] == "-h") {
41        return Verdict::Allowed(SafetyLevel::Inert);
42    }
43    let has_required = tokens[1..].iter().any(|t| {
44        require_any.iter().any(|r| {
45            t == r.as_str() || t.as_str().starts_with(&format!("{r}="))
46        })
47    });
48    if has_required && check_owned(tokens, policy) {
49        Verdict::Allowed(level)
50    } else {
51        Verdict::Denied
52    }
53}
54
55fn dispatch_sub(tokens: &[Token], sub: &SubSpec) -> Verdict {
56    match &sub.kind {
57        SubKind::Policy { policy, level } => {
58            if check_owned(tokens, policy) {
59                Verdict::Allowed(*level)
60            } else {
61                Verdict::Denied
62            }
63        }
64        SubKind::Guarded {
65            guard_long,
66            guard_short,
67            policy,
68            level,
69        } => {
70            if has_flag_owned(tokens, guard_short.as_deref(), guard_long)
71                && check_owned(tokens, policy)
72            {
73                Verdict::Allowed(*level)
74            } else {
75                Verdict::Denied
76            }
77        }
78        SubKind::Nested { subs, allow_bare } => {
79            if tokens.len() < 2 {
80                if *allow_bare {
81                    return Verdict::Allowed(SafetyLevel::Inert);
82                }
83                return Verdict::Denied;
84            }
85            let arg = tokens[1].as_str();
86            if *allow_bare && (arg == "--help" || arg == "-h") {
87                return Verdict::Allowed(SafetyLevel::Inert);
88            }
89            subs.iter()
90                .find(|s| s.name == arg)
91                .map(|s| dispatch_sub(&tokens[1..], s))
92                .unwrap_or(Verdict::Denied)
93        }
94        SubKind::AllowAll { level } => Verdict::Allowed(*level),
95        SubKind::WriteFlagged {
96            policy,
97            base_level,
98            write_flags,
99        } => {
100            if !check_owned(tokens, policy) {
101                return Verdict::Denied;
102            }
103            let has_write = tokens[1..].iter().any(|t| {
104                write_flags.iter().any(|f| t == f.as_str() || t.as_str().starts_with(&format!("{f}=")))
105            });
106            if has_write {
107                Verdict::Allowed(SafetyLevel::SafeWrite)
108            } else {
109                Verdict::Allowed(*base_level)
110            }
111        }
112        SubKind::FirstArgFilter { patterns, level } => {
113            dispatch_first_arg(tokens, patterns, *level)
114        }
115        SubKind::RequireAny {
116            require_any,
117            policy,
118            level,
119        } => dispatch_require_any(tokens, require_any, policy, *level),
120        SubKind::DelegateAfterSeparator { separator } => {
121            let sep_pos = tokens[1..].iter().position(|t| t == separator.as_str());
122            let Some(pos) = sep_pos else {
123                return Verdict::Denied;
124            };
125            let inner_start = pos + 2;
126            if inner_start >= tokens.len() {
127                return Verdict::Denied;
128            }
129            let inner = shell_words::join(tokens[inner_start..].iter().map(|t| t.as_str()));
130            crate::command_verdict(&inner)
131        }
132        SubKind::DelegateSkip { skip, .. } => {
133            if tokens.len() <= *skip {
134                return Verdict::Denied;
135            }
136            let inner = shell_words::join(tokens[*skip..].iter().map(|t| t.as_str()));
137            crate::command_verdict(&inner)
138        }
139        SubKind::Custom { handler_name } => {
140            SUB_HANDLERS
141                .get(handler_name.as_str())
142                .map(|f| f(tokens))
143                .unwrap_or(Verdict::Denied)
144        }
145    }
146}
147
148fn dispatch_structured(
149    tokens: &[Token],
150    bare_flags: &[String],
151    subs: &[SubSpec],
152    pre_standalone: &[String],
153    pre_valued: &[String],
154) -> Verdict {
155    let mut start = 1;
156    while start < tokens.len() {
157        let t = &tokens[start];
158        if !t.starts_with('-') {
159            break;
160        }
161        if pre_valued.iter().any(|f| t == f.as_str()) {
162            start += 2;
163            continue;
164        }
165        if pre_valued.iter().any(|f| t.as_str().starts_with(&format!("{f}="))) {
166            start += 1;
167            continue;
168        }
169        if pre_standalone.iter().any(|f| t == f.as_str()) {
170            start += 1;
171            continue;
172        }
173        break;
174    }
175    if start >= tokens.len() {
176        return Verdict::Denied;
177    }
178    let arg = tokens[start].as_str();
179    if start + 1 == tokens.len() && bare_flags.iter().any(|f| f == arg) {
180        return Verdict::Allowed(SafetyLevel::Inert);
181    }
182    subs.iter()
183        .find(|s| s.name == arg)
184        .map(|s| dispatch_sub(&tokens[start..], s))
185        .unwrap_or(Verdict::Denied)
186}
187
188pub fn dispatch_spec(tokens: &[Token], spec: &CommandSpec) -> Verdict {
189    match &spec.kind {
190        CommandKind::Flat { policy, level } => {
191            if check_owned(tokens, policy) {
192                Verdict::Allowed(*level)
193            } else {
194                Verdict::Denied
195            }
196        }
197        CommandKind::FlatFirstArg { patterns, level } => {
198            dispatch_first_arg(tokens, patterns, *level)
199        }
200        CommandKind::FlatRequireAny {
201            require_any,
202            policy,
203            level,
204        } => dispatch_require_any(tokens, require_any, policy, *level),
205        CommandKind::Structured { bare_flags, subs, pre_standalone, pre_valued } => {
206            dispatch_structured(tokens, bare_flags, subs, pre_standalone, pre_valued)
207        }
208        CommandKind::Wrapper {
209            standalone,
210            valued,
211            positional_skip,
212            separator,
213            bare_ok,
214        } => {
215            let mut i = 1;
216            while i < tokens.len() {
217                let t = &tokens[i];
218                if let Some(sep) = separator
219                    && t == sep.as_str()
220                {
221                    i += 1;
222                    break;
223                }
224                if !t.starts_with('-') {
225                    break;
226                }
227                if valued.iter().any(|f| t == f.as_str()) {
228                    i += 2;
229                    continue;
230                }
231                if standalone.iter().any(|f| t == f.as_str()) {
232                    i += 1;
233                    continue;
234                }
235                i += 1;
236            }
237            for _ in 0..*positional_skip {
238                if i >= tokens.len() {
239                    return if *bare_ok {
240                        Verdict::Allowed(SafetyLevel::Inert)
241                    } else {
242                        Verdict::Denied
243                    };
244                }
245                i += 1;
246            }
247            if i >= tokens.len() {
248                return if *bare_ok {
249                    Verdict::Allowed(SafetyLevel::Inert)
250                } else {
251                    Verdict::Denied
252                };
253            }
254            let inner = shell_words::join(tokens[i..].iter().map(|t| t.as_str()));
255            crate::command_verdict(&inner)
256        }
257        CommandKind::Custom { handler_name } => {
258            CMD_HANDLERS
259                .get(handler_name.as_str())
260                .map(|f| f(tokens))
261                .unwrap_or(Verdict::Denied)
262        }
263    }
264}