tauri_plugin_conduit/lib.rs
1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! # tauri-plugin-conduit
4//!
5//! Tauri v2 plugin for conduit — binary IPC over the `conduit://` custom
6//! protocol.
7//!
8//! Registers a `conduit://` custom protocol for zero-overhead in-process
9//! binary dispatch. Supports both sync and async handlers via
10//! [`ConduitHandler`](conduit_core::ConduitHandler). No network surface.
11//!
12//! ## Usage
13//!
14//! ```rust,ignore
15//! use tauri_conduit::{command, handler};
16//!
17//! #[command]
18//! fn greet(name: String) -> String {
19//! format!("Hello, {name}!")
20//! }
21//!
22//! #[command]
23//! async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
24//! state.get_user(id).await.map_err(|e| e.to_string())
25//! }
26//!
27//! tauri::Builder::default()
28//! .plugin(
29//! tauri_plugin_conduit::init()
30//! .handler("greet", handler!(greet))
31//! .handler("fetch_user", handler!(fetch_user))
32//! .channel("telemetry")
33//! .build()
34//! )
35//! .run(tauri::generate_context!())
36//! .unwrap();
37//! ```
38
39/// Re-export the `#[command]` attribute macro from `conduit-derive`.
40///
41/// This is conduit's equivalent of `#[tauri::command]`. Use it for
42/// named-parameter handlers:
43///
44/// ```rust,ignore
45/// use tauri_conduit::{command, handler};
46///
47/// #[command]
48/// fn greet(name: String, greeting: String) -> String {
49/// format!("{greeting}, {name}!")
50/// }
51/// ```
52pub use conduit_derive::command;
53
54/// Re-export the `handler!()` macro from `conduit-derive`.
55///
56/// Resolves a `#[command]` function name to its conduit handler struct
57/// for registration:
58///
59/// ```rust,ignore
60/// tauri_plugin_conduit::init()
61/// .handler("greet", handler!(greet))
62/// .build()
63/// ```
64pub use conduit_derive::handler;
65
66use std::collections::HashMap;
67use std::sync::Arc;
68
69use conduit_core::{
70 ChannelBuffer, ConduitHandler, Decode, Encode, HandlerResponse, Queue, RingBuffer, Router,
71};
72use futures_util::FutureExt;
73use subtle::ConstantTimeEq;
74use tauri::plugin::{Builder as TauriPluginBuilder, TauriPlugin};
75use tauri::{AppHandle, Emitter, Manager, Runtime};
76
77// ---------------------------------------------------------------------------
78// Helper: safe HTTP response builder
79// ---------------------------------------------------------------------------
80
81/// Build an HTTP response, falling back to a minimal 500 if construction fails.
82fn make_response(status: u16, content_type: &str, body: Vec<u8>) -> http::Response<Vec<u8>> {
83 http::Response::builder()
84 .status(status)
85 .header("Content-Type", content_type)
86 .header("Access-Control-Allow-Origin", "*")
87 .body(body)
88 .unwrap_or_else(|_| {
89 http::Response::builder()
90 .status(500)
91 .body(b"internal error".to_vec())
92 .expect("fallback response must not fail")
93 })
94}
95
96/// Build a JSON error response: `{"error": "message"}`.
97///
98/// Uses `sonic_rs` for proper RFC 8259 escaping of all control characters,
99/// newlines, quotes, and backslashes — not just `\` and `"`.
100fn make_error_response(status: u16, message: &str) -> http::Response<Vec<u8>> {
101 #[derive(serde::Serialize)]
102 struct ErrorBody<'a> {
103 error: &'a str,
104 }
105 let body = conduit_core::sonic_rs::to_vec(&ErrorBody { error: message })
106 .unwrap_or_else(|_| br#"{"error":"internal error"}"#.to_vec());
107 make_response(status, "application/json", body)
108}
109
110// ---------------------------------------------------------------------------
111// BootstrapInfo — returned to JS via `conduit_bootstrap` command
112// ---------------------------------------------------------------------------
113
114/// Connection info returned to the frontend during bootstrap.
115#[derive(Clone, serde::Serialize, serde::Deserialize)]
116#[serde(rename_all = "camelCase")]
117pub struct BootstrapInfo {
118 /// Protocol version (currently `1`). Allows the TS client to verify
119 /// protocol compatibility.
120 #[serde(default = "default_protocol_version")]
121 pub protocol_version: u8,
122 /// Base URL for the custom protocol (e.g., `"conduit://localhost"`).
123 pub protocol_base: String,
124 /// Per-launch invoke key for custom protocol authentication (hex-encoded).
125 ///
126 /// **Security**: This key authenticates custom protocol requests. It is
127 /// generated fresh each launch from 32 bytes of OS randomness and validated
128 /// using constant-time comparison. The JS client includes it as the
129 /// `X-Conduit-Key` header on every `conduit://` request.
130 pub invoke_key: String,
131 /// Available channel names.
132 pub channels: Vec<String>,
133}
134
135fn default_protocol_version() -> u8 {
136 1
137}
138
139impl std::fmt::Debug for BootstrapInfo {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("BootstrapInfo")
142 .field("protocol_version", &self.protocol_version)
143 .field("protocol_base", &self.protocol_base)
144 .field("invoke_key", &"[REDACTED]")
145 .field("channels", &self.channels)
146 .finish()
147 }
148}
149
150// ---------------------------------------------------------------------------
151// PluginState — managed Tauri state
152// ---------------------------------------------------------------------------
153
154/// Shared state for the conduit Tauri plugin.
155///
156/// Holds the router, named streaming channels, the per-launch invoke key,
157/// and the app handle for emitting push notifications.
158pub struct PluginState<R: Runtime> {
159 dispatch: Arc<Router>,
160 /// `#[command]`-generated handlers (sync and async via [`ConduitHandler`]).
161 handlers: Arc<HashMap<String, Arc<dyn ConduitHandler>>>,
162 /// Named channels for server→client streaming (lossy or ordered).
163 channels: HashMap<String, Arc<ChannelBuffer>>,
164 /// Tauri app handle for emitting events to the frontend.
165 app_handle: AppHandle<R>,
166 /// Pre-cached `Arc` of the app handle — avoids a heap allocation per request.
167 app_handle_arc: Arc<AppHandle<R>>,
168 /// Per-launch invoke key (hex-encoded, 64 hex chars = 32 bytes).
169 invoke_key: String,
170 /// Raw invoke key bytes for constant-time comparison.
171 invoke_key_bytes: [u8; 32],
172}
173
174impl<R: Runtime> std::fmt::Debug for PluginState<R> {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("PluginState")
177 .field("channels", &self.channels.keys().collect::<Vec<_>>())
178 .field("invoke_key", &"[REDACTED]")
179 .finish()
180 }
181}
182
183impl<R: Runtime> PluginState<R> {
184 /// Get a channel by name (for pushing data from Rust handlers).
185 pub fn channel(&self, name: &str) -> Option<&Arc<ChannelBuffer>> {
186 self.channels.get(name)
187 }
188
189 /// Push binary data to a named channel and notify JS listeners.
190 ///
191 /// After writing to the channel, emits both a global
192 /// `conduit:data-available` event (payload = channel name) and a
193 /// per-channel `conduit:data-available:{channel}` event. JS subscribers
194 /// can listen on either.
195 ///
196 /// For lossy channels, oldest frames are silently dropped when the buffer
197 /// is full. For reliable channels, returns an error if the buffer is full
198 /// (backpressure).
199 ///
200 /// Returns an error string if the named channel was not registered via
201 /// the builder or if a reliable channel is full.
202 pub fn push(&self, channel: &str, data: &[u8]) -> Result<(), String> {
203 let ch = self
204 .channels
205 .get(channel)
206 .ok_or_else(|| format!("unknown channel: {channel}"))?;
207 ch.push(data).map(|_| ()).map_err(|e| e.to_string())?;
208 // Emit global event (backward-compatible with old JS code).
209 if self
210 .app_handle
211 .emit("conduit:data-available", channel)
212 .is_err()
213 {
214 #[cfg(debug_assertions)]
215 eprintln!(
216 "conduit: failed to emit global data-available event for channel '{channel}'"
217 );
218 }
219 // Emit per-channel event.
220 if self
221 .app_handle
222 .emit(&format!("conduit:data-available:{channel}"), channel)
223 .is_err()
224 {
225 #[cfg(debug_assertions)]
226 eprintln!(
227 "conduit: failed to emit per-channel data-available event for channel '{channel}'"
228 );
229 }
230 Ok(())
231 }
232
233 /// Return the list of registered channel names.
234 pub fn channel_names(&self) -> Vec<String> {
235 self.channels.keys().cloned().collect()
236 }
237
238 /// Validate an invoke key candidate using constant-time operations.
239 fn validate_invoke_key(&self, candidate: &str) -> bool {
240 validate_invoke_key_ct(&self.invoke_key_bytes, candidate)
241 }
242}
243
244// ---------------------------------------------------------------------------
245// Tauri commands
246// ---------------------------------------------------------------------------
247
248/// Return bootstrap info so the JS client knows how to reach the conduit
249/// custom protocol.
250///
251/// May be called multiple times (e.g., after page reloads during development).
252/// The invoke key is generated once at plugin setup and remains constant for
253/// the lifetime of the app process. Repeated calls return the same key.
254#[tauri::command]
255fn conduit_bootstrap(
256 state: tauri::State<'_, PluginState<tauri::Wry>>,
257) -> Result<BootstrapInfo, String> {
258 Ok(BootstrapInfo {
259 protocol_version: 1,
260 protocol_base: "conduit://localhost".to_string(),
261 invoke_key: state.invoke_key.clone(),
262 channels: state.channel_names(),
263 })
264}
265
266/// Validate channel names and return those that exist.
267///
268/// This is a validation-only endpoint — no server-side subscription state is
269/// tracked. The JS client uses the returned list to know which channels are
270/// available. Actual data delivery happens via `conduit:data-available` events
271/// and `conduit://localhost/drain/<channel>` protocol requests.
272///
273/// Unknown channel names are silently filtered out — only channels that
274/// exist are returned.
275#[tauri::command]
276fn conduit_subscribe(
277 state: tauri::State<'_, PluginState<tauri::Wry>>,
278 channels: Vec<String>,
279) -> Result<Vec<String>, String> {
280 // Silently filter to only channels that exist.
281 let valid: Vec<String> = channels
282 .into_iter()
283 .filter(|c| state.channels.contains_key(c.as_str()))
284 .collect();
285 Ok(valid)
286}
287
288// ---------------------------------------------------------------------------
289// Channel kind (internal)
290// ---------------------------------------------------------------------------
291
292/// Internal enum for deferred channel construction.
293enum ChannelKind {
294 /// Lossy ring buffer with the given byte capacity.
295 Lossy(usize),
296 /// Reliable queue with the given max byte limit.
297 Reliable(usize),
298}
299
300// ---------------------------------------------------------------------------
301// Plugin builder
302// ---------------------------------------------------------------------------
303
304/// A deferred command registration closure.
305type CommandRegistration = Box<dyn FnOnce(&Router) + Send>;
306
307/// Builder for the conduit Tauri v2 plugin.
308///
309/// Collects command registrations and configuration, then produces a
310/// [`TauriPlugin`] via [`build`](Self::build).
311pub struct PluginBuilder {
312 /// Deferred command registrations: (name, handler factory).
313 commands: Vec<CommandRegistration>,
314 /// `#[command]`-generated handlers (sync and async).
315 handler_defs: Vec<(String, Arc<dyn ConduitHandler>)>,
316 /// Named channels: (name, kind).
317 channel_defs: Vec<(String, ChannelKind)>,
318}
319
320impl std::fmt::Debug for PluginBuilder {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 f.debug_struct("PluginBuilder")
323 .field("commands", &self.commands.len())
324 .field("handlers", &self.handler_defs.len())
325 .field("channel_defs_count", &self.channel_defs.len())
326 .finish()
327 }
328}
329
330/// Validate that a channel name matches `[a-zA-Z0-9_-]+`.
331fn validate_channel_name(name: &str) {
332 assert!(
333 !name.is_empty()
334 && name
335 .bytes()
336 .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-'),
337 "conduit: invalid channel name '{}' — must match [a-zA-Z0-9_-]+",
338 name
339 );
340}
341
342/// Default channel capacity (64 KB).
343const DEFAULT_CHANNEL_CAPACITY: usize = 64 * 1024;
344
345impl PluginBuilder {
346 /// Panic if a channel with the given name is already registered.
347 fn assert_no_duplicate_channel(&self, name: &str) {
348 if self.channel_defs.iter().any(|(n, _)| n == name) {
349 panic!(
350 "conduit: duplicate channel name '{}' — each channel must have a unique name",
351 name
352 );
353 }
354 }
355
356 /// Create a new, empty plugin builder.
357 pub fn new() -> Self {
358 Self {
359 commands: Vec::new(),
360 handler_defs: Vec::new(),
361 channel_defs: Vec::new(),
362 }
363 }
364
365 // -- Raw handlers -------------------------------------------------------
366
367 /// Register a raw command handler (`Vec<u8>` in, `Vec<u8>` out).
368 ///
369 /// Command names correspond to the path segment in the
370 /// `conduit://localhost/invoke/<cmd_name>` URL.
371 pub fn command<F>(mut self, name: impl Into<String>, handler: F) -> Self
372 where
373 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
374 {
375 let name = name.into();
376 self.commands.push(Box::new(move |table: &Router| {
377 table.register(name, handler);
378 }));
379 self
380 }
381
382 // -- ConduitHandler-based (#[command]-generated, sync or async) ----------
383
384 /// Register a `#[tauri_conduit::command]`-generated handler.
385 ///
386 /// Works with both sync and async handlers. Sync handlers are dispatched
387 /// inline. Async handlers are spawned on the tokio runtime — truly async,
388 /// exactly like `#[tauri::command]`.
389 ///
390 /// ```rust,ignore
391 /// use tauri_conduit::{command, handler};
392 ///
393 /// #[command]
394 /// fn greet(name: String) -> String {
395 /// format!("Hello, {name}!")
396 /// }
397 ///
398 /// #[command]
399 /// async fn fetch_user(state: State<'_, Db>, id: u64) -> Result<User, String> {
400 /// state.get_user(id).await.map_err(|e| e.to_string())
401 /// }
402 ///
403 /// tauri_plugin_conduit::init()
404 /// .handler("greet", handler!(greet))
405 /// .handler("fetch_user", handler!(fetch_user))
406 /// .build()
407 /// ```
408 pub fn handler(mut self, name: impl Into<String>, handler: impl ConduitHandler) -> Self {
409 self.handler_defs.push((name.into(), Arc::new(handler)));
410 self
411 }
412
413 /// Register a raw closure handler (legacy API).
414 ///
415 /// Accepts the same closure signature as the pre-`ConduitHandler` `.handler()`:
416 /// `Fn(Vec<u8>, &dyn Any) -> Result<Vec<u8>, Error>`. This is a synchronous
417 /// handler dispatched via `Router::register_with_context`.
418 ///
419 /// Use this for backward compatibility when migrating from closure-based
420 /// registration. For new code, prefer [`handler`](Self::handler) with
421 /// `#[tauri_conduit::command]` + `handler!()`.
422 pub fn handler_raw<F>(mut self, name: impl Into<String>, handler: F) -> Self
423 where
424 F: Fn(Vec<u8>, &dyn std::any::Any) -> Result<Vec<u8>, conduit_core::Error>
425 + Send
426 + Sync
427 + 'static,
428 {
429 let name = name.into();
430 self.commands.push(Box::new(move |table: &Router| {
431 table.register_with_context(name, handler);
432 }));
433 self
434 }
435
436 // -- JSON handlers (Level 1) --------------------------------------------
437
438 /// Typed JSON handler. Deserializes the request payload as `A` and
439 /// serializes the response as `R`.
440 ///
441 /// Unlike Tauri's `#[tauri::command]`, this takes a single argument type
442 /// (not named parameters) and does not support async or State injection.
443 ///
444 /// ```rust,ignore
445 /// .command_json("greet", |name: String| format!("Hello, {name}!"))
446 /// ```
447 pub fn command_json<F, A, R>(mut self, name: impl Into<String>, handler: F) -> Self
448 where
449 F: Fn(A) -> R + Send + Sync + 'static,
450 A: serde::de::DeserializeOwned + 'static,
451 R: serde::Serialize + 'static,
452 {
453 let name = name.into();
454 self.commands.push(Box::new(move |table: &Router| {
455 table.register_json(name, handler);
456 }));
457 self
458 }
459
460 /// Typed JSON handler that returns `Result<R, E>`.
461 ///
462 /// Like [`command_json`](Self::command_json), but the handler returns
463 /// `Result<R, E>` where `E: Display`. On success, `R` is serialized to
464 /// JSON. On error, the error's `Display` text is returned to the caller.
465 ///
466 /// For Tauri-style named parameters with `Result` returns, prefer
467 /// [`handler`](Self::handler) with `#[tauri_conduit::command]` instead:
468 ///
469 /// ```rust,ignore
470 /// use tauri_conduit::command;
471 ///
472 /// #[command]
473 /// fn divide(a: f64, b: f64) -> Result<f64, String> {
474 /// if b == 0.0 { Err("division by zero".into()) }
475 /// else { Ok(a / b) }
476 /// }
477 ///
478 /// // Preferred:
479 /// .handler("divide", divide)
480 /// ```
481 pub fn command_json_result<F, A, R, E>(mut self, name: impl Into<String>, handler: F) -> Self
482 where
483 F: Fn(A) -> Result<R, E> + Send + Sync + 'static,
484 A: serde::de::DeserializeOwned + 'static,
485 R: serde::Serialize + 'static,
486 E: std::fmt::Display + 'static,
487 {
488 let name = name.into();
489 self.commands.push(Box::new(move |table: &Router| {
490 table.register_json_result(name, handler);
491 }));
492 self
493 }
494
495 // -- Binary handlers (Level 2) ------------------------------------------
496
497 /// Register a typed binary command handler.
498 ///
499 /// The request payload is decoded via the [`Decode`] trait and the response
500 /// is encoded via [`Encode`]. No JSON involved — raw bytes in, raw bytes
501 /// out.
502 ///
503 /// ```rust,ignore
504 /// .command_binary("process", |tick: MarketTick| tick)
505 /// ```
506 pub fn command_binary<F, A, Ret>(mut self, name: impl Into<String>, handler: F) -> Self
507 where
508 F: Fn(A) -> Ret + Send + Sync + 'static,
509 A: Decode + 'static,
510 Ret: Encode + 'static,
511 {
512 let name = name.into();
513 self.commands.push(Box::new(move |table: &Router| {
514 table.register_binary(name, handler);
515 }));
516 self
517 }
518
519 // -- Lossy channels (default) -------------------------------------------
520
521 /// Register a lossy channel with the default capacity (64 KB).
522 ///
523 /// Oldest frames are silently dropped when the buffer is full. Best for
524 /// telemetry, game state, and real-time data where freshness matters more
525 /// than completeness.
526 ///
527 /// # Panics
528 ///
529 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
530 /// or duplicates an already-registered channel name.
531 pub fn channel(mut self, name: impl Into<String>) -> Self {
532 let name = name.into();
533 validate_channel_name(&name);
534 self.assert_no_duplicate_channel(&name);
535 self.channel_defs
536 .push((name, ChannelKind::Lossy(DEFAULT_CHANNEL_CAPACITY)));
537 self
538 }
539
540 /// Register a lossy channel with a custom byte capacity.
541 ///
542 /// # Panics
543 ///
544 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
545 /// or duplicates an already-registered channel name.
546 pub fn channel_with_capacity(mut self, name: impl Into<String>, capacity: usize) -> Self {
547 let name = name.into();
548 validate_channel_name(&name);
549 self.assert_no_duplicate_channel(&name);
550 self.channel_defs.push((name, ChannelKind::Lossy(capacity)));
551 self
552 }
553
554 // -- Reliable channels (guaranteed delivery) ----------------------------
555
556 /// Register an ordered channel with the default capacity (64 KB).
557 ///
558 /// No frames are ever dropped. When the buffer is full,
559 /// [`PluginState::push`] returns an error (backpressure). Best for
560 /// transaction logs, control messages, and any data that must arrive
561 /// intact and in order.
562 ///
563 /// # Panics
564 ///
565 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
566 /// or duplicates an already-registered channel name.
567 pub fn channel_ordered(mut self, name: impl Into<String>) -> Self {
568 let name = name.into();
569 validate_channel_name(&name);
570 self.assert_no_duplicate_channel(&name);
571 self.channel_defs
572 .push((name, ChannelKind::Reliable(DEFAULT_CHANNEL_CAPACITY)));
573 self
574 }
575
576 /// Register an ordered channel with a custom byte limit.
577 ///
578 /// A `max_bytes` of `0` means unbounded — the buffer grows without limit.
579 ///
580 /// # Panics
581 ///
582 /// Panics if the name is empty, contains characters outside `[a-zA-Z0-9_-]`,
583 /// or duplicates an already-registered channel name.
584 pub fn channel_ordered_with_capacity(
585 mut self,
586 name: impl Into<String>,
587 max_bytes: usize,
588 ) -> Self {
589 let name = name.into();
590 validate_channel_name(&name);
591 self.assert_no_duplicate_channel(&name);
592 self.channel_defs
593 .push((name, ChannelKind::Reliable(max_bytes)));
594 self
595 }
596
597 // -- Build --------------------------------------------------------------
598
599 /// Build the Tauri v2 plugin.
600 ///
601 /// This consumes the builder and returns a [`TauriPlugin`] that can be
602 /// passed to `tauri::Builder::plugin`.
603 ///
604 /// # Dispatch model
605 ///
606 /// Commands are dispatched through a two-tier system:
607 ///
608 /// 1. **`#[command]` handlers** (registered via [`.handler()`](Self::handler))
609 /// are checked first. These support named parameters, `State<T>` injection,
610 /// `Result` returns, and async — full parity with `#[tauri::command]`.
611 ///
612 /// 2. **Raw Router handlers** (registered via [`.command()`](Self::command),
613 /// [`.command_json()`](Self::command_json), [`.command_binary()`](Self::command_binary))
614 /// are the fallback. These are simpler `Vec<u8> -> Vec<u8>` functions
615 /// with no injection or async support.
616 ///
617 /// If a command name exists in both tiers, the `#[command]` handler takes
618 /// priority and a debug warning is printed.
619 pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
620 let commands = self.commands;
621 let handler_defs = self.handler_defs;
622 let channel_defs = self.channel_defs;
623
624 TauriPluginBuilder::<R>::new("conduit")
625 // --- Custom protocol: conduit://localhost/invoke/<cmd> ---
626 // Uses the asynchronous variant so async #[command] handlers
627 // are spawned on tokio (truly async, like #[tauri::command]).
628 .register_asynchronous_uri_scheme_protocol("conduit", move |ctx, request, responder| {
629 // Handle CORS preflight requests.
630 if request.method() == "OPTIONS" {
631 let resp = http::Response::builder()
632 .status(204)
633 .header("Access-Control-Allow-Origin", "*")
634 .header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
635 .header(
636 "Access-Control-Allow-Headers",
637 "Content-Type, X-Conduit-Key, X-Conduit-Webview",
638 )
639 .header("Access-Control-Max-Age", "86400")
640 .body(Vec::new())
641 .expect("preflight response must not fail");
642 responder.respond(resp);
643 return;
644 }
645
646 // Extract the managed PluginState from the app handle.
647 let state: tauri::State<'_, PluginState<R>> = ctx.app_handle().state();
648
649 // Extract path directly from the URI — zero allocation.
650 let path = request.uri().path();
651 let segments: Vec<&str> = path.trim_start_matches('/').splitn(2, '/').collect();
652
653 if segments.len() != 2 {
654 responder.respond(make_error_response(
655 404,
656 "not found: expected /invoke/<cmd> or /drain/<channel>",
657 ));
658 return;
659 }
660
661 // Validate the invoke key from the X-Conduit-Key header.
662 // Borrow the header value directly — no allocation needed.
663 let key = match request.headers().get("X-Conduit-Key") {
664 Some(v) => match v.to_str() {
665 Ok(s) => s,
666 Err(_) => {
667 responder
668 .respond(make_error_response(401, "invalid invoke key header"));
669 return;
670 }
671 },
672 None => {
673 responder.respond(make_error_response(401, "missing invoke key"));
674 return;
675 }
676 };
677
678 if !state.validate_invoke_key(key) {
679 responder.respond(make_error_response(403, "invalid invoke key"));
680 return;
681 }
682
683 let action = segments[0];
684 let raw_target = segments[1];
685
686 // H6: Percent-decode the target and reject path traversal.
687 let target = percent_decode(raw_target);
688 if target.contains('/') {
689 responder.respond(make_error_response(400, "invalid command name"));
690 return;
691 }
692
693 match action {
694 "invoke" => {
695 let body = request.body().to_vec();
696
697 // 1) Check #[command]-generated handlers first (sync or async)
698 if let Some(handler) = state.handlers.get(&*target) {
699 let handler = Arc::clone(handler);
700 // Extract webview label from X-Conduit-Webview header (sent by JS client).
701 // NOTE: This header is client-provided and could be spoofed by JS
702 // running in the same webview. We validate the format to prevent
703 // injection attacks, but in a multi-webview app, code in one
704 // webview could impersonate another. This matches Tauri's own
705 // trust model where all JS in the webview is equally trusted.
706 let webview_label = request
707 .headers()
708 .get("X-Conduit-Webview")
709 .and_then(|v| v.to_str().ok())
710 .filter(|s| {
711 !s.is_empty()
712 && s.len() <= 128
713 && s.bytes().all(|b| {
714 b.is_ascii_alphanumeric() || b == b'_' || b == b'-'
715 })
716 })
717 .map(|s| s.to_string());
718 // Clone the pre-cached Arc and coerce to trait object —
719 // one atomic increment, no heap allocation.
720 let app_handle_arc: Arc<dyn std::any::Any + Send + Sync> =
721 state.app_handle_arc.clone();
722 let handler_ctx = conduit_core::HandlerContext::new(
723 app_handle_arc,
724 webview_label,
725 );
726 let ctx_any: Arc<dyn std::any::Any + Send + Sync> =
727 Arc::new(handler_ctx);
728
729 // SAFETY: AssertUnwindSafe is used here because:
730 // - `body` is a Vec<u8> (unwind-safe by itself)
731 // - `ctx_any` is an Arc (unwind-safe)
732 // - conduit's own locks use poison-recovery helpers (lock_or_recover)
733 // - User-defined handler state may be left inconsistent after panic,
734 // but this is inherent to catch_unwind and documented as a limitation.
735 let result =
736 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
737 handler.call(body, ctx_any)
738 }));
739
740 match result {
741 Ok(HandlerResponse::Sync(Ok(bytes))) => {
742 responder.respond(make_response(
743 200,
744 "application/octet-stream",
745 bytes,
746 ));
747 }
748 Ok(HandlerResponse::Sync(Err(e))) => {
749 let status = error_to_status(&e);
750 responder
751 .respond(make_error_response(status, &sanitize_error(&e)));
752 }
753 Ok(HandlerResponse::Async(future)) => {
754 // Truly async — spawned on tokio, just like #[tauri::command].
755 // Single spawn with catch_unwind for panic isolation.
756 tauri::async_runtime::spawn(async move {
757 let result = std::panic::AssertUnwindSafe(future)
758 .catch_unwind()
759 .await;
760 match result {
761 Ok(Ok(bytes)) => {
762 responder.respond(make_response(
763 200,
764 "application/octet-stream",
765 bytes,
766 ));
767 }
768 Ok(Err(e)) => {
769 let status = error_to_status(&e);
770 responder.respond(make_error_response(
771 status,
772 &sanitize_error(&e),
773 ));
774 }
775 Err(_) => {
776 // Panic during async handler execution
777 responder.respond(make_error_response(
778 500,
779 "handler panicked",
780 ));
781 }
782 }
783 });
784 }
785 Err(_) => {
786 // Panic caught by catch_unwind — keep as 500.
787 responder.respond(make_error_response(500, "handler panicked"));
788 }
789 }
790 } else {
791 // 2) Fall back to legacy sync Router
792 let dispatch = Arc::clone(&state.dispatch);
793 // Use the app_handle reference from state — no clone needed.
794 let app_handle_ref = &state.app_handle;
795 // SAFETY: AssertUnwindSafe is used here because:
796 // - `body` is a Vec<u8> (unwind-safe by itself)
797 // - `dispatch` is an Arc<Router> (unwind-safe)
798 // - `app_handle_ref` borrows from Tauri state (unwind-safe)
799 // - conduit's own locks use poison-recovery helpers (lock_or_recover)
800 // - User-defined handler state may be left inconsistent after panic,
801 // but this is inherent to catch_unwind and documented as a limitation.
802 let result =
803 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
804 dispatch.call_with_context(&target, body, app_handle_ref)
805 }));
806 match result {
807 Ok(Ok(bytes)) => {
808 responder.respond(make_response(
809 200,
810 "application/octet-stream",
811 bytes,
812 ));
813 }
814 Ok(Err(e)) => {
815 let status = error_to_status(&e);
816 responder
817 .respond(make_error_response(status, &sanitize_error(&e)));
818 }
819 Err(_) => {
820 // Panic caught by catch_unwind — keep as 500.
821 responder.respond(make_error_response(500, "handler panicked"));
822 }
823 }
824 }
825 }
826 "drain" => match state.channel(&target) {
827 Some(ch) => {
828 let blob = ch.drain_all();
829 responder.respond(make_response(200, "application/octet-stream", blob));
830 }
831 None => {
832 responder.respond(make_error_response(
833 404,
834 &format!("unknown channel: {}", sanitize_name(&target)),
835 ));
836 }
837 },
838 _ => {
839 responder.respond(make_error_response(
840 404,
841 "not found: expected /invoke/<cmd> or /drain/<channel>",
842 ));
843 }
844 }
845 })
846 // --- Register Tauri IPC commands ---
847 .invoke_handler(tauri::generate_handler![
848 conduit_bootstrap,
849 conduit_subscribe,
850 ])
851 // --- Plugin setup: create state, register commands ---
852 .setup(move |app, _api| {
853 let dispatch = Arc::new(Router::new());
854
855 // Register all old-style commands that were added via the builder.
856 for register_fn in commands {
857 register_fn(&dispatch);
858 }
859
860 // Build the #[command] handler map, checking for collisions
861 // with Router commands.
862 let mut handler_map = HashMap::new();
863 for (name, handler) in handler_defs {
864 if dispatch.has(&name) {
865 #[cfg(debug_assertions)]
866 eprintln!(
867 "conduit: warning: handler '{name}' shadows a Router command \
868 with the same name — the #[command] handler takes priority"
869 );
870 }
871 handler_map.insert(name, handler);
872 }
873 let handlers = Arc::new(handler_map);
874
875 // Create named channels.
876 let mut channels = HashMap::new();
877 for (name, kind) in channel_defs {
878 let buf = match kind {
879 ChannelKind::Lossy(cap) => ChannelBuffer::Lossy(RingBuffer::new(cap)),
880 ChannelKind::Reliable(max_bytes) => {
881 ChannelBuffer::Reliable(Queue::new(max_bytes))
882 }
883 };
884 channels.insert(name, Arc::new(buf));
885 }
886
887 // Generate the per-launch invoke key.
888 let invoke_key_bytes = generate_invoke_key_bytes();
889 let invoke_key = hex_encode(&invoke_key_bytes);
890
891 // Obtain the app handle for emitting events.
892 let app_handle = app.app_handle().clone();
893 let app_handle_arc = Arc::new(app_handle.clone());
894
895 let state = PluginState {
896 dispatch,
897 handlers,
898 channels,
899 app_handle,
900 app_handle_arc,
901 invoke_key,
902 invoke_key_bytes,
903 };
904
905 app.manage(state);
906
907 Ok(())
908 })
909 .build()
910 }
911}
912
913impl Default for PluginBuilder {
914 fn default() -> Self {
915 Self::new()
916 }
917}
918
919// ---------------------------------------------------------------------------
920// Public init function
921// ---------------------------------------------------------------------------
922
923/// Create a new conduit plugin builder.
924///
925/// This is the main entry point for using the conduit Tauri plugin:
926///
927/// ```rust,ignore
928/// use tauri_conduit::command;
929///
930/// #[command]
931/// fn greet(name: String) -> String {
932/// format!("Hello, {name}!")
933/// }
934///
935/// #[command]
936/// async fn fetch_data(url: String) -> Result<Vec<u8>, String> {
937/// reqwest::get(&url).await.map_err(|e| e.to_string())?
938/// .bytes().await.map(|b| b.to_vec()).map_err(|e| e.to_string())
939/// }
940///
941/// tauri::Builder::default()
942/// .plugin(
943/// tauri_plugin_conduit::init()
944/// .handler("greet", handler!(greet))
945/// .handler("fetch_data", handler!(fetch_data))
946/// .channel("telemetry")
947/// .build()
948/// )
949/// .run(tauri::generate_context!())
950/// .unwrap();
951/// ```
952pub fn init() -> PluginBuilder {
953 PluginBuilder::new()
954}
955
956// ---------------------------------------------------------------------------
957// Helpers
958// ---------------------------------------------------------------------------
959
960/// Map a [`conduit_core::Error`] to the appropriate HTTP status code.
961fn error_to_status(e: &conduit_core::Error) -> u16 {
962 match e {
963 conduit_core::Error::UnknownCommand(_) => 404,
964 conduit_core::Error::UnknownChannel(_) => 404,
965 conduit_core::Error::AuthFailed => 403,
966 conduit_core::Error::DecodeFailed => 400,
967 conduit_core::Error::PayloadTooLarge(_) => 413,
968 conduit_core::Error::Handler(_) => 500,
969 conduit_core::Error::Serialize(_) => 500,
970 conduit_core::Error::ChannelFull => 500,
971 }
972}
973
974/// Truncate a user-supplied name to 64 bytes and strip control characters
975/// to prevent log injection and oversized error messages.
976///
977/// Truncation respects UTF-8 character boundaries — the output is always
978/// valid UTF-8 with at most 64 bytes of text content.
979fn sanitize_name(name: &str) -> String {
980 let truncated = if name.len() > 64 {
981 // Walk back from byte 64 to find a valid char boundary.
982 let mut end = 64;
983 while end > 0 && !name.is_char_boundary(end) {
984 end -= 1;
985 }
986 &name[..end]
987 } else {
988 name
989 };
990 truncated.chars().filter(|c| !c.is_control()).collect()
991}
992
993/// Format a [`conduit_core::Error`] for inclusion in HTTP error responses,
994/// sanitizing any embedded user-supplied names (command or channel names).
995fn sanitize_error(e: &conduit_core::Error) -> String {
996 match e {
997 conduit_core::Error::UnknownCommand(name) => {
998 format!("unknown command: {}", sanitize_name(name))
999 }
1000 conduit_core::Error::UnknownChannel(name) => {
1001 format!("unknown channel: {}", sanitize_name(name))
1002 }
1003 other => other.to_string(),
1004 }
1005}
1006
1007/// Percent-decode a URL path segment (e.g., `hello%20world` → `hello world`).
1008///
1009/// Returns `Cow::Borrowed` when no percent-encoding is present (the common
1010/// case), avoiding a heap allocation entirely.
1011fn percent_decode(input: &str) -> std::borrow::Cow<'_, str> {
1012 // Fast path: no percent-encoded characters — return the input as-is.
1013 if !input.as_bytes().contains(&b'%') {
1014 return std::borrow::Cow::Borrowed(input);
1015 }
1016 let mut result = Vec::with_capacity(input.len());
1017 let bytes = input.as_bytes();
1018 let mut i = 0;
1019 while i < bytes.len() {
1020 if bytes[i] == b'%' && i + 2 < bytes.len() {
1021 if let (Some(hi), Some(lo)) = (hex_val(bytes[i + 1]), hex_val(bytes[i + 2])) {
1022 result.push(hi << 4 | lo);
1023 i += 3;
1024 continue;
1025 }
1026 }
1027 result.push(bytes[i]);
1028 i += 1;
1029 }
1030 std::borrow::Cow::Owned(String::from_utf8_lossy(&result).into_owned())
1031}
1032
1033/// Convert a single ASCII hex character to its 4-bit numeric value.
1034///
1035/// Unlike [`hex_digit_ct`], this does NOT need to be constant-time — it is
1036/// used for URL percent-decoding, not security-critical key validation.
1037fn hex_val(b: u8) -> Option<u8> {
1038 match b {
1039 b'0'..=b'9' => Some(b - b'0'),
1040 b'a'..=b'f' => Some(b - b'a' + 10),
1041 b'A'..=b'F' => Some(b - b'A' + 10),
1042 _ => None,
1043 }
1044}
1045
1046/// Generate 32 random bytes for the per-launch invoke key.
1047fn generate_invoke_key_bytes() -> [u8; 32] {
1048 let mut bytes = [0u8; 32];
1049 getrandom::fill(&mut bytes).expect("conduit: failed to generate invoke key");
1050 bytes
1051}
1052
1053/// Hex-encode a byte slice (no per-byte allocation).
1054fn hex_encode(bytes: &[u8]) -> String {
1055 const HEX: &[u8; 16] = b"0123456789abcdef";
1056 let mut hex = String::with_capacity(bytes.len() * 2);
1057 for &b in bytes {
1058 hex.push(HEX[(b >> 4) as usize] as char);
1059 hex.push(HEX[(b & 0x0f) as usize] as char);
1060 }
1061 hex
1062}
1063
1064/// Hex-decode a string into bytes. Returns `None` on invalid input.
1065///
1066/// This is the non-constant-time version used for non-security paths.
1067/// For invoke key validation, see [`hex_digit_ct`] and the constant-time
1068/// path in [`PluginState::validate_invoke_key`].
1069#[cfg(test)]
1070fn hex_decode(hex: &str) -> Option<Vec<u8>> {
1071 if hex.len() % 2 != 0 {
1072 return None;
1073 }
1074 let mut bytes = Vec::with_capacity(hex.len() / 2);
1075 for chunk in hex.as_bytes().chunks(2) {
1076 let hi = hex_digit(chunk[0])?;
1077 let lo = hex_digit(chunk[1])?;
1078 bytes.push((hi << 4) | lo);
1079 }
1080 Some(bytes)
1081}
1082
1083/// Convert a single ASCII hex character to its 4-bit numeric value.
1084#[cfg(test)]
1085fn hex_digit(b: u8) -> Option<u8> {
1086 match b {
1087 b'0'..=b'9' => Some(b - b'0'),
1088 b'a'..=b'f' => Some(b - b'a' + 10),
1089 b'A'..=b'F' => Some(b - b'A' + 10),
1090 _ => None,
1091 }
1092}
1093
1094/// Validate an invoke key candidate using constant-time operations.
1095///
1096/// The length check (must be exactly 64 hex chars) is not constant-time
1097/// because the expected length is public knowledge. The hex decode and
1098/// byte comparison are fully constant-time: no early returns on invalid
1099/// characters, and the comparison always runs even if decode failed.
1100fn validate_invoke_key_ct(expected: &[u8; 32], candidate: &str) -> bool {
1101 let candidate_bytes = candidate.as_bytes();
1102
1103 // Length is not secret — always 64 hex chars for 32 bytes.
1104 if candidate_bytes.len() != 64 {
1105 return false;
1106 }
1107
1108 // Constant-time hex decode: always process all 32 byte pairs.
1109 let mut decoded = [0u8; 32];
1110 let mut all_valid = 1u8;
1111
1112 for i in 0..32 {
1113 let (hi_val, hi_ok) = hex_digit_ct(candidate_bytes[i * 2]);
1114 let (lo_val, lo_ok) = hex_digit_ct(candidate_bytes[i * 2 + 1]);
1115 decoded[i] = (hi_val << 4) | lo_val;
1116 all_valid &= hi_ok & lo_ok;
1117 }
1118
1119 // Always compare, even if some hex digits were invalid.
1120 let cmp_ok: bool = expected.ct_eq(&decoded).into();
1121
1122 // Combine with bitwise AND — no short-circuit.
1123 (all_valid == 1) & cmp_ok
1124}
1125
1126/// Constant-time hex digit decode for security-critical paths.
1127///
1128/// Returns `(value, valid)` where `valid` is `1` if the byte is a valid
1129/// hex character and `0` otherwise. All operations are bitwise — no
1130/// comparisons, no branches. When invalid, `value` is `0`.
1131///
1132/// Uses subtraction + sign-bit masking on `i16` to produce range-check
1133/// masks without any comparison operators that could compile to branches.
1134fn hex_digit_ct(b: u8) -> (u8, u8) {
1135 // Promote to i16 so wrapping_sub produces a sign bit we can extract.
1136 let b = b as i16;
1137
1138 // Check if b is in '0'..='9' (0x30..=0x39)
1139 let d = b.wrapping_sub(0x30); // b - '0'
1140 // d >= 0 && d < 10: (!d) is negative iff d >= 0; (d-10) is negative iff d < 10.
1141 // Combining via AND and extracting the sign bit gives us a mask.
1142 let digit_mask = ((!d) & (d.wrapping_sub(10))) >> 15;
1143 let digit_mask = (digit_mask & 1) as u8;
1144
1145 // Check if b is in 'a'..='f' (0x61..=0x66)
1146 let l = b.wrapping_sub(0x61); // b - 'a'
1147 let lower_mask = ((!l) & (l.wrapping_sub(6))) >> 15;
1148 let lower_mask = (lower_mask & 1) as u8;
1149
1150 // Check if b is in 'A'..='F' (0x41..=0x46)
1151 let u = b.wrapping_sub(0x41); // b - 'A'
1152 let upper_mask = ((!u) & (u.wrapping_sub(6))) >> 15;
1153 let upper_mask = (upper_mask & 1) as u8;
1154
1155 let val = ((d as u8 & 0x0f) & digit_mask.wrapping_neg())
1156 .wrapping_add((l as u8).wrapping_add(10) & lower_mask.wrapping_neg())
1157 .wrapping_add((u as u8).wrapping_add(10) & upper_mask.wrapping_neg());
1158 let valid = digit_mask | lower_mask | upper_mask;
1159
1160 (val, valid)
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165 use super::*;
1166
1167 // -- hex encode/decode tests --
1168
1169 #[test]
1170 fn hex_encode_roundtrip() {
1171 assert_eq!(hex_encode(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
1172 assert_eq!(hex_encode(&[]), "");
1173 assert_eq!(hex_encode(&[0x00, 0xff]), "00ff");
1174 }
1175
1176 #[test]
1177 fn hex_decode_valid() {
1178 assert_eq!(hex_decode("deadbeef"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1179 assert_eq!(hex_decode(""), Some(vec![]));
1180 assert_eq!(hex_decode("00ff"), Some(vec![0x00, 0xff]));
1181 }
1182
1183 #[test]
1184 fn hex_decode_uppercase() {
1185 assert_eq!(hex_decode("DEADBEEF"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1186 assert_eq!(hex_decode("DeAdBeEf"), Some(vec![0xde, 0xad, 0xbe, 0xef]));
1187 }
1188
1189 #[test]
1190 fn hex_decode_odd_length() {
1191 assert_eq!(hex_decode("abc"), None);
1192 assert_eq!(hex_decode("a"), None);
1193 }
1194
1195 #[test]
1196 fn hex_decode_invalid_chars() {
1197 assert_eq!(hex_decode("zz"), None);
1198 assert_eq!(hex_decode("gg"), None);
1199 assert_eq!(hex_decode("0x"), None);
1200 }
1201
1202 #[test]
1203 fn hex_roundtrip_32_bytes() {
1204 let original = generate_invoke_key_bytes();
1205 let encoded = hex_encode(&original);
1206 assert_eq!(encoded.len(), 64);
1207 let decoded = hex_decode(&encoded).unwrap();
1208 assert_eq!(decoded, original);
1209 }
1210
1211 // -- constant-time hex tests --
1212
1213 #[test]
1214 fn hex_digit_ct_valid_chars() {
1215 for b in b'0'..=b'9' {
1216 let (val, valid) = hex_digit_ct(b);
1217 assert_eq!(valid, 1, "digit {b} should be valid");
1218 assert_eq!(val, b - b'0');
1219 }
1220 for b in b'a'..=b'f' {
1221 let (val, valid) = hex_digit_ct(b);
1222 assert_eq!(valid, 1, "lower {b} should be valid");
1223 assert_eq!(val, b - b'a' + 10);
1224 }
1225 for b in b'A'..=b'F' {
1226 let (val, valid) = hex_digit_ct(b);
1227 assert_eq!(valid, 1, "upper {b} should be valid");
1228 assert_eq!(val, b - b'A' + 10);
1229 }
1230 }
1231
1232 #[test]
1233 fn hex_digit_ct_invalid_chars() {
1234 for &b in &[b'g', b'z', b'G', b'Z', b' ', b'\0', b'/', b':', b'@', b'`'] {
1235 let (_val, valid) = hex_digit_ct(b);
1236 assert_eq!(valid, 0, "char {b} should be invalid");
1237 }
1238 }
1239
1240 #[test]
1241 fn hex_digit_ct_matches_hex_digit() {
1242 for b in 0..=255u8 {
1243 let ct_result = hex_digit_ct(b);
1244 let std_result = hex_digit(b);
1245 match std_result {
1246 Some(v) => {
1247 assert_eq!(ct_result.1, 1, "mismatch at {b}: ct says invalid");
1248 assert_eq!(ct_result.0, v, "value mismatch at {b}");
1249 }
1250 None => {
1251 assert_eq!(ct_result.1, 0, "mismatch at {b}: ct says valid");
1252 }
1253 }
1254 }
1255 }
1256
1257 // -- make_response tests --
1258
1259 #[test]
1260 fn make_response_200() {
1261 let resp = make_response(200, "application/octet-stream", b"hello".to_vec());
1262 assert_eq!(resp.status(), 200);
1263 assert_eq!(resp.body(), b"hello");
1264 }
1265
1266 #[test]
1267 fn make_response_404() {
1268 let resp = make_response(404, "text/plain", b"not found".to_vec());
1269 assert_eq!(resp.status(), 404);
1270 assert_eq!(resp.body(), b"not found");
1271 }
1272
1273 // -- State<T> injection tests --
1274
1275 #[command]
1276 fn with_state(state: tauri::State<'_, String>, name: String) -> String {
1277 format!("{}: {name}", state.as_str())
1278 }
1279
1280 #[test]
1281 fn state_injection_wrong_context_returns_error() {
1282 use conduit_core::ConduitHandler;
1283 use conduit_derive::handler;
1284
1285 let payload = serde_json::to_vec(&serde_json::json!({ "name": "test" })).unwrap();
1286 let wrong_ctx: Arc<dyn std::any::Any + Send + Sync> = Arc::new(());
1287
1288 match handler!(with_state).call(payload, wrong_ctx) {
1289 conduit_core::HandlerResponse::Sync(Err(conduit_core::Error::Handler(msg))) => {
1290 assert!(
1291 msg.contains("handler context must be HandlerContext"),
1292 "unexpected error message: {msg}"
1293 );
1294 }
1295 _ => panic!("expected Sync(Err(Handler))"),
1296 }
1297 }
1298
1299 #[test]
1300 fn original_state_function_preserved() {
1301 // The original function with_state is preserved and callable directly.
1302 // We can't call it without an actual Tauri State, but we can verify
1303 // the function exists and has the right signature by taking a reference.
1304 let _fn_ref: fn(tauri::State<'_, String>, String) -> String = with_state;
1305 }
1306
1307 // -- validate_invoke_key tests --
1308
1309 #[test]
1310 fn validate_invoke_key_correct() {
1311 let key = [0xab_u8; 32];
1312 let hex = hex_encode(&key);
1313 assert!(validate_invoke_key_ct(&key, &hex));
1314 }
1315
1316 #[test]
1317 fn validate_invoke_key_wrong_key() {
1318 let key = [0xab_u8; 32];
1319 let wrong = hex_encode(&[0x00_u8; 32]);
1320 assert!(!validate_invoke_key_ct(&key, &wrong));
1321 }
1322
1323 #[test]
1324 fn validate_invoke_key_wrong_length() {
1325 let key = [0xab_u8; 32];
1326 assert!(!validate_invoke_key_ct(&key, "abcdef"));
1327 assert!(!validate_invoke_key_ct(&key, ""));
1328 assert!(!validate_invoke_key_ct(&key, &"a".repeat(63)));
1329 assert!(!validate_invoke_key_ct(&key, &"a".repeat(65)));
1330 }
1331
1332 #[test]
1333 fn validate_invoke_key_invalid_hex() {
1334 let key = [0xab_u8; 32];
1335 // 64 chars but invalid hex
1336 assert!(!validate_invoke_key_ct(&key, &"zz".repeat(32)));
1337 assert!(!validate_invoke_key_ct(&key, &"gg".repeat(32)));
1338 }
1339
1340 #[test]
1341 fn validate_invoke_key_uppercase_accepted() {
1342 let key = [0xab_u8; 32];
1343 let hex = hex_encode(&key);
1344 // hex_digit_ct handles uppercase, so uppercase of a valid key should match
1345 assert!(validate_invoke_key_ct(&key, &hex.to_uppercase()));
1346 }
1347
1348 #[test]
1349 fn validate_invoke_key_random_roundtrip() {
1350 let key = generate_invoke_key_bytes();
1351 let hex = hex_encode(&key);
1352 assert!(validate_invoke_key_ct(&key, &hex));
1353 }
1354
1355 // -- make_error_response tests --
1356
1357 #[test]
1358 fn make_error_response_json_format() {
1359 let resp = make_error_response(500, "something failed");
1360 assert_eq!(resp.status(), 500);
1361 let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1362 assert_eq!(body["error"], "something failed");
1363 }
1364
1365 #[test]
1366 fn make_error_response_escapes_special_chars() {
1367 let resp = make_error_response(400, r#"bad "input" with \ slash"#);
1368 let body: serde_json::Value = serde_json::from_slice(resp.body()).unwrap();
1369 assert_eq!(body["error"], r#"bad "input" with \ slash"#);
1370 }
1371
1372 // -- percent_decode tests --
1373
1374 #[test]
1375 fn percent_decode_no_encoding() {
1376 assert_eq!(percent_decode("hello"), "hello");
1377 assert_eq!(percent_decode("foo-bar_baz"), "foo-bar_baz");
1378 }
1379
1380 #[test]
1381 fn percent_decode_basic() {
1382 assert_eq!(percent_decode("hello%20world"), "hello world");
1383 assert_eq!(percent_decode("%2F"), "/");
1384 assert_eq!(percent_decode("%2f"), "/");
1385 }
1386
1387 #[test]
1388 fn percent_decode_multiple() {
1389 assert_eq!(percent_decode("a%20b%20c"), "a b c");
1390 assert_eq!(percent_decode("%41%42%43"), "ABC");
1391 }
1392
1393 #[test]
1394 fn percent_decode_incomplete_sequence() {
1395 // Incomplete %XX at end — pass through unchanged.
1396 assert_eq!(percent_decode("hello%2"), "hello%2");
1397 assert_eq!(percent_decode("hello%"), "hello%");
1398 }
1399
1400 #[test]
1401 fn percent_decode_invalid_hex() {
1402 // Invalid hex chars after % — pass through unchanged.
1403 assert_eq!(percent_decode("hello%GG"), "hello%GG");
1404 assert_eq!(percent_decode("%ZZ"), "%ZZ");
1405 }
1406
1407 #[test]
1408 fn percent_decode_empty() {
1409 assert_eq!(percent_decode(""), "");
1410 }
1411
1412 // -- sanitize_name tests --
1413
1414 #[test]
1415 fn sanitize_name_short() {
1416 assert_eq!(sanitize_name("hello"), "hello");
1417 }
1418
1419 #[test]
1420 fn sanitize_name_truncates_long() {
1421 let long = "a".repeat(100);
1422 assert_eq!(sanitize_name(&long).len(), 64);
1423 }
1424
1425 #[test]
1426 fn sanitize_name_strips_control_chars() {
1427 assert_eq!(sanitize_name("hello\x00world"), "helloworld");
1428 assert_eq!(sanitize_name("foo\nbar\rbaz"), "foobarbaz");
1429 }
1430
1431 #[test]
1432 fn sanitize_name_multibyte_utf8() {
1433 // "a" repeated 63 times + "é" (2 bytes: 0xC3 0xA9) = 65 bytes total.
1434 // Byte 64 is the second byte of "é", not a char boundary.
1435 // Must not panic — should truncate to the last valid boundary (63 'a's).
1436 let name = format!("{}{}", "a".repeat(63), "é");
1437 assert_eq!(name.len(), 65);
1438 let sanitized = sanitize_name(&name);
1439 assert_eq!(sanitized, "a".repeat(63));
1440
1441 // 4-byte character crossing the 64-byte boundary.
1442 let name = format!("{}🦀", "a".repeat(62)); // 62 + 4 = 66 bytes
1443 assert_eq!(name.len(), 66);
1444 let sanitized = sanitize_name(&name);
1445 assert_eq!(sanitized, "a".repeat(62));
1446
1447 // Exactly 64 bytes of ASCII — no truncation needed.
1448 let name = "a".repeat(64);
1449 assert_eq!(sanitize_name(&name), "a".repeat(64));
1450 }
1451
1452 // -- error_to_status tests --
1453
1454 #[test]
1455 fn error_to_status_mapping() {
1456 use conduit_core::Error;
1457 assert_eq!(error_to_status(&Error::UnknownCommand("x".into())), 404);
1458 assert_eq!(error_to_status(&Error::UnknownChannel("x".into())), 404);
1459 assert_eq!(error_to_status(&Error::AuthFailed), 403);
1460 assert_eq!(error_to_status(&Error::DecodeFailed), 400);
1461 assert_eq!(error_to_status(&Error::PayloadTooLarge(999)), 413);
1462 assert_eq!(error_to_status(&Error::Handler("x".into())), 500);
1463 assert_eq!(error_to_status(&Error::ChannelFull), 500);
1464 }
1465
1466 // -- channel validation tests --
1467
1468 #[test]
1469 fn validate_channel_name_valid() {
1470 validate_channel_name("telemetry");
1471 validate_channel_name("my-channel");
1472 validate_channel_name("my_channel");
1473 validate_channel_name("Channel123");
1474 validate_channel_name("a");
1475 }
1476
1477 #[test]
1478 #[should_panic(expected = "invalid channel name")]
1479 fn validate_channel_name_empty() {
1480 validate_channel_name("");
1481 }
1482
1483 #[test]
1484 #[should_panic(expected = "invalid channel name")]
1485 fn validate_channel_name_spaces() {
1486 validate_channel_name("my channel");
1487 }
1488
1489 #[test]
1490 #[should_panic(expected = "invalid channel name")]
1491 fn validate_channel_name_special_chars() {
1492 validate_channel_name("my.channel");
1493 }
1494
1495 #[test]
1496 #[should_panic(expected = "duplicate channel name")]
1497 fn duplicate_channel_panics() {
1498 PluginBuilder::new()
1499 .channel("telemetry")
1500 .channel("telemetry");
1501 }
1502
1503 #[test]
1504 #[should_panic(expected = "duplicate channel name")]
1505 fn duplicate_channel_different_kinds_panics() {
1506 PluginBuilder::new().channel("data").channel_ordered("data");
1507 }
1508}