1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use crate::ast::{DataType, Expr, QuoteStyle, Statement};
5
6pub trait DialectPlugin: Send + Sync {
44 fn name(&self) -> &str;
46
47 fn quote_style(&self) -> Option<QuoteStyle> {
51 None
52 }
53
54 fn supports_ilike(&self) -> Option<bool> {
58 None
59 }
60
61 fn map_function_name(&self, name: &str) -> Option<String> {
67 let _ = name;
68 None
69 }
70
71 fn map_data_type(&self, data_type: &DataType) -> Option<DataType> {
75 let _ = data_type;
76 None
77 }
78
79 fn transform_expr(&self, expr: &Expr) -> Option<Expr> {
84 let _ = expr;
85 None
86 }
87
88 fn transform_statement(&self, statement: &Statement) -> Option<Statement> {
93 let _ = statement;
94 None
95 }
96}
97
98pub struct DialectRegistry {
106 dialects: Mutex<HashMap<String, Arc<dyn DialectPlugin>>>,
107}
108
109impl DialectRegistry {
110 pub fn global() -> &'static DialectRegistry {
112 static INSTANCE: OnceLock<DialectRegistry> = OnceLock::new();
113 INSTANCE.get_or_init(|| DialectRegistry {
114 dialects: Mutex::new(HashMap::new()),
115 })
116 }
117
118 pub fn register<P: DialectPlugin + 'static>(&self, plugin: P) {
122 let name = plugin.name().to_lowercase();
123 let mut map = self
124 .dialects
125 .lock()
126 .expect("dialect registry lock poisoned");
127 map.insert(name, Arc::new(plugin));
128 }
129
130 #[must_use]
132 pub fn get(&self, name: &str) -> Option<Arc<dyn DialectPlugin>> {
133 let map = self
134 .dialects
135 .lock()
136 .expect("dialect registry lock poisoned");
137 map.get(&name.to_lowercase()).cloned()
138 }
139
140 pub fn unregister(&self, name: &str) -> bool {
144 let mut map = self
145 .dialects
146 .lock()
147 .expect("dialect registry lock poisoned");
148 map.remove(&name.to_lowercase()).is_some()
149 }
150
151 #[must_use]
153 pub fn registered_names(&self) -> Vec<String> {
154 let map = self
155 .dialects
156 .lock()
157 .expect("dialect registry lock poisoned");
158 map.keys().cloned().collect()
159 }
160}
161
162use crate::dialects::Dialect;
167
168#[derive(Debug, Clone, PartialEq, Eq, Hash)]
183pub enum DialectRef {
184 BuiltIn(Dialect),
186 Custom(String),
188}
189
190impl DialectRef {
191 #[must_use]
193 pub fn custom(name: &str) -> Self {
194 DialectRef::Custom(name.to_lowercase())
195 }
196
197 #[must_use]
199 pub fn as_builtin(&self) -> Option<Dialect> {
200 match self {
201 DialectRef::BuiltIn(d) => Some(*d),
202 DialectRef::Custom(_) => None,
203 }
204 }
205
206 #[must_use]
208 pub fn as_plugin(&self) -> Option<Arc<dyn DialectPlugin>> {
209 match self {
210 DialectRef::Custom(name) => DialectRegistry::global().get(name),
211 DialectRef::BuiltIn(_) => None,
212 }
213 }
214
215 #[must_use]
217 pub fn quote_style(&self) -> QuoteStyle {
218 match self {
219 DialectRef::BuiltIn(d) => QuoteStyle::for_dialect(*d),
220 DialectRef::Custom(name) => DialectRegistry::global()
221 .get(name)
222 .and_then(|p| p.quote_style())
223 .unwrap_or(QuoteStyle::DoubleQuote),
224 }
225 }
226
227 #[must_use]
229 pub fn supports_ilike(&self) -> bool {
230 match self {
231 DialectRef::BuiltIn(d) => super::supports_ilike_builtin(*d),
232 DialectRef::Custom(name) => DialectRegistry::global()
233 .get(name)
234 .and_then(|p| p.supports_ilike())
235 .unwrap_or(false),
236 }
237 }
238
239 #[must_use]
241 pub fn map_function_name(&self, name: &str) -> String {
242 match self {
243 DialectRef::BuiltIn(d) => super::map_function_name(name, *d),
244 DialectRef::Custom(cname) => DialectRegistry::global()
245 .get(cname)
246 .and_then(|p| p.map_function_name(name))
247 .unwrap_or_else(|| name.to_string()),
248 }
249 }
250
251 #[must_use]
253 pub fn map_data_type(&self, dt: &DataType) -> DataType {
254 match self {
255 DialectRef::BuiltIn(d) => super::map_data_type(dt.clone(), *d),
256 DialectRef::Custom(name) => DialectRegistry::global()
257 .get(name)
258 .and_then(|p| p.map_data_type(dt))
259 .unwrap_or_else(|| dt.clone()),
260 }
261 }
262}
263
264impl From<Dialect> for DialectRef {
265 fn from(d: Dialect) -> Self {
266 DialectRef::BuiltIn(d)
267 }
268}
269
270impl std::fmt::Display for DialectRef {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 match self {
273 DialectRef::BuiltIn(d) => write!(f, "{d}"),
274 DialectRef::Custom(name) => write!(f, "Custom({name})"),
275 }
276 }
277}
278
279use crate::ast::TypedFunction;
288
289fn typed_function_canonical_name(func: &TypedFunction) -> &'static str {
291 match func {
292 TypedFunction::DateAdd { .. } => "DATE_ADD",
293 TypedFunction::DateDiff { .. } => "DATE_DIFF",
294 TypedFunction::DateTrunc { .. } => "DATE_TRUNC",
295 TypedFunction::DateSub { .. } => "DATE_SUB",
296 TypedFunction::CurrentDate => "CURRENT_DATE",
297 TypedFunction::CurrentTime => "CURRENT_TIME",
298 TypedFunction::CurrentTimestamp => "NOW",
299 TypedFunction::StrToTime { .. } => "STR_TO_TIME",
300 TypedFunction::TimeToStr { .. } => "TIME_TO_STR",
301 TypedFunction::TsOrDsToDate { .. } => "TS_OR_DS_TO_DATE",
302 TypedFunction::Year { .. } => "YEAR",
303 TypedFunction::Month { .. } => "MONTH",
304 TypedFunction::Day { .. } => "DAY",
305 TypedFunction::Trim { .. } => "TRIM",
306 TypedFunction::Substring { .. } => "SUBSTRING",
307 TypedFunction::Upper { .. } => "UPPER",
308 TypedFunction::Lower { .. } => "LOWER",
309 TypedFunction::RegexpLike { .. } => "REGEXP_LIKE",
310 TypedFunction::RegexpExtract { .. } => "REGEXP_EXTRACT",
311 TypedFunction::RegexpReplace { .. } => "REGEXP_REPLACE",
312 TypedFunction::ConcatWs { .. } => "CONCAT_WS",
313 TypedFunction::Split { .. } => "SPLIT",
314 TypedFunction::Initcap { .. } => "INITCAP",
315 TypedFunction::Length { .. } => "LENGTH",
316 TypedFunction::Replace { .. } => "REPLACE",
317 TypedFunction::Reverse { .. } => "REVERSE",
318 TypedFunction::Left { .. } => "LEFT",
319 TypedFunction::Right { .. } => "RIGHT",
320 TypedFunction::Lpad { .. } => "LPAD",
321 TypedFunction::Rpad { .. } => "RPAD",
322 TypedFunction::Count { .. } => "COUNT",
323 TypedFunction::Sum { .. } => "SUM",
324 TypedFunction::Avg { .. } => "AVG",
325 TypedFunction::Min { .. } => "MIN",
326 TypedFunction::Max { .. } => "MAX",
327 TypedFunction::ArrayAgg { .. } => "ARRAY_AGG",
328 TypedFunction::ApproxDistinct { .. } => "APPROX_DISTINCT",
329 TypedFunction::Variance { .. } => "VARIANCE",
330 TypedFunction::Stddev { .. } => "STDDEV",
331 TypedFunction::GroupConcat { .. } => "GROUP_CONCAT",
332 TypedFunction::ArrayConcat { .. } => "ARRAY_CONCAT",
333 TypedFunction::ArrayContains { .. } => "ARRAY_CONTAINS",
334 TypedFunction::ArraySize { .. } => "ARRAY_SIZE",
335 TypedFunction::Explode { .. } => "EXPLODE",
336 TypedFunction::GenerateSeries { .. } => "GENERATE_SERIES",
337 TypedFunction::Flatten { .. } => "FLATTEN",
338 TypedFunction::JSONExtract { .. } => "JSON_EXTRACT",
339 TypedFunction::JSONExtractScalar { .. } => "JSON_EXTRACT_SCALAR",
340 TypedFunction::ParseJSON { .. } => "PARSE_JSON",
341 TypedFunction::JSONFormat { .. } => "JSON_FORMAT",
342 TypedFunction::RowNumber => "ROW_NUMBER",
343 TypedFunction::Rank => "RANK",
344 TypedFunction::DenseRank => "DENSE_RANK",
345 TypedFunction::NTile { .. } => "NTILE",
346 TypedFunction::Lead { .. } => "LEAD",
347 TypedFunction::Lag { .. } => "LAG",
348 TypedFunction::FirstValue { .. } => "FIRST_VALUE",
349 TypedFunction::LastValue { .. } => "LAST_VALUE",
350 TypedFunction::Abs { .. } => "ABS",
351 TypedFunction::Ceil { .. } => "CEIL",
352 TypedFunction::Floor { .. } => "FLOOR",
353 TypedFunction::Round { .. } => "ROUND",
354 TypedFunction::Log { .. } => "LOG",
355 TypedFunction::Ln { .. } => "LN",
356 TypedFunction::Pow { .. } => "POW",
357 TypedFunction::Sqrt { .. } => "SQRT",
358 TypedFunction::Greatest { .. } => "GREATEST",
359 TypedFunction::Least { .. } => "LEAST",
360 TypedFunction::Mod { .. } => "MOD",
361 TypedFunction::Hex { .. } => "HEX",
362 TypedFunction::Unhex { .. } => "UNHEX",
363 TypedFunction::Md5 { .. } => "MD5",
364 TypedFunction::Sha { .. } => "SHA",
365 TypedFunction::Sha2 { .. } => "SHA2",
366 }
367}
368
369fn typed_function_args(func: &TypedFunction) -> Vec<Expr> {
371 match func {
372 TypedFunction::CurrentDate
373 | TypedFunction::CurrentTime
374 | TypedFunction::CurrentTimestamp => vec![],
375 TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => vec![],
376 TypedFunction::Length { expr }
377 | TypedFunction::Upper { expr }
378 | TypedFunction::Lower { expr }
379 | TypedFunction::Initcap { expr }
380 | TypedFunction::Reverse { expr }
381 | TypedFunction::Abs { expr }
382 | TypedFunction::Ceil { expr }
383 | TypedFunction::Floor { expr }
384 | TypedFunction::Ln { expr }
385 | TypedFunction::Sqrt { expr }
386 | TypedFunction::Explode { expr }
387 | TypedFunction::Flatten { expr }
388 | TypedFunction::ArraySize { expr }
389 | TypedFunction::ParseJSON { expr }
390 | TypedFunction::JSONFormat { expr }
391 | TypedFunction::Hex { expr }
392 | TypedFunction::Unhex { expr }
393 | TypedFunction::Md5 { expr }
394 | TypedFunction::Sha { expr }
395 | TypedFunction::TsOrDsToDate { expr }
396 | TypedFunction::Year { expr }
397 | TypedFunction::Month { expr }
398 | TypedFunction::Day { expr }
399 | TypedFunction::ApproxDistinct { expr }
400 | TypedFunction::Variance { expr }
401 | TypedFunction::Stddev { expr }
402 | TypedFunction::FirstValue { expr }
403 | TypedFunction::LastValue { expr } => vec![*expr.clone()],
404 TypedFunction::DateTrunc { unit, expr } => {
405 vec![Expr::StringLiteral(format!("{unit:?}")), *expr.clone()]
406 }
407 TypedFunction::DateAdd { expr, interval, .. }
408 | TypedFunction::DateSub { expr, interval, .. } => {
409 vec![*expr.clone(), *interval.clone()]
410 }
411 TypedFunction::DateDiff { start, end, .. } => vec![*start.clone(), *end.clone()],
412 TypedFunction::StrToTime { expr, format } | TypedFunction::TimeToStr { expr, format } => {
413 vec![*expr.clone(), *format.clone()]
414 }
415 TypedFunction::Trim { expr, .. } => vec![*expr.clone()],
416 TypedFunction::Substring {
417 expr,
418 start,
419 length,
420 } => {
421 let mut args = vec![*expr.clone(), *start.clone()];
422 if let Some(len) = length {
423 args.push(*len.clone());
424 }
425 args
426 }
427 TypedFunction::RegexpLike {
428 expr,
429 pattern,
430 flags,
431 } => {
432 let mut args = vec![*expr.clone(), *pattern.clone()];
433 if let Some(f) = flags {
434 args.push(*f.clone());
435 }
436 args
437 }
438 TypedFunction::RegexpExtract {
439 expr,
440 pattern,
441 group_index,
442 } => {
443 let mut args = vec![*expr.clone(), *pattern.clone()];
444 if let Some(g) = group_index {
445 args.push(*g.clone());
446 }
447 args
448 }
449 TypedFunction::RegexpReplace {
450 expr,
451 pattern,
452 replacement,
453 flags,
454 } => {
455 let mut args = vec![*expr.clone(), *pattern.clone(), *replacement.clone()];
456 if let Some(f) = flags {
457 args.push(*f.clone());
458 }
459 args
460 }
461 TypedFunction::ConcatWs { separator, exprs } => {
462 let mut args = vec![*separator.clone()];
463 args.extend(exprs.iter().cloned());
464 args
465 }
466 TypedFunction::Split { expr, delimiter } => vec![*expr.clone(), *delimiter.clone()],
467 TypedFunction::Replace { expr, from, to } => {
468 vec![*expr.clone(), *from.clone(), *to.clone()]
469 }
470 TypedFunction::Left { expr, n } | TypedFunction::Right { expr, n } => {
471 vec![*expr.clone(), *n.clone()]
472 }
473 TypedFunction::Lpad { expr, length, pad } | TypedFunction::Rpad { expr, length, pad } => {
474 let mut args = vec![*expr.clone(), *length.clone()];
475 if let Some(p) = pad {
476 args.push(*p.clone());
477 }
478 args
479 }
480 TypedFunction::Count { expr, .. }
481 | TypedFunction::Sum { expr, .. }
482 | TypedFunction::Avg { expr, .. }
483 | TypedFunction::Min { expr }
484 | TypedFunction::Max { expr }
485 | TypedFunction::ArrayAgg { expr, .. } => vec![*expr.clone()],
486 TypedFunction::ArrayConcat { arrays } => arrays.clone(),
487 TypedFunction::ArrayContains { array, element } => {
488 vec![*array.clone(), *element.clone()]
489 }
490 TypedFunction::GenerateSeries { start, stop, step } => {
491 let mut args = vec![*start.clone(), *stop.clone()];
492 if let Some(s) = step {
493 args.push(*s.clone());
494 }
495 args
496 }
497 TypedFunction::JSONExtract { expr, path }
498 | TypedFunction::JSONExtractScalar { expr, path } => {
499 vec![*expr.clone(), *path.clone()]
500 }
501 TypedFunction::NTile { n } => vec![*n.clone()],
502 TypedFunction::Lead {
503 expr,
504 offset,
505 default,
506 }
507 | TypedFunction::Lag {
508 expr,
509 offset,
510 default,
511 } => {
512 let mut args = vec![*expr.clone()];
513 if let Some(o) = offset {
514 args.push(*o.clone());
515 }
516 if let Some(d) = default {
517 args.push(*d.clone());
518 }
519 args
520 }
521 TypedFunction::Round { expr, decimals } => {
522 let mut args = vec![*expr.clone()];
523 if let Some(d) = decimals {
524 args.push(*d.clone());
525 }
526 args
527 }
528 TypedFunction::Log { expr, base } => {
529 let mut args = vec![*expr.clone()];
530 if let Some(b) = base {
531 args.push(*b.clone());
532 }
533 args
534 }
535 TypedFunction::Pow { base, exponent } => vec![*base.clone(), *exponent.clone()],
536 TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => exprs.clone(),
537 TypedFunction::Mod { left, right } => vec![*left.clone(), *right.clone()],
538 TypedFunction::Sha2 { expr, bit_length } => vec![*expr.clone(), *bit_length.clone()],
539 TypedFunction::GroupConcat {
540 exprs, separator, ..
541 } => {
542 let mut args = exprs.clone();
543 if let Some(s) = separator {
544 args.push(*s.clone());
545 }
546 args
547 }
548 }
549}
550
551#[must_use]
558pub fn transform(statement: &Statement, from: &DialectRef, to: &DialectRef) -> Statement {
559 if let (DialectRef::BuiltIn(f), DialectRef::BuiltIn(t)) = (from, to) {
561 return super::transform(statement, *f, *t);
562 }
563
564 if let Some(plugin) = to.as_plugin()
566 && let Some(transformed) = plugin.transform_statement(statement)
567 {
568 return transformed;
569 }
570
571 let mut stmt = statement.clone();
573 transform_statement_plugin(&mut stmt, to);
574 stmt
575}
576
577fn transform_statement_plugin(statement: &mut Statement, target: &DialectRef) {
579 match statement {
580 Statement::Select(sel) => {
581 for item in &mut sel.columns {
582 if let crate::ast::SelectItem::Expr { expr, .. } = item {
583 *expr = transform_expr_plugin(expr.clone(), target);
584 }
585 }
586 if let Some(wh) = &mut sel.where_clause {
587 *wh = transform_expr_plugin(wh.clone(), target);
588 }
589 for gb in &mut sel.group_by {
590 *gb = transform_expr_plugin(gb.clone(), target);
591 }
592 if let Some(having) = &mut sel.having {
593 *having = transform_expr_plugin(having.clone(), target);
594 }
595 }
596 Statement::CreateTable(ct) => {
597 for col in &mut ct.columns {
598 col.data_type = target.map_data_type(&col.data_type);
599 if let Some(default) = &mut col.default {
600 *default = transform_expr_plugin(default.clone(), target);
601 }
602 }
603 }
604 _ => {}
605 }
606}
607
608fn transform_expr_plugin(expr: Expr, target: &DialectRef) -> Expr {
610 if let Some(plugin) = target.as_plugin()
612 && let Some(transformed) = plugin.transform_expr(&expr)
613 {
614 return transformed;
615 }
616
617 match expr {
618 Expr::TypedFunction { func, filter, over } => {
621 if let DialectRef::Custom(_) = target {
622 let canonical = typed_function_canonical_name(&func);
623 let new_name = target.map_function_name(canonical);
624 if new_name != canonical {
625 let args = typed_function_args(&func)
627 .into_iter()
628 .map(|a| transform_expr_plugin(a, target))
629 .collect();
630 return Expr::Function {
631 name: new_name,
632 args,
633 distinct: false,
634 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
635 over,
636 order_by: vec![],
637 within_group: false,
638 };
639 }
640 }
641 Expr::TypedFunction {
643 func: func.transform_children(&|e| transform_expr_plugin(e, target)),
644 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
645 over,
646 }
647 }
648 Expr::Function {
649 name,
650 args,
651 distinct,
652 filter,
653 over,
654 order_by,
655 within_group,
656 } => {
657 let new_name = target.map_function_name(&name);
658 let new_args: Vec<Expr> = args
659 .into_iter()
660 .map(|a| transform_expr_plugin(a, target))
661 .collect();
662 Expr::Function {
663 name: new_name,
664 args: new_args,
665 distinct,
666 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
667 over,
668 order_by,
669 within_group,
670 }
671 }
672 Expr::Cast { expr, data_type } => Expr::Cast {
673 expr: Box::new(transform_expr_plugin(*expr, target)),
674 data_type: target.map_data_type(&data_type),
675 },
676 Expr::ILike {
677 expr,
678 pattern,
679 negated,
680 escape,
681 } if !target.supports_ilike() => Expr::Like {
682 expr: Box::new(Expr::TypedFunction {
683 func: crate::ast::TypedFunction::Lower {
684 expr: Box::new(transform_expr_plugin(*expr, target)),
685 },
686 filter: None,
687 over: None,
688 }),
689 pattern: Box::new(Expr::TypedFunction {
690 func: crate::ast::TypedFunction::Lower {
691 expr: Box::new(transform_expr_plugin(*pattern, target)),
692 },
693 filter: None,
694 over: None,
695 }),
696 negated,
697 escape,
698 },
699 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
700 left: Box::new(transform_expr_plugin(*left, target)),
701 op,
702 right: Box::new(transform_expr_plugin(*right, target)),
703 },
704 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
705 op,
706 expr: Box::new(transform_expr_plugin(*expr, target)),
707 },
708 Expr::Nested(inner) => Expr::Nested(Box::new(transform_expr_plugin(*inner, target))),
709 Expr::Column {
710 table,
711 name,
712 quote_style,
713 table_quote_style,
714 } => {
715 let new_qs = if quote_style.is_quoted() {
716 target.quote_style()
717 } else {
718 QuoteStyle::None
719 };
720 let new_tqs = if table_quote_style.is_quoted() {
721 target.quote_style()
722 } else {
723 QuoteStyle::None
724 };
725 Expr::Column {
726 table,
727 name,
728 quote_style: new_qs,
729 table_quote_style: new_tqs,
730 }
731 }
732 other => other,
733 }
734}
735
736use crate::errors;
741
742pub fn transpile_ext(
761 sql: &str,
762 read_dialect: &DialectRef,
763 write_dialect: &DialectRef,
764) -> errors::Result<String> {
765 let parse_dialect = read_dialect.as_builtin().unwrap_or(Dialect::Ansi);
767 let ast = crate::parser::parse(sql, parse_dialect)?;
768 let transformed = transform(&ast, read_dialect, write_dialect);
769 let gen_dialect = write_dialect.as_builtin().unwrap_or(Dialect::Ansi);
770 Ok(crate::generator::generate(&transformed, gen_dialect))
771}
772
773pub fn transpile_statements_ext(
779 sql: &str,
780 read_dialect: &DialectRef,
781 write_dialect: &DialectRef,
782) -> errors::Result<Vec<String>> {
783 let parse_dialect = read_dialect.as_builtin().unwrap_or(Dialect::Ansi);
784 let stmts = crate::parser::parse_statements(sql, parse_dialect)?;
785 let gen_dialect = write_dialect.as_builtin().unwrap_or(Dialect::Ansi);
786 let mut results = Vec::with_capacity(stmts.len());
787 for stmt in &stmts {
788 let transformed = transform(stmt, read_dialect, write_dialect);
789 results.push(crate::generator::generate(&transformed, gen_dialect));
790 }
791 Ok(results)
792}
793
794pub fn register_dialect<P: DialectPlugin + 'static>(plugin: P) {
802 DialectRegistry::global().register(plugin);
803}
804
805#[must_use]
809pub fn resolve_dialect(name: &str) -> Option<DialectRef> {
810 if let Some(d) = Dialect::from_str(name) {
812 return Some(DialectRef::BuiltIn(d));
813 }
814 if DialectRegistry::global().get(name).is_some() {
816 return Some(DialectRef::Custom(name.to_lowercase()));
817 }
818 None
819}