1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use crate::ast::{DataType, Expr, QuoteStyle, Statement};
5
6pub trait DialectPlugin: Send + Sync {
44 fn name(&self) -> &str;
46
47 fn quote_style(&self) -> Option<QuoteStyle> {
51 None
52 }
53
54 fn supports_ilike(&self) -> Option<bool> {
58 None
59 }
60
61 fn map_function_name(&self, name: &str) -> Option<String> {
67 let _ = name;
68 None
69 }
70
71 fn map_data_type(&self, data_type: &DataType) -> Option<DataType> {
75 let _ = data_type;
76 None
77 }
78
79 fn transform_expr(&self, expr: &Expr) -> Option<Expr> {
84 let _ = expr;
85 None
86 }
87
88 fn transform_statement(&self, statement: &Statement) -> Option<Statement> {
93 let _ = statement;
94 None
95 }
96}
97
98pub struct DialectRegistry {
106 dialects: Mutex<HashMap<String, Arc<dyn DialectPlugin>>>,
107}
108
109impl DialectRegistry {
110 pub fn global() -> &'static DialectRegistry {
112 static INSTANCE: OnceLock<DialectRegistry> = OnceLock::new();
113 INSTANCE.get_or_init(|| DialectRegistry {
114 dialects: Mutex::new(HashMap::new()),
115 })
116 }
117
118 pub fn register<P: DialectPlugin + 'static>(&self, plugin: P) {
122 let name = plugin.name().to_lowercase();
123 let mut map = self.dialects.lock().expect("dialect registry lock poisoned");
124 map.insert(name, Arc::new(plugin));
125 }
126
127 #[must_use]
129 pub fn get(&self, name: &str) -> Option<Arc<dyn DialectPlugin>> {
130 let map = self.dialects.lock().expect("dialect registry lock poisoned");
131 map.get(&name.to_lowercase()).cloned()
132 }
133
134 pub fn unregister(&self, name: &str) -> bool {
138 let mut map = self.dialects.lock().expect("dialect registry lock poisoned");
139 map.remove(&name.to_lowercase()).is_some()
140 }
141
142 #[must_use]
144 pub fn registered_names(&self) -> Vec<String> {
145 let map = self.dialects.lock().expect("dialect registry lock poisoned");
146 map.keys().cloned().collect()
147 }
148}
149
150use crate::dialects::Dialect;
155
156#[derive(Debug, Clone, PartialEq, Eq, Hash)]
171pub enum DialectRef {
172 BuiltIn(Dialect),
174 Custom(String),
176}
177
178impl DialectRef {
179 #[must_use]
181 pub fn custom(name: &str) -> Self {
182 DialectRef::Custom(name.to_lowercase())
183 }
184
185 #[must_use]
187 pub fn as_builtin(&self) -> Option<Dialect> {
188 match self {
189 DialectRef::BuiltIn(d) => Some(*d),
190 DialectRef::Custom(_) => None,
191 }
192 }
193
194 #[must_use]
196 pub fn as_plugin(&self) -> Option<Arc<dyn DialectPlugin>> {
197 match self {
198 DialectRef::Custom(name) => DialectRegistry::global().get(name),
199 DialectRef::BuiltIn(_) => None,
200 }
201 }
202
203 #[must_use]
205 pub fn quote_style(&self) -> QuoteStyle {
206 match self {
207 DialectRef::BuiltIn(d) => QuoteStyle::for_dialect(*d),
208 DialectRef::Custom(name) => DialectRegistry::global()
209 .get(name)
210 .and_then(|p| p.quote_style())
211 .unwrap_or(QuoteStyle::DoubleQuote),
212 }
213 }
214
215 #[must_use]
217 pub fn supports_ilike(&self) -> bool {
218 match self {
219 DialectRef::BuiltIn(d) => super::supports_ilike_builtin(*d),
220 DialectRef::Custom(name) => DialectRegistry::global()
221 .get(name)
222 .and_then(|p| p.supports_ilike())
223 .unwrap_or(false),
224 }
225 }
226
227 #[must_use]
229 pub fn map_function_name(&self, name: &str) -> String {
230 match self {
231 DialectRef::BuiltIn(d) => super::map_function_name(name, *d),
232 DialectRef::Custom(cname) => DialectRegistry::global()
233 .get(cname)
234 .and_then(|p| p.map_function_name(name))
235 .unwrap_or_else(|| name.to_string()),
236 }
237 }
238
239 #[must_use]
241 pub fn map_data_type(&self, dt: &DataType) -> DataType {
242 match self {
243 DialectRef::BuiltIn(d) => super::map_data_type(dt.clone(), *d),
244 DialectRef::Custom(name) => DialectRegistry::global()
245 .get(name)
246 .and_then(|p| p.map_data_type(dt))
247 .unwrap_or_else(|| dt.clone()),
248 }
249 }
250}
251
252impl From<Dialect> for DialectRef {
253 fn from(d: Dialect) -> Self {
254 DialectRef::BuiltIn(d)
255 }
256}
257
258impl std::fmt::Display for DialectRef {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 match self {
261 DialectRef::BuiltIn(d) => write!(f, "{d}"),
262 DialectRef::Custom(name) => write!(f, "Custom({name})"),
263 }
264 }
265}
266
267use crate::ast::TypedFunction;
276
277fn typed_function_canonical_name(func: &TypedFunction) -> &'static str {
279 match func {
280 TypedFunction::DateAdd { .. } => "DATE_ADD",
281 TypedFunction::DateDiff { .. } => "DATE_DIFF",
282 TypedFunction::DateTrunc { .. } => "DATE_TRUNC",
283 TypedFunction::DateSub { .. } => "DATE_SUB",
284 TypedFunction::CurrentDate => "CURRENT_DATE",
285 TypedFunction::CurrentTimestamp => "NOW",
286 TypedFunction::StrToTime { .. } => "STR_TO_TIME",
287 TypedFunction::TimeToStr { .. } => "TIME_TO_STR",
288 TypedFunction::TsOrDsToDate { .. } => "TS_OR_DS_TO_DATE",
289 TypedFunction::Year { .. } => "YEAR",
290 TypedFunction::Month { .. } => "MONTH",
291 TypedFunction::Day { .. } => "DAY",
292 TypedFunction::Trim { .. } => "TRIM",
293 TypedFunction::Substring { .. } => "SUBSTRING",
294 TypedFunction::Upper { .. } => "UPPER",
295 TypedFunction::Lower { .. } => "LOWER",
296 TypedFunction::RegexpLike { .. } => "REGEXP_LIKE",
297 TypedFunction::RegexpExtract { .. } => "REGEXP_EXTRACT",
298 TypedFunction::RegexpReplace { .. } => "REGEXP_REPLACE",
299 TypedFunction::ConcatWs { .. } => "CONCAT_WS",
300 TypedFunction::Split { .. } => "SPLIT",
301 TypedFunction::Initcap { .. } => "INITCAP",
302 TypedFunction::Length { .. } => "LENGTH",
303 TypedFunction::Replace { .. } => "REPLACE",
304 TypedFunction::Reverse { .. } => "REVERSE",
305 TypedFunction::Left { .. } => "LEFT",
306 TypedFunction::Right { .. } => "RIGHT",
307 TypedFunction::Lpad { .. } => "LPAD",
308 TypedFunction::Rpad { .. } => "RPAD",
309 TypedFunction::Count { .. } => "COUNT",
310 TypedFunction::Sum { .. } => "SUM",
311 TypedFunction::Avg { .. } => "AVG",
312 TypedFunction::Min { .. } => "MIN",
313 TypedFunction::Max { .. } => "MAX",
314 TypedFunction::ArrayAgg { .. } => "ARRAY_AGG",
315 TypedFunction::ApproxDistinct { .. } => "APPROX_DISTINCT",
316 TypedFunction::Variance { .. } => "VARIANCE",
317 TypedFunction::Stddev { .. } => "STDDEV",
318 TypedFunction::ArrayConcat { .. } => "ARRAY_CONCAT",
319 TypedFunction::ArrayContains { .. } => "ARRAY_CONTAINS",
320 TypedFunction::ArraySize { .. } => "ARRAY_SIZE",
321 TypedFunction::Explode { .. } => "EXPLODE",
322 TypedFunction::GenerateSeries { .. } => "GENERATE_SERIES",
323 TypedFunction::Flatten { .. } => "FLATTEN",
324 TypedFunction::JSONExtract { .. } => "JSON_EXTRACT",
325 TypedFunction::JSONExtractScalar { .. } => "JSON_EXTRACT_SCALAR",
326 TypedFunction::ParseJSON { .. } => "PARSE_JSON",
327 TypedFunction::JSONFormat { .. } => "JSON_FORMAT",
328 TypedFunction::RowNumber => "ROW_NUMBER",
329 TypedFunction::Rank => "RANK",
330 TypedFunction::DenseRank => "DENSE_RANK",
331 TypedFunction::NTile { .. } => "NTILE",
332 TypedFunction::Lead { .. } => "LEAD",
333 TypedFunction::Lag { .. } => "LAG",
334 TypedFunction::FirstValue { .. } => "FIRST_VALUE",
335 TypedFunction::LastValue { .. } => "LAST_VALUE",
336 TypedFunction::Abs { .. } => "ABS",
337 TypedFunction::Ceil { .. } => "CEIL",
338 TypedFunction::Floor { .. } => "FLOOR",
339 TypedFunction::Round { .. } => "ROUND",
340 TypedFunction::Log { .. } => "LOG",
341 TypedFunction::Ln { .. } => "LN",
342 TypedFunction::Pow { .. } => "POW",
343 TypedFunction::Sqrt { .. } => "SQRT",
344 TypedFunction::Greatest { .. } => "GREATEST",
345 TypedFunction::Least { .. } => "LEAST",
346 TypedFunction::Mod { .. } => "MOD",
347 TypedFunction::Hex { .. } => "HEX",
348 TypedFunction::Unhex { .. } => "UNHEX",
349 TypedFunction::Md5 { .. } => "MD5",
350 TypedFunction::Sha { .. } => "SHA",
351 TypedFunction::Sha2 { .. } => "SHA2",
352 }
353}
354
355fn typed_function_args(func: &TypedFunction) -> Vec<Expr> {
357 match func {
358 TypedFunction::CurrentDate | TypedFunction::CurrentTimestamp => vec![],
359 TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => vec![],
360 TypedFunction::Length { expr }
361 | TypedFunction::Upper { expr }
362 | TypedFunction::Lower { expr }
363 | TypedFunction::Initcap { expr }
364 | TypedFunction::Reverse { expr }
365 | TypedFunction::Abs { expr }
366 | TypedFunction::Ceil { expr }
367 | TypedFunction::Floor { expr }
368 | TypedFunction::Ln { expr }
369 | TypedFunction::Sqrt { expr }
370 | TypedFunction::Explode { expr }
371 | TypedFunction::Flatten { expr }
372 | TypedFunction::ArraySize { expr }
373 | TypedFunction::ParseJSON { expr }
374 | TypedFunction::JSONFormat { expr }
375 | TypedFunction::Hex { expr }
376 | TypedFunction::Unhex { expr }
377 | TypedFunction::Md5 { expr }
378 | TypedFunction::Sha { expr }
379 | TypedFunction::TsOrDsToDate { expr }
380 | TypedFunction::Year { expr }
381 | TypedFunction::Month { expr }
382 | TypedFunction::Day { expr }
383 | TypedFunction::ApproxDistinct { expr }
384 | TypedFunction::Variance { expr }
385 | TypedFunction::Stddev { expr }
386 | TypedFunction::FirstValue { expr }
387 | TypedFunction::LastValue { expr } => vec![*expr.clone()],
388 TypedFunction::DateTrunc { unit, expr } => {
389 vec![Expr::StringLiteral(format!("{unit:?}")), *expr.clone()]
390 }
391 TypedFunction::DateAdd { expr, interval, .. }
392 | TypedFunction::DateSub { expr, interval, .. } => {
393 vec![*expr.clone(), *interval.clone()]
394 }
395 TypedFunction::DateDiff { start, end, .. } => vec![*start.clone(), *end.clone()],
396 TypedFunction::StrToTime { expr, format }
397 | TypedFunction::TimeToStr { expr, format } => {
398 vec![*expr.clone(), *format.clone()]
399 }
400 TypedFunction::Trim { expr, .. } => vec![*expr.clone()],
401 TypedFunction::Substring { expr, start, length } => {
402 let mut args = vec![*expr.clone(), *start.clone()];
403 if let Some(len) = length {
404 args.push(*len.clone());
405 }
406 args
407 }
408 TypedFunction::RegexpLike { expr, pattern, flags } => {
409 let mut args = vec![*expr.clone(), *pattern.clone()];
410 if let Some(f) = flags {
411 args.push(*f.clone());
412 }
413 args
414 }
415 TypedFunction::RegexpExtract { expr, pattern, group_index } => {
416 let mut args = vec![*expr.clone(), *pattern.clone()];
417 if let Some(g) = group_index {
418 args.push(*g.clone());
419 }
420 args
421 }
422 TypedFunction::RegexpReplace { expr, pattern, replacement, flags } => {
423 let mut args = vec![*expr.clone(), *pattern.clone(), *replacement.clone()];
424 if let Some(f) = flags {
425 args.push(*f.clone());
426 }
427 args
428 }
429 TypedFunction::ConcatWs { separator, exprs } => {
430 let mut args = vec![*separator.clone()];
431 args.extend(exprs.iter().cloned());
432 args
433 }
434 TypedFunction::Split { expr, delimiter } => vec![*expr.clone(), *delimiter.clone()],
435 TypedFunction::Replace { expr, from, to } => {
436 vec![*expr.clone(), *from.clone(), *to.clone()]
437 }
438 TypedFunction::Left { expr, n } | TypedFunction::Right { expr, n } => {
439 vec![*expr.clone(), *n.clone()]
440 }
441 TypedFunction::Lpad { expr, length, pad }
442 | TypedFunction::Rpad { expr, length, pad } => {
443 let mut args = vec![*expr.clone(), *length.clone()];
444 if let Some(p) = pad {
445 args.push(*p.clone());
446 }
447 args
448 }
449 TypedFunction::Count { expr, .. }
450 | TypedFunction::Sum { expr, .. }
451 | TypedFunction::Avg { expr, .. }
452 | TypedFunction::Min { expr }
453 | TypedFunction::Max { expr }
454 | TypedFunction::ArrayAgg { expr, .. } => vec![*expr.clone()],
455 TypedFunction::ArrayConcat { arrays } => arrays.clone(),
456 TypedFunction::ArrayContains { array, element } => {
457 vec![*array.clone(), *element.clone()]
458 }
459 TypedFunction::GenerateSeries { start, stop, step } => {
460 let mut args = vec![*start.clone(), *stop.clone()];
461 if let Some(s) = step {
462 args.push(*s.clone());
463 }
464 args
465 }
466 TypedFunction::JSONExtract { expr, path }
467 | TypedFunction::JSONExtractScalar { expr, path } => {
468 vec![*expr.clone(), *path.clone()]
469 }
470 TypedFunction::NTile { n } => vec![*n.clone()],
471 TypedFunction::Lead { expr, offset, default }
472 | TypedFunction::Lag { expr, offset, default } => {
473 let mut args = vec![*expr.clone()];
474 if let Some(o) = offset {
475 args.push(*o.clone());
476 }
477 if let Some(d) = default {
478 args.push(*d.clone());
479 }
480 args
481 }
482 TypedFunction::Round { expr, decimals } => {
483 let mut args = vec![*expr.clone()];
484 if let Some(d) = decimals {
485 args.push(*d.clone());
486 }
487 args
488 }
489 TypedFunction::Log { expr, base } => {
490 let mut args = vec![*expr.clone()];
491 if let Some(b) = base {
492 args.push(*b.clone());
493 }
494 args
495 }
496 TypedFunction::Pow { base, exponent } => vec![*base.clone(), *exponent.clone()],
497 TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => exprs.clone(),
498 TypedFunction::Mod { left, right } => vec![*left.clone(), *right.clone()],
499 TypedFunction::Sha2 { expr, bit_length } => vec![*expr.clone(), *bit_length.clone()],
500 }
501}
502
503#[must_use]
510pub fn transform(statement: &Statement, from: &DialectRef, to: &DialectRef) -> Statement {
511 if let (DialectRef::BuiltIn(f), DialectRef::BuiltIn(t)) = (from, to) {
513 return super::transform(statement, *f, *t);
514 }
515
516 if let Some(plugin) = to.as_plugin()
518 && let Some(transformed) = plugin.transform_statement(statement)
519 {
520 return transformed;
521 }
522
523 let mut stmt = statement.clone();
525 transform_statement_plugin(&mut stmt, to);
526 stmt
527}
528
529fn transform_statement_plugin(statement: &mut Statement, target: &DialectRef) {
531 match statement {
532 Statement::Select(sel) => {
533 for item in &mut sel.columns {
534 if let crate::ast::SelectItem::Expr { expr, .. } = item {
535 *expr = transform_expr_plugin(expr.clone(), target);
536 }
537 }
538 if let Some(wh) = &mut sel.where_clause {
539 *wh = transform_expr_plugin(wh.clone(), target);
540 }
541 for gb in &mut sel.group_by {
542 *gb = transform_expr_plugin(gb.clone(), target);
543 }
544 if let Some(having) = &mut sel.having {
545 *having = transform_expr_plugin(having.clone(), target);
546 }
547 }
548 Statement::CreateTable(ct) => {
549 for col in &mut ct.columns {
550 col.data_type = target.map_data_type(&col.data_type);
551 if let Some(default) = &mut col.default {
552 *default = transform_expr_plugin(default.clone(), target);
553 }
554 }
555 }
556 _ => {}
557 }
558}
559
560fn transform_expr_plugin(expr: Expr, target: &DialectRef) -> Expr {
562 if let Some(plugin) = target.as_plugin()
564 && let Some(transformed) = plugin.transform_expr(&expr)
565 {
566 return transformed;
567 }
568
569 match expr {
570 Expr::TypedFunction { func, filter, over } => {
573 if let DialectRef::Custom(_) = target {
574 let canonical = typed_function_canonical_name(&func);
575 let new_name = target.map_function_name(canonical);
576 if new_name != canonical {
577 let args = typed_function_args(&func)
579 .into_iter()
580 .map(|a| transform_expr_plugin(a, target))
581 .collect();
582 return Expr::Function {
583 name: new_name,
584 args,
585 distinct: false,
586 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
587 over,
588 };
589 }
590 }
591 Expr::TypedFunction {
593 func: func.transform_children(&|e| transform_expr_plugin(e, target)),
594 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
595 over,
596 }
597 }
598 Expr::Function {
599 name,
600 args,
601 distinct,
602 filter,
603 over,
604 } => {
605 let new_name = target.map_function_name(&name);
606 let new_args: Vec<Expr> = args
607 .into_iter()
608 .map(|a| transform_expr_plugin(a, target))
609 .collect();
610 Expr::Function {
611 name: new_name,
612 args: new_args,
613 distinct,
614 filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
615 over,
616 }
617 }
618 Expr::Cast { expr, data_type } => Expr::Cast {
619 expr: Box::new(transform_expr_plugin(*expr, target)),
620 data_type: target.map_data_type(&data_type),
621 },
622 Expr::ILike {
623 expr,
624 pattern,
625 negated,
626 escape,
627 } if !target.supports_ilike() => Expr::Like {
628 expr: Box::new(Expr::TypedFunction {
629 func: crate::ast::TypedFunction::Lower {
630 expr: Box::new(transform_expr_plugin(*expr, target)),
631 },
632 filter: None,
633 over: None,
634 }),
635 pattern: Box::new(Expr::TypedFunction {
636 func: crate::ast::TypedFunction::Lower {
637 expr: Box::new(transform_expr_plugin(*pattern, target)),
638 },
639 filter: None,
640 over: None,
641 }),
642 negated,
643 escape,
644 },
645 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
646 left: Box::new(transform_expr_plugin(*left, target)),
647 op,
648 right: Box::new(transform_expr_plugin(*right, target)),
649 },
650 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
651 op,
652 expr: Box::new(transform_expr_plugin(*expr, target)),
653 },
654 Expr::Nested(inner) => Expr::Nested(Box::new(transform_expr_plugin(*inner, target))),
655 Expr::Column {
656 table,
657 name,
658 quote_style,
659 table_quote_style,
660 } => {
661 let new_qs = if quote_style.is_quoted() {
662 target.quote_style()
663 } else {
664 QuoteStyle::None
665 };
666 let new_tqs = if table_quote_style.is_quoted() {
667 target.quote_style()
668 } else {
669 QuoteStyle::None
670 };
671 Expr::Column {
672 table,
673 name,
674 quote_style: new_qs,
675 table_quote_style: new_tqs,
676 }
677 }
678 other => other,
679 }
680}
681
682use crate::errors;
687
688pub fn transpile_ext(
707 sql: &str,
708 read_dialect: &DialectRef,
709 write_dialect: &DialectRef,
710) -> errors::Result<String> {
711 let parse_dialect = read_dialect
713 .as_builtin()
714 .unwrap_or(Dialect::Ansi);
715 let ast = crate::parser::parse(sql, parse_dialect)?;
716 let transformed = transform(&ast, read_dialect, write_dialect);
717 let gen_dialect = write_dialect
718 .as_builtin()
719 .unwrap_or(Dialect::Ansi);
720 Ok(crate::generator::generate(&transformed, gen_dialect))
721}
722
723pub fn transpile_statements_ext(
729 sql: &str,
730 read_dialect: &DialectRef,
731 write_dialect: &DialectRef,
732) -> errors::Result<Vec<String>> {
733 let parse_dialect = read_dialect
734 .as_builtin()
735 .unwrap_or(Dialect::Ansi);
736 let stmts = crate::parser::parse_statements(sql, parse_dialect)?;
737 let gen_dialect = write_dialect
738 .as_builtin()
739 .unwrap_or(Dialect::Ansi);
740 let mut results = Vec::with_capacity(stmts.len());
741 for stmt in &stmts {
742 let transformed = transform(stmt, read_dialect, write_dialect);
743 results.push(crate::generator::generate(&transformed, gen_dialect));
744 }
745 Ok(results)
746}
747
748pub fn register_dialect<P: DialectPlugin + 'static>(plugin: P) {
756 DialectRegistry::global().register(plugin);
757}
758
759#[must_use]
763pub fn resolve_dialect(name: &str) -> Option<DialectRef> {
764 if let Some(d) = Dialect::from_str(name) {
766 return Some(DialectRef::BuiltIn(d));
767 }
768 if DialectRegistry::global().get(name).is_some() {
770 return Some(DialectRef::Custom(name.to_lowercase()));
771 }
772 None
773}