Skip to main content

sqlglot_rust/dialects/
plugin.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex, OnceLock};
3
4use crate::ast::{DataType, Expr, QuoteStyle, Statement};
5
6/// Trait that external code can implement to define a custom SQL dialect.
7///
8/// All methods have default implementations that return `None`, meaning
9/// "no custom behaviour — fall through to the built-in logic". Implementors
10/// only need to override the methods they care about.
11///
12/// # Thread Safety
13///
14/// Implementations must be `Send + Sync` because the global registry is
15/// shared across threads.
16///
17/// # Example
18///
19/// ```rust
20/// use sqlglot_rust::dialects::plugin::{DialectPlugin, DialectRegistry};
21/// use sqlglot_rust::ast::{DataType, Expr, QuoteStyle, Statement};
22///
23/// struct MyDialect;
24///
25/// impl DialectPlugin for MyDialect {
26///     fn name(&self) -> &str { "mydialect" }
27///
28///     fn map_function_name(&self, name: &str) -> Option<String> {
29///         match name.to_uppercase().as_str() {
30///             "MY_FUNC" => Some("BUILTIN_FUNC".to_string()),
31///             _ => None,
32///         }
33///     }
34///
35///     fn quote_style(&self) -> Option<QuoteStyle> {
36///         Some(QuoteStyle::Backtick)
37///     }
38/// }
39///
40/// // Register once, then use via DialectRef::Custom("mydialect")
41/// DialectRegistry::global().register(MyDialect);
42/// ```
43pub trait DialectPlugin: Send + Sync {
44    /// Canonical lower-case name for this dialect (e.g. `"mydialect"`).
45    fn name(&self) -> &str;
46
47    // ── Tokenizer rules ──────────────────────────────────────────────
48
49    /// Preferred quoting style for identifiers.
50    fn quote_style(&self) -> Option<QuoteStyle> {
51        None
52    }
53
54    // ── Parser rules ─────────────────────────────────────────────────
55
56    /// Whether this dialect natively supports `ILIKE`.
57    fn supports_ilike(&self) -> Option<bool> {
58        None
59    }
60
61    // ── Generator / transform rules ──────────────────────────────────
62
63    /// Map a function name for this dialect.
64    ///
65    /// Return `Some(new_name)` to override, or `None` to keep the original.
66    fn map_function_name(&self, name: &str) -> Option<String> {
67        let _ = name;
68        None
69    }
70
71    /// Map a data type for this dialect.
72    ///
73    /// Return `Some(new_type)` to override, or `None` to keep the original.
74    fn map_data_type(&self, data_type: &DataType) -> Option<DataType> {
75        let _ = data_type;
76        None
77    }
78
79    /// Transform an entire expression for this dialect.
80    ///
81    /// Return `Some(new_expr)` to replace the expression, or `None` to
82    /// fall through to the default transformation logic.
83    fn transform_expr(&self, expr: &Expr) -> Option<Expr> {
84        let _ = expr;
85        None
86    }
87
88    /// Transform a complete statement for this dialect.
89    ///
90    /// Return `Some(new_stmt)` to replace the statement, or `None` to
91    /// fall through to the default transformation logic.
92    fn transform_statement(&self, statement: &Statement) -> Option<Statement> {
93        let _ = statement;
94        None
95    }
96}
97
98// ═══════════════════════════════════════════════════════════════════════════
99// Global registry
100// ═══════════════════════════════════════════════════════════════════════════
101
102/// Thread-safe registry for custom dialect plugins.
103///
104/// Access the singleton with [`DialectRegistry::global()`].
105pub struct DialectRegistry {
106    dialects: Mutex<HashMap<String, Arc<dyn DialectPlugin>>>,
107}
108
109impl DialectRegistry {
110    /// Returns the global registry singleton.
111    pub fn global() -> &'static DialectRegistry {
112        static INSTANCE: OnceLock<DialectRegistry> = OnceLock::new();
113        INSTANCE.get_or_init(|| DialectRegistry {
114            dialects: Mutex::new(HashMap::new()),
115        })
116    }
117
118    /// Register a custom dialect plugin.
119    ///
120    /// If a plugin with the same name already exists it is replaced.
121    pub fn register<P: DialectPlugin + 'static>(&self, plugin: P) {
122        let name = plugin.name().to_lowercase();
123        let mut map = self.dialects.lock().expect("dialect registry lock poisoned");
124        map.insert(name, Arc::new(plugin));
125    }
126
127    /// Look up a custom dialect by name (case-insensitive).
128    #[must_use]
129    pub fn get(&self, name: &str) -> Option<Arc<dyn DialectPlugin>> {
130        let map = self.dialects.lock().expect("dialect registry lock poisoned");
131        map.get(&name.to_lowercase()).cloned()
132    }
133
134    /// Remove a custom dialect plugin by name.
135    ///
136    /// Returns `true` if the dialect was found and removed.
137    pub fn unregister(&self, name: &str) -> bool {
138        let mut map = self.dialects.lock().expect("dialect registry lock poisoned");
139        map.remove(&name.to_lowercase()).is_some()
140    }
141
142    /// Returns the names of all registered custom dialects.
143    #[must_use]
144    pub fn registered_names(&self) -> Vec<String> {
145        let map = self.dialects.lock().expect("dialect registry lock poisoned");
146        map.keys().cloned().collect()
147    }
148}
149
150// ═══════════════════════════════════════════════════════════════════════════
151// DialectRef — unified built-in + custom dialect handle
152// ═══════════════════════════════════════════════════════════════════════════
153
154use crate::dialects::Dialect;
155
156/// A reference to either a built-in [`Dialect`] or a custom dialect
157/// registered via the plugin system.
158///
159/// This is the primary handle that plugin-aware API functions accept.
160///
161/// # Example
162///
163/// ```rust
164/// use sqlglot_rust::dialects::plugin::DialectRef;
165/// use sqlglot_rust::Dialect;
166///
167/// let builtin = DialectRef::from(Dialect::Postgres);
168/// let custom  = DialectRef::custom("mydialect");
169/// ```
170#[derive(Debug, Clone, PartialEq, Eq, Hash)]
171pub enum DialectRef {
172    /// A built-in dialect variant.
173    BuiltIn(Dialect),
174    /// A custom dialect identified by its registered name.
175    Custom(String),
176}
177
178impl DialectRef {
179    /// Create a `DialectRef` for a custom dialect by name.
180    #[must_use]
181    pub fn custom(name: &str) -> Self {
182        DialectRef::Custom(name.to_lowercase())
183    }
184
185    /// Try to resolve this reference to a built-in dialect.
186    #[must_use]
187    pub fn as_builtin(&self) -> Option<Dialect> {
188        match self {
189            DialectRef::BuiltIn(d) => Some(*d),
190            DialectRef::Custom(_) => None,
191        }
192    }
193
194    /// Try to resolve this reference to a custom plugin.
195    #[must_use]
196    pub fn as_plugin(&self) -> Option<Arc<dyn DialectPlugin>> {
197        match self {
198            DialectRef::Custom(name) => DialectRegistry::global().get(name),
199            DialectRef::BuiltIn(_) => None,
200        }
201    }
202
203    /// Get the quote style for this dialect reference.
204    #[must_use]
205    pub fn quote_style(&self) -> QuoteStyle {
206        match self {
207            DialectRef::BuiltIn(d) => QuoteStyle::for_dialect(*d),
208            DialectRef::Custom(name) => DialectRegistry::global()
209                .get(name)
210                .and_then(|p| p.quote_style())
211                .unwrap_or(QuoteStyle::DoubleQuote),
212        }
213    }
214
215    /// Check if this dialect supports ILIKE natively.
216    #[must_use]
217    pub fn supports_ilike(&self) -> bool {
218        match self {
219            DialectRef::BuiltIn(d) => super::supports_ilike_builtin(*d),
220            DialectRef::Custom(name) => DialectRegistry::global()
221                .get(name)
222                .and_then(|p| p.supports_ilike())
223                .unwrap_or(false),
224        }
225    }
226
227    /// Map a function name using this dialect's rules.
228    #[must_use]
229    pub fn map_function_name(&self, name: &str) -> String {
230        match self {
231            DialectRef::BuiltIn(d) => super::map_function_name(name, *d),
232            DialectRef::Custom(cname) => DialectRegistry::global()
233                .get(cname)
234                .and_then(|p| p.map_function_name(name))
235                .unwrap_or_else(|| name.to_string()),
236        }
237    }
238
239    /// Map a data type using this dialect's rules.
240    #[must_use]
241    pub fn map_data_type(&self, dt: &DataType) -> DataType {
242        match self {
243            DialectRef::BuiltIn(d) => super::map_data_type(dt.clone(), *d),
244            DialectRef::Custom(name) => DialectRegistry::global()
245                .get(name)
246                .and_then(|p| p.map_data_type(dt))
247                .unwrap_or_else(|| dt.clone()),
248        }
249    }
250}
251
252impl From<Dialect> for DialectRef {
253    fn from(d: Dialect) -> Self {
254        DialectRef::BuiltIn(d)
255    }
256}
257
258impl std::fmt::Display for DialectRef {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        match self {
261            DialectRef::BuiltIn(d) => write!(f, "{d}"),
262            DialectRef::Custom(name) => write!(f, "Custom({name})"),
263        }
264    }
265}
266
267// ═══════════════════════════════════════════════════════════════════════════
268// Plugin-aware transform
269// ═══════════════════════════════════════════════════════════════════════════
270
271// ═══════════════════════════════════════════════════════════════════════════
272// Plugin-aware transform
273// ═══════════════════════════════════════════════════════════════════════════
274
275use crate::ast::TypedFunction;
276
277/// Return the canonical SQL function name for a TypedFunction variant.
278fn typed_function_canonical_name(func: &TypedFunction) -> &'static str {
279    match func {
280        TypedFunction::DateAdd { .. } => "DATE_ADD",
281        TypedFunction::DateDiff { .. } => "DATE_DIFF",
282        TypedFunction::DateTrunc { .. } => "DATE_TRUNC",
283        TypedFunction::DateSub { .. } => "DATE_SUB",
284        TypedFunction::CurrentDate => "CURRENT_DATE",
285        TypedFunction::CurrentTimestamp => "NOW",
286        TypedFunction::StrToTime { .. } => "STR_TO_TIME",
287        TypedFunction::TimeToStr { .. } => "TIME_TO_STR",
288        TypedFunction::TsOrDsToDate { .. } => "TS_OR_DS_TO_DATE",
289        TypedFunction::Year { .. } => "YEAR",
290        TypedFunction::Month { .. } => "MONTH",
291        TypedFunction::Day { .. } => "DAY",
292        TypedFunction::Trim { .. } => "TRIM",
293        TypedFunction::Substring { .. } => "SUBSTRING",
294        TypedFunction::Upper { .. } => "UPPER",
295        TypedFunction::Lower { .. } => "LOWER",
296        TypedFunction::RegexpLike { .. } => "REGEXP_LIKE",
297        TypedFunction::RegexpExtract { .. } => "REGEXP_EXTRACT",
298        TypedFunction::RegexpReplace { .. } => "REGEXP_REPLACE",
299        TypedFunction::ConcatWs { .. } => "CONCAT_WS",
300        TypedFunction::Split { .. } => "SPLIT",
301        TypedFunction::Initcap { .. } => "INITCAP",
302        TypedFunction::Length { .. } => "LENGTH",
303        TypedFunction::Replace { .. } => "REPLACE",
304        TypedFunction::Reverse { .. } => "REVERSE",
305        TypedFunction::Left { .. } => "LEFT",
306        TypedFunction::Right { .. } => "RIGHT",
307        TypedFunction::Lpad { .. } => "LPAD",
308        TypedFunction::Rpad { .. } => "RPAD",
309        TypedFunction::Count { .. } => "COUNT",
310        TypedFunction::Sum { .. } => "SUM",
311        TypedFunction::Avg { .. } => "AVG",
312        TypedFunction::Min { .. } => "MIN",
313        TypedFunction::Max { .. } => "MAX",
314        TypedFunction::ArrayAgg { .. } => "ARRAY_AGG",
315        TypedFunction::ApproxDistinct { .. } => "APPROX_DISTINCT",
316        TypedFunction::Variance { .. } => "VARIANCE",
317        TypedFunction::Stddev { .. } => "STDDEV",
318        TypedFunction::ArrayConcat { .. } => "ARRAY_CONCAT",
319        TypedFunction::ArrayContains { .. } => "ARRAY_CONTAINS",
320        TypedFunction::ArraySize { .. } => "ARRAY_SIZE",
321        TypedFunction::Explode { .. } => "EXPLODE",
322        TypedFunction::GenerateSeries { .. } => "GENERATE_SERIES",
323        TypedFunction::Flatten { .. } => "FLATTEN",
324        TypedFunction::JSONExtract { .. } => "JSON_EXTRACT",
325        TypedFunction::JSONExtractScalar { .. } => "JSON_EXTRACT_SCALAR",
326        TypedFunction::ParseJSON { .. } => "PARSE_JSON",
327        TypedFunction::JSONFormat { .. } => "JSON_FORMAT",
328        TypedFunction::RowNumber => "ROW_NUMBER",
329        TypedFunction::Rank => "RANK",
330        TypedFunction::DenseRank => "DENSE_RANK",
331        TypedFunction::NTile { .. } => "NTILE",
332        TypedFunction::Lead { .. } => "LEAD",
333        TypedFunction::Lag { .. } => "LAG",
334        TypedFunction::FirstValue { .. } => "FIRST_VALUE",
335        TypedFunction::LastValue { .. } => "LAST_VALUE",
336        TypedFunction::Abs { .. } => "ABS",
337        TypedFunction::Ceil { .. } => "CEIL",
338        TypedFunction::Floor { .. } => "FLOOR",
339        TypedFunction::Round { .. } => "ROUND",
340        TypedFunction::Log { .. } => "LOG",
341        TypedFunction::Ln { .. } => "LN",
342        TypedFunction::Pow { .. } => "POW",
343        TypedFunction::Sqrt { .. } => "SQRT",
344        TypedFunction::Greatest { .. } => "GREATEST",
345        TypedFunction::Least { .. } => "LEAST",
346        TypedFunction::Mod { .. } => "MOD",
347        TypedFunction::Hex { .. } => "HEX",
348        TypedFunction::Unhex { .. } => "UNHEX",
349        TypedFunction::Md5 { .. } => "MD5",
350        TypedFunction::Sha { .. } => "SHA",
351        TypedFunction::Sha2 { .. } => "SHA2",
352    }
353}
354
355/// Extract the argument expressions from a TypedFunction (in positional order).
356fn typed_function_args(func: &TypedFunction) -> Vec<Expr> {
357    match func {
358        TypedFunction::CurrentDate | TypedFunction::CurrentTimestamp => vec![],
359        TypedFunction::RowNumber | TypedFunction::Rank | TypedFunction::DenseRank => vec![],
360        TypedFunction::Length { expr }
361        | TypedFunction::Upper { expr }
362        | TypedFunction::Lower { expr }
363        | TypedFunction::Initcap { expr }
364        | TypedFunction::Reverse { expr }
365        | TypedFunction::Abs { expr }
366        | TypedFunction::Ceil { expr }
367        | TypedFunction::Floor { expr }
368        | TypedFunction::Ln { expr }
369        | TypedFunction::Sqrt { expr }
370        | TypedFunction::Explode { expr }
371        | TypedFunction::Flatten { expr }
372        | TypedFunction::ArraySize { expr }
373        | TypedFunction::ParseJSON { expr }
374        | TypedFunction::JSONFormat { expr }
375        | TypedFunction::Hex { expr }
376        | TypedFunction::Unhex { expr }
377        | TypedFunction::Md5 { expr }
378        | TypedFunction::Sha { expr }
379        | TypedFunction::TsOrDsToDate { expr }
380        | TypedFunction::Year { expr }
381        | TypedFunction::Month { expr }
382        | TypedFunction::Day { expr }
383        | TypedFunction::ApproxDistinct { expr }
384        | TypedFunction::Variance { expr }
385        | TypedFunction::Stddev { expr }
386        | TypedFunction::FirstValue { expr }
387        | TypedFunction::LastValue { expr } => vec![*expr.clone()],
388        TypedFunction::DateTrunc { unit, expr } => {
389            vec![Expr::StringLiteral(format!("{unit:?}")), *expr.clone()]
390        }
391        TypedFunction::DateAdd { expr, interval, .. }
392        | TypedFunction::DateSub { expr, interval, .. } => {
393            vec![*expr.clone(), *interval.clone()]
394        }
395        TypedFunction::DateDiff { start, end, .. } => vec![*start.clone(), *end.clone()],
396        TypedFunction::StrToTime { expr, format }
397        | TypedFunction::TimeToStr { expr, format } => {
398            vec![*expr.clone(), *format.clone()]
399        }
400        TypedFunction::Trim { expr, .. } => vec![*expr.clone()],
401        TypedFunction::Substring { expr, start, length } => {
402            let mut args = vec![*expr.clone(), *start.clone()];
403            if let Some(len) = length {
404                args.push(*len.clone());
405            }
406            args
407        }
408        TypedFunction::RegexpLike { expr, pattern, flags } => {
409            let mut args = vec![*expr.clone(), *pattern.clone()];
410            if let Some(f) = flags {
411                args.push(*f.clone());
412            }
413            args
414        }
415        TypedFunction::RegexpExtract { expr, pattern, group_index } => {
416            let mut args = vec![*expr.clone(), *pattern.clone()];
417            if let Some(g) = group_index {
418                args.push(*g.clone());
419            }
420            args
421        }
422        TypedFunction::RegexpReplace { expr, pattern, replacement, flags } => {
423            let mut args = vec![*expr.clone(), *pattern.clone(), *replacement.clone()];
424            if let Some(f) = flags {
425                args.push(*f.clone());
426            }
427            args
428        }
429        TypedFunction::ConcatWs { separator, exprs } => {
430            let mut args = vec![*separator.clone()];
431            args.extend(exprs.iter().cloned());
432            args
433        }
434        TypedFunction::Split { expr, delimiter } => vec![*expr.clone(), *delimiter.clone()],
435        TypedFunction::Replace { expr, from, to } => {
436            vec![*expr.clone(), *from.clone(), *to.clone()]
437        }
438        TypedFunction::Left { expr, n } | TypedFunction::Right { expr, n } => {
439            vec![*expr.clone(), *n.clone()]
440        }
441        TypedFunction::Lpad { expr, length, pad }
442        | TypedFunction::Rpad { expr, length, pad } => {
443            let mut args = vec![*expr.clone(), *length.clone()];
444            if let Some(p) = pad {
445                args.push(*p.clone());
446            }
447            args
448        }
449        TypedFunction::Count { expr, .. }
450        | TypedFunction::Sum { expr, .. }
451        | TypedFunction::Avg { expr, .. }
452        | TypedFunction::Min { expr }
453        | TypedFunction::Max { expr }
454        | TypedFunction::ArrayAgg { expr, .. } => vec![*expr.clone()],
455        TypedFunction::ArrayConcat { arrays } => arrays.clone(),
456        TypedFunction::ArrayContains { array, element } => {
457            vec![*array.clone(), *element.clone()]
458        }
459        TypedFunction::GenerateSeries { start, stop, step } => {
460            let mut args = vec![*start.clone(), *stop.clone()];
461            if let Some(s) = step {
462                args.push(*s.clone());
463            }
464            args
465        }
466        TypedFunction::JSONExtract { expr, path }
467        | TypedFunction::JSONExtractScalar { expr, path } => {
468            vec![*expr.clone(), *path.clone()]
469        }
470        TypedFunction::NTile { n } => vec![*n.clone()],
471        TypedFunction::Lead { expr, offset, default }
472        | TypedFunction::Lag { expr, offset, default } => {
473            let mut args = vec![*expr.clone()];
474            if let Some(o) = offset {
475                args.push(*o.clone());
476            }
477            if let Some(d) = default {
478                args.push(*d.clone());
479            }
480            args
481        }
482        TypedFunction::Round { expr, decimals } => {
483            let mut args = vec![*expr.clone()];
484            if let Some(d) = decimals {
485                args.push(*d.clone());
486            }
487            args
488        }
489        TypedFunction::Log { expr, base } => {
490            let mut args = vec![*expr.clone()];
491            if let Some(b) = base {
492                args.push(*b.clone());
493            }
494            args
495        }
496        TypedFunction::Pow { base, exponent } => vec![*base.clone(), *exponent.clone()],
497        TypedFunction::Greatest { exprs } | TypedFunction::Least { exprs } => exprs.clone(),
498        TypedFunction::Mod { left, right } => vec![*left.clone(), *right.clone()],
499        TypedFunction::Sha2 { expr, bit_length } => vec![*expr.clone(), *bit_length.clone()],
500    }
501}
502
503/// Transform a statement from one dialect to another, supporting custom
504/// dialect plugins.
505///
506/// For built-in → built-in transforms this delegates to the existing
507/// [`super::transform`]. When either side is a custom dialect the plugin's
508/// transform hooks are applied.
509#[must_use]
510pub fn transform(statement: &Statement, from: &DialectRef, to: &DialectRef) -> Statement {
511    // Fast path: both built-in → use existing logic
512    if let (DialectRef::BuiltIn(f), DialectRef::BuiltIn(t)) = (from, to) {
513        return super::transform(statement, *f, *t);
514    }
515
516    // If the target is a custom dialect with a full statement transform, try that first.
517    if let Some(plugin) = to.as_plugin()
518        && let Some(transformed) = plugin.transform_statement(statement)
519    {
520        return transformed;
521    }
522
523    // Otherwise apply expression-level transforms via DialectRef helpers.
524    let mut stmt = statement.clone();
525    transform_statement_plugin(&mut stmt, to);
526    stmt
527}
528
529/// Recursively transform a statement using plugin-aware rules.
530fn transform_statement_plugin(statement: &mut Statement, target: &DialectRef) {
531    match statement {
532        Statement::Select(sel) => {
533            for item in &mut sel.columns {
534                if let crate::ast::SelectItem::Expr { expr, .. } = item {
535                    *expr = transform_expr_plugin(expr.clone(), target);
536                }
537            }
538            if let Some(wh) = &mut sel.where_clause {
539                *wh = transform_expr_plugin(wh.clone(), target);
540            }
541            for gb in &mut sel.group_by {
542                *gb = transform_expr_plugin(gb.clone(), target);
543            }
544            if let Some(having) = &mut sel.having {
545                *having = transform_expr_plugin(having.clone(), target);
546            }
547        }
548        Statement::CreateTable(ct) => {
549            for col in &mut ct.columns {
550                col.data_type = target.map_data_type(&col.data_type);
551                if let Some(default) = &mut col.default {
552                    *default = transform_expr_plugin(default.clone(), target);
553                }
554            }
555        }
556        _ => {}
557    }
558}
559
560/// Transform an expression using plugin-aware rules.
561fn transform_expr_plugin(expr: Expr, target: &DialectRef) -> Expr {
562    // Let the plugin have first shot at transforming the whole expression
563    if let Some(plugin) = target.as_plugin()
564        && let Some(transformed) = plugin.transform_expr(&expr)
565    {
566        return transformed;
567    }
568
569    match expr {
570        // For TypedFunction, check if the plugin wants to rename the function.
571        // If so, convert it to a plain Expr::Function with the new name.
572        Expr::TypedFunction { func, filter, over } => {
573            if let DialectRef::Custom(_) = target {
574                let canonical = typed_function_canonical_name(&func);
575                let new_name = target.map_function_name(canonical);
576                if new_name != canonical {
577                    // Plugin wants to rename this function — convert to Expr::Function
578                    let args = typed_function_args(&func)
579                        .into_iter()
580                        .map(|a| transform_expr_plugin(a, target))
581                        .collect();
582                    return Expr::Function {
583                        name: new_name,
584                        args,
585                        distinct: false,
586                        filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
587                        over,
588                    };
589                }
590            }
591            // No rename — recurse into children
592            Expr::TypedFunction {
593                func: func.transform_children(&|e| transform_expr_plugin(e, target)),
594                filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
595                over,
596            }
597        }
598        Expr::Function {
599            name,
600            args,
601            distinct,
602            filter,
603            over,
604        } => {
605            let new_name = target.map_function_name(&name);
606            let new_args: Vec<Expr> = args
607                .into_iter()
608                .map(|a| transform_expr_plugin(a, target))
609                .collect();
610            Expr::Function {
611                name: new_name,
612                args: new_args,
613                distinct,
614                filter: filter.map(|f| Box::new(transform_expr_plugin(*f, target))),
615                over,
616            }
617        }
618        Expr::Cast { expr, data_type } => Expr::Cast {
619            expr: Box::new(transform_expr_plugin(*expr, target)),
620            data_type: target.map_data_type(&data_type),
621        },
622        Expr::ILike {
623            expr,
624            pattern,
625            negated,
626            escape,
627        } if !target.supports_ilike() => Expr::Like {
628            expr: Box::new(Expr::TypedFunction {
629                func: crate::ast::TypedFunction::Lower {
630                    expr: Box::new(transform_expr_plugin(*expr, target)),
631                },
632                filter: None,
633                over: None,
634            }),
635            pattern: Box::new(Expr::TypedFunction {
636                func: crate::ast::TypedFunction::Lower {
637                    expr: Box::new(transform_expr_plugin(*pattern, target)),
638                },
639                filter: None,
640                over: None,
641            }),
642            negated,
643            escape,
644        },
645        Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
646            left: Box::new(transform_expr_plugin(*left, target)),
647            op,
648            right: Box::new(transform_expr_plugin(*right, target)),
649        },
650        Expr::UnaryOp { op, expr } => Expr::UnaryOp {
651            op,
652            expr: Box::new(transform_expr_plugin(*expr, target)),
653        },
654        Expr::Nested(inner) => Expr::Nested(Box::new(transform_expr_plugin(*inner, target))),
655        Expr::Column {
656            table,
657            name,
658            quote_style,
659            table_quote_style,
660        } => {
661            let new_qs = if quote_style.is_quoted() {
662                target.quote_style()
663            } else {
664                QuoteStyle::None
665            };
666            let new_tqs = if table_quote_style.is_quoted() {
667                target.quote_style()
668            } else {
669                QuoteStyle::None
670            };
671            Expr::Column {
672                table,
673                name,
674                quote_style: new_qs,
675                table_quote_style: new_tqs,
676            }
677        }
678        other => other,
679    }
680}
681
682// ═══════════════════════════════════════════════════════════════════════════
683// Plugin-aware top-level API
684// ═══════════════════════════════════════════════════════════════════════════
685
686use crate::errors;
687
688/// Transpile a SQL string using [`DialectRef`], supporting custom plugins.
689///
690/// # Example
691///
692/// ```rust
693/// use sqlglot_rust::dialects::plugin::{DialectRef, transpile_ext};
694/// use sqlglot_rust::Dialect;
695///
696/// let result = transpile_ext(
697///     "SELECT NOW()",
698///     &DialectRef::from(Dialect::Postgres),
699///     &DialectRef::from(Dialect::Mysql),
700/// ).unwrap();
701/// ```
702///
703/// # Errors
704///
705/// Returns a [`SqlglotError`](crate::errors::SqlglotError) if parsing fails.
706pub fn transpile_ext(
707    sql: &str,
708    read_dialect: &DialectRef,
709    write_dialect: &DialectRef,
710) -> errors::Result<String> {
711    // Parse using the read dialect (fall back to Ansi for custom dialects)
712    let parse_dialect = read_dialect
713        .as_builtin()
714        .unwrap_or(Dialect::Ansi);
715    let ast = crate::parser::parse(sql, parse_dialect)?;
716    let transformed = transform(&ast, read_dialect, write_dialect);
717    let gen_dialect = write_dialect
718        .as_builtin()
719        .unwrap_or(Dialect::Ansi);
720    Ok(crate::generator::generate(&transformed, gen_dialect))
721}
722
723/// Transpile multiple statements using [`DialectRef`], supporting custom plugins.
724///
725/// # Errors
726///
727/// Returns a [`SqlglotError`](crate::errors::SqlglotError) if parsing fails.
728pub fn transpile_statements_ext(
729    sql: &str,
730    read_dialect: &DialectRef,
731    write_dialect: &DialectRef,
732) -> errors::Result<Vec<String>> {
733    let parse_dialect = read_dialect
734        .as_builtin()
735        .unwrap_or(Dialect::Ansi);
736    let stmts = crate::parser::parse_statements(sql, parse_dialect)?;
737    let gen_dialect = write_dialect
738        .as_builtin()
739        .unwrap_or(Dialect::Ansi);
740    let mut results = Vec::with_capacity(stmts.len());
741    for stmt in &stmts {
742        let transformed = transform(stmt, read_dialect, write_dialect);
743        results.push(crate::generator::generate(&transformed, gen_dialect));
744    }
745    Ok(results)
746}
747
748// ═══════════════════════════════════════════════════════════════════════════
749// Convenience registration functions
750// ═══════════════════════════════════════════════════════════════════════════
751
752/// Register a custom dialect plugin in the global registry.
753///
754/// This is a convenience wrapper around [`DialectRegistry::global().register()`].
755pub fn register_dialect<P: DialectPlugin + 'static>(plugin: P) {
756    DialectRegistry::global().register(plugin);
757}
758
759/// Look up a dialect by name, returning either a built-in or custom [`DialectRef`].
760///
761/// Checks built-in dialects first, then the custom plugin registry.
762#[must_use]
763pub fn resolve_dialect(name: &str) -> Option<DialectRef> {
764    // Try built-in first
765    if let Some(d) = Dialect::from_str(name) {
766        return Some(DialectRef::BuiltIn(d));
767    }
768    // Then try plugin registry
769    if DialectRegistry::global().get(name).is_some() {
770        return Some(DialectRef::Custom(name.to_lowercase()));
771    }
772    None
773}