1use crate::error::{ConsumerError, ConsumerResult};
4use crate::ffi_bindings::{
5 FfiBuffer, FfiPluginHandle, PluginCallFn, PluginCallRawFn, PluginFreeBufferFn,
6 PluginGetRejectedCountFn, PluginGetStateFn, PluginSetLogLevelFn, PluginShutdownFn, RbResponse,
7 RbResponseFreeFn,
8};
9use libloading::Library;
10use rustbridge_core::{LifecycleState, LogLevel, PluginError};
11use rustbridge_transport::ResponseEnvelope;
12use serde::{Serialize, de::DeserializeOwned};
13use std::ffi::CString;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16pub struct NativePlugin {
20 #[allow(dead_code)]
22 library: Library,
23
24 handle: FfiPluginHandle,
26
27 shutdown: AtomicBool,
29
30 call_fn: PluginCallFn,
32 call_raw_fn: Option<PluginCallRawFn>,
33 shutdown_fn: PluginShutdownFn,
34 get_state_fn: PluginGetStateFn,
35 get_rejected_count_fn: PluginGetRejectedCountFn,
36 set_log_level_fn: PluginSetLogLevelFn,
37 free_buffer_fn: PluginFreeBufferFn,
38 rb_response_free_fn: Option<RbResponseFreeFn>,
39}
40
41impl NativePlugin {
42 #[allow(clippy::too_many_arguments)]
51 pub(crate) unsafe fn new(
52 library: Library,
53 handle: FfiPluginHandle,
54 call_fn: PluginCallFn,
55 call_raw_fn: Option<PluginCallRawFn>,
56 shutdown_fn: PluginShutdownFn,
57 get_state_fn: PluginGetStateFn,
58 get_rejected_count_fn: PluginGetRejectedCountFn,
59 set_log_level_fn: PluginSetLogLevelFn,
60 free_buffer_fn: PluginFreeBufferFn,
61 rb_response_free_fn: Option<RbResponseFreeFn>,
62 ) -> Self {
63 Self {
64 library,
65 handle,
66 shutdown: AtomicBool::new(false),
67 call_fn,
68 call_raw_fn,
69 shutdown_fn,
70 get_state_fn,
71 get_rejected_count_fn,
72 set_log_level_fn,
73 free_buffer_fn,
74 rb_response_free_fn,
75 }
76 }
77
78 pub fn call(&self, type_tag: &str, request: &str) -> ConsumerResult<String> {
95 self.ensure_active()?;
96
97 let type_tag_cstr =
98 CString::new(type_tag).map_err(|e| ConsumerError::InvalidResponse(e.to_string()))?;
99
100 let request_bytes = request.as_bytes();
101
102 let buffer: FfiBuffer = unsafe {
104 (self.call_fn)(
105 self.handle,
106 type_tag_cstr.as_ptr(),
107 request_bytes.as_ptr(),
108 request_bytes.len(),
109 )
110 };
111
112 let result = self.process_buffer(&buffer);
114
115 unsafe {
117 let mut buffer = buffer;
118 (self.free_buffer_fn)(&mut buffer);
119 }
120
121 result
122 }
123
124 pub fn call_typed<Req, Res>(&self, type_tag: &str, request: &Req) -> ConsumerResult<Res>
149 where
150 Req: Serialize,
151 Res: DeserializeOwned,
152 {
153 let request_json = serde_json::to_string(request)?;
154 let response_json = self.call(type_tag, &request_json)?;
155 let response: Res = serde_json::from_str(&response_json)?;
156 Ok(response)
157 }
158
159 pub fn call_raw(&self, message_id: u32, request: &[u8]) -> ConsumerResult<Vec<u8>> {
176 self.ensure_active()?;
177
178 let call_raw_fn = self.call_raw_fn.ok_or_else(|| {
179 ConsumerError::MissingSymbol("plugin_call_raw (binary transport not available)".into())
180 })?;
181
182 let rb_response_free_fn = self.rb_response_free_fn.ok_or_else(|| {
183 ConsumerError::MissingSymbol("rb_response_free (binary transport not available)".into())
184 })?;
185
186 let response: RbResponse = unsafe {
188 call_raw_fn(
189 self.handle,
190 message_id,
191 request.as_ptr() as *const std::ffi::c_void,
192 request.len(),
193 )
194 };
195
196 let result = if response.is_error() {
198 let error_msg = if response.data.is_null() {
200 "Unknown error".to_string()
201 } else {
202 let slice = unsafe { response.as_slice() };
203 String::from_utf8_lossy(slice).into_owned()
204 };
205 Err(ConsumerError::CallFailed(PluginError::from_code(
206 response.error_code,
207 error_msg,
208 )))
209 } else {
210 let data = unsafe { response.as_slice().to_vec() };
212 Ok(data)
213 };
214
215 unsafe {
217 let mut response = response;
218 rb_response_free_fn(&mut response);
219 }
220
221 result
222 }
223
224 pub fn state(&self) -> LifecycleState {
226 if self.shutdown.load(Ordering::SeqCst) {
230 return LifecycleState::Stopped;
231 }
232
233 let state_code = unsafe { (self.get_state_fn)(self.handle) };
235 state_from_u8(state_code)
236 }
237
238 pub fn rejected_request_count(&self) -> u64 {
240 unsafe { (self.get_rejected_count_fn)(self.handle) }
242 }
243
244 pub fn has_binary_transport(&self) -> bool {
246 self.call_raw_fn.is_some() && self.rb_response_free_fn.is_some()
247 }
248
249 pub fn set_log_level(&self, level: LogLevel) {
251 unsafe { (self.set_log_level_fn)(self.handle, level as u8) }
253 }
254
255 pub fn shutdown(&self) -> ConsumerResult<()> {
260 if self
262 .shutdown
263 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
264 .is_err()
265 {
266 return Ok(());
267 }
268
269 let success = unsafe { (self.shutdown_fn)(self.handle) };
271
272 if success {
273 Ok(())
274 } else {
275 Err(ConsumerError::InitFailed(
276 "plugin shutdown returned false".to_string(),
277 ))
278 }
279 }
280
281 fn ensure_active(&self) -> ConsumerResult<()> {
283 let state = self.state();
284 if state.can_handle_requests() {
285 Ok(())
286 } else {
287 Err(ConsumerError::NotActive(state))
288 }
289 }
290
291 fn process_buffer(&self, buffer: &FfiBuffer) -> ConsumerResult<String> {
293 if buffer.is_error() {
294 let error_msg = if buffer.is_empty() {
296 "Unknown error".to_string()
297 } else {
298 let slice = unsafe { buffer.as_slice() };
299 String::from_utf8_lossy(slice).into_owned()
300 };
301 return Err(ConsumerError::CallFailed(PluginError::from_code(
302 buffer.error_code,
303 error_msg,
304 )));
305 }
306
307 let data = unsafe { buffer.as_slice() };
309
310 let envelope: ResponseEnvelope = serde_json::from_slice(data)
312 .map_err(|e| ConsumerError::InvalidResponse(e.to_string()))?;
313
314 if envelope.is_success() {
315 match envelope.payload {
317 Some(payload) => Ok(serde_json::to_string(&payload)?),
318 None => Ok("null".to_string()),
319 }
320 } else {
321 let code = envelope.error_code.unwrap_or(11);
322 let message = envelope.error_message.unwrap_or_default();
323 Err(ConsumerError::CallFailed(PluginError::from_code(
324 code, message,
325 )))
326 }
327 }
328}
329
330impl Drop for NativePlugin {
331 fn drop(&mut self) {
332 let _ = self.shutdown();
334 }
335}
336
337unsafe impl Send for NativePlugin {}
340
341unsafe impl Sync for NativePlugin {}
345
346fn state_from_u8(code: u8) -> LifecycleState {
348 match code {
349 0 => LifecycleState::Installed,
350 1 => LifecycleState::Starting,
351 2 => LifecycleState::Active,
352 3 => LifecycleState::Stopping,
353 4 => LifecycleState::Stopped,
354 5 => LifecycleState::Failed,
355 _ => LifecycleState::Failed, }
357}
358
359#[cfg(test)]
360mod tests {
361 #![allow(non_snake_case)]
362
363 use super::*;
364
365 #[test]
366 fn state_from_u8___valid_codes___returns_correct_state() {
367 assert_eq!(state_from_u8(0), LifecycleState::Installed);
368 assert_eq!(state_from_u8(1), LifecycleState::Starting);
369 assert_eq!(state_from_u8(2), LifecycleState::Active);
370 assert_eq!(state_from_u8(3), LifecycleState::Stopping);
371 assert_eq!(state_from_u8(4), LifecycleState::Stopped);
372 assert_eq!(state_from_u8(5), LifecycleState::Failed);
373 }
374
375 #[test]
376 fn state_from_u8___invalid_code___returns_failed() {
377 assert_eq!(state_from_u8(255), LifecycleState::Failed);
378 assert_eq!(state_from_u8(100), LifecycleState::Failed);
379 }
380}