1use std::pin::Pin;
55use std::sync::Arc;
56use std::task::{Context, Poll};
57
58use arrow_array::RecordBatch;
59use arrow_schema::SchemaRef;
60use datafusion::error::DataFusionError;
61use datafusion::execution::SendableRecordBatchStream;
62use datafusion::logical_expr::{BinaryExpr, Expr, Operator};
63use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
64use futures::stream::{self, Stream, StreamExt, TryStreamExt};
65
66type BatchStream = Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>;
71
72use crate::errors::FnError;
73use crate::traits::catalog::CatalogTable;
74use crate::traits::storage::Storage;
75
76pub const STORAGE_FILTER_UNENCODABLE: u32 = 0x711;
83
84pub struct StorageCatalogTable {
91 storage: Arc<dyn Storage>,
92 table: String,
93 schema: SchemaRef,
94}
95
96impl std::fmt::Debug for StorageCatalogTable {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("StorageCatalogTable")
99 .field("table", &self.table)
100 .field("schema", &self.schema)
101 .field("storage", &"<dyn Storage>")
102 .finish()
103 }
104}
105
106impl StorageCatalogTable {
107 #[must_use]
110 pub fn new(storage: Arc<dyn Storage>, table: String, schema: SchemaRef) -> Self {
111 Self {
112 storage,
113 table,
114 schema,
115 }
116 }
117
118 #[must_use]
120 pub fn storage(&self) -> &Arc<dyn Storage> {
121 &self.storage
122 }
123
124 #[must_use]
126 pub fn table(&self) -> &str {
127 &self.table
128 }
129}
130
131impl CatalogTable for StorageCatalogTable {
132 fn schema(&self) -> SchemaRef {
133 Arc::clone(&self.schema)
134 }
135
136 fn scan(
137 &self,
138 projection: Option<&[usize]>,
139 filters: &[Expr],
140 limit: Option<usize>,
141 ) -> Result<SendableRecordBatchStream, FnError> {
142 let storage = Arc::clone(&self.storage);
143 let table = self.table.clone();
144 let predicate = and_combine(filters);
145 let projection_owned: Option<Vec<usize>> = projection.map(<[usize]>::to_vec);
146
147 let output_schema: SchemaRef = match projection_owned.as_deref() {
150 Some(p) => project_schema(&self.schema, p),
151 None => Arc::clone(&self.schema),
152 };
153
154 let inner = stream::once(async move {
155 let res = storage.read_batch(&table, predicate.as_ref()).await;
156 match res {
157 Ok(s) => Ok(s),
158 Err(e) if e.code == STORAGE_FILTER_UNENCODABLE => {
161 storage.read_batch(&table, None).await.map_err(fn_err_to_df)
162 }
163 Err(e) => Err(fn_err_to_df(e)),
164 }
165 })
166 .map(|res| match res {
167 Ok(stream) => Ok(stream),
168 Err(e) => Err(e),
169 })
170 .try_flatten();
171
172 let projected = ProjectionAndLimitStream::new(inner.boxed(), projection_owned, limit);
173
174 Ok(Box::pin(RecordBatchStreamAdapter::new(
175 output_schema,
176 projected,
177 )))
178 }
179}
180
181fn and_combine(filters: &[Expr]) -> Option<Expr> {
185 let mut iter = filters.iter().cloned();
186 let first = iter.next()?;
187 Some(iter.fold(first, |acc, next| {
188 Expr::BinaryExpr(BinaryExpr::new(
189 Box::new(acc),
190 Operator::And,
191 Box::new(next),
192 ))
193 }))
194}
195
196fn project_schema(schema: &SchemaRef, projection: &[usize]) -> SchemaRef {
198 let fields: Vec<arrow_schema::Field> = projection
199 .iter()
200 .filter_map(|i| schema.fields().get(*i).map(|f| f.as_ref().clone()))
201 .collect();
202 Arc::new(arrow_schema::Schema::new(fields))
203}
204
205fn fn_err_to_df(e: FnError) -> DataFusionError {
207 DataFusionError::Execution(format!(
208 "plugin Storage::read_batch failed (code 0x{:x}): {}",
209 e.code, e.message
210 ))
211}
212
213struct ProjectionAndLimitStream {
218 inner: BatchStream,
219 projection: Option<Vec<usize>>,
220 remaining: Option<usize>,
221 done: bool,
222}
223
224impl ProjectionAndLimitStream {
225 fn new(inner: BatchStream, projection: Option<Vec<usize>>, limit: Option<usize>) -> Self {
226 Self {
227 inner,
228 projection,
229 remaining: limit,
230 done: false,
231 }
232 }
233
234 fn apply(&self, batch: RecordBatch) -> Result<RecordBatch, DataFusionError> {
235 let projected = if let Some(p) = self.projection.as_deref() {
236 batch
237 .project(p)
238 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?
239 } else {
240 batch
241 };
242 Ok(projected)
243 }
244}
245
246impl Stream for ProjectionAndLimitStream {
247 type Item = Result<RecordBatch, DataFusionError>;
248
249 fn poll_next(
250 mut self: std::pin::Pin<&mut Self>,
251 cx: &mut Context<'_>,
252 ) -> Poll<Option<Self::Item>> {
253 if self.done {
254 return Poll::Ready(None);
255 }
256 match self.inner.poll_next_unpin(cx) {
257 Poll::Pending => Poll::Pending,
258 Poll::Ready(None) => {
259 self.done = true;
260 Poll::Ready(None)
261 }
262 Poll::Ready(Some(Err(e))) => {
263 self.done = true;
264 Poll::Ready(Some(Err(e)))
265 }
266 Poll::Ready(Some(Ok(batch))) => {
267 let projected = match self.apply(batch) {
268 Ok(b) => b,
269 Err(e) => {
270 self.done = true;
271 return Poll::Ready(Some(Err(e)));
272 }
273 };
274 let take = match self.remaining {
275 Some(n) if n <= projected.num_rows() => {
276 self.done = true;
277 n
278 }
279 Some(n) => {
280 self.remaining = Some(n - projected.num_rows());
281 projected.num_rows()
282 }
283 None => projected.num_rows(),
284 };
285 if take == projected.num_rows() {
286 Poll::Ready(Some(Ok(projected)))
287 } else {
288 Poll::Ready(Some(Ok(projected.slice(0, take))))
290 }
291 }
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use arrow_array::{Int64Array, StringArray};
300 use arrow_schema::{DataType, Field, Schema};
301 use async_trait::async_trait;
302 use datafusion::execution::SendableRecordBatchStream;
303 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
304 use futures::stream::{self, StreamExt};
305 use std::sync::Mutex;
306
307 use crate::traits::storage::WriteHandle;
308
309 struct StaticStorage {
310 batches: Mutex<Vec<RecordBatch>>,
311 schema: SchemaRef,
312 last_predicate: Mutex<Option<Expr>>,
313 fail_on_filter: bool,
314 }
315
316 #[async_trait]
317 impl Storage for StaticStorage {
318 async fn read_batch(
319 &self,
320 _table: &str,
321 predicate: Option<&Expr>,
322 ) -> Result<SendableRecordBatchStream, FnError> {
323 if self.fail_on_filter && predicate.is_some() {
324 return Err(FnError::new(STORAGE_FILTER_UNENCODABLE, "unencodable"));
325 }
326 *self.last_predicate.lock().expect("predicate mutex") = predicate.cloned();
327 let batches = self.batches.lock().expect("batches mutex").clone();
328 let schema = Arc::clone(&self.schema);
329 let s = stream::iter(batches.into_iter().map(Ok));
330 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, s)))
331 }
332
333 async fn write_batch(
334 &self,
335 _table: &str,
336 _batch: &RecordBatch,
337 ) -> Result<WriteHandle, FnError> {
338 Err(FnError::new(1, "read-only fixture"))
339 }
340
341 async fn list_tables(&self) -> Result<Vec<String>, FnError> {
342 Ok(vec!["t".to_owned()])
343 }
344
345 async fn delete(&self, _table: &str, _predicate: &Expr) -> Result<u64, FnError> {
346 Err(FnError::new(1, "read-only fixture"))
347 }
348 }
349
350 fn fixture_schema() -> SchemaRef {
351 Arc::new(Schema::new(vec![
352 Field::new("id", DataType::Int64, false),
353 Field::new("name", DataType::Utf8, true),
354 ]))
355 }
356
357 fn fixture_batch(schema: &SchemaRef, ids: &[i64], names: &[&str]) -> RecordBatch {
358 let id_arr = Arc::new(Int64Array::from(ids.to_vec()));
359 let name_arr = Arc::new(StringArray::from_iter(names.iter().map(|s| Some(*s))));
360 RecordBatch::try_new(Arc::clone(schema), vec![id_arr, name_arr]).expect("fixture batch")
361 }
362
363 #[tokio::test]
364 async fn full_scan_streams_all_rows() {
365 let schema = fixture_schema();
366 let storage = Arc::new(StaticStorage {
367 batches: Mutex::new(vec![fixture_batch(&schema, &[1, 2, 3], &["a", "b", "c"])]),
368 schema: Arc::clone(&schema),
369 last_predicate: Mutex::new(None),
370 fail_on_filter: false,
371 });
372 let storage: Arc<dyn Storage> = storage;
373 let table = StorageCatalogTable::new(storage, "people".to_owned(), schema);
374
375 let mut stream = table.scan(None, &[], None).expect("scan starts");
376 let mut total = 0usize;
377 while let Some(b) = stream.next().await {
378 total += b.expect("batch").num_rows();
379 }
380 assert_eq!(total, 3);
381 }
382
383 #[tokio::test]
384 async fn limit_is_applied_client_side() {
385 let schema = fixture_schema();
386 let storage = Arc::new(StaticStorage {
387 batches: Mutex::new(vec![fixture_batch(&schema, &[1, 2, 3], &["a", "b", "c"])]),
388 schema: Arc::clone(&schema),
389 last_predicate: Mutex::new(None),
390 fail_on_filter: false,
391 });
392 let storage: Arc<dyn Storage> = storage;
393 let table = StorageCatalogTable::new(storage, "people".to_owned(), schema);
394
395 let mut stream = table.scan(None, &[], Some(2)).expect("scan starts");
396 let mut total = 0usize;
397 while let Some(b) = stream.next().await {
398 total += b.expect("batch").num_rows();
399 }
400 assert_eq!(total, 2);
401 }
402
403 #[tokio::test]
404 async fn projection_drops_columns() {
405 let schema = fixture_schema();
406 let storage = Arc::new(StaticStorage {
407 batches: Mutex::new(vec![fixture_batch(&schema, &[1, 2], &["a", "b"])]),
408 schema: Arc::clone(&schema),
409 last_predicate: Mutex::new(None),
410 fail_on_filter: false,
411 });
412 let table = StorageCatalogTable::new(storage, "people".to_owned(), Arc::clone(&schema));
413
414 let mut stream = table.scan(Some(&[0]), &[], None).expect("scan starts");
415 let mut total_cols = 0usize;
416 while let Some(b) = stream.next().await {
417 let b = b.expect("batch");
418 total_cols = b.num_columns();
419 }
420 assert_eq!(total_cols, 1, "projection should drop name column");
421 }
422
423 #[tokio::test]
424 async fn unencodable_filter_falls_back_to_unfiltered() {
425 use datafusion::logical_expr::{col, lit};
426 let schema = fixture_schema();
427 let storage = Arc::new(StaticStorage {
428 batches: Mutex::new(vec![fixture_batch(&schema, &[1, 2, 3], &["a", "b", "c"])]),
429 schema: Arc::clone(&schema),
430 last_predicate: Mutex::new(None),
431 fail_on_filter: true,
432 });
433 let storage: Arc<dyn Storage> = storage;
434 let table = StorageCatalogTable::new(storage, "people".to_owned(), schema);
435
436 let filter = col("id").eq(lit(2_i64));
437 let mut stream = table.scan(None, &[filter], None).expect("scan starts");
438 let mut total = 0usize;
439 while let Some(b) = stream.next().await {
440 total += b.expect("batch").num_rows();
441 }
442 assert_eq!(total, 3);
446 }
447}