1mod boolean;
2mod data_set;
3mod date_time;
4mod decimal;
5mod fixed_char;
6mod float;
7mod integer;
8mod interval_day_to_second;
9mod interval_year_to_month;
10mod ip_address;
11pub mod json;
12mod map;
13mod option;
14mod row;
15mod seq;
16mod string;
17mod util;
18pub mod uuid;
19
20pub use self::uuid::*;
21pub use boolean::*;
22pub use data_set::*;
23pub use date_time::*;
24pub use decimal::*;
25pub use fixed_char::*;
26pub use float::*;
27pub use integer::*;
28pub use interval_day_to_second::*;
29pub use interval_year_to_month::*;
30pub use ip_address::*;
31pub use map::*;
32pub use option::*;
33pub use row::*;
34pub use seq::*;
35pub use string::*;
36
37use std::borrow::Cow;
41use std::collections::HashMap;
42use std::iter::FromIterator;
43use std::sync::Arc;
44
45use crate::{
46 ClientTypeSignatureParameter, Column, NamedTypeSignature, RawTrinoTy, RowFieldName,
47 TypeSignature,
48};
49use derive_more::Display;
50use iterable::*;
51use serde::de::DeserializeSeed;
52use serde::Serialize;
53
54#[derive(Display, Debug)]
56pub enum Error {
57 InvalidTrinoType,
58 InvalidColumn,
59 InvalidTypeSignature,
60 ParseDecimalFailed(String),
61 ParseIntervalMonthFailed,
62 ParseIntervalDayFailed,
63 EmptyInTrinoRow,
64 NoneTrinoRow,
65}
66
67pub trait Trino {
68 type ValueType<'a>: Serialize
69 where
70 Self: 'a;
71 type Seed<'a, 'de>: DeserializeSeed<'de, Value = Self>;
72
73 fn value(&self) -> Self::ValueType<'_>;
74
75 fn ty() -> TrinoTy;
76
77 fn seed<'a, 'de>(ctx: &'a Context<'a>) -> Self::Seed<'a, 'de>;
79
80 fn empty() -> Self;
81}
82
83pub trait TrinoMapKey: Trino {}
84
85#[derive(Debug)]
86pub struct Context<'a> {
87 ty: &'a TrinoTy,
88 map: Arc<HashMap<usize, Vec<usize>>>,
89}
90
91impl<'a> Context<'a> {
92 pub fn new<T: Trino>(provided: &'a TrinoTy) -> Result<Self, Error> {
93 let target = T::ty();
94 let ret = extract(&target, provided)?;
95 let map = HashMap::from_iter(ret);
96 Ok(Context {
97 ty: provided,
98 map: Arc::new(map),
99 })
100 }
101
102 pub fn with_ty(&'a self, ty: &'a TrinoTy) -> Context<'a> {
103 Context {
104 ty,
105 map: self.map.clone(),
106 }
107 }
108
109 pub fn ty(&self) -> &TrinoTy {
110 self.ty
111 }
112
113 pub fn row_map(&self) -> Option<&[usize]> {
114 let key = self.ty as *const TrinoTy as usize;
115 self.map.get(&key).map(|r| &**r)
116 }
117}
118
119fn extract(target: &TrinoTy, provided: &TrinoTy) -> Result<Vec<(usize, Vec<usize>)>, Error> {
120 use TrinoTy::*;
121
122 match (target, provided) {
123 (Unknown, _) => Ok(vec![]),
124 (Decimal(p1, s1), Decimal(p2, s2)) if p1 == p2 && s1 == s2 => Ok(vec![]),
125 (Option(ty), provided) => extract(ty, provided),
126 (Boolean, Boolean) => Ok(vec![]),
127 (Date, Date) => Ok(vec![]),
128 (Time, Time) => Ok(vec![]),
129 (TimeWithTimeZone, TimeWithTimeZone) => Ok(vec![]),
130 (Timestamp, Timestamp) => Ok(vec![]),
131 (TimestampWithTimeZone, TimestampWithTimeZone) => Ok(vec![]),
132 (IntervalYearToMonth, IntervalYearToMonth) => Ok(vec![]),
133 (IntervalDayToSecond, IntervalDayToSecond) => Ok(vec![]),
134 (TrinoInt(_), TrinoInt(_)) => Ok(vec![]),
135 (TrinoFloat(_), TrinoFloat(_)) => Ok(vec![]),
136 (Varchar, Varchar) => Ok(vec![]),
137 (Char(a), Char(b)) if a == b => Ok(vec![]),
138 (Tuple(t1), Tuple(t2)) => {
139 if t1.len() != t2.len() {
140 Err(Error::InvalidTrinoType)
141 } else {
142 t1.lazy_zip(t2).try_flat_map(|(l, r)| extract(l, r))
143 }
144 }
145 (Row(t1), Row(t2)) => {
146 if t1.len() != t2.len() {
147 Err(Error::InvalidTrinoType)
148 } else {
149 let t1k = t1.sorted_by(|t1, t2| Ord::cmp(&t1.0, &t2.0));
151 let t2k = t2.sorted_by(|t1, t2| Ord::cmp(&t1.0, &t2.0));
152
153 let ret = t1k.lazy_zip(t2k).try_flat_map(|(l, r)| {
154 if l.0 == r.0 {
155 extract(&l.1, &r.1)
156 } else {
157 Err(Error::InvalidTrinoType)
158 }
159 })?;
160
161 let map = t2.map(|provided| t1.position(|target| provided.0 == target.0).unwrap());
162 let key = provided as *const TrinoTy as usize;
163 Ok(ret.add_one((key, map)))
164 }
165 }
166 (Array(t1), Array(t2)) => extract(t1, t2),
167 (Map(t1k, t1v), Map(t2k, t2v)) => Ok(extract(t1k, t2k)?.chain(extract(t1v, t2v)?)),
168 (IpAddress, IpAddress) => Ok(vec![]),
169 (Uuid, Uuid) => Ok(vec![]),
170 (Json, Json) => Ok(vec![]),
171 _ => Err(Error::InvalidTrinoType),
172 }
173}
174
175#[derive(Clone, Debug, Eq, PartialEq)]
181pub enum TrinoTy {
182 Date,
183 Time,
184 TimeWithTimeZone,
185 Timestamp,
186 TimestampWithTimeZone,
187 Uuid,
188 IntervalYearToMonth,
189 IntervalDayToSecond,
190 Option(Box<TrinoTy>),
191 Boolean,
192 TrinoInt(TrinoInt),
193 TrinoFloat(TrinoFloat),
194 Varchar,
195 Char(usize),
196 Tuple(Vec<TrinoTy>),
197 Row(Vec<(String, TrinoTy)>),
198 Array(Box<TrinoTy>),
199 Map(Box<TrinoTy>, Box<TrinoTy>),
200 Decimal(usize, usize),
201 IpAddress,
202 Json,
203 Unknown,
204}
205
206#[derive(Clone, Debug, Eq, PartialEq)]
207pub enum TrinoInt {
208 I8,
209 I16,
210 I32,
211 I64,
212}
213
214#[derive(Clone, Debug, Eq, PartialEq)]
215pub enum TrinoFloat {
216 F32,
217 F64,
218}
219
220impl TrinoTy {
221 pub fn from_type_signature(mut sig: TypeSignature) -> Result<Self, Error> {
222 use TrinoFloat::*;
223 use TrinoInt::*;
224
225 let ty = match sig.raw_type {
226 RawTrinoTy::Date => TrinoTy::Date,
227 RawTrinoTy::Time => TrinoTy::Time,
228 RawTrinoTy::TimeWithTimeZone => TrinoTy::TimeWithTimeZone,
229 RawTrinoTy::Timestamp => TrinoTy::Timestamp,
230 RawTrinoTy::TimestampWithTimeZone => TrinoTy::TimestampWithTimeZone,
231 RawTrinoTy::IntervalYearToMonth => TrinoTy::IntervalYearToMonth,
232 RawTrinoTy::IntervalDayToSecond => TrinoTy::IntervalDayToSecond,
233 RawTrinoTy::Unknown => TrinoTy::Unknown,
234 RawTrinoTy::Decimal if sig.arguments.len() == 2 => {
235 let s_sig = sig.arguments.pop().unwrap();
236 let p_sig = sig.arguments.pop().unwrap();
237 if let (
238 ClientTypeSignatureParameter::LongLiteral(p),
239 ClientTypeSignatureParameter::LongLiteral(s),
240 ) = (p_sig, s_sig)
241 {
242 TrinoTy::Decimal(p as usize, s as usize)
243 } else {
244 return Err(Error::InvalidTypeSignature);
245 }
246 }
247 RawTrinoTy::Boolean => TrinoTy::Boolean,
248 RawTrinoTy::TinyInt => TrinoTy::TrinoInt(I8),
249 RawTrinoTy::SmallInt => TrinoTy::TrinoInt(I16),
250 RawTrinoTy::Integer => TrinoTy::TrinoInt(I32),
251 RawTrinoTy::BigInt => TrinoTy::TrinoInt(I64),
252 RawTrinoTy::Real => TrinoTy::TrinoFloat(F32),
253 RawTrinoTy::Double => TrinoTy::TrinoFloat(F64),
254 RawTrinoTy::VarChar => TrinoTy::Varchar,
255 RawTrinoTy::Char if sig.arguments.len() == 1 => {
256 if let ClientTypeSignatureParameter::LongLiteral(p) = sig.arguments.pop().unwrap() {
257 TrinoTy::Char(p as usize)
258 } else {
259 return Err(Error::InvalidTypeSignature);
260 }
261 }
262 RawTrinoTy::Array if sig.arguments.len() == 1 => {
263 let sig = sig.arguments.pop().unwrap();
264 if let ClientTypeSignatureParameter::TypeSignature(sig) = sig {
265 let inner = Self::from_type_signature(sig)?;
266 TrinoTy::Array(Box::new(inner))
267 } else {
268 return Err(Error::InvalidTypeSignature);
269 }
270 }
271 RawTrinoTy::Map if sig.arguments.len() == 2 => {
272 let v_sig = sig.arguments.pop().unwrap();
273 let k_sig = sig.arguments.pop().unwrap();
274 if let (
275 ClientTypeSignatureParameter::TypeSignature(k_sig),
276 ClientTypeSignatureParameter::TypeSignature(v_sig),
277 ) = (k_sig, v_sig)
278 {
279 let k_inner = Self::from_type_signature(k_sig)?;
280 let v_inner = Self::from_type_signature(v_sig)?;
281 TrinoTy::Map(Box::new(k_inner), Box::new(v_inner))
282 } else {
283 return Err(Error::InvalidTypeSignature);
284 }
285 }
286 RawTrinoTy::Row if !sig.arguments.is_empty() => {
287 let ir = sig.arguments.try_map(|arg| match arg {
288 ClientTypeSignatureParameter::NamedTypeSignature(sig) => {
289 let name = sig.field_name.map(|n| n.name);
290 let ty = Self::from_type_signature(sig.type_signature)?;
291 Ok((name, ty))
292 }
293 _ => Err(Error::InvalidTypeSignature),
294 })?;
295
296 let is_named = ir[0].0.is_some();
297
298 if is_named {
299 let row = ir.try_map(|(name, ty)| match name {
300 Some(n) => Ok((n, ty)),
301 None => Err(Error::InvalidTypeSignature),
302 })?;
303 TrinoTy::Row(row)
304 } else {
305 let tuple = ir.try_map(|(name, ty)| match name {
306 Some(_) => Err(Error::InvalidTypeSignature),
307 None => Ok(ty),
308 })?;
309 TrinoTy::Tuple(tuple)
310 }
311 }
312 RawTrinoTy::IpAddress => TrinoTy::IpAddress,
313 RawTrinoTy::Uuid => TrinoTy::Uuid,
314 RawTrinoTy::Json => TrinoTy::Json,
315 _ => return Err(Error::InvalidTypeSignature),
316 };
317
318 Ok(ty)
319 }
320
321 pub fn from_column(column: Column) -> Result<(String, Self), Error> {
322 let name = column.name;
323 if let Some(sig) = column.type_signature {
324 let ty = Self::from_type_signature(sig)?;
325 Ok((name, ty))
326 } else {
327 Err(Error::InvalidColumn)
328 }
329 }
330
331 pub fn from_columns(columns: Vec<Column>) -> Result<Self, Error> {
332 let row = columns.try_map(Self::from_column)?;
333 Ok(TrinoTy::Row(row))
334 }
335
336 pub fn into_type_signature(self) -> TypeSignature {
337 use TrinoTy::*;
338
339 let raw_ty = self.raw_type();
340
341 let params = match self {
342 Unknown => vec![],
343 Decimal(p, s) => vec![
344 ClientTypeSignatureParameter::LongLiteral(p as u64),
345 ClientTypeSignatureParameter::LongLiteral(s as u64),
346 ],
347 Date => vec![],
348 Time => vec![],
349 TimeWithTimeZone => vec![],
350 Timestamp => vec![],
351 TimestampWithTimeZone => vec![],
352 IntervalYearToMonth => vec![],
353 IntervalDayToSecond => vec![],
354 Option(t) => return t.into_type_signature(),
355 Boolean => vec![],
356 TrinoInt(_) => vec![],
357 TrinoFloat(_) => vec![],
358 Varchar => vec![ClientTypeSignatureParameter::LongLiteral(2147483647)],
359 Char(a) => vec![ClientTypeSignatureParameter::LongLiteral(a as u64)],
360 Tuple(ts) => ts.map(|ty| {
361 ClientTypeSignatureParameter::NamedTypeSignature(NamedTypeSignature {
362 field_name: None,
363 type_signature: ty.into_type_signature(),
364 })
365 }),
366 Row(ts) => ts.map(|(name, ty)| {
367 ClientTypeSignatureParameter::NamedTypeSignature(NamedTypeSignature {
368 field_name: Some(RowFieldName::new(name)),
369 type_signature: ty.into_type_signature(),
370 })
371 }),
372 Array(t) => vec![ClientTypeSignatureParameter::TypeSignature(
373 t.into_type_signature(),
374 )],
375 Map(t1, t2) => vec![
376 ClientTypeSignatureParameter::TypeSignature(t1.into_type_signature()),
377 ClientTypeSignatureParameter::TypeSignature(t2.into_type_signature()),
378 ],
379 IpAddress => vec![],
380 Uuid => vec![],
381 Json => vec![],
382 };
383
384 TypeSignature::new(raw_ty, params)
385 }
386
387 pub fn full_type(&self) -> Cow<'static, str> {
388 use TrinoTy::*;
389
390 match self {
391 Unknown => RawTrinoTy::Unknown.to_str().into(),
392 Decimal(p, s) => format!("{}({},{})", RawTrinoTy::Decimal.to_str(), p, s).into(),
393 Option(t) => t.full_type(),
394 Date => RawTrinoTy::Date.to_str().into(),
395 Time => RawTrinoTy::Time.to_str().into(),
396 TimeWithTimeZone => RawTrinoTy::TimeWithTimeZone.to_str().into(),
397 Timestamp => RawTrinoTy::Timestamp.to_str().into(),
398 TimestampWithTimeZone => RawTrinoTy::TimestampWithTimeZone.to_str().into(),
399 IntervalYearToMonth => RawTrinoTy::IntervalYearToMonth.to_str().into(),
400 IntervalDayToSecond => RawTrinoTy::IntervalDayToSecond.to_str().into(),
401 Boolean => RawTrinoTy::Boolean.to_str().into(),
402 TrinoInt(ty) => ty.raw_type().to_str().into(),
403 TrinoFloat(ty) => ty.raw_type().to_str().into(),
404 Varchar => RawTrinoTy::VarChar.to_str().into(),
405 Char(a) => format!("{}({})", RawTrinoTy::Char.to_str(), a).into(),
406 Tuple(ts) => format!(
407 "{}({})",
408 RawTrinoTy::Row.to_str(),
409 ts.lazy_map(|ty| ty.full_type()).join(",")
410 )
411 .into(),
412 Row(ts) => format!(
413 "{}({})",
414 RawTrinoTy::Row.to_str(),
415 ts.lazy_map(|(name, ty)| format!("{} {}", name, ty.full_type()))
416 .join(",")
417 )
418 .into(),
419 Array(t) => format!("{}({})", RawTrinoTy::Array.to_str(), t.full_type()).into(),
420 Map(t1, t2) => format!(
421 "{}({},{})",
422 RawTrinoTy::Map.to_str(),
423 t1.full_type(),
424 t2.full_type()
425 )
426 .into(),
427 IpAddress => RawTrinoTy::IpAddress.to_str().into(),
428 Uuid => RawTrinoTy::Uuid.to_str().into(),
429 Json => RawTrinoTy::Json.to_str().into(),
430 }
431 }
432
433 pub fn raw_type(&self) -> RawTrinoTy {
434 use TrinoTy::*;
435
436 match self {
437 Unknown => RawTrinoTy::Unknown,
438 Date => RawTrinoTy::Date,
439 Time => RawTrinoTy::Time,
440 TimeWithTimeZone => RawTrinoTy::TimeWithTimeZone,
441 Timestamp => RawTrinoTy::Timestamp,
442 TimestampWithTimeZone => RawTrinoTy::TimestampWithTimeZone,
443 IntervalYearToMonth => RawTrinoTy::IntervalYearToMonth,
444 IntervalDayToSecond => RawTrinoTy::IntervalDayToSecond,
445 Decimal(_, _) => RawTrinoTy::Decimal,
446 Option(ty) => ty.raw_type(),
447 Boolean => RawTrinoTy::Boolean,
448 TrinoInt(ty) => ty.raw_type(),
449 TrinoFloat(ty) => ty.raw_type(),
450 Varchar => RawTrinoTy::VarChar,
451 Char(_) => RawTrinoTy::Char,
452 Tuple(_) => RawTrinoTy::Row,
453 Row(_) => RawTrinoTy::Row,
454 Array(_) => RawTrinoTy::Array,
455 Map(_, _) => RawTrinoTy::Map,
456 IpAddress => RawTrinoTy::IpAddress,
457 Uuid => RawTrinoTy::Uuid,
458 Json => RawTrinoTy::Json,
459 }
460 }
461}
462
463impl TrinoInt {
464 pub fn raw_type(&self) -> RawTrinoTy {
465 use TrinoInt::*;
466 match self {
467 I8 => RawTrinoTy::TinyInt,
468 I16 => RawTrinoTy::SmallInt,
469 I32 => RawTrinoTy::Integer,
470 I64 => RawTrinoTy::BigInt,
471 }
472 }
473}
474
475impl TrinoFloat {
476 pub fn raw_type(&self) -> RawTrinoTy {
477 use TrinoFloat::*;
478 match self {
479 F32 => RawTrinoTy::Real,
480 F64 => RawTrinoTy::Double,
481 }
482 }
483}