stargate_grpc_derive/
lib.rs

1//! # Derive macros for mapping between `Value` and Rust structs
2//!
3//! Converting structures from/to hash maps manually is tedious.
4//! This module defines a few derive macros that can generate the conversion code automatically
5//! for you.
6//!
7//! ## Converting a custom Rust struct to a `Value`
8//! ```
9//! use stargate_grpc::Value;
10//! use stargate_grpc_derive::IntoValue;
11//! #[derive(IntoValue)]
12//! struct User {
13//!     id: i64,
14//!     login: &'static str
15//! }
16//!
17//! let user = User { id: 1, login: "user" };
18//! let value = Value::from(user);
19//!
20//! assert_eq!(value, Value::udt(vec![("id", Value::bigint(1)), ("login", Value::string("user"))]))
21//! ```
22//!
23//! ## Converting a `Value` to a custom Rust struct
24//! ```
25//! use stargate_grpc::Value;
26//! use stargate_grpc_derive::TryFromValue;
27//!
28//! #[derive(TryFromValue)]
29//! struct User {
30//!     id: i64,
31//!     login: String
32//! }
33//!
34//! let value = Value::udt(vec![("id", Value::bigint(1)), ("login", Value::string("user"))]);
35//! let user: User = value.try_into().unwrap();
36//!
37//! assert_eq!(user.id, 1);
38//! assert_eq!(user.login, "user".to_string());
39//! ```
40//!
41//! ## Using custom structs as arguments in queries
42//! It is possible to unpack struct fields in such a way that each field value
43//! gets bound to a named argument of a query. For that to work, the struct must implement
44//! [`std::convert::Into<Values>`] trait. You can derive such trait automatically:
45//!
46//! ```
47//! use stargate_grpc::Query;
48//! use stargate_grpc_derive::IntoValues;
49//!
50//! #[derive(IntoValues)]
51//! struct User {
52//!     id: i64,
53//!     login: &'static str
54//! }
55//!
56//! let user = User { id: 1, login: "user" };
57//! let query = Query::builder()
58//!     .query("INSERT INTO users(id, login) VALUES (:id, :login)")
59//!     .bind(user)  // bind user.id to :id and user.login to :login
60//!     .build();
61//! ```
62//! ## Converting result set rows to custom struct values
63//! You can convert a `Row` to a value of your custom type by deriving
64//! [`TryFromRow`] and then passing the rows to a mapper:
65//!
66//! ```no_run
67//! use stargate_grpc::*;
68//! use stargate_grpc_derive::*;
69//!
70//! #[derive(Debug, TryFromRow)]
71//! struct User {
72//!     id: i64,
73//!     login: String,
74//! }
75//!
76//! let result_set: ResultSet = unimplemented!();  // replace with actual code to run a query
77//! let mapper = result_set.mapper().unwrap();
78//! for row in result_set.rows {
79//!     let user: User = mapper.try_unpack(row).unwrap();
80//!     println!("{:?}", user)
81//! }
82//!
83//! ```
84//!
85//! ## Options
86//! All macros defined in this module accept a `#[stargate]` attribute that you can set
87//! on struct fields to control the details of how the conversion should be made.
88//!
89//! ### `#[stargate(skip)]`
90//! Skips the field when doing the conversion to `Value`. This is useful when the structure
91//! needs to store some data that are not mapped to the database schema.
92//! However, the field is included in the conversion from `Value`, and the conversion would fail
93//! if it was missing, hence you likely need to set `#[stargate(default)]` as well.
94//!
95//! ### `#[stargate(default)]`
96//! Uses the default value for the field type provided by [`std::default::Default`],
97//! if the source `Value` doesn't contain the field, or if the field is set to `Value::null`
98//! or `Value::unset`.
99//!
100//! ### `#[stargate(default = "expression")]`
101//! Obtains the default value by evaluating given Rust expression given as a string.
102//!
103//! ```
104//! use stargate_grpc_derive::TryFromValue;
105//!
106//! fn default_file_name() -> String {
107//!     "file.txt".to_string()
108//! }
109//!
110//! #[derive(TryFromValue)]
111//! struct File {
112//!     #[stargate(default = "default_file_name()")]
113//!     path: String,
114//! }
115//! ```
116//!
117//! ### `#[stargate(cql_type = "type")]`
118//! Sets the target CQL type the field should be converted into, useful
119//! when there are multiple possibilities.
120//!
121//! ```
122//! use stargate_grpc::types;
123//! use stargate_grpc_derive::IntoValue;
124//!
125//! #[derive(IntoValue)]
126//! struct InetAndUuid {
127//!     #[stargate(cql_type = "types::Inet")]
128//!     inet: [u8; 16],
129//!     #[stargate(cql_type = "types::Uuid")]
130//!     uuid: [u8; 16],
131//! }
132//! ```
133//!
134//! ### `#[stargate(name = "column")]`
135//! Sets the CQL field, column or query argument name associated with the field.
136//! If not given, it is assumed to be the same as struct field name.
137//!
138use proc_macro::TokenStream;
139
140use darling::util::Override;
141use darling::{ast, util, FromDeriveInput, FromField};
142use quote::quote;
143use syn::__private::TokenStream2;
144
145#[derive(Debug, FromField)]
146#[darling(attributes(stargate))]
147struct UdtField {
148    ident: Option<syn::Ident>,
149    ty: syn::Type,
150    #[darling(default)]
151    default: Option<Override<String>>,
152    #[darling(default)]
153    cql_type: Option<String>,
154    #[darling(default)]
155    skip: bool,
156    #[darling(default)]
157    name: Option<String>,
158}
159
160#[derive(Debug, FromDeriveInput)]
161struct Udt {
162    ident: syn::Ident,
163    data: ast::Data<util::Ignored, UdtField>,
164}
165
166fn get_fields(udt: ast::Data<util::Ignored, UdtField>) -> Vec<UdtField> {
167    match udt {
168        ast::Data::Struct(s) => s.fields,
169        _ => panic!("Deriving IntoValue allowed only on structs"),
170    }
171}
172
173fn field_idents(fields: &[UdtField]) -> Vec<&syn::Ident> {
174    fields.iter().map(|f| f.ident.as_ref().unwrap()).collect()
175}
176
177/// Lists the field names of the associated Udt, Row or Values.
178fn field_names(fields: &[UdtField]) -> Vec<String> {
179    fields
180        .iter()
181        .map(|f| {
182            f.name
183                .clone()
184                .unwrap_or_else(|| f.ident.as_ref().unwrap().to_string())
185        })
186        .collect()
187}
188
189fn token_stream(s: &str) -> proc_macro2::TokenStream {
190    s.parse().unwrap()
191}
192
193/// Emits code for reading the field value and converting it to a `Value`.
194fn convert_to_value(obj: &syn::Ident, field: &UdtField) -> TokenStream2 {
195    let field_ident = field.ident.as_ref().unwrap();
196    match &field.cql_type {
197        Some(t) => {
198            let cql_type = token_stream(t.as_str());
199            quote! { stargate_grpc::Value::of_type(#cql_type, #obj.#field_ident) }
200        }
201        None => {
202            quote! { stargate_grpc::Value::from(#obj.#field_ident) }
203        }
204    }
205}
206
207/// For each field, returns an expression that converts that field's value to a `Value`.
208fn convert_to_values(obj: &syn::Ident, fields: &[UdtField]) -> Vec<TokenStream2> {
209    fields.iter().map(|f| convert_to_value(obj, f)).collect()
210}
211
212/// Derives the `IntoValue` and `DefaultCqlType` implementations for a struct.
213#[proc_macro_derive(IntoValue, attributes(stargate))]
214pub fn derive_into_value(tokens: TokenStream) -> TokenStream {
215    let parsed = syn::parse(tokens).unwrap();
216    let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
217    let udt_type = udt.ident;
218
219    let obj = syn::Ident::new("obj", proc_macro2::Span::mixed_site());
220    let fields: Vec<_> = get_fields(udt.data)
221        .into_iter()
222        .filter(|f| !f.skip)
223        .collect();
224    let remote_field_names = field_names(&fields);
225    let field_values: Vec<_> = convert_to_values(&obj, &fields);
226
227    let result = quote! {
228        impl stargate_grpc::into_value::IntoValue<stargate_grpc::types::Udt> for #udt_type {
229            fn into_value(self) -> stargate_grpc::Value {
230                let #obj = self;
231                let mut fields = std::collections::HashMap::new();
232                #(fields.insert(#remote_field_names.to_string(), #field_values));*;
233                stargate_grpc::Value::raw_udt(fields)
234            }
235        }
236        impl stargate_grpc::into_value::DefaultCqlType for #udt_type {
237            type C = stargate_grpc::types::Udt;
238        }
239    };
240    result.into()
241}
242
243/// Derives the `IntoValues` impl that allows to use struct in `QueryBuilder::bind`
244#[proc_macro_derive(IntoValues, attributes(stargate))]
245pub fn derive_into_values(tokens: TokenStream) -> TokenStream {
246    let parsed = syn::parse(tokens).unwrap();
247    let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
248    let udt_type = udt.ident;
249
250    let obj = syn::Ident::new("obj", proc_macro2::Span::mixed_site());
251    let fields: Vec<_> = get_fields(udt.data)
252        .into_iter()
253        .filter(|f| !f.skip)
254        .collect();
255    let field_names = field_names(&fields);
256    let field_values: Vec<_> = convert_to_values(&obj, &fields);
257
258    let result = quote! {
259        impl std::convert::From<#udt_type> for stargate_grpc::proto::Values {
260            fn from(#obj: #udt_type) -> Self {
261                stargate_grpc::proto::Values {
262                     value_names: vec![#(#field_names.to_string()),*],
263                     values: vec![#(#field_values),*]
264                }
265            }
266        }
267    };
268    result.into()
269}
270
271/// Emits code for reading the field from a hashmap and converting it to proper type.
272/// Applies default value if the key is missing in the hashmap or if the value
273/// under the key is null.
274fn convert_from_hashmap_value(hashmap: &syn::Ident, field: &UdtField) -> TokenStream2 {
275    let field_name = field
276        .name
277        .clone()
278        .unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
279    let field_type = &field.ty;
280
281    let default_expr = match &field.default {
282        None => quote! { Err(ConversionError::field_not_found::<_, Self>(&#hashmap, #field_name)) },
283        Some(Override::Inherit) => quote! { Ok(std::default::Default::default()) },
284        Some(Override::Explicit(s)) => {
285            let path = token_stream(s);
286            quote! { Ok(#path) }
287        }
288    };
289
290    quote! {
291        match #hashmap.remove(#field_name) {
292            Some(value) => {
293                let maybe_value: Option<#field_type> = value.try_into()?;
294                match maybe_value {
295                    Some(v) => Ok(v),
296                    None => #default_expr
297                }
298            }
299            None => #default_expr
300        }
301    }
302}
303
304/// Derives the `TryFromValue` implementation for a struct.
305#[proc_macro_derive(TryFromValue, attributes(stargate))]
306pub fn derive_try_from_value(tokens: TokenStream) -> TokenStream {
307    let parsed = syn::parse(tokens).unwrap();
308    let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
309    let ident = udt.ident;
310    let fields = get_fields(udt.data);
311    let field_idents = field_idents(&fields);
312    let udt_hashmap = syn::Ident::new("fields", proc_macro2::Span::mixed_site());
313    let field_values = fields
314        .iter()
315        .map(|field| convert_from_hashmap_value(&udt_hashmap, field));
316
317    let result = quote! {
318
319        impl stargate_grpc::from_value::TryFromValue for #ident {
320            fn try_from(value: stargate_grpc::Value) ->
321                Result<Self, stargate_grpc::error::ConversionError>
322            {
323                use stargate_grpc::Value;
324                use stargate_grpc::error::ConversionError;
325                use stargate_grpc::proto::*;
326                match value.inner {
327                    Some(value::Inner::Udt(UdtValue { mut #udt_hashmap })) => {
328                        Ok(#ident {
329                            #(#field_idents: #field_values?),*
330                        })
331                    }
332                    other => Err(ConversionError::incompatible::<_, Self>(other))
333                }
334            }
335        }
336
337        impl std::convert::TryFrom<stargate_grpc::Value> for #ident {
338            type Error = stargate_grpc::error::ConversionError;
339            fn try_from(value: stargate_grpc::Value) ->
340                Result<Self, stargate_grpc::error::ConversionError>
341            {
342                <#ident as stargate_grpc::from_value::TryFromValue>::try_from(value)
343            }
344        }
345    };
346
347    result.into()
348}
349
350/// Derives the `TryFromRow` implementation for a struct.
351#[proc_macro_derive(TryFromRow, attributes(stargate))]
352pub fn derive_try_from_typed_row(tokens: TokenStream) -> TokenStream {
353    let parsed = syn::parse(tokens).unwrap();
354    let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
355    let ident = udt.ident;
356    let fields = get_fields(udt.data);
357    let field_idents = field_idents(&fields);
358    let field_names = field_names(&fields);
359    let indexes = 0..field_idents.len();
360
361    let result = quote! {
362        impl stargate_grpc::result::ColumnPositions for #ident {
363            fn field_to_column_pos(
364                column_positions: std::collections::HashMap<String, usize>
365            ) -> Result<Vec<usize>, stargate_grpc::result::MapperError>
366            {
367                use stargate_grpc::result::MapperError;
368                let mut result = Vec::new();
369                #(
370                    result.push(
371                        *column_positions
372                            .get(#field_names)
373                            .ok_or_else(|| MapperError::ColumnNotFound(#field_names))?
374                    );
375                )*
376                Ok(result)
377            }
378        }
379
380        impl stargate_grpc::result::TryFromRow for #ident {
381            fn try_unpack(
382                mut row: stargate_grpc::Row,
383                column_positions: &[usize]
384            ) -> Result<Self, stargate_grpc::error::ConversionError>
385            {
386                Ok(#ident {
387                    #(#field_idents: row.values[column_positions[#indexes]].take().try_into()?),*
388                })
389            }
390        }
391    };
392
393    result.into()
394}