1use std::fmt::{self, Display, Formatter};
2
3use sqlx_core::type_info::TypeInfo;
4
5use crate::protocol::type_info as protocol;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum MssqlType {
10 Null,
12 Bit,
14 TinyInt,
16 SmallInt,
18 Int,
20 BigInt,
22 Real,
24 Float,
26 NVarChar,
28 VarChar,
30 VarBinary,
32 Decimal,
34 Money,
36 Date,
38 Time,
40 DateTime,
42 DateTime2,
44 DateTimeOffset,
46 UniqueIdentifier,
48 Other(String),
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct MssqlTypeInfo {
55 kind: MssqlType,
56 variable_length: bool,
57 size: Option<u16>,
58 protocol_type_info: Option<protocol::TypeInfo>,
59}
60
61impl MssqlTypeInfo {
62 pub const fn new(kind: MssqlType) -> Self {
64 Self {
65 kind,
66 variable_length: false,
67 size: None,
68 protocol_type_info: None,
69 }
70 }
71
72 pub(crate) const fn with_size(kind: MssqlType, size: u16) -> Self {
73 Self {
74 kind,
75 variable_length: true,
76 size: Some(size),
77 protocol_type_info: None,
78 }
79 }
80
81 pub(crate) const fn with_protocol(
82 kind: MssqlType,
83 protocol_type_info: protocol::TypeInfo,
84 ) -> Self {
85 Self {
86 kind,
87 variable_length: true,
88 size: Some(protocol_type_info.size as u16),
89 protocol_type_info: Some(protocol_type_info),
90 }
91 }
92
93 pub fn kind(&self) -> &MssqlType {
95 &self.kind
96 }
97
98 pub(crate) const fn size(&self) -> Option<u16> {
99 self.size
100 }
101
102 pub(crate) const fn protocol_type_info(&self) -> Option<&protocol::TypeInfo> {
103 self.protocol_type_info.as_ref()
104 }
105
106 pub(crate) const fn scale(&self) -> u8 {
107 match &self.protocol_type_info {
108 Some(protocol_type_info) => protocol_type_info.scale,
109 None => 0,
110 }
111 }
112
113 pub(crate) const fn precision(&self) -> u8 {
114 match &self.protocol_type_info {
115 Some(protocol_type_info) => protocol_type_info.precision,
116 None => 0,
117 }
118 }
119
120 pub const NULL: Self = Self::new(MssqlType::Null);
122 pub const BIT: Self = Self::new(MssqlType::Bit);
124 pub const TINYINT: Self = Self::new(MssqlType::TinyInt);
126 pub const SMALLINT: Self = Self::new(MssqlType::SmallInt);
128 pub const INT: Self = Self::new(MssqlType::Int);
130 pub const BIGINT: Self = Self::new(MssqlType::BigInt);
132 pub const REAL: Self = Self::new(MssqlType::Real);
134 pub const FLOAT: Self = Self::new(MssqlType::Float);
136 pub const NVARCHAR: Self = Self::new(MssqlType::NVarChar);
138 pub const VARCHAR: Self = Self::new(MssqlType::VarChar);
140 pub const VARBINARY: Self = Self::new(MssqlType::VarBinary);
142 pub const DECIMAL: Self = Self::with_protocol(
144 MssqlType::Decimal,
145 protocol::TypeInfo {
146 ty: protocol::DataType::NumericN,
147 size: 17,
148 scale: 0,
149 precision: 38,
150 collation: None,
151 },
152 );
153 pub const MONEY: Self = Self::new(MssqlType::Money);
155 pub const DATE: Self = Self::with_protocol(
157 MssqlType::Date,
158 protocol::TypeInfo {
159 ty: protocol::DataType::DateN,
160 size: 3,
161 scale: 0,
162 precision: 10,
163 collation: None,
164 },
165 );
166 pub const TIME: Self = Self::with_protocol(
168 MssqlType::Time,
169 protocol::TypeInfo {
170 ty: protocol::DataType::TimeN,
171 size: 5,
172 scale: 7,
173 precision: 0,
174 collation: None,
175 },
176 );
177 pub const DATETIME2: Self = Self::with_protocol(
179 MssqlType::DateTime2,
180 protocol::TypeInfo {
181 ty: protocol::DataType::DateTime2N,
182 size: 8,
183 scale: 7,
184 precision: 0,
185 collation: None,
186 },
187 );
188 pub const DATETIMEOFFSET: Self = Self::with_protocol(
190 MssqlType::DateTimeOffset,
191 protocol::TypeInfo {
192 ty: protocol::DataType::DateTimeOffsetN,
193 size: 10,
194 scale: 7,
195 precision: 34,
196 collation: None,
197 },
198 );
199 pub const UNIQUEIDENTIFIER: Self = Self::with_protocol(
201 MssqlType::UniqueIdentifier,
202 protocol::TypeInfo {
203 ty: protocol::DataType::Guid,
204 size: 16,
205 scale: 0,
206 precision: 0,
207 collation: None,
208 },
209 );
210
211 #[cfg(any(feature = "bigdecimal", feature = "decimal"))]
212 pub(crate) const fn decimal_with_scale(scale: u8) -> Self {
213 Self::with_protocol(
214 MssqlType::Decimal,
215 protocol::TypeInfo {
216 ty: protocol::DataType::NumericN,
217 size: 17,
218 scale,
219 precision: 38,
220 collation: None,
221 },
222 )
223 }
224
225 pub(crate) fn from_protocol(type_info: &protocol::TypeInfo) -> Self {
226 let kind = match type_info.ty {
227 protocol::DataType::Null => MssqlType::Null,
228 protocol::DataType::Bit | protocol::DataType::BitN => MssqlType::Bit,
229 protocol::DataType::TinyInt => MssqlType::TinyInt,
230 protocol::DataType::SmallInt => MssqlType::SmallInt,
231 protocol::DataType::Int => MssqlType::Int,
232 protocol::DataType::BigInt => MssqlType::BigInt,
233 protocol::DataType::Real => MssqlType::Real,
234 protocol::DataType::Float => MssqlType::Float,
235 protocol::DataType::IntN => match type_info.size {
236 1 => MssqlType::TinyInt,
237 2 => MssqlType::SmallInt,
238 4 => MssqlType::Int,
239 8 => MssqlType::BigInt,
240 _ => MssqlType::Other(type_info.name().to_owned()),
241 },
242 protocol::DataType::FloatN => match type_info.size {
243 4 => MssqlType::Real,
244 8 => MssqlType::Float,
245 _ => MssqlType::Other(type_info.name().to_owned()),
246 },
247 protocol::DataType::NVarChar | protocol::DataType::NChar => MssqlType::NVarChar,
248 protocol::DataType::VarChar
249 | protocol::DataType::Char
250 | protocol::DataType::BigVarChar
251 | protocol::DataType::BigChar => MssqlType::VarChar,
252 protocol::DataType::VarBinary
253 | protocol::DataType::Binary
254 | protocol::DataType::BigVarBinary
255 | protocol::DataType::BigBinary => MssqlType::VarBinary,
256 protocol::DataType::Decimal
257 | protocol::DataType::DecimalN
258 | protocol::DataType::Numeric
259 | protocol::DataType::NumericN => MssqlType::Decimal,
260 protocol::DataType::Money
261 | protocol::DataType::MoneyN
262 | protocol::DataType::SmallMoney => MssqlType::Money,
263 protocol::DataType::DateN => MssqlType::Date,
264 protocol::DataType::TimeN => MssqlType::Time,
265 protocol::DataType::DateTime
266 | protocol::DataType::DateTimeN
267 | protocol::DataType::SmallDateTime => MssqlType::DateTime,
268 protocol::DataType::DateTime2N => MssqlType::DateTime2,
269 protocol::DataType::DateTimeOffsetN => MssqlType::DateTimeOffset,
270 protocol::DataType::Guid => MssqlType::UniqueIdentifier,
271 _ => MssqlType::Other(type_info.name().to_owned()),
272 };
273
274 Self {
275 kind,
276 variable_length: type_info.is_nullable_or_variable_length(),
277 size: u16::try_from(type_info.size).ok(),
278 protocol_type_info: Some(type_info.clone()),
279 }
280 }
281}
282
283impl TypeInfo for MssqlTypeInfo {
284 fn is_null(&self) -> bool {
285 matches!(self.kind, MssqlType::Null)
286 }
287
288 fn name(&self) -> &str {
289 match &self.kind {
290 MssqlType::Null => "NULL",
291 MssqlType::Bit => "BIT",
292 MssqlType::TinyInt => "TINYINT",
293 MssqlType::SmallInt => "SMALLINT",
294 MssqlType::Int => "INT",
295 MssqlType::BigInt => "BIGINT",
296 MssqlType::Real => "REAL",
297 MssqlType::Float => "FLOAT",
298 MssqlType::NVarChar => "NVARCHAR",
299 MssqlType::VarChar => "VARCHAR",
300 MssqlType::VarBinary => "VARBINARY",
301 MssqlType::Decimal => "DECIMAL",
302 MssqlType::Money => "MONEY",
303 MssqlType::Date => "DATE",
304 MssqlType::Time => "TIME",
305 MssqlType::DateTime => "DATETIME",
306 MssqlType::DateTime2 => "DATETIME2",
307 MssqlType::DateTimeOffset => "DATETIMEOFFSET",
308 MssqlType::UniqueIdentifier => "UNIQUEIDENTIFIER",
309 MssqlType::Other(name) => name,
310 }
311 }
312
313 fn type_compatible(&self, other: &Self) -> bool {
314 self.kind == other.kind || self.is_null() || other.is_null()
315 }
316}
317
318impl Display for MssqlTypeInfo {
319 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
320 f.write_str(self.name())
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn exposes_sql_server_type_names() {
330 assert_eq!("INT", MssqlTypeInfo::INT.name());
331 assert_eq!("NVARCHAR", MssqlTypeInfo::NVARCHAR.to_string());
332 assert_eq!("VARCHAR", MssqlTypeInfo::VARCHAR.to_string());
333 }
334
335 #[test]
336 fn null_is_compatible_with_known_types() {
337 assert!(MssqlTypeInfo::NULL.type_compatible(&MssqlTypeInfo::INT));
338 assert!(MssqlTypeInfo::NVARCHAR.type_compatible(&MssqlTypeInfo::NULL));
339 assert!(!MssqlTypeInfo::INT.type_compatible(&MssqlTypeInfo::BIGINT));
340 }
341
342 #[test]
343 fn maps_unicode_and_non_unicode_protocol_text_separately() {
344 assert_eq!(
345 MssqlType::NVarChar,
346 MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(protocol::DataType::NVarChar, 8))
347 .kind
348 );
349 assert_eq!(
350 MssqlType::VarChar,
351 MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(protocol::DataType::VarChar, 8))
352 .kind
353 );
354 assert_eq!(
355 MssqlType::VarChar,
356 MssqlTypeInfo::from_protocol(&protocol::TypeInfo::new(
357 protocol::DataType::BigVarChar,
358 8,
359 ))
360 .kind
361 );
362 }
363}