1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use crate::ast::{DataType, Expr, QuoteStyle, Statement};
5
6pub trait DialectPlugin: Send + Sync {
44 fn name(&self) -> &str;
46
47 fn quote_style(&self) -> Option<QuoteStyle> {
51 None
52 }
53
54 fn supports_ilike(&self) -> Option<bool> {
58 None
59 }
60
61 fn map_function_name(&self, name: &str) -> Option<String> {
67 let _ = name;
68 None
69 }
70
71 fn map_data_type(&self, data_type: &DataType) -> Option<DataType> {
75 let _ = data_type;
76 None
77 }
78
79 fn transform_expr(&self, expr: &Expr) -> Option<Expr> {
84 let _ = expr;
85 None
86 }
87
88 fn transform_statement(&self, statement: &Statement) -> Option<Statement> {
93 let _ = statement;
94 None
95 }
96}
97
98pub struct DialectRegistry {
106 dialects: Mutex<HashMap<String, Arc<dyn DialectPlugin>>>,
107}
108
109impl DialectRegistry {
110 pub fn global() -> &'static DialectRegistry {
112 static INSTANCE: OnceLock<DialectRegistry> = OnceLock::new();
113 INSTANCE.get_or_init(|| DialectRegistry {
114 dialects: Mutex::new(HashMap::new()),
115 })
116 }
117
118 pub fn register<P: DialectPlugin + 'static>(&self, plugin: P) {
122 let name = plugin.name().to_lowercase();
123 let mut map = self
124 .dialects
125 .lock()
126 .expect("dialect registry lock poisoned");
127 map.insert(name, Arc::new(plugin));
128 }
129
130 #[must_use]
132 pub fn get(&self, name: &str) -> Option<Arc<dyn DialectPlugin>> {
133 let map = self
134 .dialects
135 .lock()
136 .expect("dialect registry lock poisoned");
137 map.get(&name.to_lowercase()).cloned()
138 }
139
140 pub fn unregister(&self, name: &str) -> bool {
144 let mut map = self
145 .dialects
146 .lock()
147 .expect("dialect registry lock poisoned");
148 map.remove(&name.to_lowercase()).is_some()
149 }
150
151 #[must_use]
153 pub fn registered_names(&self) -> Vec<String> {
154 let map = self
155 .dialects
156 .lock()
157 .expect("dialect registry lock poisoned");
158 map.keys().cloned().collect()
159 }
160}
161
162use crate::dialects::Dialect;
167
168#[derive(Debug, Clone, PartialEq, Eq, Hash)]
183pub enum DialectRef {
184 BuiltIn(Dialect),
186 Custom(String),
188}
189
190impl DialectRef {
191 #[must_use]
193 pub fn custom(name: &str) -> Self {
194 DialectRef::Custom(name.to_lowercase())
195 }
196
197 #[must_use]
199 pub fn as_builtin(&self) -> Option<Dialect> {
200 match self {
201 DialectRef::BuiltIn(d) => Some(*d),
202 DialectRef::Custom(_) => None,
203 }
204 }
205
206 #[must_use]
208 pub fn as_plugin(&self) -> Option<Arc<dyn DialectPlugin>> {
209 match self {
210 DialectRef::Custom(name) => DialectRegistry::global().get(name),
211 DialectRef::BuiltIn(_) => None,
212 }
213 }
214
215 #[must_use]
217 pub fn quote_style(&self) -> QuoteStyle {
218 match self {
219 DialectRef::BuiltIn(d) => QuoteStyle::for_dialect(*d),
220 DialectRef::Custom(name) => DialectRegistry::global()
221 .get(name)
222 .and_then(|p| p.quote_style())
223 .unwrap_or(QuoteStyle::DoubleQuote),
224 }
225 }
226
227 #[must_use]
229 pub fn supports_ilike(&self) -> bool {
230 match self {
231 DialectRef::BuiltIn(d) => super::supports_ilike_builtin(*d),
232 DialectRef::Custom(name) => DialectRegistry::global()
233 .get(name)
234 .and_then(|p| p.supports_ilike())
235 .unwrap_or(false),
236 }
237 }
238
239 #[must_use]
241 pub fn map_function_name(&self, name: &str) -> String {
242 match self {
243 DialectRef::BuiltIn(d) => super::map_function_name(name, *d),
244 DialectRef::Custom(cname) => DialectRegistry::global()
245 .get(cname)
246 .and_then(|p| p.map_function_name(name))
247 .unwrap_or_else(|| name.to_string()),
248 }
249 }
250
251 #[must_use]
253 pub fn map_data_type(&self, dt: &DataType) -> DataType {
254 match self {
255 DialectRef::BuiltIn(d) => super::map_data_type(dt.clone(), *d),
256 DialectRef::Custom(name) => DialectRegistry::global()
257 .get(name)
258 .and_then(|p| p.map_data_type(dt))
259 .unwrap_or_else(|| dt.clone()),
260 }
261 }
262}
263
264impl From<Dialect> for DialectRef {
265 fn from(d: Dialect) -> Self {
266 DialectRef::BuiltIn(d)
267 }
268}
269
270impl std::fmt::Display for DialectRef {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 match self {
273 DialectRef::BuiltIn(d) => write!(f, "{d}"),
274 DialectRef::Custom(name) => write!(f, "Custom({name})"),
275 }
276 }
277}
278
279use crate::ast::TypedFunction;
288
289fn typed_function_canonical_name(func: &TypedFunction) -> &'static str {
291 match func {
292 TypedFunction::DateAdd { .. } => "DATE_ADD",
293 TypedFunction::DateDiff { .. } => "DATE_DIFF",
294 TypedFunction::DateTrunc { .. } => "DATE_TRUNC",
295 TypedFunction::DateSub { .. } => "DATE_SUB",
296 TypedFunction::CurrentDate => "CURRENT_DATE",
297 TypedFunction::CurrentTimestamp => "NOW",
298 TypedFunction::StrToTime { .. } => "STR_TO_TIME",
299 TypedFunction::TimeToStr { .. } => "TIME_TO_STR",
300 TypedFunction::TsOrDsToDate { .. } => "TS_OR_DS_TO_DATE",
301 TypedFunction::Year { .. } => "YEAR",
302 TypedFunction::Month { .. } => "MONTH",
303 TypedFunction::Day { .. } => "DAY",
304 TypedFunction::Trim { .. } => "TRIM",
305 TypedFunction::Substring { .. } => "SUBSTRING",
306 TypedFunction::Upper { .. } => "UPPER",
307 TypedFunction::Lower { .. } => "LOWER",
308 TypedFunction::RegexpLike { .. } => "REGEXP_LIKE",
309 TypedFunction::RegexpExtract { .. } => "REGEXP_EXTRACT",
310 TypedFunction::RegexpReplace { .. } => "REGEXP_REPLACE",
311 TypedFunction::ConcatWs { .. } => "CONCAT_WS",
312 TypedFunction::Split { .. } => "SPLIT",
313 TypedFunction::Initcap { .. } => "INITCAP",
314 TypedFunction::Length { .. } => "LENGTH",
315 TypedFunction::Replace { .. } => "REPLACE",
316 TypedFunction::Reverse { .. } => "REVERSE",
317 TypedFunction::Left { .. } => "LEFT",
318 TypedFunction::Right { .. } => "RIGHT",
319 TypedFunction::Lpad { .. } => "LPAD",
320 TypedFunction::Rpad { .. } => "RPAD",
321 TypedFunction::Count { .. } => "COUNT",
322 TypedFunction::Sum { .. } => "SUM",
323 TypedFunction::Avg { .. } => "AVG",
324 TypedFunction::Min { .. } => "MIN",
325 TypedFunction::Max { .. } => "MAX",
326 TypedFunction::ArrayAgg { .. } => "ARRAY_AGG",
327 TypedFunction::ApproxDistinct { .. } => "APPROX_DISTINCT",
328 TypedFunction::Variance { .. } => "VARIANCE",
329 TypedFunction::Stddev { .. } => "STDDEV",
330 TypedFunction::GroupConcat { .. } => "GROUP_CONCAT",
331 TypedFunction::ArrayConcat { .. } => "ARRAY_CONCAT",
332 TypedFunction::ArrayContains { .. } => "ARRAY_CONTAINS",
333 TypedFunction::ArraySize { .. } => "ARRAY_SIZE",
334 TypedFunction::Explode { .. } => "EXPLODE",
335 TypedFunction::GenerateSeries { .. } => "GENERATE_SERIES",
336 TypedFunction::Flatten { .. } => "FLATTEN",
337 TypedFunction::JSONExtract { .. } => "JSON_EXTRACT",
338 TypedFunction::JSONExtractScalar { .. } => "JSON_EXTRACT_SCALAR",
339 TypedFunction::ParseJSON { .. } => "PARSE_JSON",
340 TypedFunction::JSONFormat { .. } => "JSON_FORMAT",
341 TypedFunction::RowNumber => "ROW_NUMBER",
342 TypedFunction::Rank => "RANK",
343 TypedFunction::DenseRank => "DENSE_RANK",
344 TypedFunction::NTile { .. } => "NTILE",
345 TypedFunction::Lead { .. } => "LEAD",
346 TypedFunction::Lag { .. } => "LAG",
347 TypedFunction::FirstValue { .. } => "FIRST_VALUE",
348 TypedFunction::LastValue { .. } => "LAST_VALUE",
349 TypedFunction::Abs { .. } => "ABS",
350 TypedFunction::Ceil { .. } => "CEIL",
351 TypedFunction::Floor { .. } => "FLOOR",
352 TypedFunction::Round { .. } => "ROUND",
353 TypedFunction::Log { .. } => "LOG",
354 TypedFunction::Ln { .. } => "LN",
355 TypedFunction::Pow { .. } => "POW",
356 TypedFunction::Sqrt { .. } => "SQRT",
357 TypedFunction::Greatest { .. } => "GREATEST",
358 TypedFunction::Least { .. } => "LEAST",
359 TypedFunction::Mod { .. } => "MOD",
360 TypedFunction::Hex { .. } => "HEX",
361 TypedFunction::Unhex { .. } => "UNHEX",
362 TypedFunction::Md5 { .. } => "MD5",
363 TypedFunction::Sha { .. } => "SHA",
364 TypedFunction::Sha2 { .. } => "SHA2",
365 }
366}
367
368fn typed_function_args(func: &TypedFunction) -> Vec<Expr> {
370 match func {
371 TypedFunction::CurrentDate | TypedFunction::CurrentTimestamp => vec![],
372 TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => vec![],
373 TypedFunction::Length { expr }
374 | TypedFunction::Upper { expr }
375 | TypedFunction::Lower { expr }
376 | TypedFunction::Initcap { expr }
377 | TypedFunction::Reverse { expr }
378 | TypedFunction::Abs { expr }
379 | TypedFunction::Ceil { expr }
380 | TypedFunction::Floor { expr }
381 | TypedFunction::Ln { expr }
382 | TypedFunction::Sqrt { expr }
383 | TypedFunction::Explode { expr }
384 | TypedFunction::Flatten { expr }
385 | TypedFunction::ArraySize { expr }
386 | TypedFunction::ParseJSON { expr }
387 | TypedFunction::JSONFormat { expr }
388 | TypedFunction::Hex { expr }
389 | TypedFunction::Unhex { expr }
390 | TypedFunction::Md5 { expr }
391 | TypedFunction::Sha { expr }
392 | TypedFunction::TsOrDsToDate { expr }
393 | TypedFunction::Year { expr }
394 | TypedFunction::Month { expr }
395 | TypedFunction::Day { expr }
396 | TypedFunction::ApproxDistinct { expr }
397 | TypedFunction::Variance { expr }
398 | TypedFunction::Stddev { expr }
399 | TypedFunction::FirstValue { expr }
400 | TypedFunction::LastValue { expr } => vec![*expr.clone()],
401 TypedFunction::DateTrunc { unit, expr } => {
402 vec![Expr::StringLiteral(format!("{unit:?}")), *expr.clone()]
403 }
404 TypedFunction::DateAdd { expr, interval, .. }
405 | TypedFunction::DateSub { expr, interval, .. } => {
406 vec![*expr.clone(), *interval.clone()]
407 }
408 TypedFunction::DateDiff { start, end, .. } => vec![*start.clone(), *end.clone()],
409 TypedFunction::StrToTime { expr, format } | TypedFunction::TimeToStr { expr, format } => {
410 vec![*expr.clone(), *format.clone()]
411 }
412 TypedFunction::Trim { expr, .. } => vec![*expr.clone()],
413 TypedFunction::Substring {
414 expr,
415 start,
416 length,
417 } => {
418 let mut args = vec![*expr.clone(), *start.clone()];
419 if let Some(len) = length {
420 args.push(*len.clone());
421 }
422 args
423 }
424 TypedFunction::RegexpLike {
425 expr,
426 pattern,
427 flags,
428 } => {
429 let mut args = vec![*expr.clone(), *pattern.clone()];
430 if let Some(f) = flags {
431 args.push(*f.clone());
432 }
433 args
434 }
435 TypedFunction::RegexpExtract {
436 expr,
437 pattern,
438 group_index,
439 } => {
440 let mut args = vec![*expr.clone(), *pattern.clone()];
441 if let Some(g) = group_index {
442 args.push(*g.clone());
443 }
444 args
445 }
446 TypedFunction::RegexpReplace {
447 expr,
448 pattern,
449 replacement,
450 flags,
451 } => {
452 let mut args = vec![*expr.clone(), *pattern.clone(), *replacement.clone()];
453 if let Some(f) = flags {
454 args.push(*f.clone());
455 }
456 args
457 }
458 TypedFunction::ConcatWs { separator, exprs } => {
459 let mut args = vec![*separator.clone()];
460 args.extend(exprs.iter().cloned());
461 args
462 }
463 TypedFunction::Split { expr, delimiter } => vec![*expr.clone(), *delimiter.clone()],
464 TypedFunction::Replace { expr, from, to } => {
465 vec![*expr.clone(), *from.clone(), *to.clone()]
466 }
467 TypedFunction::Left { expr, n } | TypedFunction::Right { expr, n } => {
468 vec![*expr.clone(), *n.clone()]
469 }
470 TypedFunction::Lpad { expr, length, pad } | TypedFunction::Rpad { expr, length, pad } => {
471 let mut args = vec![*expr.clone(), *length.clone()];
472 if let Some(p) = pad {
473 args.push(*p.clone());
474 }
475 args
476 }
477 TypedFunction::Count { expr, .. }
478 | TypedFunction::Sum { expr, .. }
479 | TypedFunction::Avg { expr, .. }
480 | TypedFunction::Min { expr }
481 | TypedFunction::Max { expr }
482 | TypedFunction::ArrayAgg { expr, .. } => vec![*expr.clone()],
483 TypedFunction::ArrayConcat { arrays } => arrays.clone(),
484 TypedFunction::ArrayContains { array, element } => {
485 vec![*array.clone(), *element.clone()]
486 }
487 TypedFunction::GenerateSeries { start, stop, step } => {
488 let mut args = vec![*start.clone(), *stop.clone()];
489 if let Some(s) = step {
490 args.push(*s.clone());
491 }
492 args
493 }
494 TypedFunction::JSONExtract { expr, path }
495 | TypedFunction::JSONExtractScalar { expr, path } => {
496 vec![*expr.clone(), *path.clone()]
497 }
498 TypedFunction::NTile { n } => vec![*n.clone()],
499 TypedFunction::Lead {
500 expr,
501 offset,
502 default,
503 }
504 | TypedFunction::Lag {
505 expr,
506 offset,
507 default,
508 } => {
509 let mut args = vec![*expr.clone()];
510 if let Some(o) = offset {
511 args.push(*o.clone());
512 }
513 if let Some(d) = default {
514 args.push(*d.clone());
515 }
516 args
517 }
518 TypedFunction::Round { expr, decimals } => {
519 let mut args = vec![*expr.clone()];
520 if let Some(d) = decimals {
521 args.push(*d.clone());
522 }
523 args
524 }
525 TypedFunction::Log { expr, base } => {
526 let mut args = vec![*expr.clone()];
527 if let Some(b) = base {
528 args.push(*b.clone());
529 }
530 args
531 }
532 TypedFunction::Pow { base, exponent } => vec![*base.clone(), *exponent.clone()],
533 TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => exprs.clone(),
534 TypedFunction::Mod { left, right } => vec![*left.clone(), *right.clone()],
535 TypedFunction::Sha2 { expr, bit_length } => vec![*expr.clone(), *bit_length.clone()],
536 TypedFunction::GroupConcat {
537 exprs, separator, ..
538 } => {
539 let mut args = exprs.clone();
540 if let Some(s) = separator {
541 args.push(*s.clone());
542 }
543 args
544 }
545 }
546}
547
548#[must_use]
555pub fn transform(statement: &Statement, from: &DialectRef, to: &DialectRef) -> Statement {
556 if let (DialectRef::BuiltIn(f), DialectRef::BuiltIn(t)) = (from, to) {
558 return super::transform(statement, *f, *t);
559 }
560
561 if let Some(plugin) = to.as_plugin()
563 && let Some(transformed) = plugin.transform_statement(statement)
564 {
565 return transformed;
566 }
567
568 let mut stmt = statement.clone();
570 transform_statement_plugin(&mut stmt, to);
571 stmt
572}
573
574fn transform_statement_plugin(statement: &mut Statement, target: &DialectRef) {
576 match statement {
577 Statement::Select(sel) => {
578 for item in &mut sel.columns {
579 if let crate::ast::SelectItem::Expr { expr, .. } = item {
580 *expr = transform_expr_plugin(expr.clone(), target);
581 }
582 }
583 if let Some(wh) = &mut sel.where_clause {
584 *wh = transform_expr_plugin(wh.clone(), target);
585 }
586 for gb in &mut sel.group_by {
587 *gb = transform_expr_plugin(gb.clone(), target);
588 }
589 if let Some(having) = &mut sel.having {
590 *having = transform_expr_plugin(having.clone(), target);
591 }
592 }
593 Statement::CreateTable(ct) => {
594 for col in &mut ct.columns {
595 col.data_type = target.map_data_type(&col.data_type);
596 if let Some(default) = &mut col.default {
597 *default = transform_expr_plugin(default.clone(), target);
598 }
599 }
600 }
601 _ => {}
602 }
603}
604
605fn transform_expr_plugin(expr: Expr, target: &DialectRef) -> Expr {
607 if let Some(plugin) = target.as_plugin()
609 && let Some(transformed) = plugin.transform_expr(&expr)
610 {
611 return transformed;
612 }
613
614 match expr {
615 Expr::TypedFunction { func, filter, over } => {
618 if let DialectRef::Custom(_) = target {
619 let canonical = typed_function_canonical_name(&func);
620 let new_name = target.map_function_name(canonical);
621 if new_name != canonical {
622 let args = typed_function_args(&func)
624 .into_iter()
625 .map(|a| transform_expr_plugin(a, target))
626 .collect();
627 return Expr::Function {
628 name: new_name,
629 args,
630 distinct: false,
631 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
632 over,
633 };
634 }
635 }
636 Expr::TypedFunction {
638 func: func.transform_children(&|e| transform_expr_plugin(e, target)),
639 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
640 over,
641 }
642 }
643 Expr::Function {
644 name,
645 args,
646 distinct,
647 filter,
648 over,
649 } => {
650 let new_name = target.map_function_name(&name);
651 let new_args: Vec<Expr> = args
652 .into_iter()
653 .map(|a| transform_expr_plugin(a, target))
654 .collect();
655 Expr::Function {
656 name: new_name,
657 args: new_args,
658 distinct,
659 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
660 over,
661 }
662 }
663 Expr::Cast { expr, data_type } => Expr::Cast {
664 expr: Box::new(transform_expr_plugin(*expr, target)),
665 data_type: target.map_data_type(&data_type),
666 },
667 Expr::ILike {
668 expr,
669 pattern,
670 negated,
671 escape,
672 } if !target.supports_ilike() => Expr::Like {
673 expr: Box::new(Expr::TypedFunction {
674 func: crate::ast::TypedFunction::Lower {
675 expr: Box::new(transform_expr_plugin(*expr, target)),
676 },
677 filter: None,
678 over: None,
679 }),
680 pattern: Box::new(Expr::TypedFunction {
681 func: crate::ast::TypedFunction::Lower {
682 expr: Box::new(transform_expr_plugin(*pattern, target)),
683 },
684 filter: None,
685 over: None,
686 }),
687 negated,
688 escape,
689 },
690 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
691 left: Box::new(transform_expr_plugin(*left, target)),
692 op,
693 right: Box::new(transform_expr_plugin(*right, target)),
694 },
695 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
696 op,
697 expr: Box::new(transform_expr_plugin(*expr, target)),
698 },
699 Expr::Nested(inner) => Expr::Nested(Box::new(transform_expr_plugin(*inner, target))),
700 Expr::Column {
701 table,
702 name,
703 quote_style,
704 table_quote_style,
705 } => {
706 let new_qs = if quote_style.is_quoted() {
707 target.quote_style()
708 } else {
709 QuoteStyle::None
710 };
711 let new_tqs = if table_quote_style.is_quoted() {
712 target.quote_style()
713 } else {
714 QuoteStyle::None
715 };
716 Expr::Column {
717 table,
718 name,
719 quote_style: new_qs,
720 table_quote_style: new_tqs,
721 }
722 }
723 other => other,
724 }
725}
726
727use crate::errors;
732
733pub fn transpile_ext(
752 sql: &str,
753 read_dialect: &DialectRef,
754 write_dialect: &DialectRef,
755) -> errors::Result<String> {
756 let parse_dialect = read_dialect.as_builtin().unwrap_or(Dialect::Ansi);
758 let ast = crate::parser::parse(sql, parse_dialect)?;
759 let transformed = transform(&ast, read_dialect, write_dialect);
760 let gen_dialect = write_dialect.as_builtin().unwrap_or(Dialect::Ansi);
761 Ok(crate::generator::generate(&transformed, gen_dialect))
762}
763
764pub fn transpile_statements_ext(
770 sql: &str,
771 read_dialect: &DialectRef,
772 write_dialect: &DialectRef,
773) -> errors::Result<Vec<String>> {
774 let parse_dialect = read_dialect.as_builtin().unwrap_or(Dialect::Ansi);
775 let stmts = crate::parser::parse_statements(sql, parse_dialect)?;
776 let gen_dialect = write_dialect.as_builtin().unwrap_or(Dialect::Ansi);
777 let mut results = Vec::with_capacity(stmts.len());
778 for stmt in &stmts {
779 let transformed = transform(stmt, read_dialect, write_dialect);
780 results.push(crate::generator::generate(&transformed, gen_dialect));
781 }
782 Ok(results)
783}
784
785pub fn register_dialect<P: DialectPlugin + 'static>(plugin: P) {
793 DialectRegistry::global().register(plugin);
794}
795
796#[must_use]
800pub fn resolve_dialect(name: &str) -> Option<DialectRef> {
801 if let Some(d) = Dialect::from_str(name) {
803 return Some(DialectRef::BuiltIn(d));
804 }
805 if DialectRegistry::global().get(name).is_some() {
807 return Some(DialectRef::Custom(name.to_lowercase()));
808 }
809 None
810}