Skip to main content

streamling_plugin/
ffi.rs

1#![allow(non_local_definitions)]
2//! This module defines the FFI interface for the plugin system, including types and traits.
3
4use crate::CheckpointEpoch;
5use abi_stable::StableAbi;
6use abi_stable::derive_macro_reexports::NonExhaustive;
7use abi_stable::external_types::crossbeam_channel::{RReceiver, RSender};
8use abi_stable::external_types::parking_lot::mutex::RMutex;
9use abi_stable::std_types::{RArc, RHashMap, RNone, RSome, RString};
10use arrow::array::{Array, ArrayRef, RecordBatch, StructArray, make_array};
11use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
12use arrow::ffi::{FFI_ArrowSchema, from_ffi, to_ffi};
13use arrow_data::ffi::FFI_ArrowArray;
14use crossbeam_channel::TrySendError;
15use datafusion::common::DataFusionError;
16use std::collections::HashMap;
17use std::fmt;
18use std::sync::Arc;
19use std::time::Duration;
20use tracing::{debug, warn};
21
22#[repr(C)]
23#[derive(StableAbi, Debug)]
24pub struct PluginOptions(RHashMap<RString, RString>);
25
26impl PluginOptions {
27    pub fn new(options: HashMap<String, String>) -> Self {
28        PluginOptions(
29            options
30                .into_iter()
31                .map(|(k, v)| (RString::from(k), RString::from(v)))
32                .collect(),
33        )
34    }
35
36    pub fn as_rust(&self) -> HashMap<String, String> {
37        self.0
38            .iter()
39            .map(|t| (t.0.to_string(), t.1.to_string()))
40            .collect()
41    }
42}
43
44/// Logging configuration for the plugin.
45#[repr(u8)]
46#[derive(StableAbi, Debug, Clone)]
47pub enum PluginLogging {
48    Plain,
49    Json,
50}
51
52impl PluginLogging {
53    pub fn initialize_logging(&self) {
54        if tracing::dispatcher::has_been_set() {
55            return;
56        }
57
58        use tracing_subscriber::layer::SubscriberExt;
59        use tracing_subscriber::util::SubscriberInitExt;
60
61        let env_filter = tracing_subscriber::EnvFilter::from_default_env();
62        let init_result = match self {
63            PluginLogging::Json => tracing_subscriber::registry()
64                .with(
65                    tracing_subscriber::fmt::layer()
66                        .with_writer(std::io::stderr)
67                        .fmt_fields(tracing_subscriber::fmt::format::JsonFields::new())
68                        .event_format(streamling_common::logging::FlatJsonFormat),
69                )
70                .with(env_filter)
71                .try_init(),
72            PluginLogging::Plain => tracing_subscriber::registry()
73                .with(
74                    tracing_subscriber::fmt::layer()
75                        .with_writer(std::io::stderr)
76                        .with_thread_ids(true)
77                        .with_thread_names(true),
78                )
79                .with(env_filter)
80                .try_init(),
81        };
82        if init_result.is_err() {
83            eprintln!("Logger already initialized; skipping plugin logging setup.");
84        }
85    }
86}
87
88/// Custom wrapper for FFI_ArrowSchema
89/// DataFusion's wrapper (WrappedSchema) looks almost the same, but since FFI_ArrowSchema
90/// doesn't implement Sync, we need to add a mutex to allow concurrent access
91#[repr(C)]
92#[derive(StableAbi)]
93pub struct SafeArrowSchema {
94    #[sabi(unsafe_opaque_field)]
95    pub schema: RArc<RMutex<FFI_ArrowSchema>>,
96}
97
98impl SafeArrowSchema {
99    pub fn new(schema: FFI_ArrowSchema) -> Self {
100        SafeArrowSchema {
101            schema: RArc::new(RMutex::new(schema)),
102        }
103    }
104}
105
106impl fmt::Debug for SafeArrowSchema {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        // Doing a *non-blocking* lock keeps Debug printing from hanging.
109        match self.schema.try_lock() {
110            RSome(guard) => f
111                .debug_struct("SafeArrowSchema")
112                // Delegate to FFI_ArrowSchema’s Debug
113                .field("schema", &*guard)
114                .finish(),
115            RNone => f
116                .debug_struct("SafeArrowSchema")
117                .field("schema", &"<locked>")
118                .finish(),
119        }
120    }
121}
122
123impl From<SchemaRef> for SafeArrowSchema {
124    fn from(value: SchemaRef) -> Self {
125        SafeArrowSchema::new(FFI_ArrowSchema::try_from(value.as_ref()).unwrap())
126    }
127}
128
129impl From<SafeArrowSchema> for SchemaRef {
130    fn from(value: SafeArrowSchema) -> Self {
131        let schema = value.schema.lock();
132        Arc::new(Schema::try_from(&*schema).unwrap())
133    }
134}
135
136impl From<DataType> for SafeArrowSchema {
137    fn from(value: DataType) -> Self {
138        let field = Field::new("_", value, true);
139        SafeArrowSchema::new(FFI_ArrowSchema::try_from(&field).unwrap())
140    }
141}
142
143impl From<SafeArrowSchema> for DataType {
144    fn from(value: SafeArrowSchema) -> Self {
145        let schema = value.schema.lock();
146        let field = Field::try_from(&*schema).unwrap();
147        field.data_type().clone()
148    }
149}
150
151/// A single Arrow array (column) transported across the FFI boundary.
152#[repr(C)]
153#[derive(StableAbi)]
154pub struct SafeArrowColumn {
155    #[sabi(unsafe_opaque_field)]
156    pub array: FFI_ArrowArray,
157    #[sabi(unsafe_opaque_field)]
158    pub field: RArc<RMutex<FFI_ArrowSchema>>,
159}
160
161/// A UDF argument transported across the FFI boundary, preserving scalar vs array semantics.
162///
163/// When `is_scalar` is true, `column` holds exactly one element representing a constant value.
164/// The host encodes `ColumnarValue::Scalar` as a length-1 array; the plugin reconstructs the
165/// scalar without needing to broadcast it to `N` rows first.
166#[repr(C)]
167#[derive(StableAbi)]
168pub struct SafeUdfArg {
169    pub column: SafeArrowColumn,
170    pub is_scalar: bool,
171}
172
173impl From<ArrayRef> for SafeArrowColumn {
174    fn from(value: ArrayRef) -> Self {
175        let field = Field::new("_", value.data_type().clone(), true);
176        let ffi_schema = FFI_ArrowSchema::try_from(&field).unwrap();
177        let (ffi_array, _) = to_ffi(&value.to_data()).unwrap();
178        SafeArrowColumn {
179            array: ffi_array,
180            field: RArc::new(RMutex::new(ffi_schema)),
181        }
182    }
183}
184
185impl From<SafeArrowColumn> for ArrayRef {
186    fn from(value: SafeArrowColumn) -> Self {
187        let schema = value.field.lock();
188        let array_data = unsafe { from_ffi(value.array, &schema).unwrap() };
189        make_array(array_data)
190    }
191}
192
193#[repr(C)]
194#[derive(StableAbi, Debug)]
195pub struct SafeArrowArray {
196    #[sabi(unsafe_opaque_field)]
197    pub array: FFI_ArrowArray,
198    pub schema: SafeArrowSchema,
199}
200
201impl From<SafeArrowArray> for RecordBatch {
202    fn from(value: SafeArrowArray) -> Self {
203        let schema = value.schema.schema.lock();
204        let array_data = unsafe {
205            from_ffi(value.array, &schema)
206                .map_err(DataFusionError::from)
207                .unwrap()
208        };
209        let array = make_array(array_data);
210        let struct_array = array
211            .as_any()
212            .downcast_ref::<StructArray>()
213            .ok_or(DataFusionError::Execution(
214                "Unexpected array type during record batch collection in FFI_RecordBatchStream"
215                    .to_string(),
216            ))
217            .unwrap();
218
219        struct_array.into()
220    }
221}
222
223impl From<RecordBatch> for SafeArrowArray {
224    fn from(value: RecordBatch) -> Self {
225        let schema: SafeArrowSchema = value.schema().into();
226
227        let struct_array = StructArray::from(value);
228        let (array, _) = to_ffi(&struct_array.into_data()).unwrap();
229
230        SafeArrowArray { array, schema }
231    }
232}
233
234#[repr(C)]
235#[derive(StableAbi, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
236pub struct PluginCheckpointEpoch(pub u64);
237
238impl From<PluginCheckpointEpoch> for CheckpointEpoch {
239    fn from(value: PluginCheckpointEpoch) -> Self {
240        CheckpointEpoch(value.0)
241    }
242}
243
244#[repr(u8)]
245#[derive(StableAbi, Debug)]
246#[sabi(kind(WithNonExhaustive(
247    size = [usize;12],
248    traits(Debug),
249    assert_nonexhaustive(PluginMetric),
250)))]
251#[non_exhaustive]
252pub enum PluginMetric {
253    Count {
254        name: RString,
255        value: u64,
256        tags: RHashMap<RString, RString>,
257    },
258    Gauge {
259        name: RString,
260        value: u64,
261        tags: RHashMap<RString, RString>,
262    },
263    Time {
264        name: RString,
265        duration_ms: u64,
266        tags: RHashMap<RString, RString>,
267    },
268}
269
270#[repr(C)]
271#[derive(StableAbi, Clone, Debug)]
272pub struct PluginMetricsRecorder {
273    sender: RSender<PluginMetric_NE>,
274}
275
276impl PluginMetricsRecorder {
277    pub fn new(sender: RSender<PluginMetric_NE>) -> Self {
278        PluginMetricsRecorder { sender }
279    }
280
281    pub fn record_count(&self, name: &str, value: u64) {
282        self.dispatch_metric(PluginMetric::Count {
283            name: RString::from(name),
284            value,
285            tags: Default::default(),
286        });
287    }
288
289    pub fn record_count_w_tags(&self, name: &str, value: u64, tags: Vec<(&str, &str)>) {
290        let tags = tags
291            .into_iter()
292            .map(|(k, v)| (RString::from(k), RString::from(v)))
293            .collect();
294        self.dispatch_metric(PluginMetric::Count {
295            name: RString::from(name),
296            value,
297            tags,
298        });
299    }
300
301    pub fn record_latency(&self, name: &str, duration: Duration) {
302        self.dispatch_metric(PluginMetric::Time {
303            name: RString::from(name),
304            duration_ms: duration.as_millis() as u64,
305            tags: Default::default(),
306        });
307    }
308
309    pub fn record_latency_w_tags(&self, name: &str, duration: Duration, tags: Vec<(&str, &str)>) {
310        let tags = tags
311            .into_iter()
312            .map(|(k, v)| (RString::from(k), RString::from(v)))
313            .collect();
314        self.dispatch_metric(PluginMetric::Time {
315            name: RString::from(name),
316            duration_ms: duration.as_millis() as u64,
317            tags,
318        });
319    }
320
321    pub fn record_gauge(&self, name: &str, value: u64) {
322        self.dispatch_metric(PluginMetric::Gauge {
323            name: RString::from(name),
324            value,
325            tags: Default::default(),
326        });
327    }
328
329    pub fn record_gauge_w_tags(&self, name: &str, value: u64, tags: Vec<(&str, &str)>) {
330        let tags = tags
331            .into_iter()
332            .map(|(k, v)| (RString::from(k), RString::from(v)))
333            .collect();
334        self.dispatch_metric(PluginMetric::Gauge {
335            name: RString::from(name),
336            value,
337            tags,
338        });
339    }
340
341    pub fn dispatch_metric(&self, metric: PluginMetric) {
342        match self.sender.try_send(NonExhaustive::new(metric)) {
343            Ok(_) => {
344                debug!("Successfully dispatched plugin metrics")
345            }
346            Err(e) => {
347                warn!("Encountered error dispatching metrics. Error: {}", e);
348            }
349        }
350    }
351}
352
353#[repr(C)]
354#[derive(StableAbi, Clone, Debug)]
355pub struct PluginChannel {
356    pub sender: RSender<PluginMsg_NE>,
357    pub receiver: RReceiver<PluginMsg_NE>,
358}
359
360impl PluginChannel {
361    pub fn new(channels: (RSender<PluginMsg_NE>, RReceiver<PluginMsg_NE>)) -> Self {
362        let (sender, receiver) = channels;
363        PluginChannel { sender, receiver }
364    }
365
366    pub async fn send_with_retry<CreatePayloadFn>(
367        &self,
368        runtime: &crate::r#async::PluginAsyncRuntimeObj,
369        op_name: &str,
370        create_payload: CreatePayloadFn,
371    ) -> Result<(), crate::api::PluginError>
372    where
373        CreatePayloadFn: Fn() -> PluginMsg_NE,
374    {
375        self.send_with_retry_callback(
376            runtime,
377            op_name,
378            create_payload,
379            None::<fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>>,
380            Duration::from_millis(50),
381        )
382        .await
383    }
384
385    /// Send a message with retry logic if the channel is full.
386    /// Uses try_send to avoid blocking the thread, and retries with a delay.
387    ///
388    /// # Arguments
389    /// * `runtime` - The async runtime for sleeping between retries
390    /// * `op_name` - Name used in error messages and logging
391    /// * `create_payload` - The function to create the payload to send. This is called every send attempt.
392    /// * `on_retry` - Optional callback executed on each retry attempt. Return false to stop retrying.
393    pub async fn send_with_retry_callback<CreatePayloadFn, OnRetryFn>(
394        &self,
395        runtime: &crate::r#async::PluginAsyncRuntimeObj,
396        op_name: &str,
397        create_payload: CreatePayloadFn,
398        on_retry: Option<OnRetryFn>,
399        retry_delay: Duration,
400    ) -> Result<(), crate::api::PluginError>
401    where
402        CreatePayloadFn: Fn() -> PluginMsg_NE,
403        OnRetryFn: Fn() -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>,
404    {
405        loop {
406            match self.sender.try_send(create_payload()) {
407                Ok(_) => return Ok(()),
408                Err(TrySendError::Full(_)) => {
409                    // Execute optional retry callback and check if we should continue
410                    if let Some(ref callback) = on_retry
411                        && !callback().await
412                    {
413                        return Err(crate::api::PluginError::Execution(format!(
414                            "{} retry callback returned false, stopping retries",
415                            op_name
416                        )));
417                    }
418
419                    runtime.sleep(retry_delay.into()).await;
420                }
421                Err(TrySendError::Disconnected(_)) => {
422                    return Err(crate::api::PluginError::Execution(format!(
423                        "{} output channel disconnected",
424                        op_name
425                    )));
426                }
427            }
428        }
429    }
430}
431
432#[repr(C)]
433#[derive(StableAbi, Clone, Debug)]
434pub struct PluginMetricsChannel {
435    pub sender: RSender<PluginMetric_NE>,
436    pub receiver: RReceiver<PluginMetric_NE>,
437}
438
439impl PluginMetricsChannel {
440    pub fn new(channels: (RSender<PluginMetric_NE>, RReceiver<PluginMetric_NE>)) -> Self {
441        let (sender, receiver) = channels;
442        PluginMetricsChannel { sender, receiver }
443    }
444}
445
446#[repr(C)]
447#[derive(StableAbi, Clone, Debug)]
448pub struct PluginChannels {
449    pub input: PluginChannel,
450    pub output: PluginChannel,
451    pub metrics: PluginMetricsChannel,
452}
453
454// Largest variant is NextBatch (13 words); 5 spare words in [usize; 18] for future growth
455#[repr(u8)]
456#[derive(StableAbi, Debug)]
457#[sabi(kind(WithNonExhaustive(
458    size = [usize;18],
459    traits(Debug),
460    assert_nonexhaustive(PluginMsg),
461)))]
462#[non_exhaustive]
463pub enum PluginMsg {
464    Init,
465    NextBatch { data: SafeArrowArray },
466    CheckpointMarker { epoch: PluginCheckpointEpoch },
467    CheckpointAck { epoch: PluginCheckpointEpoch },
468    CheckpointFinalizer { epoch: PluginCheckpointEpoch },
469    Terminate,
470    Topology { config: RString },
471    Error { message: RString },
472}