rustables_macros/
lib.rs

1#![allow(rustdoc::broken_intra_doc_links)]
2
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro2_diagnostics::{Diagnostic, Level, SpanDiagnosticExt};
10use quote::{quote, quote_spanned};
11
12use syn::parse::Parser;
13use syn::punctuated::Punctuated;
14use syn::spanned::Spanned;
15use syn::{
16    parse, Attribute, Expr, ExprCast, ExprLit, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path,
17    Token, Type, TypePath, Visibility,
18};
19
20use once_cell::sync::OnceCell;
21
22struct GlobalState {
23    declared_identifiers: Vec<String>,
24}
25
26static STATE: OnceCell<GlobalState> = OnceCell::new();
27
28fn get_state() -> &'static GlobalState {
29    STATE.get_or_init(|| {
30        let sys_file = {
31            // Load the header file and extract the constants defined inside.
32            // This is what determines whether optional attributes (or enum variants)
33            // will be supported or not in the resulting binary.
34            let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
35            let mut sys_file = String::new();
36            File::open(out_path)
37                .expect("Error: could not open the output header file")
38                .read_to_string(&mut sys_file)
39                .expect("Could not read the header file");
40            syn::parse_file(&sys_file).expect("Could not parse the header file")
41        };
42
43        let mut declared_identifiers = Vec::new();
44        for item in sys_file.items {
45            if let Item::Const(v) = item {
46                declared_identifiers.push(v.ident.to_string());
47            }
48        }
49
50        GlobalState {
51            declared_identifiers,
52        }
53    })
54}
55
56struct Field<'a> {
57    name: &'a Ident,
58    ty: &'a Type,
59    args: FieldArgs,
60    netlink_type: Path,
61    vis: &'a Visibility,
62    attrs: Vec<&'a Attribute>,
63}
64
65#[derive(Default)]
66struct FieldArgs {
67    netlink_type: Option<Path>,
68    override_function_name: Option<String>,
69    optional: bool,
70}
71
72fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs, Diagnostic> {
73    let mut args = FieldArgs::default();
74    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
75    let attribute_args = parser
76        .parse2(input)
77        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
78    for arg in attribute_args.iter() {
79        match arg {
80            Meta::Path(path) => {
81                if args.netlink_type.is_none() {
82                    args.netlink_type = Some(path.clone());
83                } else {
84                    return Err(arg
85                        .span()
86                        .error("Only a single netlink value can exist for a given field"));
87                }
88            }
89            Meta::NameValue(namevalue) => {
90                let key = namevalue
91                    .path
92                    .get_ident()
93                    .expect("the macro parameter is not an ident?")
94                    .to_string();
95                match key.as_str() {
96                    "name_in_functions" => {
97                        if let Expr::Lit(ExprLit {
98                            lit: Lit::Str(val), ..
99                        }) = &namevalue.value
100                        {
101                            args.override_function_name = Some(val.value());
102                        } else {
103                            return Err(namevalue.value.span().error("Expected a string literal"));
104                        }
105                    }
106                    "optional" => {
107                        if let Expr::Lit(ExprLit {
108                            lit: Lit::Bool(boolean),
109                            ..
110                        }) = &namevalue.value
111                        {
112                            args.optional = boolean.value;
113                        } else {
114                            return Err(namevalue.value.span().error("Expected a boolean"));
115                        }
116                    }
117                    _ => return Err(arg.span().error("Unsupported macro parameter")),
118                }
119            }
120            _ => return Err(arg.span().error("Unrecognized argument")),
121        }
122    }
123    Ok(args)
124}
125
126struct StructArgs {
127    nested: bool,
128    derive_decoder: bool,
129    derive_deserialize: bool,
130}
131
132impl Default for StructArgs {
133    fn default() -> Self {
134        Self {
135            nested: false,
136            derive_decoder: true,
137            derive_deserialize: true,
138        }
139    }
140}
141
142fn parse_struct_args(input: TokenStream) -> Result<StructArgs, Diagnostic> {
143    let mut args = StructArgs::default();
144    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
145    let attribute_args = parser
146        .parse(input.clone())
147        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
148    for arg in attribute_args.iter() {
149        if let Meta::NameValue(namevalue) = arg {
150            let key = namevalue
151                .path
152                .get_ident()
153                .expect("the macro parameter is not an ident?")
154                .to_string();
155            if let Expr::Lit(ExprLit {
156                lit: Lit::Bool(boolean),
157                ..
158            }) = &namevalue.value
159            {
160                match key.as_str() {
161                    "derive_decoder" => {
162                        args.derive_decoder = boolean.value;
163                    }
164                    "nested" => {
165                        args.nested = boolean.value;
166                    }
167                    "derive_deserialize" => {
168                        args.derive_deserialize = boolean.value;
169                    }
170                    _ => return Err(arg.span().error("Unsupported macro parameter")),
171                }
172            } else {
173                return Err(namevalue.value.span().error("Expected a boolean"));
174            }
175        } else {
176            return Err(arg.span().error("Unrecognized argument"));
177        }
178    }
179    Ok(args)
180}
181
182fn nfnetlink_struct_inner(
183    attrs: TokenStream,
184    item: TokenStream,
185) -> Result<TokenStream, Diagnostic> {
186    let ast: ItemStruct = parse(item).unwrap();
187    let name = ast.ident;
188
189    let args = match parse_struct_args(attrs) {
190        Ok(x) => x,
191        Err(e) => return Err(e),
192    };
193
194    let state = get_state();
195
196    let mut fields = Vec::with_capacity(ast.fields.len());
197    let mut identical_fields = Vec::new();
198
199    'out: for field in ast.fields.iter() {
200        for attr in field.attrs.iter() {
201            if let Some(id) = attr.path().get_ident() {
202                if id == "field" {
203                    let field_args = match &attr.meta {
204                        Meta::List(l) => l,
205                        _ => {
206                            return Err(attr.span().error("Invalid attributes"));
207                        }
208                    };
209
210                    let field_args = match parse_field_args(field_args.tokens.clone()) {
211                        Ok(x) => x,
212                        Err(_) => {
213                            return Err(attr.span().error("Could not parse the field attributes"));
214                        }
215                    };
216                    if let Some(netlink_type) = field_args.netlink_type.clone() {
217                        // optional fields are not generated when the kernel version you have on
218                        // the system does not support that field
219                        if field_args.optional {
220                            let netlink_type_ident = netlink_type
221                                .segments
222                                .last()
223                                .expect("empty path?")
224                                .ident
225                                .to_string();
226                            if !state.declared_identifiers.contains(&netlink_type_ident) {
227                                // reject the optional identifier
228                                continue 'out;
229                            }
230                        }
231
232                        fields.push(Field {
233                            name: field.ident.as_ref().expect("Should be a names struct"),
234                            ty: &field.ty,
235                            args: field_args,
236                            netlink_type,
237                            vis: &field.vis,
238                            // drop the "field" attribute
239                            attrs: field
240                                .attrs
241                                .iter()
242                                .filter(|x| x.path().get_ident() != attr.path().get_ident())
243                                .collect(),
244                        });
245                    } else {
246                        return Err(attr.span().error("Missing Netlink Type in field"));
247                    }
248                    continue 'out;
249                }
250            }
251        }
252        identical_fields.push(field);
253    }
254
255    let getters_and_setters = fields.iter().map(|field| {
256        let field_name = field.name;
257        // use the name override if any
258        let field_str = field_name.to_string();
259        let field_str = field
260            .args
261            .override_function_name
262            .as_ref()
263            .map(|x| x.as_str())
264            .unwrap_or(field_str.as_str());
265        let field_type = field.ty;
266
267        let getter_name = format!("get_{}", field_str);
268        let getter_name = Ident::new(&getter_name, field.name.span());
269
270        let muttable_getter_name = format!("get_mut_{}", field_str);
271        let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span());
272
273        let setter_name = format!("set_{}", field_str);
274        let setter_name = Ident::new(&setter_name, field.name.span());
275
276        let in_place_edit_name = format!("with_{}", field_str);
277        let in_place_edit_name = Ident::new(&in_place_edit_name, field.name.span());
278        quote!(
279            #[allow(dead_code)]
280            impl #name {
281            pub fn #getter_name(&self) -> Option<&#field_type> {
282                self.#field_name.as_ref()
283            }
284
285            pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> {
286                self.#field_name.as_mut()
287            }
288
289            pub fn #setter_name(&mut self, val: impl Into<#field_type>) {
290                self.#field_name = Some(val.into());
291            }
292
293            pub fn #in_place_edit_name(mut self, val: impl Into<#field_type>) -> Self {
294                self.#field_name = Some(val.into());
295                self
296            }
297        })
298    });
299
300    let decoder = if args.derive_decoder {
301        let match_entries = fields.iter().map(|field| {
302            let field_name = field.name;
303            let field_type = field.ty;
304            let netlink_value = &field.netlink_type;
305            quote!(
306                x if x == #netlink_value => {
307                    debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>());
308                    let (val, remaining) = <#field_type>::deserialize(buf)?;
309                    if remaining.len() != 0 {
310                        return Err(crate::error::DecodeError::InvalidDataSize);
311                    }
312                    self.#field_name = Some(val);
313                    Ok(())
314                }
315            )
316        });
317        quote!(
318            impl crate::nlmsg::AttributeDecoder for #name {
319                #[allow(dead_code)]
320                fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> {
321                    use crate::nlmsg::NfNetlinkDeserializable;
322                    debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>());
323                    match attr_type {
324                        #(#match_entries),*
325                        _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)),
326                    }
327                }
328            }
329        )
330    } else {
331        proc_macro2::TokenStream::new()
332    };
333
334    let nfnetlinkattribute_impl = {
335        let size_entries = fields.iter().map(|field| {
336            let field_name = field.name;
337            quote!(
338                if let Some(val) = &self.#field_name {
339                    // Attribute header + attribute value
340                    size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
341                        + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
342                }
343            )
344        });
345        let write_entries = fields.iter().map(|field| {
346            let field_name = field.name;
347            let field_str = field_name.to_string();
348            let netlink_value = &field.netlink_type;
349            quote!(
350                if let Some(val) = &self.#field_name {
351                    debug!("writing attribute {} - {:?}", #field_str, val);
352
353                    crate::parser::write_attribute(#netlink_value, val, addr);
354
355                    #[allow(unused)]
356                    {
357                        let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
358                            + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
359                        addr = &mut addr[size..];
360                    }
361                }
362            )
363        });
364        let nested = args.nested;
365        quote!(
366            impl crate::nlmsg::NfNetlinkAttribute for #name {
367                fn is_nested(&self) -> bool {
368                    #nested
369                }
370
371                fn get_size(&self) -> usize {
372                    use crate::nlmsg::NfNetlinkAttribute;
373
374                    let mut size = 0;
375                    #(#size_entries) *
376                    size
377                }
378
379                fn write_payload(&self, mut addr: &mut [u8]) {
380                    use crate::nlmsg::NfNetlinkAttribute;
381
382                    #(#write_entries) *
383                }
384            }
385        )
386    };
387
388    let vis = &ast.vis;
389    let attrs = ast.attrs;
390    let new_fields = fields.iter().map(|field| {
391        let name = field.name;
392        let ty = field.ty;
393        let attrs = &field.attrs;
394        let vis = &field.vis;
395        quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
396    });
397    let nfnetlinkdeserialize_impl = if args.derive_deserialize {
398        quote!(
399            impl crate::nlmsg::NfNetlinkDeserializable for #name {
400                fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
401                    Ok((crate::parser::read_attributes(buf)?, &[]))
402                }
403            }
404        )
405    } else {
406        proc_macro2::TokenStream::new()
407    };
408    let res = quote! {
409        #(#attrs) * #vis struct #name {
410            #(#new_fields)*
411            #(#identical_fields),*
412        }
413
414        #(#getters_and_setters) *
415
416        #decoder
417
418        #nfnetlinkattribute_impl
419
420        #nfnetlinkdeserialize_impl
421    };
422
423    Ok(res.into())
424}
425
426/// `nfnetlink_struct` is a macro wrapping structures that describe nftables objects.
427/// It allows serializing and deserializing these objects to the corresponding nfnetlink
428/// attributes.
429///
430/// It automatically generates getter and setter functions for each netlink properties.
431///
432/// # Parameters
433/// The macro have multiple parameters:
434/// - `nested` (defaults to `false`): the structure is nested (in the netlink sense)
435///   inside its parent structure. This is the case of most structures outside
436///   of the main nftables objects (batches, sets, rules, chains and tables), which are
437///   the outermost structures, and as such cannot be nested.
438/// - `derive_decoder` (defaults to `true`): derive a [`rustables::nlmsg::AttributeDecoder`]
439///   implementation for the structure
440/// - `derive_deserialize` (defaults to `true`): derive a [`rustables::nlmsg::NfNetlinkDeserializable`]
441///   implementation for the structure
442///
443/// # Example use
444/// ```ignore
445/// #[nfnetlink_struct(derive_deserialize = false)]
446/// #[derive(PartialEq, Eq, Default, Debug)]
447/// pub struct Chain {
448///     family: ProtocolFamily,
449///     #[field(NFTA_CHAIN_TABLE)]
450///     table: String,
451///     #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")]
452///     chain_type: ChainType,
453///     #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
454///     userdata: Vec<u8>,
455///     ...
456/// }
457/// ```
458///
459/// # Type of fields
460/// This contrived example show the two possible type of fields:
461/// - A field that is not converted to a netlink attribute (`family`) because it is not
462///   annotated in `#[field]` attribute.
463///   When deserialized, this field will take the value it is given in the Default implementation
464///   of the struct.
465/// - A field that is annotated with the `#[field]` attribute.
466///   That attribute takes parameters (there are none here), and the netlink attribute type.
467///   When annotated with that attribute, the macro will generate `get_<name>`, `set_<name>` and
468///   `with_<name>` methods to manipulate the attribute (e.g. `get_table`, `set_table` and
469///   `with_table`).
470///   It will also replace the field type (here `String`) with an Option (`Option<String>`)
471///   so the struct may represent objects where that attribute is not set.
472///
473/// # `#[field]` parameters
474/// The `#[field]` attribute can be parametrized through two options:
475/// - `optional` (defaults to `false`): if the netlink attribute type (here `NFTA_CHAIN_USERDATA`)
476///   does not exist, do not generate methods and ignore this attribute if encountered
477///   while deserializing a nftables object.
478///   This is useful for attributes added recently to the kernel, which may not be supported on
479///   older kernels.
480///   Support for an attribute is detected according to the existence of that attribute in the kernel
481///   headers.
482/// - `name_in_functions` (not defined by default): overwrite the `<name`> in the name of the methods
483///   `get_<name>`, `set_<name>` and `with_<name>`.
484///   Here, this means that even though the field is called `chain_type`, users can query it with
485///   the method `get_type` instead of `get_chain_type`.
486#[proc_macro_attribute]
487pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
488    match nfnetlink_struct_inner(attrs, item) {
489        Ok(tokens) => tokens.into(),
490        Err(diag) => diag.emit_as_item_tokens().into(),
491    }
492}
493
494struct Variant<'a> {
495    inner: &'a syn::Variant,
496    name: &'a Ident,
497    value: &'a Path,
498}
499
500#[derive(Default)]
501struct EnumArgs {
502    nested: bool,
503    ty: Option<Path>,
504}
505
506fn parse_enum_args(input: TokenStream) -> Result<EnumArgs, Diagnostic> {
507    let mut args = EnumArgs::default();
508    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
509    let attribute_args = parser
510        .parse(input)
511        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
512    for arg in attribute_args.iter() {
513        match arg {
514            Meta::Path(path) => {
515                if args.ty.is_none() {
516                    args.ty = Some(path.clone());
517                } else {
518                    return Err(arg
519                        .span()
520                        .error("A value can only have a single representation"));
521                }
522            }
523            Meta::NameValue(namevalue) => {
524                let key = namevalue
525                    .path
526                    .get_ident()
527                    .expect("the macro parameter is not an ident?")
528                    .to_string();
529                match key.as_str() {
530                    "nested" => {
531                        if let Expr::Lit(ExprLit {
532                            lit: Lit::Bool(boolean),
533                            ..
534                        }) = &namevalue.value
535                        {
536                            args.nested = boolean.value;
537                        } else {
538                            return Err(namevalue.value.span().error("Expected a boolean"));
539                        }
540                    }
541                    _ => return Err(arg.span().error("Unsupported macro parameter")),
542                }
543            }
544            _ => return Err(arg.span().error("Unrecognized argument")),
545        }
546    }
547    Ok(args)
548}
549
550fn nfnetlink_enum_inner(attrs: TokenStream, item: TokenStream) -> Result<TokenStream, Diagnostic> {
551    let ast: ItemEnum = parse(item).unwrap();
552    let name = ast.ident;
553
554    let args = match parse_enum_args(attrs) {
555        Ok(x) => x,
556        Err(_) => return Err(Span::call_site().error("Could not parse the macro arguments")),
557    };
558
559    if args.ty.is_none() {
560        return Err(Span::call_site().error("The target type representation is unspecified"));
561    }
562
563    let mut variants = Vec::with_capacity(ast.variants.len());
564
565    for variant in ast.variants.iter() {
566        if variant.discriminant.is_none() {
567            return Err(variant.ident.span().error("Missing value"));
568        }
569        let discriminant = variant.discriminant.as_ref().unwrap();
570        if let syn::Expr::Path(path) = &discriminant.1 {
571            variants.push(Variant {
572                inner: variant,
573                name: &variant.ident,
574                value: &path.path,
575            });
576        } else {
577            return Err(discriminant.1.span().error("Expected a path"));
578        }
579    }
580
581    let repr_type = args.ty.unwrap();
582    let match_entries = variants.iter().map(|variant| {
583        let variant_name = variant.name;
584        let variant_value = &variant.value;
585        quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), )
586    });
587    let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span());
588    let tryfrom_impl = quote!(
589        impl ::core::convert::TryFrom<#repr_type> for #name {
590            type Error = crate::error::DecodeError;
591
592            fn try_from(val: #repr_type) -> Result<Self, Self::Error> {
593                    match val {
594                        #(#match_entries) *
595                        value => Err(crate::error::DecodeError::#unknown_type_ident(value))
596                    }
597            }
598        }
599    );
600    let nfnetlinkdeserialize_impl = quote!(
601        impl crate::nlmsg::NfNetlinkDeserializable for #name {
602            fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
603                let (v, remaining_data) = #repr_type::deserialize(buf)?;
604                <#name>::try_from(v).map(|x| (x, remaining_data))
605            }
606        }
607    );
608    let vis = &ast.vis;
609    let attrs = ast.attrs;
610    let original_variants = variants.into_iter().map(|x| {
611        let mut inner = x.inner.clone();
612        let discriminant = inner.discriminant.as_mut().unwrap();
613        let cur_value = discriminant.1.clone();
614        let cast_value = Expr::Cast(ExprCast {
615            attrs: vec![],
616            expr: Box::new(cur_value),
617            as_token: Token![as](name.span()),
618            ty: Box::new(Type::Path(TypePath {
619                qself: None,
620                path: repr_type.clone(),
621            })),
622        });
623        discriminant.1 = cast_value;
624        inner
625    });
626    let res = quote! {
627        #[repr(#repr_type)]
628        #(#attrs) * #vis enum #name {
629            #(#original_variants),*
630        }
631
632        impl crate::nlmsg::NfNetlinkAttribute for #name {
633            fn get_size(&self) -> usize {
634                (*self as #repr_type).get_size()
635            }
636
637            fn write_payload(&self, addr: &mut [u8]) {
638                (*self as #repr_type).write_payload(addr);
639            }
640        }
641
642        #tryfrom_impl
643
644        #nfnetlinkdeserialize_impl
645    };
646
647    Ok(res.into())
648}
649
650#[proc_macro_attribute]
651pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream {
652    match nfnetlink_enum_inner(attrs, item) {
653        Ok(tokens) => tokens.into(),
654        Err(diag) => diag.emit_as_item_tokens().into(),
655    }
656}