scyllax_macros_core/queries/
upsert.rs

1use darling::{ast::NestedMeta, FromDeriveInput, FromMeta};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{DeriveInput, ItemStruct};
5
6use crate::entity::{EntityDerive, EntityDeriveColumn};
7
8#[derive(FromMeta)]
9pub(crate) struct UpsertQueryOptions {
10    pub name: syn::Ident,
11    pub table: String,
12    pub ttl: Option<bool>,
13}
14
15/// Attribute expand
16/// Just adds the dervie macro to the struct.
17pub fn expand(args: TokenStream, input: TokenStream) -> TokenStream {
18    let attr_args = match NestedMeta::parse_meta_list(args.clone()) {
19        Ok(args) => args,
20        Err(e) => return darling::Error::from(e).write_errors(),
21    };
22
23    let args = match UpsertQueryOptions::from_list(&attr_args) {
24        Ok(o) => o,
25        Err(e) => return e.write_errors(),
26    };
27
28    let input: DeriveInput = match syn::parse2(input.clone()) {
29        Ok(it) => it,
30        Err(e) => return e.to_compile_error(),
31    };
32
33    let entity = match EntityDerive::from_derive_input(&input) {
34        Ok(e) => e,
35        Err(e) => return e.write_errors(),
36    };
37
38    upsert_impl(&input, &args, &entity)
39}
40
41/// Create the implementation for the upsert query
42pub(crate) fn upsert_impl(
43    input: &DeriveInput,
44    opt: &UpsertQueryOptions,
45    entity: &EntityDerive,
46) -> TokenStream {
47    let upsert_struct = &opt.name;
48    let upsert_table = &opt.table;
49    let struct_ident = &input.ident;
50    let keys = entity
51        .data
52        .as_ref()
53        .take_struct()
54        .expect("Should never be enum")
55        .fields;
56    let primary_keys: Vec<&&EntityDeriveColumn> = keys.iter().filter(|f| f.primary_key).collect();
57    let counters: Vec<&&EntityDeriveColumn> = keys.iter().filter(|f| f.counter).collect();
58
59    let input: ItemStruct = match syn::parse2(input.to_token_stream()) {
60        Ok(it) => it,
61        Err(e) => return e.to_compile_error(),
62    };
63
64    let expanded_pks = primary_keys
65        .iter()
66        .map(|f| {
67            let ident = &f.ident;
68            let ty = &f.ty;
69            let comment = format!("The {} of the {}", ident.as_ref().unwrap(), struct_ident);
70
71            quote! {
72                #[doc = #comment]
73                pub #ident: #ty
74            }
75        })
76        .collect::<Vec<_>>();
77
78    let maybe_unset_fields = keys
79        .iter()
80        .filter(|f| !primary_keys.contains(f))
81        .map(|f| {
82            let ident = &f.ident;
83            let ty = &f.ty;
84            let comment = format!("The {} of the {}", ident.as_ref().unwrap(), struct_ident);
85
86            quote! {
87                #[doc = #comment]
88                pub #ident: scyllax::prelude::MaybeUnset<#ty>
89            }
90        })
91        .collect::<Vec<_>>();
92
93    let ttl = if opt.ttl.unwrap_or(false) {
94        quote! {
95            #[doc = "The ttl of the row in seconds"]
96            pub set_ttl: i32,
97        }
98    } else {
99        quote! {}
100    };
101
102    let docs = format!(
103        "Upserts a {} into the `{}` table",
104        struct_ident, upsert_table
105    );
106    let expanded_upsert_struct = quote! {
107        #[doc = #docs]
108        #[derive(Debug, Clone, scylla::SerializeRow)]
109        pub struct #upsert_struct {
110            #(#expanded_pks,)*
111            #(#maybe_unset_fields,)*
112            #ttl
113        }
114    };
115
116    // SET clauses
117    // expanded variables will loop over every field that isn't Pk
118    let set_clauses = keys
119        .iter()
120        // filter out pks
121        .filter(|f| !primary_keys.contains(f))
122        .map(|f| {
123            let ident = &f.ident.clone().unwrap();
124            let col = f.name.as_ref().unwrap();
125            let ident_string = ident.to_string();
126
127            if counters.contains(&f) {
128                format!("{col} = {col} + :{ident_string}")
129            } else {
130                format!("{col} = :{ident_string}")
131            }
132        })
133        .collect::<Vec<_>>();
134
135    // WHERE clauses
136    let where_clauses = keys
137        .iter()
138        // filter out pks
139        .filter(|f| primary_keys.contains(f))
140        .map(|f| {
141            let ident = &f.ident.clone().unwrap();
142            let col = f.name.as_ref().unwrap();
143            let named_var = ident.to_string();
144
145            (col.clone(), named_var.clone())
146        })
147        .collect::<Vec<_>>();
148
149    // if there are no set clauses, then we need to do an insert
150    // because we can't do an update with no set clauses
151    let query = build_query(opt, upsert_table, set_clauses, where_clauses);
152
153    quote! {
154        #input
155
156        #expanded_upsert_struct
157
158        #[scyllax::prelude::async_trait]
159        impl scyllax::prelude::Query for #upsert_struct {
160            fn query() -> String {
161                #query.to_string()
162            }
163        }
164
165        impl scyllax::prelude::WriteQuery for #upsert_struct {}
166    }
167}
168
169fn build_query(
170    args: &UpsertQueryOptions,
171    table: &String,
172    set_clauses: Vec<String>,
173    where_clauses: Vec<(String, String)>,
174) -> String {
175    let ttl = match args.ttl.unwrap_or(false) {
176        true => " using ttl :set_ttl",
177        _ => "",
178    };
179
180    if set_clauses.is_empty() {
181        let mut query = format!("insert into {table}{ttl}");
182        let (cols, named_var) = where_clauses.into_iter().unzip::<_, _, Vec<_>, Vec<_>>();
183        let cols = cols.join(", ");
184        let named_var = named_var
185            .into_iter()
186            .map(|var| format!(":{var}"))
187            .collect::<Vec<_>>()
188            .join(", ");
189
190        query.push_str(&format!(" ({cols}) values ({named_var});"));
191
192        query
193    } else {
194        let mut query = format!("update {table}{ttl} set ");
195        let query_set = set_clauses.join(", ");
196        query.push_str(&query_set);
197
198        query.push_str(" where ");
199        let query_where = where_clauses
200            .into_iter()
201            .map(|(col, ident_string)| format!("{col} = :{ident_string}"))
202            .collect::<Vec<_>>()
203            .join(" and ");
204        query.push_str(&query_where);
205
206        query.push(';');
207
208        query
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    fn get_set_clauses() -> Vec<String> {
217        vec![
218            "name = :name",
219            "email = :email",
220            "\"createdAt\" = :created_at",
221        ]
222        .into_iter()
223        .map(|x| x.to_string())
224        .collect::<Vec<_>>()
225    }
226
227    fn get_where_clauses() -> Vec<(String, String)> {
228        vec![("id", "id"), (r#""orgId""#, "org_id")]
229            .into_iter()
230            .map(|(x, y)| (x.to_string(), y.to_string()))
231            .collect::<Vec<_>>()
232    }
233
234    #[test]
235    fn test_update() {
236        let query = build_query(
237            &UpsertQueryOptions {
238                name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
239                table: "person".to_string(),
240                ttl: None,
241            },
242            &"person".to_string(),
243            get_set_clauses(),
244            get_where_clauses(),
245        );
246
247        assert_eq!(
248            query,
249            "update person set name = :name, email = :email, \"createdAt\" = :created_at where id = :id and \"orgId\" = :org_id;",
250        );
251    }
252
253    #[test]
254    fn test_update_ttl() {
255        let query = build_query(
256            &UpsertQueryOptions {
257                name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
258                table: "person".to_string(),
259                ttl: Some(true),
260            },
261            &"person".to_string(),
262            get_set_clauses(),
263            get_where_clauses(),
264        );
265
266        assert_eq!(
267            query,
268            "update person using ttl :set_ttl set name = :name, email = :email, \"createdAt\" = :created_at where id = :id and \"orgId\" = :org_id;",
269        );
270    }
271
272    #[test]
273    fn test_insert() {
274        let query = build_query(
275            &UpsertQueryOptions {
276                name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
277                table: "person".to_string(),
278                ttl: Default::default(),
279            },
280            &"person".to_string(),
281            vec![],
282            get_where_clauses(),
283        );
284
285        assert_eq!(
286            query,
287            "insert into person (id, \"orgId\") values (:id, :org_id);",
288        );
289    }
290
291    #[test]
292    fn test_insert_ttl() {
293        let query = build_query(
294            &UpsertQueryOptions {
295                name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
296                table: "person".to_string(),
297                ttl: Some(true),
298            },
299            &"person".to_string(),
300            vec![],
301            get_where_clauses(),
302        );
303
304        assert_eq!(
305            query,
306            "insert into person using ttl :set_ttl (id, \"orgId\") values (:id, :org_id);",
307        );
308    }
309}