Skip to main content

supertable_core/
merge.rs

1//! # SuperTable MERGE Operation
2//!
3//! This module provides the `MergeBuilder` for performing UPSERT (MERGE) operations.
4
5use crate::table::Table;
6use crate::transaction::Transaction;
7use anyhow::Result;
8use arrow::record_batch::RecordBatch;
9use datafusion::logical_expr::Expr;
10use futures::stream::BoxStream;
11
12/// Action to take during a MERGE operation.
13#[derive(Debug, Clone)]
14pub enum MergeAction {
15    /// Update existing rows with assignments.
16    Update(Vec<(String, Expr)>),
17    /// Delete existing rows.
18    Delete,
19    /// Insert new rows with assignments.
20    Insert(Vec<(String, Expr)>),
21}
22
23/// A single clause in a MERGE operation.
24#[derive(Debug, Clone)]
25pub struct MergeClause {
26    /// Optional condition for this clause.
27    pub condition: Option<Expr>,
28    /// The action to take if the condition matches.
29    pub action: MergeAction,
30}
31
32/// Builder for MERGE (UPSERT) operations.
33#[allow(unused)]
34pub struct MergeBuilder {
35    table: Table,
36    source: BoxStream<'static, Result<RecordBatch>>,
37    on_condition: Expr,
38    matched_clauses: Vec<MergeClause>,
39    not_matched_clauses: Vec<MergeClause>,
40}
41
42impl MergeBuilder {
43    pub fn new(
44        table: Table,
45        source: BoxStream<'static, Result<RecordBatch>>,
46        on_condition: Expr,
47    ) -> Self {
48        Self {
49            table,
50            source,
51            on_condition,
52            matched_clauses: Vec::new(),
53            not_matched_clauses: Vec::new(),
54        }
55    }
56
57    /// Adds a clause to be applied when the ON condition matches.
58    pub fn when_matched(mut self, condition: Option<Expr>, action: MergeAction) -> Self {
59        self.matched_clauses.push(MergeClause { condition, action });
60        self
61    }
62
63    /// Adds a clause to be applied when the ON condition does NOT match.
64    pub fn when_not_matched(mut self, condition: Option<Expr>, action: MergeAction) -> Self {
65        self.not_matched_clauses
66            .push(MergeClause { condition, action });
67        self
68    }
69
70    /// Executes the merge operation as a Copy-on-Write process.
71    pub async fn execute(self) -> Result<Transaction> {
72        use datafusion::prelude::*;
73        use futures::StreamExt;
74
75        // 1. Setup DataFusion Context
76        let ctx = SessionContext::new();
77
78        // 2. Register Source Table
79        // We collect source into memory for now (Prototype limitation)
80        let batches: Vec<Result<RecordBatch>> = self.source.collect().await;
81        let source_batches: Vec<RecordBatch> =
82            batches.into_iter().collect::<Result<Vec<RecordBatch>>>()?;
83
84        // Check if source is empty
85        if source_batches.is_empty() {
86            return Ok(self.table.new_transaction());
87        }
88
89        let source_schema = source_batches[0].schema();
90        // Use MemTable
91        let source_provider = datafusion::datasource::MemTable::try_new(
92            source_schema,
93            vec![source_batches.clone()], // Partitions
94        )?;
95        ctx.register_table("source", std::sync::Arc::new(source_provider))?;
96        let source_df = ctx.table("source").await?;
97
98        // 3. Register Target Table
99        // Use TableReader to load target data
100        let storage = self.table.storage.clone();
101        let reader = crate::reader::TableReader::new(storage.clone());
102        let snapshot = self
103            .table
104            .metadata
105            .current_snapshot()
106            .ok_or_else(|| anyhow::anyhow!("No snapshot"))?;
107        let (data_files, _) = snapshot.all_files(&storage).await?;
108
109        let mut target_batches = Vec::new();
110        for file in data_files {
111            let batches = reader.read_file(&file.file_path).await?;
112            target_batches.extend(batches);
113        }
114
115        // Register Target MemTable
116        let target_schema = if !target_batches.is_empty() {
117            target_batches[0].schema()
118        } else {
119            self.table.metadata.current_schema().to_arrow_schema_ref()
120        };
121
122        let target_provider =
123            datafusion::datasource::MemTable::try_new(target_schema, vec![target_batches])?;
124        ctx.register_table("target", std::sync::Arc::new(target_provider))?;
125        let target_df = ctx.table("target").await?;
126
127        // 4. Perform Join to identify Matched vs Not Matched
128        // logic: source LEFT JOIN target ON condition
129        // We need to alias tables to distinguish columns
130
131        // Simplified Logic for Prototype:
132        // 1. Identify PK column (assume field_id=1, "id")
133        let schema = self.table.metadata.current_schema();
134        let id_field = schema
135            .fields
136            .iter()
137            .find(|f| f.id == 1)
138            .ok_or_else(|| anyhow::anyhow!("PK not found"))?;
139        let id_col = &id_field.name;
140
141        // 2. Find IDs to Delete (Matched Rows)
142        // Join Source and Target on ID (assuming on_condition is ID equality for now, or evaluating generic condition)
143        // Let's assume on_condition is `source.id = target.id`
144
145        let join_df = source_df.join(
146            target_df,
147            datafusion::logical_expr::JoinType::Inner,
148            &[id_col],
149            &[id_col],
150            None,
151        )?; // Simplified: assumes join on ID column name
152        let matched_ids_df = join_df.select(vec![col(id_col)])?;
153        let matched_batches = matched_ids_df.collect().await?;
154
155        // Write Equality Deletes for matched IDs
156        let mut ids_to_delete = Vec::new();
157        for batch in &matched_batches {
158            if batch.num_columns() > 0 {
159                ids_to_delete.push(batch.column(0).clone());
160            }
161        }
162
163        let mut tx = self.table.new_transaction();
164
165        if !ids_to_delete.is_empty() {
166            // Concatenate and Write Delete File
167            let total_ids: Vec<&dyn arrow::array::Array> =
168                ids_to_delete.iter().map(|a| a.as_ref()).collect();
169            let combined_ids = arrow::compute::concat(&total_ids)?;
170
171            let del_schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![
172                arrow::datatypes::Field::new(
173                    id_field.name.clone(),
174                    id_field.field_type.to_arrow_datatype(),
175                    false,
176                ),
177            ]));
178
179            let batch =
180                arrow::record_batch::RecordBatch::try_new(del_schema.clone(), vec![combined_ids])?;
181
182            let writer = crate::writer::TableWriter::new(
183                self.table.storage.clone(),
184                self.table.metadata.location.clone(),
185                del_schema,
186            );
187
188            let file_id = uuid::Uuid::new_v4().to_string();
189            let mut data_file = writer
190                .write_batch(&batch, &format!("delete-merge-{}", file_id))
191                .await?;
192            data_file.content = crate::manifest::FileContent::EqualityDeletes;
193            tx.add_file(data_file);
194        }
195
196        // 3. Write New Data (Updates + Inserts)
197        // For simple Upsert (Merge), we just write the whole Source batch as new data
198        // (Assuming we deleted the old versions of matched rows above)
199        // This covers both "Update" (Insert new version) and "Insert" (Insert new row).
200
201        let writer = crate::writer::TableWriter::new(
202            self.table.storage.clone(),
203            self.table.metadata.location.clone(),
204            self.table.metadata.current_schema().to_arrow_schema_ref(),
205        );
206
207        // Concatenate all source batches
208        // Should validate against table schema first? Yes.
209        // Assuming source matches table schema for now.
210        for batch in source_batches {
211            let file_id = uuid::Uuid::new_v4().to_string();
212            let data_file = writer.write_batch(&batch, &file_id).await?;
213            tx.add_file(data_file);
214        }
215
216        Ok(tx)
217    }
218}