Skip to main content

rustyhdf5_derive/
lib.rs

1//! Proc macros for deriving HDF5 compound type mapping.
2//!
3//! Provides `#[derive(H5Type)]` which generates methods for mapping Rust structs
4//! to HDF5 compound datatypes, including serialization and deserialization.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{Data, DeriveInput, Fields, Type, parse_macro_input};
9
10/// Derive macro that generates HDF5 compound type mapping for structs.
11///
12/// Generates three methods:
13/// - `hdf5_datatype()` — returns the HDF5 `Datatype::Compound` descriptor
14/// - `to_bytes(&self)` — serializes the struct to HDF5 compound raw bytes
15/// - `from_bytes(data: &[u8])` — deserializes from HDF5 compound raw bytes
16///
17/// # Supported field types
18/// - `f32`, `f64`
19/// - `i8`, `i16`, `i32`, `i64`
20/// - `u8`, `u16`, `u32`, `u64`
21/// - `bool` (stored as `u8`)
22/// - `[T; N]` fixed-size arrays of any supported numeric type
23#[proc_macro_derive(H5Type)]
24pub fn derive_h5type(input: TokenStream) -> TokenStream {
25    let input = parse_macro_input!(input as DeriveInput);
26    match impl_h5type(&input) {
27        Ok(ts) => ts.into(),
28        Err(e) => e.to_compile_error().into(),
29    }
30}
31
32fn impl_h5type(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
33    let name = &input.ident;
34
35    let fields = match &input.data {
36        Data::Struct(data) => match &data.fields {
37            Fields::Named(named) => &named.named,
38            _ => {
39                return Err(syn::Error::new_spanned(
40                    name,
41                    "H5Type can only be derived for structs with named fields",
42                ));
43            }
44        },
45        _ => {
46            return Err(syn::Error::new_spanned(
47                name,
48                "H5Type can only be derived for structs",
49            ));
50        }
51    };
52
53    let mut datatype_member_stmts = Vec::new();
54    let mut serialize_stmts = Vec::new();
55    let mut deserialize_stmts = Vec::new();
56    let mut field_names = Vec::new();
57    let mut size_increments = Vec::new();
58
59    for field in fields.iter() {
60        let field_name = field.ident.as_ref().unwrap();
61        let field_name_str = field_name.to_string();
62        let ty = &field.ty;
63
64        let (dt_expr, ser_expr, deser_expr, size_expr) = type_mapping(ty, field_name)?;
65
66        datatype_member_stmts.push(quote! {
67            _members.push(rustyhdf5_format::datatype::CompoundMember {
68                name: #field_name_str.into(),
69                byte_offset: _offset,
70                datatype: #dt_expr,
71            });
72            _offset += #size_expr as u64;
73        });
74
75        size_increments.push(quote! { + (#size_expr as usize) });
76        serialize_stmts.push(ser_expr);
77        deserialize_stmts.push(deser_expr);
78        field_names.push(field_name.clone());
79    }
80
81    let expanded = quote! {
82        impl #name {
83            /// Returns the HDF5 compound datatype descriptor for this struct.
84            pub fn hdf5_datatype() -> rustyhdf5_format::datatype::Datatype {
85                let mut _offset: u64 = 0;
86                let mut _members = Vec::new();
87                #(#datatype_member_stmts)*
88                rustyhdf5_format::datatype::Datatype::Compound {
89                    size: _offset as u32,
90                    members: _members,
91                }
92            }
93
94            /// Serializes this struct to HDF5 compound raw bytes (little-endian).
95            pub fn to_bytes(&self) -> Vec<u8> {
96                let mut _buf = Vec::with_capacity(Self::_h5_compound_size());
97                #(#serialize_stmts)*
98                _buf
99            }
100
101            /// Deserializes from HDF5 compound raw bytes (little-endian).
102            pub fn from_bytes(_data: &[u8]) -> Self {
103                let mut _pos = 0usize;
104                #(#deserialize_stmts)*
105                Self {
106                    #(#field_names),*
107                }
108            }
109
110            fn _h5_compound_size() -> usize {
111                0usize #(#size_increments)*
112            }
113        }
114    };
115
116    Ok(expanded)
117}
118
119fn type_mapping(
120    ty: &Type,
121    field_name: &syn::Ident,
122) -> syn::Result<(
123    proc_macro2::TokenStream, // datatype expression
124    proc_macro2::TokenStream, // serialize expression
125    proc_macro2::TokenStream, // deserialize expression
126    proc_macro2::TokenStream, // size expression
127)> {
128    match ty {
129        Type::Path(type_path) => {
130            let seg = type_path.path.segments.last().unwrap();
131            let type_name = seg.ident.to_string();
132            match type_name.as_str() {
133                "f64" => Ok(float_mapping(field_name, 8, 64, 52, 11, 52, 1023)),
134                "f32" => Ok(float_mapping(field_name, 4, 32, 23, 8, 23, 127)),
135                "i8" => Ok(int_mapping(field_name, 1, true)),
136                "i16" => Ok(int_mapping(field_name, 2, true)),
137                "i32" => Ok(int_mapping(field_name, 4, true)),
138                "i64" => Ok(int_mapping(field_name, 8, true)),
139                "u8" => Ok(int_mapping(field_name, 1, false)),
140                "u16" => Ok(int_mapping(field_name, 2, false)),
141                "u32" => Ok(int_mapping(field_name, 4, false)),
142                "u64" => Ok(int_mapping(field_name, 8, false)),
143                "bool" => Ok(bool_mapping(field_name)),
144                _ => Err(syn::Error::new_spanned(
145                    ty,
146                    format!("unsupported type `{type_name}` for H5Type derive"),
147                )),
148            }
149        }
150        Type::Array(arr) => {
151            let elem_ty = &*arr.elem;
152            let len_expr = &arr.len;
153            array_mapping(field_name, elem_ty, len_expr)
154        }
155        _ => Err(syn::Error::new_spanned(
156            ty,
157            "unsupported type for H5Type derive",
158        )),
159    }
160}
161
162fn float_mapping(
163    field_name: &syn::Ident,
164    size: u32,
165    precision: u16,
166    mant_loc: u8,
167    exp_size: u8,
168    mant_size: u8,
169    exp_bias: u32,
170) -> (
171    proc_macro2::TokenStream,
172    proc_macro2::TokenStream,
173    proc_macro2::TokenStream,
174    proc_macro2::TokenStream,
175) {
176    let size_lit = size;
177    let precision_lit = precision;
178    let exp_size_lit = exp_size;
179    let mant_size_lit = mant_size;
180    let exp_bias_lit = exp_bias;
181    let exp_loc: u8 = mant_loc;
182
183    let dt = quote! {
184        rustyhdf5_format::datatype::Datatype::FloatingPoint {
185            size: #size_lit,
186            byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
187            bit_offset: 0,
188            bit_precision: #precision_lit,
189            exponent_location: #exp_loc,
190            exponent_size: #exp_size_lit,
191            mantissa_location: 0,
192            mantissa_size: #mant_size_lit,
193            exponent_bias: #exp_bias_lit,
194        }
195    };
196
197    let ser = quote! {
198        _buf.extend_from_slice(&self.#field_name.to_le_bytes());
199    };
200
201    let deser = if size == 8 {
202        quote! {
203            let #field_name = f64::from_le_bytes(
204                _data[_pos.._pos + 8].try_into().unwrap()
205            );
206            _pos += 8;
207        }
208    } else {
209        quote! {
210            let #field_name = f32::from_le_bytes(
211                _data[_pos.._pos + 4].try_into().unwrap()
212            );
213            _pos += 4;
214        }
215    };
216
217    let sz = size as usize;
218    let size_expr = quote! { #sz };
219    (dt, ser, deser, size_expr)
220}
221
222fn int_mapping(
223    field_name: &syn::Ident,
224    size: u32,
225    signed: bool,
226) -> (
227    proc_macro2::TokenStream,
228    proc_macro2::TokenStream,
229    proc_macro2::TokenStream,
230    proc_macro2::TokenStream,
231) {
232    let precision = (size * 8) as u16;
233
234    let dt = quote! {
235        rustyhdf5_format::datatype::Datatype::FixedPoint {
236            size: #size,
237            byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
238            signed: #signed,
239            bit_offset: 0,
240            bit_precision: #precision,
241        }
242    };
243
244    let ser = quote! {
245        _buf.extend_from_slice(&self.#field_name.to_le_bytes());
246    };
247
248    let sz = size as usize;
249    let deser = match (size, signed) {
250        (1, true) => quote! {
251            let #field_name = _data[_pos] as i8;
252            _pos += 1;
253        },
254        (1, false) => quote! {
255            let #field_name = _data[_pos];
256            _pos += 1;
257        },
258        (2, true) => quote! {
259            let #field_name = i16::from_le_bytes(
260                _data[_pos.._pos + 2].try_into().unwrap()
261            );
262            _pos += 2;
263        },
264        (2, false) => quote! {
265            let #field_name = u16::from_le_bytes(
266                _data[_pos.._pos + 2].try_into().unwrap()
267            );
268            _pos += 2;
269        },
270        (4, true) => quote! {
271            let #field_name = i32::from_le_bytes(
272                _data[_pos.._pos + 4].try_into().unwrap()
273            );
274            _pos += 4;
275        },
276        (4, false) => quote! {
277            let #field_name = u32::from_le_bytes(
278                _data[_pos.._pos + 4].try_into().unwrap()
279            );
280            _pos += 4;
281        },
282        (8, true) => quote! {
283            let #field_name = i64::from_le_bytes(
284                _data[_pos.._pos + 8].try_into().unwrap()
285            );
286            _pos += 8;
287        },
288        (8, false) => quote! {
289            let #field_name = u64::from_le_bytes(
290                _data[_pos.._pos + 8].try_into().unwrap()
291            );
292            _pos += 8;
293        },
294        _ => quote! {
295            let mut _tmp = [0u8; #sz];
296            _tmp.copy_from_slice(&_data[_pos.._pos + #sz]);
297            let #field_name = _tmp;
298            _pos += #sz;
299        },
300    };
301
302    let sz = size as usize;
303    let size_expr = quote! { #sz };
304    (dt, ser, deser, size_expr)
305}
306
307fn bool_mapping(
308    field_name: &syn::Ident,
309) -> (
310    proc_macro2::TokenStream,
311    proc_macro2::TokenStream,
312    proc_macro2::TokenStream,
313    proc_macro2::TokenStream,
314) {
315    let dt = quote! {
316        rustyhdf5_format::datatype::Datatype::FixedPoint {
317            size: 1,
318            byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
319            signed: false,
320            bit_offset: 0,
321            bit_precision: 8,
322        }
323    };
324
325    let ser = quote! {
326        _buf.push(if self.#field_name { 1u8 } else { 0u8 });
327    };
328
329    let deser = quote! {
330        let #field_name = _data[_pos] != 0;
331        _pos += 1;
332    };
333
334    let size_expr = quote! { 1usize };
335    (dt, ser, deser, size_expr)
336}
337
338fn array_mapping(
339    field_name: &syn::Ident,
340    elem_ty: &Type,
341    len_expr: &syn::Expr,
342) -> syn::Result<(
343    proc_macro2::TokenStream,
344    proc_macro2::TokenStream,
345    proc_macro2::TokenStream,
346    proc_macro2::TokenStream,
347)> {
348    let Type::Path(type_path) = elem_ty else {
349        return Err(syn::Error::new_spanned(
350            elem_ty,
351            "array element must be a primitive type for H5Type derive",
352        ));
353    };
354    let elem_name = type_path.path.segments.last().unwrap().ident.to_string();
355
356    let (base_dt, elem_size, deser_one) = match elem_name.as_str() {
357        "f64" => (
358            quote! {
359                rustyhdf5_format::datatype::Datatype::FloatingPoint {
360                    size: 8,
361                    byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
362                    bit_offset: 0, bit_precision: 64,
363                    exponent_location: 52, exponent_size: 11,
364                    mantissa_location: 0, mantissa_size: 52,
365                    exponent_bias: 1023,
366                }
367            },
368            8usize,
369            quote! { f64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
370        ),
371        "f32" => (
372            quote! {
373                rustyhdf5_format::datatype::Datatype::FloatingPoint {
374                    size: 4,
375                    byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
376                    bit_offset: 0, bit_precision: 32,
377                    exponent_location: 23, exponent_size: 8,
378                    mantissa_location: 0, mantissa_size: 23,
379                    exponent_bias: 127,
380                }
381            },
382            4usize,
383            quote! { f32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
384        ),
385        "i8" => (
386            int_dt_quote(1, true),
387            1usize,
388            quote! { _data[_pos] as i8 },
389        ),
390        "i16" => (
391            int_dt_quote(2, true),
392            2usize,
393            quote! { i16::from_le_bytes(_data[_pos.._pos + 2].try_into().unwrap()) },
394        ),
395        "i32" => (
396            int_dt_quote(4, true),
397            4usize,
398            quote! { i32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
399        ),
400        "i64" => (
401            int_dt_quote(8, true),
402            8usize,
403            quote! { i64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
404        ),
405        "u8" => (
406            int_dt_quote(1, false),
407            1usize,
408            quote! { _data[_pos] },
409        ),
410        "u16" => (
411            int_dt_quote(2, false),
412            2usize,
413            quote! { u16::from_le_bytes(_data[_pos.._pos + 2].try_into().unwrap()) },
414        ),
415        "u32" => (
416            int_dt_quote(4, false),
417            4usize,
418            quote! { u32::from_le_bytes(_data[_pos.._pos + 4].try_into().unwrap()) },
419        ),
420        "u64" => (
421            int_dt_quote(8, false),
422            8usize,
423            quote! { u64::from_le_bytes(_data[_pos.._pos + 8].try_into().unwrap()) },
424        ),
425        _ => {
426            return Err(syn::Error::new_spanned(
427                elem_ty,
428                format!("unsupported array element type `{elem_name}` for H5Type derive"),
429            ));
430        }
431    };
432
433    let dt = quote! {
434        rustyhdf5_format::datatype::Datatype::Array {
435            base_type: Box::new(#base_dt),
436            dimensions: vec![#len_expr as u32],
437        }
438    };
439
440    let ser = quote! {
441        for _elem in &self.#field_name {
442            _buf.extend_from_slice(&_elem.to_le_bytes());
443        }
444    };
445
446    let deser = quote! {
447        let #field_name = {
448            let mut _arr = [Default::default(); #len_expr];
449            for _i in 0..#len_expr {
450                _arr[_i] = #deser_one;
451                _pos += #elem_size;
452            }
453            _arr
454        };
455    };
456
457    let size_expr = quote! { (#len_expr * #elem_size) };
458    Ok((dt, ser, deser, size_expr))
459}
460
461fn int_dt_quote(size: u32, signed: bool) -> proc_macro2::TokenStream {
462    let precision = (size * 8) as u16;
463    quote! {
464        rustyhdf5_format::datatype::Datatype::FixedPoint {
465            size: #size,
466            byte_order: rustyhdf5_format::datatype::DatatypeByteOrder::LittleEndian,
467            signed: #signed,
468            bit_offset: 0,
469            bit_precision: #precision,
470        }
471    }
472}