1#![allow(rustdoc::broken_intra_doc_links)]
2
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro2_diagnostics::{Diagnostic, Level, SpanDiagnosticExt};
10use quote::{quote, quote_spanned};
11
12use syn::parse::Parser;
13use syn::punctuated::Punctuated;
14use syn::spanned::Spanned;
15use syn::{
16 parse, Attribute, Expr, ExprCast, ExprLit, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path,
17 Token, Type, TypePath, Visibility,
18};
19
20use once_cell::sync::OnceCell;
21
22struct GlobalState {
23 declared_identifiers: Vec<String>,
24}
25
26static STATE: OnceCell<GlobalState> = OnceCell::new();
27
28fn get_state() -> &'static GlobalState {
29 STATE.get_or_init(|| {
30 let sys_file = {
31 let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
35 let mut sys_file = String::new();
36 File::open(out_path)
37 .expect("Error: could not open the output header file")
38 .read_to_string(&mut sys_file)
39 .expect("Could not read the header file");
40 syn::parse_file(&sys_file).expect("Could not parse the header file")
41 };
42
43 let mut declared_identifiers = Vec::new();
44 for item in sys_file.items {
45 if let Item::Const(v) = item {
46 declared_identifiers.push(v.ident.to_string());
47 }
48 }
49
50 GlobalState {
51 declared_identifiers,
52 }
53 })
54}
55
56struct Field<'a> {
57 name: &'a Ident,
58 ty: &'a Type,
59 args: FieldArgs,
60 netlink_type: Path,
61 vis: &'a Visibility,
62 attrs: Vec<&'a Attribute>,
63}
64
65#[derive(Default)]
66struct FieldArgs {
67 netlink_type: Option<Path>,
68 override_function_name: Option<String>,
69 optional: bool,
70}
71
72fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs, Diagnostic> {
73 let mut args = FieldArgs::default();
74 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
75 let attribute_args = parser
76 .parse2(input)
77 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
78 for arg in attribute_args.iter() {
79 match arg {
80 Meta::Path(path) => {
81 if args.netlink_type.is_none() {
82 args.netlink_type = Some(path.clone());
83 } else {
84 return Err(arg
85 .span()
86 .error("Only a single netlink value can exist for a given field"));
87 }
88 }
89 Meta::NameValue(namevalue) => {
90 let key = namevalue
91 .path
92 .get_ident()
93 .expect("the macro parameter is not an ident?")
94 .to_string();
95 match key.as_str() {
96 "name_in_functions" => {
97 if let Expr::Lit(ExprLit {
98 lit: Lit::Str(val), ..
99 }) = &namevalue.value
100 {
101 args.override_function_name = Some(val.value());
102 } else {
103 return Err(namevalue.value.span().error("Expected a string literal"));
104 }
105 }
106 "optional" => {
107 if let Expr::Lit(ExprLit {
108 lit: Lit::Bool(boolean),
109 ..
110 }) = &namevalue.value
111 {
112 args.optional = boolean.value;
113 } else {
114 return Err(namevalue.value.span().error("Expected a boolean"));
115 }
116 }
117 _ => return Err(arg.span().error("Unsupported macro parameter")),
118 }
119 }
120 _ => return Err(arg.span().error("Unrecognized argument")),
121 }
122 }
123 Ok(args)
124}
125
126struct StructArgs {
127 nested: bool,
128 derive_decoder: bool,
129 derive_deserialize: bool,
130}
131
132impl Default for StructArgs {
133 fn default() -> Self {
134 Self {
135 nested: false,
136 derive_decoder: true,
137 derive_deserialize: true,
138 }
139 }
140}
141
142fn parse_struct_args(input: TokenStream) -> Result<StructArgs, Diagnostic> {
143 let mut args = StructArgs::default();
144 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
145 let attribute_args = parser
146 .parse(input.clone())
147 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
148 for arg in attribute_args.iter() {
149 if let Meta::NameValue(namevalue) = arg {
150 let key = namevalue
151 .path
152 .get_ident()
153 .expect("the macro parameter is not an ident?")
154 .to_string();
155 if let Expr::Lit(ExprLit {
156 lit: Lit::Bool(boolean),
157 ..
158 }) = &namevalue.value
159 {
160 match key.as_str() {
161 "derive_decoder" => {
162 args.derive_decoder = boolean.value;
163 }
164 "nested" => {
165 args.nested = boolean.value;
166 }
167 "derive_deserialize" => {
168 args.derive_deserialize = boolean.value;
169 }
170 _ => return Err(arg.span().error("Unsupported macro parameter")),
171 }
172 } else {
173 return Err(namevalue.value.span().error("Expected a boolean"));
174 }
175 } else {
176 return Err(arg.span().error("Unrecognized argument"));
177 }
178 }
179 Ok(args)
180}
181
182fn nfnetlink_struct_inner(
183 attrs: TokenStream,
184 item: TokenStream,
185) -> Result<TokenStream, Diagnostic> {
186 let ast: ItemStruct = parse(item).unwrap();
187 let name = ast.ident;
188
189 let args = match parse_struct_args(attrs) {
190 Ok(x) => x,
191 Err(e) => return Err(e),
192 };
193
194 let state = get_state();
195
196 let mut fields = Vec::with_capacity(ast.fields.len());
197 let mut identical_fields = Vec::new();
198
199 'out: for field in ast.fields.iter() {
200 for attr in field.attrs.iter() {
201 if let Some(id) = attr.path().get_ident() {
202 if id == "field" {
203 let field_args = match &attr.meta {
204 Meta::List(l) => l,
205 _ => {
206 return Err(attr.span().error("Invalid attributes"));
207 }
208 };
209
210 let field_args = match parse_field_args(field_args.tokens.clone()) {
211 Ok(x) => x,
212 Err(_) => {
213 return Err(attr.span().error("Could not parse the field attributes"));
214 }
215 };
216 if let Some(netlink_type) = field_args.netlink_type.clone() {
217 if field_args.optional {
220 let netlink_type_ident = netlink_type
221 .segments
222 .last()
223 .expect("empty path?")
224 .ident
225 .to_string();
226 if !state.declared_identifiers.contains(&netlink_type_ident) {
227 continue 'out;
229 }
230 }
231
232 fields.push(Field {
233 name: field.ident.as_ref().expect("Should be a names struct"),
234 ty: &field.ty,
235 args: field_args,
236 netlink_type,
237 vis: &field.vis,
238 attrs: field
240 .attrs
241 .iter()
242 .filter(|x| x.path().get_ident() != attr.path().get_ident())
243 .collect(),
244 });
245 } else {
246 return Err(attr.span().error("Missing Netlink Type in field"));
247 }
248 continue 'out;
249 }
250 }
251 }
252 identical_fields.push(field);
253 }
254
255 let getters_and_setters = fields.iter().map(|field| {
256 let field_name = field.name;
257 let field_str = field_name.to_string();
259 let field_str = field
260 .args
261 .override_function_name
262 .as_ref()
263 .map(|x| x.as_str())
264 .unwrap_or(field_str.as_str());
265 let field_type = field.ty;
266
267 let getter_name = format!("get_{}", field_str);
268 let getter_name = Ident::new(&getter_name, field.name.span());
269
270 let muttable_getter_name = format!("get_mut_{}", field_str);
271 let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span());
272
273 let setter_name = format!("set_{}", field_str);
274 let setter_name = Ident::new(&setter_name, field.name.span());
275
276 let in_place_edit_name = format!("with_{}", field_str);
277 let in_place_edit_name = Ident::new(&in_place_edit_name, field.name.span());
278 quote!(
279 #[allow(dead_code)]
280 impl #name {
281 pub fn #getter_name(&self) -> Option<&#field_type> {
282 self.#field_name.as_ref()
283 }
284
285 pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> {
286 self.#field_name.as_mut()
287 }
288
289 pub fn #setter_name(&mut self, val: impl Into<#field_type>) {
290 self.#field_name = Some(val.into());
291 }
292
293 pub fn #in_place_edit_name(mut self, val: impl Into<#field_type>) -> Self {
294 self.#field_name = Some(val.into());
295 self
296 }
297 })
298 });
299
300 let decoder = if args.derive_decoder {
301 let match_entries = fields.iter().map(|field| {
302 let field_name = field.name;
303 let field_type = field.ty;
304 let netlink_value = &field.netlink_type;
305 quote!(
306 x if x == #netlink_value => {
307 debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>());
308 let (val, remaining) = <#field_type>::deserialize(buf)?;
309 if remaining.len() != 0 {
310 return Err(crate::error::DecodeError::InvalidDataSize);
311 }
312 self.#field_name = Some(val);
313 Ok(())
314 }
315 )
316 });
317 quote!(
318 impl crate::nlmsg::AttributeDecoder for #name {
319 #[allow(dead_code)]
320 fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> {
321 use crate::nlmsg::NfNetlinkDeserializable;
322 debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>());
323 match attr_type {
324 #(#match_entries),*
325 _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)),
326 }
327 }
328 }
329 )
330 } else {
331 proc_macro2::TokenStream::new()
332 };
333
334 let nfnetlinkattribute_impl = {
335 let size_entries = fields.iter().map(|field| {
336 let field_name = field.name;
337 quote!(
338 if let Some(val) = &self.#field_name {
339 size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
341 + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
342 }
343 )
344 });
345 let write_entries = fields.iter().map(|field| {
346 let field_name = field.name;
347 let field_str = field_name.to_string();
348 let netlink_value = &field.netlink_type;
349 quote!(
350 if let Some(val) = &self.#field_name {
351 debug!("writing attribute {} - {:?}", #field_str, val);
352
353 crate::parser::write_attribute(#netlink_value, val, addr);
354
355 #[allow(unused)]
356 {
357 let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
358 + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
359 addr = &mut addr[size..];
360 }
361 }
362 )
363 });
364 let nested = args.nested;
365 quote!(
366 impl crate::nlmsg::NfNetlinkAttribute for #name {
367 fn is_nested(&self) -> bool {
368 #nested
369 }
370
371 fn get_size(&self) -> usize {
372 use crate::nlmsg::NfNetlinkAttribute;
373
374 let mut size = 0;
375 #(#size_entries) *
376 size
377 }
378
379 fn write_payload(&self, mut addr: &mut [u8]) {
380 use crate::nlmsg::NfNetlinkAttribute;
381
382 #(#write_entries) *
383 }
384 }
385 )
386 };
387
388 let vis = &ast.vis;
389 let attrs = ast.attrs;
390 let new_fields = fields.iter().map(|field| {
391 let name = field.name;
392 let ty = field.ty;
393 let attrs = &field.attrs;
394 let vis = &field.vis;
395 quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
396 });
397 let nfnetlinkdeserialize_impl = if args.derive_deserialize {
398 quote!(
399 impl crate::nlmsg::NfNetlinkDeserializable for #name {
400 fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
401 Ok((crate::parser::read_attributes(buf)?, &[]))
402 }
403 }
404 )
405 } else {
406 proc_macro2::TokenStream::new()
407 };
408 let res = quote! {
409 #(#attrs) * #vis struct #name {
410 #(#new_fields)*
411 #(#identical_fields),*
412 }
413
414 #(#getters_and_setters) *
415
416 #decoder
417
418 #nfnetlinkattribute_impl
419
420 #nfnetlinkdeserialize_impl
421 };
422
423 Ok(res.into())
424}
425
426#[proc_macro_attribute]
487pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
488 match nfnetlink_struct_inner(attrs, item) {
489 Ok(tokens) => tokens.into(),
490 Err(diag) => diag.emit_as_item_tokens().into(),
491 }
492}
493
494struct Variant<'a> {
495 inner: &'a syn::Variant,
496 name: &'a Ident,
497 value: &'a Path,
498}
499
500#[derive(Default)]
501struct EnumArgs {
502 nested: bool,
503 ty: Option<Path>,
504}
505
506fn parse_enum_args(input: TokenStream) -> Result<EnumArgs, Diagnostic> {
507 let mut args = EnumArgs::default();
508 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
509 let attribute_args = parser
510 .parse(input)
511 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
512 for arg in attribute_args.iter() {
513 match arg {
514 Meta::Path(path) => {
515 if args.ty.is_none() {
516 args.ty = Some(path.clone());
517 } else {
518 return Err(arg
519 .span()
520 .error("A value can only have a single representation"));
521 }
522 }
523 Meta::NameValue(namevalue) => {
524 let key = namevalue
525 .path
526 .get_ident()
527 .expect("the macro parameter is not an ident?")
528 .to_string();
529 match key.as_str() {
530 "nested" => {
531 if let Expr::Lit(ExprLit {
532 lit: Lit::Bool(boolean),
533 ..
534 }) = &namevalue.value
535 {
536 args.nested = boolean.value;
537 } else {
538 return Err(namevalue.value.span().error("Expected a boolean"));
539 }
540 }
541 _ => return Err(arg.span().error("Unsupported macro parameter")),
542 }
543 }
544 _ => return Err(arg.span().error("Unrecognized argument")),
545 }
546 }
547 Ok(args)
548}
549
550fn nfnetlink_enum_inner(attrs: TokenStream, item: TokenStream) -> Result<TokenStream, Diagnostic> {
551 let ast: ItemEnum = parse(item).unwrap();
552 let name = ast.ident;
553
554 let args = match parse_enum_args(attrs) {
555 Ok(x) => x,
556 Err(_) => return Err(Span::call_site().error("Could not parse the macro arguments")),
557 };
558
559 if args.ty.is_none() {
560 return Err(Span::call_site().error("The target type representation is unspecified"));
561 }
562
563 let mut variants = Vec::with_capacity(ast.variants.len());
564
565 for variant in ast.variants.iter() {
566 if variant.discriminant.is_none() {
567 return Err(variant.ident.span().error("Missing value"));
568 }
569 let discriminant = variant.discriminant.as_ref().unwrap();
570 if let syn::Expr::Path(path) = &discriminant.1 {
571 variants.push(Variant {
572 inner: variant,
573 name: &variant.ident,
574 value: &path.path,
575 });
576 } else {
577 return Err(discriminant.1.span().error("Expected a path"));
578 }
579 }
580
581 let repr_type = args.ty.unwrap();
582 let match_entries = variants.iter().map(|variant| {
583 let variant_name = variant.name;
584 let variant_value = &variant.value;
585 quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), )
586 });
587 let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span());
588 let tryfrom_impl = quote!(
589 impl ::core::convert::TryFrom<#repr_type> for #name {
590 type Error = crate::error::DecodeError;
591
592 fn try_from(val: #repr_type) -> Result<Self, Self::Error> {
593 match val {
594 #(#match_entries) *
595 value => Err(crate::error::DecodeError::#unknown_type_ident(value))
596 }
597 }
598 }
599 );
600 let nfnetlinkdeserialize_impl = quote!(
601 impl crate::nlmsg::NfNetlinkDeserializable for #name {
602 fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
603 let (v, remaining_data) = #repr_type::deserialize(buf)?;
604 <#name>::try_from(v).map(|x| (x, remaining_data))
605 }
606 }
607 );
608 let vis = &ast.vis;
609 let attrs = ast.attrs;
610 let original_variants = variants.into_iter().map(|x| {
611 let mut inner = x.inner.clone();
612 let discriminant = inner.discriminant.as_mut().unwrap();
613 let cur_value = discriminant.1.clone();
614 let cast_value = Expr::Cast(ExprCast {
615 attrs: vec![],
616 expr: Box::new(cur_value),
617 as_token: Token),
618 ty: Box::new(Type::Path(TypePath {
619 qself: None,
620 path: repr_type.clone(),
621 })),
622 });
623 discriminant.1 = cast_value;
624 inner
625 });
626 let res = quote! {
627 #[repr(#repr_type)]
628 #(#attrs) * #vis enum #name {
629 #(#original_variants),*
630 }
631
632 impl crate::nlmsg::NfNetlinkAttribute for #name {
633 fn get_size(&self) -> usize {
634 (*self as #repr_type).get_size()
635 }
636
637 fn write_payload(&self, addr: &mut [u8]) {
638 (*self as #repr_type).write_payload(addr);
639 }
640 }
641
642 #tryfrom_impl
643
644 #nfnetlinkdeserialize_impl
645 };
646
647 Ok(res.into())
648}
649
650#[proc_macro_attribute]
651pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream {
652 match nfnetlink_enum_inner(attrs, item) {
653 Ok(tokens) => tokens.into(),
654 Err(diag) => diag.emit_as_item_tokens().into(),
655 }
656}