Skip to main content

windjammer_runtime/
mock_function.rs

1//! Function mocking utilities
2//!
3//! Provides runtime support for mocking global functions.
4//!
5//! Note: Full function mocking requires unsafe code for runtime function
6//! replacement. This module provides a safe framework using function pointers.
7
8use std::cell::RefCell;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12thread_local! {
13    #[allow(clippy::type_complexity)]
14    static FUNCTION_MOCKS: RefCell<HashMap<String, Arc<Mutex<Box<dyn std::any::Any + Send>>>>> = RefCell::new(HashMap::new());
15}
16
17/// Mock a function with a closure
18///
19/// # Safety
20/// This is a safe version that requires explicit checking in the function being mocked.
21/// For true runtime function replacement, unsafe code would be needed.
22///
23/// # Example
24/// ```
25/// use windjammer_runtime::mock_function::{mock_function, is_mocked, clear_mock};
26///
27/// // Simple example showing mock registration
28/// mock_function("get_time", || 12345i64);
29///
30/// // Verify mock is registered
31/// assert!(is_mocked("get_time"));
32///
33/// // Clear the mock
34/// clear_mock("get_time");
35/// assert!(!is_mocked("get_time"));
36/// ```
37pub fn mock_function<F: 'static + Send>(name: &str, mock_fn: F) {
38    FUNCTION_MOCKS.with(|mocks| {
39        mocks
40            .borrow_mut()
41            .insert(name.to_string(), Arc::new(Mutex::new(Box::new(mock_fn))));
42    });
43}
44
45/// Get a mocked function
46pub fn get_mock<F>(name: &str) -> Option<F>
47where
48    F: Clone + 'static,
49{
50    FUNCTION_MOCKS.with(|mocks| {
51        mocks.borrow().get(name).and_then(|mock| {
52            let guard = mock.lock().unwrap();
53            let any_ref = &**guard;
54            // Try to downcast to the function type
55            (any_ref as &dyn std::any::Any).downcast_ref::<F>().cloned()
56        })
57    })
58}
59
60/// Check if a function is mocked
61pub fn is_mocked(name: &str) -> bool {
62    FUNCTION_MOCKS.with(|mocks| mocks.borrow().contains_key(name))
63}
64
65/// Clear a specific mock
66pub fn clear_mock(name: &str) {
67    FUNCTION_MOCKS.with(|mocks| {
68        mocks.borrow_mut().remove(name);
69    });
70}
71
72/// Clear all mocks
73pub fn clear_all_mocks() {
74    FUNCTION_MOCKS.with(|mocks| {
75        mocks.borrow_mut().clear();
76    });
77}
78
79/// Run a function with a temporary mock
80///
81/// Automatically restores the original function after execution.
82///
83/// # Example
84/// ```
85/// use windjammer_runtime::mock_function::with_mock;
86///
87/// fn get_time() -> i64 { 0 } // Simplified
88///
89/// with_mock("get_time", || 12345i64, || {
90///     // Mock is active here
91///     // let time = get_time();
92///     // assert_eq!(time, 12345);
93/// });
94/// // Mock is automatically cleared
95/// ```
96pub fn with_mock<F, R, T>(function_name: &str, mock_fn: F, test_fn: T) -> R
97where
98    F: 'static + Send,
99    T: FnOnce() -> R,
100{
101    mock_function(function_name, mock_fn);
102    let result = test_fn();
103    clear_mock(function_name);
104    result
105}
106
107/// Mock registry for tracking mocked functions
108#[derive(Debug, Default)]
109pub struct MockRegistry {
110    mocks: HashMap<String, usize>, // Function name -> call count
111}
112
113impl MockRegistry {
114    pub fn new() -> Self {
115        Self {
116            mocks: HashMap::new(),
117        }
118    }
119
120    /// Record that a function was called
121    pub fn record_call(&mut self, function_name: &str) {
122        *self.mocks.entry(function_name.to_string()).or_insert(0) += 1;
123    }
124
125    /// Get call count for a function
126    pub fn call_count(&self, function_name: &str) -> usize {
127        self.mocks.get(function_name).copied().unwrap_or(0)
128    }
129
130    /// Was function called?
131    pub fn was_called(&self, function_name: &str) -> bool {
132        self.call_count(function_name) > 0
133    }
134
135    /// Verify function was called exactly N times
136    pub fn verify_called_times(&self, function_name: &str, expected_count: usize) {
137        let actual = self.call_count(function_name);
138        if actual != expected_count {
139            panic!(
140                "Expected {} to be called {} times, but it was called {} times",
141                function_name, expected_count, actual
142            );
143        }
144    }
145
146    /// Reset all counts
147    pub fn reset(&mut self) {
148        self.mocks.clear();
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_mock_function() {
158        mock_function("test_func", || 100);
159        assert!(is_mocked("test_func"));
160        clear_mock("test_func");
161        assert!(!is_mocked("test_func"));
162    }
163
164    #[test]
165    fn test_is_mocked() {
166        assert!(!is_mocked("nonexistent"));
167        mock_function("exists", || 1);
168        assert!(is_mocked("exists"));
169        clear_mock("exists");
170        assert!(!is_mocked("exists"));
171    }
172
173    #[test]
174    fn test_clear_all_mocks() {
175        mock_function("func1", || 1);
176        mock_function("func2", || 2);
177        assert!(is_mocked("func1"));
178        assert!(is_mocked("func2"));
179
180        clear_all_mocks();
181
182        assert!(!is_mocked("func1"));
183        assert!(!is_mocked("func2"));
184    }
185
186    #[test]
187    fn test_with_mock_cleanup() {
188        with_mock(
189            "test_function",
190            || 999,
191            || {
192                // Test body
193            },
194        );
195
196        assert!(!is_mocked("test_function")); // Auto-cleared
197    }
198
199    #[test]
200    fn test_mock_registry() {
201        let mut registry = MockRegistry::new();
202
203        registry.record_call("func1");
204        registry.record_call("func1");
205        registry.record_call("func2");
206
207        assert_eq!(registry.call_count("func1"), 2);
208        assert_eq!(registry.call_count("func2"), 1);
209        assert!(registry.was_called("func1"));
210        assert!(!registry.was_called("func3"));
211    }
212
213    #[test]
214    fn test_verify_called_times_success() {
215        let mut registry = MockRegistry::new();
216        registry.record_call("func");
217        registry.record_call("func");
218
219        registry.verify_called_times("func", 2); // Should not panic
220    }
221
222    #[test]
223    #[should_panic(expected = "Expected func to be called 3 times")]
224    fn test_verify_called_times_failure() {
225        let mut registry = MockRegistry::new();
226        registry.record_call("func");
227        registry.record_call("func");
228
229        registry.verify_called_times("func", 3); // Should panic
230    }
231
232    #[test]
233    fn test_registry_reset() {
234        let mut registry = MockRegistry::new();
235        registry.record_call("func");
236        assert_eq!(registry.call_count("func"), 1);
237
238        registry.reset();
239        assert_eq!(registry.call_count("func"), 0);
240    }
241}