1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 ).await;
251
252 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
253 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
254 return serde_json::json!({
255 "type": "tool_result",
256 "tool_use_id": tool_id,
257 "content": format!("Tool call blocked by extension: {}", reason),
258 "is_error": true,
259 });
260 };
261
262 let input_for_hook = input.clone();
263 let (result, is_error) = match tool.execute(input, ctx).await {
264 Ok(output) => (output, false),
265 Err(error) => (format!("Tool execution failed: {}", error), true),
266 };
267 let _ = crate::runtime::emit_after_tool_call(
268 hook_bus,
269 &tool_name,
270 Some(&runtime_name),
271 input_for_hook,
272 result.clone(),
273 ).await;
274
275 let mut response = serde_json::json!({
276 "type": "tool_result",
277 "tool_use_id": tool_id,
278 "content": crate::truncate_str(&result, max_tool_output).to_string(),
279 });
280 if is_error {
281 response["is_error"] = serde_json::json!(true);
282 }
283 response
284}
285
286pub async fn complete_provider_with_tools<F>(
287 handler: Arc<dyn ExtensionHandler>,
288 mut params: ProviderCompleteParams,
289 registry: &crate::ToolRegistry,
290 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
291 mut context_factory: F,
292 max_tool_output: usize,
293 max_iterations: usize,
294) -> Result<ProviderCompleteResult, String>
295where
296 F: FnMut() -> crate::ToolContext,
297{
298 let max_iterations = max_iterations.max(1);
299 for iteration in 0..max_iterations {
300 let result = handler.provider_complete(params.clone()).await?;
301 let tool_uses = extract_provider_tool_uses(&result.content)?;
302 if tool_uses.is_empty() {
303 return Ok(result);
304 }
305 if iteration + 1 == max_iterations {
306 return Err(format!(
307 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
308 handler.id(),
309 max_iterations,
310 ));
311 }
312
313 let assistant_content = result.content.clone();
314 params.messages.push(serde_json::json!({
315 "role": "assistant",
316 "content": assistant_content,
317 }));
318
319 let mut tool_results = Vec::with_capacity(tool_uses.len());
320 for tool_use in tool_uses {
321 tool_results.push(execute_provider_tool_use(
322 registry,
323 hook_bus,
324 tool_use,
325 context_factory(),
326 max_tool_output,
327 ).await);
328 }
329 params.messages.push(serde_json::json!({
330 "role": "user",
331 "content": tool_results,
332 }));
333 }
334 Err(format!(
335 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
336 handler.id(),
337 max_iterations,
338 ))
339}
340
341pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
342 let mut tool_uses = Vec::new();
343 for block in content {
344 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
345 continue;
346 }
347 let id = block
348 .get("id")
349 .and_then(Value::as_str)
350 .ok_or_else(|| "provider tool_use missing id".to_string())?;
351 let name = block
352 .get("name")
353 .and_then(Value::as_str)
354 .ok_or_else(|| "provider tool_use missing name".to_string())?;
355 if id.trim().is_empty() {
356 return Err("provider tool_use id is empty".to_string());
357 }
358 if name.trim().is_empty() {
359 return Err("provider tool_use name is empty".to_string());
360 }
361 let input = block
362 .get("input")
363 .cloned()
364 .unwrap_or_else(|| serde_json::json!({}));
365 if !input.is_object() {
366 return Err(format!(
367 "provider tool_use '{}' input must be a JSON object",
368 id
369 ));
370 }
371 tool_uses.push(ProviderToolUse {
372 id: id.to_string(),
373 name: name.to_string(),
374 input,
375 });
376 }
377 Ok(tool_uses)
378}
379
380#[derive(Debug, Clone, Default, PartialEq)]
381pub struct InitializeCapabilitiesResult {
382 pub tools: Vec<RegisteredExtensionToolSpec>,
383 pub providers: Vec<RegisteredProviderSpec>,
384 pub capabilities: Vec<CapabilityDeclaration>,
390}
391
392#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
401pub struct CapabilityDeclaration {
402 pub kind: String,
406 pub name: String,
409 #[serde(default)]
414 pub permissions: Vec<String>,
415 #[serde(default, skip_serializing_if = "is_null_value")]
419 pub params: serde_json::Value,
420}
421
422fn is_null_value(v: &serde_json::Value) -> bool {
423 v.is_null()
424}
425
426pub fn validate_capability(
436 decl: &CapabilityDeclaration,
437 granted: &crate::extensions::permissions::PermissionSet,
438) -> Result<(), String> {
439 use crate::extensions::permissions::Permission;
440 if decl.kind.trim().is_empty() {
441 return Err("capability 'kind' must be non-empty".to_string());
442 }
443 if decl.name.trim().is_empty() {
444 return Err("capability 'name' must be non-empty".to_string());
445 }
446 for perm_name in &decl.permissions {
447 let parsed = Permission::parse(perm_name).ok_or_else(|| {
448 format!(
449 "capability '{}' declares unknown permission '{}'",
450 decl.kind, perm_name
451 )
452 })?;
453 if !granted.has(parsed) {
454 return Err(format!(
455 "capability '{}' requires permission '{}' but it is not granted",
456 decl.kind, perm_name
457 ));
458 }
459 }
460 Ok(())
461}
462
463#[derive(Deserialize)]
464struct InitializeResult {
465 protocol_version: u32,
466 #[serde(default)]
467 capabilities: InitializeCapabilities,
468}
469
470#[derive(Default, Deserialize)]
471struct InitializeCapabilities {
472 #[serde(default)]
473 tools: Vec<RegisteredExtensionToolSpec>,
474 #[serde(default)]
475 providers: Vec<RegisteredProviderSpec>,
476 #[serde(default)]
478 capabilities: Vec<CapabilityDeclaration>,
479}
480
481#[doc(hidden)]
487#[derive(Debug, Clone)]
488pub struct NotificationFrame {
489 pub method: String,
490 pub params: Value,
491}
492
493struct Inbox {
500 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
501 notification_sink: Mutex<Option<mpsc::UnboundedSender<NotificationFrame>>>,
502 closed: std::sync::atomic::AtomicBool,
505 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
508 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
512 extension_id: String,
514}
515
516impl Inbox {
517 fn new(extension_id: String) -> Self {
518 Self {
519 pending: Mutex::new(HashMap::new()),
520 notification_sink: Mutex::new(None),
521 closed: std::sync::atomic::AtomicBool::new(false),
522 permissions: RwLock::new(None),
523 inbound_stdin: Mutex::new(None),
524 extension_id,
525 }
526 }
527
528 async fn fail_all_pending(&self, reason: &str) {
531 self.closed.store(true, std::sync::atomic::Ordering::Release);
532 let drained: Vec<_> = {
533 let mut pending = self.pending.lock().await;
534 pending.drain().collect()
535 };
536 for (_, tx) in drained {
537 let _ = tx.send(Err(reason.to_string()));
538 }
539 }
540}
541
542struct ProcessState {
543 child: Child,
544 stdin: Arc<Mutex<ChildStdin>>,
545 reader_handle: JoinHandle<()>,
546}
547
548pub struct ProcessExtension {
550 id: String,
551 command: String,
552 args: Vec<String>,
553 cwd: Option<PathBuf>,
554 state: Arc<Mutex<Option<ProcessState>>>,
555 call_lock: Arc<Mutex<()>>,
557 next_id: AtomicU64,
558 restart_count: AtomicUsize,
559 pub(crate) restart_policy: RestartPolicy,
561 inbox: Arc<Inbox>,
565}
566
567impl ProcessExtension {
568 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
569 Self::spawn_with_cwd(id, command, args, None).await
570 }
571
572 pub async fn spawn_with_cwd(
577 id: &str,
578 command: &str,
579 args: &[String],
580 cwd: Option<PathBuf>,
581 ) -> Result<Self, String> {
582 let inbox = Arc::new(Inbox::new(id.to_string()));
583 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
584 Ok(Self {
585 id: id.to_string(),
586 command: command.to_string(),
587 args: args.to_vec(),
588 cwd,
589 state: Arc::new(Mutex::new(Some(state))),
590 call_lock: Arc::new(Mutex::new(())),
591 next_id: AtomicU64::new(1),
592 restart_count: AtomicUsize::new(0),
593 restart_policy: RestartPolicy::default(),
594 inbox,
595 })
596 }
597
598 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
600 self.restart_policy = policy;
601 self
602 }
603
604 async fn spawn_state(
605 id: &str,
606 command: &str,
607 args: &[String],
608 cwd: Option<&PathBuf>,
609 inbox: Arc<Inbox>,
610 ) -> Result<ProcessState, String> {
611 let mut cmd = Command::new(command);
612 cmd.args(args)
613 .stdin(Stdio::piped())
614 .stdout(Stdio::piped())
615 .stderr(Stdio::piped());
616 if let Some(cwd) = cwd {
617 cmd.current_dir(cwd);
618 }
619
620 cmd.env_clear();
626 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
627 if let Ok(val) = std::env::var(var) {
628 cmd.env(var, val);
629 }
630 }
631
632 cmd.kill_on_drop(true);
633
634 let mut child = cmd
635 .spawn()
636 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
637
638 let stdin = child
639 .stdin
640 .take()
641 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
642 let stdout = child
643 .stdout
644 .take()
645 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
646 if let Some(stderr) = child.stderr.take() {
647 let extension_id = id.to_string();
648 tokio::spawn(async move {
649 let mut lines = BufReader::new(stderr).lines();
650 loop {
651 match lines.next_line().await {
652 Ok(Some(line)) => {
653 tracing::debug!(extension = %extension_id, stderr = %line);
654 }
655 Ok(None) => break,
656 Err(error) => {
657 tracing::debug!(
658 extension = %extension_id,
659 error = %error,
660 "Failed to read extension stderr",
661 );
662 break;
663 }
664 }
665 }
666 });
667 }
668
669 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
670
671 let stdin_arc = Arc::new(Mutex::new(stdin));
672 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
675
676 Ok(ProcessState {
677 child,
678 stdin: stdin_arc,
679 reader_handle,
680 })
681 }
682
683 fn spawn_reader(
688 stdout: ChildStdout,
689 inbox: Arc<Inbox>,
690 extension_id: String,
691 ) -> JoinHandle<()> {
692 tokio::spawn(async move {
693 let mut reader = BufReader::new(stdout);
694 loop {
695 match Self::read_one_frame(&mut reader, &extension_id).await {
696 Ok(Some(value)) => {
697 Self::dispatch_frame(value, &inbox, &extension_id).await;
698 }
699 Ok(None) => {
700 tracing::debug!(
701 extension = %extension_id,
702 "Extension stdout closed (EOF); failing pending requests",
703 );
704 inbox.fail_all_pending("transport closed: EOF").await;
705 inbox.notification_sink.lock().await.take();
707 return;
708 }
709 Err(error) => {
710 tracing::debug!(
711 extension = %extension_id,
712 error = %error,
713 "Extension transport read error",
714 );
715 inbox
716 .fail_all_pending(&format!("transport error: {}", error))
717 .await;
718 inbox.notification_sink.lock().await.take();
719 return;
720 }
721 }
722 }
723 })
724 }
725
726 async fn read_one_frame(
730 reader: &mut BufReader<ChildStdout>,
731 extension_id: &str,
732 ) -> Result<Option<Value>, String> {
733 let mut content_length: Option<usize> = None;
734 let mut saw_any_header = false;
735 loop {
736 let mut header_line = String::new();
737 let n = reader
738 .read_line(&mut header_line)
739 .await
740 .map_err(|e| format!("Read header error: {}", e))?;
741 if n == 0 {
742 if saw_any_header {
743 return Err("Unexpected EOF while reading response headers".into());
744 }
745 return Ok(None);
746 }
747 saw_any_header = true;
748 if header_line.len() > 1024 {
749 return Err(format!(
750 "Extension '{}' header line too long ({} bytes)",
751 extension_id,
752 header_line.len()
753 ));
754 }
755 let trimmed = header_line.trim();
756 if trimmed.is_empty() {
757 break;
758 }
759 if let Some((name, value)) = trimmed.split_once(':') {
760 if name.trim().eq_ignore_ascii_case("Content-Length") {
761 content_length = Some(value.trim().parse().map_err(|_| {
762 format!("Invalid Content-Length value: {:?}", value.trim())
763 })?);
764 }
765 }
766 }
767 let content_length = content_length.ok_or_else(|| {
768 format!(
769 "Extension '{}' frame missing Content-Length header",
770 extension_id
771 )
772 })?;
773 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
774 if content_length > MAX_RESPONSE_SIZE {
775 return Err(format!(
776 "Extension '{}' frame too large: {} bytes (max {})",
777 extension_id, content_length, MAX_RESPONSE_SIZE
778 ));
779 }
780 let mut buf = vec![0u8; content_length];
781 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
782 .await
783 .map_err(|e| format!("Read body error: {}", e))?;
784 let value: Value = serde_json::from_slice(&buf)
785 .map_err(|e| format!("Parse frame error: {}", e))?;
786 Ok(Some(value))
787 }
788
789 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
795 let id_field = value.get("id");
796 let id_is_present = !matches!(id_field, None | Some(Value::Null));
797 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
798
799 if id_is_present && method_field.is_some() {
800 let id = match id_field.and_then(Value::as_u64) {
803 Some(id) => id,
804 None => {
805 tracing::trace!(
806 extension = %extension_id,
807 frame = %value,
808 "Discarding inbound request with non-numeric id",
809 );
810 return;
811 }
812 };
813 let method = method_field.unwrap();
814 let params = value.get("params").cloned().unwrap_or(Value::Null);
815 let inbox = inbox.clone();
816 let extension_id = extension_id.to_string();
817 tokio::spawn(async move {
818 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
819 let payload = match outcome {
820 Ok(result) => serde_json::json!({
821 "jsonrpc": "2.0",
822 "id": id,
823 "result": result,
824 }),
825 Err((code, message)) => serde_json::json!({
826 "jsonrpc": "2.0",
827 "id": id,
828 "error": {"code": code, "message": message},
829 }),
830 };
831 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
832 if let Some(stdin) = stdin_handle {
833 let body = match serde_json::to_string(&payload) {
834 Ok(s) => s,
835 Err(error) => {
836 tracing::warn!(
837 extension = %extension_id,
838 error = %error,
839 "Failed to serialize inbound response",
840 );
841 return;
842 }
843 };
844 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
845 let mut stdin = stdin.lock().await;
846 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
847 tracing::warn!(
848 extension = %extension_id,
849 error = %error,
850 "Failed to write inbound response",
851 );
852 return;
853 }
854 if let Err(error) = stdin.flush().await {
855 tracing::warn!(
856 extension = %extension_id,
857 error = %error,
858 "Failed to flush inbound response",
859 );
860 }
861 } else {
862 tracing::warn!(
863 extension = %extension_id,
864 "No stdin available to reply to inbound request",
865 );
866 }
867 });
868 return;
869 }
870
871 if id_is_present {
872 let id = match id_field.and_then(Value::as_u64) {
873 Some(id) => id,
874 None => {
875 tracing::trace!(
876 extension = %extension_id,
877 frame = %value,
878 "Discarding frame with non-numeric id",
879 );
880 return;
881 }
882 };
883 let sender = inbox.pending.lock().await.remove(&id);
884 match sender {
885 Some(tx) => {
886 let payload = if let Some(err) = value.get("error") {
887 let message = err
888 .get("message")
889 .and_then(Value::as_str)
890 .unwrap_or("unknown extension error")
891 .to_string();
892 Err(format!("Extension error: {}", message))
893 } else {
894 Ok(value
895 .get("result")
896 .cloned()
897 .unwrap_or(Value::Null))
898 };
899 let _ = tx.send(payload);
900 }
901 None => {
902 tracing::trace!(
903 extension = %extension_id,
904 id = id,
905 "Response with unknown id (no pending request); dropping",
906 );
907 }
908 }
909 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
910 let params = value.get("params").cloned().unwrap_or(Value::Null);
911 let frame = NotificationFrame {
912 method: method.to_string(),
913 params,
914 };
915 let mut sink_guard = inbox.notification_sink.lock().await;
916 if let Some(sink) = sink_guard.as_ref() {
917 if sink.send(frame).is_err() {
918 sink_guard.take();
920 }
921 } else {
922 tracing::trace!(
923 extension = %extension_id,
924 method = %method,
925 "Notification with no active subscriber; dropping",
926 );
927 }
928 } else {
929 tracing::trace!(
930 extension = %extension_id,
931 frame = %value,
932 "Unrecognized frame; dropping",
933 );
934 }
935 }
936
937 pub fn restart_count(&self) -> usize {
938 self.restart_count.load(Ordering::Relaxed)
939 }
940
941 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
944 *self.inbox.permissions.write().await = Some(perms);
945 }
946
947 async fn handle_inbound_request(
955 inbox: &Arc<Inbox>,
956 method: &str,
957 params: Value,
958 ) -> Result<Value, (i32, String)> {
959 use crate::extensions::permissions::Permission;
960 use crate::memory::store::{self, MemoryQuery};
961
962 match method {
963 "memory.append" => {
964 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
965 let namespace = Self::param_str(¶ms, "namespace")?;
966 Self::require_namespace_matches(inbox, &namespace).await?;
967 let content = Self::param_str(¶ms, "content")?;
968 let tags = match params.get("tags") {
969 None | Some(Value::Null) => Vec::new(),
970 Some(Value::Array(arr)) => {
971 let mut out = Vec::with_capacity(arr.len());
972 for v in arr {
973 match v.as_str() {
974 Some(s) => out.push(s.to_string()),
975 None => {
976 return Err((
977 -32602,
978 "tags must be an array of strings".to_string(),
979 ))
980 }
981 }
982 }
983 out
984 }
985 _ => {
986 return Err((
987 -32602,
988 "tags must be an array of strings".to_string(),
989 ))
990 }
991 };
992 let meta = match params.get("meta") {
993 None | Some(Value::Null) => None,
994 Some(v) => Some(v.clone()),
995 };
996 let record = store::new_record(namespace, content, tags, meta);
997 let timestamp_ms = record.timestamp_ms;
998 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
999 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1000 }
1001 "memory.query" => {
1002 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1003 let namespace = Self::param_str(¶ms, "namespace")?;
1004 Self::require_namespace_matches(inbox, &namespace).await?;
1005 let q = MemoryQuery {
1006 content_contains: params
1007 .get("content_contains")
1008 .and_then(Value::as_str)
1009 .map(str::to_string),
1010 tag_prefix: params
1011 .get("tag_prefix")
1012 .and_then(Value::as_str)
1013 .map(str::to_string),
1014 since_ms: params.get("since_ms").and_then(Value::as_u64),
1015 until_ms: params.get("until_ms").and_then(Value::as_u64),
1016 limit: params
1017 .get("limit")
1018 .and_then(Value::as_u64)
1019 .map(|n| n as usize),
1020 };
1021 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1022 Ok(serde_json::json!({"records": records}))
1023 }
1024 "config.get" => {
1025 let key = Self::param_str(¶ms, "key")?;
1026 Self::validate_config_key(&key)?;
1027 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1028 Ok(serde_json::json!({"value": value}))
1029 }
1030 "config.set" => {
1031 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1032 let key = Self::param_str(¶ms, "key")?;
1033 Self::validate_config_key(&key)?;
1034 let value = Self::param_str(¶ms, "value")?;
1035 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1036 .map_err(|e| (-32000, e.to_string()))?;
1037 Ok(serde_json::json!({"ok": true}))
1038 }
1039 "config.subscribe" => {
1040 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1041 Ok(serde_json::json!({"ok": true}))
1045 }
1046 other => Err((-32601, format!("method not found: {other}"))),
1047 }
1048 }
1049
1050 async fn require_permission(
1051 inbox: &Arc<Inbox>,
1052 perm: crate::extensions::permissions::Permission,
1053 wire: &str,
1054 ) -> Result<(), (i32, String)> {
1055 let guard = inbox.permissions.read().await;
1056 match guard.as_ref() {
1057 Some(set) if set.has(perm) => Ok(()),
1058 _ => Err((
1059 -32602,
1060 format!("permission denied: {wire} required"),
1061 )),
1062 }
1063 }
1064
1065 async fn require_namespace_matches(
1066 inbox: &Arc<Inbox>,
1067 namespace: &str,
1068 ) -> Result<(), (i32, String)> {
1069 if namespace == inbox.extension_id {
1070 Ok(())
1071 } else {
1072 Err((
1073 -32602,
1074 format!(
1075 "namespace must equal extension id '{}' (got '{}')",
1076 inbox.extension_id, namespace
1077 ),
1078 ))
1079 }
1080 }
1081
1082 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1083 params
1084 .get(name)
1085 .and_then(Value::as_str)
1086 .map(str::to_string)
1087 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1088 }
1089
1090 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1091 let trimmed = key.trim();
1092 if trimmed.is_empty() {
1093 return Err((-32602, "config key must be non-empty".to_string()));
1094 }
1095 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1096 return Err((
1097 -32602,
1098 "config key must not contain dots, slashes, or spaces".to_string(),
1099 ));
1100 }
1101 Ok(())
1102 }
1103
1104 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1105 let params = InitializeParams {
1106 synaps_version: env!("CARGO_PKG_VERSION"),
1107 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1108 plugin_id: self.id.clone(),
1109 plugin_root: plugin_root
1110 .or_else(|| self.cwd.clone())
1111 .map(|path| path.to_string_lossy().to_string()),
1112 config,
1113 };
1114 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1115 Self::parse_initialize_result(&self.id, value)
1116 }
1117
1118 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1119 let result: InitializeResult = serde_json::from_value(value)
1120 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1121 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1122 return Err(format!(
1123 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1124 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1125 ));
1126 }
1127 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1128 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1129 Ok(InitializeCapabilitiesResult {
1130 tools: result.capabilities.tools,
1131 providers: result.capabilities.providers,
1132 capabilities: result.capabilities.capabilities,
1133 })
1134 }
1135
1136 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1137 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1138 let mut names = HashSet::new();
1139 for tool in tools {
1140 let name = tool.name.trim();
1141 if let Err(err) = validate_id_segment(name) {
1142 return Err(match err {
1143 IdValidationError::Empty => format!(
1144 "Extension '{}' registered a tool with an empty tool name",
1145 id
1146 ),
1147 IdValidationError::ContainsReserved { ch } => format!(
1148 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1149 id, name, ch
1150 ),
1151 IdValidationError::TooLong { len, max } => format!(
1152 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1153 id, name, max, len
1154 ),
1155 IdValidationError::ContainsWhitespace => format!(
1156 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1157 id, name
1158 ),
1159 IdValidationError::ContainsControl { ch } => format!(
1160 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1161 id, name, ch as u32
1162 ),
1163 });
1164 }
1165 if !names.insert(name.to_string()) {
1166 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1167 }
1168 if tool.description.trim().is_empty() {
1169 return Err(format!(
1170 "Extension '{}' registered tool '{}' with an empty description",
1171 id, name,
1172 ));
1173 }
1174 if !tool.input_schema.is_object() {
1175 return Err(format!(
1176 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1177 id, name,
1178 ));
1179 }
1180 }
1181 Ok(())
1182 }
1183
1184 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1185 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1186 for provider in providers {
1187 let provider_id = provider.id.trim();
1188 match validate_id_segment(provider_id) {
1189 Ok(()) => {
1190 if !Self::is_safe_provider_id(provider_id) {
1191 return Err(format!(
1192 "Extension '{}' registered provider '{}' with invalid provider id",
1193 id, provider_id
1194 ));
1195 }
1196 }
1197 Err(IdValidationError::Empty) => {
1198 return Err(format!(
1199 "Extension '{}' registered provider with empty provider id",
1200 id
1201 ));
1202 }
1203 Err(err) => {
1204 return Err(format!(
1205 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1206 id, provider_id, err
1207 ));
1208 }
1209 }
1210 if provider.display_name.trim().is_empty() {
1211 return Err(format!(
1212 "Extension '{}' registered provider '{}' with empty display_name",
1213 id, provider_id,
1214 ));
1215 }
1216 if provider.description.trim().is_empty() {
1217 return Err(format!(
1218 "Extension '{}' registered provider '{}' with empty description",
1219 id, provider_id,
1220 ));
1221 }
1222 if provider.models.is_empty() {
1223 return Err(format!(
1224 "Extension '{}' registered provider '{}' must declare at least one model",
1225 id, provider_id,
1226 ));
1227 }
1228 let mut model_ids = HashSet::new();
1229 for model in &provider.models {
1230 let model_id = model.id.trim();
1231 if let Err(err) = validate_id_segment(model_id) {
1232 return Err(match err {
1233 IdValidationError::Empty => format!(
1234 "Extension '{}' registered provider '{}' with empty model id",
1235 id, provider_id
1236 ),
1237 IdValidationError::ContainsReserved { ch } => format!(
1238 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1239 id, provider_id, model_id, ch
1240 ),
1241 IdValidationError::TooLong { len, max } => format!(
1242 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1243 id, provider_id, model_id, max, len
1244 ),
1245 IdValidationError::ContainsWhitespace => format!(
1246 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1247 id, provider_id, model_id
1248 ),
1249 IdValidationError::ContainsControl { ch } => format!(
1250 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1251 id, provider_id, model_id, ch as u32
1252 ),
1253 });
1254 }
1255 if !model_ids.insert(model_id.to_string()) {
1256 return Err(format!(
1257 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1258 id, provider_id, model_id,
1259 ));
1260 }
1261 }
1262 if let Some(config_schema) = &provider.config_schema {
1263 if !config_schema.is_object() {
1264 return Err(format!(
1265 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1266 id, provider_id,
1267 ));
1268 }
1269 }
1270 }
1271 Ok(())
1272 }
1273
1274 fn is_safe_provider_id(id: &str) -> bool {
1275 !id.is_empty()
1276 && !id.contains(':')
1277 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1278 }
1279
1280 #[doc(hidden)]
1281 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1282 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1283 }
1284
1285 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1286 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1287 let max_attempts = self.restart_policy.max_attempts;
1288 if attempted > max_attempts as usize {
1289 *state = None;
1290 return Err(format!(
1291 "Extension '{}' exceeded restart limit ({})",
1292 self.id, max_attempts,
1293 ));
1294 }
1295
1296 if let Some(old) = state.take() {
1297 old.reader_handle.abort();
1298 let mut child = old.child;
1299 let _ = child.kill().await;
1300 }
1301 self.inbox
1303 .fail_all_pending("transport closed: process restarting")
1304 .await;
1305
1306 let delay = self
1307 .restart_policy
1308 .delay_for_attempt(attempted as u32)
1309 .unwrap_or_default();
1310
1311 tracing::warn!(
1312 extension = %self.id,
1313 attempt = attempted,
1314 max_attempts = max_attempts,
1315 delay_ms = delay.as_millis() as u64,
1316 "Restarting extension process after transport failure",
1317 );
1318
1319 if !delay.is_zero() {
1320 tokio::time::sleep(delay).await;
1321 }
1322
1323 *state = Some(Self::spawn_state(
1324 &self.id,
1325 &self.command,
1326 &self.args,
1327 self.cwd.as_ref(),
1328 self.inbox.clone(),
1329 ).await?);
1330 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1332 self.initialize_locked(state).await?;
1333 self.restart_count.store(0, Ordering::Relaxed);
1336 Ok(())
1337 }
1338
1339
1340 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1341 let params = InitializeParams {
1342 synaps_version: env!("CARGO_PKG_VERSION"),
1343 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1344 plugin_id: self.id.clone(),
1345 plugin_root: self.cwd
1346 .clone()
1347 .map(|path| path.to_string_lossy().to_string()),
1348 config: Value::Object(Default::default()),
1349 };
1350 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1351 let value = tokio::time::timeout(
1352 std::time::Duration::from_secs(10),
1353 self.call_once_locked(
1354 state.as_mut().expect("state should exist for initialize"),
1355 "initialize",
1356 serde_json::to_value(params).map_err(|e| e.to_string())?,
1357 id,
1358 ),
1359 )
1360 .await
1361 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1362 ?;
1363 Self::parse_initialize_result(&self.id, value).map(|_| ())
1364 }
1365
1366 async fn call_once_locked(
1370 &self,
1371 state: &mut ProcessState,
1372 method: &str,
1373 params: Value,
1374 id: u64,
1375 ) -> Result<Value, String> {
1376 let body = serde_json::to_string(&JsonRpcRequest {
1377 jsonrpc: "2.0",
1378 method: method.to_string(),
1379 params,
1380 id,
1381 })
1382 .map_err(|e| format!("Serialize error: {}", e))?;
1383
1384 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1385 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1388 return Err("transport closed: inbox is shut down".to_string());
1389 }
1390
1391 self.inbox.pending.lock().await.insert(id, tx);
1394
1395 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1398 self.inbox.pending.lock().await.remove(&id);
1399 return Err("transport closed: inbox shut down during registration".to_string());
1400 }
1401
1402 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1403 let write_result = {
1404 let mut stdin = state.stdin.lock().await;
1405 match stdin.write_all(frame.as_bytes()).await {
1406 Ok(()) => stdin.flush().await,
1407 Err(e) => Err(e),
1408 }
1409 };
1410 if let Err(e) = write_result {
1411 self.inbox.pending.lock().await.remove(&id);
1413 return Err(format!("Write error: {}", e));
1414 }
1415
1416 match rx.await {
1417 Ok(payload) => payload,
1418 Err(_) => {
1419 self.inbox.pending.lock().await.remove(&id);
1424 Err("transport closed: response channel dropped".to_string())
1425 }
1426 }
1427 }
1428
1429 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1430 let _call_guard = self.call_lock.lock().await;
1431 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1432 let mut state_guard = self.state.lock().await;
1433 if state_guard.is_none() {
1434 *state_guard = Some(Self::spawn_state(
1435 &self.id,
1436 &self.command,
1437 &self.args,
1438 self.cwd.as_ref(),
1439 self.inbox.clone(),
1440 ).await?);
1441 }
1442 self.call_once_locked(
1443 state_guard.as_mut().expect("state should exist"),
1444 method,
1445 params,
1446 id,
1447 ).await
1448 }
1449
1450 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1451 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1452 let id_str = self.id.clone();
1453 let method_str = method.to_string();
1454
1455 let result = tokio::time::timeout(
1456 std::time::Duration::from_secs(timeout_secs),
1457 self.call_inner(method, params),
1458 )
1459 .await;
1460
1461 match result {
1462 Ok(inner) => inner,
1463 Err(_) => Err(format!(
1464 "Extension '{}' method '{}' timed out after {}s",
1465 id_str, method_str, timeout_secs
1466 )),
1467 }
1468 }
1469
1470 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1471 let _call_guard = self.call_lock.lock().await;
1472 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1473 let mut state_guard = self.state.lock().await;
1474 if state_guard.is_none() {
1475 self.restart_locked(&mut state_guard).await?;
1476 }
1477
1478 let result = self
1479 .call_once_locked(
1480 state_guard.as_mut().expect("state should exist after restart"),
1481 method,
1482 params.clone(),
1483 id,
1484 )
1485 .await;
1486
1487 match result {
1488 Ok(value) => Ok(value),
1489 Err(first_error) => {
1490 self.restart_locked(&mut state_guard).await?;
1491 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1492 self.call_once_locked(
1493 state_guard.as_mut().expect("state should exist after restart"),
1494 method,
1495 params,
1496 retry_id,
1497 )
1498 .await
1499 .map_err(|retry_error| {
1500 format!("{}; retry after restart failed: {}", first_error, retry_error)
1501 })
1502 }
1503 }
1504 }
1505
1506 #[doc(hidden)]
1519 pub async fn subscribe_notifications(&self) -> mpsc::UnboundedReceiver<NotificationFrame> {
1520 let (tx, rx) = mpsc::unbounded_channel();
1521 let mut sink = self.inbox.notification_sink.lock().await;
1522 *sink = Some(tx);
1523 rx
1524 }
1525
1526 #[doc(hidden)]
1528 pub async fn unsubscribe_notifications(&self) {
1529 self.inbox.notification_sink.lock().await.take();
1530 }
1531
1532 pub(crate) fn forward_invoke_command_frame(
1542 extension_id: &str,
1543 request_id: &str,
1544 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1545 sink_open: &mut bool,
1546 frame: NotificationFrame,
1547 ) -> bool {
1548 use crate::extensions::commands::parse_command_output;
1549 use crate::extensions::tasks::{is_task_method, parse_task_event};
1550 use crate::extensions::runtime::InvokeCommandEvent;
1551
1552 let mut saw_done = false;
1553 if frame.method == "command.output" {
1554 match parse_command_output(&frame.params) {
1555 Ok(parsed) if parsed.request_id == request_id => {
1556 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1557 saw_done = true;
1558 }
1559 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1560 *sink_open = false;
1561 }
1562 }
1563 Ok(_) => {
1564 tracing::trace!(
1566 extension = %extension_id,
1567 "Ignoring command.output for unrelated request_id",
1568 );
1569 }
1570 Err(error) => {
1571 tracing::warn!(
1572 extension = %extension_id,
1573 error = %error,
1574 params = %frame.params,
1575 "Skipping malformed command.output notification",
1576 );
1577 }
1578 }
1579 } else if is_task_method(&frame.method) {
1580 match parse_task_event(&frame.method, &frame.params) {
1581 Ok(event) => {
1582 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1583 *sink_open = false;
1584 }
1585 }
1586 Err(error) => {
1587 tracing::warn!(
1588 extension = %extension_id,
1589 method = %frame.method,
1590 error = %error,
1591 params = %frame.params,
1592 "Skipping malformed task notification",
1593 );
1594 }
1595 }
1596 } else {
1597 tracing::trace!(
1598 extension = %extension_id,
1599 method = %frame.method,
1600 "Ignoring non-command/task notification during command.invoke",
1601 );
1602 }
1603 saw_done
1604 }
1605
1606 fn forward_provider_stream_frame(
1613 extension_id: &str,
1614 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1615 sink_open: &mut bool,
1616 frame: NotificationFrame,
1617 ) {
1618 if frame.method != "provider.stream.event" {
1619 tracing::trace!(
1620 extension = %extension_id,
1621 method = %frame.method,
1622 "Ignoring non-stream notification during provider.stream",
1623 );
1624 return;
1625 }
1626 match parse_provider_stream_event(&frame.params) {
1627 Ok(event) => {
1628 if *sink_open && sink.send(event).is_err() {
1629 *sink_open = false;
1630 }
1631 }
1632 Err(error) => {
1633 tracing::warn!(
1634 extension = %extension_id,
1635 error = %error,
1636 params = %frame.params,
1637 "Skipping malformed provider.stream.event notification",
1638 );
1639 }
1640 }
1641 }
1642}
1643
1644#[async_trait]
1645impl ExtensionHandler for ProcessExtension {
1646 fn id(&self) -> &str {
1647 &self.id
1648 }
1649
1650 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1651 self.call("tool.call", serde_json::json!({
1652 "name": name,
1653 "input": input,
1654 })).await
1655 }
1656
1657 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1658 let value = tokio::time::timeout(
1659 std::time::Duration::from_secs(60),
1660 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1661 )
1662 .await
1663 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1664 let result: ProviderCompleteResult = serde_json::from_value(value)
1665 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1666 if result.content.is_empty() {
1667 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1668 }
1669 Ok(result)
1670 }
1671
1672 async fn provider_stream(
1673 &self,
1674 params: ProviderCompleteParams,
1675 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1676 ) -> Result<ProviderCompleteResult, String> {
1677 let mut rx = self.subscribe_notifications().await;
1680 let params_value =
1681 serde_json::to_value(params).map_err(|e| e.to_string())?;
1682
1683 let extension_id = self.id.clone();
1684 let stream_future = async {
1685 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1686 let mut sink_open = true;
1687 let response = loop {
1688 tokio::select! {
1689 response = &mut call_fut => break response,
1690 Some(frame) = rx.recv() => {
1691 Self::forward_provider_stream_frame(
1692 &extension_id, &sink, &mut sink_open, frame,
1693 );
1694 }
1695 }
1696 };
1697 self.unsubscribe_notifications().await;
1701 while let Some(frame) = rx.recv().await {
1702 Self::forward_provider_stream_frame(
1703 &extension_id, &sink, &mut sink_open, frame,
1704 );
1705 }
1706 response
1707 };
1708
1709 let outcome = tokio::time::timeout(
1710 std::time::Duration::from_secs(60),
1711 stream_future,
1712 )
1713 .await;
1714
1715 self.unsubscribe_notifications().await;
1717
1718 let value = outcome
1719 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1720
1721 let result: ProviderCompleteResult = serde_json::from_value(value)
1722 .map_err(|e| {
1723 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1724 })?;
1725 Ok(result)
1728 }
1729
1730 async fn invoke_command(
1731 &self,
1732 command: &str,
1733 args: Vec<String>,
1734 request_id: &str,
1735 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1736 ) -> Result<Value, String> {
1737 let mut rx = self.subscribe_notifications().await;
1739 let params = serde_json::json!({
1740 "command": command,
1741 "args": args,
1742 "request_id": request_id,
1743 });
1744
1745 let extension_id = self.id.clone();
1746 let request_id_owned = request_id.to_string();
1747 let invoke_future = async {
1748 let mut call_fut = Box::pin(self.call("command.invoke", params));
1749 let mut sink_open = true;
1750 let response = loop {
1751 tokio::select! {
1752 response = &mut call_fut => break response,
1753 Some(frame) = rx.recv() => {
1754 let _ = Self::forward_invoke_command_frame(
1755 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1756 );
1757 }
1758 }
1759 };
1760 self.unsubscribe_notifications().await;
1764 while let Ok(frame) = rx.try_recv() {
1765 let _ = Self::forward_invoke_command_frame(
1766 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1767 );
1768 }
1769 response
1770 };
1771
1772 let outcome = tokio::time::timeout(
1773 std::time::Duration::from_secs(120),
1774 invoke_future,
1775 )
1776 .await;
1777
1778 self.unsubscribe_notifications().await;
1780
1781 outcome
1782 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1783 }
1784
1785 async fn handle(&self, event: &HookEvent) -> HookResult {
1786 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1787 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1788 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1789 Ok(result) => result,
1790 Err(error) => {
1791 tracing::warn!(
1792 extension = %self.id,
1793 error = %error,
1794 response = %value,
1795 "Extension hook handler returned invalid result",
1796 );
1797 if value.get("action").and_then(Value::as_str) == Some("modify") {
1798 HookResult::Block {
1799 reason: "Extension returned malformed modify result".to_string(),
1800 }
1801 } else {
1802 HookResult::Continue
1803 }
1804 }
1805 },
1806 Ok(Err(e)) => {
1807 tracing::warn!(
1808 extension = %self.id,
1809 error = %e,
1810 "Extension hook handler failed — continuing",
1811 );
1812 HookResult::Continue
1813 }
1814 Err(_) => {
1815 tracing::warn!(
1816 extension = %self.id,
1817 timeout_secs = 5,
1818 "Extension hook handler timed out — continuing",
1819 );
1820 HookResult::Continue
1821 }
1822 }
1823 }
1824
1825 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1826 let value = tokio::time::timeout(
1827 std::time::Duration::from_secs(5),
1828 self.call("info.get", Value::Null),
1829 )
1830 .await
1831 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1832 serde_json::from_value(value)
1833 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1834 }
1835
1836 async fn sidecar_spawn_args(
1837 &self,
1838 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1839 let value = tokio::time::timeout(
1840 std::time::Duration::from_secs(5),
1841 self.call("sidecar.spawn_args", Value::Null),
1842 )
1843 .await
1844 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1845 serde_json::from_value(value).map_err(|e| {
1846 format!(
1847 "Invalid sidecar.spawn_args response from extension '{}': {}",
1848 self.id, e
1849 )
1850 })
1851 }
1852
1853 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1854 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1855 category: category.to_string(),
1856 field: field.to_string(),
1857 };
1858 tokio::time::timeout(
1859 std::time::Duration::from_secs(5),
1860 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1861 )
1862 .await
1863 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1864 }
1865
1866 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1867 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1868 key: key.to_string(),
1869 }).map_err(|e| e.to_string())?;
1870 if let Some(obj) = params.as_object_mut() {
1871 obj.insert("category".to_string(), Value::String(category.to_string()));
1872 obj.insert("field".to_string(), Value::String(field.to_string()));
1873 }
1874 tokio::time::timeout(
1875 std::time::Duration::from_secs(5),
1876 self.call("settings.editor.key", params),
1877 )
1878 .await
1879 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1880 }
1881
1882 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1883 let params = serde_json::json!({
1884 "category": category,
1885 "field": field,
1886 "value": value,
1887 });
1888 tokio::time::timeout(
1889 std::time::Duration::from_secs(5),
1890 self.call("settings.editor.commit", params),
1891 )
1892 .await
1893 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1894 }
1895
1896 async fn shutdown(&self) {
1897 let _ = tokio::time::timeout(
1898 std::time::Duration::from_millis(500),
1899 self.call("shutdown", Value::Null),
1900 )
1901 .await;
1902
1903 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1904 let mut state_guard = self.state.lock().await;
1905 if let Some(state) = state_guard.take() {
1906 state.reader_handle.abort();
1907 let mut child = state.child;
1908 let _ = child.kill().await;
1909 }
1910 self.inbox.notification_sink.lock().await.take();
1912 self.inbox
1913 .fail_all_pending("transport closed: extension shutdown")
1914 .await;
1915 }
1916
1917 async fn restart_count(&self) -> usize {
1918 self.restart_count()
1919 }
1920
1921 async fn health(&self) -> ExtensionHealth {
1922 let count = self.restart_count.load(Ordering::Relaxed);
1923 let max = self.restart_policy.max_attempts as usize;
1924 if count >= max {
1925 ExtensionHealth::Failed
1926 } else if count > 0 {
1927 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1931 if state_alive {
1932 ExtensionHealth::Degraded
1933 } else {
1934 ExtensionHealth::Restarting
1935 }
1936 } else {
1937 ExtensionHealth::Running
1938 }
1939 }
1940}
1941
1942#[cfg(test)]
1943mod stream_event_tests {
1944 use super::*;
1945 use serde_json::json;
1946
1947 #[test]
1948 fn parses_text_delta_with_delta_key() {
1949 let v = json!({"type": "text", "delta": "hi"});
1950 assert_eq!(
1951 parse_provider_stream_event(&v).unwrap(),
1952 ProviderStreamEvent::TextDelta { text: "hi".into() }
1953 );
1954 }
1955
1956 #[test]
1957 fn parses_text_delta_with_text_key() {
1958 let v = json!({"type": "text", "text": "hi"});
1959 assert_eq!(
1960 parse_provider_stream_event(&v).unwrap(),
1961 ProviderStreamEvent::TextDelta { text: "hi".into() }
1962 );
1963 }
1964
1965 #[test]
1966 fn parses_thinking_delta() {
1967 let v = json!({"type": "thinking", "delta": "hmm"});
1968 assert_eq!(
1969 parse_provider_stream_event(&v).unwrap(),
1970 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1971 );
1972 let v2 = json!({"type": "thinking", "text": "hmm"});
1973 assert_eq!(
1974 parse_provider_stream_event(&v2).unwrap(),
1975 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1976 );
1977 }
1978
1979 #[test]
1980 fn parses_tool_use() {
1981 let v = json!({
1982 "type": "tool_use",
1983 "id": "t1",
1984 "name": "echo",
1985 "input": {"x": 1}
1986 });
1987 assert_eq!(
1988 parse_provider_stream_event(&v).unwrap(),
1989 ProviderStreamEvent::ToolUse {
1990 id: "t1".into(),
1991 name: "echo".into(),
1992 input: json!({"x": 1}),
1993 }
1994 );
1995 }
1996
1997 #[test]
1998 fn tool_use_input_defaults_to_empty_object() {
1999 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2000 assert_eq!(
2001 parse_provider_stream_event(&v).unwrap(),
2002 ProviderStreamEvent::ToolUse {
2003 id: "t1".into(),
2004 name: "echo".into(),
2005 input: json!({}),
2006 }
2007 );
2008 }
2009
2010 #[test]
2011 fn parses_usage_strips_type() {
2012 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2013 assert_eq!(
2014 parse_provider_stream_event(&v).unwrap(),
2015 ProviderStreamEvent::Usage {
2016 usage: json!({"input_tokens": 5, "output_tokens": 7})
2017 }
2018 );
2019 }
2020
2021 #[test]
2022 fn parses_error() {
2023 let v = json!({"type": "error", "message": "boom"});
2024 assert_eq!(
2025 parse_provider_stream_event(&v).unwrap(),
2026 ProviderStreamEvent::Error { message: "boom".into() }
2027 );
2028 }
2029
2030 #[test]
2031 fn parses_done() {
2032 let v = json!({"type": "done"});
2033 assert_eq!(
2034 parse_provider_stream_event(&v).unwrap(),
2035 ProviderStreamEvent::Done
2036 );
2037 }
2038
2039 #[test]
2040 fn nested_event_shape_matches_flat() {
2041 let flat = json!({"type": "text", "delta": "hi"});
2042 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2043 assert_eq!(
2044 parse_provider_stream_event(&flat).unwrap(),
2045 parse_provider_stream_event(&nested).unwrap()
2046 );
2047 }
2048
2049 #[test]
2050 fn missing_type_errors() {
2051 let v = json!({"delta": "hi"});
2052 let err = parse_provider_stream_event(&v).unwrap_err();
2053 assert!(err.contains("missing type"), "got: {err}");
2054 }
2055
2056 #[test]
2057 fn unknown_type_errors_with_type() {
2058 let v = json!({"type": "wat"});
2059 let err = parse_provider_stream_event(&v).unwrap_err();
2060 assert!(err.contains("wat"), "got: {err}");
2061 }
2062
2063 #[test]
2064 fn tool_use_missing_id_errors() {
2065 let v = json!({"type": "tool_use", "name": "echo"});
2066 let err = parse_provider_stream_event(&v).unwrap_err();
2067 assert!(err.contains("id"), "got: {err}");
2068 }
2069
2070 #[test]
2071 fn tool_use_missing_name_errors() {
2072 let v = json!({"type": "tool_use", "id": "t1"});
2073 let err = parse_provider_stream_event(&v).unwrap_err();
2074 assert!(err.contains("name"), "got: {err}");
2075 }
2076
2077 #[test]
2078 fn tool_use_empty_id_errors() {
2079 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2080 assert!(parse_provider_stream_event(&v).is_err());
2081 }
2082
2083 #[test]
2084 fn tool_use_empty_name_errors() {
2085 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2086 assert!(parse_provider_stream_event(&v).is_err());
2087 }
2088
2089 #[test]
2090 fn tool_use_non_object_input_errors() {
2091 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2092 let err = parse_provider_stream_event(&v).unwrap_err();
2093 assert!(err.contains("input"), "got: {err}");
2094 }
2095
2096 #[test]
2097 fn text_missing_delta_and_text_errors() {
2098 let v = json!({"type": "text"});
2099 let err = parse_provider_stream_event(&v).unwrap_err();
2100 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2101 }
2102
2103 #[test]
2104 fn error_missing_message_errors() {
2105 let v = json!({"type": "error"});
2106 assert!(parse_provider_stream_event(&v).is_err());
2107 }
2108
2109 #[test]
2110 fn error_empty_message_errors() {
2111 let v = json!({"type": "error", "message": ""});
2112 assert!(parse_provider_stream_event(&v).is_err());
2113 }
2114}
2115
2116#[cfg(test)]
2117mod restart_policy_tests {
2118 use super::*;
2119
2120 #[tokio::test]
2121 async fn restart_policy_default_max_attempts_is_3() {
2122 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2127 .await
2128 .expect("spawn /bin/cat");
2129 assert_eq!(ext.restart_policy.max_attempts, 3);
2130 ext.shutdown().await;
2131 }
2132
2133 #[tokio::test]
2134 async fn with_restart_policy_overrides_default() {
2135 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2136 .await
2137 .expect("spawn /bin/cat");
2138 let custom = RestartPolicy {
2139 max_attempts: 7,
2140 ..RestartPolicy::default()
2141 };
2142 let ext = ext.with_restart_policy(custom);
2143 assert_eq!(ext.restart_policy.max_attempts, 7);
2144 ext.shutdown().await;
2145 }
2146}
2147
2148#[cfg(test)]
2149mod capture_validator_tests {
2150 use super::*;
2151 use crate::extensions::permissions::{Permission, PermissionSet};
2152
2153 fn perms_with(grants: &[Permission]) -> PermissionSet {
2154 let mut p = PermissionSet::new();
2155 for g in grants {
2156 p.grant(*g);
2157 }
2158 p
2159 }
2160
2161 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2162 CapabilityDeclaration {
2163 kind: kind.to_string(),
2164 name: name.to_string(),
2165 permissions: perms.iter().map(|p| p.to_string()).collect(),
2166 params: serde_json::Value::Null,
2167 }
2168 }
2169
2170 #[test]
2171 fn capability_validator_rejects_empty_kind() {
2172 let d = cap(" ", "Sample", &["audio.input"]);
2173 let perms = perms_with(&[Permission::AudioInput]);
2174 let err = validate_capability(&d, &perms).unwrap_err();
2175 assert!(err.contains("kind"), "got: {}", err);
2176 }
2177
2178 #[test]
2179 fn capability_validator_rejects_empty_name() {
2180 let d = cap("capture", " ", &["audio.input"]);
2181 let perms = perms_with(&[Permission::AudioInput]);
2182 let err = validate_capability(&d, &perms).unwrap_err();
2183 assert!(err.contains("name"), "got: {}", err);
2184 }
2185
2186 #[test]
2187 fn capability_validator_rejects_unknown_permission_string() {
2188 let d = cap("capture", "Sample", &["audio.telepathy"]);
2189 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2190 let err = validate_capability(&d, &perms).unwrap_err();
2191 assert!(
2192 err.contains("unknown permission") && err.contains("audio.telepathy"),
2193 "got: {}",
2194 err,
2195 );
2196 }
2197
2198 #[test]
2199 fn capability_validator_requires_every_declared_permission() {
2200 let d = cap("capture", "Sample", &["audio.input"]);
2201 let perms = perms_with(&[]);
2202 let err = validate_capability(&d, &perms).unwrap_err();
2203 assert!(
2204 err.contains("audio.input") && err.contains("not granted"),
2205 "got: {}",
2206 err,
2207 );
2208 }
2209
2210 #[test]
2211 fn capability_validator_accepts_when_all_permissions_granted() {
2212 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2213 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2214 validate_capability(&d, &perms).expect("should validate");
2215 }
2216
2217 #[test]
2218 fn capability_validator_accepts_no_permissions() {
2219 let d = cap("ocr", "Tesseract", &[]);
2223 let perms = perms_with(&[]);
2224 validate_capability(&d, &perms).expect("should validate");
2225 }
2226
2227 #[test]
2228 fn capability_validator_does_not_branch_on_kind() {
2229 let perms = perms_with(&[Permission::AudioInput]);
2233 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2234 let d = cap(kind, "Anything", &["audio.input"]);
2235 validate_capability(&d, &perms).expect("should validate");
2236 }
2237 }
2238
2239}
2240
2241#[cfg(test)]
2242mod invoke_command_dispatch_tests {
2243 use super::*;
2248 use crate::extensions::commands::CommandOutputEvent;
2249 use crate::extensions::runtime::InvokeCommandEvent;
2250 use crate::extensions::tasks::{TaskEvent, TaskKind};
2251 use serde_json::json;
2252 use tokio::sync::mpsc;
2253
2254 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2255 NotificationFrame {
2256 method: method.to_string(),
2257 params,
2258 }
2259 }
2260
2261 #[test]
2262 fn forwards_mixed_event_stream_in_order() {
2263 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2264 let mut open = true;
2265 let frames = vec![
2266 frame(
2267 "command.output",
2268 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2269 ),
2270 frame(
2271 "task.start",
2272 json!({"id":"dl","label":"Downloading","kind":"download"}),
2273 ),
2274 frame(
2275 "task.update",
2276 json!({"id":"dl","current":50,"total":100}),
2277 ),
2278 frame(
2279 "command.output",
2280 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2281 ),
2282 frame("task.done", json!({"id":"dl"})),
2283 frame(
2284 "command.output",
2285 json!({"request_id":"r1","event":{"kind":"done"}}),
2286 ),
2287 ];
2288
2289 let mut saw_done = false;
2290 for f in frames {
2291 saw_done |= ProcessExtension::forward_invoke_command_frame(
2292 "ext-test", "r1", &tx, &mut open, f,
2293 );
2294 }
2295 drop(tx);
2296 assert!(saw_done, "should have observed the command Done marker");
2297
2298 let mut events = Vec::new();
2299 while let Ok(ev) = rx.try_recv() {
2300 events.push(ev);
2301 }
2302 assert_eq!(events.len(), 6);
2303 assert_eq!(
2304 events[0],
2305 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2306 );
2307 assert!(matches!(
2308 events[1],
2309 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2310 ));
2311 assert!(matches!(
2312 events[2],
2313 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2314 ));
2315 assert!(matches!(
2316 events[3],
2317 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2318 ));
2319 assert!(matches!(
2320 events[4],
2321 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2322 ));
2323 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2324 }
2325
2326 #[test]
2327 fn ignores_command_output_for_unrelated_request_id() {
2328 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2329 let mut open = true;
2330 ProcessExtension::forward_invoke_command_frame(
2331 "ext",
2332 "r1",
2333 &tx,
2334 &mut open,
2335 frame(
2336 "command.output",
2337 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2338 ),
2339 );
2340 drop(tx);
2341 assert!(rx.try_recv().is_err());
2342 }
2343
2344 #[test]
2345 fn skips_malformed_command_output_without_aborting() {
2346 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2347 let mut open = true;
2348 ProcessExtension::forward_invoke_command_frame(
2350 "ext",
2351 "r1",
2352 &tx,
2353 &mut open,
2354 frame("command.output", json!({"request_id":"r1","event":{}})),
2355 );
2356 ProcessExtension::forward_invoke_command_frame(
2358 "ext",
2359 "r1",
2360 &tx,
2361 &mut open,
2362 frame(
2363 "command.output",
2364 json!({"request_id":"r1","event":{"kind":"done"}}),
2365 ),
2366 );
2367 drop(tx);
2368 let ev = rx.try_recv().unwrap();
2369 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2370 assert!(rx.try_recv().is_err());
2371 }
2372
2373 #[test]
2374 fn task_events_pass_through_regardless_of_request_id() {
2375 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2376 let mut open = true;
2377 ProcessExtension::forward_invoke_command_frame(
2378 "ext",
2379 "r1",
2380 &tx,
2381 &mut open,
2382 frame("task.log", json!({"id":"abc","line":"..."})),
2383 );
2384 drop(tx);
2385 match rx.try_recv().unwrap() {
2386 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2387 assert_eq!(id, "abc");
2388 assert_eq!(line, "...");
2389 }
2390 other => panic!("unexpected: {other:?}"),
2391 }
2392 }
2393
2394 #[test]
2395 fn unrelated_methods_are_dropped() {
2396 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2397 let mut open = true;
2398 ProcessExtension::forward_invoke_command_frame(
2399 "ext",
2400 "r1",
2401 &tx,
2402 &mut open,
2403 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2404 );
2405 drop(tx);
2406 assert!(rx.try_recv().is_err());
2407 }
2408}