1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 false,
251 ).await;
252
253 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
254 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
255 return serde_json::json!({
256 "type": "tool_result",
257 "tool_use_id": tool_id,
258 "content": format!("Tool call blocked by extension: {}", reason),
259 "is_error": true,
260 });
261 };
262
263 let input_for_hook = input.clone();
264 let (result, is_error) = match tool.execute(input, ctx).await {
265 Ok(output) => (output, false),
266 Err(error) => (format!("Tool execution failed: {}", error), true),
267 };
268 let _ = crate::runtime::emit_after_tool_call(
269 hook_bus,
270 &tool_name,
271 Some(&runtime_name),
272 input_for_hook,
273 result.clone(),
274 ).await;
275
276 let mut response = serde_json::json!({
277 "type": "tool_result",
278 "tool_use_id": tool_id,
279 "content": crate::truncate_str(&result, max_tool_output).to_string(),
280 });
281 if is_error {
282 response["is_error"] = serde_json::json!(true);
283 }
284 response
285}
286
287pub async fn complete_provider_with_tools<F>(
288 handler: Arc<dyn ExtensionHandler>,
289 mut params: ProviderCompleteParams,
290 registry: &crate::ToolRegistry,
291 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
292 mut context_factory: F,
293 max_tool_output: usize,
294 max_iterations: usize,
295) -> Result<ProviderCompleteResult, String>
296where
297 F: FnMut() -> crate::ToolContext,
298{
299 let max_iterations = max_iterations.max(1);
300 for iteration in 0..max_iterations {
301 let result = handler.provider_complete(params.clone()).await?;
302 let tool_uses = extract_provider_tool_uses(&result.content)?;
303 if tool_uses.is_empty() {
304 return Ok(result);
305 }
306 if iteration + 1 == max_iterations {
307 return Err(format!(
308 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
309 handler.id(),
310 max_iterations,
311 ));
312 }
313
314 let assistant_content = result.content.clone();
315 params.messages.push(serde_json::json!({
316 "role": "assistant",
317 "content": assistant_content,
318 }));
319
320 let mut tool_results = Vec::with_capacity(tool_uses.len());
321 for tool_use in tool_uses {
322 tool_results.push(execute_provider_tool_use(
323 registry,
324 hook_bus,
325 tool_use,
326 context_factory(),
327 max_tool_output,
328 ).await);
329 }
330 params.messages.push(serde_json::json!({
331 "role": "user",
332 "content": tool_results,
333 }));
334 }
335 Err(format!(
336 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
337 handler.id(),
338 max_iterations,
339 ))
340}
341
342pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
343 let mut tool_uses = Vec::new();
344 for block in content {
345 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
346 continue;
347 }
348 let id = block
349 .get("id")
350 .and_then(Value::as_str)
351 .ok_or_else(|| "provider tool_use missing id".to_string())?;
352 let name = block
353 .get("name")
354 .and_then(Value::as_str)
355 .ok_or_else(|| "provider tool_use missing name".to_string())?;
356 if id.trim().is_empty() {
357 return Err("provider tool_use id is empty".to_string());
358 }
359 if name.trim().is_empty() {
360 return Err("provider tool_use name is empty".to_string());
361 }
362 let input = block
363 .get("input")
364 .cloned()
365 .unwrap_or_else(|| serde_json::json!({}));
366 if !input.is_object() {
367 return Err(format!(
368 "provider tool_use '{}' input must be a JSON object",
369 id
370 ));
371 }
372 tool_uses.push(ProviderToolUse {
373 id: id.to_string(),
374 name: name.to_string(),
375 input,
376 });
377 }
378 Ok(tool_uses)
379}
380
381#[derive(Debug, Clone, Default, PartialEq)]
382pub struct InitializeCapabilitiesResult {
383 pub tools: Vec<RegisteredExtensionToolSpec>,
384 pub providers: Vec<RegisteredProviderSpec>,
385 pub capabilities: Vec<CapabilityDeclaration>,
391}
392
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
402pub struct CapabilityDeclaration {
403 pub kind: String,
407 pub name: String,
410 #[serde(default)]
415 pub permissions: Vec<String>,
416 #[serde(default, skip_serializing_if = "is_null_value")]
420 pub params: serde_json::Value,
421}
422
423fn is_null_value(v: &serde_json::Value) -> bool {
424 v.is_null()
425}
426
427pub fn validate_capability(
437 decl: &CapabilityDeclaration,
438 granted: &crate::extensions::permissions::PermissionSet,
439) -> Result<(), String> {
440 use crate::extensions::permissions::Permission;
441 if decl.kind.trim().is_empty() {
442 return Err("capability 'kind' must be non-empty".to_string());
443 }
444 if decl.name.trim().is_empty() {
445 return Err("capability 'name' must be non-empty".to_string());
446 }
447 for perm_name in &decl.permissions {
448 let parsed = Permission::parse(perm_name).ok_or_else(|| {
449 format!(
450 "capability '{}' declares unknown permission '{}'",
451 decl.kind, perm_name
452 )
453 })?;
454 if !granted.has(parsed) {
455 return Err(format!(
456 "capability '{}' requires permission '{}' but it is not granted",
457 decl.kind, perm_name
458 ));
459 }
460 }
461 Ok(())
462}
463
464#[derive(Deserialize)]
465struct InitializeResult {
466 protocol_version: u32,
467 #[serde(default)]
468 capabilities: InitializeCapabilities,
469}
470
471#[derive(Default, Deserialize)]
472struct InitializeCapabilities {
473 #[serde(default)]
474 tools: Vec<RegisteredExtensionToolSpec>,
475 #[serde(default)]
476 providers: Vec<RegisteredProviderSpec>,
477 #[serde(default)]
479 capabilities: Vec<CapabilityDeclaration>,
480}
481
482#[doc(hidden)]
488#[derive(Debug, Clone)]
489pub struct NotificationFrame {
490 pub method: String,
491 pub params: Value,
492}
493
494struct Inbox {
501 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
502 notification_sinks: Mutex<Vec<(usize, mpsc::UnboundedSender<NotificationFrame>)>>,
508 next_sink_id: std::sync::atomic::AtomicUsize,
512 closed: std::sync::atomic::AtomicBool,
515 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
518 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
522 extension_id: String,
524}
525
526impl Inbox {
527 fn new(extension_id: String) -> Self {
528 Self {
529 pending: Mutex::new(HashMap::new()),
530 notification_sinks: Mutex::new(Vec::new()),
531 next_sink_id: std::sync::atomic::AtomicUsize::new(0),
532 closed: std::sync::atomic::AtomicBool::new(false),
533 permissions: RwLock::new(None),
534 inbound_stdin: Mutex::new(None),
535 extension_id,
536 }
537 }
538
539 async fn fail_all_pending(&self, reason: &str) {
542 self.closed.store(true, std::sync::atomic::Ordering::Release);
543 let drained: Vec<_> = {
544 let mut pending = self.pending.lock().await;
545 pending.drain().collect()
546 };
547 for (_, tx) in drained {
548 let _ = tx.send(Err(reason.to_string()));
549 }
550 }
551}
552
553struct ProcessState {
554 child: Child,
555 stdin: Arc<Mutex<ChildStdin>>,
556 reader_handle: JoinHandle<()>,
557}
558
559pub struct ProcessExtension {
561 id: String,
562 command: String,
563 args: Vec<String>,
564 cwd: Option<PathBuf>,
565 state: Arc<Mutex<Option<ProcessState>>>,
566 call_lock: Arc<Mutex<()>>,
568 next_id: AtomicU64,
569 restart_count: AtomicUsize,
570 total_restarts: AtomicUsize,
576 pub(crate) restart_policy: RestartPolicy,
578 inbox: Arc<Inbox>,
582}
583
584impl ProcessExtension {
585 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
586 Self::spawn_with_cwd(id, command, args, None).await
587 }
588
589 pub async fn spawn_with_cwd(
594 id: &str,
595 command: &str,
596 args: &[String],
597 cwd: Option<PathBuf>,
598 ) -> Result<Self, String> {
599 let inbox = Arc::new(Inbox::new(id.to_string()));
600 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
601 Ok(Self {
602 id: id.to_string(),
603 command: command.to_string(),
604 args: args.to_vec(),
605 cwd,
606 state: Arc::new(Mutex::new(Some(state))),
607 call_lock: Arc::new(Mutex::new(())),
608 next_id: AtomicU64::new(1),
609 restart_count: AtomicUsize::new(0),
610 total_restarts: AtomicUsize::new(0),
611 restart_policy: RestartPolicy::default(),
612 inbox,
613 })
614 }
615
616 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
618 self.restart_policy = policy;
619 self
620 }
621
622 async fn spawn_state(
623 id: &str,
624 command: &str,
625 args: &[String],
626 cwd: Option<&PathBuf>,
627 inbox: Arc<Inbox>,
628 ) -> Result<ProcessState, String> {
629 let mut cmd = Command::new(command);
630 cmd.args(args)
631 .stdin(Stdio::piped())
632 .stdout(Stdio::piped())
633 .stderr(Stdio::piped());
634 if let Some(cwd) = cwd {
635 cmd.current_dir(cwd);
636 }
637
638 cmd.env_clear();
644 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
645 if let Ok(val) = std::env::var(var) {
646 cmd.env(var, val);
647 }
648 }
649
650 cmd.kill_on_drop(true);
651
652 let mut child = cmd
653 .spawn()
654 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
655
656 let stdin = child
657 .stdin
658 .take()
659 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
660 let stdout = child
661 .stdout
662 .take()
663 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
664 if let Some(stderr) = child.stderr.take() {
665 let extension_id = id.to_string();
666 tokio::spawn(async move {
667 let mut lines = BufReader::new(stderr).lines();
668 loop {
669 match lines.next_line().await {
670 Ok(Some(line)) => {
671 tracing::debug!(extension = %extension_id, stderr = %line);
672 }
673 Ok(None) => break,
674 Err(error) => {
675 tracing::debug!(
676 extension = %extension_id,
677 error = %error,
678 "Failed to read extension stderr",
679 );
680 break;
681 }
682 }
683 }
684 });
685 }
686
687 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
688
689 let stdin_arc = Arc::new(Mutex::new(stdin));
690 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
693
694 Ok(ProcessState {
695 child,
696 stdin: stdin_arc,
697 reader_handle,
698 })
699 }
700
701 fn spawn_reader(
706 stdout: ChildStdout,
707 inbox: Arc<Inbox>,
708 extension_id: String,
709 ) -> JoinHandle<()> {
710 tokio::spawn(async move {
711 let mut reader = BufReader::new(stdout);
712 loop {
713 match Self::read_one_frame(&mut reader, &extension_id).await {
714 Ok(Some(value)) => {
715 Self::dispatch_frame(value, &inbox, &extension_id).await;
716 }
717 Ok(None) => {
718 tracing::debug!(
719 extension = %extension_id,
720 "Extension stdout closed (EOF); failing pending requests",
721 );
722 inbox.fail_all_pending("transport closed: EOF").await;
723 inbox.notification_sinks.lock().await.clear();
725 return;
726 }
727 Err(error) => {
728 tracing::debug!(
729 extension = %extension_id,
730 error = %error,
731 "Extension transport read error",
732 );
733 inbox
734 .fail_all_pending(&format!("transport error: {}", error))
735 .await;
736 inbox.notification_sinks.lock().await.clear();
737 return;
738 }
739 }
740 }
741 })
742 }
743
744 async fn read_one_frame(
748 reader: &mut BufReader<ChildStdout>,
749 extension_id: &str,
750 ) -> Result<Option<Value>, String> {
751 let mut content_length: Option<usize> = None;
752 let mut saw_any_header = false;
753 loop {
754 let mut header_line = String::new();
755 let n = reader
756 .read_line(&mut header_line)
757 .await
758 .map_err(|e| format!("Read header error: {}", e))?;
759 if n == 0 {
760 if saw_any_header {
761 return Err("Unexpected EOF while reading response headers".into());
762 }
763 return Ok(None);
764 }
765 saw_any_header = true;
766 if header_line.len() > 1024 {
767 return Err(format!(
768 "Extension '{}' header line too long ({} bytes)",
769 extension_id,
770 header_line.len()
771 ));
772 }
773 let trimmed = header_line.trim();
774 if trimmed.is_empty() {
775 break;
776 }
777 if let Some((name, value)) = trimmed.split_once(':') {
778 if name.trim().eq_ignore_ascii_case("Content-Length") {
779 content_length = Some(value.trim().parse().map_err(|_| {
780 format!("Invalid Content-Length value: {:?}", value.trim())
781 })?);
782 }
783 }
784 }
785 let content_length = content_length.ok_or_else(|| {
786 format!(
787 "Extension '{}' frame missing Content-Length header",
788 extension_id
789 )
790 })?;
791 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
792 if content_length > MAX_RESPONSE_SIZE {
793 return Err(format!(
794 "Extension '{}' frame too large: {} bytes (max {})",
795 extension_id, content_length, MAX_RESPONSE_SIZE
796 ));
797 }
798 let mut buf = vec![0u8; content_length];
799 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
800 .await
801 .map_err(|e| format!("Read body error: {}", e))?;
802 let value: Value = serde_json::from_slice(&buf)
803 .map_err(|e| format!("Parse frame error: {}", e))?;
804 Ok(Some(value))
805 }
806
807 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
813 let id_field = value.get("id");
814 let id_is_present = !matches!(id_field, None | Some(Value::Null));
815 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
816
817 if id_is_present && method_field.is_some() {
818 let id = match id_field.and_then(Value::as_u64) {
821 Some(id) => id,
822 None => {
823 tracing::trace!(
824 extension = %extension_id,
825 frame = %value,
826 "Discarding inbound request with non-numeric id",
827 );
828 return;
829 }
830 };
831 let Some(method) = method_field else { return };
832 let params = value.get("params").cloned().unwrap_or(Value::Null);
833 let inbox = inbox.clone();
834 let extension_id = extension_id.to_string();
835 tokio::spawn(async move {
836 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
837 let payload = match outcome {
838 Ok(result) => serde_json::json!({
839 "jsonrpc": "2.0",
840 "id": id,
841 "result": result,
842 }),
843 Err((code, message)) => serde_json::json!({
844 "jsonrpc": "2.0",
845 "id": id,
846 "error": {"code": code, "message": message},
847 }),
848 };
849 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
850 if let Some(stdin) = stdin_handle {
851 let body = match serde_json::to_string(&payload) {
852 Ok(s) => s,
853 Err(error) => {
854 tracing::warn!(
855 extension = %extension_id,
856 error = %error,
857 "Failed to serialize inbound response",
858 );
859 return;
860 }
861 };
862 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
863 let mut stdin = stdin.lock().await;
864 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
865 tracing::warn!(
866 extension = %extension_id,
867 error = %error,
868 "Failed to write inbound response",
869 );
870 return;
871 }
872 if let Err(error) = stdin.flush().await {
873 tracing::warn!(
874 extension = %extension_id,
875 error = %error,
876 "Failed to flush inbound response",
877 );
878 }
879 } else {
880 tracing::warn!(
881 extension = %extension_id,
882 "No stdin available to reply to inbound request",
883 );
884 }
885 });
886 return;
887 }
888
889 if id_is_present {
890 let id = match id_field.and_then(Value::as_u64) {
891 Some(id) => id,
892 None => {
893 tracing::trace!(
894 extension = %extension_id,
895 frame = %value,
896 "Discarding frame with non-numeric id",
897 );
898 return;
899 }
900 };
901 let sender = inbox.pending.lock().await.remove(&id);
902 match sender {
903 Some(tx) => {
904 let payload = if let Some(err) = value.get("error") {
905 let message = err
906 .get("message")
907 .and_then(Value::as_str)
908 .unwrap_or("unknown extension error")
909 .to_string();
910 Err(format!("Extension error: {}", message))
911 } else {
912 Ok(value
913 .get("result")
914 .cloned()
915 .unwrap_or(Value::Null))
916 };
917 let _ = tx.send(payload);
918 }
919 None => {
920 tracing::trace!(
921 extension = %extension_id,
922 id = id,
923 "Response with unknown id (no pending request); dropping",
924 );
925 }
926 }
927 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
928 let params = value.get("params").cloned().unwrap_or(Value::Null);
929 let frame = NotificationFrame {
930 method: method.to_string(),
931 params,
932 };
933 let mut sinks = inbox.notification_sinks.lock().await;
934 if sinks.is_empty() {
935 tracing::trace!(
936 extension = %extension_id,
937 method = %method,
938 "Notification with no active subscribers; dropping",
939 );
940 } else {
941 sinks.retain(|(_, tx)| tx.send(frame.clone()).is_ok());
945 }
946 } else {
947 tracing::trace!(
948 extension = %extension_id,
949 frame = %value,
950 "Unrecognized frame; dropping",
951 );
952 }
953 }
954
955 pub fn restart_count(&self) -> usize {
956 self.total_restarts.load(Ordering::Relaxed)
959 }
960
961 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
964 *self.inbox.permissions.write().await = Some(perms);
965 }
966
967 #[allow(clippy::doc_lazy_continuation)]
975 async fn handle_inbound_request(
976 inbox: &Arc<Inbox>,
977 method: &str,
978 params: Value,
979 ) -> Result<Value, (i32, String)> {
980 use crate::extensions::permissions::Permission;
981 use crate::memory::store::{self, MemoryQuery};
982
983 match method {
984 "memory.append" => {
985 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
986 let namespace = Self::param_str(¶ms, "namespace")?;
987 Self::require_namespace_matches(inbox, &namespace).await?;
988 let content = Self::param_str(¶ms, "content")?;
989 let tags = match params.get("tags") {
990 None | Some(Value::Null) => Vec::new(),
991 Some(Value::Array(arr)) => {
992 let mut out = Vec::with_capacity(arr.len());
993 for v in arr {
994 match v.as_str() {
995 Some(s) => out.push(s.to_string()),
996 None => {
997 return Err((
998 -32602,
999 "tags must be an array of strings".to_string(),
1000 ))
1001 }
1002 }
1003 }
1004 out
1005 }
1006 _ => {
1007 return Err((
1008 -32602,
1009 "tags must be an array of strings".to_string(),
1010 ))
1011 }
1012 };
1013 let meta = match params.get("meta") {
1014 None | Some(Value::Null) => None,
1015 Some(v) => Some(v.clone()),
1016 };
1017 let record = store::new_record(namespace, content, tags, meta);
1018 let timestamp_ms = record.timestamp_ms;
1019 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
1020 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1021 }
1022 "memory.query" => {
1023 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1024 let namespace = Self::param_str(¶ms, "namespace")?;
1025 Self::require_namespace_matches(inbox, &namespace).await?;
1026 let q = MemoryQuery {
1027 content_contains: params
1028 .get("content_contains")
1029 .and_then(Value::as_str)
1030 .map(str::to_string),
1031 tag_prefix: params
1032 .get("tag_prefix")
1033 .and_then(Value::as_str)
1034 .map(str::to_string),
1035 since_ms: params.get("since_ms").and_then(Value::as_u64),
1036 until_ms: params.get("until_ms").and_then(Value::as_u64),
1037 limit: params
1038 .get("limit")
1039 .and_then(Value::as_u64)
1040 .map(|n| n as usize),
1041 };
1042 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1043 Ok(serde_json::json!({"records": records}))
1044 }
1045 "config.get" => {
1046 let key = Self::param_str(¶ms, "key")?;
1047 Self::validate_config_key(&key)?;
1048 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1049 Ok(serde_json::json!({"value": value}))
1050 }
1051 "config.set" => {
1052 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1053 let key = Self::param_str(¶ms, "key")?;
1054 Self::validate_config_key(&key)?;
1055 let value = Self::param_str(¶ms, "value")?;
1056 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1057 .map_err(|e| (-32000, e.to_string()))?;
1058 Ok(serde_json::json!({"ok": true}))
1059 }
1060 "config.subscribe" => {
1061 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1062 Ok(serde_json::json!({"ok": true}))
1066 }
1067 other => Err((-32601, format!("method not found: {other}"))),
1068 }
1069 }
1070
1071 async fn require_permission(
1072 inbox: &Arc<Inbox>,
1073 perm: crate::extensions::permissions::Permission,
1074 wire: &str,
1075 ) -> Result<(), (i32, String)> {
1076 let guard = inbox.permissions.read().await;
1077 match guard.as_ref() {
1078 Some(set) if set.has(perm) => Ok(()),
1079 _ => Err((
1080 -32602,
1081 format!("permission denied: {wire} required"),
1082 )),
1083 }
1084 }
1085
1086 async fn require_namespace_matches(
1087 inbox: &Arc<Inbox>,
1088 namespace: &str,
1089 ) -> Result<(), (i32, String)> {
1090 if namespace == inbox.extension_id {
1091 Ok(())
1092 } else {
1093 Err((
1094 -32602,
1095 format!(
1096 "namespace must equal extension id '{}' (got '{}')",
1097 inbox.extension_id, namespace
1098 ),
1099 ))
1100 }
1101 }
1102
1103 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1104 params
1105 .get(name)
1106 .and_then(Value::as_str)
1107 .map(str::to_string)
1108 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1109 }
1110
1111 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1112 let trimmed = key.trim();
1113 if trimmed.is_empty() {
1114 return Err((-32602, "config key must be non-empty".to_string()));
1115 }
1116 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1117 return Err((
1118 -32602,
1119 "config key must not contain dots, slashes, or spaces".to_string(),
1120 ));
1121 }
1122 Ok(())
1123 }
1124
1125 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1126 let params = InitializeParams {
1127 synaps_version: env!("CARGO_PKG_VERSION"),
1128 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1129 plugin_id: self.id.clone(),
1130 plugin_root: plugin_root
1131 .or_else(|| self.cwd.clone())
1132 .map(|path| path.to_string_lossy().to_string()),
1133 config,
1134 };
1135 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1136 Self::parse_initialize_result(&self.id, value)
1137 }
1138
1139 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1140 let result: InitializeResult = serde_json::from_value(value)
1141 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1142 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1143 return Err(format!(
1144 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1145 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1146 ));
1147 }
1148 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1149 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1150 Ok(InitializeCapabilitiesResult {
1151 tools: result.capabilities.tools,
1152 providers: result.capabilities.providers,
1153 capabilities: result.capabilities.capabilities,
1154 })
1155 }
1156
1157 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1158 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1159 let mut names = HashSet::new();
1160 for tool in tools {
1161 let name = tool.name.trim();
1162 if let Err(err) = validate_id_segment(name) {
1163 return Err(match err {
1164 IdValidationError::Empty => format!(
1165 "Extension '{}' registered a tool with an empty tool name",
1166 id
1167 ),
1168 IdValidationError::ContainsReserved { ch } => format!(
1169 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1170 id, name, ch
1171 ),
1172 IdValidationError::TooLong { len, max } => format!(
1173 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1174 id, name, max, len
1175 ),
1176 IdValidationError::ContainsWhitespace => format!(
1177 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1178 id, name
1179 ),
1180 IdValidationError::ContainsControl { ch } => format!(
1181 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1182 id, name, ch as u32
1183 ),
1184 });
1185 }
1186 if !names.insert(name.to_string()) {
1187 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1188 }
1189 if tool.description.trim().is_empty() {
1190 return Err(format!(
1191 "Extension '{}' registered tool '{}' with an empty description",
1192 id, name,
1193 ));
1194 }
1195 if !tool.input_schema.is_object() {
1196 return Err(format!(
1197 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1198 id, name,
1199 ));
1200 }
1201 }
1202 Ok(())
1203 }
1204
1205 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1206 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1207 for provider in providers {
1208 let provider_id = provider.id.trim();
1209 match validate_id_segment(provider_id) {
1210 Ok(()) => {
1211 if !Self::is_safe_provider_id(provider_id) {
1212 return Err(format!(
1213 "Extension '{}' registered provider '{}' with invalid provider id",
1214 id, provider_id
1215 ));
1216 }
1217 }
1218 Err(IdValidationError::Empty) => {
1219 return Err(format!(
1220 "Extension '{}' registered provider with empty provider id",
1221 id
1222 ));
1223 }
1224 Err(err) => {
1225 return Err(format!(
1226 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1227 id, provider_id, err
1228 ));
1229 }
1230 }
1231 if provider.display_name.trim().is_empty() {
1232 return Err(format!(
1233 "Extension '{}' registered provider '{}' with empty display_name",
1234 id, provider_id,
1235 ));
1236 }
1237 if provider.description.trim().is_empty() {
1238 return Err(format!(
1239 "Extension '{}' registered provider '{}' with empty description",
1240 id, provider_id,
1241 ));
1242 }
1243 if provider.models.is_empty() {
1244 return Err(format!(
1245 "Extension '{}' registered provider '{}' must declare at least one model",
1246 id, provider_id,
1247 ));
1248 }
1249 let mut model_ids = HashSet::new();
1250 for model in &provider.models {
1251 let model_id = model.id.trim();
1252 if let Err(err) = validate_id_segment(model_id) {
1253 return Err(match err {
1254 IdValidationError::Empty => format!(
1255 "Extension '{}' registered provider '{}' with empty model id",
1256 id, provider_id
1257 ),
1258 IdValidationError::ContainsReserved { ch } => format!(
1259 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1260 id, provider_id, model_id, ch
1261 ),
1262 IdValidationError::TooLong { len, max } => format!(
1263 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1264 id, provider_id, model_id, max, len
1265 ),
1266 IdValidationError::ContainsWhitespace => format!(
1267 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1268 id, provider_id, model_id
1269 ),
1270 IdValidationError::ContainsControl { ch } => format!(
1271 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1272 id, provider_id, model_id, ch as u32
1273 ),
1274 });
1275 }
1276 if !model_ids.insert(model_id.to_string()) {
1277 return Err(format!(
1278 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1279 id, provider_id, model_id,
1280 ));
1281 }
1282 }
1283 if let Some(config_schema) = &provider.config_schema {
1284 if !config_schema.is_object() {
1285 return Err(format!(
1286 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1287 id, provider_id,
1288 ));
1289 }
1290 }
1291 }
1292 Ok(())
1293 }
1294
1295 fn is_safe_provider_id(id: &str) -> bool {
1296 !id.is_empty()
1297 && !id.contains(':')
1298 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1299 }
1300
1301 #[doc(hidden)]
1302 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1303 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1304 }
1305
1306 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1307 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1308 self.total_restarts.fetch_add(1, Ordering::Relaxed);
1309 let max_attempts = self.restart_policy.max_attempts;
1310 if attempted > max_attempts as usize {
1311 *state = None;
1312 return Err(format!(
1313 "Extension '{}' exceeded restart limit ({})",
1314 self.id, max_attempts,
1315 ));
1316 }
1317
1318 if let Some(old) = state.take() {
1319 old.reader_handle.abort();
1320 let mut child = old.child;
1321 let _ = child.kill().await;
1322 }
1323 self.inbox
1325 .fail_all_pending("transport closed: process restarting")
1326 .await;
1327
1328 let delay = self
1329 .restart_policy
1330 .delay_for_attempt(attempted as u32)
1331 .unwrap_or_default();
1332
1333 tracing::warn!(
1334 extension = %self.id,
1335 attempt = attempted,
1336 max_attempts = max_attempts,
1337 delay_ms = delay.as_millis() as u64,
1338 "Restarting extension process after transport failure",
1339 );
1340
1341 if !delay.is_zero() {
1342 tokio::time::sleep(delay).await;
1343 }
1344
1345 *state = Some(Self::spawn_state(
1346 &self.id,
1347 &self.command,
1348 &self.args,
1349 self.cwd.as_ref(),
1350 self.inbox.clone(),
1351 ).await?);
1352 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1354 self.initialize_locked(state).await?;
1355 Ok(())
1362 }
1363
1364
1365 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1366 let params = InitializeParams {
1367 synaps_version: env!("CARGO_PKG_VERSION"),
1368 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1369 plugin_id: self.id.clone(),
1370 plugin_root: self.cwd
1371 .clone()
1372 .map(|path| path.to_string_lossy().to_string()),
1373 config: Value::Object(Default::default()),
1374 };
1375 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1376 let value = tokio::time::timeout(
1377 std::time::Duration::from_secs(10),
1378 self.call_once_locked(
1379 state.as_mut().expect("state should exist for initialize"),
1380 "initialize",
1381 serde_json::to_value(params).map_err(|e| e.to_string())?,
1382 id,
1383 ),
1384 )
1385 .await
1386 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1387 ?;
1388 Self::parse_initialize_result(&self.id, value).map(|_| ())
1389 }
1390
1391 async fn call_once_locked(
1395 &self,
1396 state: &mut ProcessState,
1397 method: &str,
1398 params: Value,
1399 id: u64,
1400 ) -> Result<Value, String> {
1401 let body = serde_json::to_string(&JsonRpcRequest {
1402 jsonrpc: "2.0",
1403 method: method.to_string(),
1404 params,
1405 id,
1406 })
1407 .map_err(|e| format!("Serialize error: {}", e))?;
1408
1409 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1410 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1413 return Err("transport closed: inbox is shut down".to_string());
1414 }
1415
1416 self.inbox.pending.lock().await.insert(id, tx);
1419
1420 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1423 self.inbox.pending.lock().await.remove(&id);
1424 return Err("transport closed: inbox shut down during registration".to_string());
1425 }
1426
1427 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1428 let write_result = {
1429 let mut stdin = state.stdin.lock().await;
1430 match stdin.write_all(frame.as_bytes()).await {
1431 Ok(()) => stdin.flush().await,
1432 Err(e) => Err(e),
1433 }
1434 };
1435 if let Err(e) = write_result {
1436 self.inbox.pending.lock().await.remove(&id);
1438 return Err(format!("Write error: {}", e));
1439 }
1440
1441 match rx.await {
1442 Ok(payload) => payload,
1443 Err(_) => {
1444 self.inbox.pending.lock().await.remove(&id);
1449 Err("transport closed: response channel dropped".to_string())
1450 }
1451 }
1452 }
1453
1454 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1455 let _call_guard = self.call_lock.lock().await;
1456 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1457 let mut state_guard = self.state.lock().await;
1458 if state_guard.is_none() {
1459 *state_guard = Some(Self::spawn_state(
1460 &self.id,
1461 &self.command,
1462 &self.args,
1463 self.cwd.as_ref(),
1464 self.inbox.clone(),
1465 ).await?);
1466 }
1467 self.call_once_locked(
1468 state_guard.as_mut().expect("state should exist"),
1469 method,
1470 params,
1471 id,
1472 ).await
1473 }
1474
1475 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1476 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1477 let id_str = self.id.clone();
1478 let method_str = method.to_string();
1479
1480 let result = tokio::time::timeout(
1481 std::time::Duration::from_secs(timeout_secs),
1482 self.call_inner(method, params),
1483 )
1484 .await;
1485
1486 match result {
1487 Ok(inner) => inner,
1488 Err(_) => Err(format!(
1489 "Extension '{}' method '{}' timed out after {}s",
1490 id_str, method_str, timeout_secs
1491 )),
1492 }
1493 }
1494
1495 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1496 let _call_guard = self.call_lock.lock().await;
1497 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1498 let mut state_guard = self.state.lock().await;
1499 if state_guard.is_none() {
1500 self.restart_locked(&mut state_guard).await?;
1501 }
1502
1503 let result = self
1504 .call_once_locked(
1505 state_guard.as_mut().expect("state should exist after restart"),
1506 method,
1507 params.clone(),
1508 id,
1509 )
1510 .await;
1511
1512 match result {
1513 Ok(value) => {
1514 self.restart_count.store(0, Ordering::Relaxed);
1518 Ok(value)
1519 }
1520 Err(first_error) => {
1521 self.restart_locked(&mut state_guard).await?;
1522 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1523 self.call_once_locked(
1524 state_guard.as_mut().expect("state should exist after restart"),
1525 method,
1526 params,
1527 retry_id,
1528 )
1529 .await
1530 .map(|value| {
1531 self.restart_count.store(0, Ordering::Relaxed);
1533 value
1534 })
1535 .map_err(|retry_error| {
1536 format!("{}; retry after restart failed: {}", first_error, retry_error)
1537 })
1538 }
1539 }
1540 }
1541
1542 #[doc(hidden)]
1560 pub async fn subscribe_notifications(
1561 &self,
1562 ) -> (usize, mpsc::UnboundedReceiver<NotificationFrame>) {
1563 let (tx, rx) = mpsc::unbounded_channel();
1564 let id = self
1565 .inbox
1566 .next_sink_id
1567 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1568 let mut sinks = self.inbox.notification_sinks.lock().await;
1569 sinks.retain(|(_, tx)| !tx.is_closed());
1573 sinks.push((id, tx));
1574 (id, rx)
1575 }
1576
1577 #[doc(hidden)]
1582 pub async fn unsubscribe_notifications(&self, id: usize) {
1583 let mut sinks = self.inbox.notification_sinks.lock().await;
1584 sinks.retain(|(sub_id, tx)| *sub_id != id && !tx.is_closed());
1585 }
1586
1587 pub(crate) fn forward_invoke_command_frame(
1597 extension_id: &str,
1598 request_id: &str,
1599 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1600 sink_open: &mut bool,
1601 frame: NotificationFrame,
1602 ) -> bool {
1603 use crate::extensions::commands::parse_command_output;
1604 use crate::extensions::tasks::{is_task_method, parse_task_event};
1605 use crate::extensions::runtime::InvokeCommandEvent;
1606
1607 let mut saw_done = false;
1608 if frame.method == "command.output" {
1609 match parse_command_output(&frame.params) {
1610 Ok(parsed) if parsed.request_id == request_id => {
1611 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1612 saw_done = true;
1613 }
1614 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1615 *sink_open = false;
1616 }
1617 }
1618 Ok(_) => {
1619 tracing::trace!(
1621 extension = %extension_id,
1622 "Ignoring command.output for unrelated request_id",
1623 );
1624 }
1625 Err(error) => {
1626 tracing::warn!(
1627 extension = %extension_id,
1628 error = %error,
1629 params = %frame.params,
1630 "Skipping malformed command.output notification",
1631 );
1632 }
1633 }
1634 } else if is_task_method(&frame.method) {
1635 match parse_task_event(&frame.method, &frame.params) {
1636 Ok(event) => {
1637 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1638 *sink_open = false;
1639 }
1640 }
1641 Err(error) => {
1642 tracing::warn!(
1643 extension = %extension_id,
1644 method = %frame.method,
1645 error = %error,
1646 params = %frame.params,
1647 "Skipping malformed task notification",
1648 );
1649 }
1650 }
1651 } else {
1652 tracing::trace!(
1653 extension = %extension_id,
1654 method = %frame.method,
1655 "Ignoring non-command/task notification during command.invoke",
1656 );
1657 }
1658 saw_done
1659 }
1660
1661 fn forward_provider_stream_frame(
1668 extension_id: &str,
1669 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1670 sink_open: &mut bool,
1671 frame: NotificationFrame,
1672 ) {
1673 if frame.method != "provider.stream.event" {
1674 tracing::trace!(
1675 extension = %extension_id,
1676 method = %frame.method,
1677 "Ignoring non-stream notification during provider.stream",
1678 );
1679 return;
1680 }
1681 match parse_provider_stream_event(&frame.params) {
1682 Ok(event) => {
1683 if *sink_open && sink.send(event).is_err() {
1684 *sink_open = false;
1685 }
1686 }
1687 Err(error) => {
1688 tracing::warn!(
1689 extension = %extension_id,
1690 error = %error,
1691 params = %frame.params,
1692 "Skipping malformed provider.stream.event notification",
1693 );
1694 }
1695 }
1696 }
1697}
1698
1699#[async_trait]
1700impl ExtensionHandler for ProcessExtension {
1701 fn id(&self) -> &str {
1702 &self.id
1703 }
1704
1705 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1706 self.call("tool.call", serde_json::json!({
1707 "name": name,
1708 "input": input,
1709 })).await
1710 }
1711
1712 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1713 let value = tokio::time::timeout(
1714 std::time::Duration::from_secs(60),
1715 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1716 )
1717 .await
1718 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1719 let result: ProviderCompleteResult = serde_json::from_value(value)
1720 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1721 if result.content.is_empty() {
1722 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1723 }
1724 Ok(result)
1725 }
1726
1727 async fn provider_stream(
1728 &self,
1729 params: ProviderCompleteParams,
1730 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1731 ) -> Result<ProviderCompleteResult, String> {
1732 let (sub_id, mut rx) = self.subscribe_notifications().await;
1735 let params_value =
1736 serde_json::to_value(params).map_err(|e| e.to_string())?;
1737
1738 let extension_id = self.id.clone();
1739 let stream_future = async {
1740 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1741 let mut sink_open = true;
1742 let response = loop {
1743 tokio::select! {
1744 response = &mut call_fut => break response,
1745 Some(frame) = rx.recv() => {
1746 Self::forward_provider_stream_frame(
1747 &extension_id, &sink, &mut sink_open, frame,
1748 );
1749 }
1750 }
1751 };
1752 self.unsubscribe_notifications(sub_id).await;
1757 while let Some(frame) = rx.recv().await {
1758 Self::forward_provider_stream_frame(
1759 &extension_id, &sink, &mut sink_open, frame,
1760 );
1761 }
1762 response
1763 };
1764
1765 let outcome = tokio::time::timeout(
1766 std::time::Duration::from_secs(60),
1767 stream_future,
1768 )
1769 .await;
1770
1771 self.unsubscribe_notifications(sub_id).await;
1774
1775 let value = outcome
1776 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1777
1778 let result: ProviderCompleteResult = serde_json::from_value(value)
1779 .map_err(|e| {
1780 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1781 })?;
1782 Ok(result)
1785 }
1786
1787 async fn invoke_command(
1788 &self,
1789 command: &str,
1790 args: Vec<String>,
1791 request_id: &str,
1792 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1793 ) -> Result<Value, String> {
1794 let (sub_id, mut rx) = self.subscribe_notifications().await;
1796 let params = serde_json::json!({
1797 "command": command,
1798 "args": args,
1799 "request_id": request_id,
1800 });
1801
1802 let extension_id = self.id.clone();
1803 let request_id_owned = request_id.to_string();
1804 let invoke_future = async {
1805 let mut call_fut = Box::pin(self.call("command.invoke", params));
1806 let mut sink_open = true;
1807 let response = loop {
1808 tokio::select! {
1809 response = &mut call_fut => break response,
1810 Some(frame) = rx.recv() => {
1811 let _ = Self::forward_invoke_command_frame(
1812 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1813 );
1814 }
1815 }
1816 };
1817 self.unsubscribe_notifications(sub_id).await;
1821 while let Ok(frame) = rx.try_recv() {
1822 let _ = Self::forward_invoke_command_frame(
1823 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1824 );
1825 }
1826 response
1827 };
1828
1829 let outcome = tokio::time::timeout(
1830 std::time::Duration::from_secs(120),
1831 invoke_future,
1832 )
1833 .await;
1834
1835 self.unsubscribe_notifications(sub_id).await;
1838
1839 outcome
1840 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1841 }
1842
1843 async fn handle(&self, event: &HookEvent) -> HookResult {
1844 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1845 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1846 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1847 Ok(result) => result,
1848 Err(error) => {
1849 tracing::warn!(
1850 extension = %self.id,
1851 error = %error,
1852 response = %value,
1853 "Extension hook handler returned invalid result",
1854 );
1855 if value.get("action").and_then(Value::as_str) == Some("modify") {
1856 HookResult::Block {
1857 reason: "Extension returned malformed modify result".to_string(),
1858 }
1859 } else {
1860 HookResult::Continue
1861 }
1862 }
1863 },
1864 Ok(Err(e)) => {
1865 tracing::warn!(
1866 extension = %self.id,
1867 error = %e,
1868 "Extension hook handler failed — continuing",
1869 );
1870 HookResult::Continue
1871 }
1872 Err(_) => {
1873 tracing::warn!(
1874 extension = %self.id,
1875 timeout_secs = 5,
1876 "Extension hook handler timed out — continuing",
1877 );
1878 HookResult::Continue
1879 }
1880 }
1881 }
1882
1883 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1884 let value = tokio::time::timeout(
1885 std::time::Duration::from_secs(5),
1886 self.call("info.get", Value::Null),
1887 )
1888 .await
1889 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1890 serde_json::from_value(value)
1891 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1892 }
1893
1894 async fn sidecar_spawn_args(
1895 &self,
1896 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1897 let value = tokio::time::timeout(
1898 std::time::Duration::from_secs(5),
1899 self.call("sidecar.spawn_args", Value::Null),
1900 )
1901 .await
1902 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1903 serde_json::from_value(value).map_err(|e| {
1904 format!(
1905 "Invalid sidecar.spawn_args response from extension '{}': {}",
1906 self.id, e
1907 )
1908 })
1909 }
1910
1911 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1912 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1913 category: category.to_string(),
1914 field: field.to_string(),
1915 };
1916 tokio::time::timeout(
1917 std::time::Duration::from_secs(5),
1918 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1919 )
1920 .await
1921 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1922 }
1923
1924 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1925 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1926 key: key.to_string(),
1927 }).map_err(|e| e.to_string())?;
1928 if let Some(obj) = params.as_object_mut() {
1929 obj.insert("category".to_string(), Value::String(category.to_string()));
1930 obj.insert("field".to_string(), Value::String(field.to_string()));
1931 }
1932 tokio::time::timeout(
1933 std::time::Duration::from_secs(5),
1934 self.call("settings.editor.key", params),
1935 )
1936 .await
1937 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1938 }
1939
1940 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1941 let params = serde_json::json!({
1942 "category": category,
1943 "field": field,
1944 "value": value,
1945 });
1946 tokio::time::timeout(
1947 std::time::Duration::from_secs(5),
1948 self.call("settings.editor.commit", params),
1949 )
1950 .await
1951 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1952 }
1953
1954 async fn shutdown(&self) {
1955 let _ = tokio::time::timeout(
1956 std::time::Duration::from_millis(500),
1957 self.call("shutdown", Value::Null),
1958 )
1959 .await;
1960
1961 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1962 let mut state_guard = self.state.lock().await;
1963 if let Some(state) = state_guard.take() {
1964 state.reader_handle.abort();
1965 let mut child = state.child;
1966 let _ = child.kill().await;
1967 }
1968 self.inbox.notification_sinks.lock().await.clear();
1970 self.inbox
1971 .fail_all_pending("transport closed: extension shutdown")
1972 .await;
1973 }
1974
1975 async fn subscribe_notifications(&self) -> (usize, tokio::sync::mpsc::UnboundedReceiver<NotificationFrame>) {
1976 ProcessExtension::subscribe_notifications(self).await
1977 }
1978
1979 async fn restart_count(&self) -> usize {
1980 self.restart_count()
1981 }
1982
1983 async fn health(&self) -> ExtensionHealth {
1984 let consecutive = self.restart_count.load(Ordering::Relaxed);
1988 let lifetime = self.total_restarts.load(Ordering::Relaxed);
1989 let max = self.restart_policy.max_attempts as usize;
1990 if consecutive >= max {
1991 ExtensionHealth::Failed
1992 } else if lifetime > 0 {
1993 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1997 if state_alive {
1998 ExtensionHealth::Degraded
1999 } else {
2000 ExtensionHealth::Restarting
2001 }
2002 } else {
2003 ExtensionHealth::Running
2004 }
2005 }
2006}
2007
2008#[cfg(test)]
2009mod stream_event_tests {
2010 use super::*;
2011 use serde_json::json;
2012
2013 #[test]
2014 fn parses_text_delta_with_delta_key() {
2015 let v = json!({"type": "text", "delta": "hi"});
2016 assert_eq!(
2017 parse_provider_stream_event(&v).unwrap(),
2018 ProviderStreamEvent::TextDelta { text: "hi".into() }
2019 );
2020 }
2021
2022 #[test]
2023 fn parses_text_delta_with_text_key() {
2024 let v = json!({"type": "text", "text": "hi"});
2025 assert_eq!(
2026 parse_provider_stream_event(&v).unwrap(),
2027 ProviderStreamEvent::TextDelta { text: "hi".into() }
2028 );
2029 }
2030
2031 #[test]
2032 fn parses_thinking_delta() {
2033 let v = json!({"type": "thinking", "delta": "hmm"});
2034 assert_eq!(
2035 parse_provider_stream_event(&v).unwrap(),
2036 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2037 );
2038 let v2 = json!({"type": "thinking", "text": "hmm"});
2039 assert_eq!(
2040 parse_provider_stream_event(&v2).unwrap(),
2041 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2042 );
2043 }
2044
2045 #[test]
2046 fn parses_tool_use() {
2047 let v = json!({
2048 "type": "tool_use",
2049 "id": "t1",
2050 "name": "echo",
2051 "input": {"x": 1}
2052 });
2053 assert_eq!(
2054 parse_provider_stream_event(&v).unwrap(),
2055 ProviderStreamEvent::ToolUse {
2056 id: "t1".into(),
2057 name: "echo".into(),
2058 input: json!({"x": 1}),
2059 }
2060 );
2061 }
2062
2063 #[test]
2064 fn tool_use_input_defaults_to_empty_object() {
2065 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2066 assert_eq!(
2067 parse_provider_stream_event(&v).unwrap(),
2068 ProviderStreamEvent::ToolUse {
2069 id: "t1".into(),
2070 name: "echo".into(),
2071 input: json!({}),
2072 }
2073 );
2074 }
2075
2076 #[test]
2077 fn parses_usage_strips_type() {
2078 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2079 assert_eq!(
2080 parse_provider_stream_event(&v).unwrap(),
2081 ProviderStreamEvent::Usage {
2082 usage: json!({"input_tokens": 5, "output_tokens": 7})
2083 }
2084 );
2085 }
2086
2087 #[test]
2088 fn parses_error() {
2089 let v = json!({"type": "error", "message": "boom"});
2090 assert_eq!(
2091 parse_provider_stream_event(&v).unwrap(),
2092 ProviderStreamEvent::Error { message: "boom".into() }
2093 );
2094 }
2095
2096 #[test]
2097 fn parses_done() {
2098 let v = json!({"type": "done"});
2099 assert_eq!(
2100 parse_provider_stream_event(&v).unwrap(),
2101 ProviderStreamEvent::Done
2102 );
2103 }
2104
2105 #[test]
2106 fn nested_event_shape_matches_flat() {
2107 let flat = json!({"type": "text", "delta": "hi"});
2108 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2109 assert_eq!(
2110 parse_provider_stream_event(&flat).unwrap(),
2111 parse_provider_stream_event(&nested).unwrap()
2112 );
2113 }
2114
2115 #[test]
2116 fn missing_type_errors() {
2117 let v = json!({"delta": "hi"});
2118 let err = parse_provider_stream_event(&v).unwrap_err();
2119 assert!(err.contains("missing type"), "got: {err}");
2120 }
2121
2122 #[test]
2123 fn unknown_type_errors_with_type() {
2124 let v = json!({"type": "wat"});
2125 let err = parse_provider_stream_event(&v).unwrap_err();
2126 assert!(err.contains("wat"), "got: {err}");
2127 }
2128
2129 #[test]
2130 fn tool_use_missing_id_errors() {
2131 let v = json!({"type": "tool_use", "name": "echo"});
2132 let err = parse_provider_stream_event(&v).unwrap_err();
2133 assert!(err.contains("id"), "got: {err}");
2134 }
2135
2136 #[test]
2137 fn tool_use_missing_name_errors() {
2138 let v = json!({"type": "tool_use", "id": "t1"});
2139 let err = parse_provider_stream_event(&v).unwrap_err();
2140 assert!(err.contains("name"), "got: {err}");
2141 }
2142
2143 #[test]
2144 fn tool_use_empty_id_errors() {
2145 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2146 assert!(parse_provider_stream_event(&v).is_err());
2147 }
2148
2149 #[test]
2150 fn tool_use_empty_name_errors() {
2151 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2152 assert!(parse_provider_stream_event(&v).is_err());
2153 }
2154
2155 #[test]
2156 fn tool_use_non_object_input_errors() {
2157 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2158 let err = parse_provider_stream_event(&v).unwrap_err();
2159 assert!(err.contains("input"), "got: {err}");
2160 }
2161
2162 #[test]
2163 fn text_missing_delta_and_text_errors() {
2164 let v = json!({"type": "text"});
2165 let err = parse_provider_stream_event(&v).unwrap_err();
2166 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2167 }
2168
2169 #[test]
2170 fn error_missing_message_errors() {
2171 let v = json!({"type": "error"});
2172 assert!(parse_provider_stream_event(&v).is_err());
2173 }
2174
2175 #[test]
2176 fn error_empty_message_errors() {
2177 let v = json!({"type": "error", "message": ""});
2178 assert!(parse_provider_stream_event(&v).is_err());
2179 }
2180}
2181
2182#[cfg(test)]
2183mod restart_policy_tests {
2184 use super::*;
2185
2186 #[tokio::test]
2187 async fn restart_policy_default_max_attempts_is_3() {
2188 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2193 .await
2194 .expect("spawn /bin/cat");
2195 assert_eq!(ext.restart_policy.max_attempts, 3);
2196 ext.shutdown().await;
2197 }
2198
2199 #[tokio::test]
2200 async fn with_restart_policy_overrides_default() {
2201 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2202 .await
2203 .expect("spawn /bin/cat");
2204 let custom = RestartPolicy {
2205 max_attempts: 7,
2206 ..RestartPolicy::default()
2207 };
2208 let ext = ext.with_restart_policy(custom);
2209 assert_eq!(ext.restart_policy.max_attempts, 7);
2210 ext.shutdown().await;
2211 }
2212}
2213
2214#[cfg(test)]
2215mod capture_validator_tests {
2216 use super::*;
2217 use crate::extensions::permissions::{Permission, PermissionSet};
2218
2219 fn perms_with(grants: &[Permission]) -> PermissionSet {
2220 let mut p = PermissionSet::new();
2221 for g in grants {
2222 p.grant(*g);
2223 }
2224 p
2225 }
2226
2227 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2228 CapabilityDeclaration {
2229 kind: kind.to_string(),
2230 name: name.to_string(),
2231 permissions: perms.iter().map(|p| p.to_string()).collect(),
2232 params: serde_json::Value::Null,
2233 }
2234 }
2235
2236 #[test]
2237 fn capability_validator_rejects_empty_kind() {
2238 let d = cap(" ", "Sample", &["audio.input"]);
2239 let perms = perms_with(&[Permission::AudioInput]);
2240 let err = validate_capability(&d, &perms).unwrap_err();
2241 assert!(err.contains("kind"), "got: {}", err);
2242 }
2243
2244 #[test]
2245 fn capability_validator_rejects_empty_name() {
2246 let d = cap("capture", " ", &["audio.input"]);
2247 let perms = perms_with(&[Permission::AudioInput]);
2248 let err = validate_capability(&d, &perms).unwrap_err();
2249 assert!(err.contains("name"), "got: {}", err);
2250 }
2251
2252 #[test]
2253 fn capability_validator_rejects_unknown_permission_string() {
2254 let d = cap("capture", "Sample", &["audio.telepathy"]);
2255 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2256 let err = validate_capability(&d, &perms).unwrap_err();
2257 assert!(
2258 err.contains("unknown permission") && err.contains("audio.telepathy"),
2259 "got: {}",
2260 err,
2261 );
2262 }
2263
2264 #[test]
2265 fn capability_validator_requires_every_declared_permission() {
2266 let d = cap("capture", "Sample", &["audio.input"]);
2267 let perms = perms_with(&[]);
2268 let err = validate_capability(&d, &perms).unwrap_err();
2269 assert!(
2270 err.contains("audio.input") && err.contains("not granted"),
2271 "got: {}",
2272 err,
2273 );
2274 }
2275
2276 #[test]
2277 fn capability_validator_accepts_when_all_permissions_granted() {
2278 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2279 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2280 validate_capability(&d, &perms).expect("should validate");
2281 }
2282
2283 #[test]
2284 fn capability_validator_accepts_no_permissions() {
2285 let d = cap("ocr", "Tesseract", &[]);
2289 let perms = perms_with(&[]);
2290 validate_capability(&d, &perms).expect("should validate");
2291 }
2292
2293 #[test]
2294 fn capability_validator_does_not_branch_on_kind() {
2295 let perms = perms_with(&[Permission::AudioInput]);
2299 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2300 let d = cap(kind, "Anything", &["audio.input"]);
2301 validate_capability(&d, &perms).expect("should validate");
2302 }
2303 }
2304
2305}
2306
2307#[cfg(test)]
2308mod invoke_command_dispatch_tests {
2309 use super::*;
2314 use crate::extensions::commands::CommandOutputEvent;
2315 use crate::extensions::runtime::InvokeCommandEvent;
2316 use crate::extensions::tasks::{TaskEvent, TaskKind};
2317 use serde_json::json;
2318 use tokio::sync::mpsc;
2319
2320 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2321 NotificationFrame {
2322 method: method.to_string(),
2323 params,
2324 }
2325 }
2326
2327 #[test]
2328 fn forwards_mixed_event_stream_in_order() {
2329 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2330 let mut open = true;
2331 let frames = vec![
2332 frame(
2333 "command.output",
2334 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2335 ),
2336 frame(
2337 "task.start",
2338 json!({"id":"dl","label":"Downloading","kind":"download"}),
2339 ),
2340 frame(
2341 "task.update",
2342 json!({"id":"dl","current":50,"total":100}),
2343 ),
2344 frame(
2345 "command.output",
2346 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2347 ),
2348 frame("task.done", json!({"id":"dl"})),
2349 frame(
2350 "command.output",
2351 json!({"request_id":"r1","event":{"kind":"done"}}),
2352 ),
2353 ];
2354
2355 let mut saw_done = false;
2356 for f in frames {
2357 saw_done |= ProcessExtension::forward_invoke_command_frame(
2358 "ext-test", "r1", &tx, &mut open, f,
2359 );
2360 }
2361 drop(tx);
2362 assert!(saw_done, "should have observed the command Done marker");
2363
2364 let mut events = Vec::new();
2365 while let Ok(ev) = rx.try_recv() {
2366 events.push(ev);
2367 }
2368 assert_eq!(events.len(), 6);
2369 assert_eq!(
2370 events[0],
2371 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2372 );
2373 assert!(matches!(
2374 events[1],
2375 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2376 ));
2377 assert!(matches!(
2378 events[2],
2379 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2380 ));
2381 assert!(matches!(
2382 events[3],
2383 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2384 ));
2385 assert!(matches!(
2386 events[4],
2387 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2388 ));
2389 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2390 }
2391
2392 #[test]
2393 fn ignores_command_output_for_unrelated_request_id() {
2394 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2395 let mut open = true;
2396 ProcessExtension::forward_invoke_command_frame(
2397 "ext",
2398 "r1",
2399 &tx,
2400 &mut open,
2401 frame(
2402 "command.output",
2403 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2404 ),
2405 );
2406 drop(tx);
2407 assert!(rx.try_recv().is_err());
2408 }
2409
2410 #[test]
2411 fn skips_malformed_command_output_without_aborting() {
2412 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2413 let mut open = true;
2414 ProcessExtension::forward_invoke_command_frame(
2416 "ext",
2417 "r1",
2418 &tx,
2419 &mut open,
2420 frame("command.output", json!({"request_id":"r1","event":{}})),
2421 );
2422 ProcessExtension::forward_invoke_command_frame(
2424 "ext",
2425 "r1",
2426 &tx,
2427 &mut open,
2428 frame(
2429 "command.output",
2430 json!({"request_id":"r1","event":{"kind":"done"}}),
2431 ),
2432 );
2433 drop(tx);
2434 let ev = rx.try_recv().unwrap();
2435 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2436 assert!(rx.try_recv().is_err());
2437 }
2438
2439 #[test]
2440 fn task_events_pass_through_regardless_of_request_id() {
2441 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2442 let mut open = true;
2443 ProcessExtension::forward_invoke_command_frame(
2444 "ext",
2445 "r1",
2446 &tx,
2447 &mut open,
2448 frame("task.log", json!({"id":"abc","line":"..."})),
2449 );
2450 drop(tx);
2451 match rx.try_recv().unwrap() {
2452 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2453 assert_eq!(id, "abc");
2454 assert_eq!(line, "...");
2455 }
2456 other => panic!("unexpected: {other:?}"),
2457 }
2458 }
2459
2460 #[test]
2461 fn unrelated_methods_are_dropped() {
2462 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2463 let mut open = true;
2464 ProcessExtension::forward_invoke_command_frame(
2465 "ext",
2466 "r1",
2467 &tx,
2468 &mut open,
2469 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2470 );
2471 drop(tx);
2472 assert!(rx.try_recv().is_err());
2473 }
2474}