structure_macro_impl/
lib.rs

1#![recursion_limit = "128"]
2
3#[macro_use]
4extern crate proc_macro_hack;
5#[macro_use]
6extern crate quote;
7
8use std::mem;
9use std::os::raw::c_void;
10use std::string::String;
11use quote::{Tokens, Ident};
12
13proc_macro_expr_impl! {
14    pub fn structure_impl(input: &str) -> String {
15        let format = trim_quotes(input);
16        let struct_name = Ident::from(format_to_struct_name(format));
17        let (values, endianness) = format_to_values(&format);
18        let (args, fn_decl_args, args_types) = build_args_list(&values);
19        let endianness = match endianness {
20            Endianness::Native => {
21                if cfg!(target_endian = "little") {
22                    quote!(LittleEndian)
23                } else {
24                    quote!(BigEndian)
25                }
26            }
27            Endianness::LittleEndian => quote!(LittleEndian),
28            Endianness::BigEndian => quote!(BigEndian),
29        };
30        let size = calc_size(&values);
31        let pack_fn = build_pack_fn(&args, &fn_decl_args, size);
32        let pack_into_fn = build_pack_into_fn(&values, &fn_decl_args, &endianness);
33        let unpack_fn = build_unpack_fn(&args_types, size);
34        let unpack_from_fn = build_unpack_from_fn(&values, &args, &args_types, &endianness);
35        let size_fn = build_size_fn(size);
36        let output = quote! {{
37            #[derive(Debug)]
38            #[allow(non_camel_case_types)]
39            struct #struct_name;
40            #[allow(unused_imports)]
41            use std::io::{Result, Write, Read, Error, ErrorKind, Cursor};
42            #[allow(unused_imports)]
43            use std::os::raw::c_void;
44            #[allow(unused_imports)]
45            use structure::byteorder::{WriteBytesExt, ReadBytesExt, BigEndian, LittleEndian};
46
47            #[allow(unused)] static TRUE_BUF: &[u8] = &[1];
48            #[allow(unused)] static FALSE_BUF: &[u8] = &[0];
49
50            impl #struct_name {
51                #pack_fn
52                #pack_into_fn
53                #unpack_fn
54                #unpack_from_fn
55                #size_fn
56            }
57
58            #struct_name // Create structure instance
59        }};
60
61        output.into_string()
62    }
63}
64
65#[derive(PartialEq)]
66enum Endianness {
67    Native,
68    LittleEndian,
69    BigEndian,
70}
71
72fn build_pack_fn(args: &Tokens, fn_decl_args: &Tokens, size: usize) -> Tokens {
73    quote! {
74        #[allow(unused)]
75        fn pack(&self, #fn_decl_args) -> Result<Vec<u8>> {
76            let mut wtr = Vec::with_capacity(#size);
77            self.pack_into(&mut wtr, #args)?;
78            Ok(wtr)
79        }
80    }
81}
82
83fn build_pack_into_fn(values: &[StructValue], fn_decl_args: &Tokens, endianness: &Tokens) -> Tokens {
84    // Pack each argument
85    let mut writings = Tokens::new();
86    let mut arg_index = 0;
87    for value in values {
88        let writing = match *value.kind() {
89            ValueKind::Number | ValueKind::Boolean | ValueKind::Pointer => {
90                let mut tokens = Tokens::new();
91                for _ in 0..value.repeat() {
92                    arg_index += 1;
93                    let current_arg = Ident::from(format!("_{}", arg_index));
94                    if *value.kind() == ValueKind::Number {
95                        let byteorder_fn = Ident::from(format!("write_{}", value.type_name()));
96                        match value.type_name().as_str() {
97                            "u8" | "i8" => {
98                                tokens.append(quote! {wtr.#byteorder_fn(#current_arg)?;});
99                            }
100                            _ => {
101                                tokens.append(quote! {wtr.#byteorder_fn::<#endianness>(#current_arg)?;});
102                            }
103                        }
104                    } else if *value.kind() == ValueKind::Boolean {
105                        tokens.append(quote! {
106                            let buf = if #current_arg { TRUE_BUF } else { FALSE_BUF };
107                            wtr.write(buf)?;
108                        });
109                    } else {
110                        let size = mem::size_of::<usize>();
111                        let integer_type = Ident::from(format!("u{}", size * 8));
112                        let byteorder_fn = Ident::from(format!("write_u{}", size * 8));
113                        tokens.append(quote! {
114                            let v = #current_arg as #integer_type;
115                            wtr.#byteorder_fn::<#endianness>(v)?;
116                        });
117                    }
118                }
119                tokens
120            }
121            ValueKind::Buffer | ValueKind::FixedBuffer => {
122                arg_index += 1;
123                let current_arg = Ident::from(format!("_{}", arg_index));
124                let buffer_length = value.repeat();
125                let length_check = if *value.kind() == ValueKind::Buffer {
126                    // If the type is `ValueKind::Buffer`, and the given buffer is smaller than the
127                    // size determined in the format, the rest will be filled with zeros.
128                    quote! { #current_arg.len() <= #buffer_length }
129                } else {
130                    quote! { #current_arg.len() == #buffer_length }
131                };
132                let mut tokens = quote! {
133                    if !(#length_check) {
134                        let msg = format!("Buffer length does not match the format \
135                            (buffer size in format: {}, actual size: {}", #current_arg.len(), #buffer_length);
136                        return Err(Error::new(ErrorKind::InvalidInput, msg));
137                    }
138                    wtr.write_all(#current_arg)?;
139                };
140                if *value.kind() == ValueKind::Buffer {
141                    tokens.append(quote! {
142                        if #current_arg.len() != #buffer_length {
143                            wtr.write_all(&vec![0; (#buffer_length - #current_arg.len())])?;
144                        }
145                    });
146                }
147                tokens
148            }
149            ValueKind::Padding => {
150                let number = value.repeat();
151                quote! {
152                    wtr.write_all(&[0; #number])?;
153                }
154            }
155        };
156        writings.append(writing);
157    }
158
159    quote! {
160        #[allow(unused)]
161        fn pack_into<T: Write>(&self, wtr: &mut T, #fn_decl_args) -> Result<()> {
162            #writings
163            Ok(())
164        }
165    }
166}
167
168fn build_unpack_fn(args_types: &Tokens, size: usize) -> Tokens {
169    quote! {
170        #[allow(unused)]
171        fn unpack<T: AsRef<[u8]>>(&self, buf: T) -> Result<(#args_types,)> {
172            if buf.as_ref().len() != #size {
173                let msg = format!("Buffer length does not match the format \
174                    (format size: {}, actual size: {}", #size, buf.as_ref().len());
175                return Err(Error::new(ErrorKind::InvalidInput, msg))
176            }
177            let mut rdr = Cursor::new(buf);
178            self.unpack_from(&mut rdr)
179        }
180    }
181}
182
183fn build_unpack_from_fn(values: &[StructValue], args: &Tokens, args_types: &Tokens, endianness: &Tokens) -> Tokens {
184    let mut readings = Tokens::new();
185    let mut arg_index = 0;
186    for value in values {
187        let reading = match *value.kind() {
188            ValueKind::Number | ValueKind::Boolean | ValueKind::Pointer => {
189                let mut tokens = Tokens::new();
190                for _ in 0..value.repeat() {
191                    arg_index += 1;
192                    let current_arg = Ident::from(format!("_{}", arg_index));
193                    if *value.kind() == ValueKind::Number {
194                        let byteorder_fn = Ident::from(format!("read_{}", value.type_name()));
195                        match value.type_name().as_str() {
196                            "u8" | "i8" => {
197                                tokens.append(quote! { let #current_arg = rdr.#byteorder_fn()?;});
198                            }
199                            _ => {
200                                tokens.append(quote! { let #current_arg = rdr.#byteorder_fn::<#endianness>()?;});
201                            }
202                        }
203                    } else if *value.kind() == ValueKind::Boolean {
204                        tokens.append(quote! {
205                            let #current_arg = rdr.read_u8()?;
206                            let #current_arg = #current_arg != 0; // 0 is false
207                        });
208                    } else {
209                        let pointer_type = Ident::from(value.type_name().as_str());
210                        let size = mem::size_of::<usize>();
211                        let byteorder_fn = Ident::from(format!("read_u{}", size * 8));
212                        tokens.append(quote! {
213                            let #current_arg = {
214                                let v = rdr.#byteorder_fn::<#endianness>()?;
215                                v as #pointer_type
216                            };
217                        });
218                    }
219                }
220                tokens
221            }
222            ValueKind::Buffer | ValueKind::FixedBuffer => {
223                arg_index += 1;
224                let current_arg = Ident::from(format!("_{}", arg_index));
225                let buffer_length = value.repeat();
226                quote! {
227                    let mut #current_arg = vec![0; #buffer_length];
228                    rdr.read_exact(&mut #current_arg)?;
229                }
230            }
231            ValueKind::Padding => {
232                let number = value.repeat();
233                quote! {
234                    rdr.read_exact(&mut [0; #number])?;
235                }
236            }
237        };
238        readings.append(reading);
239    }
240
241    quote! {
242        #[allow(unused)]
243        fn unpack_from<T: Read>(&self, rdr: &mut T) -> Result<(#args_types,)> {
244            #readings
245            Ok((#args,))
246        }
247    }
248}
249
250/// Build the args list, the function declaration args list and the type list
251fn build_args_list(values: &[StructValue]) -> (Tokens, Tokens, Tokens) {
252    let mut args = vec![];
253    let mut fn_decl_args = vec![];
254    let mut args_types = vec![];
255    let mut arg_index = 0;
256    for v in values {
257        match *v.kind() {
258            ValueKind::Padding => continue,
259            ValueKind::Buffer | ValueKind::FixedBuffer => {
260                arg_index += 1;
261                args.push(Ident::from(format!("_{}", arg_index)));
262                fn_decl_args.push(Ident::from(format!("_{}: {}", arg_index, v.type_name())));
263                args_types.push(Ident::from("Vec<u8>".to_owned()));
264            }
265            _ => {
266                for _ in 0..v.repeat() {
267                    arg_index += 1;
268                    args.push(Ident::from(format!("_{}", arg_index)));
269                    fn_decl_args.push(Ident::from(format!("_{}: {}", arg_index, v.type_name())));
270                    args_types.push(Ident::from(v.type_name().as_str()));
271                }
272            }
273        }
274    }
275    (quote!(#(#args),*), quote!(#(#fn_decl_args),*), quote!(#(#args_types),*))
276}
277
278fn build_size_fn(size: usize) -> Tokens {
279    quote! {
280        #[allow(unused)]
281        fn size(&self) -> usize {
282            #size
283        }
284    }
285}
286
287fn calc_size(values: &[StructValue]) -> usize {
288    let mut size = 0;
289    for v in values {
290        if v.type_name().starts_with("*") {
291            mem::size_of::<*const c_void>();
292        }
293        let type_size = match v.type_name().as_str() {
294            "i8" => mem::size_of::<i8>(),
295            "&[u8]" | "u8" => mem::size_of::<u8>(),
296            "bool" => 1,
297            "i16" => mem::size_of::<i16>(),
298            "u16" => mem::size_of::<u16>(),
299            "i32" => mem::size_of::<i32>(),
300            "u32" => mem::size_of::<u32>(),
301            "i64" => mem::size_of::<i64>(),
302            "u64" => mem::size_of::<u64>(),
303            "f32" => mem::size_of::<f32>(),
304            "f64" => mem::size_of::<f64>(),
305            t if t.starts_with("*") => mem::size_of::<usize>(),
306            _ => panic!("Unknown type: '{}'", v.type_name()),
307        };
308        size += type_size * v.repeat();
309    }
310    size
311}
312
313fn format_to_struct_name(format: &str) -> String {
314    format!("Struct_{}", format.replace("?", "Bool")
315        .replace("=", "Native")
316        .replace("<", "LittleEndian")
317        .replace(">", "")
318        .replace("!", ""))
319}
320
321/// Return the format string without the endianness, and the endianness
322fn format_endianness(format: &str) -> (&str, Endianness) {
323    let first_char = format.chars().nth(0);
324    let endianness = match first_char {
325        Some('=') => Endianness::Native,
326        Some('<') => Endianness::LittleEndian,
327        _ => Endianness::BigEndian,
328    };
329    let mut chars = format.chars();
330    match chars.next() {
331        Some('=') | Some('<') | Some('>') | Some('!') => (chars.as_str(), endianness),
332        _ => (format, endianness),
333    }
334}
335
336fn char_to_type(c: char) -> (&'static str, ValueKind) {
337    match c {
338        'b' => ("i8", ValueKind::Number),
339        'B' => ("u8", ValueKind::Number),
340        '?' => ("bool", ValueKind::Boolean),
341        'h' => ("i16", ValueKind::Number),
342        'H' => ("u16", ValueKind::Number),
343        'i' => ("i32", ValueKind::Number),
344        'I' => ("u32", ValueKind::Number),
345        'q' => ("i64", ValueKind::Number),
346        'Q' => ("u64", ValueKind::Number),
347        'f' => ("f32", ValueKind::Number),
348        'd' => ("f64", ValueKind::Number),
349        's' => ("&[u8]", ValueKind::Buffer),
350        'S' => ("&[u8]", ValueKind::FixedBuffer),
351        'P' => ("*const c_void", ValueKind::Pointer),
352        'x' => ("u8", ValueKind::Padding),
353        _ => panic!("Unknown format: '{}'", c),
354    }
355}
356
357fn format_to_values(format: &str) -> (Vec<StructValue>, Endianness) {
358    let (format, endianness) = format_endianness(format);
359    let mut values = vec![];
360    let mut chars = format.chars().peekable();
361    let mut repeat_str = String::new();
362    while let Some(c) = chars.next() {
363        if c.is_digit(10) {
364            repeat_str.push(c);
365        } else {
366            let (type_name, kind) = char_to_type(c);
367            let mut type_name = type_name.to_owned();
368            if kind == ValueKind::Pointer {
369                // Parse pointer type
370                if endianness != Endianness::Native {
371                    panic!("Pointer can be used only if the endianness is native. \
372                            To change the endianness to native, start the format with '='");
373                }
374                if let Some(&'<') = chars.peek() {
375                    chars.next();
376                    let mut pointer_type_name = String::new();
377                    loop {
378                        let c = chars.next();
379                        if c == None {
380                            panic!("Pointer type must end with '>'");
381                        } else if c == Some('>') {
382                            if pointer_type_name.is_empty() {
383                                panic!("Pointer type cannot be empty");
384                            }
385                            type_name = format!("*const {}", pointer_type_name);
386                            break;
387                        } else {
388                            pointer_type_name.push(c.unwrap());
389                        }
390                    }
391                }
392            }
393            let mut repeat = 1;
394            if !repeat_str.is_empty() {
395                repeat = repeat_str.parse().expect("not a number");
396                repeat_str.clear();
397            }
398            values.push(StructValue::new(type_name, repeat, kind));
399        }
400    }
401    if !repeat_str.is_empty() {
402        panic!("No format character is followed by the number {}", repeat_str);
403    }
404    (values, endianness)
405}
406
407#[derive(PartialEq)]
408enum ValueKind {
409    Number,
410    Boolean,
411    Buffer,
412    FixedBuffer,
413    Pointer,
414    Padding,
415}
416
417struct StructValue {
418    type_name: String,
419    repeat: usize,
420    kind: ValueKind
421}
422
423impl StructValue {
424    fn new(type_name: String, repeat: usize, kind: ValueKind) -> StructValue {
425        StructValue { type_name: type_name, repeat: repeat, kind: kind }
426    }
427    fn type_name(&self) -> &String {
428        &self.type_name
429    }
430    fn repeat(&self) -> usize {
431        self.repeat
432    }
433    fn kind(&self) -> &ValueKind {
434        &self.kind
435    }
436}
437
438fn trim_quotes(input: &str) -> &str {
439
440    if input.chars().nth(0) != Some('"') && input.chars().last() != Some('"') || input.len() < 2 {
441        panic!("structure!() macro takes a literal string as an argument");
442    }
443    &input[1..(input.len()-1)]
444}