1use std::collections::HashMap;
46
47use proc_macro::TokenStream;
48use proc_macro2::Span;
49use quote::quote;
50use syn::{DeriveInput, LitStr, parse_macro_input};
51
52const KNOWN_PERMISSIONS: &[&str] = &[
57 "read",
58 "write",
59 "delete",
60 "execute",
61 "delegate",
62 "read_internal",
63 "read_sensitive",
64 "write_sensitive",
65 "declassify",
66 "ai:infer",
67 "ai:train",
68 "ai:exfiltrate",
69];
70
71fn check_permission(name: &str, span: Span) -> Result<(), syn::Error> {
72 if KNOWN_PERMISSIONS.contains(&name) {
73 Ok(())
74 } else {
75 Err(syn::Error::new(
76 span,
77 format!(
78 "unknown permission '{name}' (expected one of: {})",
79 KNOWN_PERMISSIONS.join(", ")
80 ),
81 ))
82 }
83}
84
85fn pascal_to_snake(name: &str) -> String {
86 let chars: Vec<char> = name.chars().collect();
87 let mut out = String::new();
88
89 for (i, ch) in chars.iter().enumerate() {
90 if ch.is_ascii_uppercase() {
91 let prev = i.checked_sub(1).and_then(|idx| chars.get(idx));
92 let next = chars.get(i + 1);
93 let starts_new_word = prev.is_some_and(|prev| {
94 prev.is_ascii_lowercase()
95 || prev.is_ascii_digit()
96 || (prev.is_ascii_uppercase()
97 && next.is_some_and(|next| next.is_ascii_lowercase()))
98 });
99
100 if starts_new_word && !out.ends_with('_') {
101 out.push('_');
102 }
103 out.push(ch.to_ascii_lowercase());
104 } else if *ch == '-' {
105 if !out.ends_with('_') {
106 out.push('_');
107 }
108 } else {
109 out.push(*ch);
110 }
111 }
112
113 out
114}
115
116#[proc_macro_derive(TypesecRole, attributes(role))]
120pub fn derive_typesec_role(input: TokenStream) -> TokenStream {
121 let input = parse_macro_input!(input as DeriveInput);
122 match derive_typesec_role_impl(input) {
123 Ok(ts) => ts.into(),
124 Err(e) => e.to_compile_error().into(),
125 }
126}
127
128fn derive_typesec_role_impl(input: DeriveInput) -> Result<proc_macro2::TokenStream, syn::Error> {
129 let struct_name = &input.ident;
130 let struct_name_str = struct_name.to_string().to_lowercase();
131
132 let role_attr = input
134 .attrs
135 .iter()
136 .find(|a| a.path().is_ident("role"))
137 .ok_or_else(|| {
138 syn::Error::new(
139 Span::call_site(),
140 "TypesecRole requires a #[role(permissions = \"...\", resources = \"...\")] attribute",
141 )
142 })?;
143
144 let mut permissions: Vec<String> = Vec::new();
146 let mut resources: Vec<String> = Vec::new();
147
148 role_attr.parse_nested_meta(|meta| {
149 if meta.path.is_ident("permissions") {
150 let value: LitStr = meta.value()?.parse()?;
151 permissions = value
152 .value()
153 .split(',')
154 .map(|s| s.trim().to_owned())
155 .filter(|s| !s.is_empty())
156 .collect();
157 for permission in &permissions {
158 check_permission(permission, value.span())?;
159 }
160 Ok(())
161 } else if meta.path.is_ident("resources") {
162 let value: LitStr = meta.value()?.parse()?;
163 resources = value
164 .value()
165 .split(',')
166 .map(|s| s.trim().to_owned())
167 .filter(|s| !s.is_empty())
168 .collect();
169 Ok(())
170 } else {
171 Err(meta.error("unknown role attribute key (expected 'permissions' or 'resources')"))
172 }
173 })?;
174
175 let perm_lits: Vec<LitStr> = permissions
176 .iter()
177 .map(|p| LitStr::new(p, Span::call_site()))
178 .collect();
179
180 let resource_lits: Vec<LitStr> = resources
181 .iter()
182 .map(|r| LitStr::new(r, Span::call_site()))
183 .collect();
184
185 let name_lit = LitStr::new(&struct_name_str, Span::call_site());
186
187 Ok(quote! {
188 impl typesec_core::role::Role for #struct_name {
189 fn name() -> &'static str {
190 #name_lit
191 }
192 fn permission_names() -> &'static [&'static str] {
193 &[#(#perm_lits),*]
194 }
195 fn resource_patterns() -> &'static [&'static str] {
196 &[#(#resource_lits),*]
197 }
198 }
199 })
200}
201
202#[proc_macro]
217pub fn policy(input: TokenStream) -> TokenStream {
218 match policy_impl(input.into()) {
219 Ok(ts) => ts.into(),
220 Err(e) => e.to_compile_error().into(),
221 }
222}
223
224fn policy_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, syn::Error> {
225 use syn::{
226 Ident, Token, braced,
227 parse::{Parse, ParseStream},
228 punctuated::Punctuated,
229 };
230
231 struct RoleDef {
233 name: Ident,
234 parent: Option<Ident>,
235 perms: Vec<Ident>,
236 resources: Vec<LitStr>,
237 }
238
239 struct PolicyParser(Vec<RoleDef>);
240
241 impl Parse for PolicyParser {
242 fn parse(input: ParseStream) -> syn::Result<Self> {
243 let mut roles = Vec::new();
244
245 while !input.is_empty() {
246 let kw: Ident = input.parse()?;
248 if kw != "role" {
249 return Err(syn::Error::new(kw.span(), "expected `role`"));
250 }
251
252 let name: Ident = input.parse()?;
254
255 let parent = if input.peek(Ident) {
256 let maybe_extends: Ident = input.parse()?;
257 if maybe_extends != "extends" {
258 return Err(syn::Error::new(
259 maybe_extends.span(),
260 "expected `extends` or `{`",
261 ));
262 }
263 Some(input.parse()?)
264 } else {
265 None
266 };
267
268 let content;
270 braced!(content in input);
271
272 let can_kw: Ident = content.parse()?;
274 if can_kw != "can" {
275 return Err(syn::Error::new(can_kw.span(), "expected `can`"));
276 }
277
278 let perm_content;
280 syn::bracketed!(perm_content in content);
281 let perms: Punctuated<Ident, Token![,]> =
282 perm_content.parse_terminated(Ident::parse, Token![,])?;
283
284 let on_kw: Ident = content.parse()?;
286 if on_kw != "on" {
287 return Err(syn::Error::new(on_kw.span(), "expected `on`"));
288 }
289
290 let res_content;
292 syn::bracketed!(res_content in content);
293 let resources: Punctuated<LitStr, Token![,]> =
294 res_content.parse_terminated(Parse::parse, Token![,])?;
295
296 let _ = content.parse::<Token![;]>();
298
299 roles.push(RoleDef {
300 name,
301 parent,
302 perms: perms.into_iter().collect(),
303 resources: resources.into_iter().collect(),
304 });
305 }
306
307 Ok(PolicyParser(roles))
308 }
309 }
310
311 let parsed: PolicyParser = syn::parse2(input)?;
312 let role_index: HashMap<String, usize> = parsed
313 .0
314 .iter()
315 .enumerate()
316 .map(|(idx, role)| (role.name.to_string(), idx))
317 .collect();
318 let mut output = proc_macro2::TokenStream::new();
319
320 fn flatten_role(
321 idx: usize,
322 roles: &[RoleDef],
323 role_index: &HashMap<String, usize>,
324 visiting: &mut Vec<String>,
325 ) -> Result<(Vec<String>, Vec<LitStr>), syn::Error> {
326 let role = &roles[idx];
327 let role_name = role.name.to_string();
328 if visiting.contains(&role_name) {
329 return Err(syn::Error::new(
330 role.name.span(),
331 format!("circular role inheritance detected for `{role_name}`"),
332 ));
333 }
334
335 visiting.push(role_name);
336
337 let mut permissions = Vec::new();
338 let mut resources = Vec::new();
339
340 if let Some(parent) = &role.parent {
341 let parent_name = parent.to_string();
342 let parent_idx = role_index.get(&parent_name).ok_or_else(|| {
343 syn::Error::new(
344 parent.span(),
345 format!("role `{}` extends unknown role `{parent_name}`", role.name),
346 )
347 })?;
348 let (parent_permissions, parent_resources) =
349 flatten_role(*parent_idx, roles, role_index, visiting)?;
350 permissions.extend(parent_permissions);
351 resources.extend(parent_resources);
352 }
353
354 for perm in &role.perms {
355 let perm_name = perm.to_string();
356 check_permission(&perm_name, perm.span())?;
357 if !permissions.contains(&perm_name) {
358 permissions.push(perm_name);
359 }
360 }
361
362 for resource in &role.resources {
363 if !resources
364 .iter()
365 .any(|existing: &LitStr| existing.value() == resource.value())
366 {
367 resources.push(resource.clone());
368 }
369 }
370
371 visiting.pop();
372 Ok((permissions, resources))
373 }
374
375 for (idx, role) in parsed.0.iter().enumerate() {
376 let name = &role.name;
377 let name_str = pascal_to_snake(&name.to_string());
378 let (permissions, resources) = flatten_role(idx, &parsed.0, &role_index, &mut Vec::new())?;
379 let perm_lits: Vec<LitStr> = permissions
380 .iter()
381 .map(|s| LitStr::new(s, Span::call_site()))
382 .collect();
383
384 let name_lit = LitStr::new(&name_str, Span::call_site());
385
386 output.extend(quote! {
387 #[derive(Debug, Clone, Copy)]
388 pub struct #name;
389
390 impl typesec_core::role::Role for #name {
391 fn name() -> &'static str { #name_lit }
392 fn permission_names() -> &'static [&'static str] { &[#(#perm_lits),*] }
393 fn resource_patterns() -> &'static [&'static str] { &[#(#resources),*] }
394 }
395 });
396 }
397
398 Ok(output)
399}
400
401#[cfg(test)]
402mod tests {
403 use quote::quote;
404
405 use super::{pascal_to_snake, policy_impl};
406
407 #[test]
408 fn converts_pascal_case_role_names_to_snake_case() {
409 assert_eq!(pascal_to_snake("AnalystReadOnly"), "analyst_read_only");
410 assert_eq!(pascal_to_snake("AITrainer"), "ai_trainer");
411 assert_eq!(pascal_to_snake("HTTPAuditLog"), "http_audit_log");
412 assert_eq!(pascal_to_snake("Reader"), "reader");
413 }
414
415 #[test]
416 fn policy_macro_rejects_unknown_parent_role() {
417 let err = policy_impl(quote! {
418 role Writer extends Reader {
419 can [write] on ["docs/*"];
420 }
421 })
422 .expect_err("unknown parent should fail");
423
424 assert!(err.to_string().contains("unknown role `Reader`"));
425 }
426
427 #[test]
428 fn policy_macro_rejects_cyclic_inheritance() {
429 let err = policy_impl(quote! {
430 role Reader extends Writer {
431 can [read] on ["docs/*"];
432 }
433 role Writer extends Reader {
434 can [write] on ["docs/*"];
435 }
436 })
437 .expect_err("inheritance cycle should fail");
438
439 assert!(err.to_string().contains("circular role inheritance"));
440 }
441}