1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
4
5#[proc_macro_derive(FromRow, attributes(sentinel))]
26pub fn derive_from_row(input: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 match impl_from_row(&input) {
29 Ok(tokens) => tokens.into(),
30 Err(err) => err.to_compile_error().into(),
31 }
32}
33
34fn impl_from_row(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
35 let name = &input.ident;
36 let generics = &input.generics;
37 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
38
39 let fields = match &input.data {
40 Data::Struct(data) => match &data.fields {
41 Fields::Named(fields) => &fields.named,
42 _ => {
43 return Err(syn::Error::new_spanned(
44 input,
45 "FromRow can only be derived for structs with named fields",
46 ))
47 }
48 },
49 _ => {
50 return Err(syn::Error::new_spanned(
51 input,
52 "FromRow can only be derived for structs",
53 ))
54 }
55 };
56
57 let rename_all = get_struct_rename_all(input);
58
59 let field_extractions = fields.iter().map(|f| {
60 let field_name = f.ident.as_ref().unwrap();
61 let field_ty = &f.ty;
62 let column_name = field_name.to_string();
63
64 let attrs = parse_field_attrs(f).unwrap();
65
66 if attrs.skip {
68 return quote! {
69 #field_name: ::std::default::Default::default()
70 };
71 }
72
73 if attrs.flatten {
75 return quote! {
76 #field_name: #field_ty::from_row(row)?
77 };
78 }
79
80 let col = attrs.rename.unwrap_or_else(|| {
82 if let Some(ref strategy) = rename_all {
83 apply_rename_all(&column_name, strategy)
84 } else {
85 column_name
86 }
87 });
88
89 if attrs.json {
91 return quote! {
92 #field_name: {
93 let json_str: String = row.try_get_by_name(#col)?;
94 serde_json::from_str(&json_str)
95 .map_err(|e| sentinel_driver::Error::Decode(format!("json: {}", e)))?
96 }
97 };
98 }
99
100 if let Some(ref source_ty) = attrs.from {
102 return quote! {
103 #field_name: {
104 let v: #source_ty = row.try_get_by_name(#col)?;
105 <#field_ty as ::std::convert::From<#source_ty>>::from(v)
106 }
107 };
108 }
109
110 if let Some(ref source_ty) = attrs.try_from {
112 if attrs.default {
113 return quote! {
114 #field_name: match row.try_get_by_name::<#source_ty>(#col) {
115 Ok(v) => <#field_ty as ::std::convert::TryFrom<#source_ty>>::try_from(v)
116 .map_err(|e| sentinel_driver::Error::Decode(format!("{}", e)))?,
117 Err(sentinel_driver::Error::ColumnNotFound(_)) => ::std::default::Default::default(),
118 Err(e) => return Err(e),
119 }
120 };
121 }
122 return quote! {
123 #field_name: {
124 let v = row.try_get_by_name::<#source_ty>(#col)?;
125 <#field_ty as ::std::convert::TryFrom<#source_ty>>::try_from(v)
126 .map_err(|e| sentinel_driver::Error::Decode(format!("{}", e)))?
127 }
128 };
129 }
130
131 if attrs.default {
133 return quote! {
134 #field_name: match row.try_get_by_name(#col) {
135 Ok(v) => v,
136 Err(sentinel_driver::Error::ColumnNotFound(_)) => ::std::default::Default::default(),
137 Err(e) => return Err(e),
138 }
139 };
140 }
141
142 quote! {
144 #field_name: row.try_get_by_name(#col)?
145 }
146 });
147
148 Ok(quote! {
149 impl #impl_generics #name #ty_generics #where_clause {
150 pub fn from_row(row: &sentinel_driver::Row) -> sentinel_driver::Result<Self> {
152 Ok(Self {
153 #(#field_extractions,)*
154 })
155 }
156 }
157 })
158}
159
160#[proc_macro_derive(ToSql, attributes(sentinel))]
173pub fn derive_to_sql(input: TokenStream) -> TokenStream {
174 let input = parse_macro_input!(input as DeriveInput);
175 match impl_to_sql(&input) {
176 Ok(tokens) => tokens.into(),
177 Err(err) => err.to_compile_error().into(),
178 }
179}
180
181fn impl_to_sql(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
182 let name = &input.ident;
183 let generics = &input.generics;
184 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
185
186 match &input.data {
187 Data::Enum(data) => impl_to_sql_enum(name, generics, data, input),
188 Data::Struct(data) => match &data.fields {
189 Fields::Named(fields) if get_type_name(input).is_some() => {
190 impl_to_sql_composite(name, generics, fields, input)
191 }
192 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => Ok(quote! {
193 impl #impl_generics sentinel_driver::ToSql for #name #ty_generics #where_clause {
194 fn oid(&self) -> sentinel_driver::Oid {
195 self.0.oid()
196 }
197
198 fn to_sql(&self, buf: &mut bytes::BytesMut) -> sentinel_driver::Result<()> {
199 self.0.to_sql(buf)
200 }
201 }
202 }),
203 _ => Err(syn::Error::new_spanned(
204 input,
205 "ToSql requires a newtype struct or a named struct with #[sentinel(type_name = \"...\")]",
206 )),
207 },
208 _ => Err(syn::Error::new_spanned(
209 input,
210 "ToSql can only be derived for structs or enums",
211 )),
212 }
213}
214
215fn impl_to_sql_enum(
216 name: &syn::Ident,
217 generics: &syn::Generics,
218 data: &syn::DataEnum,
219 input: &DeriveInput,
220) -> syn::Result<proc_macro2::TokenStream> {
221 if let Some(repr_ty) = get_repr_type(input) {
223 return impl_to_sql_enum_repr(name, generics, data, &repr_ty);
224 }
225
226 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
227 let rename_all = get_struct_rename_all(input);
228
229 let match_arms = data.variants.iter().map(|v| {
230 let variant_name = &v.ident;
231 let label = get_variant_rename(v)
232 .or_else(|| {
233 rename_all
234 .as_ref()
235 .map(|s| apply_rename_all(&variant_name.to_string(), s))
236 })
237 .unwrap_or_else(|| variant_name.to_string());
238
239 quote! {
240 #name::#variant_name => {
241 buf.put_slice(#label.as_bytes());
242 Ok(())
243 }
244 }
245 });
246
247 Ok(quote! {
248 impl #impl_generics sentinel_driver::ToSql for #name #ty_generics #where_clause {
249 fn oid(&self) -> sentinel_driver::Oid {
250 sentinel_driver::Oid::TEXT
251 }
252
253 fn to_sql(&self, buf: &mut bytes::BytesMut) -> sentinel_driver::Result<()> {
254 use bytes::BufMut;
255 match self {
256 #(#match_arms)*
257 }
258 }
259 }
260 })
261}
262
263#[proc_macro_derive(FromSql, attributes(sentinel))]
276pub fn derive_from_sql(input: TokenStream) -> TokenStream {
277 let input = parse_macro_input!(input as DeriveInput);
278 match impl_from_sql(&input) {
279 Ok(tokens) => tokens.into(),
280 Err(err) => err.to_compile_error().into(),
281 }
282}
283
284fn impl_from_sql(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
285 let name = &input.ident;
286 let generics = &input.generics;
287 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
288
289 match &input.data {
290 Data::Enum(data) => impl_from_sql_enum(name, generics, data, input),
291 Data::Struct(data) => match &data.fields {
292 Fields::Named(fields) if get_type_name(input).is_some() => {
293 impl_from_sql_composite(name, generics, fields, input)
294 }
295 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
296 let inner_ty = fields.unnamed.first().unwrap().ty.clone();
297 Ok(quote! {
298 impl #impl_generics sentinel_driver::FromSql for #name #ty_generics #where_clause {
299 fn oid() -> sentinel_driver::Oid {
300 <#inner_ty as sentinel_driver::FromSql>::oid()
301 }
302
303 fn from_sql(buf: &[u8]) -> sentinel_driver::Result<Self> {
304 <#inner_ty as sentinel_driver::FromSql>::from_sql(buf).map(Self)
305 }
306 }
307 })
308 }
309 _ => Err(syn::Error::new_spanned(
310 input,
311 "FromSql requires a newtype struct or a named struct with #[sentinel(type_name = \"...\")]",
312 )),
313 },
314 _ => Err(syn::Error::new_spanned(
315 input,
316 "FromSql can only be derived for structs or enums",
317 )),
318 }
319}
320
321fn impl_from_sql_enum(
322 name: &syn::Ident,
323 generics: &syn::Generics,
324 data: &syn::DataEnum,
325 input: &DeriveInput,
326) -> syn::Result<proc_macro2::TokenStream> {
327 if let Some(repr_ty) = get_repr_type(input) {
329 return impl_from_sql_enum_repr(name, generics, data, &repr_ty);
330 }
331
332 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
333 let rename_all = get_struct_rename_all(input);
334
335 let match_arms = data.variants.iter().map(|v| {
336 let variant_name = &v.ident;
337 let label = get_variant_rename(v)
338 .or_else(|| {
339 rename_all
340 .as_ref()
341 .map(|s| apply_rename_all(&variant_name.to_string(), s))
342 })
343 .unwrap_or_else(|| variant_name.to_string());
344
345 quote! {
346 #label => Ok(#name::#variant_name),
347 }
348 });
349
350 let type_name_str = name.to_string();
351 let allow_mismatch = has_allow_mismatch(input);
352
353 let fallback = if allow_mismatch {
354 let first_variant = &data.variants.first().unwrap().ident;
355 quote! { _ => Ok(#name::#first_variant), }
356 } else {
357 quote! {
358 other => Err(sentinel_driver::Error::Decode(
359 format!("unknown {} variant: '{}'", #type_name_str, other)
360 )),
361 }
362 };
363
364 Ok(quote! {
365 impl #impl_generics sentinel_driver::FromSql for #name #ty_generics #where_clause {
366 fn oid() -> sentinel_driver::Oid {
367 sentinel_driver::Oid::TEXT
368 }
369
370 fn from_sql(buf: &[u8]) -> sentinel_driver::Result<Self> {
371 let s = ::std::str::from_utf8(buf)
372 .map_err(|e| sentinel_driver::Error::Decode(
373 format!("enum: invalid UTF-8: {}", e)
374 ))?;
375 match s {
376 #(#match_arms)*
377 #fallback
378 }
379 }
380 }
381 })
382}
383
384fn impl_to_sql_composite(
387 name: &syn::Ident,
388 generics: &syn::Generics,
389 fields: &syn::FieldsNamed,
390 _input: &DeriveInput,
391) -> syn::Result<proc_macro2::TokenStream> {
392 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393 let field_count = fields.named.len() as i32;
394
395 let encode_fields = fields.named.iter().map(|f| {
396 let field_name = f.ident.as_ref().unwrap();
397 quote! {
398 {
399 let oid = sentinel_driver::ToSql::oid(&self.#field_name);
400 buf.put_u32(oid.0);
401 let len_pos = buf.len();
402 buf.put_i32(0); let data_start = buf.len();
404 sentinel_driver::ToSql::to_sql(&self.#field_name, buf)?;
405 let data_len = (buf.len() - data_start) as i32;
406 buf[len_pos..len_pos + 4].copy_from_slice(&data_len.to_be_bytes());
407 }
408 }
409 });
410
411 Ok(quote! {
412 impl #impl_generics sentinel_driver::ToSql for #name #ty_generics #where_clause {
413 fn oid(&self) -> sentinel_driver::Oid {
414 sentinel_driver::Oid::TEXT
415 }
416
417 fn to_sql(&self, buf: &mut bytes::BytesMut) -> sentinel_driver::Result<()> {
418 use bytes::BufMut;
419 buf.put_i32(#field_count);
420 #(#encode_fields)*
421 Ok(())
422 }
423 }
424 })
425}
426
427fn impl_from_sql_composite(
428 name: &syn::Ident,
429 generics: &syn::Generics,
430 fields: &syn::FieldsNamed,
431 _input: &DeriveInput,
432) -> syn::Result<proc_macro2::TokenStream> {
433 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
434
435 let decode_fields = fields.named.iter().enumerate().map(|(i, f)| {
436 let field_name = f.ident.as_ref().unwrap();
437 let field_ty = &f.ty;
438 let idx = i;
439
440 quote! {
441 #field_name: {
442 if offset + 8 > buf.len() {
443 return Err(sentinel_driver::Error::Decode(
444 format!("composite: field {} truncated at offset {}", #idx, offset)
445 ));
446 }
447 let _field_oid = u32::from_be_bytes([
448 buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
449 ]);
450 offset += 4;
451 let field_len = i32::from_be_bytes([
452 buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3],
453 ]);
454 offset += 4;
455 if field_len < 0 {
456 return Err(sentinel_driver::Error::Decode(
457 format!("composite: NULL not supported for field {}", #idx)
458 ));
459 }
460 let field_len = field_len as usize;
461 if offset + field_len > buf.len() {
462 return Err(sentinel_driver::Error::Decode(
463 format!("composite: field {} data truncated", #idx)
464 ));
465 }
466 let val = <#field_ty as sentinel_driver::FromSql>::from_sql(
467 &buf[offset..offset + field_len],
468 )?;
469 offset += field_len;
470 val
471 }
472 }
473 });
474
475 Ok(quote! {
476 impl #impl_generics sentinel_driver::FromSql for #name #ty_generics #where_clause {
477 fn oid() -> sentinel_driver::Oid {
478 sentinel_driver::Oid::TEXT
479 }
480
481 fn from_sql(buf: &[u8]) -> sentinel_driver::Result<Self> {
482 if buf.len() < 4 {
483 return Err(sentinel_driver::Error::Decode(
484 "composite: too short".into(),
485 ));
486 }
487 let _field_count = i32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]);
488 let mut offset = 4;
489
490 Ok(Self {
491 #(#decode_fields,)*
492 })
493 }
494 }
495 })
496}
497
498fn impl_to_sql_enum_repr(
501 name: &syn::Ident,
502 generics: &syn::Generics,
503 _data: &syn::DataEnum,
504 repr_ty: &syn::Ident,
505) -> syn::Result<proc_macro2::TokenStream> {
506 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
507
508 let oid_const = match repr_ty.to_string().as_str() {
509 "i8" | "u8" => quote! { sentinel_driver::Oid::CHAR },
510 "i16" | "u16" => quote! { sentinel_driver::Oid::INT2 },
511 "i32" | "u32" => quote! { sentinel_driver::Oid::INT4 },
512 "i64" | "u64" => quote! { sentinel_driver::Oid::INT8 },
513 _ => quote! { sentinel_driver::Oid::INT4 },
514 };
515
516 Ok(quote! {
517 impl #impl_generics sentinel_driver::ToSql for #name #ty_generics #where_clause {
518 fn oid(&self) -> sentinel_driver::Oid {
519 #oid_const
520 }
521
522 fn to_sql(&self, buf: &mut bytes::BytesMut) -> sentinel_driver::Result<()> {
523 (*self as #repr_ty).to_sql(buf)
524 }
525 }
526 })
527}
528
529fn impl_from_sql_enum_repr(
530 name: &syn::Ident,
531 generics: &syn::Generics,
532 data: &syn::DataEnum,
533 repr_ty: &syn::Ident,
534) -> syn::Result<proc_macro2::TokenStream> {
535 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
536
537 let oid_const = match repr_ty.to_string().as_str() {
538 "i8" | "u8" => quote! { sentinel_driver::Oid::CHAR },
539 "i16" | "u16" => quote! { sentinel_driver::Oid::INT2 },
540 "i32" | "u32" => quote! { sentinel_driver::Oid::INT4 },
541 "i64" | "u64" => quote! { sentinel_driver::Oid::INT8 },
542 _ => quote! { sentinel_driver::Oid::INT4 },
543 };
544
545 let match_arms = data.variants.iter().map(|v| {
546 let variant_name = &v.ident;
547 quote! {
548 x if x == #name::#variant_name as #repr_ty => Ok(#name::#variant_name),
549 }
550 });
551
552 let type_name_str = name.to_string();
553
554 Ok(quote! {
555 impl #impl_generics sentinel_driver::FromSql for #name #ty_generics #where_clause {
556 fn oid() -> sentinel_driver::Oid {
557 #oid_const
558 }
559
560 fn from_sql(buf: &[u8]) -> sentinel_driver::Result<Self> {
561 let val = <#repr_ty as sentinel_driver::FromSql>::from_sql(buf)?;
562 match val {
563 #(#match_arms)*
564 other => Err(sentinel_driver::Error::Decode(
565 format!("unknown {} discriminant: {}", #type_name_str, other)
566 )),
567 }
568 }
569 }
570 })
571}
572
573fn get_repr_type(input: &DeriveInput) -> Option<syn::Ident> {
577 for attr in &input.attrs {
578 if attr.path().is_ident("repr") {
579 let ty: syn::Result<syn::Ident> = attr.parse_args();
580 if let Ok(ident) = ty {
581 let s = ident.to_string();
582 if matches!(
583 s.as_str(),
584 "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64"
585 ) {
586 return Some(ident);
587 }
588 }
589 }
590 }
591 None
592}
593
594struct StructAttrs {
596 rename_all: Option<String>,
597 type_name: Option<String>,
598 allow_mismatch: bool,
599}
600
601fn parse_struct_attrs(input: &DeriveInput) -> StructAttrs {
602 let mut attrs = StructAttrs {
603 rename_all: None,
604 type_name: None,
605 allow_mismatch: false,
606 };
607
608 for attr in &input.attrs {
609 if !attr.path().is_ident("sentinel") {
610 continue;
611 }
612
613 let _ = attr.parse_nested_meta(|meta| {
614 if meta.path.is_ident("rename_all") {
615 let value = meta.value()?;
616 let s: syn::LitStr = value.parse()?;
617 attrs.rename_all = Some(s.value());
618 } else if meta.path.is_ident("type_name") {
619 let value = meta.value()?;
620 let s: syn::LitStr = value.parse()?;
621 attrs.type_name = Some(s.value());
622 } else if meta.path.is_ident("allow_mismatch") {
623 attrs.allow_mismatch = true;
624 }
625 Ok(())
626 });
627 }
628
629 attrs
630}
631
632fn apply_rename_all(name: &str, strategy: &str) -> String {
634 match strategy {
635 "lowercase" => name.to_lowercase(),
636 "UPPERCASE" => name.to_uppercase(),
637 "camelCase" => {
638 let mut result = String::new();
639 let mut capitalize_next = false;
640 for (i, c) in name.chars().enumerate() {
641 if c == '_' {
642 capitalize_next = true;
643 } else if capitalize_next {
644 result.extend(c.to_uppercase());
645 capitalize_next = false;
646 } else if i == 0 {
647 result.extend(c.to_lowercase());
648 } else {
649 result.push(c);
650 }
651 }
652 result
653 }
654 "PascalCase" => {
655 let mut result = String::new();
656 let mut capitalize_next = true;
657 for c in name.chars() {
658 if c == '_' {
659 capitalize_next = true;
660 } else if capitalize_next {
661 result.extend(c.to_uppercase());
662 capitalize_next = false;
663 } else {
664 result.push(c);
665 }
666 }
667 result
668 }
669 "snake_case" => {
670 let mut result = String::new();
671 for (i, c) in name.chars().enumerate() {
672 if c.is_uppercase() && i > 0 {
673 result.push('_');
674 }
675 result.extend(c.to_lowercase());
676 }
677 result
678 }
679 "SCREAMING_SNAKE_CASE" => {
680 let mut result = String::new();
681 for (i, c) in name.chars().enumerate() {
682 if c.is_uppercase() && i > 0 {
683 result.push('_');
684 }
685 result.extend(c.to_uppercase());
686 }
687 result
688 }
689 "kebab-case" => {
690 let mut result = String::new();
691 for (i, c) in name.chars().enumerate() {
692 if c == '_' {
693 result.push('-');
694 } else if c.is_uppercase() && i > 0 {
695 result.push('-');
696 result.extend(c.to_lowercase());
697 } else {
698 result.extend(c.to_lowercase());
699 }
700 }
701 result
702 }
703 _ => name.to_string(),
704 }
705}
706
707fn get_struct_rename_all(input: &DeriveInput) -> Option<String> {
709 parse_struct_attrs(input).rename_all
710}
711
712fn has_allow_mismatch(input: &DeriveInput) -> bool {
714 parse_struct_attrs(input).allow_mismatch
715}
716
717fn get_type_name(input: &DeriveInput) -> Option<String> {
719 parse_struct_attrs(input).type_name
720}
721
722fn get_variant_rename(variant: &syn::Variant) -> Option<String> {
724 for attr in &variant.attrs {
725 if !attr.path().is_ident("sentinel") {
726 continue;
727 }
728 let result: syn::Result<String> = attr.parse_args_with(|input: syn::parse::ParseStream| {
729 let ident: syn::Ident = input.parse()?;
730 if ident != "rename" {
731 return Err(syn::Error::new_spanned(&ident, "expected `rename`"));
732 }
733 let _: syn::Token![=] = input.parse()?;
734 let lit: syn::LitStr = input.parse()?;
735 Ok(lit.value())
736 });
737 if let Ok(name) = result {
738 return Some(name);
739 }
740 }
741 None
742}
743
744struct FieldAttrs {
746 rename: Option<String>,
747 skip: bool,
748 default: bool,
749 try_from: Option<Type>,
750 from: Option<Type>,
751 flatten: bool,
752 json: bool,
753}
754
755fn parse_field_attrs(field: &syn::Field) -> syn::Result<FieldAttrs> {
756 let mut attrs = FieldAttrs {
757 rename: None,
758 skip: false,
759 default: false,
760 try_from: None,
761 from: None,
762 flatten: false,
763 json: false,
764 };
765
766 for attr in &field.attrs {
767 if !attr.path().is_ident("sentinel") {
768 continue;
769 }
770
771 attr.parse_nested_meta(|meta| {
772 if meta.path.is_ident("rename") {
773 let value = meta.value()?;
774 let s: syn::LitStr = value.parse()?;
775 attrs.rename = Some(s.value());
776 } else if meta.path.is_ident("skip") {
777 attrs.skip = true;
778 } else if meta.path.is_ident("default") {
779 attrs.default = true;
780 } else if meta.path.is_ident("try_from") {
781 let value = meta.value()?;
782 let s: syn::LitStr = value.parse()?;
783 attrs.try_from = Some(syn::parse_str(&s.value())?);
784 } else if meta.path.is_ident("from") {
785 let value = meta.value()?;
786 let s: syn::LitStr = value.parse()?;
787 attrs.from = Some(syn::parse_str(&s.value())?);
788 } else if meta.path.is_ident("flatten") {
789 attrs.flatten = true;
790 } else if meta.path.is_ident("json") {
791 attrs.json = true;
792 } else {
793 return Err(meta.error("unknown sentinel attribute"));
794 }
795 Ok(())
796 })?;
797 }
798
799 Ok(attrs)
800}
801
802#[cfg(test)]
803mod tests {
804 #[test]
808 fn test_crate_compiles() {
809 assert!(true);
811 }
812}