1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 false,
251 ).await;
252
253 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
254 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
255 return serde_json::json!({
256 "type": "tool_result",
257 "tool_use_id": tool_id,
258 "content": format!("Tool call blocked by extension: {}", reason),
259 "is_error": true,
260 });
261 };
262
263 let input_for_hook = input.clone();
264 let (result, is_error) = match tool.execute(input, ctx).await {
265 Ok(output) => (output, false),
266 Err(error) => (format!("Tool execution failed: {}", error), true),
267 };
268 let _ = crate::runtime::emit_after_tool_call(
269 hook_bus,
270 &tool_name,
271 Some(&runtime_name),
272 input_for_hook,
273 result.clone(),
274 ).await;
275
276 let mut response = serde_json::json!({
277 "type": "tool_result",
278 "tool_use_id": tool_id,
279 "content": crate::truncate_str(&result, max_tool_output).to_string(),
280 });
281 if is_error {
282 response["is_error"] = serde_json::json!(true);
283 }
284 response
285}
286
287pub async fn complete_provider_with_tools<F>(
288 handler: Arc<dyn ExtensionHandler>,
289 mut params: ProviderCompleteParams,
290 registry: &crate::ToolRegistry,
291 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
292 mut context_factory: F,
293 max_tool_output: usize,
294 max_iterations: usize,
295) -> Result<ProviderCompleteResult, String>
296where
297 F: FnMut() -> crate::ToolContext,
298{
299 let max_iterations = max_iterations.max(1);
300 for iteration in 0..max_iterations {
301 let result = handler.provider_complete(params.clone()).await?;
302 let tool_uses = extract_provider_tool_uses(&result.content)?;
303 if tool_uses.is_empty() {
304 return Ok(result);
305 }
306 if iteration + 1 == max_iterations {
307 return Err(format!(
308 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
309 handler.id(),
310 max_iterations,
311 ));
312 }
313
314 let assistant_content = result.content.clone();
315 params.messages.push(serde_json::json!({
316 "role": "assistant",
317 "content": assistant_content,
318 }));
319
320 let mut tool_results = Vec::with_capacity(tool_uses.len());
321 for tool_use in tool_uses {
322 tool_results.push(execute_provider_tool_use(
323 registry,
324 hook_bus,
325 tool_use,
326 context_factory(),
327 max_tool_output,
328 ).await);
329 }
330 params.messages.push(serde_json::json!({
331 "role": "user",
332 "content": tool_results,
333 }));
334 }
335 Err(format!(
336 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
337 handler.id(),
338 max_iterations,
339 ))
340}
341
342pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
343 let mut tool_uses = Vec::new();
344 for block in content {
345 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
346 continue;
347 }
348 let id = block
349 .get("id")
350 .and_then(Value::as_str)
351 .ok_or_else(|| "provider tool_use missing id".to_string())?;
352 let name = block
353 .get("name")
354 .and_then(Value::as_str)
355 .ok_or_else(|| "provider tool_use missing name".to_string())?;
356 if id.trim().is_empty() {
357 return Err("provider tool_use id is empty".to_string());
358 }
359 if name.trim().is_empty() {
360 return Err("provider tool_use name is empty".to_string());
361 }
362 let input = block
363 .get("input")
364 .cloned()
365 .unwrap_or_else(|| serde_json::json!({}));
366 if !input.is_object() {
367 return Err(format!(
368 "provider tool_use '{}' input must be a JSON object",
369 id
370 ));
371 }
372 tool_uses.push(ProviderToolUse {
373 id: id.to_string(),
374 name: name.to_string(),
375 input,
376 });
377 }
378 Ok(tool_uses)
379}
380
381#[derive(Debug, Clone, Default, PartialEq)]
382pub struct InitializeCapabilitiesResult {
383 pub tools: Vec<RegisteredExtensionToolSpec>,
384 pub providers: Vec<RegisteredProviderSpec>,
385 pub capabilities: Vec<CapabilityDeclaration>,
391}
392
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
402pub struct CapabilityDeclaration {
403 pub kind: String,
407 pub name: String,
410 #[serde(default)]
415 pub permissions: Vec<String>,
416 #[serde(default, skip_serializing_if = "is_null_value")]
420 pub params: serde_json::Value,
421}
422
423fn is_null_value(v: &serde_json::Value) -> bool {
424 v.is_null()
425}
426
427pub fn validate_capability(
437 decl: &CapabilityDeclaration,
438 granted: &crate::extensions::permissions::PermissionSet,
439) -> Result<(), String> {
440 use crate::extensions::permissions::Permission;
441 if decl.kind.trim().is_empty() {
442 return Err("capability 'kind' must be non-empty".to_string());
443 }
444 if decl.name.trim().is_empty() {
445 return Err("capability 'name' must be non-empty".to_string());
446 }
447 for perm_name in &decl.permissions {
448 let parsed = Permission::parse(perm_name).ok_or_else(|| {
449 format!(
450 "capability '{}' declares unknown permission '{}'",
451 decl.kind, perm_name
452 )
453 })?;
454 if !granted.has(parsed) {
455 return Err(format!(
456 "capability '{}' requires permission '{}' but it is not granted",
457 decl.kind, perm_name
458 ));
459 }
460 }
461 Ok(())
462}
463
464#[derive(Deserialize)]
465struct InitializeResult {
466 protocol_version: u32,
467 #[serde(default)]
468 capabilities: InitializeCapabilities,
469}
470
471#[derive(Default, Deserialize)]
472struct InitializeCapabilities {
473 #[serde(default)]
474 tools: Vec<RegisteredExtensionToolSpec>,
475 #[serde(default)]
476 providers: Vec<RegisteredProviderSpec>,
477 #[serde(default)]
479 capabilities: Vec<CapabilityDeclaration>,
480}
481
482#[doc(hidden)]
488#[derive(Debug, Clone)]
489pub struct NotificationFrame {
490 pub method: String,
491 pub params: Value,
492}
493
494struct Inbox {
501 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
502 notification_sink: Mutex<Option<mpsc::UnboundedSender<NotificationFrame>>>,
503 closed: std::sync::atomic::AtomicBool,
506 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
509 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
513 extension_id: String,
515}
516
517impl Inbox {
518 fn new(extension_id: String) -> Self {
519 Self {
520 pending: Mutex::new(HashMap::new()),
521 notification_sink: Mutex::new(None),
522 closed: std::sync::atomic::AtomicBool::new(false),
523 permissions: RwLock::new(None),
524 inbound_stdin: Mutex::new(None),
525 extension_id,
526 }
527 }
528
529 async fn fail_all_pending(&self, reason: &str) {
532 self.closed.store(true, std::sync::atomic::Ordering::Release);
533 let drained: Vec<_> = {
534 let mut pending = self.pending.lock().await;
535 pending.drain().collect()
536 };
537 for (_, tx) in drained {
538 let _ = tx.send(Err(reason.to_string()));
539 }
540 }
541}
542
543struct ProcessState {
544 child: Child,
545 stdin: Arc<Mutex<ChildStdin>>,
546 reader_handle: JoinHandle<()>,
547}
548
549pub struct ProcessExtension {
551 id: String,
552 command: String,
553 args: Vec<String>,
554 cwd: Option<PathBuf>,
555 state: Arc<Mutex<Option<ProcessState>>>,
556 call_lock: Arc<Mutex<()>>,
558 next_id: AtomicU64,
559 restart_count: AtomicUsize,
560 pub(crate) restart_policy: RestartPolicy,
562 inbox: Arc<Inbox>,
566}
567
568impl ProcessExtension {
569 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
570 Self::spawn_with_cwd(id, command, args, None).await
571 }
572
573 pub async fn spawn_with_cwd(
578 id: &str,
579 command: &str,
580 args: &[String],
581 cwd: Option<PathBuf>,
582 ) -> Result<Self, String> {
583 let inbox = Arc::new(Inbox::new(id.to_string()));
584 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
585 Ok(Self {
586 id: id.to_string(),
587 command: command.to_string(),
588 args: args.to_vec(),
589 cwd,
590 state: Arc::new(Mutex::new(Some(state))),
591 call_lock: Arc::new(Mutex::new(())),
592 next_id: AtomicU64::new(1),
593 restart_count: AtomicUsize::new(0),
594 restart_policy: RestartPolicy::default(),
595 inbox,
596 })
597 }
598
599 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
601 self.restart_policy = policy;
602 self
603 }
604
605 async fn spawn_state(
606 id: &str,
607 command: &str,
608 args: &[String],
609 cwd: Option<&PathBuf>,
610 inbox: Arc<Inbox>,
611 ) -> Result<ProcessState, String> {
612 let mut cmd = Command::new(command);
613 cmd.args(args)
614 .stdin(Stdio::piped())
615 .stdout(Stdio::piped())
616 .stderr(Stdio::piped());
617 if let Some(cwd) = cwd {
618 cmd.current_dir(cwd);
619 }
620
621 cmd.env_clear();
627 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
628 if let Ok(val) = std::env::var(var) {
629 cmd.env(var, val);
630 }
631 }
632
633 cmd.kill_on_drop(true);
634
635 let mut child = cmd
636 .spawn()
637 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
638
639 let stdin = child
640 .stdin
641 .take()
642 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
643 let stdout = child
644 .stdout
645 .take()
646 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
647 if let Some(stderr) = child.stderr.take() {
648 let extension_id = id.to_string();
649 tokio::spawn(async move {
650 let mut lines = BufReader::new(stderr).lines();
651 loop {
652 match lines.next_line().await {
653 Ok(Some(line)) => {
654 tracing::debug!(extension = %extension_id, stderr = %line);
655 }
656 Ok(None) => break,
657 Err(error) => {
658 tracing::debug!(
659 extension = %extension_id,
660 error = %error,
661 "Failed to read extension stderr",
662 );
663 break;
664 }
665 }
666 }
667 });
668 }
669
670 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
671
672 let stdin_arc = Arc::new(Mutex::new(stdin));
673 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
676
677 Ok(ProcessState {
678 child,
679 stdin: stdin_arc,
680 reader_handle,
681 })
682 }
683
684 fn spawn_reader(
689 stdout: ChildStdout,
690 inbox: Arc<Inbox>,
691 extension_id: String,
692 ) -> JoinHandle<()> {
693 tokio::spawn(async move {
694 let mut reader = BufReader::new(stdout);
695 loop {
696 match Self::read_one_frame(&mut reader, &extension_id).await {
697 Ok(Some(value)) => {
698 Self::dispatch_frame(value, &inbox, &extension_id).await;
699 }
700 Ok(None) => {
701 tracing::debug!(
702 extension = %extension_id,
703 "Extension stdout closed (EOF); failing pending requests",
704 );
705 inbox.fail_all_pending("transport closed: EOF").await;
706 inbox.notification_sink.lock().await.take();
708 return;
709 }
710 Err(error) => {
711 tracing::debug!(
712 extension = %extension_id,
713 error = %error,
714 "Extension transport read error",
715 );
716 inbox
717 .fail_all_pending(&format!("transport error: {}", error))
718 .await;
719 inbox.notification_sink.lock().await.take();
720 return;
721 }
722 }
723 }
724 })
725 }
726
727 async fn read_one_frame(
731 reader: &mut BufReader<ChildStdout>,
732 extension_id: &str,
733 ) -> Result<Option<Value>, String> {
734 let mut content_length: Option<usize> = None;
735 let mut saw_any_header = false;
736 loop {
737 let mut header_line = String::new();
738 let n = reader
739 .read_line(&mut header_line)
740 .await
741 .map_err(|e| format!("Read header error: {}", e))?;
742 if n == 0 {
743 if saw_any_header {
744 return Err("Unexpected EOF while reading response headers".into());
745 }
746 return Ok(None);
747 }
748 saw_any_header = true;
749 if header_line.len() > 1024 {
750 return Err(format!(
751 "Extension '{}' header line too long ({} bytes)",
752 extension_id,
753 header_line.len()
754 ));
755 }
756 let trimmed = header_line.trim();
757 if trimmed.is_empty() {
758 break;
759 }
760 if let Some((name, value)) = trimmed.split_once(':') {
761 if name.trim().eq_ignore_ascii_case("Content-Length") {
762 content_length = Some(value.trim().parse().map_err(|_| {
763 format!("Invalid Content-Length value: {:?}", value.trim())
764 })?);
765 }
766 }
767 }
768 let content_length = content_length.ok_or_else(|| {
769 format!(
770 "Extension '{}' frame missing Content-Length header",
771 extension_id
772 )
773 })?;
774 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
775 if content_length > MAX_RESPONSE_SIZE {
776 return Err(format!(
777 "Extension '{}' frame too large: {} bytes (max {})",
778 extension_id, content_length, MAX_RESPONSE_SIZE
779 ));
780 }
781 let mut buf = vec![0u8; content_length];
782 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
783 .await
784 .map_err(|e| format!("Read body error: {}", e))?;
785 let value: Value = serde_json::from_slice(&buf)
786 .map_err(|e| format!("Parse frame error: {}", e))?;
787 Ok(Some(value))
788 }
789
790 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
796 let id_field = value.get("id");
797 let id_is_present = !matches!(id_field, None | Some(Value::Null));
798 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
799
800 if id_is_present && method_field.is_some() {
801 let id = match id_field.and_then(Value::as_u64) {
804 Some(id) => id,
805 None => {
806 tracing::trace!(
807 extension = %extension_id,
808 frame = %value,
809 "Discarding inbound request with non-numeric id",
810 );
811 return;
812 }
813 };
814 let Some(method) = method_field else { return };
815 let params = value.get("params").cloned().unwrap_or(Value::Null);
816 let inbox = inbox.clone();
817 let extension_id = extension_id.to_string();
818 tokio::spawn(async move {
819 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
820 let payload = match outcome {
821 Ok(result) => serde_json::json!({
822 "jsonrpc": "2.0",
823 "id": id,
824 "result": result,
825 }),
826 Err((code, message)) => serde_json::json!({
827 "jsonrpc": "2.0",
828 "id": id,
829 "error": {"code": code, "message": message},
830 }),
831 };
832 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
833 if let Some(stdin) = stdin_handle {
834 let body = match serde_json::to_string(&payload) {
835 Ok(s) => s,
836 Err(error) => {
837 tracing::warn!(
838 extension = %extension_id,
839 error = %error,
840 "Failed to serialize inbound response",
841 );
842 return;
843 }
844 };
845 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
846 let mut stdin = stdin.lock().await;
847 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
848 tracing::warn!(
849 extension = %extension_id,
850 error = %error,
851 "Failed to write inbound response",
852 );
853 return;
854 }
855 if let Err(error) = stdin.flush().await {
856 tracing::warn!(
857 extension = %extension_id,
858 error = %error,
859 "Failed to flush inbound response",
860 );
861 }
862 } else {
863 tracing::warn!(
864 extension = %extension_id,
865 "No stdin available to reply to inbound request",
866 );
867 }
868 });
869 return;
870 }
871
872 if id_is_present {
873 let id = match id_field.and_then(Value::as_u64) {
874 Some(id) => id,
875 None => {
876 tracing::trace!(
877 extension = %extension_id,
878 frame = %value,
879 "Discarding frame with non-numeric id",
880 );
881 return;
882 }
883 };
884 let sender = inbox.pending.lock().await.remove(&id);
885 match sender {
886 Some(tx) => {
887 let payload = if let Some(err) = value.get("error") {
888 let message = err
889 .get("message")
890 .and_then(Value::as_str)
891 .unwrap_or("unknown extension error")
892 .to_string();
893 Err(format!("Extension error: {}", message))
894 } else {
895 Ok(value
896 .get("result")
897 .cloned()
898 .unwrap_or(Value::Null))
899 };
900 let _ = tx.send(payload);
901 }
902 None => {
903 tracing::trace!(
904 extension = %extension_id,
905 id = id,
906 "Response with unknown id (no pending request); dropping",
907 );
908 }
909 }
910 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
911 let params = value.get("params").cloned().unwrap_or(Value::Null);
912 let frame = NotificationFrame {
913 method: method.to_string(),
914 params,
915 };
916 let mut sink_guard = inbox.notification_sink.lock().await;
917 if let Some(sink) = sink_guard.as_ref() {
918 if sink.send(frame).is_err() {
919 sink_guard.take();
921 }
922 } else {
923 tracing::trace!(
924 extension = %extension_id,
925 method = %method,
926 "Notification with no active subscriber; dropping",
927 );
928 }
929 } else {
930 tracing::trace!(
931 extension = %extension_id,
932 frame = %value,
933 "Unrecognized frame; dropping",
934 );
935 }
936 }
937
938 pub fn restart_count(&self) -> usize {
939 self.restart_count.load(Ordering::Relaxed)
940 }
941
942 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
945 *self.inbox.permissions.write().await = Some(perms);
946 }
947
948 #[allow(clippy::doc_lazy_continuation)]
956 async fn handle_inbound_request(
957 inbox: &Arc<Inbox>,
958 method: &str,
959 params: Value,
960 ) -> Result<Value, (i32, String)> {
961 use crate::extensions::permissions::Permission;
962 use crate::memory::store::{self, MemoryQuery};
963
964 match method {
965 "memory.append" => {
966 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
967 let namespace = Self::param_str(¶ms, "namespace")?;
968 Self::require_namespace_matches(inbox, &namespace).await?;
969 let content = Self::param_str(¶ms, "content")?;
970 let tags = match params.get("tags") {
971 None | Some(Value::Null) => Vec::new(),
972 Some(Value::Array(arr)) => {
973 let mut out = Vec::with_capacity(arr.len());
974 for v in arr {
975 match v.as_str() {
976 Some(s) => out.push(s.to_string()),
977 None => {
978 return Err((
979 -32602,
980 "tags must be an array of strings".to_string(),
981 ))
982 }
983 }
984 }
985 out
986 }
987 _ => {
988 return Err((
989 -32602,
990 "tags must be an array of strings".to_string(),
991 ))
992 }
993 };
994 let meta = match params.get("meta") {
995 None | Some(Value::Null) => None,
996 Some(v) => Some(v.clone()),
997 };
998 let record = store::new_record(namespace, content, tags, meta);
999 let timestamp_ms = record.timestamp_ms;
1000 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
1001 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1002 }
1003 "memory.query" => {
1004 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1005 let namespace = Self::param_str(¶ms, "namespace")?;
1006 Self::require_namespace_matches(inbox, &namespace).await?;
1007 let q = MemoryQuery {
1008 content_contains: params
1009 .get("content_contains")
1010 .and_then(Value::as_str)
1011 .map(str::to_string),
1012 tag_prefix: params
1013 .get("tag_prefix")
1014 .and_then(Value::as_str)
1015 .map(str::to_string),
1016 since_ms: params.get("since_ms").and_then(Value::as_u64),
1017 until_ms: params.get("until_ms").and_then(Value::as_u64),
1018 limit: params
1019 .get("limit")
1020 .and_then(Value::as_u64)
1021 .map(|n| n as usize),
1022 };
1023 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1024 Ok(serde_json::json!({"records": records}))
1025 }
1026 "config.get" => {
1027 let key = Self::param_str(¶ms, "key")?;
1028 Self::validate_config_key(&key)?;
1029 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1030 Ok(serde_json::json!({"value": value}))
1031 }
1032 "config.set" => {
1033 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1034 let key = Self::param_str(¶ms, "key")?;
1035 Self::validate_config_key(&key)?;
1036 let value = Self::param_str(¶ms, "value")?;
1037 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1038 .map_err(|e| (-32000, e.to_string()))?;
1039 Ok(serde_json::json!({"ok": true}))
1040 }
1041 "config.subscribe" => {
1042 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1043 Ok(serde_json::json!({"ok": true}))
1047 }
1048 other => Err((-32601, format!("method not found: {other}"))),
1049 }
1050 }
1051
1052 async fn require_permission(
1053 inbox: &Arc<Inbox>,
1054 perm: crate::extensions::permissions::Permission,
1055 wire: &str,
1056 ) -> Result<(), (i32, String)> {
1057 let guard = inbox.permissions.read().await;
1058 match guard.as_ref() {
1059 Some(set) if set.has(perm) => Ok(()),
1060 _ => Err((
1061 -32602,
1062 format!("permission denied: {wire} required"),
1063 )),
1064 }
1065 }
1066
1067 async fn require_namespace_matches(
1068 inbox: &Arc<Inbox>,
1069 namespace: &str,
1070 ) -> Result<(), (i32, String)> {
1071 if namespace == inbox.extension_id {
1072 Ok(())
1073 } else {
1074 Err((
1075 -32602,
1076 format!(
1077 "namespace must equal extension id '{}' (got '{}')",
1078 inbox.extension_id, namespace
1079 ),
1080 ))
1081 }
1082 }
1083
1084 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1085 params
1086 .get(name)
1087 .and_then(Value::as_str)
1088 .map(str::to_string)
1089 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1090 }
1091
1092 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1093 let trimmed = key.trim();
1094 if trimmed.is_empty() {
1095 return Err((-32602, "config key must be non-empty".to_string()));
1096 }
1097 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1098 return Err((
1099 -32602,
1100 "config key must not contain dots, slashes, or spaces".to_string(),
1101 ));
1102 }
1103 Ok(())
1104 }
1105
1106 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1107 let params = InitializeParams {
1108 synaps_version: env!("CARGO_PKG_VERSION"),
1109 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1110 plugin_id: self.id.clone(),
1111 plugin_root: plugin_root
1112 .or_else(|| self.cwd.clone())
1113 .map(|path| path.to_string_lossy().to_string()),
1114 config,
1115 };
1116 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1117 Self::parse_initialize_result(&self.id, value)
1118 }
1119
1120 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1121 let result: InitializeResult = serde_json::from_value(value)
1122 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1123 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1124 return Err(format!(
1125 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1126 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1127 ));
1128 }
1129 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1130 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1131 Ok(InitializeCapabilitiesResult {
1132 tools: result.capabilities.tools,
1133 providers: result.capabilities.providers,
1134 capabilities: result.capabilities.capabilities,
1135 })
1136 }
1137
1138 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1139 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1140 let mut names = HashSet::new();
1141 for tool in tools {
1142 let name = tool.name.trim();
1143 if let Err(err) = validate_id_segment(name) {
1144 return Err(match err {
1145 IdValidationError::Empty => format!(
1146 "Extension '{}' registered a tool with an empty tool name",
1147 id
1148 ),
1149 IdValidationError::ContainsReserved { ch } => format!(
1150 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1151 id, name, ch
1152 ),
1153 IdValidationError::TooLong { len, max } => format!(
1154 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1155 id, name, max, len
1156 ),
1157 IdValidationError::ContainsWhitespace => format!(
1158 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1159 id, name
1160 ),
1161 IdValidationError::ContainsControl { ch } => format!(
1162 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1163 id, name, ch as u32
1164 ),
1165 });
1166 }
1167 if !names.insert(name.to_string()) {
1168 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1169 }
1170 if tool.description.trim().is_empty() {
1171 return Err(format!(
1172 "Extension '{}' registered tool '{}' with an empty description",
1173 id, name,
1174 ));
1175 }
1176 if !tool.input_schema.is_object() {
1177 return Err(format!(
1178 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1179 id, name,
1180 ));
1181 }
1182 }
1183 Ok(())
1184 }
1185
1186 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1187 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1188 for provider in providers {
1189 let provider_id = provider.id.trim();
1190 match validate_id_segment(provider_id) {
1191 Ok(()) => {
1192 if !Self::is_safe_provider_id(provider_id) {
1193 return Err(format!(
1194 "Extension '{}' registered provider '{}' with invalid provider id",
1195 id, provider_id
1196 ));
1197 }
1198 }
1199 Err(IdValidationError::Empty) => {
1200 return Err(format!(
1201 "Extension '{}' registered provider with empty provider id",
1202 id
1203 ));
1204 }
1205 Err(err) => {
1206 return Err(format!(
1207 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1208 id, provider_id, err
1209 ));
1210 }
1211 }
1212 if provider.display_name.trim().is_empty() {
1213 return Err(format!(
1214 "Extension '{}' registered provider '{}' with empty display_name",
1215 id, provider_id,
1216 ));
1217 }
1218 if provider.description.trim().is_empty() {
1219 return Err(format!(
1220 "Extension '{}' registered provider '{}' with empty description",
1221 id, provider_id,
1222 ));
1223 }
1224 if provider.models.is_empty() {
1225 return Err(format!(
1226 "Extension '{}' registered provider '{}' must declare at least one model",
1227 id, provider_id,
1228 ));
1229 }
1230 let mut model_ids = HashSet::new();
1231 for model in &provider.models {
1232 let model_id = model.id.trim();
1233 if let Err(err) = validate_id_segment(model_id) {
1234 return Err(match err {
1235 IdValidationError::Empty => format!(
1236 "Extension '{}' registered provider '{}' with empty model id",
1237 id, provider_id
1238 ),
1239 IdValidationError::ContainsReserved { ch } => format!(
1240 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1241 id, provider_id, model_id, ch
1242 ),
1243 IdValidationError::TooLong { len, max } => format!(
1244 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1245 id, provider_id, model_id, max, len
1246 ),
1247 IdValidationError::ContainsWhitespace => format!(
1248 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1249 id, provider_id, model_id
1250 ),
1251 IdValidationError::ContainsControl { ch } => format!(
1252 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1253 id, provider_id, model_id, ch as u32
1254 ),
1255 });
1256 }
1257 if !model_ids.insert(model_id.to_string()) {
1258 return Err(format!(
1259 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1260 id, provider_id, model_id,
1261 ));
1262 }
1263 }
1264 if let Some(config_schema) = &provider.config_schema {
1265 if !config_schema.is_object() {
1266 return Err(format!(
1267 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1268 id, provider_id,
1269 ));
1270 }
1271 }
1272 }
1273 Ok(())
1274 }
1275
1276 fn is_safe_provider_id(id: &str) -> bool {
1277 !id.is_empty()
1278 && !id.contains(':')
1279 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1280 }
1281
1282 #[doc(hidden)]
1283 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1284 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1285 }
1286
1287 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1288 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1289 let max_attempts = self.restart_policy.max_attempts;
1290 if attempted > max_attempts as usize {
1291 *state = None;
1292 return Err(format!(
1293 "Extension '{}' exceeded restart limit ({})",
1294 self.id, max_attempts,
1295 ));
1296 }
1297
1298 if let Some(old) = state.take() {
1299 old.reader_handle.abort();
1300 let mut child = old.child;
1301 let _ = child.kill().await;
1302 }
1303 self.inbox
1305 .fail_all_pending("transport closed: process restarting")
1306 .await;
1307
1308 let delay = self
1309 .restart_policy
1310 .delay_for_attempt(attempted as u32)
1311 .unwrap_or_default();
1312
1313 tracing::warn!(
1314 extension = %self.id,
1315 attempt = attempted,
1316 max_attempts = max_attempts,
1317 delay_ms = delay.as_millis() as u64,
1318 "Restarting extension process after transport failure",
1319 );
1320
1321 if !delay.is_zero() {
1322 tokio::time::sleep(delay).await;
1323 }
1324
1325 *state = Some(Self::spawn_state(
1326 &self.id,
1327 &self.command,
1328 &self.args,
1329 self.cwd.as_ref(),
1330 self.inbox.clone(),
1331 ).await?);
1332 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1334 self.initialize_locked(state).await?;
1335 self.restart_count.store(0, Ordering::Relaxed);
1338 Ok(())
1339 }
1340
1341
1342 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1343 let params = InitializeParams {
1344 synaps_version: env!("CARGO_PKG_VERSION"),
1345 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1346 plugin_id: self.id.clone(),
1347 plugin_root: self.cwd
1348 .clone()
1349 .map(|path| path.to_string_lossy().to_string()),
1350 config: Value::Object(Default::default()),
1351 };
1352 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1353 let value = tokio::time::timeout(
1354 std::time::Duration::from_secs(10),
1355 self.call_once_locked(
1356 state.as_mut().expect("state should exist for initialize"),
1357 "initialize",
1358 serde_json::to_value(params).map_err(|e| e.to_string())?,
1359 id,
1360 ),
1361 )
1362 .await
1363 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1364 ?;
1365 Self::parse_initialize_result(&self.id, value).map(|_| ())
1366 }
1367
1368 async fn call_once_locked(
1372 &self,
1373 state: &mut ProcessState,
1374 method: &str,
1375 params: Value,
1376 id: u64,
1377 ) -> Result<Value, String> {
1378 let body = serde_json::to_string(&JsonRpcRequest {
1379 jsonrpc: "2.0",
1380 method: method.to_string(),
1381 params,
1382 id,
1383 })
1384 .map_err(|e| format!("Serialize error: {}", e))?;
1385
1386 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1387 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1390 return Err("transport closed: inbox is shut down".to_string());
1391 }
1392
1393 self.inbox.pending.lock().await.insert(id, tx);
1396
1397 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1400 self.inbox.pending.lock().await.remove(&id);
1401 return Err("transport closed: inbox shut down during registration".to_string());
1402 }
1403
1404 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1405 let write_result = {
1406 let mut stdin = state.stdin.lock().await;
1407 match stdin.write_all(frame.as_bytes()).await {
1408 Ok(()) => stdin.flush().await,
1409 Err(e) => Err(e),
1410 }
1411 };
1412 if let Err(e) = write_result {
1413 self.inbox.pending.lock().await.remove(&id);
1415 return Err(format!("Write error: {}", e));
1416 }
1417
1418 match rx.await {
1419 Ok(payload) => payload,
1420 Err(_) => {
1421 self.inbox.pending.lock().await.remove(&id);
1426 Err("transport closed: response channel dropped".to_string())
1427 }
1428 }
1429 }
1430
1431 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1432 let _call_guard = self.call_lock.lock().await;
1433 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1434 let mut state_guard = self.state.lock().await;
1435 if state_guard.is_none() {
1436 *state_guard = Some(Self::spawn_state(
1437 &self.id,
1438 &self.command,
1439 &self.args,
1440 self.cwd.as_ref(),
1441 self.inbox.clone(),
1442 ).await?);
1443 }
1444 self.call_once_locked(
1445 state_guard.as_mut().expect("state should exist"),
1446 method,
1447 params,
1448 id,
1449 ).await
1450 }
1451
1452 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1453 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1454 let id_str = self.id.clone();
1455 let method_str = method.to_string();
1456
1457 let result = tokio::time::timeout(
1458 std::time::Duration::from_secs(timeout_secs),
1459 self.call_inner(method, params),
1460 )
1461 .await;
1462
1463 match result {
1464 Ok(inner) => inner,
1465 Err(_) => Err(format!(
1466 "Extension '{}' method '{}' timed out after {}s",
1467 id_str, method_str, timeout_secs
1468 )),
1469 }
1470 }
1471
1472 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1473 let _call_guard = self.call_lock.lock().await;
1474 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1475 let mut state_guard = self.state.lock().await;
1476 if state_guard.is_none() {
1477 self.restart_locked(&mut state_guard).await?;
1478 }
1479
1480 let result = self
1481 .call_once_locked(
1482 state_guard.as_mut().expect("state should exist after restart"),
1483 method,
1484 params.clone(),
1485 id,
1486 )
1487 .await;
1488
1489 match result {
1490 Ok(value) => Ok(value),
1491 Err(first_error) => {
1492 self.restart_locked(&mut state_guard).await?;
1493 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1494 self.call_once_locked(
1495 state_guard.as_mut().expect("state should exist after restart"),
1496 method,
1497 params,
1498 retry_id,
1499 )
1500 .await
1501 .map_err(|retry_error| {
1502 format!("{}; retry after restart failed: {}", first_error, retry_error)
1503 })
1504 }
1505 }
1506 }
1507
1508 #[doc(hidden)]
1521 pub async fn subscribe_notifications(&self) -> mpsc::UnboundedReceiver<NotificationFrame> {
1522 let (tx, rx) = mpsc::unbounded_channel();
1523 let mut sink = self.inbox.notification_sink.lock().await;
1524 *sink = Some(tx);
1525 rx
1526 }
1527
1528 #[doc(hidden)]
1530 pub async fn unsubscribe_notifications(&self) {
1531 self.inbox.notification_sink.lock().await.take();
1532 }
1533
1534 pub(crate) fn forward_invoke_command_frame(
1544 extension_id: &str,
1545 request_id: &str,
1546 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1547 sink_open: &mut bool,
1548 frame: NotificationFrame,
1549 ) -> bool {
1550 use crate::extensions::commands::parse_command_output;
1551 use crate::extensions::tasks::{is_task_method, parse_task_event};
1552 use crate::extensions::runtime::InvokeCommandEvent;
1553
1554 let mut saw_done = false;
1555 if frame.method == "command.output" {
1556 match parse_command_output(&frame.params) {
1557 Ok(parsed) if parsed.request_id == request_id => {
1558 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1559 saw_done = true;
1560 }
1561 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1562 *sink_open = false;
1563 }
1564 }
1565 Ok(_) => {
1566 tracing::trace!(
1568 extension = %extension_id,
1569 "Ignoring command.output for unrelated request_id",
1570 );
1571 }
1572 Err(error) => {
1573 tracing::warn!(
1574 extension = %extension_id,
1575 error = %error,
1576 params = %frame.params,
1577 "Skipping malformed command.output notification",
1578 );
1579 }
1580 }
1581 } else if is_task_method(&frame.method) {
1582 match parse_task_event(&frame.method, &frame.params) {
1583 Ok(event) => {
1584 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1585 *sink_open = false;
1586 }
1587 }
1588 Err(error) => {
1589 tracing::warn!(
1590 extension = %extension_id,
1591 method = %frame.method,
1592 error = %error,
1593 params = %frame.params,
1594 "Skipping malformed task notification",
1595 );
1596 }
1597 }
1598 } else {
1599 tracing::trace!(
1600 extension = %extension_id,
1601 method = %frame.method,
1602 "Ignoring non-command/task notification during command.invoke",
1603 );
1604 }
1605 saw_done
1606 }
1607
1608 fn forward_provider_stream_frame(
1615 extension_id: &str,
1616 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1617 sink_open: &mut bool,
1618 frame: NotificationFrame,
1619 ) {
1620 if frame.method != "provider.stream.event" {
1621 tracing::trace!(
1622 extension = %extension_id,
1623 method = %frame.method,
1624 "Ignoring non-stream notification during provider.stream",
1625 );
1626 return;
1627 }
1628 match parse_provider_stream_event(&frame.params) {
1629 Ok(event) => {
1630 if *sink_open && sink.send(event).is_err() {
1631 *sink_open = false;
1632 }
1633 }
1634 Err(error) => {
1635 tracing::warn!(
1636 extension = %extension_id,
1637 error = %error,
1638 params = %frame.params,
1639 "Skipping malformed provider.stream.event notification",
1640 );
1641 }
1642 }
1643 }
1644}
1645
1646#[async_trait]
1647impl ExtensionHandler for ProcessExtension {
1648 fn id(&self) -> &str {
1649 &self.id
1650 }
1651
1652 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1653 self.call("tool.call", serde_json::json!({
1654 "name": name,
1655 "input": input,
1656 })).await
1657 }
1658
1659 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1660 let value = tokio::time::timeout(
1661 std::time::Duration::from_secs(60),
1662 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1663 )
1664 .await
1665 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1666 let result: ProviderCompleteResult = serde_json::from_value(value)
1667 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1668 if result.content.is_empty() {
1669 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1670 }
1671 Ok(result)
1672 }
1673
1674 async fn provider_stream(
1675 &self,
1676 params: ProviderCompleteParams,
1677 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1678 ) -> Result<ProviderCompleteResult, String> {
1679 let mut rx = self.subscribe_notifications().await;
1682 let params_value =
1683 serde_json::to_value(params).map_err(|e| e.to_string())?;
1684
1685 let extension_id = self.id.clone();
1686 let stream_future = async {
1687 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1688 let mut sink_open = true;
1689 let response = loop {
1690 tokio::select! {
1691 response = &mut call_fut => break response,
1692 Some(frame) = rx.recv() => {
1693 Self::forward_provider_stream_frame(
1694 &extension_id, &sink, &mut sink_open, frame,
1695 );
1696 }
1697 }
1698 };
1699 self.unsubscribe_notifications().await;
1703 while let Some(frame) = rx.recv().await {
1704 Self::forward_provider_stream_frame(
1705 &extension_id, &sink, &mut sink_open, frame,
1706 );
1707 }
1708 response
1709 };
1710
1711 let outcome = tokio::time::timeout(
1712 std::time::Duration::from_secs(60),
1713 stream_future,
1714 )
1715 .await;
1716
1717 self.unsubscribe_notifications().await;
1719
1720 let value = outcome
1721 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1722
1723 let result: ProviderCompleteResult = serde_json::from_value(value)
1724 .map_err(|e| {
1725 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1726 })?;
1727 Ok(result)
1730 }
1731
1732 async fn invoke_command(
1733 &self,
1734 command: &str,
1735 args: Vec<String>,
1736 request_id: &str,
1737 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1738 ) -> Result<Value, String> {
1739 let mut rx = self.subscribe_notifications().await;
1741 let params = serde_json::json!({
1742 "command": command,
1743 "args": args,
1744 "request_id": request_id,
1745 });
1746
1747 let extension_id = self.id.clone();
1748 let request_id_owned = request_id.to_string();
1749 let invoke_future = async {
1750 let mut call_fut = Box::pin(self.call("command.invoke", params));
1751 let mut sink_open = true;
1752 let response = loop {
1753 tokio::select! {
1754 response = &mut call_fut => break response,
1755 Some(frame) = rx.recv() => {
1756 let _ = Self::forward_invoke_command_frame(
1757 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1758 );
1759 }
1760 }
1761 };
1762 self.unsubscribe_notifications().await;
1766 while let Ok(frame) = rx.try_recv() {
1767 let _ = Self::forward_invoke_command_frame(
1768 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1769 );
1770 }
1771 response
1772 };
1773
1774 let outcome = tokio::time::timeout(
1775 std::time::Duration::from_secs(120),
1776 invoke_future,
1777 )
1778 .await;
1779
1780 self.unsubscribe_notifications().await;
1782
1783 outcome
1784 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1785 }
1786
1787 async fn handle(&self, event: &HookEvent) -> HookResult {
1788 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1789 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1790 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1791 Ok(result) => result,
1792 Err(error) => {
1793 tracing::warn!(
1794 extension = %self.id,
1795 error = %error,
1796 response = %value,
1797 "Extension hook handler returned invalid result",
1798 );
1799 if value.get("action").and_then(Value::as_str) == Some("modify") {
1800 HookResult::Block {
1801 reason: "Extension returned malformed modify result".to_string(),
1802 }
1803 } else {
1804 HookResult::Continue
1805 }
1806 }
1807 },
1808 Ok(Err(e)) => {
1809 tracing::warn!(
1810 extension = %self.id,
1811 error = %e,
1812 "Extension hook handler failed — continuing",
1813 );
1814 HookResult::Continue
1815 }
1816 Err(_) => {
1817 tracing::warn!(
1818 extension = %self.id,
1819 timeout_secs = 5,
1820 "Extension hook handler timed out — continuing",
1821 );
1822 HookResult::Continue
1823 }
1824 }
1825 }
1826
1827 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1828 let value = tokio::time::timeout(
1829 std::time::Duration::from_secs(5),
1830 self.call("info.get", Value::Null),
1831 )
1832 .await
1833 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1834 serde_json::from_value(value)
1835 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1836 }
1837
1838 async fn sidecar_spawn_args(
1839 &self,
1840 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1841 let value = tokio::time::timeout(
1842 std::time::Duration::from_secs(5),
1843 self.call("sidecar.spawn_args", Value::Null),
1844 )
1845 .await
1846 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1847 serde_json::from_value(value).map_err(|e| {
1848 format!(
1849 "Invalid sidecar.spawn_args response from extension '{}': {}",
1850 self.id, e
1851 )
1852 })
1853 }
1854
1855 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1856 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1857 category: category.to_string(),
1858 field: field.to_string(),
1859 };
1860 tokio::time::timeout(
1861 std::time::Duration::from_secs(5),
1862 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1863 )
1864 .await
1865 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1866 }
1867
1868 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1869 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1870 key: key.to_string(),
1871 }).map_err(|e| e.to_string())?;
1872 if let Some(obj) = params.as_object_mut() {
1873 obj.insert("category".to_string(), Value::String(category.to_string()));
1874 obj.insert("field".to_string(), Value::String(field.to_string()));
1875 }
1876 tokio::time::timeout(
1877 std::time::Duration::from_secs(5),
1878 self.call("settings.editor.key", params),
1879 )
1880 .await
1881 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1882 }
1883
1884 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1885 let params = serde_json::json!({
1886 "category": category,
1887 "field": field,
1888 "value": value,
1889 });
1890 tokio::time::timeout(
1891 std::time::Duration::from_secs(5),
1892 self.call("settings.editor.commit", params),
1893 )
1894 .await
1895 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1896 }
1897
1898 async fn shutdown(&self) {
1899 let _ = tokio::time::timeout(
1900 std::time::Duration::from_millis(500),
1901 self.call("shutdown", Value::Null),
1902 )
1903 .await;
1904
1905 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1906 let mut state_guard = self.state.lock().await;
1907 if let Some(state) = state_guard.take() {
1908 state.reader_handle.abort();
1909 let mut child = state.child;
1910 let _ = child.kill().await;
1911 }
1912 self.inbox.notification_sink.lock().await.take();
1914 self.inbox
1915 .fail_all_pending("transport closed: extension shutdown")
1916 .await;
1917 }
1918
1919 async fn restart_count(&self) -> usize {
1920 self.restart_count()
1921 }
1922
1923 async fn health(&self) -> ExtensionHealth {
1924 let count = self.restart_count.load(Ordering::Relaxed);
1925 let max = self.restart_policy.max_attempts as usize;
1926 if count >= max {
1927 ExtensionHealth::Failed
1928 } else if count > 0 {
1929 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1933 if state_alive {
1934 ExtensionHealth::Degraded
1935 } else {
1936 ExtensionHealth::Restarting
1937 }
1938 } else {
1939 ExtensionHealth::Running
1940 }
1941 }
1942}
1943
1944#[cfg(test)]
1945mod stream_event_tests {
1946 use super::*;
1947 use serde_json::json;
1948
1949 #[test]
1950 fn parses_text_delta_with_delta_key() {
1951 let v = json!({"type": "text", "delta": "hi"});
1952 assert_eq!(
1953 parse_provider_stream_event(&v).unwrap(),
1954 ProviderStreamEvent::TextDelta { text: "hi".into() }
1955 );
1956 }
1957
1958 #[test]
1959 fn parses_text_delta_with_text_key() {
1960 let v = json!({"type": "text", "text": "hi"});
1961 assert_eq!(
1962 parse_provider_stream_event(&v).unwrap(),
1963 ProviderStreamEvent::TextDelta { text: "hi".into() }
1964 );
1965 }
1966
1967 #[test]
1968 fn parses_thinking_delta() {
1969 let v = json!({"type": "thinking", "delta": "hmm"});
1970 assert_eq!(
1971 parse_provider_stream_event(&v).unwrap(),
1972 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1973 );
1974 let v2 = json!({"type": "thinking", "text": "hmm"});
1975 assert_eq!(
1976 parse_provider_stream_event(&v2).unwrap(),
1977 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1978 );
1979 }
1980
1981 #[test]
1982 fn parses_tool_use() {
1983 let v = json!({
1984 "type": "tool_use",
1985 "id": "t1",
1986 "name": "echo",
1987 "input": {"x": 1}
1988 });
1989 assert_eq!(
1990 parse_provider_stream_event(&v).unwrap(),
1991 ProviderStreamEvent::ToolUse {
1992 id: "t1".into(),
1993 name: "echo".into(),
1994 input: json!({"x": 1}),
1995 }
1996 );
1997 }
1998
1999 #[test]
2000 fn tool_use_input_defaults_to_empty_object() {
2001 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2002 assert_eq!(
2003 parse_provider_stream_event(&v).unwrap(),
2004 ProviderStreamEvent::ToolUse {
2005 id: "t1".into(),
2006 name: "echo".into(),
2007 input: json!({}),
2008 }
2009 );
2010 }
2011
2012 #[test]
2013 fn parses_usage_strips_type() {
2014 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2015 assert_eq!(
2016 parse_provider_stream_event(&v).unwrap(),
2017 ProviderStreamEvent::Usage {
2018 usage: json!({"input_tokens": 5, "output_tokens": 7})
2019 }
2020 );
2021 }
2022
2023 #[test]
2024 fn parses_error() {
2025 let v = json!({"type": "error", "message": "boom"});
2026 assert_eq!(
2027 parse_provider_stream_event(&v).unwrap(),
2028 ProviderStreamEvent::Error { message: "boom".into() }
2029 );
2030 }
2031
2032 #[test]
2033 fn parses_done() {
2034 let v = json!({"type": "done"});
2035 assert_eq!(
2036 parse_provider_stream_event(&v).unwrap(),
2037 ProviderStreamEvent::Done
2038 );
2039 }
2040
2041 #[test]
2042 fn nested_event_shape_matches_flat() {
2043 let flat = json!({"type": "text", "delta": "hi"});
2044 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2045 assert_eq!(
2046 parse_provider_stream_event(&flat).unwrap(),
2047 parse_provider_stream_event(&nested).unwrap()
2048 );
2049 }
2050
2051 #[test]
2052 fn missing_type_errors() {
2053 let v = json!({"delta": "hi"});
2054 let err = parse_provider_stream_event(&v).unwrap_err();
2055 assert!(err.contains("missing type"), "got: {err}");
2056 }
2057
2058 #[test]
2059 fn unknown_type_errors_with_type() {
2060 let v = json!({"type": "wat"});
2061 let err = parse_provider_stream_event(&v).unwrap_err();
2062 assert!(err.contains("wat"), "got: {err}");
2063 }
2064
2065 #[test]
2066 fn tool_use_missing_id_errors() {
2067 let v = json!({"type": "tool_use", "name": "echo"});
2068 let err = parse_provider_stream_event(&v).unwrap_err();
2069 assert!(err.contains("id"), "got: {err}");
2070 }
2071
2072 #[test]
2073 fn tool_use_missing_name_errors() {
2074 let v = json!({"type": "tool_use", "id": "t1"});
2075 let err = parse_provider_stream_event(&v).unwrap_err();
2076 assert!(err.contains("name"), "got: {err}");
2077 }
2078
2079 #[test]
2080 fn tool_use_empty_id_errors() {
2081 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2082 assert!(parse_provider_stream_event(&v).is_err());
2083 }
2084
2085 #[test]
2086 fn tool_use_empty_name_errors() {
2087 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2088 assert!(parse_provider_stream_event(&v).is_err());
2089 }
2090
2091 #[test]
2092 fn tool_use_non_object_input_errors() {
2093 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2094 let err = parse_provider_stream_event(&v).unwrap_err();
2095 assert!(err.contains("input"), "got: {err}");
2096 }
2097
2098 #[test]
2099 fn text_missing_delta_and_text_errors() {
2100 let v = json!({"type": "text"});
2101 let err = parse_provider_stream_event(&v).unwrap_err();
2102 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2103 }
2104
2105 #[test]
2106 fn error_missing_message_errors() {
2107 let v = json!({"type": "error"});
2108 assert!(parse_provider_stream_event(&v).is_err());
2109 }
2110
2111 #[test]
2112 fn error_empty_message_errors() {
2113 let v = json!({"type": "error", "message": ""});
2114 assert!(parse_provider_stream_event(&v).is_err());
2115 }
2116}
2117
2118#[cfg(test)]
2119mod restart_policy_tests {
2120 use super::*;
2121
2122 #[tokio::test]
2123 async fn restart_policy_default_max_attempts_is_3() {
2124 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2129 .await
2130 .expect("spawn /bin/cat");
2131 assert_eq!(ext.restart_policy.max_attempts, 3);
2132 ext.shutdown().await;
2133 }
2134
2135 #[tokio::test]
2136 async fn with_restart_policy_overrides_default() {
2137 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2138 .await
2139 .expect("spawn /bin/cat");
2140 let custom = RestartPolicy {
2141 max_attempts: 7,
2142 ..RestartPolicy::default()
2143 };
2144 let ext = ext.with_restart_policy(custom);
2145 assert_eq!(ext.restart_policy.max_attempts, 7);
2146 ext.shutdown().await;
2147 }
2148}
2149
2150#[cfg(test)]
2151mod capture_validator_tests {
2152 use super::*;
2153 use crate::extensions::permissions::{Permission, PermissionSet};
2154
2155 fn perms_with(grants: &[Permission]) -> PermissionSet {
2156 let mut p = PermissionSet::new();
2157 for g in grants {
2158 p.grant(*g);
2159 }
2160 p
2161 }
2162
2163 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2164 CapabilityDeclaration {
2165 kind: kind.to_string(),
2166 name: name.to_string(),
2167 permissions: perms.iter().map(|p| p.to_string()).collect(),
2168 params: serde_json::Value::Null,
2169 }
2170 }
2171
2172 #[test]
2173 fn capability_validator_rejects_empty_kind() {
2174 let d = cap(" ", "Sample", &["audio.input"]);
2175 let perms = perms_with(&[Permission::AudioInput]);
2176 let err = validate_capability(&d, &perms).unwrap_err();
2177 assert!(err.contains("kind"), "got: {}", err);
2178 }
2179
2180 #[test]
2181 fn capability_validator_rejects_empty_name() {
2182 let d = cap("capture", " ", &["audio.input"]);
2183 let perms = perms_with(&[Permission::AudioInput]);
2184 let err = validate_capability(&d, &perms).unwrap_err();
2185 assert!(err.contains("name"), "got: {}", err);
2186 }
2187
2188 #[test]
2189 fn capability_validator_rejects_unknown_permission_string() {
2190 let d = cap("capture", "Sample", &["audio.telepathy"]);
2191 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2192 let err = validate_capability(&d, &perms).unwrap_err();
2193 assert!(
2194 err.contains("unknown permission") && err.contains("audio.telepathy"),
2195 "got: {}",
2196 err,
2197 );
2198 }
2199
2200 #[test]
2201 fn capability_validator_requires_every_declared_permission() {
2202 let d = cap("capture", "Sample", &["audio.input"]);
2203 let perms = perms_with(&[]);
2204 let err = validate_capability(&d, &perms).unwrap_err();
2205 assert!(
2206 err.contains("audio.input") && err.contains("not granted"),
2207 "got: {}",
2208 err,
2209 );
2210 }
2211
2212 #[test]
2213 fn capability_validator_accepts_when_all_permissions_granted() {
2214 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2215 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2216 validate_capability(&d, &perms).expect("should validate");
2217 }
2218
2219 #[test]
2220 fn capability_validator_accepts_no_permissions() {
2221 let d = cap("ocr", "Tesseract", &[]);
2225 let perms = perms_with(&[]);
2226 validate_capability(&d, &perms).expect("should validate");
2227 }
2228
2229 #[test]
2230 fn capability_validator_does_not_branch_on_kind() {
2231 let perms = perms_with(&[Permission::AudioInput]);
2235 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2236 let d = cap(kind, "Anything", &["audio.input"]);
2237 validate_capability(&d, &perms).expect("should validate");
2238 }
2239 }
2240
2241}
2242
2243#[cfg(test)]
2244mod invoke_command_dispatch_tests {
2245 use super::*;
2250 use crate::extensions::commands::CommandOutputEvent;
2251 use crate::extensions::runtime::InvokeCommandEvent;
2252 use crate::extensions::tasks::{TaskEvent, TaskKind};
2253 use serde_json::json;
2254 use tokio::sync::mpsc;
2255
2256 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2257 NotificationFrame {
2258 method: method.to_string(),
2259 params,
2260 }
2261 }
2262
2263 #[test]
2264 fn forwards_mixed_event_stream_in_order() {
2265 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2266 let mut open = true;
2267 let frames = vec![
2268 frame(
2269 "command.output",
2270 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2271 ),
2272 frame(
2273 "task.start",
2274 json!({"id":"dl","label":"Downloading","kind":"download"}),
2275 ),
2276 frame(
2277 "task.update",
2278 json!({"id":"dl","current":50,"total":100}),
2279 ),
2280 frame(
2281 "command.output",
2282 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2283 ),
2284 frame("task.done", json!({"id":"dl"})),
2285 frame(
2286 "command.output",
2287 json!({"request_id":"r1","event":{"kind":"done"}}),
2288 ),
2289 ];
2290
2291 let mut saw_done = false;
2292 for f in frames {
2293 saw_done |= ProcessExtension::forward_invoke_command_frame(
2294 "ext-test", "r1", &tx, &mut open, f,
2295 );
2296 }
2297 drop(tx);
2298 assert!(saw_done, "should have observed the command Done marker");
2299
2300 let mut events = Vec::new();
2301 while let Ok(ev) = rx.try_recv() {
2302 events.push(ev);
2303 }
2304 assert_eq!(events.len(), 6);
2305 assert_eq!(
2306 events[0],
2307 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2308 );
2309 assert!(matches!(
2310 events[1],
2311 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2312 ));
2313 assert!(matches!(
2314 events[2],
2315 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2316 ));
2317 assert!(matches!(
2318 events[3],
2319 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2320 ));
2321 assert!(matches!(
2322 events[4],
2323 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2324 ));
2325 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2326 }
2327
2328 #[test]
2329 fn ignores_command_output_for_unrelated_request_id() {
2330 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2331 let mut open = true;
2332 ProcessExtension::forward_invoke_command_frame(
2333 "ext",
2334 "r1",
2335 &tx,
2336 &mut open,
2337 frame(
2338 "command.output",
2339 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2340 ),
2341 );
2342 drop(tx);
2343 assert!(rx.try_recv().is_err());
2344 }
2345
2346 #[test]
2347 fn skips_malformed_command_output_without_aborting() {
2348 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2349 let mut open = true;
2350 ProcessExtension::forward_invoke_command_frame(
2352 "ext",
2353 "r1",
2354 &tx,
2355 &mut open,
2356 frame("command.output", json!({"request_id":"r1","event":{}})),
2357 );
2358 ProcessExtension::forward_invoke_command_frame(
2360 "ext",
2361 "r1",
2362 &tx,
2363 &mut open,
2364 frame(
2365 "command.output",
2366 json!({"request_id":"r1","event":{"kind":"done"}}),
2367 ),
2368 );
2369 drop(tx);
2370 let ev = rx.try_recv().unwrap();
2371 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2372 assert!(rx.try_recv().is_err());
2373 }
2374
2375 #[test]
2376 fn task_events_pass_through_regardless_of_request_id() {
2377 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2378 let mut open = true;
2379 ProcessExtension::forward_invoke_command_frame(
2380 "ext",
2381 "r1",
2382 &tx,
2383 &mut open,
2384 frame("task.log", json!({"id":"abc","line":"..."})),
2385 );
2386 drop(tx);
2387 match rx.try_recv().unwrap() {
2388 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2389 assert_eq!(id, "abc");
2390 assert_eq!(line, "...");
2391 }
2392 other => panic!("unexpected: {other:?}"),
2393 }
2394 }
2395
2396 #[test]
2397 fn unrelated_methods_are_dropped() {
2398 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2399 let mut open = true;
2400 ProcessExtension::forward_invoke_command_frame(
2401 "ext",
2402 "r1",
2403 &tx,
2404 &mut open,
2405 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2406 );
2407 drop(tx);
2408 assert!(rx.try_recv().is_err());
2409 }
2410}