Skip to main content

sea_schema/postgres/writer/
column.rs

1use crate::postgres::def::{ColumnDefault, ColumnInfo, Type};
2use sea_query::{
3    Alias, ColumnDef, ColumnType, DynIden, Expr, IntoIden, Keyword, PgInterval, RcOrArc,
4    SimpleExpr, StringLen,
5};
6use std::convert::TryFrom;
7
8impl ColumnInfo {
9    pub fn write(&self) -> ColumnDef {
10        let mut col_info = self.clone();
11        if let Some(ColumnDefault::AutoIncrement(_)) = &self.default {
12            col_info = Self::convert_to_serial(col_info);
13        }
14        let col_type = col_info.write_col_type();
15        let mut col_def = ColumnDef::new_with_type(Alias::new(self.name.as_str()), col_type);
16        if self.is_identity {
17            col_info = Self::convert_to_serial(col_info);
18        }
19        if matches!(
20            col_info.col_type,
21            Type::SmallSerial | Type::Serial | Type::BigSerial
22        ) {
23            col_def.auto_increment();
24        }
25        if self.not_null.is_some() {
26            col_def.not_null();
27        }
28        if let Some(default) = &self.default {
29            if let Some(default_expr) = default.write() {
30                col_def.default(default_expr);
31            }
32        }
33        col_def
34    }
35
36    fn convert_to_serial(mut col_info: ColumnInfo) -> ColumnInfo {
37        match col_info.col_type {
38            Type::SmallInt => {
39                col_info.col_type = Type::SmallSerial;
40            }
41            Type::Integer => {
42                col_info.col_type = Type::Serial;
43            }
44            Type::BigInt => {
45                col_info.col_type = Type::BigSerial;
46            }
47            _ => {}
48        };
49        col_info
50    }
51
52    pub fn write_col_type(&self) -> ColumnType {
53        fn write_type(col_type: &Type) -> ColumnType {
54            match col_type {
55                Type::SmallInt => ColumnType::SmallInteger,
56                Type::Integer => ColumnType::Integer,
57                Type::BigInt => ColumnType::BigInteger,
58                Type::Decimal(num_attr) | Type::Numeric(num_attr) => {
59                    match (num_attr.precision, num_attr.scale) {
60                        (None, None) => ColumnType::Decimal(None),
61                        (precision, scale) => ColumnType::Decimal(Some((
62                            precision.unwrap_or(0).into(),
63                            scale.unwrap_or(0).into(),
64                        ))),
65                    }
66                }
67                Type::Real => ColumnType::Float,
68                Type::DoublePrecision => ColumnType::Double,
69                Type::SmallSerial => ColumnType::SmallInteger,
70                Type::Serial => ColumnType::Integer,
71                Type::BigSerial => ColumnType::BigInteger,
72                Type::Money => ColumnType::Money(None),
73                Type::Varchar(string_attr) => match string_attr.length {
74                    Some(length) => ColumnType::String(StringLen::N(length.into())),
75                    None => ColumnType::String(StringLen::None),
76                },
77                Type::Char(string_attr) => ColumnType::Char(string_attr.length.map(Into::into)),
78                Type::Text => ColumnType::Text,
79                Type::Bytea => ColumnType::VarBinary(StringLen::None),
80                // The SQL standard requires that writing just timestamp be equivalent to timestamp without time zone,
81                // and PostgreSQL honors that behavior. (https://www.postgresql.org/docs/current/datatype-datetime.html)
82                Type::Timestamp(_) => ColumnType::DateTime,
83                Type::TimestampWithTimeZone(_) => ColumnType::TimestampWithTimeZone,
84                Type::Date => ColumnType::Date,
85                Type::Time(_) => ColumnType::Time,
86                Type::TimeWithTimeZone(_) => ColumnType::Time,
87                Type::Interval(interval_attr) => {
88                    let field = match &interval_attr.field {
89                        Some(field) => PgInterval::try_from(field).ok(),
90                        None => None,
91                    };
92                    let precision = interval_attr.precision.map(Into::into);
93                    ColumnType::Interval(field, precision)
94                }
95                Type::Boolean => ColumnType::Boolean,
96                Type::Point => ColumnType::Custom("point".into_iden()),
97                Type::Line => ColumnType::Custom("line".into_iden()),
98                Type::Lseg => ColumnType::Custom("lseg".into_iden()),
99                Type::Box => ColumnType::Custom("box".into_iden()),
100                Type::Path => ColumnType::Custom("path".into_iden()),
101                Type::Polygon => ColumnType::Custom("polygon".into_iden()),
102                Type::Circle => ColumnType::Custom("circle".into_iden()),
103                Type::Cidr => ColumnType::Cidr,
104                Type::Inet => ColumnType::Inet,
105                Type::MacAddr => ColumnType::Custom("macaddr".into_iden()),
106                Type::MacAddr8 => ColumnType::Custom("macaddr8".into_iden()),
107                Type::Bit(bit_attr) => ColumnType::Bit(bit_attr.length.map(Into::into)),
108                Type::VarBit(bit_attr) => ColumnType::VarBit(bit_attr.length.unwrap_or(1).into()),
109                Type::TsVector => ColumnType::Custom("tsvector".into_iden()),
110                Type::TsQuery => ColumnType::Custom("tsquery".into_iden()),
111                Type::Uuid => ColumnType::Uuid,
112                Type::Xml => ColumnType::Custom("xml".into_iden()),
113                Type::Json => ColumnType::Json,
114                Type::JsonBinary => ColumnType::JsonBinary,
115                Type::Int4Range => ColumnType::Custom("int4range".into_iden()),
116                Type::Int8Range => ColumnType::Custom("int8range".into_iden()),
117                Type::NumRange => ColumnType::Custom("numrange".into_iden()),
118                Type::TsRange => ColumnType::Custom("tsrange".into_iden()),
119                Type::TsTzRange => ColumnType::Custom("tstzrange".into_iden()),
120                Type::DateRange => ColumnType::Custom("daterange".into_iden()),
121                Type::PgLsn => ColumnType::Custom("pg_lsn".into_iden()),
122                #[cfg(feature = "postgres-vector")]
123                Type::Vector(vector_attr) => match vector_attr.length {
124                    Some(length) => ColumnType::Vector(Some(length)),
125                    None => ColumnType::Vector(None),
126                },
127                Type::Unknown(s) => ColumnType::Custom(Alias::new(s).into_iden()),
128                Type::Enum(enum_def) => {
129                    let name = Alias::new(&enum_def.typename).into_iden();
130                    let variants: Vec<DynIden> = enum_def
131                        .values
132                        .iter()
133                        .map(|variant| Alias::new(variant).into_iden())
134                        .collect();
135                    ColumnType::Enum { name, variants }
136                }
137                Type::Array(array_def) => ColumnType::Array(RcOrArc::new(write_type(
138                    array_def.col_type.as_ref().expect("Array type not defined"),
139                ))),
140            }
141        }
142        write_type(&self.col_type)
143    }
144}
145
146impl ColumnDefault {
147    /// Convert to a [SimpleExpr] for use with `col_def.default()`.
148    /// Returns `None` for [ColumnDefault::AutoIncrement] since those are handled
149    /// via SERIAL type conversion instead.
150    pub fn write(&self) -> Option<SimpleExpr> {
151        match self {
152            ColumnDefault::Int(int) => Some((*int).into()),
153            ColumnDefault::Real(real) => Some((*real).into()),
154            ColumnDefault::String(string) => Some(string.into()),
155            ColumnDefault::Bool(val) => Some(Expr::val(*val)),
156            ColumnDefault::CurrentTimestamp => Some(Keyword::CurrentTimestamp.into()),
157            ColumnDefault::AutoIncrement(_) => None,
158            ColumnDefault::Expression(expr) => Some(Expr::cust(expr.to_owned())),
159        }
160    }
161}