simconnect_sdk_derive/
lib.rs

1//! This crate provides the [`crate::SimConnectObject`] derive macro of simconnect-sdk.
2
3extern crate proc_macro;
4
5use std::collections::HashMap;
6
7use fields::{extract_attribute_properties, parse_field_attributes, ALLOWED_CLASS_ATTRIBUTES};
8use helpers::{get_attribute, mk_err};
9use proc_macro::TokenStream;
10use quote::quote;
11use syn::{parse_macro_input, DeriveInput};
12
13mod fields;
14mod helpers;
15
16/// SimConnectObject derive macro.
17///
18/// # Struct Arguments
19/// * `period` - Required. One of `once`, `visual-frame`, `sim-frame`, `second`.
20/// * `condition` - Optional. Defaults to `none`. The condition of the data. Must be either `none` or `changed`. `changed` = Data will only be sent to the client when one or more values have changed. All the variables in a data definition will be returned if just one of the values changes.
21/// * `interval` - Optional. Defaults to `0`. The number of period events that should elapse between transmissions of the data. `0` means the data is transmitted every Period, `1` means that the data is transmitted every other Period, etc.
22///
23/// # Field Arguments
24/// * `name` - Required. The name of the field. One from <https://www.prepar3d.com/SDKv5/sdk/references/variables/simulation_variables.html>.
25/// * `unit` - Optional. The unit of the field. For `string`s and `bool`s it should be left out or be empty string. For numeric fields it should be one from <https://www.prepar3d.com/SDKv5/sdk/references/variables/simulation_variables.html>.
26///
27/// # Example
28///
29/// ```rust
30/// # use simconnect_sdk_derive::SimConnectObject;
31///
32/// #[derive(Debug, Clone, SimConnectObject)]
33/// #[simconnect(period = "second")]
34/// struct AirplaneData {
35///     #[simconnect(name = "TITLE")]
36///     title: String,
37///     #[simconnect(name = "CATEGORY")]
38///     category: String,
39///     #[simconnect(name = "PLANE LATITUDE", unit = "degrees")]
40///     lat: f64,
41///     #[simconnect(name = "PLANE LONGITUDE", unit = "degrees")]
42///     lon: f64,
43///     #[simconnect(name = "PLANE ALTITUDE", unit = "feet")]
44///     alt: f64,
45///     #[simconnect(name = "SIM ON GROUND", unit = "bool")]
46///     sim_on_ground: bool,
47/// }
48/// ```
49#[proc_macro_derive(SimConnectObject, attributes(simconnect))]
50pub fn derive(input: TokenStream) -> TokenStream {
51    let ast = parse_macro_input!(input as DeriveInput);
52
53    let name_ident = &ast.ident;
54    let packed_ident = syn::Ident::new(&format!("{name_ident}CPacked"), name_ident.span());
55
56    let fields = if let syn::Data::Struct(syn::DataStruct {
57        fields: syn::Fields::Named(syn::FieldsNamed { ref named, .. }),
58        ..
59    }) = ast.data
60    {
61        named
62    } else {
63        return mk_err(
64            ast,
65            "Unsupported field type. Only named fields are supported.",
66        )
67        .into();
68    };
69
70    // parse the fields and their attributes
71    let mut parsed_fields = Vec::with_capacity(fields.len());
72    for field in fields {
73        let result = parse_field_attributes(field);
74
75        match result {
76            Ok(field) => {
77                parsed_fields.push(field);
78            }
79            Err(e) => return e.into(),
80        }
81    }
82
83    // packed struct fields
84    let packed_fields = parsed_fields
85        .iter()
86        .map(|(ident, path, _)| build_packed_field(ident, path));
87    let packed_fields_assignments = parsed_fields
88        .iter()
89        .map(|(ident, path, _)| build_packed_field_assignment(ident, path));
90
91    // SC fields
92    let sc_definition = parsed_fields
93        .iter()
94        .map(|(_, path, properties)| build_sc_definition(path, properties));
95    let sc_request = build_sc_request(&ast);
96
97    // put everything together
98    let expanded = quote! {
99        #[repr(C, packed)]
100        struct #packed_ident {
101            #(#packed_fields,)*
102        }
103        impl simconnect_sdk::SimConnectObjectExt for #name_ident {
104            fn register(client: &mut simconnect_sdk::SimConnect, id: u32) -> Result<(), simconnect_sdk::SimConnectError> {
105                #(#sc_definition)*
106                #sc_request
107                Ok(())
108            }
109        }
110        impl TryFrom<&simconnect_sdk::Object> for #name_ident {
111            type Error = simconnect_sdk::SimConnectError;
112            fn try_from(value: &simconnect_sdk::Object) -> Result<Self, Self::Error> {
113                let raw = value.try_transmute::<#name_ident, #packed_ident>()?;
114                Ok(#name_ident {
115                    #(#packed_fields_assignments,)*
116                })
117            }
118        }
119    };
120
121    expanded.into()
122}
123
124fn build_packed_field(ident: &proc_macro2::Ident, path: &syn::Path) -> proc_macro2::TokenStream {
125    let path_segments = &path.segments;
126    let path_idents = path_segments.iter().map(|s| &s.ident);
127
128    match path_idents.last() {
129        Some(value) if value == "String" => {
130            quote! {
131                #ident: [std::primitive::i8; 256]
132            }
133        }
134        _ => {
135            quote! {
136                #ident: #path
137            }
138        }
139    }
140}
141
142fn build_packed_field_assignment(
143    ident: &proc_macro2::Ident,
144    path: &syn::Path,
145) -> proc_macro2::TokenStream {
146    let path_segments = &path.segments;
147    let path_idents = path_segments.iter().map(|s| &s.ident);
148
149    match path_idents.last() {
150        Some(value) if value == "String" => {
151            quote! {
152                #ident: simconnect_sdk::fixed_c_str_to_string(&raw.#ident)
153            }
154        }
155        _ => {
156            quote! {
157                #ident: raw.#ident
158            }
159        }
160    }
161}
162
163fn build_sc_definition(
164    path: &syn::Path,
165    properties: &HashMap<String, String>,
166) -> proc_macro2::TokenStream {
167    let error_message =
168        "expected attribute `#[simconnect(name = \"...\", unit = \"...\")]`. `unit` is optional.";
169
170    let path_segments = &path.segments;
171    let path_idents = path_segments.iter().map(|s| &s.ident);
172
173    let name = properties.get("name").expect("this should never happen");
174    let unit = match properties.get("unit") {
175        Some(unit) => unit,
176        None => "",
177    };
178
179    match path_idents.last() {
180        Some(value) if value == "f64" => {
181            quote! {
182                client.add_to_data_definition(id, #name, #unit, simconnect_sdk::DataType::Float64)?;
183            }
184        }
185        Some(value) if value == "bool" => {
186            quote! {
187                client.add_to_data_definition(id, #name, #unit, simconnect_sdk::DataType::Bool)?;
188            }
189        }
190        Some(value) if value == "String" => {
191            quote! {
192                client.add_to_data_definition(id, #name, #unit, simconnect_sdk::DataType::String)?;
193            }
194        }
195        _ => {
196            // this error is already caught in `parse_field_attributes`
197            mk_err(path, error_message)
198        }
199    }
200}
201
202fn build_sc_request(ast: &DeriveInput) -> proc_macro2::TokenStream {
203    let attr = get_attribute(&ast.attrs);
204    let error_message = "expected attribute `#[simconnect(period = \"...\", condition = \"...\", interval = ...)]`. `condition` and `interval` are optional.";
205
206    match attr {
207        Some(attr) => {
208            let properties =
209                extract_attribute_properties(attr, &ALLOWED_CLASS_ATTRIBUTES, error_message);
210
211            match properties {
212                Ok(properties) => {
213                    let period = match properties.get("period") {
214                        Some(p) if p == "once" => {
215                            quote! {
216                                simconnect_sdk::Period::Once
217                            }
218                        }
219                        Some(p) if p == "visual-frame" => {
220                            quote! {
221                                simconnect_sdk::Period::VisualFrame
222                            }
223                        }
224                        Some(p) if p == "sim-frame" => {
225                            quote! {
226                                simconnect_sdk::Period::SimFrame
227                            }
228                        }
229                        _ => {
230                            quote! {
231                                simconnect_sdk::Period::Second
232                            }
233                        }
234                    };
235
236                    let condition = match properties.get("condition") {
237                        Some(c) if c == "changed" => {
238                            quote! {
239                                simconnect_sdk::Condition::Changed
240                            }
241                        }
242                        _ => {
243                            quote! {
244                                simconnect_sdk::Condition::None
245                            }
246                        }
247                    };
248
249                    let interval = match properties.get("interval") {
250                        Some(i) => i.parse::<u32>().unwrap_or_default(),
251                        None => 0,
252                    };
253
254                    quote! {
255                        client.request_data_on_sim_object(id, #period, #condition, #interval)?;
256                    }
257                }
258                Err(e) => e,
259            }
260        }
261        None => mk_err(ast, error_message),
262    }
263}