1use serde::de::DeserializeOwned;
2use swift_rs::SRString;
3use tauri::{ipc::Channel, plugin::PluginApi, AppHandle, Manager, Runtime};
4
5use serde::Serialize;
6use serde_json::Value as JsonValue;
7
8use memoffset::offset_of;
9
10use std::{
11 collections::HashMap,
12 fmt,
13 sync::{mpsc::channel, Mutex, OnceLock},
14};
15
16use std::sync::atomic::{AtomicI32, Ordering};
17
18use std::sync::Arc;
19
20type PluginResponse = Result<serde_json::Value, serde_json::Value>;
21
22type PendingPluginCallHandler = Box<dyn FnOnce(PluginResponse) + Send + 'static>;
23
24static PENDING_PLUGIN_CALLS_ID: AtomicI32 = AtomicI32::new(0);
25static PENDING_PLUGIN_CALLS: OnceLock<Mutex<HashMap<i32, PendingPluginCallHandler>>> =
26 OnceLock::new();
27static CHANNELS: OnceLock<Mutex<HashMap<u32, Channel<serde_json::Value>>>> = OnceLock::new();
28
29#[derive(Debug, thiserror::Error, Clone, serde::Deserialize)]
31pub struct ErrorResponse<T = ()> {
32 pub code: Option<String>,
34 pub message: Option<String>,
36 #[serde(flatten)]
38 pub data: T,
39}
40
41impl<T> fmt::Display for ErrorResponse<T> {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 if let Some(code) = &self.code {
44 write!(f, "[{code}]")?;
45 if self.message.is_some() {
46 write!(f, " - ")?;
47 }
48 }
49 if let Some(message) = &self.message {
50 write!(f, "{message}")?;
51 }
52 Ok(())
53 }
54}
55
56#[derive(Debug, thiserror::Error)]
58pub enum PluginInvokeError {
59 #[error("the webview is unreachable")]
61 UnreachableWebview,
62 #[error(transparent)]
64 InvokeRejected(#[from] ErrorResponse),
65 #[error("failed to deserialize response: {0}")]
67 CannotDeserializeResponse(serde_json::Error),
68 #[error("failed to serialize payload: {0}")]
70 CannotSerializePayload(serde_json::Error),
71}
72
73#[repr(C)]
74pub struct PluginApiRef<R: Runtime, C: DeserializeOwned> {
75 handle: AppHandle<R>,
76 name: &'static str,
77 raw_config: Arc<JsonValue>,
78 config: C,
79}
80
81#[repr(C)]
82pub struct PluginApiExt<R: Runtime, C: DeserializeOwned>(PluginApi<R, C>);
83
84impl<R: Runtime, C: DeserializeOwned> From<PluginApi<R, C>> for PluginApiExt<R, C> {
85 fn from(api: PluginApi<R, C>) -> Self {
86 PluginApiExt(api)
87 }
88}
89
90impl<R: Runtime, C: DeserializeOwned> PluginApiExt<R, C> {
91 pub fn app(&self) -> &AppHandle<R> {
93 self.0.app()
94 }
95
96 pub fn name(&self) -> &str {
98 let self_ptr = &self.0 as *const PluginApi<R, C> as *const u8;
99 let offset = offset_of!(PluginApiRef<R, C>, name);
100
101 let name: &'static str = unsafe {
102 let field_ptr = self_ptr.add(offset) as *const &'static str;
103 *field_ptr
104 };
105 name
106 }
107
108 pub fn raw_config(&self) -> Arc<JsonValue> {
110 let self_ptr = self as *const PluginApiExt<R, C> as *const u8;
111 let offset = offset_of!(PluginApiRef<R, C>, raw_config);
112
113 let rc_ptr = unsafe { self_ptr.add(offset) as *const Arc<JsonValue> };
114 let rc_ref = unsafe { &*rc_ptr };
115 rc_ref.clone()
116 }
117}
118
119#[cfg(any(target_os = "macos", target_os = "ios"))]
120impl<R: Runtime, C: DeserializeOwned> PluginApiExt<R, C> {
121 pub fn register_swift_plugin(
123 &self,
124 init_fn: unsafe fn() -> *const std::ffi::c_void,
125 ) -> Result<PluginHandleExt<R>, PluginInvokeError> {
126 if let Some(webview) = self.app().webviews().values().next() {
127 let (tx, rx) = channel();
128 let name = self.name();
129 let config = self.raw_config().clone();
130 let name = name.to_string();
131 let config = serde_json::to_string(&config).unwrap();
132 webview
133 .with_webview(move |w| {
134 unsafe {
135 crate::macos::swift_register_plugin(
136 &SRString::from(name.as_str()),
137 init_fn(),
138 &serde_json::to_string(&config).unwrap().as_str().into(),
139 w.inner() as _,
140 )
141 };
142 tx.send(()).unwrap();
143 })
144 .map_err(|_| PluginInvokeError::UnreachableWebview)?;
145 rx.recv().unwrap();
146 } else {
147 unsafe {
148 crate::macos::swift_register_plugin(
149 &SRString::from(self.name()),
150 init_fn(),
151 &serde_json::to_string(&self.raw_config())
152 .unwrap()
153 .as_str()
154 .into(),
155 std::ptr::null(),
156 )
157 };
158 }
159
160 Ok(PluginHandleExt {
161 name: self.name().to_string(),
162 handle: self.app().clone(),
163 })
164 }
165}
166
167pub struct PluginHandleExt<R: Runtime> {
168 name: String,
169 handle: AppHandle<R>,
170}
171
172impl<R: Runtime> PluginHandleExt<R> {
173 pub fn run_swift_plugin<T: DeserializeOwned>(
175 &self,
176 command: impl AsRef<str>,
177 payload: impl Serialize,
178 ) -> Result<T, PluginInvokeError> {
179 let (tx, rx) = channel();
180
181 run_command(
182 &self.name,
183 &self.handle,
184 command,
185 serde_json::to_value(payload).map_err(PluginInvokeError::CannotSerializePayload)?,
186 move |response| {
187 tx.send(response).unwrap();
188 },
189 )?;
190
191 let response = rx.recv().unwrap();
192 match response {
193 Ok(r) => serde_json::from_value(r).map_err(PluginInvokeError::CannotDeserializeResponse),
194 Err(r) => Err(
195 serde_json::from_value::<ErrorResponse>(r)
196 .map(Into::into)
197 .map_err(PluginInvokeError::CannotDeserializeResponse)?,
198 ),
199 }
200 }
201}
202
203pub(crate) fn run_command<R: Runtime, C: AsRef<str>, F: FnOnce(PluginResponse) + Send + 'static>(
204 name: &str,
205 _handle: &AppHandle<R>,
206 command: C,
207 payload: serde_json::Value,
208 handler: F,
209) -> Result<(), PluginInvokeError> {
210 use std::{
211 ffi::CStr,
212 os::raw::{c_char, c_int, c_ulonglong},
213 };
214
215 let id: i32 = PENDING_PLUGIN_CALLS_ID.fetch_add(1, Ordering::Relaxed);
216 PENDING_PLUGIN_CALLS
217 .get_or_init(Default::default)
218 .lock()
219 .unwrap()
220 .insert(id, Box::new(handler));
221
222 unsafe {
223 extern "C" fn plugin_command_response_handler(
224 id: c_int,
225 success: c_int,
226 payload: *const c_char,
227 ) {
228 let payload = unsafe {
229 assert!(!payload.is_null());
230 CStr::from_ptr(payload)
231 };
232
233 if let Some(handler) = PENDING_PLUGIN_CALLS
234 .get_or_init(Default::default)
235 .lock()
236 .unwrap()
237 .remove(&id)
238 {
239 let json = payload.to_str().unwrap();
240 match serde_json::from_str(json) {
241 Ok(payload) => {
242 handler(if success == 1 {
243 Ok(payload)
244 } else {
245 Err(payload)
246 });
247 }
248 Err(err) => {
249 handler(Err(format!("{err}, data: {json}").into()));
250 }
251 }
252 }
253 }
254
255 extern "C" fn send_channel_data_handler(id: c_ulonglong, payload: *const c_char) {
256 let payload = unsafe {
257 assert!(!payload.is_null());
258 CStr::from_ptr(payload)
259 };
260
261 if let Some(channel) = CHANNELS
262 .get_or_init(Default::default)
263 .lock()
264 .unwrap()
265 .get(&(id as u32))
266 {
267 let payload: serde_json::Value = serde_json::from_str(payload.to_str().unwrap()).unwrap();
268 let _ = channel.send(payload);
269 }
270 }
271
272 crate::macos::swift_run_plugin_command(
273 id,
274 &name.into(),
275 &command.as_ref().into(),
276 &serde_json::to_string(&payload).unwrap().as_str().into(),
277 crate::macos::PluginMessageCallback(plugin_command_response_handler),
278 crate::macos::ChannelSendDataCallback(send_channel_data_handler),
279 );
280 }
281
282 Ok(())
283}