Skip to main content

tauri_plugin_background_service/desktop/
ipc.rs

1//! IPC protocol types and framing for desktop OS service mode.
2//!
3//! The desktop OS service uses length-prefixed JSON over Unix domain sockets
4//! (Linux/macOS) or named pipes (Windows) for IPC between the GUI process
5//! and the headless service process.
6
7use serde::{Deserialize, Serialize};
8
9use crate::error::ServiceError;
10
11/// Maximum allowed frame size (16 MB).
12pub const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
13
14/// IPC request sent from the GUI process to the headless service.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase", tag = "type")]
17#[non_exhaustive]
18pub enum IpcRequest {
19    /// Start the background service with the given config.
20    Start {
21        /// Startup configuration forwarded from the plugin.
22        config: crate::models::StartConfig,
23    },
24    /// Stop the running background service.
25    Stop,
26    /// Query whether a background service is currently running.
27    IsRunning,
28}
29
30/// IPC response sent from the headless service to the GUI process.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct IpcResponse {
34    /// Whether the request succeeded.
35    pub ok: bool,
36    /// Optional data payload on success.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub data: Option<serde_json::Value>,
39    /// Optional error message on failure.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub error: Option<String>,
42}
43
44/// IPC event streamed from the headless service to the GUI process.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "camelCase", tag = "type")]
47#[non_exhaustive]
48pub enum IpcEvent {
49    /// Service started successfully.
50    Started,
51    /// Service stopped.
52    Stopped { reason: String },
53    /// Service encountered an error.
54    Error { message: String },
55}
56
57/// Encode a message into a length-prefixed JSON frame.
58///
59/// Frame format: `[4-byte big-endian u32 length][JSON payload]`
60pub fn encode_frame<T: Serialize>(msg: &T) -> Result<Vec<u8>, serde_json::Error> {
61    let json = serde_json::to_vec(msg)?;
62    let len = json.len() as u32;
63    let mut buf = Vec::with_capacity(4 + json.len());
64    buf.extend_from_slice(&len.to_be_bytes());
65    buf.extend_from_slice(&json);
66    Ok(buf)
67}
68
69/// Decode a length-prefixed JSON frame.
70///
71/// Returns an error if:
72/// - The frame is shorter than 4 bytes (missing length prefix)
73/// - The payload length exceeds `MAX_FRAME_SIZE`
74/// - The JSON payload cannot be deserialized
75pub fn decode_frame<T: serde::de::DeserializeOwned>(data: &[u8]) -> Result<T, FrameError> {
76    if data.len() < 4 {
77        return Err(FrameError::IncompleteLength);
78    }
79    let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
80    if len > MAX_FRAME_SIZE {
81        return Err(FrameError::TooLarge {
82            size: len,
83            max: MAX_FRAME_SIZE,
84        });
85    }
86    let payload = data.get(4..4 + len).ok_or(FrameError::IncompletePayload {
87        expected: len,
88        available: data.len().saturating_sub(4),
89    })?;
90    serde_json::from_slice(payload).map_err(FrameError::Json)
91}
92
93/// Errors that can occur during frame decoding.
94#[derive(Debug, thiserror::Error)]
95#[non_exhaustive]
96pub enum FrameError {
97    /// The frame is shorter than 4 bytes.
98    #[error("incomplete length prefix: need 4 bytes")]
99    IncompleteLength,
100    /// The payload exceeds the maximum frame size.
101    #[error("frame too large: {size} bytes (max {max})")]
102    TooLarge { size: usize, max: usize },
103    /// The available data is shorter than the declared payload length.
104    #[error("incomplete payload: expected {expected} bytes, got {available}")]
105    IncompletePayload { expected: usize, available: usize },
106    /// The JSON payload could not be deserialized.
107    #[error("JSON decode error: {0}")]
108    Json(#[from] serde_json::Error),
109}
110
111/// Derive the IPC socket path for a given service label.
112///
113/// - Linux: `$XDG_RUNTIME_DIR/{label}.sock` (fallback: `/run/user/{uid}/{label}.sock`)
114/// - macOS: `/tmp/{label}.sock`
115/// - Windows: `\\.\pipe\{label}`
116///
117/// # Errors
118///
119/// Returns `ServiceError::Init` if the label is empty, contains path
120/// separators, or contains `..` components.
121pub fn socket_path(label: &str) -> Result<std::path::PathBuf, ServiceError> {
122    sanitize_label(label)?;
123    #[cfg(target_os = "linux")]
124    {
125        let dir = std::env::var("XDG_RUNTIME_DIR")
126            .unwrap_or_else(|_| format!("/run/user/{}", unsafe { libc::getuid() }));
127        Ok(std::path::PathBuf::from(format!("{dir}/{label}.sock")))
128    }
129    #[cfg(target_os = "macos")]
130    {
131        Ok(std::path::PathBuf::from(format!("/tmp/{label}.sock")))
132    }
133    #[cfg(windows)]
134    {
135        Ok(std::path::PathBuf::from(format!(r"\\.\pipe\{label}")))
136    }
137    #[cfg(not(any(target_os = "linux", target_os = "macos", windows)))]
138    {
139        Ok(std::path::PathBuf::from(format!("/tmp/{label}.sock")))
140    }
141}
142
143/// Validate that a service label does not contain path traversal characters.
144fn sanitize_label(label: &str) -> Result<std::path::PathBuf, ServiceError> {
145    if label.is_empty() {
146        return Err(ServiceError::Init("service label must not be empty".into()));
147    }
148    if label.contains('/') || label.contains('\\') {
149        return Err(ServiceError::Init(format!(
150            "service label must not contain path separators: {label}"
151        )));
152    }
153    if label.contains("..") {
154        return Err(ServiceError::Init(format!(
155            "service label must not contain '..': {label}"
156        )));
157    }
158    Ok(std::path::PathBuf::from(label))
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    // --- IpcRequest serde roundtrip tests ---
166
167    #[test]
168    fn ipc_request_start_serde_roundtrip() {
169        let req = IpcRequest::Start {
170            config: crate::models::StartConfig {
171                service_label: "Syncing".into(),
172                foreground_service_type: "dataSync".into(),
173            },
174        };
175        let json = serde_json::to_string(&req).unwrap();
176        let de: IpcRequest = serde_json::from_str(&json).unwrap();
177        match de {
178            IpcRequest::Start { config } => {
179                assert_eq!(config.service_label, "Syncing");
180                assert_eq!(config.foreground_service_type, "dataSync");
181            }
182            other => panic!("Expected Start, got {other:?}"),
183        }
184    }
185
186    #[test]
187    fn ipc_request_start_json_tag() {
188        let req = IpcRequest::Start {
189            config: crate::models::StartConfig::default(),
190        };
191        let json = serde_json::to_string(&req).unwrap();
192        assert!(json.contains("\"type\":\"start\""), "Tagged JSON: {json}");
193    }
194
195    #[test]
196    fn ipc_request_stop_serde_roundtrip() {
197        let req = IpcRequest::Stop;
198        let json = serde_json::to_string(&req).unwrap();
199        let de: IpcRequest = serde_json::from_str(&json).unwrap();
200        assert!(matches!(de, IpcRequest::Stop));
201    }
202
203    #[test]
204    fn ipc_request_stop_json_tag() {
205        let req = IpcRequest::Stop;
206        let json = serde_json::to_string(&req).unwrap();
207        assert!(json.contains("\"type\":\"stop\""), "Tagged JSON: {json}");
208    }
209
210    #[test]
211    fn ipc_request_is_running_serde_roundtrip() {
212        let req = IpcRequest::IsRunning;
213        let json = serde_json::to_string(&req).unwrap();
214        let de: IpcRequest = serde_json::from_str(&json).unwrap();
215        assert!(matches!(de, IpcRequest::IsRunning));
216    }
217
218    #[test]
219    fn ipc_request_is_running_json_tag() {
220        let req = IpcRequest::IsRunning;
221        let json = serde_json::to_string(&req).unwrap();
222        assert!(
223            json.contains("\"type\":\"isRunning\""),
224            "Tagged JSON: {json}"
225        );
226    }
227
228    // --- IpcResponse serde roundtrip tests ---
229
230    #[test]
231    fn ipc_response_success_roundtrip() {
232        let resp = IpcResponse {
233            ok: true,
234            data: Some(serde_json::json!({"running": true})),
235            error: None,
236        };
237        let json = serde_json::to_string(&resp).unwrap();
238        let de: IpcResponse = serde_json::from_str(&json).unwrap();
239        assert!(de.ok);
240        assert_eq!(de.data.unwrap()["running"], true);
241        assert!(de.error.is_none());
242    }
243
244    #[test]
245    fn ipc_response_error_roundtrip() {
246        let resp = IpcResponse {
247            ok: false,
248            data: None,
249            error: Some("Service is already running".into()),
250        };
251        let json = serde_json::to_string(&resp).unwrap();
252        let de: IpcResponse = serde_json::from_str(&json).unwrap();
253        assert!(!de.ok);
254        assert!(de.data.is_none());
255        assert_eq!(de.error.unwrap(), "Service is already running");
256    }
257
258    #[test]
259    fn ipc_response_skips_none_fields() {
260        let resp = IpcResponse {
261            ok: true,
262            data: None,
263            error: None,
264        };
265        let json = serde_json::to_string(&resp).unwrap();
266        assert!(
267            !json.contains("\"data\""),
268            "Should skip null data: {json}"
269        );
270        assert!(
271            !json.contains("\"error\""),
272            "Should skip null error: {json}"
273        );
274    }
275
276    #[test]
277    fn ipc_response_camel_case_keys() {
278        let resp = IpcResponse {
279            ok: true,
280            data: None,
281            error: None,
282        };
283        let json = serde_json::to_string(&resp).unwrap();
284        assert!(json.contains("\"ok\""), "ok key: {json}");
285    }
286
287    // --- IpcEvent serde roundtrip tests ---
288
289    #[test]
290    fn ipc_event_started_serde_roundtrip() {
291        let event = IpcEvent::Started;
292        let json = serde_json::to_string(&event).unwrap();
293        let de: IpcEvent = serde_json::from_str(&json).unwrap();
294        assert!(matches!(de, IpcEvent::Started));
295    }
296
297    #[test]
298    fn ipc_event_started_json_tag() {
299        let event = IpcEvent::Started;
300        let json = serde_json::to_string(&event).unwrap();
301        assert!(json.contains("\"type\":\"started\""), "Tagged JSON: {json}");
302    }
303
304    #[test]
305    fn ipc_event_stopped_serde_roundtrip() {
306        let event = IpcEvent::Stopped {
307            reason: "cancelled".into(),
308        };
309        let json = serde_json::to_string(&event).unwrap();
310        let de: IpcEvent = serde_json::from_str(&json).unwrap();
311        match de {
312            IpcEvent::Stopped { reason } => assert_eq!(reason, "cancelled"),
313            other => panic!("Expected Stopped, got {other:?}"),
314        }
315    }
316
317    #[test]
318    fn ipc_event_stopped_json_keys() {
319        let event = IpcEvent::Stopped {
320            reason: "done".into(),
321        };
322        let json = serde_json::to_string(&event).unwrap();
323        assert!(json.contains("\"type\":\"stopped\""), "Tag: {json}");
324        assert!(json.contains("\"reason\":\"done\""), "Reason: {json}");
325    }
326
327    #[test]
328    fn ipc_event_error_serde_roundtrip() {
329        let event = IpcEvent::Error {
330            message: "init failed".into(),
331        };
332        let json = serde_json::to_string(&event).unwrap();
333        let de: IpcEvent = serde_json::from_str(&json).unwrap();
334        match de {
335            IpcEvent::Error { message } => assert_eq!(message, "init failed"),
336            other => panic!("Expected Error, got {other:?}"),
337        }
338    }
339
340    #[test]
341    fn ipc_event_error_json_keys() {
342        let event = IpcEvent::Error {
343            message: "oops".into(),
344        };
345        let json = serde_json::to_string(&event).unwrap();
346        assert!(json.contains("\"type\":\"error\""), "Tag: {json}");
347        assert!(json.contains("\"message\":\"oops\""), "Message: {json}");
348    }
349
350    // --- Frame encode/decode tests ---
351
352    #[test]
353    fn ipc_frame_encode_decode_request() {
354        let req = IpcRequest::Start {
355            config: crate::models::StartConfig::default(),
356        };
357        let encoded = encode_frame(&req).unwrap();
358        let decoded: IpcRequest = decode_frame(&encoded).unwrap();
359        match decoded {
360            IpcRequest::Start { config } => {
361                assert_eq!(config.service_label, "Service running");
362            }
363            other => panic!("Expected Start, got {other:?}"),
364        }
365    }
366
367    #[test]
368    fn ipc_frame_encode_decode_response() {
369        let resp = IpcResponse {
370            ok: true,
371            data: Some(serde_json::json!(42)),
372            error: None,
373        };
374        let encoded = encode_frame(&resp).unwrap();
375        let decoded: IpcResponse = decode_frame(&encoded).unwrap();
376        assert!(decoded.ok);
377        assert_eq!(decoded.data.unwrap(), 42);
378    }
379
380    #[test]
381    fn ipc_frame_encode_decode_event() {
382        let event = IpcEvent::Stopped {
383            reason: "done".into(),
384        };
385        let encoded = encode_frame(&event).unwrap();
386        let decoded: IpcEvent = decode_frame(&encoded).unwrap();
387        match decoded {
388            IpcEvent::Stopped { reason } => assert_eq!(reason, "done"),
389            other => panic!("Expected Stopped, got {other:?}"),
390        }
391    }
392
393    #[test]
394    fn ipc_frame_length_prefix_is_big_endian() {
395        let req = IpcRequest::Stop;
396        let encoded = encode_frame(&req).unwrap();
397        // First 4 bytes are the length of the JSON payload
398        let len = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
399        assert_eq!(len as usize, encoded.len() - 4);
400    }
401
402    #[test]
403    fn ipc_frame_too_large_rejected() {
404        // Create a frame with length prefix claiming > 16MB
405        let fake_len = (MAX_FRAME_SIZE + 1) as u32;
406        let mut data = vec![0u8; 4 + 1];
407        data[0..4].copy_from_slice(&fake_len.to_be_bytes());
408        data[4] = b'{';
409        let result: Result<IpcRequest, FrameError> = decode_frame(&data);
410        match result {
411            Err(FrameError::TooLarge { size, max }) => {
412                assert_eq!(size, MAX_FRAME_SIZE + 1);
413                assert_eq!(max, MAX_FRAME_SIZE);
414            }
415            other => panic!("Expected TooLarge, got {other:?}"),
416        }
417    }
418
419    #[test]
420    fn ipc_frame_incomplete_length() {
421        let data = [0u8; 3]; // Only 3 bytes, need 4
422        let result: Result<IpcRequest, FrameError> = decode_frame(&data);
423        assert!(
424            matches!(result, Err(FrameError::IncompleteLength)),
425            "Expected IncompleteLength, got {result:?}"
426        );
427    }
428
429    #[test]
430    fn ipc_frame_malformed_json() {
431        // Valid length prefix (3 bytes) + invalid JSON
432        let payload = b"{invalid";
433        let mut data = Vec::with_capacity(4 + payload.len());
434        data.extend_from_slice(&(payload.len() as u32).to_be_bytes());
435        data.extend_from_slice(payload);
436        let result: Result<IpcRequest, FrameError> = decode_frame(&data);
437        match result {
438            Err(FrameError::Json(_)) => {} // expected
439            other => panic!("Expected Json error, got {other:?}"),
440        }
441    }
442
443    #[test]
444    fn ipc_frame_incomplete_payload() {
445        // Length says 100 bytes but only 1 byte available
446        let mut data = vec![0u8; 5];
447        data[0..4].copy_from_slice(&100u32.to_be_bytes());
448        data[4] = b'{';
449        let result: Result<IpcRequest, FrameError> = decode_frame(&data);
450        match result {
451            Err(FrameError::IncompletePayload {
452                expected,
453                available,
454            }) => {
455                assert_eq!(expected, 100);
456                assert_eq!(available, 1);
457            }
458            other => panic!("Expected IncompletePayload, got {other:?}"),
459        }
460    }
461
462    // --- socket_path tests ---
463
464    #[test]
465    fn socket_path_unix_format() {
466        let path = socket_path("com.example.svc").unwrap();
467        #[cfg(target_os = "linux")]
468        {
469            // Should be under XDG_RUNTIME_DIR or /run/user/{uid}
470            let path_str = path.to_str().unwrap();
471            assert!(
472                path_str.ends_with("/com.example.svc.sock"),
473                "Expected path ending with /com.example.svc.sock, got: {path_str}"
474            );
475            if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
476                assert!(
477                    path_str.starts_with(&xdg),
478                    "Expected path under XDG_RUNTIME_DIR ({xdg}), got: {path_str}"
479                );
480            } else {
481                assert!(
482                    path_str.contains("/run/user/"),
483                    "Expected fallback /run/user/ path, got: {path_str}"
484                );
485            }
486        }
487        #[cfg(target_os = "macos")]
488        {
489            assert_eq!(path.to_str().unwrap(), "/tmp/com.example.svc.sock");
490        }
491    }
492
493    #[test]
494    fn socket_path_nonempty_label() {
495        let path = socket_path("my-app").unwrap();
496        #[cfg(unix)]
497        {
498            assert!(
499                path.to_str().unwrap().ends_with("my-app.sock"),
500                "Expected path ending with my-app.sock, got: {:?}",
501                path
502            );
503        }
504    }
505
506    #[test]
507    fn socket_path_rejects_slash_in_label() {
508        let result = socket_path("../etc/passwd");
509        assert!(result.is_err());
510        let err = result.unwrap_err().to_string();
511        assert!(err.contains("path separators"), "Error: {err}");
512    }
513
514    #[test]
515    fn socket_path_rejects_dotdot_in_label() {
516        let result = socket_path("..");
517        assert!(result.is_err());
518        let err = result.unwrap_err().to_string();
519        assert!(err.contains("'..'"), "Error: {err}");
520    }
521
522    #[test]
523    fn socket_path_rejects_backslash_in_label() {
524        let result = socket_path("foo\\bar");
525        assert!(result.is_err());
526        let err = result.unwrap_err().to_string();
527        assert!(err.contains("path separators"), "Error: {err}");
528    }
529
530    #[test]
531    fn socket_path_rejects_empty_label() {
532        let result = socket_path("");
533        assert!(result.is_err());
534        let err = result.unwrap_err().to_string();
535        assert!(err.contains("empty"), "Error: {err}");
536    }
537}