1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::{Data, DeriveInput, ExprAssign, ExprLit, parse::Parse, punctuated::Punctuated};
4
5#[cfg(test)]
6mod tests;
7
8mod surrql;
9pub use surrql::surrql_macro_impl;
10
11pub fn table_macro_impl(input: TokenStream) -> syn::Result<TokenStream> {
17 let input = syn::parse2::<DeriveInput>(input)?;
18
19 let struct_name = &input.ident;
20
21 let table_name = parse_table_name(&input)?;
22
23 let struct_fields = parse_struct_fields(&input)?;
24
25 let (table_field_queries, index_queries) = parse_attributes(struct_fields, &table_name)?;
26
27 let table_query = format!("DEFINE TABLE IF NOT EXISTS {table_name} SCHEMAFULL;");
28
29 let migration_up_query = format!(
30 "{table_query}\n{}",
31 table_field_queries
32 .into_iter()
33 .chain(index_queries)
34 .collect::<Vec<_>>()
35 .join("\n")
36 );
37
38 let expanded = quote! {
40 impl ::surrealqlx::traits::Table for #struct_name {
42 const TABLE_NAME: &'static str = #table_name;
43
44 fn migrations() -> Vec<::surrealqlx::migrations::M<'static>> {
45 vec![
46 ::surrealqlx::migrations::M::up(
47 #migration_up_query
48 )
49 .comment("Initial version"),
50 ]
51 }
52 }
53 };
54
55 Ok(expanded)
57}
58
59fn parse_table_name(input: &DeriveInput) -> syn::Result<String> {
60 let table_name = input
61 .attrs
62 .iter()
63 .find(|attr| attr.path().is_ident("Table"))
64 .ok_or_else(|| {
65 syn::Error::new_spanned(input, "Table attribute must be specified for the struct")
66 })
67 .and_then(|attr| attr.parse_args::<syn::LitStr>().map(|lit| lit.value()))?;
68 Ok(table_name)
69}
70
71fn parse_struct_fields(input: &DeriveInput) -> syn::Result<impl Iterator<Item = &syn::Field>> {
73 match input.data {
74 Data::Struct(ref data) => match data.fields {
75 syn::Fields::Named(ref fields) => {
76 let mut fields = fields.named.iter().peekable();
77 if fields.peek().is_none() {
78 return Err(syn::Error::new_spanned(
79 input,
80 "Struct must have at least one field",
81 ));
82 }
83 Ok(fields)
84 }
85 _ => Err(syn::Error::new_spanned(
86 input,
87 "Tuple structs not supported",
88 )),
89 },
90 _ => Err(syn::Error::new_spanned(input, "Only structs are supported")),
91 }
92}
93
94fn parse_attributes<'a>(
96 fields: impl Iterator<Item = &'a syn::Field>,
97 table_name: &str,
98) -> syn::Result<(Vec<String>, Vec<String>)> {
99 let mut table_field_queries = Vec::new();
100
101 let mut index_queries = Vec::new();
102
103 for field in fields {
104 let Some(field_name) = field.ident.as_ref() else {
105 return Err(syn::Error::new_spanned(
106 field,
107 "Field must have a name, tuple structs not allowed",
108 ));
109 };
110 let mut field_attrs = field
111 .attrs
112 .iter()
113 .filter(|attr| attr.path().is_ident("field"))
114 .map(|attr| {
115 let parsed = attr.parse_args::<FieldAnnotation>();
116 match parsed {
117 Ok(parsed) => Ok((attr, parsed)),
118 Err(err) => Err(err),
119 }
120 })
121 .peekable();
122
123 let extra = match field_attrs.next() {
127 Some(Ok((_, FieldAnnotation::Skip))) => {
128 continue;
129 }
130 Some(Ok((_, FieldAnnotation::Plain))) => String::new(),
131 Some(Ok((_, FieldAnnotation::Typed { type_ }))) => format!(" TYPE {}", type_.value()),
132 Some(Ok((_, FieldAnnotation::CustomQuery { query }))) => {
133 format!(" {}", query.value())
134 }
135 Some(Err(err)) => {
136 return Err(err);
137 }
138 None => {
139 return Err(syn::Error::new_spanned(
140 field,
141 "Field must have a #[field] attribute",
142 ));
143 }
144 };
145 if field_attrs.peek().is_some() {
147 return Err(syn::Error::new_spanned(
148 field,
149 "Field can have only one #[field] attribute",
150 ));
151 }
152
153 table_field_queries.push(format!(
154 "DEFINE FIELD IF NOT EXISTS {field_name} ON {table_name}{extra};",
155 ));
156
157 let index_attrs = field
159 .attrs
160 .iter()
161 .filter(|attr| attr.path().is_ident("index"))
162 .map(|attr| {
163 let parsed = attr.parse_args::<IndexAnnotation>();
164 match parsed {
165 Ok(parsed) => Ok(parsed),
166 Err(err) => Err(err),
167 }
168 })
169 .collect::<Result<Vec<_>, _>>()?;
170
171 for index in index_attrs {
172 for query in index.to_query_strings(table_name, &field_name.to_string()) {
173 index_queries.push(query);
174 }
175 }
176 }
177
178 Ok((table_field_queries, index_queries))
179}
180
181enum FieldAnnotation {
182 Skip,
183 Plain,
184 Typed { type_: syn::LitStr },
185 CustomQuery { query: syn::LitStr },
186}
187
188impl Parse for FieldAnnotation {
194 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
195 let args: Punctuated<syn::Expr, syn::token::Comma> =
196 input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
197
198 if args.is_empty() {
199 return Ok(Self::Plain);
200 }
201
202 if args.len() > 1 {
203 return Err(syn::Error::new_spanned(
204 args,
205 "Field attribute can have at most one argument",
206 ));
207 }
208
209 match args.first() {
210 None => Ok(Self::Plain),
211 Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("skip") => {
212 Ok(Self::Skip)
213 }
214 Some(syn::Expr::Lit(ExprLit {
215 lit: syn::Lit::Str(strlit),
216 ..
217 })) => Ok(Self::CustomQuery {
218 query: strlit.clone(),
219 }),
220 Some(syn::Expr::Assign(ExprAssign { left, right, .. })) => {
221 if left.to_token_stream().to_string().eq("dt") {
222 match *right.to_owned() {
223 syn::Expr::Lit(ExprLit {
224 lit: syn::Lit::Str(strlit),
225 ..
226 }) => Ok(Self::Typed { type_: strlit }),
227 _ => Err(syn::Error::new_spanned(
228 right,
229 "The `dt` attribute expects a string literal",
230 )),
231 }
232 } else {
233 Err(syn::Error::new_spanned(
234 left,
235 "Unknown field attribute, expected `dt`",
236 ))
237 }
238 }
239 Some(expr) => Err(syn::Error::new_spanned(
240 expr,
241 "Unsupported expression syntax, expected `skip`, `dt = \"type\"`, or a string literal representing a custom query",
242 )),
243 }
244 }
245}
246
247#[derive(Default, Debug, Clone)]
248struct IndexAnnotation {
249 indexes: Vec<IndexAnnotationInner>,
250}
251
252#[derive(Debug, Clone)]
253enum IndexAnnotationInner {
254 Compound(CompoundIndexAnnotation),
255 Single(IndexKind),
256}
257
258impl Parse for IndexAnnotation {
259 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
281 let args: Punctuated<syn::Expr, syn::token::Comma> =
283 input.parse_terminated(syn::Expr::parse, syn::token::Comma)?;
284
285 if args.is_empty() {
286 return Ok(Self {
287 indexes: vec![IndexAnnotationInner::Single(IndexKind::Normal)],
288 });
289 }
290
291 let mut indexes = Vec::new();
292 for arg in &args {
293 match arg {
294 syn::Expr::Call(call) if call.func.to_token_stream().to_string().eq("compound") => {
295 indexes.push(IndexAnnotationInner::Compound(
297 CompoundIndexAnnotation::parse(&call.args)?,
298 ));
299 }
300 _ => {
301 let index_type = IndexKind::parse(Some(arg))?;
303 indexes.push(IndexAnnotationInner::Single(index_type));
304 }
305 }
306 }
307
308 Ok(Self { indexes })
309 }
310}
311
312impl IndexAnnotation {
313 fn to_query_strings(&self, table_name: &str, field_name: &str) -> Vec<String> {
315 let mut output = Vec::new();
316 for index in &self.indexes {
317 let (compound, index_type) = match index {
318 IndexAnnotationInner::Compound(compound_index_annotation) => (
319 Some(&compound_index_annotation.fields),
320 &compound_index_annotation.index,
321 ),
322 IndexAnnotationInner::Single(index_kind) => (None, index_kind),
323 };
324
325 let (extra, index_type) = match index_type {
326 IndexKind::Vector(vector) => (format!(" MTREE DIMENSION {}", vector.dim), "vector"),
327 IndexKind::Text(text) => {
328 (format!(" SEARCH ANALYZER {} BM25", text.analyzer), "text")
329 }
330 IndexKind::Normal => (String::new(), "normal"),
331 IndexKind::Unique => (String::from(" UNIQUE"), "unique"),
332 };
333 let compound_fields = |sep: &str| match compound {
334 Some(compound) if !compound.is_empty() => {
335 format!("{sep}{}", compound.join(sep))
336 }
337 _ => String::new(),
338 };
339
340 let index_name = format!(
341 "{table_name}_{field_name}{extra_fields}_{index_type}_index",
342 extra_fields = compound_fields("_")
343 );
344
345 let query = format!(
346 "DEFINE INDEX IF NOT EXISTS {index_name} ON {table_name} FIELDS {field_name}{extra_fields}{extra};",
347 extra_fields = compound_fields(",")
348 );
349
350 output.push(query);
351 }
352
353 output
354 }
355}
356
357#[derive(Default, Debug, Clone)]
358struct CompoundIndexAnnotation {
360 index: IndexKind,
361 fields: Vec<String>,
362}
363
364impl CompoundIndexAnnotation {
365 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
366 let mut fields = Vec::new();
367
368 let mut args_iter = args.iter();
369
370 let index = match args_iter.next() {
372 Some(syn::Expr::Lit(ExprLit {
373 lit: syn::Lit::Str(strlit),
374 ..
375 })) => {
376 fields.push(strlit.value());
377 IndexKind::Normal
378 }
379 arg => match IndexKind::parse(arg) {
380 Ok(index_type) => index_type,
381 Err(mut err) => {
382 err.combine(syn::Error::new_spanned(
383 arg,
384 "Compound index attribute expects a valid index type or string literal representing the first field name as the first argument",
385 ));
386 return Err(err);
387 }
388 },
389 };
390
391 for arg in args_iter {
393 match arg {
394 syn::Expr::Lit(ExprLit {
395 lit: syn::Lit::Str(strlit),
396 ..
397 }) => fields.push(strlit.value()),
398 _ => {
399 return Err(syn::Error::new_spanned(
400 arg,
401 "Compound index attribute expects string literals representing the other field names",
402 ));
403 }
404 }
405 }
406
407 if fields.is_empty() {
408 Err(syn::Error::new_spanned(
409 args,
410 "Compound index attribute expects at least one string literal representing the other field names",
411 ))
412 } else {
413 Ok(Self { index, fields })
414 }
415 }
416}
417
418#[derive(Default, Debug, Clone)]
419enum IndexKind {
420 Vector(VectorIndexAnnotation),
421 Text(TextIndexAnnotation),
422 #[default]
423 Normal,
424 Unique,
425}
426
427impl IndexKind {
428 fn parse(arg: Option<&syn::Expr>) -> syn::Result<Self> {
429 match arg {
430 None => Ok(Self::Normal),
431 Some(syn::Expr::Path(path)) if path.to_token_stream().to_string().eq("unique") => {
432 Ok(Self::Unique)
433 }
434 Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("vector") => {
435 Ok(Self::Vector(VectorIndexAnnotation::parse(&call.args)?))
436 }
437 Some(syn::Expr::Call(call)) if call.func.to_token_stream().to_string().eq("text") => {
438 Ok(Self::Text(TextIndexAnnotation::parse(&call.args)?))
439 }
440 _ => Err(syn::Error::new_spanned(
441 arg,
442 "Unsupported expression syntax",
443 )),
444 }
445 }
446}
447
448#[derive(Debug, Copy, Clone)]
449struct VectorIndexAnnotation {
450 dim: usize,
451}
452
453impl VectorIndexAnnotation {
454 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
455 let mut args_iter = args.iter();
456 let arg = args_iter.next();
457 if args_iter.next().is_some() {
458 return Err(syn::Error::new_spanned(
459 args,
460 "Vector index attribute only expects one argument, the dimension of the vector",
461 ));
462 }
463
464 let dim = match arg {
465 Some(syn::Expr::Assign(ExprAssign { left, right, .. }))
466 if left.to_token_stream().to_string().eq("dim") =>
467 {
468 match *right.to_owned() {
469 syn::Expr::Lit(ExprLit {
470 lit: syn::Lit::Int(int),
471 ..
472 }) => int.base10_parse()?,
473 _ => {
474 return Err(syn::Error::new_spanned(
475 right,
476 "`dim` expects an integer literal representing the number of dimensions in the vector",
477 ));
478 }
479 }
480 }
481 Some(syn::Expr::Lit(ExprLit {
482 lit: syn::Lit::Int(int),
483 ..
484 })) => int.base10_parse()?,
485 _ => {
486 return Err(syn::Error::new_spanned(
487 arg,
488 "Unsupported expression syntax",
489 ));
490 }
491 };
492
493 if dim < 1 {
494 return Err(syn::Error::new_spanned(
495 arg,
496 "Vector dimension must be greater than 0",
497 ));
498 }
499
500 Ok(Self { dim })
501 }
502}
503
504#[derive(Debug, Clone)]
505struct TextIndexAnnotation {
506 analyzer: String,
507}
508
509impl TextIndexAnnotation {
510 fn parse(args: &Punctuated<syn::Expr, syn::token::Comma>) -> syn::Result<Self> {
511 let mut args_iter = args.iter();
513 let arg = args_iter.next();
514
515 if args_iter.next().is_some() {
516 return Err(syn::Error::new_spanned(
517 args,
518 "Text index attribute only expects one argument, the analyzer to use",
519 ));
520 }
521
522 let analyzer = match arg {
523 Some(syn::Expr::Lit(ExprLit {
524 lit: syn::Lit::Str(strlit),
525 ..
526 })) => strlit.value(),
527 _ => {
528 return Err(syn::Error::new_spanned(
529 arg,
530 "Text index attribute expects a string literal representing the analyzer to use",
531 ));
532 }
533 };
534
535 Ok(Self { analyzer })
536 }
537}