1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::io::AsyncWrite;
7use tokio::sync::{Mutex, oneshot};
8
9const DISPATCH_TIMEOUT: Duration = Duration::from_secs(30);
10
11const MAX_PENDING: usize = 1024;
15
16pub struct BridgeDispatch {
21 pending: Arc<Mutex<HashMap<String, oneshot::Sender<DispatchResult>>>>,
22 writer: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>>,
29}
30
31#[derive(Debug)]
32pub struct DispatchResult {
33 pub data: Option<Value>,
34 pub error: Option<String>,
35}
36
37impl BridgeDispatch {
38 #[must_use]
39 pub fn new<W: AsyncWrite + Send + Unpin + 'static>(writer: W) -> Self {
40 Self {
41 pending: Arc::new(Mutex::new(HashMap::new())),
42 writer: Arc::new(Mutex::new(Box::new(writer))),
43 }
44 }
45
46 #[must_use]
51 pub fn new_sink() -> Self {
52 Self::new(tokio::io::sink())
53 }
54
55 pub async fn dispatch(
62 &self,
63 tab_id: Option<u32>,
64 method: &str,
65 args: Value,
66 ) -> Result<Value, String> {
67 let id = uuid::Uuid::new_v4().to_string();
68
69 let (tx, rx) = oneshot::channel();
70 {
71 let mut pending = self.pending.lock().await;
72 if pending.len() >= MAX_PENDING {
73 return Err(format!(
74 "too many in-flight commands ({MAX_PENDING}); extension unresponsive"
75 ));
76 }
77 pending.insert(id.clone(), tx);
78 }
79
80 let msg = serde_json::json!({
81 "id": id,
82 "type": "execute",
83 "tab_id": tab_id,
84 "method": method,
85 "args": args,
86 });
87
88 {
89 let mut writer = self.writer.lock().await;
90 crate::native_messaging::write_message(&mut *writer, &msg)
91 .await
92 .map_err(|e| format!("native messaging write failed: {e}"))?;
93 }
94
95 match tokio::time::timeout(DISPATCH_TIMEOUT, rx).await {
96 Ok(Ok(result)) => {
97 if let Some(err) = result.error {
98 Err(err)
99 } else {
100 Ok(result.data.unwrap_or(Value::Null))
101 }
102 }
103 Ok(Err(_)) => {
104 self.cleanup_pending(&id).await;
105 Err("extension disconnected while waiting for response".to_string())
106 }
107 Err(_) => {
108 self.cleanup_pending(&id).await;
109 Err(format!(
110 "timeout ({DISPATCH_TIMEOUT:?}) waiting for {method}"
111 ))
112 }
113 }
114 }
115
116 pub async fn on_response(&self, id: &str, data: Option<Value>, error: Option<String>) {
118 let mut pending = self.pending.lock().await;
119 if let Some(tx) = pending.remove(id) {
120 let _ = tx.send(DispatchResult { data, error });
121 }
122 }
123
124 pub async fn cancel_all(&self) {
126 let mut pending = self.pending.lock().await;
127 for (_, tx) in pending.drain() {
128 let _ = tx.send(DispatchResult {
129 data: None,
130 error: Some("extension disconnected".to_string()),
131 });
132 }
133 }
134
135 #[must_use]
136 #[allow(dead_code)]
137 pub async fn pending_count(&self) -> usize {
138 self.pending.lock().await.len()
139 }
140
141 async fn cleanup_pending(&self, id: &str) {
142 let mut pending = self.pending.lock().await;
143 pending.remove(id);
144 }
145
146 pub async fn pending_ids(&self) -> Vec<String> {
148 self.pending.lock().await.keys().cloned().collect()
149 }
150
151 pub async fn register_test_pending(&self, id: &str) -> oneshot::Receiver<DispatchResult> {
153 let (tx, rx) = oneshot::channel();
154 self.pending.lock().await.insert(id.to_string(), tx);
155 rx
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[tokio::test]
164 async fn on_response_resolves_pending() {
165 let dispatch = BridgeDispatch::new_sink();
166
167 let (tx, rx) = oneshot::channel();
168 {
169 let mut pending = dispatch.pending.lock().await;
170 pending.insert("test-123".to_string(), tx);
171 }
172
173 dispatch
174 .on_response("test-123", Some(serde_json::json!({"ok": true})), None)
175 .await;
176
177 let result = rx.await.unwrap();
178 assert!(result.error.is_none());
179 assert_eq!(result.data.unwrap(), serde_json::json!({"ok": true}));
180 }
181
182 #[tokio::test]
183 async fn on_response_with_error() {
184 let dispatch = BridgeDispatch::new_sink();
185
186 let (tx, rx) = oneshot::channel();
187 {
188 let mut pending = dispatch.pending.lock().await;
189 pending.insert("test-456".to_string(), tx);
190 }
191
192 dispatch
193 .on_response("test-456", None, Some("bridge timeout".to_string()))
194 .await;
195
196 let result = rx.await.unwrap();
197 assert_eq!(result.error.unwrap(), "bridge timeout");
198 }
199
200 #[tokio::test]
201 async fn cancel_all_resolves_pending() {
202 let dispatch = BridgeDispatch::new_sink();
203
204 let (tx, rx) = oneshot::channel();
205 {
206 let mut pending = dispatch.pending.lock().await;
207 pending.insert("test-789".to_string(), tx);
208 }
209
210 dispatch.cancel_all().await;
211
212 let result = rx.await.unwrap();
213 assert!(result.error.is_some());
214 assert_eq!(dispatch.pending_count().await, 0);
215 }
216
217 #[tokio::test]
218 async fn unknown_response_id_ignored() {
219 let dispatch = BridgeDispatch::new_sink();
220
221 dispatch
222 .on_response("nonexistent", Some(serde_json::json!({})), None)
223 .await;
224
225 assert_eq!(dispatch.pending_count().await, 0);
226 }
227
228 #[tokio::test]
229 async fn pending_count_tracks_insertions() {
230 let dispatch = BridgeDispatch::new_sink();
231
232 assert_eq!(dispatch.pending_count().await, 0);
233
234 let (tx1, _rx1) = oneshot::channel();
235 let (tx2, _rx2) = oneshot::channel();
236 {
237 let mut pending = dispatch.pending.lock().await;
238 pending.insert("a".to_string(), tx1);
239 pending.insert("b".to_string(), tx2);
240 }
241 assert_eq!(dispatch.pending_count().await, 2);
242
243 dispatch
244 .on_response("a", Some(serde_json::json!({"ok": true})), None)
245 .await;
246 assert_eq!(dispatch.pending_count().await, 1);
247 }
248
249 #[tokio::test]
250 async fn on_response_with_null_data_and_no_error() {
251 let dispatch = BridgeDispatch::new_sink();
252
253 let (tx, rx) = oneshot::channel();
254 {
255 let mut pending = dispatch.pending.lock().await;
256 pending.insert("test-null".to_string(), tx);
257 }
258
259 dispatch.on_response("test-null", None, None).await;
260
261 let result = rx.await.unwrap();
262 assert!(result.data.is_none());
263 assert!(result.error.is_none());
264 }
265
266 #[tokio::test]
267 async fn cancel_all_with_multiple_pending() {
268 let dispatch = BridgeDispatch::new_sink();
269
270 let (tx1, rx1) = oneshot::channel();
271 let (tx2, rx2) = oneshot::channel();
272 let (tx3, rx3) = oneshot::channel();
273 {
274 let mut pending = dispatch.pending.lock().await;
275 pending.insert("a".to_string(), tx1);
276 pending.insert("b".to_string(), tx2);
277 pending.insert("c".to_string(), tx3);
278 }
279
280 dispatch.cancel_all().await;
281 assert_eq!(dispatch.pending_count().await, 0);
282
283 for rx in [rx1, rx2, rx3] {
284 let result = rx.await.unwrap();
285 assert!(result.error.is_some());
286 assert!(result.error.unwrap().contains("disconnected"));
287 }
288 }
289
290 #[tokio::test]
291 async fn cancel_all_on_empty_is_noop() {
292 let dispatch = BridgeDispatch::new_sink();
293 dispatch.cancel_all().await;
294 assert_eq!(dispatch.pending_count().await, 0);
295 }
296
297 #[tokio::test]
300 async fn concurrent_100_pending_insertions_and_resolutions() {
301 let dispatch = Arc::new(BridgeDispatch::new_sink());
302
303 let mut receivers = vec![];
304 for i in 0..100 {
305 let (tx, rx) = oneshot::channel();
306 {
307 let mut pending = dispatch.pending.lock().await;
308 pending.insert(format!("stress-{i}"), tx);
309 }
310 receivers.push((i, rx));
311 }
312 assert_eq!(dispatch.pending_count().await, 100);
313
314 let mut handles = vec![];
315 for i in 0..100 {
316 let d = Arc::clone(&dispatch);
317 handles.push(tokio::spawn(async move {
318 d.on_response(
319 &format!("stress-{i}"),
320 Some(serde_json::json!({"idx": i})),
321 None,
322 )
323 .await;
324 }));
325 }
326
327 for h in handles {
328 h.await.unwrap();
329 }
330
331 assert_eq!(dispatch.pending_count().await, 0);
332 for (i, rx) in receivers {
333 let result = rx.await.unwrap();
334 assert_eq!(result.data.unwrap()["idx"], i);
335 }
336 }
337
338 #[tokio::test]
339 async fn resolve_after_cancel_all_is_noop() {
340 let dispatch = BridgeDispatch::new_sink();
341
342 let (tx, _rx) = oneshot::channel();
343 {
344 let mut pending = dispatch.pending.lock().await;
345 pending.insert("doomed".to_string(), tx);
346 }
347
348 dispatch.cancel_all().await;
349
350 dispatch
352 .on_response("doomed", Some(serde_json::json!({"late": true})), None)
353 .await;
354 assert_eq!(dispatch.pending_count().await, 0);
355 }
356
357 #[tokio::test]
358 async fn duplicate_id_response_only_resolves_once() {
359 let dispatch = BridgeDispatch::new_sink();
360
361 let (tx, rx) = oneshot::channel();
362 {
363 let mut pending = dispatch.pending.lock().await;
364 pending.insert("dup".to_string(), tx);
365 }
366
367 dispatch
368 .on_response("dup", Some(serde_json::json!({"first": true})), None)
369 .await;
370 dispatch
372 .on_response("dup", Some(serde_json::json!({"second": true})), None)
373 .await;
374
375 let result = rx.await.unwrap();
376 assert_eq!(result.data.unwrap()["first"], true);
377 }
378
379 #[tokio::test]
380 async fn cancel_all_then_insert_new() {
381 let dispatch = BridgeDispatch::new_sink();
382
383 let (tx1, rx1) = oneshot::channel();
384 {
385 let mut pending = dispatch.pending.lock().await;
386 pending.insert("before".to_string(), tx1);
387 }
388
389 dispatch.cancel_all().await;
390 let result1 = rx1.await.unwrap();
391 assert!(result1.error.is_some());
392
393 let (tx2, rx2) = oneshot::channel();
395 {
396 let mut pending = dispatch.pending.lock().await;
397 pending.insert("after".to_string(), tx2);
398 }
399 assert_eq!(dispatch.pending_count().await, 1);
400
401 dispatch
402 .on_response("after", Some(serde_json::json!({"ok": true})), None)
403 .await;
404 let result2 = rx2.await.unwrap();
405 assert_eq!(result2.data.unwrap()["ok"], true);
406 }
407
408 #[tokio::test]
409 async fn concurrent_cancel_and_resolve_race() {
410 let dispatch = Arc::new(BridgeDispatch::new_sink());
411
412 for i in 0..50 {
413 let (tx, _rx) = oneshot::channel();
414 let mut pending = dispatch.pending.lock().await;
415 pending.insert(format!("race-{i}"), tx);
416 }
417
418 let d1 = Arc::clone(&dispatch);
419 let cancel_task = tokio::spawn(async move {
420 d1.cancel_all().await;
421 });
422
423 let d2 = Arc::clone(&dispatch);
424 let resolve_task = tokio::spawn(async move {
425 for i in 0..50 {
426 d2.on_response(&format!("race-{i}"), Some(serde_json::json!({})), None)
427 .await;
428 }
429 });
430
431 cancel_task.await.unwrap();
432 resolve_task.await.unwrap();
433
434 assert_eq!(dispatch.pending_count().await, 0);
436 }
437
438 #[tokio::test]
439 async fn on_response_with_both_data_and_error() {
440 let dispatch = BridgeDispatch::new_sink();
441
442 let (tx, rx) = oneshot::channel();
443 {
444 let mut pending = dispatch.pending.lock().await;
445 pending.insert("both".to_string(), tx);
446 }
447
448 dispatch
449 .on_response(
450 "both",
451 Some(serde_json::json!({"partial": true})),
452 Some("also an error".to_string()),
453 )
454 .await;
455
456 let result = rx.await.unwrap();
457 assert!(result.data.is_some());
458 assert!(result.error.is_some());
459 }
460
461 use std::sync::Arc;
462}