Skip to main content

sea_orm_codegen/entity/writer/
frontend.rs

1use sea_query::ColumnType;
2use std::collections::HashSet;
3
4use super::*;
5
6// seperate enum so `ColumnType` doesnt need to derive `Hash` or `Eq`
7#[derive(Hash, PartialEq, Eq)]
8enum ExternalTypes {
9    JsonOrJsonBinary,
10    Date,
11    Time,
12    DateTime,
13    Timestamp,
14    TimestampWithTimeZone,
15    DecimalOrMoney,
16    Uuid,
17    Vector,
18    CidrOrInet,
19}
20
21impl ExternalTypes {
22    fn from_column_type(col_type: &ColumnType) -> Option<Self> {
23        Some(match col_type {
24            ColumnType::Json | ColumnType::JsonBinary => Self::JsonOrJsonBinary,
25            ColumnType::Date => Self::Date,
26            ColumnType::Time => Self::Time,
27            ColumnType::DateTime => Self::DateTime,
28            ColumnType::Timestamp => Self::Timestamp,
29            ColumnType::TimestampWithTimeZone => Self::TimestampWithTimeZone,
30            ColumnType::Decimal(..) | ColumnType::Money(..) => Self::DecimalOrMoney,
31            ColumnType::Uuid => Self::Uuid,
32            ColumnType::Vector(..) => Self::Vector,
33            ColumnType::Cidr | ColumnType::Inet => Self::CidrOrInet,
34            _ => return None,
35        })
36    }
37}
38
39impl EntityWriter {
40    #[allow(clippy::too_many_arguments)]
41    pub fn gen_frontend_code_blocks(
42        entity: &Entity,
43        with_serde: &WithSerde,
44        column_option: &ColumnOption,
45        schema_name: &Option<String>,
46        serde_skip_deserializing_primary_key: bool,
47        serde_skip_hidden_column: bool,
48        model_extra_derives: &TokenStream,
49        model_extra_attributes: &TokenStream,
50        _column_extra_derives: &TokenStream,
51        _seaography: bool,
52        _impl_active_model_behavior: bool,
53    ) -> Vec<TokenStream> {
54        let mut imports = Self::gen_import_serde(with_serde);
55        let active_enums = Self::gen_import_active_enum(entity);
56        imports.extend(active_enums.imports);
57        imports.extend(Self::gen_import_frontend(entity, column_option));
58        let code_blocks = vec![
59            imports,
60            Self::gen_frontend_model_struct(
61                entity,
62                with_serde,
63                column_option,
64                schema_name,
65                serde_skip_deserializing_primary_key,
66                serde_skip_hidden_column,
67                model_extra_derives,
68                model_extra_attributes,
69                &active_enums.type_idents,
70            ),
71        ];
72        code_blocks
73    }
74
75    #[allow(clippy::too_many_arguments)]
76    pub fn gen_frontend_model_struct(
77        entity: &Entity,
78        with_serde: &WithSerde,
79        column_option: &ColumnOption,
80        _schema_name: &Option<String>,
81        serde_skip_deserializing_primary_key: bool,
82        serde_skip_hidden_column: bool,
83        model_extra_derives: &TokenStream,
84        model_extra_attributes: &TokenStream,
85        active_enum_type_idents: &ActiveEnumTypeIdents,
86    ) -> TokenStream {
87        let column_names_snake_case = entity.get_column_names_snake_case();
88        let column_rs_types = Self::get_column_rs_types_with_enum_idents(
89            entity,
90            column_option,
91            active_enum_type_idents,
92        );
93        let if_eq_needed = entity.get_eq_needed();
94        let primary_keys: Vec<String> = entity
95            .primary_keys
96            .iter()
97            .map(|pk| pk.name.clone())
98            .collect();
99        let attrs: Vec<TokenStream> = entity
100            .columns
101            .iter()
102            .map(|col| {
103                let is_primary_key = primary_keys.contains(&col.name);
104                col.get_serde_attribute(
105                    is_primary_key,
106                    serde_skip_deserializing_primary_key,
107                    serde_skip_hidden_column,
108                )
109            })
110            .collect();
111        let extra_derive = with_serde.extra_derive();
112
113        quote! {
114            #[derive(Clone, Debug, PartialEq #if_eq_needed #extra_derive #model_extra_derives)]
115            #model_extra_attributes
116            pub struct Model {
117                #(
118                    #attrs
119                    pub #column_names_snake_case: #column_rs_types,
120                )*
121            }
122        }
123    }
124
125    pub fn gen_import_frontend(entity: &Entity, opt: &ColumnOption) -> TokenStream {
126        fn collect(
127            col_type: &ColumnType,
128            opt: &ColumnOption,
129            date_time: &mut Vec<TokenStream>,
130            aliases: &mut Vec<TokenStream>,
131            plain_uses: &mut Vec<TokenStream>,
132            encountered: &mut HashSet<ExternalTypes>,
133        ) {
134            // skip column types we have already generated imports for
135            if let Some(ty) = ExternalTypes::from_column_type(col_type)
136                && !encountered.insert(ty)
137            {
138                return;
139            }
140
141            match col_type {
142                ColumnType::Json | ColumnType::JsonBinary => {
143                    plain_uses.push(quote! { use serde_json::Value as Json; });
144                }
145                ColumnType::Date => match opt.date_time_crate {
146                    DateTimeCrate::Chrono => {
147                        date_time.push(quote! { NaiveDate as Date });
148                    }
149                    DateTimeCrate::Time => {
150                        date_time.push(quote! { Date as TimeDate });
151                    }
152                },
153                ColumnType::Time => match opt.date_time_crate {
154                    DateTimeCrate::Chrono => {
155                        date_time.push(quote! { NaiveTime as Time });
156                    }
157                    DateTimeCrate::Time => {
158                        date_time.push(quote! { Time as TimeTime });
159                    }
160                },
161                ColumnType::DateTime => match opt.date_time_crate {
162                    DateTimeCrate::Chrono => {
163                        date_time.push(quote! { NaiveDateTime as DateTime });
164                    }
165                    DateTimeCrate::Time => {
166                        date_time.push(quote! { PrimitiveDateTime as TimeDateTime });
167                    }
168                },
169                ColumnType::Timestamp => match opt.date_time_crate {
170                    DateTimeCrate::Chrono => {
171                        aliases.push(quote! {
172                            type DateTimeUtc = chrono::DateTime<chrono::Utc>;
173                        });
174                    }
175                    DateTimeCrate::Time => {
176                        date_time.push(quote! { PrimitiveDateTime as TimeDateTime });
177                    }
178                },
179                ColumnType::TimestampWithTimeZone => match opt.date_time_crate {
180                    DateTimeCrate::Chrono => {
181                        aliases.push(quote! {
182                            type DateTimeWithTimeZone = chrono::DateTime<chrono::FixedOffset>;
183                        });
184                    }
185                    DateTimeCrate::Time => {
186                        date_time.push(quote! { OffsetDateTime as TimeDateTimeWithTimeZone });
187                    }
188                },
189                ColumnType::Decimal(_) | ColumnType::Money(_) => {
190                    plain_uses.push(quote! { use rust_decimal::Decimal; })
191                }
192                ColumnType::Uuid => {
193                    plain_uses.push(quote! { use uuid::Uuid; });
194                }
195                ColumnType::Vector(_) => {
196                    plain_uses.push(quote! { use pgvector::Vector as PgVector; });
197                }
198                ColumnType::Cidr | ColumnType::Inet => {
199                    plain_uses.push(quote! { use ipnetwork::IpNetwork; });
200                }
201                ColumnType::Array(inner) => {
202                    collect(
203                        inner.as_ref(),
204                        opt,
205                        date_time,
206                        aliases,
207                        plain_uses,
208                        encountered,
209                    );
210                }
211                _ => {}
212            }
213        }
214
215        let mut date_time_uses = Vec::new();
216        let mut aliases = Vec::new();
217        let mut plain_uses = Vec::new();
218        let mut encountered = HashSet::new();
219
220        for col in &entity.columns {
221            collect(
222                &col.col_type,
223                opt,
224                &mut date_time_uses,
225                &mut aliases,
226                &mut plain_uses,
227                &mut encountered,
228            );
229        }
230
231        let time_use = if date_time_uses.is_empty() {
232            quote! {}
233        } else {
234            match opt.date_time_crate {
235                DateTimeCrate::Chrono => quote! { use chrono::{ #(#date_time_uses),* }; },
236                DateTimeCrate::Time => quote! { use time::{ #(#date_time_uses),* }; },
237            }
238        };
239
240        quote! {
241            #time_use
242            #(#plain_uses)*
243            #(#aliases)*
244        }
245    }
246}