1#![allow(non_local_definitions)]
2
3pub mod api;
4pub mod r#async;
5mod dispatch;
6pub mod ffi;
7
8use crate::api::PluginStateBackendFactory;
9pub use crate::api::{
10 CheckpointEpoch, PluginError, PluginStateBackend, PreprocessorPlugin, SideOutputPlugin,
11 SinkPlugin, SourcePlugin, TransformPlugin,
12};
13use crate::r#async::PluginAsyncRuntimeObj;
14pub use crate::dispatch::{
15 PreprocessorPluginDispatcher, SinkPluginDispatcher, SourcePluginDispatcher,
16 TransformPluginDispatcher,
17};
18use crate::ffi::PluginMetricsRecorder;
19pub use crate::ffi::SafeArrowSchema;
20pub use crate::ffi::{
21 PluginChannel, PluginChannels, PluginCheckpointEpoch, PluginLogging, PluginMsg, PluginOptions,
22 SafeArrowColumn, SafeUdfArg,
23};
24use abi_stable::std_types::{RHashMap, RNone, ROption, RResult, RSome, RString, RVec};
25use abi_stable::traits::IntoReprC;
26use abi_stable::{
27 StableAbi, declare_root_module_statics,
28 library::{LibraryError, RootModule},
29 package_version_strings,
30 sabi_types::VersionStrings,
31};
32use arrow::array::ArrayRef;
33use arrow::datatypes::{Field, SchemaRef};
34use async_ffi::{FfiFuture, FutureExt};
35use datafusion::common::ScalarValue;
36use datafusion::logical_expr::{
37 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, TypeSignature,
38};
39use std::collections::HashMap;
40use std::sync::Arc;
41pub use streamling_plugin_derive::*;
42pub use streamling_state::{StateKey, StateOperatorBackend};
43use tracing::{error, info};
44
45#[repr(C)]
52#[derive(StableAbi, Debug, Clone)]
53pub struct PluginLabel {
54 pub key: RString,
55 pub value: RString,
56}
57
58impl PluginLabel {
59 pub fn new(key: impl Into<RString>, value: impl Into<RString>) -> Self {
67 PluginLabel {
68 key: key.into(),
69 value: value.into(),
70 }
71 }
72}
73
74#[repr(C)]
75#[derive(StableAbi)]
76pub struct PluginResult {
77 pub execution_future: FfiFuture<RResult<(), RString>>,
82 pub output_schema: ROption<SafeArrowSchema>,
84 pub labels: RVec<PluginLabel>,
88}
89
90#[repr(C)]
91#[derive(StableAbi, Debug, Clone, Copy)]
92pub struct PluginChannelCaps {
93 pub input: u32,
94 pub output: u32,
95 pub metrics: u32,
96}
97
98#[repr(C)]
99#[derive(StableAbi, Debug)]
100pub struct PluginRuntimeConfiguration {
101 pub plugin_ids: RVec<RString>,
102 pub default_channel_caps: RHashMap<RString, PluginChannelCaps>,
104}
105
106impl PluginResult {
107 pub fn new(
108 execution_future: FfiFuture<RResult<(), RString>>,
109 output_schema: ROption<SafeArrowSchema>,
110 ) -> Self {
111 PluginResult {
112 execution_future,
113 output_schema,
114 labels: RVec::new(),
115 }
116 }
117
118 pub fn with_labels(mut self, labels: Vec<PluginLabel>) -> Self {
129 self.labels = labels.into();
130 self
131 }
132}
133
134#[repr(u8)]
135#[derive(StableAbi, Debug)]
136pub enum PluginInitializationError {
137 NotImplemented,
138 Configuration(RString),
139 Execution(RString),
140}
141
142#[repr(C)]
143#[derive(StableAbi, Debug)]
144pub struct PluginStateBackendConfig {
145 pub application_namespace: RString,
146 pub plugin_reference_name: RString,
148 pub serialized_config: RString,
154}
155
156impl PluginStateBackendConfig {
157 pub fn new(
158 application_namespace: String,
159 plugin_reference_name: String,
160 serialized_config: String,
161 ) -> Self {
162 PluginStateBackendConfig {
163 application_namespace: application_namespace.into_c(),
164 plugin_reference_name: plugin_reference_name.into_c(),
165 serialized_config: serialized_config.into_c(),
166 }
167 }
168}
169
170pub trait IntoSourcePluginResult {
174 fn into_source_result(self) -> Result<Arc<dyn SourcePlugin>, PluginInitializationError>;
175}
176
177impl<T: SourcePlugin + 'static> IntoSourcePluginResult for T {
178 fn into_source_result(self) -> Result<Arc<dyn SourcePlugin>, PluginInitializationError> {
179 Ok(Arc::new(self))
180 }
181}
182
183impl<T: SourcePlugin + 'static> IntoSourcePluginResult for Result<T, PluginInitializationError> {
184 fn into_source_result(self) -> Result<Arc<dyn SourcePlugin>, PluginInitializationError> {
185 self.map(|s| Arc::new(s) as Arc<dyn SourcePlugin>)
186 }
187}
188
189pub trait IntoTransformPluginResult {
193 fn into_transform_result(self) -> Result<Arc<dyn TransformPlugin>, PluginInitializationError>;
194}
195
196impl<T: TransformPlugin + 'static> IntoTransformPluginResult for T {
197 fn into_transform_result(self) -> Result<Arc<dyn TransformPlugin>, PluginInitializationError> {
198 Ok(Arc::new(self))
199 }
200}
201
202impl<T: TransformPlugin + 'static> IntoTransformPluginResult
203 for Result<T, PluginInitializationError>
204{
205 fn into_transform_result(self) -> Result<Arc<dyn TransformPlugin>, PluginInitializationError> {
206 self.map(|t| Arc::new(t) as Arc<dyn TransformPlugin>)
207 }
208}
209
210pub trait IntoSinkPluginResult {
214 fn into_sink_result(self) -> Result<Arc<dyn SinkPlugin>, PluginInitializationError>;
215}
216
217impl<T: SinkPlugin + 'static> IntoSinkPluginResult for T {
218 fn into_sink_result(self) -> Result<Arc<dyn SinkPlugin>, PluginInitializationError> {
219 Ok(Arc::new(self))
220 }
221}
222
223impl<T: SinkPlugin + 'static> IntoSinkPluginResult for Result<T, PluginInitializationError> {
224 fn into_sink_result(self) -> Result<Arc<dyn SinkPlugin>, PluginInitializationError> {
225 self.map(|s| Arc::new(s) as Arc<dyn SinkPlugin>)
226 }
227}
228
229fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
230 if let Some(s) = payload.downcast_ref::<&str>() {
231 s.to_string()
232 } else if let Ok(s) = payload.downcast::<String>() {
233 *s
234 } else {
235 "unknown panic during plugin creation".to_string()
236 }
237}
238
239pub fn source_generator<F>(
240 id: RString,
241 factory: F,
242 options: PluginOptions,
243 runtime: PluginAsyncRuntimeObj,
244 state_backend_config: PluginStateBackendConfig,
245 message_channels: PluginChannels,
246) -> RResult<PluginResult, PluginInitializationError>
247where
248 F: FnOnce(
249 PluginAsyncRuntimeObj,
250 PluginStateBackendFactory,
251 PluginMetricsRecorder,
252 HashMap<String, String>,
253 ) -> Result<Arc<dyn SourcePlugin>, PluginInitializationError>,
254{
255 info!("Creating {} with options: {:?}", id, options);
256
257 let state_backend_factory = PluginStateBackendFactory::new(state_backend_config);
258 let metrics_recorder = PluginMetricsRecorder::new(message_channels.metrics.sender.clone());
259 let source = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
262 factory(
263 runtime.clone(),
264 state_backend_factory,
265 metrics_recorder,
266 options.as_rust(),
267 )
268 })) {
269 Ok(Ok(source)) => source,
270 Ok(Err(e)) => return Err(e).into_c(),
271 Err(panic_payload) => {
272 return Err(PluginInitializationError::Configuration(RString::from(
273 panic_payload_to_string(panic_payload),
274 )))
275 .into_c();
276 }
277 };
278 let labels = source.labels();
279 let output_schema = match source.output_schema() {
280 Ok(schema) => schema,
281 Err(e) => {
282 return RResult::RErr(PluginInitializationError::Configuration(RString::from(
283 e.to_string(),
284 )));
285 }
286 };
287 let dispatcher = SourcePluginDispatcher::new(message_channels, source);
288
289 let rt = runtime.clone();
290 let worker = async move {
291 match dispatcher.start(rt).await {
292 Ok(()) => (),
293 Err(e) => {
294 error!("Plugin error {}: {:?}", id, e);
295 panic!("Plugin error {}: {:?}", id, e);
296 }
297 }
298 }
299 .into_ffi();
300
301 let spawned = runtime.spawn(worker);
302
303 let dispatcher_future = async move {
304 spawned.await;
305 RResult::ROk(())
306 }
307 .into_ffi();
308
309 Ok(PluginResult::new(dispatcher_future, RSome(output_schema.into())).with_labels(labels))
310 .into_c()
311}
312
313pub fn transform_generator<F>(
314 id: RString,
315 factory: F,
316 input_schema: SafeArrowSchema,
317 options: PluginOptions,
318 runtime: PluginAsyncRuntimeObj,
319 state_backend_config: PluginStateBackendConfig,
320 message_channels: PluginChannels,
321) -> RResult<PluginResult, PluginInitializationError>
322where
323 F: FnOnce(
324 SchemaRef,
325 PluginAsyncRuntimeObj,
326 PluginStateBackendFactory,
327 PluginMetricsRecorder,
328 HashMap<String, String>,
329 ) -> Result<Arc<dyn TransformPlugin>, PluginInitializationError>,
330{
331 info!("Creating {} with options: {:?}", id, options);
332
333 let state_backend_factory = PluginStateBackendFactory::new(state_backend_config);
334 let metrics_recorder = PluginMetricsRecorder::new(message_channels.metrics.sender.clone());
335
336 let transform = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
338 factory(
339 input_schema.into(),
340 runtime.clone(),
341 state_backend_factory,
342 metrics_recorder,
343 options.as_rust(),
344 )
345 })) {
346 Ok(Ok(transform)) => transform,
347 Ok(Err(e)) => return Err(e).into_c(),
348 Err(panic_payload) => {
349 return Err(PluginInitializationError::Configuration(RString::from(
350 panic_payload_to_string(panic_payload),
351 )))
352 .into_c();
353 }
354 };
355 let labels = transform.labels();
356 let output_schema = match transform.output_schema() {
357 Ok(schema) => schema,
358 Err(e) => {
359 return RResult::RErr(PluginInitializationError::Configuration(RString::from(
360 e.to_string(),
361 )));
362 }
363 };
364
365 let dispatcher = TransformPluginDispatcher::new(message_channels, transform);
366
367 let rt = runtime.clone();
368 let worker = async move {
369 match dispatcher.start(rt).await {
370 Ok(()) => (),
371 Err(e) => {
372 error!("Plugin error {}: {:?}", id, e);
373 panic!("Plugin error {}: {:?}", id, e);
374 }
375 }
376 }
377 .into_ffi();
378
379 let spawned = runtime.spawn(worker);
380
381 let dispatcher_future = async move {
382 spawned.await;
383 RResult::ROk(())
384 }
385 .into_ffi();
386
387 Ok(PluginResult::new(dispatcher_future, RSome(output_schema.into())).with_labels(labels))
388 .into_c()
389}
390
391pub fn sink_generator<F>(
392 id: RString,
393 factory: F,
394 input_schema: SafeArrowSchema,
395 options: PluginOptions,
396 runtime: PluginAsyncRuntimeObj,
397 state_backend_config: PluginStateBackendConfig,
398 message_channels: PluginChannels,
399) -> RResult<PluginResult, PluginInitializationError>
400where
401 F: FnOnce(
402 SchemaRef,
403 PluginAsyncRuntimeObj,
404 PluginStateBackendFactory,
405 PluginMetricsRecorder,
406 HashMap<String, String>,
407 ) -> Result<Arc<dyn SinkPlugin>, PluginInitializationError>,
408{
409 info!("Creating {} with options: {:?}", id, options);
410
411 let state_backend_factory = PluginStateBackendFactory::new(state_backend_config);
412 let metrics_recorder = PluginMetricsRecorder::new(message_channels.metrics.sender.clone());
413 let sink = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
415 factory(
416 input_schema.into(),
417 runtime.clone(),
418 state_backend_factory,
419 metrics_recorder,
420 options.as_rust(),
421 )
422 })) {
423 Ok(Ok(sink)) => sink,
424 Ok(Err(e)) => return Err(e).into_c(),
425 Err(panic_payload) => {
426 return Err(PluginInitializationError::Configuration(RString::from(
427 panic_payload_to_string(panic_payload),
428 )))
429 .into_c();
430 }
431 };
432 let labels = sink.labels();
433
434 let rt = runtime.clone();
435 let worker = async move {
436 let dispatcher = SinkPluginDispatcher::new(message_channels, sink);
437 match dispatcher.start(rt).await {
438 Ok(()) => (),
439 Err(e) => {
440 error!("Plugin error {}: {:?}", id, e);
441 panic!("Plugin error {}: {:?}", id, e);
442 }
443 }
444 }
445 .into_ffi();
446
447 let spawned = runtime.spawn(worker);
448
449 let dispatcher_future = async move {
450 spawned.await;
451 RResult::ROk(())
452 }
453 .into_ffi();
454
455 Ok(PluginResult::new(dispatcher_future, RNone).with_labels(labels)).into_c()
456}
457
458pub fn preprocessor_generator<F>(
459 id: RString,
460 factory: F,
461 options: PluginOptions,
462 runtime: PluginAsyncRuntimeObj,
463 message_channels: PluginChannels,
464) -> RResult<PluginResult, PluginInitializationError>
465where
466 F: FnOnce(
467 HashMap<String, String>,
468 ) -> Result<Arc<dyn PreprocessorPlugin>, PluginInitializationError>,
469{
470 info!("Creating preprocessor {} with options: {:?}", id, options);
471
472 let preprocessor = match factory(options.as_rust()) {
473 Ok(p) => p,
474 Err(e) => return Err(e).into_c(),
475 };
476 let dispatcher = PreprocessorPluginDispatcher::new(message_channels, preprocessor);
477
478 let worker_error: Arc<std::sync::OnceLock<String>> = Arc::new(std::sync::OnceLock::new());
479 let worker_error_writer = worker_error.clone();
480
481 let worker = async move {
482 if let Err(e) = dispatcher.start().await {
483 error!("Preprocessor plugin error {}: {:?}", id, e);
484 let _ = worker_error_writer.set(format!("{e}"));
485 }
486 }
487 .into_ffi();
488
489 let spawned = runtime.spawn(worker);
490
491 let dispatcher_future = async move {
492 spawned.await;
493 match worker_error.get() {
494 Some(err) => RResult::RErr(RString::from(err.clone())),
495 None => RResult::ROk(()),
496 }
497 }
498 .into_ffi();
499
500 Ok(PluginResult::new(dispatcher_future, RNone)).into_c()
501}
502
503#[repr(C)]
505#[derive(StableAbi)]
506pub struct PluginUdfDescriptor {
507 pub name: RString,
508 pub aliases: RVec<RString>,
509 pub type_signatures: RVec<RVec<SafeArrowSchema>>,
510 pub return_type: SafeArrowSchema,
511 pub deterministic: bool,
512 pub invoke: extern "C" fn(
513 args: RVec<SafeUdfArg>,
514 number_rows: usize,
515 ) -> RResult<SafeArrowColumn, RString>,
516}
517
518pub fn invoke_plugin_udf(
524 instance: &dyn ScalarUDFImpl,
525 args: RVec<SafeUdfArg>,
526 number_rows: usize,
527) -> RResult<SafeArrowColumn, RString> {
528 let columnar_args: Vec<ColumnarValue> = args
529 .into_iter()
530 .map(|arg| {
531 let array = ArrayRef::from(arg.column);
532 if arg.is_scalar {
533 match ScalarValue::try_from_array(array.as_ref(), 0) {
534 Ok(s) => ColumnarValue::Scalar(s),
535 Err(_) => ColumnarValue::Array(array),
536 }
537 } else {
538 ColumnarValue::Array(array)
539 }
540 })
541 .collect();
542
543 let arg_fields: Vec<Arc<Field>> = columnar_args
544 .iter()
545 .map(|cv| match cv {
546 ColumnarValue::Array(a) => Arc::new(Field::new("_", a.data_type().clone(), true)),
547 ColumnarValue::Scalar(s) => Arc::new(Field::new("_", s.data_type(), true)),
548 })
549 .collect();
550
551 let scalar_storage: Vec<Option<ScalarValue>> = columnar_args
552 .iter()
553 .map(|cv| match cv {
554 ColumnarValue::Scalar(s) => Some(s.clone()),
555 ColumnarValue::Array(_) => None,
556 })
557 .collect();
558 let scalar_argument_refs: Vec<Option<&ScalarValue>> =
559 scalar_storage.iter().map(|opt| opt.as_ref()).collect();
560
561 let return_field = match instance.return_type(&[]) {
562 Ok(dt) => Arc::new(Field::new("result", dt, true)),
563 Err(_) => {
564 let fallback_args = ReturnFieldArgs {
565 arg_fields: &arg_fields,
566 scalar_arguments: &scalar_argument_refs,
567 };
568 match instance.return_field_from_args(fallback_args) {
569 Ok(field) => field,
570 Err(e) => return RResult::RErr(RString::from(e.to_string())),
571 }
572 }
573 };
574
575 let scalar_args = ScalarFunctionArgs {
576 args: columnar_args,
577 arg_fields,
578 number_rows,
579 return_field,
580 };
581
582 match instance.invoke_with_args(scalar_args) {
583 Ok(ColumnarValue::Array(arr)) => RResult::ROk(SafeArrowColumn::from(arr)),
584 Ok(ColumnarValue::Scalar(s)) => match s.to_array_of_size(number_rows.max(1)) {
585 Ok(arr) => RResult::ROk(SafeArrowColumn::from(arr)),
586 Err(e) => RResult::RErr(RString::from(e.to_string())),
587 },
588 Err(e) => RResult::RErr(RString::from(e.to_string())),
589 }
590}
591
592#[repr(C)]
596#[derive(StableAbi, Clone)]
597pub struct PluginSideOutputDescriptor {
598 pub id: RString,
599 pub initialize: extern "C" fn(
600 source_name: RString,
601 schema: SafeArrowSchema,
602 options: PluginOptions,
603 metrics_recorder: PluginMetricsRecorder,
604 ) -> RResult<(), RString>,
605 pub process_batch:
606 extern "C" fn(source_name: RString, data: ffi::SafeArrowArray) -> RResult<(), RString>,
607 pub shutdown: extern "C" fn() -> RResult<(), RString>,
608}
609
610pub fn build_plugin_udf_descriptor(
613 instance: &dyn ScalarUDFImpl,
614 invoke: extern "C" fn(
615 args: RVec<SafeUdfArg>,
616 number_rows: usize,
617 ) -> RResult<SafeArrowColumn, RString>,
618) -> Result<PluginUdfDescriptor, PluginInitializationError> {
619 let sig = instance.signature();
620 let type_signatures: RVec<RVec<SafeArrowSchema>> = match &sig.type_signature {
621 TypeSignature::Exact(types) => {
622 let converted: RVec<SafeArrowSchema> = types
623 .iter()
624 .map(|dt| SafeArrowSchema::from(dt.clone()))
625 .collect();
626 RVec::from(vec![converted])
627 }
628 TypeSignature::OneOf(variants) => {
629 let mut converted = Vec::with_capacity(variants.len());
630 for variant in variants {
631 match variant {
632 TypeSignature::Exact(types) => {
633 converted.push(
634 types
635 .iter()
636 .map(|dt| SafeArrowSchema::from(dt.clone()))
637 .collect(),
638 );
639 }
640 other => {
641 return Err(PluginInitializationError::Configuration(RString::from(
642 format!(
643 "Plugin UDFs only support Exact type signatures within OneOf, got: {:?}",
644 other
645 ),
646 )));
647 }
648 }
649 }
650 RVec::from(converted)
651 }
652 other => {
653 return Err(PluginInitializationError::Configuration(RString::from(
654 format!(
655 "Plugin UDFs only support Exact and OneOf type signatures, got: {:?}",
656 other
657 ),
658 )));
659 }
660 };
661 let return_type = match instance.return_type(&[]) {
662 Ok(dt) => dt,
663 Err(_) => {
664 let fallback_args = ReturnFieldArgs {
665 arg_fields: &[],
666 scalar_arguments: &[],
667 };
668 instance
669 .return_field_from_args(fallback_args)
670 .map_err(|e| {
671 PluginInitializationError::Configuration(RString::from(format!(
672 "UDF must implement either return_type or return_field_from_args: {e}"
673 )))
674 })?
675 .data_type()
676 .clone()
677 }
678 };
679 let deterministic = sig.volatility == datafusion::logical_expr::Volatility::Immutable;
680 let aliases: RVec<RString> = instance
681 .aliases()
682 .iter()
683 .map(|a| RString::from(a.as_str()))
684 .collect();
685 Ok(PluginUdfDescriptor {
686 name: RString::from(instance.name()),
687 aliases,
688 type_signatures,
689 return_type: SafeArrowSchema::from(return_type),
690 deterministic,
691 invoke,
692 })
693}
694
695#[cfg(test)]
696mod safe_udf_arg_tests {
697 use super::*;
698 use arrow::array::StringArray;
699 use std::sync::Arc;
700
701 #[test]
702 fn scalar_arg_round_trips_to_columnar_scalar() {
703 let arr = Arc::new(StringArray::from(vec!["url"])) as ArrayRef;
704 let ffi_arg = SafeUdfArg {
705 column: SafeArrowColumn::from(arr),
706 is_scalar: true,
707 };
708
709 let array = ArrayRef::from(ffi_arg.column);
710 assert!(ffi_arg.is_scalar);
711 let sv = ScalarValue::try_from_array(array.as_ref(), 0).unwrap();
712 assert_eq!(sv, ScalarValue::Utf8(Some("url".to_string())));
713 }
714
715 #[test]
716 fn array_arg_round_trips_to_columnar_array() {
717 let arr = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
718 let ffi_arg = SafeUdfArg {
719 column: SafeArrowColumn::from(arr.clone()),
720 is_scalar: false,
721 };
722
723 let restored = ArrayRef::from(ffi_arg.column);
724 assert_eq!(restored.len(), 3);
725 }
726}
727
728#[repr(C)]
730#[derive(StableAbi)]
731#[sabi(kind(Prefix(prefix_ref = PluginModuleRef)))]
732pub struct PluginModule {
733 pub init: extern "C" fn(
735 logging: PluginLogging,
736 ) -> RResult<PluginRuntimeConfiguration, PluginInitializationError>,
737
738 #[sabi(last_prefix_field)]
744 pub create: extern "C" fn(
746 plugin_id: RString,
747 input_schema: ROption<SafeArrowSchema>,
748 options: PluginOptions,
749 runtime: PluginAsyncRuntimeObj,
750 state_backend_config: PluginStateBackendConfig,
751 message_channels: PluginChannels,
752 ) -> RResult<PluginResult, PluginInitializationError>,
753
754 pub udf_descriptors:
756 extern "C" fn() -> RResult<RVec<PluginUdfDescriptor>, PluginInitializationError>,
757
758 pub side_output_descriptors:
760 extern "C" fn() -> RResult<RVec<PluginSideOutputDescriptor>, PluginInitializationError>,
761}
762
763impl RootModule for PluginModuleRef {
764 declare_root_module_statics! {PluginModuleRef}
765 const BASE_NAME: &'static str = "streamling_plugin";
766 const NAME: &'static str = "streamling_plugin";
767 const VERSION_STRINGS: VersionStrings = package_version_strings!();
768
769 fn initialization(self) -> Result<Self, LibraryError> {
770 Ok(self)
771 }
772}