sea_orm_newtype_id_macros/
lib.rs

1//! ID Generation
2//!
3//! We use a macro to generate the appropriate structs that we need
4use proc_macro::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{parse::Parse, ExprLit, Ident, Token};
7
8extern crate proc_macro;
9
10struct DefId {
11    struct_name: Ident,
12    prefix: ExprLit,
13}
14
15impl Parse for DefId {
16    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17        let struct_name: Ident = input.parse()?;
18        let _comma: Token![,] = input.parse()?;
19        let prefix: ExprLit = input.parse()?;
20
21        Ok(DefId {
22            struct_name,
23            prefix,
24        })
25    }
26}
27
28#[proc_macro]
29pub fn def_id(tokens: TokenStream) -> TokenStream {
30    let DefId {
31        struct_name,
32        prefix,
33    } = syn::parse_macro_input!(tokens as DefId);
34
35    let async_graphql_impl =
36        FeatureAsyncGraphQL::new(cfg!(feature = "with-async-graphql"), struct_name.clone());
37    let serde_impl = FeatureSerde::new(cfg!(feature = "with-serde"), struct_name.clone());
38
39    let tokens = quote! {
40      ////////////////////////////////////////////////
41      // Main Struct
42      ////////////////////////////////////////////////
43
44      #[derive(Clone, Debug, Eq, PartialEq, Hash)]
45      pub struct #struct_name(sea_orm_newtype_id::smol_str::SmolStr);
46
47      impl #struct_name {
48        /// Create a new ID
49        pub fn new() -> Self {
50          #struct_name(<Self as sea_orm_newtype_id::PrefixedId>::new_id().into())
51        }
52
53        /// Extracts a string slice containing the entire id.
54        #[inline(always)]
55        pub fn as_str(&self) -> &str {
56          self.0.as_str()
57        }
58      }
59
60      impl Default for #struct_name {
61         fn default() -> Self {
62           Self::new()
63         }
64      }
65
66      impl PartialEq<str> for #struct_name {
67        fn eq(&self, other: &str) -> bool {
68            self.as_str() == other
69        }
70      }
71
72      impl PartialEq<&str> for #struct_name {
73        fn eq(&self, other: &&str) -> bool {
74            self.as_str() == *other
75        }
76      }
77
78      impl PartialEq<String> for #struct_name {
79        fn eq(&self, other: &String) -> bool {
80            self.as_str() == other
81        }
82      }
83
84      impl std::fmt::Display for #struct_name {
85        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86          self.0.fmt(f)
87        }
88      }
89
90     impl std::str::FromStr for #struct_name {
91      type Err = sea_orm_newtype_id::ParseIdError;
92
93      fn from_str(s: &str) -> Result<Self, Self::Err> {
94          if !s.starts_with(#prefix) {
95            // N.B. For debugging
96            eprintln!("bad id is: {} (expected: {:?})", s, #prefix);
97
98            Err(sea_orm_newtype_id::ParseIdError {
99                typename: stringify!(#struct_name),
100                expected: stringify!(id to start with #prefix)
101            })
102          } else {
103            Ok(#struct_name(s.into()))
104          }
105        }
106      }
107
108      impl sea_orm_newtype_id::PrefixedId for #struct_name {
109        const PREFIX: &'static str = #prefix;
110      }
111
112      impl From<#struct_name> for sea_orm::Value {
113        fn from(v: #struct_name) -> Self {
114          sea_orm::Value::String(Some(Box::new(v.as_str().to_string())))
115        }
116      }
117
118      impl sea_orm::sea_query::ValueType for #struct_name {
119        fn try_from(v: sea_orm::Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
120          match v {
121            sea_orm::Value::String(Some(x)) => {
122              let v: String = *x;
123              Ok(Self(v.into()))
124            }
125            _ => Err(sea_orm::sea_query::ValueTypeErr),
126          }
127        }
128
129        fn array_type() -> sea_orm::sea_query::ArrayType {
130          sea_orm::sea_query::ArrayType::String
131        }
132
133        fn type_name() -> String {
134          stringify!($type).to_owned()
135        }
136
137        fn column_type() -> sea_orm::sea_query::ColumnType {
138          sea_orm::sea_query::ColumnType::String(Some(26))
139        }
140      }
141
142      impl sea_orm::TryFromU64 for #struct_name {
143        fn try_from_u64(_n: u64) -> Result<Self, sea_orm::error::DbErr> {
144          Err(sea_orm::error::DbErr::ConvertFromU64(stringify!(#struct_name)))
145        }
146      }
147
148      impl sea_orm::TryGetable for #struct_name {
149        fn try_get(
150          res: &sea_orm::QueryResult,
151          pre: &str,
152          col: &str,
153        ) -> Result<Self, sea_orm::TryGetError> {
154          let val: std::option::Option<String> = res.try_get(pre, col).ok();
155          if let Some(value) = val {
156            Ok(<#struct_name as std::str::FromStr>::from_str(&value).map_err(|e| {
157              sea_orm::TryGetError::DbErr(sea_orm::error::DbErr::Custom(String::from("failed to convert")))
158            })?)
159          } else {
160            Err(sea_orm::TryGetError::Null(col.to_string()))
161          }
162        }
163      }
164
165      impl sea_orm::sea_query::Nullable for #struct_name {
166        fn null() -> sea_orm::Value {
167          sea_orm::Value::String(None)
168        }
169      }
170
171      impl sea_orm::IntoActiveValue<Self> for #struct_name {
172        fn into_active_value(self) -> sea_orm::ActiveValue<Self> {
173          sea_orm::Set(self)
174        }
175      }
176
177      #async_graphql_impl
178      #serde_impl
179    }
180    .into();
181
182    tokens
183}
184
185struct FeatureSerde {
186    enabled: bool,
187    struct_name: Ident,
188}
189
190impl FeatureSerde {
191    fn new(enabled: bool, struct_name: Ident) -> Self {
192        Self {
193            enabled,
194            struct_name,
195        }
196    }
197}
198
199impl ToTokens for FeatureSerde {
200    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
201        let struct_name = self.struct_name.clone();
202        if self.enabled {
203            tokens.extend(quote! {
204                impl serde::Serialize for #struct_name {
205                  fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206                  where
207                    S: serde::ser::Serializer,
208                  {
209                    self.as_str().serialize(serializer)
210                  }
211                }
212
213                impl<'de> serde::Deserialize<'de> for #struct_name {
214                  fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215                  where
216                    D: serde::de::Deserializer<'de>,
217                  {
218                    let s: String = serde::Deserialize::deserialize(deserializer)?;
219                    s.parse::<Self>().map_err(::serde::de::Error::custom)
220                  }
221                }
222            });
223        }
224    }
225}
226
227struct FeatureAsyncGraphQL {
228    enabled: bool,
229    struct_name: Ident,
230}
231
232impl FeatureAsyncGraphQL {
233    fn new(enabled: bool, struct_name: Ident) -> Self {
234        Self {
235            enabled,
236            struct_name,
237        }
238    }
239}
240
241impl ToTokens for FeatureAsyncGraphQL {
242    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
243        let struct_name = self.struct_name.clone();
244        if self.enabled {
245            tokens.extend(quote! {
246                #[async_graphql::Scalar]
247                impl async_graphql::ScalarType for #struct_name {
248                  fn parse(value: async_graphql::Value) -> async_graphql::InputValueResult<Self> {
249                    if let async_graphql::Value::String(value) = &value {
250                      Ok(#struct_name(value.into()))
251                    } else {
252                      Err(async_graphql::InputValueError::expected_type(value))
253                    }
254                  }
255
256                  fn to_value(&self) -> async_graphql::Value {
257                    async_graphql::Value::String(self.to_string())
258                  }
259                }
260            })
261        }
262    }
263}