Skip to main content

s3rm_rs/callback/
filter_manager.rs

1//! Filter callback manager.
2//!
3//! Adapted from s3sync's `callback/filter_manager.rs`.
4//! Wraps a single `FilterCallback` trait object for use by the
5//! `UserDefinedFilter` pipeline stage.
6
7use std::fmt;
8use std::sync::Arc;
9
10use anyhow::Result;
11use tokio::sync::Mutex;
12
13use crate::types::S3Object;
14use crate::types::filter_callback::FilterCallback;
15
16/// Manages a registered filter callback.
17///
18/// Holds an optional `FilterCallback` behind `Arc<Mutex<...>>` so it can
19/// be shared across pipeline stages and called asynchronously.
20#[derive(Clone)]
21pub struct FilterManager {
22    callback: Option<Arc<Mutex<Box<dyn FilterCallback + Send + Sync>>>>,
23}
24
25impl Default for FilterManager {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl FilterManager {
32    pub fn new() -> Self {
33        Self { callback: None }
34    }
35
36    /// Register a filter callback implementation.
37    pub fn register_callback<T: FilterCallback + Send + Sync + 'static>(&mut self, callback: T) {
38        self.callback = Some(Arc::new(Mutex::new(Box::new(callback))));
39    }
40
41    /// Returns true if a filter callback has been registered.
42    pub fn is_callback_registered(&self) -> bool {
43        self.callback.is_some()
44    }
45
46    /// Execute the registered filter callback on the given object.
47    ///
48    /// # Panics
49    /// Panics if no callback has been registered. Check `is_callback_registered()` first.
50    pub async fn execute_filter(&self, object: &S3Object) -> Result<bool> {
51        if let Some(callback) = &self.callback {
52            callback.lock().await.filter(object).await
53        } else {
54            panic!("Filter callback is not registered");
55        }
56    }
57}
58
59impl fmt::Debug for FilterManager {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.debug_struct("FilterManager")
62            .field("callback_registered", &self.callback.is_some())
63            .finish()
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use async_trait::async_trait;
71
72    struct AlwaysTrueFilter;
73
74    #[async_trait]
75    impl FilterCallback for AlwaysTrueFilter {
76        async fn filter(&mut self, _object: &S3Object) -> Result<bool> {
77            Ok(true)
78        }
79    }
80
81    struct AlwaysFalseFilter;
82
83    #[async_trait]
84    impl FilterCallback for AlwaysFalseFilter {
85        async fn filter(&mut self, _object: &S3Object) -> Result<bool> {
86            Ok(false)
87        }
88    }
89
90    #[tokio::test]
91    async fn new_manager_has_no_callback() {
92        let manager = FilterManager::new();
93        assert!(!manager.is_callback_registered());
94    }
95
96    #[tokio::test]
97    async fn default_manager_has_no_callback() {
98        let manager = FilterManager::default();
99        assert!(!manager.is_callback_registered());
100    }
101
102    #[tokio::test]
103    async fn register_and_execute_true_filter() {
104        let mut manager = FilterManager::new();
105        manager.register_callback(AlwaysTrueFilter);
106        assert!(manager.is_callback_registered());
107
108        let object =
109            S3Object::NotVersioning(aws_sdk_s3::types::Object::builder().key("test").build());
110        let result = manager.execute_filter(&object).await.unwrap();
111        assert!(result);
112    }
113
114    #[tokio::test]
115    async fn register_and_execute_false_filter() {
116        let mut manager = FilterManager::new();
117        manager.register_callback(AlwaysFalseFilter);
118
119        let object =
120            S3Object::NotVersioning(aws_sdk_s3::types::Object::builder().key("test").build());
121        let result = manager.execute_filter(&object).await.unwrap();
122        assert!(!result);
123    }
124
125    #[tokio::test]
126    #[should_panic(expected = "Filter callback is not registered")]
127    async fn execute_without_registration_panics() {
128        let manager = FilterManager::new();
129        let object =
130            S3Object::NotVersioning(aws_sdk_s3::types::Object::builder().key("test").build());
131        let _ = manager.execute_filter(&object).await;
132    }
133
134    #[test]
135    fn debug_format() {
136        let manager = FilterManager::new();
137        let debug = format!("{manager:?}");
138        assert!(debug.contains("callback_registered: false"));
139    }
140}