Skip to main content

sea_orm_codegen/entity/writer/
frontend.rs

1use sea_query::ColumnType;
2use std::collections::HashSet;
3
4use super::*;
5
6// seperate enum so `ColumnType` doesnt need to derive `Hash` or `Eq`
7#[derive(Hash, PartialEq, Eq)]
8enum ExternalTypes {
9    JsonOrJsonBinary,
10    Date,
11    Time,
12    DateTime,
13    Timestamp,
14    TimestampWithTimeZone,
15    DecimalOrMoney,
16    Uuid,
17    Vector,
18    CidrOrInet,
19}
20
21impl ExternalTypes {
22    fn from_column_type(col_type: &ColumnType) -> Option<Self> {
23        Some(match col_type {
24            ColumnType::Json | ColumnType::JsonBinary => Self::JsonOrJsonBinary,
25            ColumnType::Date => Self::Date,
26            ColumnType::Time => Self::Time,
27            ColumnType::DateTime => Self::DateTime,
28            ColumnType::Timestamp => Self::Timestamp,
29            ColumnType::TimestampWithTimeZone => Self::TimestampWithTimeZone,
30            ColumnType::Decimal(..) | ColumnType::Money(..) => Self::DecimalOrMoney,
31            ColumnType::Uuid => Self::Uuid,
32            ColumnType::Vector(..) => Self::Vector,
33            ColumnType::Cidr | ColumnType::Inet => Self::CidrOrInet,
34            _ => return None,
35        })
36    }
37}
38
39impl EntityWriter {
40    #[allow(clippy::too_many_arguments)]
41    pub fn gen_frontend_code_blocks(
42        entity: &Entity,
43        with_serde: &WithSerde,
44        column_option: &ColumnOption,
45        schema_name: &Option<String>,
46        serde_skip_deserializing_primary_key: bool,
47        serde_skip_hidden_column: bool,
48        model_extra_derives: &TokenStream,
49        model_extra_attributes: &TokenStream,
50        _column_extra_derives: &TokenStream,
51        _seaography: bool,
52        _impl_active_model_behavior: bool,
53    ) -> Vec<TokenStream> {
54        let mut imports = Self::gen_import_serde(with_serde);
55        imports.extend(Self::gen_import_active_enum(entity));
56        imports.extend(Self::gen_import_frontend(entity, column_option));
57        let code_blocks = vec![
58            imports,
59            Self::gen_frontend_model_struct(
60                entity,
61                with_serde,
62                column_option,
63                schema_name,
64                serde_skip_deserializing_primary_key,
65                serde_skip_hidden_column,
66                model_extra_derives,
67                model_extra_attributes,
68            ),
69        ];
70        code_blocks
71    }
72
73    #[allow(clippy::too_many_arguments)]
74    pub fn gen_frontend_model_struct(
75        entity: &Entity,
76        with_serde: &WithSerde,
77        column_option: &ColumnOption,
78        _schema_name: &Option<String>,
79        serde_skip_deserializing_primary_key: bool,
80        serde_skip_hidden_column: bool,
81        model_extra_derives: &TokenStream,
82        model_extra_attributes: &TokenStream,
83    ) -> TokenStream {
84        let column_names_snake_case = entity.get_column_names_snake_case();
85        let column_rs_types = entity.get_column_rs_types(column_option);
86        let if_eq_needed = entity.get_eq_needed();
87        let primary_keys: Vec<String> = entity
88            .primary_keys
89            .iter()
90            .map(|pk| pk.name.clone())
91            .collect();
92        let attrs: Vec<TokenStream> = entity
93            .columns
94            .iter()
95            .map(|col| {
96                let is_primary_key = primary_keys.contains(&col.name);
97                col.get_serde_attribute(
98                    is_primary_key,
99                    serde_skip_deserializing_primary_key,
100                    serde_skip_hidden_column,
101                )
102            })
103            .collect();
104        let extra_derive = with_serde.extra_derive();
105
106        quote! {
107            #[derive(Clone, Debug, PartialEq #if_eq_needed #extra_derive #model_extra_derives)]
108            #model_extra_attributes
109            pub struct Model {
110                #(
111                    #attrs
112                    pub #column_names_snake_case: #column_rs_types,
113                )*
114            }
115        }
116    }
117
118    pub fn gen_import_frontend(entity: &Entity, opt: &ColumnOption) -> TokenStream {
119        fn collect(
120            col_type: &ColumnType,
121            opt: &ColumnOption,
122            date_time: &mut Vec<TokenStream>,
123            aliases: &mut Vec<TokenStream>,
124            plain_uses: &mut Vec<TokenStream>,
125            encountered: &mut HashSet<ExternalTypes>,
126        ) {
127            // skip column types we have already generated imports for
128            if let Some(ty) = ExternalTypes::from_column_type(col_type)
129                && !encountered.insert(ty)
130            {
131                return;
132            }
133
134            match col_type {
135                ColumnType::Json | ColumnType::JsonBinary => {
136                    plain_uses.push(quote! { use serde_json::Value as Json; });
137                }
138                ColumnType::Date => match opt.date_time_crate {
139                    DateTimeCrate::Chrono => {
140                        date_time.push(quote! { NaiveDate as Date });
141                    }
142                    DateTimeCrate::Time => {
143                        date_time.push(quote! { Date as TimeDate });
144                    }
145                },
146                ColumnType::Time => match opt.date_time_crate {
147                    DateTimeCrate::Chrono => {
148                        date_time.push(quote! { NaiveTime as Time });
149                    }
150                    DateTimeCrate::Time => {
151                        date_time.push(quote! { Time as TimeTime });
152                    }
153                },
154                ColumnType::DateTime => match opt.date_time_crate {
155                    DateTimeCrate::Chrono => {
156                        date_time.push(quote! { NaiveDateTime as DateTime });
157                    }
158                    DateTimeCrate::Time => {
159                        date_time.push(quote! { PrimitiveDateTime as TimeDateTime });
160                    }
161                },
162                ColumnType::Timestamp => match opt.date_time_crate {
163                    DateTimeCrate::Chrono => {
164                        aliases.push(quote! {
165                            type DateTimeUtc = chrono::DateTime<chrono::Utc>;
166                        });
167                    }
168                    DateTimeCrate::Time => {
169                        date_time.push(quote! { PrimitiveDateTime as TimeDateTime });
170                    }
171                },
172                ColumnType::TimestampWithTimeZone => match opt.date_time_crate {
173                    DateTimeCrate::Chrono => {
174                        aliases.push(quote! {
175                            type DateTimeWithTimeZone = chrono::DateTime<chrono::FixedOffset>;
176                        });
177                    }
178                    DateTimeCrate::Time => {
179                        date_time.push(quote! { OffsetDateTime as TimeDateTimeWithTimeZone });
180                    }
181                },
182                ColumnType::Decimal(_) | ColumnType::Money(_) => {
183                    plain_uses.push(quote! { use rust_decimal::Decimal; })
184                }
185                ColumnType::Uuid => {
186                    plain_uses.push(quote! { use uuid::Uuid; });
187                }
188                ColumnType::Vector(_) => {
189                    plain_uses.push(quote! { use pgvector::Vector as PgVector; });
190                }
191                ColumnType::Cidr | ColumnType::Inet => {
192                    plain_uses.push(quote! { use ipnetwork::IpNetwork; });
193                }
194                ColumnType::Array(inner) => {
195                    collect(
196                        inner.as_ref(),
197                        opt,
198                        date_time,
199                        aliases,
200                        plain_uses,
201                        encountered,
202                    );
203                }
204                _ => {}
205            }
206        }
207
208        let mut date_time_uses = Vec::new();
209        let mut aliases = Vec::new();
210        let mut plain_uses = Vec::new();
211        let mut encountered = HashSet::new();
212
213        for col in &entity.columns {
214            collect(
215                &col.col_type,
216                opt,
217                &mut date_time_uses,
218                &mut aliases,
219                &mut plain_uses,
220                &mut encountered,
221            );
222        }
223
224        let time_use = if date_time_uses.is_empty() {
225            quote! {}
226        } else {
227            match opt.date_time_crate {
228                DateTimeCrate::Chrono => quote! { use chrono::{ #(#date_time_uses),* }; },
229                DateTimeCrate::Time => quote! { use time::{ #(#date_time_uses),* }; },
230            }
231        };
232
233        quote! {
234            #time_use
235            #(#plain_uses)*
236            #(#aliases)*
237        }
238    }
239}