sqlx_derive_with/
lib.rs

1#![doc = include_str!("../README.md")]
2
3/// Derive `sqlx::FromRow` specific to the given database.
4///
5/// The original derive macro of `sqlx::FromRow` is database-agnostic but too generic to define
6/// custom decoder for specific columns. For example, `sqlx::FromRow` cannot support custom decoder
7/// like below.
8///
9/// ```ignore
10/// #[derive(sqlx::FromRow)]
11/// struct Row {
12///     #[sqlx(decode = "split_x")]
13///     x: (i64, i64),
14/// }
15///
16/// fn split_x<'r, R>(index: &'r str, row: &'r R) -> sqlx::Result<(i64, i64)>
17/// where
18///     R: sqlx::Row,
19///     &'r str: sqlx::ColumnIndex<R>,
20///     i64: sqlx::Type<R::Database> + sqlx::Decode<'r, R::Database>,
21/// {
22///     let n: i64 = row.try_get(index)?;
23///     Ok((n, n + 2))
24/// }
25/// ```
26///
27/// The reason is `sqlx::FromRow` cannot add `i64: sqlx::Type<R::Database> + sqlx::Decode<'r, R::Database>`
28/// to the derived implementation since it cannot see `row.try_get()` usage from the struct
29/// definition.
30///
31/// sqlx-derive-with resolves the problem by specifying database.
32///
33/// # Usage
34/// Basic usage is similar to `sqlx::FromRow`.
35///
36/// ```
37/// #[derive(sqlx_derive_with::FromRow)]
38/// #[sqlx_with(db = "sqlx::Sqlite")]
39/// struct Row {
40///     x: i64,
41///     y: String,
42/// }
43/// ```
44///
45/// You have to specify `db`.
46///
47/// ```compile_fail
48/// #[derive(sqlx_derive_with::FromRow)]
49/// struct Row {
50///     x: i64,
51///     y: String,
52/// }
53/// ```
54///
55/// You cannot use sqlx-derive-with to tuple structs. Use the original `sqlx::FromRow`
56/// instead.
57///
58/// ```compile_fail
59/// #[derive(sqlx_derive_with::FromRow)]
60/// #[sqlx_with(db = "sqlx::Sqlite")]
61/// struct Row(i64, String);
62/// ```
63///
64/// # Container attributes
65/// ## rename_all
66/// Specify column name conversion.
67///
68/// ```
69/// #[derive(sqlx_derive_with::FromRow)]
70/// #[sqlx_with(db = "sqlx::Sqlite", rename_all = "camelCase")]
71/// struct Row {
72///     foo_bar: i64,   // deserialized from column "fooBar"
73/// }
74/// ```
75///
76/// # Field attributes
77/// ## rename
78/// Configure column name explicitly. `rename` takes precedence over `rename_all`.
79///
80/// ```
81/// #[derive(sqlx_derive_with::FromRow)]
82/// #[sqlx_with(db = "sqlx::Sqlite")]
83/// struct Row {
84///     #[sqlx_with(rename = "z")]
85///     x: i64, // deserialized from column "z"
86///     y: String,  // deserialized from column "y"
87/// }
88/// ```
89///
90/// ## default
91/// Use `Default::default()` value when the column doesn't exist..
92///
93/// ```
94/// #[derive(sqlx_derive_with::FromRow)]
95/// #[sqlx_with(db = "sqlx::Sqlite")]
96/// struct Row {
97///     #[sqlx_with(default)]
98///     x: i64, // i64::default() value is set when column "x" doesn't exist.
99///     y: String,
100/// }
101/// ```
102///
103/// ## flatten
104///
105/// ```
106/// #[derive(sqlx_derive_with::FromRow)]
107/// #[sqlx_with(db = "sqlx::Sqlite")]
108/// struct Row {
109///     x: i64,
110///     #[sqlx_with(flatten)]
111///     y: Y,
112/// }
113/// #[derive(sqlx_derive_with::FromRow)]
114/// #[sqlx_with(db = "sqlx::Sqlite")]
115/// struct Y {
116///     z: i64,
117///     w: i64,
118/// }
119/// ```
120///
121/// ## decode
122/// Configure custom decode function to specific columns.
123///
124/// ```
125/// #[derive(sqlx_derive_with::FromRow)]
126/// #[sqlx_with(db = "sqlx::Sqlite")]
127/// struct Row {
128///     #[sqlx_with(decode = "split_x")]
129///     x: (i64, i64),
130///     y: String,
131/// }
132///
133/// fn split_x(index: &str, row: &sqlx::sqlite::SqliteRow) -> sqlx::Result<(i64, i64)> {
134///     use sqlx::Row as _;
135///     let n: i64 = row.try_get(index)?;
136///     Ok((n, n + 2))
137/// }
138/// ```
139#[proc_macro_derive(FromRow, attributes(sqlx_with))]
140pub fn derive_sqlx_with(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
141    let input = syn::parse_macro_input!(input as syn::DeriveInput);
142
143    match expand_derive(input) {
144        Ok(ts) => ts.into(),
145        Err(e) => e.to_compile_error().into(),
146    }
147}
148
149#[derive(Debug, darling::FromMeta)]
150enum RenameAll {
151    #[darling(rename = "snake_case")]
152    Snake,
153    #[darling(rename = "lowercase")]
154    Lower,
155    #[darling(rename = "UPPERCASE")]
156    Upper,
157    #[darling(rename = "camelCase")]
158    Camel,
159    #[darling(rename = "PascalCase")]
160    Pascal,
161    #[darling(rename = "SCREAMING_SNAKE_CASE")]
162    ScreamingSnake,
163    #[darling(rename = "kebab-case")]
164    Kebab,
165}
166
167#[derive(Debug, darling::FromDeriveInput)]
168#[darling(attributes(sqlx_with), supports(struct_named))]
169struct DeriveInput {
170    ident: syn::Ident,
171    generics: syn::Generics,
172    data: darling::ast::Data<(), Field>,
173    db: syn::Path,
174    rename_all: Option<RenameAll>,
175}
176
177#[derive(Debug, darling::FromField)]
178#[darling(attributes(sqlx_with))]
179struct Field {
180    ident: Option<syn::Ident>,
181    ty: syn::Type,
182    rename: Option<String>,
183    default: darling::util::Flag,
184    decode: Option<syn::Path>,
185    flatten: darling::util::Flag,
186}
187
188fn expand_derive(input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
189    use darling::FromDeriveInput as _;
190
191    let input = DeriveInput::from_derive_input(&input)?;
192
193    let mut struct_expr: syn::ExprStruct = syn::parse_quote!(Self {});
194    for field in input.data.take_struct().unwrap().fields {
195        let id = field.ident.unwrap();
196        let column_val_expr: syn::Expr = if field.flatten.is_present() {
197            let ty = field.ty;
198            syn::parse_quote!(#ty::from_row(row)?)
199        } else {
200            let column_name = if let Some(rename) = field.rename {
201                rename
202            } else if let Some(ref rename_all) = input.rename_all {
203                use heck::*;
204
205                match rename_all {
206                    RenameAll::Snake => id.to_string().to_snake_case(),
207                    RenameAll::Lower => id.to_string().to_lowercase(),
208                    RenameAll::Upper => id.to_string().to_uppercase(),
209                    RenameAll::Camel => id.to_string().to_lower_camel_case(),
210                    RenameAll::Pascal => id.to_string().to_upper_camel_case(),
211                    RenameAll::ScreamingSnake => id.to_string().to_shouty_snake_case(),
212                    RenameAll::Kebab => id.to_string().to_kebab_case(),
213                }
214            } else {
215                id.to_string()
216            };
217            let column_get_expr: syn::Expr = if let Some(decode) = field.decode {
218                syn::parse_quote!(#decode(#column_name, row))
219            } else {
220                syn::parse_quote!(row.try_get(#column_name))
221            };
222            if field.default.is_present() {
223                syn::parse_quote! {
224                    match #column_get_expr {
225                        ::std::result::Result::Err(::sqlx::Error::ColumnNotFound(_)) => ::std::result::Result::Ok(::std::default::Default::default()),
226                        val => val,
227                    }?
228                }
229            } else {
230                syn::parse_quote!(#column_get_expr?)
231            }
232        };
233        struct_expr
234            .fields
235            .push(syn::parse_quote!(#id: #column_val_expr));
236    }
237
238    let struct_ident = input.ident;
239    let db = input.db;
240    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
241    Ok(quote::quote! {
242        impl #impl_generics ::sqlx::FromRow<'_, <#db as ::sqlx::Database>::Row> for #struct_ident #type_generics #where_clause {
243            fn from_row(row: &<#db as ::sqlx::Database>::Row) -> ::sqlx::Result<Self> {
244                use ::sqlx::Row;
245                Ok(#struct_expr)
246            }
247        }
248    })
249}