1#![allow(rustdoc::broken_intra_doc_links)]
2
3use std::fs::File;
4use std::io::Read;
5use std::path::PathBuf;
6
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro2_diagnostics::{Diagnostic, Level, SpanDiagnosticExt};
10use quote::{quote, quote_spanned};
11
12use syn::parse::Parser;
13use syn::punctuated::Punctuated;
14use syn::spanned::Spanned;
15use syn::{
16 Attribute, Expr, ExprCast, ExprLit, Ident, Item, ItemEnum, ItemStruct, Lit, Meta, Path, Token,
17 Type, TypePath, Visibility, parse,
18};
19
20use once_cell::sync::OnceCell;
21
22struct GlobalState {
23 declared_identifiers: Vec<String>,
24}
25
26static STATE: OnceCell<GlobalState> = OnceCell::new();
27
28fn get_state() -> &'static GlobalState {
29 STATE.get_or_init(|| {
30 let sys_file = {
31 let out_path = PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("sys.rs");
35 let mut sys_file = String::new();
36 File::open(out_path)
37 .expect("Error: could not open the output header file")
38 .read_to_string(&mut sys_file)
39 .expect("Could not read the header file");
40 syn::parse_file(&sys_file).expect("Could not parse the header file")
41 };
42
43 let mut declared_identifiers = Vec::new();
44 for item in sys_file.items {
45 if let Item::Const(v) = item {
46 declared_identifiers.push(v.ident.to_string());
47 }
48 }
49
50 GlobalState {
51 declared_identifiers,
52 }
53 })
54}
55
56struct Field<'a> {
57 name: &'a Ident,
58 ty: &'a Type,
59 args: FieldArgs,
60 netlink_type: Path,
61 vis: &'a Visibility,
62 attrs: Vec<&'a Attribute>,
63}
64
65#[derive(Default)]
66struct FieldArgs {
67 netlink_type: Option<Path>,
68 override_function_name: Option<String>,
69 optional: bool,
70}
71
72fn parse_field_args(input: proc_macro2::TokenStream) -> Result<FieldArgs, Diagnostic> {
73 let mut args = FieldArgs::default();
74 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
75 let attribute_args = parser
76 .parse2(input)
77 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
78 for arg in attribute_args.iter() {
79 match arg {
80 Meta::Path(path) => {
81 if args.netlink_type.is_none() {
82 args.netlink_type = Some(path.clone());
83 } else {
84 return Err(arg
85 .span()
86 .error("Only a single netlink value can exist for a given field"));
87 }
88 }
89 Meta::NameValue(namevalue) => {
90 let key = namevalue
91 .path
92 .get_ident()
93 .expect("the macro parameter is not an ident?")
94 .to_string();
95 match key.as_str() {
96 "name_in_functions" => {
97 if let Expr::Lit(ExprLit {
98 lit: Lit::Str(val), ..
99 }) = &namevalue.value
100 {
101 args.override_function_name = Some(val.value());
102 } else {
103 return Err(namevalue.value.span().error("Expected a string literal"));
104 }
105 }
106 "optional" => {
107 if let Expr::Lit(ExprLit {
108 lit: Lit::Bool(boolean),
109 ..
110 }) = &namevalue.value
111 {
112 args.optional = boolean.value;
113 } else {
114 return Err(namevalue.value.span().error("Expected a boolean"));
115 }
116 }
117 _ => return Err(arg.span().error("Unsupported macro parameter")),
118 }
119 }
120 _ => return Err(arg.span().error("Unrecognized argument")),
121 }
122 }
123 Ok(args)
124}
125
126struct StructArgs {
127 nested: bool,
128 derive_decoder: bool,
129 derive_deserialize: bool,
130}
131
132impl Default for StructArgs {
133 fn default() -> Self {
134 Self {
135 nested: false,
136 derive_decoder: true,
137 derive_deserialize: true,
138 }
139 }
140}
141
142fn parse_struct_args(input: TokenStream) -> Result<StructArgs, Diagnostic> {
143 let mut args = StructArgs::default();
144 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
145 let attribute_args = parser
146 .parse(input.clone())
147 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
148 for arg in attribute_args.iter() {
149 if let Meta::NameValue(namevalue) = arg {
150 let key = namevalue
151 .path
152 .get_ident()
153 .expect("the macro parameter is not an ident?")
154 .to_string();
155 if let Expr::Lit(ExprLit {
156 lit: Lit::Bool(boolean),
157 ..
158 }) = &namevalue.value
159 {
160 match key.as_str() {
161 "derive_decoder" => {
162 args.derive_decoder = boolean.value;
163 }
164 "nested" => {
165 args.nested = boolean.value;
166 }
167 "derive_deserialize" => {
168 args.derive_deserialize = boolean.value;
169 }
170 _ => return Err(arg.span().error("Unsupported macro parameter")),
171 }
172 } else {
173 return Err(namevalue.value.span().error("Expected a boolean"));
174 }
175 } else {
176 return Err(arg.span().error("Unrecognized argument"));
177 }
178 }
179 Ok(args)
180}
181
182fn nfnetlink_struct_inner(
183 attrs: TokenStream,
184 item: TokenStream,
185) -> Result<TokenStream, Diagnostic> {
186 let ast: ItemStruct = parse(item).unwrap();
187 let name = ast.ident;
188
189 let args = parse_struct_args(attrs)?;
190
191 let state = get_state();
192
193 let mut fields = Vec::with_capacity(ast.fields.len());
194 let mut identical_fields = Vec::new();
195
196 'out: for field in ast.fields.iter() {
197 for attr in field.attrs.iter() {
198 if let Some(id) = attr.path().get_ident()
199 && id == "field"
200 {
201 let field_args = match &attr.meta {
202 Meta::List(l) => l,
203 _ => {
204 return Err(attr.span().error("Invalid attributes"));
205 }
206 };
207
208 let field_args = match parse_field_args(field_args.tokens.clone()) {
209 Ok(x) => x,
210 Err(_) => {
211 return Err(attr.span().error("Could not parse the field attributes"));
212 }
213 };
214 if let Some(netlink_type) = field_args.netlink_type.clone() {
215 if field_args.optional {
218 let netlink_type_ident = netlink_type
219 .segments
220 .last()
221 .expect("empty path?")
222 .ident
223 .to_string();
224 if !state.declared_identifiers.contains(&netlink_type_ident) {
225 continue 'out;
227 }
228 }
229
230 fields.push(Field {
231 name: field.ident.as_ref().expect("Should be a names struct"),
232 ty: &field.ty,
233 args: field_args,
234 netlink_type,
235 vis: &field.vis,
236 attrs: field
238 .attrs
239 .iter()
240 .filter(|x| x.path().get_ident() != attr.path().get_ident())
241 .collect(),
242 });
243 } else {
244 return Err(attr.span().error("Missing Netlink Type in field"));
245 }
246 continue 'out;
247 }
248 }
249 identical_fields.push(field);
250 }
251
252 let getters_and_setters = fields.iter().map(|field| {
253 let field_name = field.name;
254 let field_str = field_name.to_string();
256 let field_str = field
257 .args
258 .override_function_name
259 .as_deref()
260 .unwrap_or(field_str.as_str());
261 let field_type = field.ty;
262
263 let getter_name = format!("get_{}", field_str);
264 let getter_name = Ident::new(&getter_name, field.name.span());
265
266 let muttable_getter_name = format!("get_mut_{}", field_str);
267 let muttable_getter_name = Ident::new(&muttable_getter_name, field.name.span());
268
269 let setter_name = format!("set_{}", field_str);
270 let setter_name = Ident::new(&setter_name, field.name.span());
271
272 let in_place_edit_name = format!("with_{}", field_str);
273 let in_place_edit_name = Ident::new(&in_place_edit_name, field.name.span());
274 quote!(
275 #[allow(dead_code)]
276 impl #name {
277 pub fn #getter_name(&self) -> Option<&#field_type> {
278 self.#field_name.as_ref()
279 }
280
281 pub fn #muttable_getter_name(&mut self) -> Option<&mut #field_type> {
282 self.#field_name.as_mut()
283 }
284
285 pub fn #setter_name(&mut self, val: impl Into<#field_type>) {
286 self.#field_name = Some(val.into());
287 }
288
289 pub fn #in_place_edit_name(mut self, val: impl Into<#field_type>) -> Self {
290 self.#field_name = Some(val.into());
291 self
292 }
293 })
294 });
295
296 let decoder = if args.derive_decoder {
297 let match_entries = fields.iter().map(|field| {
298 let field_name = field.name;
299 let field_type = field.ty;
300 let netlink_value = &field.netlink_type;
301 quote!(
302 x if x == #netlink_value => {
303 debug!("Calling {}::deserialize()", std::any::type_name::<#field_type>());
304 let (val, remaining) = <#field_type>::deserialize(buf)?;
305 if remaining.len() != 0 {
306 return Err(crate::error::DecodeError::InvalidDataSize);
307 }
308 self.#field_name = Some(val);
309 Ok(())
310 }
311 )
312 });
313 quote!(
314 impl crate::nlmsg::AttributeDecoder for #name {
315 #[allow(dead_code)]
316 fn decode_attribute(&mut self, attr_type: u16, buf: &[u8]) -> Result<(), crate::error::DecodeError> {
317 use crate::nlmsg::NfNetlinkDeserializable;
318 debug!("Decoding attribute {} in type {}", attr_type, std::any::type_name::<#name>());
319 match attr_type {
320 #(#match_entries),*
321 _ => Err(crate::error::DecodeError::UnsupportedAttributeType(attr_type)),
322 }
323 }
324 }
325 )
326 } else {
327 proc_macro2::TokenStream::new()
328 };
329
330 let nfnetlinkattribute_impl = {
331 let size_entries = fields.iter().map(|field| {
332 let field_name = field.name;
333 quote!(
334 if let Some(val) = &self.#field_name {
335 size += crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
337 + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
338 }
339 )
340 });
341 let write_entries = fields.iter().map(|field| {
342 let field_name = field.name;
343 let field_str = field_name.to_string();
344 let netlink_value = &field.netlink_type;
345 quote!(
346 if let Some(val) = &self.#field_name {
347 debug!("writing attribute {} - {:?}", #field_str, val);
348
349 crate::parser::write_attribute(#netlink_value, val, addr);
350
351 #[allow(unused)]
352 {
353 let size = crate::nlmsg::pad_netlink_object::<crate::sys::nlattr>()
354 + crate::nlmsg::pad_netlink_object_with_variable_size(val.get_size());
355 addr = &mut addr[size..];
356 }
357 }
358 )
359 });
360 let nested = args.nested;
361 quote!(
362 impl crate::nlmsg::NfNetlinkAttribute for #name {
363 fn is_nested(&self) -> bool {
364 #nested
365 }
366
367 fn get_size(&self) -> usize {
368 use crate::nlmsg::NfNetlinkAttribute;
369
370 let mut size = 0;
371 #(#size_entries) *
372 size
373 }
374
375 fn write_payload(&self, mut addr: &mut [u8]) {
376 use crate::nlmsg::NfNetlinkAttribute;
377
378 #(#write_entries) *
379 }
380 }
381 )
382 };
383
384 let vis = &ast.vis;
385 let attrs = ast.attrs;
386 let new_fields = fields.iter().map(|field| {
387 let name = field.name;
388 let ty = field.ty;
389 let attrs = &field.attrs;
390 let vis = &field.vis;
391 quote_spanned!(name.span() => #(#attrs) * #vis #name: Option<#ty>, )
392 });
393 let nfnetlinkdeserialize_impl = if args.derive_deserialize {
394 quote!(
395 impl crate::nlmsg::NfNetlinkDeserializable for #name {
396 fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
397 Ok((crate::parser::read_attributes(buf)?, &[]))
398 }
399 }
400 )
401 } else {
402 proc_macro2::TokenStream::new()
403 };
404 let res = quote! {
405 #(#attrs) * #vis struct #name {
406 #(#new_fields)*
407 #(#identical_fields),*
408 }
409
410 #(#getters_and_setters) *
411
412 #decoder
413
414 #nfnetlinkattribute_impl
415
416 #nfnetlinkdeserialize_impl
417 };
418
419 Ok(res.into())
420}
421
422#[proc_macro_attribute]
483pub fn nfnetlink_struct(attrs: TokenStream, item: TokenStream) -> TokenStream {
484 match nfnetlink_struct_inner(attrs, item) {
485 Ok(tokens) => tokens,
486 Err(diag) => diag.emit_as_item_tokens().into(),
487 }
488}
489
490struct Variant<'a> {
491 inner: &'a syn::Variant,
492 name: &'a Ident,
493 value: &'a Path,
494}
495
496#[derive(Default)]
497struct EnumArgs {
498 nested: bool,
499 ty: Option<Path>,
500}
501
502fn parse_enum_args(input: TokenStream) -> Result<EnumArgs, Diagnostic> {
503 let mut args = EnumArgs::default();
504 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
505 let attribute_args = parser
506 .parse(input)
507 .map_err(|e| Diagnostic::new(Level::Error, e.to_string()))?;
508 for arg in attribute_args.iter() {
509 match arg {
510 Meta::Path(path) => {
511 if args.ty.is_none() {
512 args.ty = Some(path.clone());
513 } else {
514 return Err(arg
515 .span()
516 .error("A value can only have a single representation"));
517 }
518 }
519 Meta::NameValue(namevalue) => {
520 let key = namevalue
521 .path
522 .get_ident()
523 .expect("the macro parameter is not an ident?")
524 .to_string();
525 match key.as_str() {
526 "nested" => {
527 if let Expr::Lit(ExprLit {
528 lit: Lit::Bool(boolean),
529 ..
530 }) = &namevalue.value
531 {
532 args.nested = boolean.value;
533 } else {
534 return Err(namevalue.value.span().error("Expected a boolean"));
535 }
536 }
537 _ => return Err(arg.span().error("Unsupported macro parameter")),
538 }
539 }
540 _ => return Err(arg.span().error("Unrecognized argument")),
541 }
542 }
543 Ok(args)
544}
545
546fn nfnetlink_enum_inner(attrs: TokenStream, item: TokenStream) -> Result<TokenStream, Diagnostic> {
547 let ast: ItemEnum = parse(item).unwrap();
548 let name = ast.ident;
549
550 let args = match parse_enum_args(attrs) {
551 Ok(x) => x,
552 Err(_) => return Err(Span::call_site().error("Could not parse the macro arguments")),
553 };
554
555 if args.ty.is_none() {
556 return Err(Span::call_site().error("The target type representation is unspecified"));
557 }
558
559 let mut variants = Vec::with_capacity(ast.variants.len());
560
561 for variant in ast.variants.iter() {
562 if variant.discriminant.is_none() {
563 return Err(variant.ident.span().error("Missing value"));
564 }
565 let discriminant = variant.discriminant.as_ref().unwrap();
566 if let syn::Expr::Path(path) = &discriminant.1 {
567 variants.push(Variant {
568 inner: variant,
569 name: &variant.ident,
570 value: &path.path,
571 });
572 } else {
573 return Err(discriminant.1.span().error("Expected a path"));
574 }
575 }
576
577 let repr_type = args.ty.unwrap();
578 let match_entries = variants.iter().map(|variant| {
579 let variant_name = variant.name;
580 let variant_value = &variant.value;
581 quote!( x if x == (#variant_value as #repr_type) => Ok(Self::#variant_name), )
582 });
583 #[allow(clippy::to_string_in_format_args)]
584 let unknown_type_ident = Ident::new(&format!("Unknown{}", name.to_string()), name.span());
585 let tryfrom_impl = quote!(
586 impl ::core::convert::TryFrom<#repr_type> for #name {
587 type Error = crate::error::DecodeError;
588
589 #[allow(clippy::unnecessary_cast)]
590 fn try_from(val: #repr_type) -> Result<Self, Self::Error> {
591 match val {
592 #(#match_entries) *
593 value => Err(crate::error::DecodeError::#unknown_type_ident(value))
594 }
595 }
596 }
597 );
598 let nfnetlinkdeserialize_impl = quote!(
599 impl crate::nlmsg::NfNetlinkDeserializable for #name {
600 fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), crate::error::DecodeError> {
601 let (v, remaining_data) = #repr_type::deserialize(buf)?;
602 <#name>::try_from(v).map(|x| (x, remaining_data))
603 }
604 }
605 );
606 let vis = &ast.vis;
607 let attrs = ast.attrs;
608 let original_variants = variants.into_iter().map(|x| {
609 let mut inner = x.inner.clone();
610 let discriminant = inner.discriminant.as_mut().unwrap();
611 let cur_value = discriminant.1.clone();
612 let cast_value = Expr::Cast(ExprCast {
613 attrs: vec![],
614 expr: Box::new(cur_value),
615 as_token: Token),
616 ty: Box::new(Type::Path(TypePath {
617 qself: None,
618 path: repr_type.clone(),
619 })),
620 });
621 discriminant.1 = cast_value;
622 inner
623 });
624 let res = quote! {
625 #[repr(#repr_type)]
626 #[allow(clippy::unnecessary_cast)]
627 #(#attrs) * #vis enum #name {
628 #(#original_variants),*
629 }
630
631 impl crate::nlmsg::NfNetlinkAttribute for #name {
632 fn get_size(&self) -> usize {
633 (*self as #repr_type).get_size()
634 }
635
636 fn write_payload(&self, addr: &mut [u8]) {
637 (*self as #repr_type).write_payload(addr);
638 }
639 }
640
641 #tryfrom_impl
642
643 #nfnetlinkdeserialize_impl
644 };
645
646 Ok(res.into())
647}
648
649#[proc_macro_attribute]
650pub fn nfnetlink_enum(attrs: TokenStream, item: TokenStream) -> TokenStream {
651 match nfnetlink_enum_inner(attrs, item) {
652 Ok(tokens) => tokens,
653 Err(diag) => diag.emit_as_item_tokens().into(),
654 }
655}