Skip to main content

rustables_macros/
lib.rs

1#![allow(rustdoc::broken_intra_doc_links)]
2
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro2_diagnostics::{Diagnostic, Level, SpanDiagnosticExt};
10use quote::{quote, quote_spanned};
11
12use syn::parse::Parser;
13use syn::punctuated::Punctuated;
14use syn::spanned::Spanned;
15use syn::{
16    Attribute, Expr, ExprCast, ExprLit, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path, Token,
17    Type, TypePath, Visibility, parse,
18};
19
20use once_cell::sync::OnceCell;
21
22struct GlobalState {
23    declared_identifiers: Vec<String>,
24}
25
26static STATE: OnceCell<GlobalState> = OnceCell::new();
27
28fn get_state() -> &'static GlobalState {
29    STATE.get_or_init(|| {
30        let sys_file = {
31            // Load the header file and extract the constants defined inside.
32            // This is what determines whether optional attributes (or enum variants)
33            // will be supported or not in the resulting binary.
34            let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
35            let mut sys_file = String::new();
36            File::open(out_path)
37                .expect("Error: could not open the output header file")
38                .read_to_string(&mut sys_file)
39                .expect("Could not read the header file");
40            syn::parse_file(&sys_file).expect("Could not parse the header file")
41        };
42
43        let mut declared_identifiers = Vec::new();
44        for item in sys_file.items {
45            if let Item::Const(v) = item {
46                declared_identifiers.push(v.ident.to_string());
47            }
48        }
49
50        GlobalState {
51            declared_identifiers,
52        }
53    })
54}
55
56struct Field<'a> {
57    name: &'a Ident,
58    ty: &'a Type,
59    args: FieldArgs,
60    netlink_type: Path,
61    vis: &'a Visibility,
62    attrs: Vec<&'a Attribute>,
63}
64
65#[derive(Default)]
66struct FieldArgs {
67    netlink_type: Option<Path>,
68    override_function_name: Option<String>,
69    optional: bool,
70}
71
72fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs, Diagnostic> {
73    let mut args = FieldArgs::default();
74    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
75    let attribute_args = parser
76        .parse2(input)
77        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
78    for arg in attribute_args.iter() {
79        match arg {
80            Meta::Path(path) => {
81                if args.netlink_type.is_none() {
82                    args.netlink_type = Some(path.clone());
83                } else {
84                    return Err(arg
85                        .span()
86                        .error("Only a single netlink value can exist for a given field"));
87                }
88            }
89            Meta::NameValue(namevalue) => {
90                let key = namevalue
91                    .path
92                    .get_ident()
93                    .expect("the macro parameter is not an ident?")
94                    .to_string();
95                match key.as_str() {
96                    "name_in_functions" => {
97                        if let Expr::Lit(ExprLit {
98                            lit: Lit::Str(val), ..
99                        }) = &namevalue.value
100                        {
101                            args.override_function_name = Some(val.value());
102                        } else {
103                            return Err(namevalue.value.span().error("Expected a string literal"));
104                        }
105                    }
106                    "optional" => {
107                        if let Expr::Lit(ExprLit {
108                            lit: Lit::Bool(boolean),
109                            ..
110                        }) = &namevalue.value
111                        {
112                            args.optional = boolean.value;
113                        } else {
114                            return Err(namevalue.value.span().error("Expected a boolean"));
115                        }
116                    }
117                    _ => return Err(arg.span().error("Unsupported macro parameter")),
118                }
119            }
120            _ => return Err(arg.span().error("Unrecognized argument")),
121        }
122    }
123    Ok(args)
124}
125
126struct StructArgs {
127    nested: bool,
128    derive_decoder: bool,
129    derive_deserialize: bool,
130}
131
132impl Default for StructArgs {
133    fn default() -> Self {
134        Self {
135            nested: false,
136            derive_decoder: true,
137            derive_deserialize: true,
138        }
139    }
140}
141
142fn parse_struct_args(input: TokenStream) -> Result<StructArgs, Diagnostic> {
143    let mut args = StructArgs::default();
144    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
145    let attribute_args = parser
146        .parse(input.clone())
147        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
148    for arg in attribute_args.iter() {
149        if let Meta::NameValue(namevalue) = arg {
150            let key = namevalue
151                .path
152                .get_ident()
153                .expect("the macro parameter is not an ident?")
154                .to_string();
155            if let Expr::Lit(ExprLit {
156                lit: Lit::Bool(boolean),
157                ..
158            }) = &namevalue.value
159            {
160                match key.as_str() {
161                    "derive_decoder" => {
162                        args.derive_decoder = boolean.value;
163                    }
164                    "nested" => {
165                        args.nested = boolean.value;
166                    }
167                    "derive_deserialize" => {
168                        args.derive_deserialize = boolean.value;
169                    }
170                    _ => return Err(arg.span().error("Unsupported macro parameter")),
171                }
172            } else {
173                return Err(namevalue.value.span().error("Expected a boolean"));
174            }
175        } else {
176            return Err(arg.span().error("Unrecognized argument"));
177        }
178    }
179    Ok(args)
180}
181
182fn nfnetlink_struct_inner(
183    attrs: TokenStream,
184    item: TokenStream,
185) -> Result<TokenStream, Diagnostic> {
186    let ast: ItemStruct = parse(item).unwrap();
187    let name = ast.ident;
188
189    let args = parse_struct_args(attrs)?;
190
191    let state = get_state();
192
193    let mut fields = Vec::with_capacity(ast.fields.len());
194    let mut identical_fields = Vec::new();
195
196    'out: for field in ast.fields.iter() {
197        for attr in field.attrs.iter() {
198            if let Some(id) = attr.path().get_ident()
199                && id == "field"
200            {
201                let field_args = match &attr.meta {
202                    Meta::List(l) => l,
203                    _ => {
204                        return Err(attr.span().error("Invalid attributes"));
205                    }
206                };
207
208                let field_args = match parse_field_args(field_args.tokens.clone()) {
209                    Ok(x) => x,
210                    Err(_) => {
211                        return Err(attr.span().error("Could not parse the field attributes"));
212                    }
213                };
214                if let Some(netlink_type) = field_args.netlink_type.clone() {
215                    // optional fields are not generated when the kernel version you have on
216                    // the system does not support that field
217                    if field_args.optional {
218                        let netlink_type_ident = netlink_type
219                            .segments
220                            .last()
221                            .expect("empty path?")
222                            .ident
223                            .to_string();
224                        if !state.declared_identifiers.contains(&netlink_type_ident) {
225                            // reject the optional identifier
226                            continue 'out;
227                        }
228                    }
229
230                    fields.push(Field {
231                        name: field.ident.as_ref().expect("Should be a names struct"),
232                        ty: &field.ty,
233                        args: field_args,
234                        netlink_type,
235                        vis: &field.vis,
236                        // drop the "field" attribute
237                        attrs: field
238                            .attrs
239                            .iter()
240                            .filter(|x| x.path().get_ident() != attr.path().get_ident())
241                            .collect(),
242                    });
243                } else {
244                    return Err(attr.span().error("Missing Netlink Type in field"));
245                }
246                continue 'out;
247            }
248        }
249        identical_fields.push(field);
250    }
251
252    let getters_and_setters = fields.iter().map(|field| {
253        let field_name = field.name;
254        // use the name override if any
255        let field_str = field_name.to_string();
256        let field_str = field
257            .args
258            .override_function_name
259            .as_deref()
260            .unwrap_or(field_str.as_str());
261        let field_type = field.ty;
262
263        let getter_name = format!("get_{}", field_str);
264        let getter_name = Ident::new(&getter_name, field.name.span());
265
266        let muttable_getter_name = format!("get_mut_{}", field_str);
267        let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span());
268
269        let setter_name = format!("set_{}", field_str);
270        let setter_name = Ident::new(&setter_name, field.name.span());
271
272        let in_place_edit_name = format!("with_{}", field_str);
273        let in_place_edit_name = Ident::new(&in_place_edit_name, field.name.span());
274        quote!(
275            #[allow(dead_code)]
276            impl #name {
277            pub fn #getter_name(&self) -> Option<&#field_type> {
278                self.#field_name.as_ref()
279            }
280
281            pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> {
282                self.#field_name.as_mut()
283            }
284
285            pub fn #setter_name(&mut self, val: impl Into<#field_type>) {
286                self.#field_name = Some(val.into());
287            }
288
289            pub fn #in_place_edit_name(mut self, val: impl Into<#field_type>) -> Self {
290                self.#field_name = Some(val.into());
291                self
292            }
293        })
294    });
295
296    let decoder = if args.derive_decoder {
297        let match_entries = fields.iter().map(|field| {
298            let field_name = field.name;
299            let field_type = field.ty;
300            let netlink_value = &field.netlink_type;
301            quote!(
302                x if x == #netlink_value => {
303                    debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>());
304                    let (val, remaining) = <#field_type>::deserialize(buf)?;
305                    if remaining.len() != 0 {
306                        return Err(crate::error::DecodeError::InvalidDataSize);
307                    }
308                    self.#field_name = Some(val);
309                    Ok(())
310                }
311            )
312        });
313        quote!(
314            impl crate::nlmsg::AttributeDecoder for #name {
315                #[allow(dead_code)]
316                fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> {
317                    use crate::nlmsg::NfNetlinkDeserializable;
318                    debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>());
319                    match attr_type {
320                        #(#match_entries),*
321                        _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)),
322                    }
323                }
324            }
325        )
326    } else {
327        proc_macro2::TokenStream::new()
328    };
329
330    let nfnetlinkattribute_impl = {
331        let size_entries = fields.iter().map(|field| {
332            let field_name = field.name;
333            quote!(
334                if let Some(val) = &self.#field_name {
335                    // Attribute header + attribute value
336                    size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
337                        + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
338                }
339            )
340        });
341        let write_entries = fields.iter().map(|field| {
342            let field_name = field.name;
343            let field_str = field_name.to_string();
344            let netlink_value = &field.netlink_type;
345            quote!(
346                if let Some(val) = &self.#field_name {
347                    debug!("writing attribute {} - {:?}", #field_str, val);
348
349                    crate::parser::write_attribute(#netlink_value, val, addr);
350
351                    #[allow(unused)]
352                    {
353                        let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
354                            + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
355                        addr = &mut addr[size..];
356                    }
357                }
358            )
359        });
360        let nested = args.nested;
361        quote!(
362            impl crate::nlmsg::NfNetlinkAttribute for #name {
363                fn is_nested(&self) -> bool {
364                    #nested
365                }
366
367                fn get_size(&self) -> usize {
368                    use crate::nlmsg::NfNetlinkAttribute;
369
370                    let mut size = 0;
371                    #(#size_entries) *
372                    size
373                }
374
375                fn write_payload(&self, mut addr: &mut [u8]) {
376                    use crate::nlmsg::NfNetlinkAttribute;
377
378                    #(#write_entries) *
379                }
380            }
381        )
382    };
383
384    let vis = &ast.vis;
385    let attrs = ast.attrs;
386    let new_fields = fields.iter().map(|field| {
387        let name = field.name;
388        let ty = field.ty;
389        let attrs = &field.attrs;
390        let vis = &field.vis;
391        quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
392    });
393    let nfnetlinkdeserialize_impl = if args.derive_deserialize {
394        quote!(
395            impl crate::nlmsg::NfNetlinkDeserializable for #name {
396                fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
397                    Ok((crate::parser::read_attributes(buf)?, &[]))
398                }
399            }
400        )
401    } else {
402        proc_macro2::TokenStream::new()
403    };
404    let res = quote! {
405        #(#attrs) * #vis struct #name {
406            #(#new_fields)*
407            #(#identical_fields),*
408        }
409
410        #(#getters_and_setters) *
411
412        #decoder
413
414        #nfnetlinkattribute_impl
415
416        #nfnetlinkdeserialize_impl
417    };
418
419    Ok(res.into())
420}
421
422/// `nfnetlink_struct` is a macro wrapping structures that describe nftables objects.
423/// It allows serializing and deserializing these objects to the corresponding nfnetlink
424/// attributes.
425///
426/// It automatically generates getter and setter functions for each netlink properties.
427///
428/// # Parameters
429/// The macro have multiple parameters:
430/// - `nested` (defaults to `false`): the structure is nested (in the netlink sense)
431///   inside its parent structure. This is the case of most structures outside
432///   of the main nftables objects (batches, sets, rules, chains and tables), which are
433///   the outermost structures, and as such cannot be nested.
434/// - `derive_decoder` (defaults to `true`): derive a [`rustables::nlmsg::AttributeDecoder`]
435///   implementation for the structure
436/// - `derive_deserialize` (defaults to `true`): derive a [`rustables::nlmsg::NfNetlinkDeserializable`]
437///   implementation for the structure
438///
439/// # Example use
440/// ```ignore
441/// #[nfnetlink_struct(derive_deserialize = false)]
442/// #[derive(PartialEq, Eq, Default, Debug)]
443/// pub struct Chain {
444///     family: ProtocolFamily,
445///     #[field(NFTA_CHAIN_TABLE)]
446///     table: String,
447///     #[field(NFTA_CHAIN_TYPE, name_in_functions = "type")]
448///     chain_type: ChainType,
449///     #[field(optional = true, crate::sys::NFTA_CHAIN_USERDATA)]
450///     userdata: Vec<u8>,
451///     ...
452/// }
453/// ```
454///
455/// # Type of fields
456/// This contrived example show the two possible type of fields:
457/// - A field that is not converted to a netlink attribute (`family`) because it is not
458///   annotated in `#[field]` attribute.
459///   When deserialized, this field will take the value it is given in the Default implementation
460///   of the struct.
461/// - A field that is annotated with the `#[field]` attribute.
462///   That attribute takes parameters (there are none here), and the netlink attribute type.
463///   When annotated with that attribute, the macro will generate `get_<name>`, `set_<name>` and
464///   `with_<name>` methods to manipulate the attribute (e.g. `get_table`, `set_table` and
465///   `with_table`).
466///   It will also replace the field type (here `String`) with an Option (`Option<String>`)
467///   so the struct may represent objects where that attribute is not set.
468///
469/// # `#[field]` parameters
470/// The `#[field]` attribute can be parametrized through two options:
471/// - `optional` (defaults to `false`): if the netlink attribute type (here `NFTA_CHAIN_USERDATA`)
472///   does not exist, do not generate methods and ignore this attribute if encountered
473///   while deserializing a nftables object.
474///   This is useful for attributes added recently to the kernel, which may not be supported on
475///   older kernels.
476///   Support for an attribute is detected according to the existence of that attribute in the kernel
477///   headers.
478/// - `name_in_functions` (not defined by default): overwrite the `<name`> in the name of the methods
479///   `get_<name>`, `set_<name>` and `with_<name>`.
480///   Here, this means that even though the field is called `chain_type`, users can query it with
481///   the method `get_type` instead of `get_chain_type`.
482#[proc_macro_attribute]
483pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
484    match nfnetlink_struct_inner(attrs, item) {
485        Ok(tokens) => tokens,
486        Err(diag) => diag.emit_as_item_tokens().into(),
487    }
488}
489
490struct Variant<'a> {
491    inner: &'a syn::Variant,
492    name: &'a Ident,
493    value: &'a Path,
494}
495
496#[derive(Default)]
497struct EnumArgs {
498    nested: bool,
499    ty: Option<Path>,
500}
501
502fn parse_enum_args(input: TokenStream) -> Result<EnumArgs, Diagnostic> {
503    let mut args = EnumArgs::default();
504    let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
505    let attribute_args = parser
506        .parse(input)
507        .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
508    for arg in attribute_args.iter() {
509        match arg {
510            Meta::Path(path) => {
511                if args.ty.is_none() {
512                    args.ty = Some(path.clone());
513                } else {
514                    return Err(arg
515                        .span()
516                        .error("A value can only have a single representation"));
517                }
518            }
519            Meta::NameValue(namevalue) => {
520                let key = namevalue
521                    .path
522                    .get_ident()
523                    .expect("the macro parameter is not an ident?")
524                    .to_string();
525                match key.as_str() {
526                    "nested" => {
527                        if let Expr::Lit(ExprLit {
528                            lit: Lit::Bool(boolean),
529                            ..
530                        }) = &namevalue.value
531                        {
532                            args.nested = boolean.value;
533                        } else {
534                            return Err(namevalue.value.span().error("Expected a boolean"));
535                        }
536                    }
537                    _ => return Err(arg.span().error("Unsupported macro parameter")),
538                }
539            }
540            _ => return Err(arg.span().error("Unrecognized argument")),
541        }
542    }
543    Ok(args)
544}
545
546fn nfnetlink_enum_inner(attrs: TokenStream, item: TokenStream) -> Result<TokenStream, Diagnostic> {
547    let ast: ItemEnum = parse(item).unwrap();
548    let name = ast.ident;
549
550    let args = match parse_enum_args(attrs) {
551        Ok(x) => x,
552        Err(_) => return Err(Span::call_site().error("Could not parse the macro arguments")),
553    };
554
555    if args.ty.is_none() {
556        return Err(Span::call_site().error("The target type representation is unspecified"));
557    }
558
559    let mut variants = Vec::with_capacity(ast.variants.len());
560
561    for variant in ast.variants.iter() {
562        if variant.discriminant.is_none() {
563            return Err(variant.ident.span().error("Missing value"));
564        }
565        let discriminant = variant.discriminant.as_ref().unwrap();
566        if let syn::Expr::Path(path) = &discriminant.1 {
567            variants.push(Variant {
568                inner: variant,
569                name: &variant.ident,
570                value: &path.path,
571            });
572        } else {
573            return Err(discriminant.1.span().error("Expected a path"));
574        }
575    }
576
577    let repr_type = args.ty.unwrap();
578    let match_entries = variants.iter().map(|variant| {
579        let variant_name = variant.name;
580        let variant_value = &variant.value;
581        quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), )
582    });
583    #[allow(clippy::to_string_in_format_args)]
584    let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span());
585    let tryfrom_impl = quote!(
586        impl ::core::convert::TryFrom<#repr_type> for #name {
587            type Error = crate::error::DecodeError;
588
589            #[allow(clippy::unnecessary_cast)]
590            fn try_from(val: #repr_type) -> Result<Self, Self::Error> {
591                    match val {
592                        #(#match_entries) *
593                        value => Err(crate::error::DecodeError::#unknown_type_ident(value))
594                    }
595            }
596        }
597    );
598    let nfnetlinkdeserialize_impl = quote!(
599        impl crate::nlmsg::NfNetlinkDeserializable for #name {
600            fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
601                let (v, remaining_data) = #repr_type::deserialize(buf)?;
602                <#name>::try_from(v).map(|x| (x, remaining_data))
603            }
604        }
605    );
606    let vis = &ast.vis;
607    let attrs = ast.attrs;
608    let original_variants = variants.into_iter().map(|x| {
609        let mut inner = x.inner.clone();
610        let discriminant = inner.discriminant.as_mut().unwrap();
611        let cur_value = discriminant.1.clone();
612        let cast_value = Expr::Cast(ExprCast {
613            attrs: vec![],
614            expr: Box::new(cur_value),
615            as_token: Token![as](name.span()),
616            ty: Box::new(Type::Path(TypePath {
617                qself: None,
618                path: repr_type.clone(),
619            })),
620        });
621        discriminant.1 = cast_value;
622        inner
623    });
624    let res = quote! {
625        #[repr(#repr_type)]
626        #[allow(clippy::unnecessary_cast)]
627        #(#attrs) * #vis enum #name {
628            #(#original_variants),*
629        }
630
631        impl crate::nlmsg::NfNetlinkAttribute for #name {
632            fn get_size(&self) -> usize {
633                (*self as #repr_type).get_size()
634            }
635
636            fn write_payload(&self, addr: &mut [u8]) {
637                (*self as #repr_type).write_payload(addr);
638            }
639        }
640
641        #tryfrom_impl
642
643        #nfnetlinkdeserialize_impl
644    };
645
646    Ok(res.into())
647}
648
649#[proc_macro_attribute]
650pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream {
651    match nfnetlink_enum_inner(attrs, item) {
652        Ok(tokens) => tokens,
653        Err(diag) => diag.emit_as_item_tokens().into(),
654    }
655}