Skip to main content

sqlglot_rust/optimizer/
lineage.rs

1//! Column lineage tracking for SQL queries.
2//!
3//! Provides functionality to trace data flow from source columns through
4//! query transformations to output columns. This is the foundation for
5//! data governance tools and impact analysis.
6//!
7//! Inspired by Python sqlglot's `lineage.py`.
8//!
9//! # Example
10//!
11//! ```rust
12//! use sqlglot_rust::parser::parse;
13//! use sqlglot_rust::dialects::Dialect;
14//! use sqlglot_rust::optimizer::lineage::{lineage, LineageConfig};
15//! use sqlglot_rust::schema::MappingSchema;
16//!
17//! let sql = "SELECT a, b + 1 AS c FROM t";
18//! let ast = parse(sql, Dialect::Ansi).unwrap();
19//! let schema = MappingSchema::new(Dialect::Ansi);
20//! let config = LineageConfig::default();
21//!
22//! let graph = lineage("c", &ast, &schema, &config).unwrap();
23//! assert_eq!(graph.node.name, "c");
24//! ```
25
26use std::collections::{HashMap, HashSet};
27
28use crate::ast::*;
29use crate::dialects::Dialect;
30use crate::errors::SqlglotError;
31use crate::schema::{MappingSchema, Schema};
32
33// ═══════════════════════════════════════════════════════════════════════
34// Error types
35// ═══════════════════════════════════════════════════════════════════════
36
37/// Errors specific to lineage operations.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum LineageError {
40    /// The target column was not found in the output.
41    ColumnNotFound(String),
42    /// Ambiguous column reference (multiple sources).
43    AmbiguousColumn(String),
44    /// Invalid query structure for lineage analysis.
45    InvalidQuery(String),
46    /// A parsing error occurred.
47    ParseError(String),
48}
49
50impl std::fmt::Display for LineageError {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            LineageError::ColumnNotFound(c) => write!(f, "Column not found in output: {c}"),
54            LineageError::AmbiguousColumn(c) => write!(f, "Ambiguous column reference: {c}"),
55            LineageError::InvalidQuery(msg) => write!(f, "Invalid query for lineage: {msg}"),
56            LineageError::ParseError(msg) => write!(f, "Parse error: {msg}"),
57        }
58    }
59}
60
61impl std::error::Error for LineageError {}
62
63impl From<LineageError> for SqlglotError {
64    fn from(e: LineageError) -> Self {
65        SqlglotError::Internal(e.to_string())
66    }
67}
68
69/// Result type for lineage operations.
70pub type LineageResult<T> = std::result::Result<T, LineageError>;
71
72// ═══════════════════════════════════════════════════════════════════════
73// Configuration
74// ═══════════════════════════════════════════════════════════════════════
75
76/// Configuration for lineage analysis.
77#[derive(Debug, Clone)]
78pub struct LineageConfig {
79    /// SQL dialect for parsing and identifier normalization.
80    pub dialect: Dialect,
81    /// Whether to trim column qualifiers in output node names.
82    pub trim_qualifiers: bool,
83    /// External sources mapping for multi-query lineage.
84    /// Maps source names to their SQL definitions (e.g., views).
85    pub sources: HashMap<String, String>,
86}
87
88impl Default for LineageConfig {
89    fn default() -> Self {
90        Self {
91            dialect: Dialect::Ansi,
92            trim_qualifiers: true,
93            sources: HashMap::new(),
94        }
95    }
96}
97
98impl LineageConfig {
99    /// Create a new configuration with the specified dialect.
100    #[must_use]
101    pub fn new(dialect: Dialect) -> Self {
102        Self {
103            dialect,
104            ..Default::default()
105        }
106    }
107
108    /// Add external sources for multi-query lineage.
109    #[must_use]
110    pub fn with_sources(mut self, sources: HashMap<String, String>) -> Self {
111        self.sources = sources;
112        self
113    }
114
115    /// Set whether to trim table qualifiers from output names.
116    #[must_use]
117    pub fn with_trim_qualifiers(mut self, trim: bool) -> Self {
118        self.trim_qualifiers = trim;
119        self
120    }
121}
122
123// ═══════════════════════════════════════════════════════════════════════
124// Lineage Node
125// ═══════════════════════════════════════════════════════════════════════
126
127/// A node in the lineage graph representing a column or expression.
128#[derive(Debug, Clone)]
129pub struct LineageNode {
130    /// The name of this column/expression (e.g., "a", "SUM(b)", "t.col").
131    pub name: String,
132    /// The AST expression this node represents.
133    pub expression: Option<Expr>,
134    /// The source table/CTE/subquery name, if applicable.
135    pub source_name: Option<String>,
136    /// Reference to the source AST (for complex expressions).
137    pub source: Option<Expr>,
138    /// Child nodes (upstream lineage - where data comes from).
139    pub downstream: Vec<LineageNode>,
140    /// The alias, if this is an aliased expression.
141    pub alias: Option<String>,
142    /// Depth in the lineage graph (0 = root output column).
143    pub depth: usize,
144}
145
146impl LineageNode {
147    /// Create a new lineage node.
148    #[must_use]
149    pub fn new(name: String) -> Self {
150        Self {
151            name,
152            expression: None,
153            source_name: None,
154            source: None,
155            downstream: Vec::new(),
156            alias: None,
157            depth: 0,
158        }
159    }
160
161    /// Create a node with source information.
162    #[must_use]
163    pub fn with_source(mut self, source_name: String) -> Self {
164        self.source_name = Some(source_name);
165        self
166    }
167
168    /// Create a node with an expression.
169    #[must_use]
170    pub fn with_expression(mut self, expr: Expr) -> Self {
171        self.expression = Some(expr);
172        self
173    }
174
175    /// Create a node with an alias.
176    #[must_use]
177    #[allow(dead_code)]
178    pub fn with_alias(mut self, alias: String) -> Self {
179        self.alias = Some(alias);
180        self
181    }
182
183    /// Create a node with depth.
184    #[must_use]
185    pub fn with_depth(mut self, depth: usize) -> Self {
186        self.depth = depth;
187        self
188    }
189
190    /// Add a downstream (upstream lineage) node.
191    #[allow(dead_code)]
192    pub fn add_downstream(&mut self, node: LineageNode) {
193        self.downstream.push(node);
194    }
195
196    /// Walk through all nodes in the lineage graph (pre-order).
197    pub fn walk<F>(&self, visitor: &mut F)
198    where
199        F: FnMut(&LineageNode),
200    {
201        visitor(self);
202        for child in &self.downstream {
203            child.walk(visitor);
204        }
205    }
206
207    /// Iterate over all nodes in the lineage graph.
208    #[must_use]
209    pub fn iter(&self) -> LineageIterator<'_> {
210        LineageIterator { stack: vec![self] }
211    }
212
213    /// Get all source columns (leaf nodes) in this lineage.
214    #[must_use]
215    #[allow(dead_code)]
216    pub fn source_columns(&self) -> Vec<&LineageNode> {
217        self.iter().filter(|n| n.downstream.is_empty()).collect()
218    }
219
220    /// Get all source table names referenced in this lineage.
221    #[must_use]
222    pub fn source_tables(&self) -> Vec<String> {
223        let mut tables = HashSet::new();
224        for node in self.iter() {
225            if let Some(ref source) = node.source_name {
226                tables.insert(source.clone());
227            }
228        }
229        tables.into_iter().collect()
230    }
231
232    /// Generate DOT format representation for visualization.
233    #[must_use]
234    pub fn to_dot(&self) -> String {
235        let mut dot = String::from("digraph lineage {\n");
236        dot.push_str("  rankdir=BT;\n");
237        dot.push_str("  node [shape=box];\n");
238
239        let mut node_id = 0;
240        let mut node_ids = HashMap::new();
241
242        // First pass: assign IDs and create nodes
243        self.walk(&mut |node| {
244            let id = format!("n{}", node_id);
245            let label = if let Some(ref src) = node.source_name {
246                format!("{}.{}", src, node.name)
247            } else {
248                node.name.clone()
249            };
250            dot.push_str(&format!("  {} [label=\"{}\"];\n", id, escape_dot(&label)));
251            node_ids.insert(node as *const _ as usize, id);
252            node_id += 1;
253        });
254
255        // Second pass: create edges
256        self.walk(&mut |node| {
257            let parent_id = node_ids.get(&(node as *const _ as usize)).unwrap();
258            for child in &node.downstream {
259                let child_id = node_ids.get(&(child as *const _ as usize)).unwrap();
260                dot.push_str(&format!("  {} -> {};\n", child_id, parent_id));
261            }
262        });
263
264        dot.push_str("}\n");
265        dot
266    }
267
268    /// Generate Mermaid diagram representation.
269    #[must_use]
270    pub fn to_mermaid(&self) -> String {
271        let mut mermaid = String::from("flowchart BT\n");
272
273        let mut node_id = 0;
274        let mut node_ids = HashMap::new();
275
276        // First pass: assign IDs and create nodes
277        self.walk(&mut |node| {
278            let id = format!("n{}", node_id);
279            let label = if let Some(ref src) = node.source_name {
280                format!("{}.{}", src, node.name)
281            } else {
282                node.name.clone()
283            };
284            mermaid.push_str(&format!("  {}[\"{}\"]\n", id, escape_mermaid(&label)));
285            node_ids.insert(node as *const _ as usize, id);
286            node_id += 1;
287        });
288
289        // Second pass: create edges
290        self.walk(&mut |node| {
291            let parent_id = node_ids.get(&(node as *const _ as usize)).unwrap();
292            for child in &node.downstream {
293                let child_id = node_ids.get(&(child as *const _ as usize)).unwrap();
294                mermaid.push_str(&format!("  {} --> {}\n", child_id, parent_id));
295            }
296        });
297
298        mermaid
299    }
300}
301
302/// Iterator over lineage nodes (pre-order traversal).
303pub struct LineageIterator<'a> {
304    stack: Vec<&'a LineageNode>,
305}
306
307impl<'a> Iterator for LineageIterator<'a> {
308    type Item = &'a LineageNode;
309
310    fn next(&mut self) -> Option<Self::Item> {
311        self.stack.pop().map(|node| {
312            // Push children in reverse order for pre-order traversal
313            for child in node.downstream.iter().rev() {
314                self.stack.push(child);
315            }
316            node
317        })
318    }
319}
320
321// ═══════════════════════════════════════════════════════════════════════
322// Lineage Graph
323// ═══════════════════════════════════════════════════════════════════════
324
325/// A lineage graph rooted at a specific output column.
326#[derive(Debug, Clone)]
327pub struct LineageGraph {
328    /// The root node representing the target output column.
329    pub node: LineageNode,
330    /// The original SQL that was analyzed.
331    pub sql: Option<String>,
332    /// The dialect used for analysis.
333    pub dialect: Dialect,
334}
335
336impl LineageGraph {
337    /// Create a new lineage graph.
338    #[must_use]
339    pub fn new(node: LineageNode, dialect: Dialect) -> Self {
340        Self {
341            node,
342            sql: None,
343            dialect,
344        }
345    }
346
347    /// Set the original SQL string.
348    #[must_use]
349    #[allow(dead_code)]
350    pub fn with_sql(mut self, sql: String) -> Self {
351        self.sql = Some(sql);
352        self
353    }
354
355    /// Get all source tables in the lineage.
356    #[must_use]
357    pub fn source_tables(&self) -> Vec<String> {
358        self.node.source_tables()
359    }
360
361    /// Get all source columns (leaf nodes).
362    #[must_use]
363    #[allow(dead_code)]
364    pub fn source_columns(&self) -> Vec<&LineageNode> {
365        self.node.source_columns()
366    }
367
368    /// Walk through all nodes in the graph.
369    #[allow(dead_code)]
370    pub fn walk<F>(&self, visitor: &mut F)
371    where
372        F: FnMut(&LineageNode),
373    {
374        self.node.walk(visitor);
375    }
376
377    /// Generate DOT format visualization.
378    #[must_use]
379    pub fn to_dot(&self) -> String {
380        self.node.to_dot()
381    }
382
383    /// Generate Mermaid diagram visualization.
384    #[must_use]
385    pub fn to_mermaid(&self) -> String {
386        self.node.to_mermaid()
387    }
388}
389
390// ═══════════════════════════════════════════════════════════════════════
391// Context for lineage building
392// ═══════════════════════════════════════════════════════════════════════
393
394/// Internal context for building lineage graphs.
395struct LineageContext {
396    /// The schema for column resolution.
397    schema: MappingSchema,
398    /// Configuration options.
399    config: LineageConfig,
400    /// Current depth in the lineage graph.
401    depth: usize,
402    /// CTE definitions available in this scope (owned).
403    ctes: HashMap<String, Statement>,
404    /// Visible sources in current scope (alias/name → source info).
405    sources: HashMap<String, SourceInfo>,
406    /// External sources for multi-query lineage.
407    external_sources: HashMap<String, Statement>,
408    /// Sources currently being visited (to prevent infinite recursion).
409    visiting: HashSet<String>,
410}
411
412/// Information about a source (table, CTE, derived table).
413#[derive(Debug, Clone)]
414struct SourceInfo {
415    /// The source type.
416    kind: SourceKind,
417    /// For subqueries/CTEs, the SELECT columns.
418    columns: Option<Vec<SelectItem>>,
419    /// The underlying statement, if any (owned).
420    statement: Option<Statement>,
421}
422
423#[derive(Debug, Clone, Copy, PartialEq, Eq)]
424#[allow(dead_code)]
425enum SourceKind {
426    Table,
427    Cte,
428    DerivedTable,
429    External,
430}
431
432impl LineageContext {
433    fn new(schema: &MappingSchema, config: &LineageConfig) -> Self {
434        Self {
435            schema: schema.clone(),
436            config: config.clone(),
437            depth: 0,
438            ctes: HashMap::new(),
439            sources: HashMap::new(),
440            external_sources: HashMap::new(),
441            visiting: HashSet::new(),
442        }
443    }
444
445    fn with_depth(&self, depth: usize) -> Self {
446        Self {
447            schema: self.schema.clone(),
448            config: self.config.clone(),
449            depth,
450            ctes: self.ctes.clone(),
451            sources: self.sources.clone(),
452            external_sources: self.external_sources.clone(),
453            visiting: self.visiting.clone(),
454        }
455    }
456
457    #[allow(dead_code)]
458    fn resolve_source(&self, name: &str) -> Option<&SourceInfo> {
459        let normalized = normalize_name(name, self.config.dialect);
460        self.sources.get(&normalized)
461    }
462}
463
464// ═══════════════════════════════════════════════════════════════════════
465// Public API
466// ═══════════════════════════════════════════════════════════════════════
467
468/// Build lineage for a specific output column in a SQL statement.
469///
470/// # Arguments
471///
472/// * `column` - The name of the output column to trace (can include table qualifier).
473/// * `statement` - The parsed SQL statement.
474/// * `schema` - Schema information for table/column resolution.
475/// * `config` - Configuration options.
476///
477/// # Returns
478///
479/// A [`LineageGraph`] rooted at the target column, showing its upstream lineage.
480///
481/// # Errors
482///
483/// Returns [`LineageError::ColumnNotFound`] if the column is not in the output.
484///
485/// # Example
486///
487/// ```rust
488/// use sqlglot_rust::parser::parse;
489/// use sqlglot_rust::dialects::Dialect;
490/// use sqlglot_rust::optimizer::lineage::{lineage, LineageConfig};
491/// use sqlglot_rust::schema::MappingSchema;
492///
493/// let sql = "SELECT a, b AS c FROM t";
494/// let ast = parse(sql, Dialect::Ansi).unwrap();
495/// let schema = MappingSchema::new(Dialect::Ansi);
496/// let config = LineageConfig::default();
497///
498/// let graph = lineage("c", &ast, &schema, &config).unwrap();
499/// assert_eq!(graph.node.name, "c");
500/// ```
501pub fn lineage(
502    column: &str,
503    statement: &Statement,
504    schema: &MappingSchema,
505    config: &LineageConfig,
506) -> LineageResult<LineageGraph> {
507    // Parse external sources if provided
508    let mut ctx = LineageContext::new(schema, config);
509
510    for (name, sql) in &config.sources {
511        match crate::parser::parse(sql, config.dialect) {
512            Ok(stmt) => {
513                ctx.external_sources
514                    .insert(normalize_name(name, config.dialect), stmt);
515            }
516            Err(e) => return Err(LineageError::ParseError(e.to_string())),
517        }
518    }
519
520    // Build lineage for the target column
521    let node = build_lineage_for_column(column, statement, &mut ctx)?;
522
523    Ok(LineageGraph::new(node, config.dialect))
524}
525
526/// Build lineage from a SQL string.
527///
528/// Convenience function that parses the SQL and builds lineage.
529///
530/// # Example
531///
532/// ```rust
533/// use sqlglot_rust::dialects::Dialect;
534/// use sqlglot_rust::optimizer::lineage::{lineage_sql, LineageConfig};
535/// use sqlglot_rust::schema::MappingSchema;
536///
537/// let schema = MappingSchema::new(Dialect::Ansi);
538/// let config = LineageConfig::default();
539///
540/// let graph = lineage_sql("c", "SELECT a + b AS c FROM t", &schema, &config).unwrap();
541/// assert_eq!(graph.node.name, "c");
542/// ```
543pub fn lineage_sql(
544    column: &str,
545    sql: &str,
546    schema: &MappingSchema,
547    config: &LineageConfig,
548) -> LineageResult<LineageGraph> {
549    let statement = crate::parser::parse(sql, config.dialect)
550        .map_err(|e| LineageError::ParseError(e.to_string()))?;
551
552    let mut graph = lineage(column, &statement, schema, config)?;
553    graph.sql = Some(sql.to_string());
554    Ok(graph)
555}
556
557// ═══════════════════════════════════════════════════════════════════════
558// Internal lineage building
559// ═══════════════════════════════════════════════════════════════════════
560
561/// Build lineage for a specific column in a statement.
562fn build_lineage_for_column(
563    column: &str,
564    statement: &Statement,
565    ctx: &mut LineageContext,
566) -> LineageResult<LineageNode> {
567    match statement {
568        Statement::Select(sel) => build_lineage_for_select_column(column, sel, ctx),
569        Statement::SetOperation(set_op) => build_lineage_for_set_operation(column, set_op, ctx),
570        Statement::CreateView(cv) => build_lineage_for_column(column, &cv.query, ctx),
571        _ => Err(LineageError::InvalidQuery(
572            "Lineage analysis requires a SELECT or set operation statement".to_string(),
573        )),
574    }
575}
576
577/// Build lineage for a column in a SELECT statement.
578fn build_lineage_for_select_column(
579    column: &str,
580    sel: &SelectStatement,
581    ctx: &mut LineageContext,
582) -> LineageResult<LineageNode> {
583    // Register CTEs (cloning to avoid lifetime issues)
584    for cte in &sel.ctes {
585        let cte_name = normalize_name(&cte.name, ctx.config.dialect);
586        ctx.ctes.insert(cte_name.clone(), (*cte.query).clone());
587        ctx.sources.insert(
588            cte_name,
589            SourceInfo {
590                kind: SourceKind::Cte,
591                columns: extract_select_columns(&cte.query),
592                statement: Some((*cte.query).clone()),
593            },
594        );
595    }
596
597    // Register FROM source
598    if let Some(from) = &sel.from {
599        register_table_source(&from.source, ctx);
600    }
601
602    // Register JOINs
603    for join in &sel.joins {
604        register_table_source(&join.table, ctx);
605    }
606
607    // Find the target column in the SELECT list
608    let (col_name, table_qual) = parse_column_ref(column);
609
610    for item in &sel.columns {
611        match item {
612            SelectItem::Expr { expr, alias } => {
613                let item_name = alias
614                    .as_ref()
615                    .map(String::as_str)
616                    .unwrap_or_else(|| expr_output_name(expr));
617
618                if matches_column_name(item_name, &col_name) {
619                    return build_lineage_for_expr(expr, alias.clone(), ctx);
620                }
621            }
622            SelectItem::Wildcard => {
623                // Expand wildcard - check all sources
624                for (source_name, source_info) in ctx.sources.clone() {
625                    if let Some(cols) = &source_info.columns {
626                        for col_item in cols {
627                            if let SelectItem::Expr { expr, alias } = col_item {
628                                let item_name = alias
629                                    .as_ref()
630                                    .map(String::as_str)
631                                    .unwrap_or_else(|| expr_output_name(expr));
632                                if matches_column_name(item_name, &col_name) {
633                                    return build_lineage_for_expr(expr, alias.clone(), ctx);
634                                }
635                            }
636                        }
637                    } else if source_info.kind == SourceKind::Table {
638                        // Check schema for table columns
639                        if let Ok(schema_cols) = ctx.schema.column_names(&[&source_name]) {
640                            if schema_cols
641                                .iter()
642                                .any(|c| matches_column_name(c, &col_name))
643                            {
644                                // Found in schema
645                                let mut node = LineageNode::new(col_name.clone())
646                                    .with_source(source_name.clone())
647                                    .with_depth(ctx.depth);
648                                node.expression = Some(Expr::Column {
649                                    table: Some(source_name.clone()),
650                                    name: col_name.clone(),
651                                    quote_style: QuoteStyle::None,
652                                    table_quote_style: QuoteStyle::None,
653                                });
654                                return Ok(node);
655                            }
656                        }
657                    }
658                }
659            }
660            SelectItem::QualifiedWildcard { table } => {
661                if table_qual
662                    .as_ref()
663                    .is_some_and(|t| matches_column_name(t, table))
664                {
665                    // Check if column exists in this table
666                    if let Some(source_info) = ctx.sources.get(table).cloned() {
667                        if let Some(cols) = &source_info.columns {
668                            for col_item in cols {
669                                if let SelectItem::Expr { expr, alias } = col_item {
670                                    let item_name = alias
671                                        .as_ref()
672                                        .map(String::as_str)
673                                        .unwrap_or_else(|| expr_output_name(expr));
674                                    if matches_column_name(item_name, &col_name) {
675                                        return build_lineage_for_expr(expr, alias.clone(), ctx);
676                                    }
677                                }
678                            }
679                        }
680                    }
681                }
682            }
683        }
684    }
685
686    Err(LineageError::ColumnNotFound(column.to_string()))
687}
688
689/// Build lineage for a set operation (UNION, INTERSECT, EXCEPT).
690fn build_lineage_for_set_operation(
691    column: &str,
692    set_op: &SetOperationStatement,
693    ctx: &mut LineageContext,
694) -> LineageResult<LineageNode> {
695    let mut root = LineageNode::new(column.to_string()).with_depth(ctx.depth);
696
697    // Build lineage from both branches
698    let mut child_ctx = ctx.with_depth(ctx.depth + 1);
699
700    let left_lineage = build_lineage_for_column(column, &set_op.left, &mut child_ctx)?;
701    let right_lineage = build_lineage_for_column(column, &set_op.right, &mut child_ctx)?;
702
703    root.downstream.push(left_lineage);
704    root.downstream.push(right_lineage);
705
706    Ok(root)
707}
708
709/// Build lineage for an expression.
710fn build_lineage_for_expr(
711    expr: &Expr,
712    alias: Option<String>,
713    ctx: &mut LineageContext,
714) -> LineageResult<LineageNode> {
715    let name = alias
716        .clone()
717        .unwrap_or_else(|| expr_to_name(expr, ctx.config.trim_qualifiers));
718    let mut node = LineageNode::new(name.clone())
719        .with_expression(expr.clone())
720        .with_depth(ctx.depth);
721
722    if let Some(a) = alias {
723        node.alias = Some(a);
724    }
725
726    // Collect column references from the expression
727    let columns = collect_expr_columns(expr);
728
729    let mut child_ctx = ctx.with_depth(ctx.depth + 1);
730
731    for col_ref in columns {
732        let child_node = resolve_column_lineage(&col_ref, &mut child_ctx)?;
733        node.downstream.push(child_node);
734    }
735
736    Ok(node)
737}
738
739/// Resolve lineage for a column reference.
740fn resolve_column_lineage(
741    col: &ColumnReference,
742    ctx: &mut LineageContext,
743) -> LineageResult<LineageNode> {
744    let name = if ctx.config.trim_qualifiers {
745        col.name.clone()
746    } else {
747        col.qualified_name()
748    };
749
750    // If table qualifier is provided, look up in that source
751    if let Some(ref table) = col.table {
752        let normalized_table = normalize_name(table, ctx.config.dialect);
753
754        if let Some(source_info) = ctx.sources.get(&normalized_table).cloned() {
755            match source_info.kind {
756                SourceKind::Table => {
757                    // Base table - this is a leaf node
758                    let node = LineageNode::new(name)
759                        .with_source(normalized_table)
760                        .with_depth(ctx.depth);
761                    return Ok(node);
762                }
763                SourceKind::Cte | SourceKind::DerivedTable => {
764                    // Recurse into CTE/derived table (if not already visiting)
765                    if !ctx.visiting.contains(&normalized_table) {
766                        if let Some(stmt) = source_info.statement {
767                            ctx.visiting.insert(normalized_table.clone());
768                            let result = build_lineage_for_column(&col.name, &stmt, ctx);
769                            ctx.visiting.remove(&normalized_table);
770                            return result;
771                        }
772                    }
773                    // If already visiting, treat as leaf
774                    let node = LineageNode::new(name)
775                        .with_source(normalized_table)
776                        .with_depth(ctx.depth);
777                    return Ok(node);
778                }
779                SourceKind::External => {
780                    // Check external sources
781                    if let Some(stmt) = ctx.external_sources.get(&normalized_table).cloned() {
782                        return build_lineage_for_column(&col.name, &stmt, ctx);
783                    }
784                }
785            }
786        }
787    }
788
789    // No table qualifier - search all sources
790    for (source_name, source_info) in ctx.sources.clone() {
791        match source_info.kind {
792            SourceKind::Table => {
793                // Check if this table has the column
794                if ctx.schema.has_column(&[&source_name], &col.name) {
795                    let node = LineageNode::new(name)
796                        .with_source(source_name.clone())
797                        .with_depth(ctx.depth);
798                    return Ok(node);
799                }
800            }
801            SourceKind::Cte | SourceKind::DerivedTable => {
802                // Skip if already visiting this source
803                if ctx.visiting.contains(&source_name) {
804                    continue;
805                }
806                // Check if CTE/derived table outputs this column
807                if let Some(ref columns) = source_info.columns {
808                    if columns.iter().any(|c| select_item_has_column(c, &col.name)) {
809                        if let Some(stmt) = source_info.statement {
810                            ctx.visiting.insert(source_name.clone());
811                            let result = build_lineage_for_column(&col.name, &stmt, ctx);
812                            ctx.visiting.remove(&source_name);
813                            return result;
814                        }
815                    }
816                }
817            }
818            SourceKind::External => {}
819        }
820    }
821
822    // Column not found in any known source - treat as external/unknown
823    let node = LineageNode::new(name).with_depth(ctx.depth);
824    Ok(node)
825}
826
827/// Register a table source in the context.
828fn register_table_source(source: &TableSource, ctx: &mut LineageContext) {
829    match source {
830        TableSource::Table(table_ref) => {
831            let key = table_ref.alias.as_ref().unwrap_or(&table_ref.name).clone();
832            let normalized = normalize_name(&key, ctx.config.dialect);
833            // Don't overwrite CTEs or derived tables
834            if !ctx.sources.contains_key(&normalized) {
835                ctx.sources.insert(
836                    normalized,
837                    SourceInfo {
838                        kind: SourceKind::Table,
839                        columns: None,
840                        statement: None,
841                    },
842                );
843            }
844        }
845        TableSource::Subquery { query, alias } => {
846            if let Some(alias) = alias {
847                let normalized = normalize_name(alias, ctx.config.dialect);
848                ctx.sources.insert(
849                    normalized,
850                    SourceInfo {
851                        kind: SourceKind::DerivedTable,
852                        columns: extract_select_columns(query),
853                        statement: Some((**query).clone()),
854                    },
855                );
856            }
857        }
858        TableSource::Lateral { source } => {
859            register_table_source(source, ctx);
860        }
861        TableSource::Pivot { source, alias, .. } | TableSource::Unpivot { source, alias, .. } => {
862            register_table_source(source, ctx);
863            // TODO: Track pivot/unpivot column transformations
864            if let Some(alias) = alias {
865                let normalized = normalize_name(alias, ctx.config.dialect);
866                ctx.sources.insert(
867                    normalized,
868                    SourceInfo {
869                        kind: SourceKind::DerivedTable,
870                        columns: None,
871                        statement: None,
872                    },
873                );
874            }
875        }
876        TableSource::TableFunction { alias, .. } => {
877            if let Some(alias) = alias {
878                let normalized = normalize_name(alias, ctx.config.dialect);
879                ctx.sources.insert(
880                    normalized,
881                    SourceInfo {
882                        kind: SourceKind::Table,
883                        columns: None,
884                        statement: None,
885                    },
886                );
887            }
888        }
889        TableSource::Unnest { alias, .. } => {
890            if let Some(alias) = alias {
891                let normalized = normalize_name(alias, ctx.config.dialect);
892                ctx.sources.insert(
893                    normalized,
894                    SourceInfo {
895                        kind: SourceKind::Table,
896                        columns: None,
897                        statement: None,
898                    },
899                );
900            }
901        }
902    }
903}
904
905// ═══════════════════════════════════════════════════════════════════════
906// Helper types and functions
907// ═══════════════════════════════════════════════════════════════════════
908
909/// A column reference found in an expression.
910#[derive(Debug, Clone)]
911struct ColumnReference {
912    table: Option<String>,
913    name: String,
914}
915
916impl ColumnReference {
917    fn qualified_name(&self) -> String {
918        if let Some(ref table) = self.table {
919            format!("{}.{}", table, self.name)
920        } else {
921            self.name.clone()
922        }
923    }
924}
925
926/// Collect all column references from an expression.
927fn collect_expr_columns(expr: &Expr) -> Vec<ColumnReference> {
928    let mut columns = Vec::new();
929
930    expr.walk(&mut |e| {
931        if let Expr::Column { table, name, .. } = e {
932            columns.push(ColumnReference {
933                table: table.clone(),
934                name: name.clone(),
935            });
936            return false; // Don't recurse into column nodes
937        }
938        // Don't descend into subqueries
939        !matches!(
940            e,
941            Expr::Subquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. }
942        )
943    });
944
945    columns
946}
947
948/// Extract SELECT columns from a statement.
949fn extract_select_columns(stmt: &Statement) -> Option<Vec<SelectItem>> {
950    match stmt {
951        Statement::Select(sel) => Some(sel.columns.clone()),
952        Statement::SetOperation(set_op) => extract_select_columns(&set_op.left),
953        Statement::CreateView(cv) => extract_select_columns(&cv.query),
954        _ => None,
955    }
956}
957
958/// Get the output name of an expression.
959fn expr_output_name(expr: &Expr) -> &str {
960    match expr {
961        Expr::Column { name, .. } => name,
962        Expr::Alias { name, .. } => name,
963        _ => "",
964    }
965}
966
967/// Convert an expression to a displayable name.
968fn expr_to_name(expr: &Expr, trim_qualifiers: bool) -> String {
969    match expr {
970        Expr::Column { table, name, .. } => {
971            if trim_qualifiers {
972                name.clone()
973            } else if let Some(t) = table {
974                format!("{}.{}", t, name)
975            } else {
976                name.clone()
977            }
978        }
979        Expr::Alias { name, .. } => name.clone(),
980        Expr::Function { name, .. } => format!("{}(...)", name),
981        Expr::BinaryOp { op, .. } => format!("({:?})", op),
982        Expr::Cast { data_type, .. } => format!("CAST AS {:?}", data_type),
983        _ => "expr".to_string(),
984    }
985}
986
987/// Parse a column reference string into (name, optional_table_qualifier).
988fn parse_column_ref(column: &str) -> (String, Option<String>) {
989    if let Some(idx) = column.rfind('.') {
990        let table = column[..idx].to_string();
991        let name = column[idx + 1..].to_string();
992        (name, Some(table))
993    } else {
994        (column.to_string(), None)
995    }
996}
997
998/// Check if a column name matches (case-insensitive for most dialects).
999fn matches_column_name(item: &str, target: &str) -> bool {
1000    item.eq_ignore_ascii_case(target)
1001}
1002
1003/// Normalize an identifier name for the given dialect.
1004fn normalize_name(name: &str, dialect: Dialect) -> String {
1005    crate::schema::normalize_identifier(name, dialect)
1006}
1007
1008/// Check if a SELECT item outputs a column with the given name.
1009fn select_item_has_column(item: &SelectItem, name: &str) -> bool {
1010    match item {
1011        SelectItem::Expr { expr, alias } => {
1012            let item_name = alias
1013                .as_ref()
1014                .map(String::as_str)
1015                .unwrap_or_else(|| expr_output_name(expr));
1016            matches_column_name(item_name, name)
1017        }
1018        SelectItem::Wildcard => true, // Could match any column
1019        SelectItem::QualifiedWildcard { .. } => true,
1020    }
1021}
1022
1023/// Escape a string for DOT format.
1024fn escape_dot(s: &str) -> String {
1025    s.replace('\\', "\\\\")
1026        .replace('"', "\\\"")
1027        .replace('\n', "\\n")
1028}
1029
1030/// Escape a string for Mermaid format.
1031fn escape_mermaid(s: &str) -> String {
1032    s.replace('"', "'")
1033        .replace('\n', " ")
1034        .replace('[', "(")
1035        .replace(']', ")")
1036}
1037
1038// ═══════════════════════════════════════════════════════════════════════
1039// Tests
1040// ═══════════════════════════════════════════════════════════════════════
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use crate::parser::parse;
1046
1047    fn test_config() -> LineageConfig {
1048        LineageConfig::new(Dialect::Ansi)
1049    }
1050
1051    fn test_schema() -> MappingSchema {
1052        let mut schema = MappingSchema::new(Dialect::Ansi);
1053        schema
1054            .add_table(
1055                &["t"],
1056                vec![
1057                    ("a".to_string(), DataType::Int),
1058                    ("b".to_string(), DataType::Int),
1059                    ("c".to_string(), DataType::Int),
1060                ],
1061            )
1062            .unwrap();
1063        schema
1064            .add_table(
1065                &["users"],
1066                vec![
1067                    ("id".to_string(), DataType::Int),
1068                    ("name".to_string(), DataType::Varchar(Some(255))),
1069                    ("email".to_string(), DataType::Text),
1070                ],
1071            )
1072            .unwrap();
1073        schema
1074            .add_table(
1075                &["orders"],
1076                vec![
1077                    ("id".to_string(), DataType::Int),
1078                    ("user_id".to_string(), DataType::Int),
1079                    (
1080                        "amount".to_string(),
1081                        DataType::Decimal {
1082                            precision: Some(10),
1083                            scale: Some(2),
1084                        },
1085                    ),
1086                ],
1087            )
1088            .unwrap();
1089        schema
1090    }
1091
1092    #[test]
1093    fn test_simple_column_lineage() {
1094        let sql = "SELECT a FROM t";
1095        let ast = parse(sql, Dialect::Ansi).unwrap();
1096        let schema = test_schema();
1097        let config = test_config();
1098
1099        let graph = lineage("a", &ast, &schema, &config).unwrap();
1100
1101        assert_eq!(graph.node.name, "a");
1102        assert_eq!(graph.node.depth, 0);
1103        // The root column depends on t.a
1104        assert_eq!(graph.node.downstream.len(), 1);
1105        assert_eq!(graph.node.downstream[0].source_name, Some("t".to_string()));
1106    }
1107
1108    #[test]
1109    fn test_aliased_column_lineage() {
1110        let sql = "SELECT a AS col_a FROM t";
1111        let ast = parse(sql, Dialect::Ansi).unwrap();
1112        let schema = test_schema();
1113        let config = test_config();
1114
1115        let graph = lineage("col_a", &ast, &schema, &config).unwrap();
1116
1117        assert_eq!(graph.node.name, "col_a");
1118        assert_eq!(graph.node.alias, Some("col_a".to_string()));
1119    }
1120
1121    #[test]
1122    fn test_expression_lineage() {
1123        let sql = "SELECT a + b AS sum FROM t";
1124        let ast = parse(sql, Dialect::Ansi).unwrap();
1125        let schema = test_schema();
1126        let config = test_config();
1127
1128        let graph = lineage("sum", &ast, &schema, &config).unwrap();
1129
1130        assert_eq!(graph.node.name, "sum");
1131        // The sum depends on both a and b
1132        assert_eq!(graph.node.downstream.len(), 2);
1133    }
1134
1135    #[test]
1136    fn test_cte_lineage() {
1137        let sql = "WITH cte AS (SELECT a FROM t) SELECT a FROM cte";
1138        let ast = parse(sql, Dialect::Ansi).unwrap();
1139        let schema = test_schema();
1140        let config = test_config();
1141
1142        let graph = lineage("a", &ast, &schema, &config).unwrap();
1143
1144        assert_eq!(graph.node.name, "a");
1145        // Should trace through the CTE
1146        assert!(graph.source_tables().contains(&"t".to_string()));
1147    }
1148
1149    #[test]
1150    fn test_join_lineage() {
1151        let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
1152        let ast = parse(sql, Dialect::Ansi).unwrap();
1153        let schema = test_schema();
1154        let config = test_config();
1155
1156        let graph = lineage("name", &ast, &schema, &config).unwrap();
1157        assert_eq!(graph.node.name, "name");
1158
1159        let graph2 = lineage("amount", &ast, &schema, &config).unwrap();
1160        assert_eq!(graph2.node.name, "amount");
1161    }
1162
1163    #[test]
1164    fn test_union_lineage() {
1165        let sql = "SELECT a FROM t UNION SELECT b AS a FROM t";
1166        let ast = parse(sql, Dialect::Ansi).unwrap();
1167        let schema = test_schema();
1168        let config = test_config();
1169
1170        let graph = lineage("a", &ast, &schema, &config).unwrap();
1171
1172        assert_eq!(graph.node.name, "a");
1173        // Should have two branches
1174        assert_eq!(graph.node.downstream.len(), 2);
1175    }
1176
1177    #[test]
1178    fn test_column_not_found() {
1179        let sql = "SELECT a FROM t";
1180        let ast = parse(sql, Dialect::Ansi).unwrap();
1181        let schema = test_schema();
1182        let config = test_config();
1183
1184        let result = lineage("nonexistent", &ast, &schema, &config);
1185        assert!(matches!(result, Err(LineageError::ColumnNotFound(_))));
1186    }
1187
1188    #[test]
1189    fn test_derived_table_lineage() {
1190        let sql = "SELECT x FROM (SELECT a AS x FROM t) sub";
1191        let ast = parse(sql, Dialect::Ansi).unwrap();
1192        let schema = test_schema();
1193        let config = test_config();
1194
1195        let graph = lineage("x", &ast, &schema, &config).unwrap();
1196
1197        assert_eq!(graph.node.name, "x");
1198        // Should trace through the derived table to t.a
1199        assert!(graph.source_tables().contains(&"t".to_string()));
1200    }
1201
1202    #[test]
1203    fn test_function_lineage() {
1204        let sql = "SELECT SUM(a) AS total FROM t";
1205        let ast = parse(sql, Dialect::Ansi).unwrap();
1206        let schema = test_schema();
1207        let config = test_config();
1208
1209        let graph = lineage("total", &ast, &schema, &config).unwrap();
1210
1211        assert_eq!(graph.node.name, "total");
1212        assert_eq!(graph.node.downstream.len(), 1);
1213    }
1214
1215    #[test]
1216    fn test_lineage_sql_convenience() {
1217        let schema = test_schema();
1218        let config = test_config();
1219
1220        let graph = lineage_sql("b", "SELECT a, b FROM t", &schema, &config).unwrap();
1221
1222        assert_eq!(graph.node.name, "b");
1223        assert_eq!(graph.sql, Some("SELECT a, b FROM t".to_string()));
1224    }
1225
1226    #[test]
1227    fn test_source_tables() {
1228        let sql = "SELECT u.name, o.amount FROM users u JOIN orders o ON u.id = o.user_id";
1229        let ast = parse(sql, Dialect::Ansi).unwrap();
1230        let schema = test_schema();
1231        let config = test_config();
1232
1233        let graph = lineage("name", &ast, &schema, &config).unwrap();
1234        let tables = graph.source_tables();
1235
1236        assert!(tables.contains(&"u".to_string()));
1237    }
1238
1239    #[test]
1240    fn test_to_dot() {
1241        let sql = "SELECT a AS col FROM t";
1242        let ast = parse(sql, Dialect::Ansi).unwrap();
1243        let schema = test_schema();
1244        let config = test_config();
1245
1246        let graph = lineage("col", &ast, &schema, &config).unwrap();
1247        let dot = graph.to_dot();
1248
1249        assert!(dot.contains("digraph lineage"));
1250        assert!(dot.contains("rankdir=BT"));
1251    }
1252
1253    #[test]
1254    fn test_to_mermaid() {
1255        let sql = "SELECT a AS col FROM t";
1256        let ast = parse(sql, Dialect::Ansi).unwrap();
1257        let schema = test_schema();
1258        let config = test_config();
1259
1260        let graph = lineage("col", &ast, &schema, &config).unwrap();
1261        let mermaid = graph.to_mermaid();
1262
1263        assert!(mermaid.contains("flowchart BT"));
1264    }
1265
1266    #[test]
1267    fn test_case_expression_lineage() {
1268        let sql = "SELECT CASE WHEN a > 0 THEN b ELSE c END AS result FROM t";
1269        let ast = parse(sql, Dialect::Ansi).unwrap();
1270        let schema = test_schema();
1271        let config = test_config();
1272
1273        let graph = lineage("result", &ast, &schema, &config).unwrap();
1274
1275        assert_eq!(graph.node.name, "result");
1276        // Should depend on a, b, and c
1277        assert!(graph.node.downstream.len() >= 2);
1278    }
1279
1280    #[test]
1281    fn test_coalesce_lineage() {
1282        let sql = "SELECT COALESCE(a, b, c) AS val FROM t";
1283        let ast = parse(sql, Dialect::Ansi).unwrap();
1284        let schema = test_schema();
1285        let config = test_config();
1286
1287        let graph = lineage("val", &ast, &schema, &config).unwrap();
1288
1289        assert_eq!(graph.node.name, "val");
1290        // Should depend on a, b, and c
1291        assert_eq!(graph.node.downstream.len(), 3);
1292    }
1293
1294    #[test]
1295    fn test_nested_cte_lineage() {
1296        let sql = r#"
1297            WITH cte1 AS (SELECT a, b FROM t),
1298                 cte2 AS (SELECT a + b AS sum FROM cte1)
1299            SELECT sum FROM cte2
1300        "#;
1301        let ast = parse(sql, Dialect::Ansi).unwrap();
1302        let schema = test_schema();
1303        let config = test_config();
1304
1305        let graph = lineage("sum", &ast, &schema, &config).unwrap();
1306
1307        assert_eq!(graph.node.name, "sum");
1308        // Should trace through both CTEs to t
1309        let sources = graph.source_tables();
1310        assert!(sources.contains(&"t".to_string()));
1311    }
1312
1313    #[test]
1314    fn test_lineage_iterator() {
1315        let sql = "SELECT a + b AS sum FROM t";
1316        let ast = parse(sql, Dialect::Ansi).unwrap();
1317        let schema = test_schema();
1318        let config = test_config();
1319
1320        let graph = lineage("sum", &ast, &schema, &config).unwrap();
1321
1322        let nodes: Vec<_> = graph.node.iter().collect();
1323        assert!(!nodes.is_empty());
1324        assert_eq!(nodes[0].name, "sum");
1325    }
1326
1327    #[test]
1328    fn test_external_sources() {
1329        let schema = test_schema();
1330        let mut sources = HashMap::new();
1331        sources.insert("view1".to_string(), "SELECT a FROM t".to_string());
1332
1333        let config = LineageConfig::new(Dialect::Ansi).with_sources(sources);
1334
1335        let sql = "SELECT a FROM view1";
1336        let result = lineage_sql("a", sql, &schema, &config);
1337        // Should parse and handle external sources
1338        assert!(result.is_ok() || matches!(result, Err(LineageError::ColumnNotFound(_))));
1339    }
1340
1341    #[test]
1342    fn test_qualified_column() {
1343        let sql = "SELECT t.a FROM t";
1344        let ast = parse(sql, Dialect::Ansi).unwrap();
1345        let schema = test_schema();
1346        let config = LineageConfig::new(Dialect::Ansi).with_trim_qualifiers(false);
1347
1348        let graph = lineage("a", &ast, &schema, &config).unwrap();
1349
1350        // With trim_qualifiers=false, should preserve qualification
1351        assert!(graph.node.name.contains('a'));
1352    }
1353}