scyllax_macros_core/queries/
upsert.rs1use darling::{ast::NestedMeta, FromDeriveInput, FromMeta};
2use proc_macro2::TokenStream;
3use quote::{quote, ToTokens};
4use syn::{DeriveInput, ItemStruct};
5
6use crate::entity::{EntityDerive, EntityDeriveColumn};
7
8#[derive(FromMeta)]
9pub(crate) struct UpsertQueryOptions {
10 pub name: syn::Ident,
11 pub table: String,
12 pub ttl: Option<bool>,
13}
14
15pub fn expand(args: TokenStream, input: TokenStream) -> TokenStream {
18 let attr_args = match NestedMeta::parse_meta_list(args.clone()) {
19 Ok(args) => args,
20 Err(e) => return darling::Error::from(e).write_errors(),
21 };
22
23 let args = match UpsertQueryOptions::from_list(&attr_args) {
24 Ok(o) => o,
25 Err(e) => return e.write_errors(),
26 };
27
28 let input: DeriveInput = match syn::parse2(input.clone()) {
29 Ok(it) => it,
30 Err(e) => return e.to_compile_error(),
31 };
32
33 let entity = match EntityDerive::from_derive_input(&input) {
34 Ok(e) => e,
35 Err(e) => return e.write_errors(),
36 };
37
38 upsert_impl(&input, &args, &entity)
39}
40
41pub(crate) fn upsert_impl(
43 input: &DeriveInput,
44 opt: &UpsertQueryOptions,
45 entity: &EntityDerive,
46) -> TokenStream {
47 let upsert_struct = &opt.name;
48 let upsert_table = &opt.table;
49 let struct_ident = &input.ident;
50 let keys = entity
51 .data
52 .as_ref()
53 .take_struct()
54 .expect("Should never be enum")
55 .fields;
56 let primary_keys: Vec<&&EntityDeriveColumn> = keys.iter().filter(|f| f.primary_key).collect();
57 let counters: Vec<&&EntityDeriveColumn> = keys.iter().filter(|f| f.counter).collect();
58
59 let input: ItemStruct = match syn::parse2(input.to_token_stream()) {
60 Ok(it) => it,
61 Err(e) => return e.to_compile_error(),
62 };
63
64 let expanded_pks = primary_keys
65 .iter()
66 .map(|f| {
67 let ident = &f.ident;
68 let ty = &f.ty;
69 let comment = format!("The {} of the {}", ident.as_ref().unwrap(), struct_ident);
70
71 quote! {
72 #[doc = #comment]
73 pub #ident: #ty
74 }
75 })
76 .collect::<Vec<_>>();
77
78 let maybe_unset_fields = keys
79 .iter()
80 .filter(|f| !primary_keys.contains(f))
81 .map(|f| {
82 let ident = &f.ident;
83 let ty = &f.ty;
84 let comment = format!("The {} of the {}", ident.as_ref().unwrap(), struct_ident);
85
86 quote! {
87 #[doc = #comment]
88 pub #ident: scyllax::prelude::MaybeUnset<#ty>
89 }
90 })
91 .collect::<Vec<_>>();
92
93 let ttl = if opt.ttl.unwrap_or(false) {
94 quote! {
95 #[doc = "The ttl of the row in seconds"]
96 pub set_ttl: i32,
97 }
98 } else {
99 quote! {}
100 };
101
102 let docs = format!(
103 "Upserts a {} into the `{}` table",
104 struct_ident, upsert_table
105 );
106 let expanded_upsert_struct = quote! {
107 #[doc = #docs]
108 #[derive(Debug, Clone, scylla::SerializeRow)]
109 pub struct #upsert_struct {
110 #(#expanded_pks,)*
111 #(#maybe_unset_fields,)*
112 #ttl
113 }
114 };
115
116 let set_clauses = keys
119 .iter()
120 .filter(|f| !primary_keys.contains(f))
122 .map(|f| {
123 let ident = &f.ident.clone().unwrap();
124 let col = f.name.as_ref().unwrap();
125 let ident_string = ident.to_string();
126
127 if counters.contains(&f) {
128 format!("{col} = {col} + :{ident_string}")
129 } else {
130 format!("{col} = :{ident_string}")
131 }
132 })
133 .collect::<Vec<_>>();
134
135 let where_clauses = keys
137 .iter()
138 .filter(|f| primary_keys.contains(f))
140 .map(|f| {
141 let ident = &f.ident.clone().unwrap();
142 let col = f.name.as_ref().unwrap();
143 let named_var = ident.to_string();
144
145 (col.clone(), named_var.clone())
146 })
147 .collect::<Vec<_>>();
148
149 let query = build_query(opt, upsert_table, set_clauses, where_clauses);
152
153 quote! {
154 #input
155
156 #expanded_upsert_struct
157
158 #[scyllax::prelude::async_trait]
159 impl scyllax::prelude::Query for #upsert_struct {
160 fn query() -> String {
161 #query.to_string()
162 }
163 }
164
165 impl scyllax::prelude::WriteQuery for #upsert_struct {}
166 }
167}
168
169fn build_query(
170 args: &UpsertQueryOptions,
171 table: &String,
172 set_clauses: Vec<String>,
173 where_clauses: Vec<(String, String)>,
174) -> String {
175 let ttl = match args.ttl.unwrap_or(false) {
176 true => " using ttl :set_ttl",
177 _ => "",
178 };
179
180 if set_clauses.is_empty() {
181 let mut query = format!("insert into {table}{ttl}");
182 let (cols, named_var) = where_clauses.into_iter().unzip::<_, _, Vec<_>, Vec<_>>();
183 let cols = cols.join(", ");
184 let named_var = named_var
185 .into_iter()
186 .map(|var| format!(":{var}"))
187 .collect::<Vec<_>>()
188 .join(", ");
189
190 query.push_str(&format!(" ({cols}) values ({named_var});"));
191
192 query
193 } else {
194 let mut query = format!("update {table}{ttl} set ");
195 let query_set = set_clauses.join(", ");
196 query.push_str(&query_set);
197
198 query.push_str(" where ");
199 let query_where = where_clauses
200 .into_iter()
201 .map(|(col, ident_string)| format!("{col} = :{ident_string}"))
202 .collect::<Vec<_>>()
203 .join(" and ");
204 query.push_str(&query_where);
205
206 query.push(';');
207
208 query
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 fn get_set_clauses() -> Vec<String> {
217 vec![
218 "name = :name",
219 "email = :email",
220 "\"createdAt\" = :created_at",
221 ]
222 .into_iter()
223 .map(|x| x.to_string())
224 .collect::<Vec<_>>()
225 }
226
227 fn get_where_clauses() -> Vec<(String, String)> {
228 vec![("id", "id"), (r#""orgId""#, "org_id")]
229 .into_iter()
230 .map(|(x, y)| (x.to_string(), y.to_string()))
231 .collect::<Vec<_>>()
232 }
233
234 #[test]
235 fn test_update() {
236 let query = build_query(
237 &UpsertQueryOptions {
238 name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
239 table: "person".to_string(),
240 ttl: None,
241 },
242 &"person".to_string(),
243 get_set_clauses(),
244 get_where_clauses(),
245 );
246
247 assert_eq!(
248 query,
249 "update person set name = :name, email = :email, \"createdAt\" = :created_at where id = :id and \"orgId\" = :org_id;",
250 );
251 }
252
253 #[test]
254 fn test_update_ttl() {
255 let query = build_query(
256 &UpsertQueryOptions {
257 name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
258 table: "person".to_string(),
259 ttl: Some(true),
260 },
261 &"person".to_string(),
262 get_set_clauses(),
263 get_where_clauses(),
264 );
265
266 assert_eq!(
267 query,
268 "update person using ttl :set_ttl set name = :name, email = :email, \"createdAt\" = :created_at where id = :id and \"orgId\" = :org_id;",
269 );
270 }
271
272 #[test]
273 fn test_insert() {
274 let query = build_query(
275 &UpsertQueryOptions {
276 name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
277 table: "person".to_string(),
278 ttl: Default::default(),
279 },
280 &"person".to_string(),
281 vec![],
282 get_where_clauses(),
283 );
284
285 assert_eq!(
286 query,
287 "insert into person (id, \"orgId\") values (:id, :org_id);",
288 );
289 }
290
291 #[test]
292 fn test_insert_ttl() {
293 let query = build_query(
294 &UpsertQueryOptions {
295 name: syn::parse_str::<syn::Ident>("UpdatePerson").unwrap(),
296 table: "person".to_string(),
297 ttl: Some(true),
298 },
299 &"person".to_string(),
300 vec![],
301 get_where_clauses(),
302 );
303
304 assert_eq!(
305 query,
306 "insert into person using ttl :set_ttl (id, \"orgId\") values (:id, :org_id);",
307 );
308 }
309}