1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 false,
251 ).await;
252
253 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
254 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
255 return serde_json::json!({
256 "type": "tool_result",
257 "tool_use_id": tool_id,
258 "content": format!("Tool call blocked by extension: {}", reason),
259 "is_error": true,
260 });
261 };
262
263 let input_for_hook = input.clone();
264 let (result, is_error) = match tool.execute(input, ctx).await {
265 Ok(output) => (output, false),
266 Err(error) => (format!("Tool execution failed: {}", error), true),
267 };
268 let _ = crate::runtime::emit_after_tool_call(
269 hook_bus,
270 &tool_name,
271 Some(&runtime_name),
272 input_for_hook,
273 result.clone(),
274 ).await;
275
276 let mut response = serde_json::json!({
277 "type": "tool_result",
278 "tool_use_id": tool_id,
279 "content": crate::truncate_str(&result, max_tool_output).to_string(),
280 });
281 if is_error {
282 response["is_error"] = serde_json::json!(true);
283 }
284 response
285}
286
287pub async fn complete_provider_with_tools<F>(
288 handler: Arc<dyn ExtensionHandler>,
289 mut params: ProviderCompleteParams,
290 registry: &crate::ToolRegistry,
291 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
292 mut context_factory: F,
293 max_tool_output: usize,
294 max_iterations: usize,
295) -> Result<ProviderCompleteResult, String>
296where
297 F: FnMut() -> crate::ToolContext,
298{
299 let max_iterations = max_iterations.max(1);
300 for iteration in 0..max_iterations {
301 let result = handler.provider_complete(params.clone()).await?;
302 let tool_uses = extract_provider_tool_uses(&result.content)?;
303 if tool_uses.is_empty() {
304 return Ok(result);
305 }
306 if iteration + 1 == max_iterations {
307 return Err(format!(
308 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
309 handler.id(),
310 max_iterations,
311 ));
312 }
313
314 let assistant_content = result.content.clone();
315 params.messages.push(serde_json::json!({
316 "role": "assistant",
317 "content": assistant_content,
318 }));
319
320 let mut tool_results = Vec::with_capacity(tool_uses.len());
321 for tool_use in tool_uses {
322 tool_results.push(execute_provider_tool_use(
323 registry,
324 hook_bus,
325 tool_use,
326 context_factory(),
327 max_tool_output,
328 ).await);
329 }
330 params.messages.push(serde_json::json!({
331 "role": "user",
332 "content": tool_results,
333 }));
334 }
335 Err(format!(
336 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
337 handler.id(),
338 max_iterations,
339 ))
340}
341
342pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
343 let mut tool_uses = Vec::new();
344 for block in content {
345 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
346 continue;
347 }
348 let id = block
349 .get("id")
350 .and_then(Value::as_str)
351 .ok_or_else(|| "provider tool_use missing id".to_string())?;
352 let name = block
353 .get("name")
354 .and_then(Value::as_str)
355 .ok_or_else(|| "provider tool_use missing name".to_string())?;
356 if id.trim().is_empty() {
357 return Err("provider tool_use id is empty".to_string());
358 }
359 if name.trim().is_empty() {
360 return Err("provider tool_use name is empty".to_string());
361 }
362 let input = block
363 .get("input")
364 .cloned()
365 .unwrap_or_else(|| serde_json::json!({}));
366 if !input.is_object() {
367 return Err(format!(
368 "provider tool_use '{}' input must be a JSON object",
369 id
370 ));
371 }
372 tool_uses.push(ProviderToolUse {
373 id: id.to_string(),
374 name: name.to_string(),
375 input,
376 });
377 }
378 Ok(tool_uses)
379}
380
381#[derive(Debug, Clone, Default, PartialEq)]
382pub struct InitializeCapabilitiesResult {
383 pub tools: Vec<RegisteredExtensionToolSpec>,
384 pub providers: Vec<RegisteredProviderSpec>,
385 pub capabilities: Vec<CapabilityDeclaration>,
391}
392
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
402pub struct CapabilityDeclaration {
403 pub kind: String,
407 pub name: String,
410 #[serde(default)]
415 pub permissions: Vec<String>,
416 #[serde(default, skip_serializing_if = "is_null_value")]
420 pub params: serde_json::Value,
421}
422
423fn is_null_value(v: &serde_json::Value) -> bool {
424 v.is_null()
425}
426
427pub fn validate_capability(
437 decl: &CapabilityDeclaration,
438 granted: &crate::extensions::permissions::PermissionSet,
439) -> Result<(), String> {
440 use crate::extensions::permissions::Permission;
441 if decl.kind.trim().is_empty() {
442 return Err("capability 'kind' must be non-empty".to_string());
443 }
444 if decl.name.trim().is_empty() {
445 return Err("capability 'name' must be non-empty".to_string());
446 }
447 for perm_name in &decl.permissions {
448 let parsed = Permission::parse(perm_name).ok_or_else(|| {
449 format!(
450 "capability '{}' declares unknown permission '{}'",
451 decl.kind, perm_name
452 )
453 })?;
454 if !granted.has(parsed) {
455 return Err(format!(
456 "capability '{}' requires permission '{}' but it is not granted",
457 decl.kind, perm_name
458 ));
459 }
460 }
461 Ok(())
462}
463
464#[derive(Deserialize)]
465struct InitializeResult {
466 protocol_version: u32,
467 #[serde(default)]
468 capabilities: InitializeCapabilities,
469}
470
471#[derive(Default, Deserialize)]
472struct InitializeCapabilities {
473 #[serde(default)]
474 tools: Vec<RegisteredExtensionToolSpec>,
475 #[serde(default)]
476 providers: Vec<RegisteredProviderSpec>,
477 #[serde(default)]
479 capabilities: Vec<CapabilityDeclaration>,
480}
481
482#[doc(hidden)]
488#[derive(Debug, Clone)]
489pub struct NotificationFrame {
490 pub method: String,
491 pub params: Value,
492}
493
494struct Inbox {
501 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
502 notification_sinks: Mutex<Vec<(usize, mpsc::UnboundedSender<NotificationFrame>)>>,
508 next_sink_id: std::sync::atomic::AtomicUsize,
512 closed: std::sync::atomic::AtomicBool,
515 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
518 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
522 extension_id: String,
524}
525
526impl Inbox {
527 fn new(extension_id: String) -> Self {
528 Self {
529 pending: Mutex::new(HashMap::new()),
530 notification_sinks: Mutex::new(Vec::new()),
531 next_sink_id: std::sync::atomic::AtomicUsize::new(0),
532 closed: std::sync::atomic::AtomicBool::new(false),
533 permissions: RwLock::new(None),
534 inbound_stdin: Mutex::new(None),
535 extension_id,
536 }
537 }
538
539 async fn fail_all_pending(&self, reason: &str) {
542 self.closed.store(true, std::sync::atomic::Ordering::Release);
543 let drained: Vec<_> = {
544 let mut pending = self.pending.lock().await;
545 pending.drain().collect()
546 };
547 for (_, tx) in drained {
548 let _ = tx.send(Err(reason.to_string()));
549 }
550 }
551}
552
553struct ProcessState {
554 child: Child,
555 stdin: Arc<Mutex<ChildStdin>>,
556 reader_handle: JoinHandle<()>,
557}
558
559pub struct ProcessExtension {
561 id: String,
562 command: String,
563 args: Vec<String>,
564 cwd: Option<PathBuf>,
565 state: Arc<Mutex<Option<ProcessState>>>,
566 call_lock: Arc<Mutex<()>>,
568 next_id: AtomicU64,
569 restart_count: AtomicUsize,
570 total_restarts: AtomicUsize,
576 pub(crate) restart_policy: RestartPolicy,
578 inbox: Arc<Inbox>,
582}
583
584impl ProcessExtension {
585 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
586 Self::spawn_with_cwd(id, command, args, None).await
587 }
588
589 pub async fn spawn_with_cwd(
594 id: &str,
595 command: &str,
596 args: &[String],
597 cwd: Option<PathBuf>,
598 ) -> Result<Self, String> {
599 let inbox = Arc::new(Inbox::new(id.to_string()));
600 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
601 Ok(Self {
602 id: id.to_string(),
603 command: command.to_string(),
604 args: args.to_vec(),
605 cwd,
606 state: Arc::new(Mutex::new(Some(state))),
607 call_lock: Arc::new(Mutex::new(())),
608 next_id: AtomicU64::new(1),
609 restart_count: AtomicUsize::new(0),
610 total_restarts: AtomicUsize::new(0),
611 restart_policy: RestartPolicy::default(),
612 inbox,
613 })
614 }
615
616 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
618 self.restart_policy = policy;
619 self
620 }
621
622 async fn spawn_state(
623 id: &str,
624 command: &str,
625 args: &[String],
626 cwd: Option<&PathBuf>,
627 inbox: Arc<Inbox>,
628 ) -> Result<ProcessState, String> {
629 let mut cmd = Command::new(command);
630 cmd.args(args)
631 .stdin(Stdio::piped())
632 .stdout(Stdio::piped())
633 .stderr(Stdio::piped());
634 if let Some(cwd) = cwd {
635 cmd.current_dir(cwd);
636 }
637
638 cmd.env_clear();
644 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
645 if let Ok(val) = std::env::var(var) {
646 cmd.env(var, val);
647 }
648 }
649
650 cmd.kill_on_drop(true);
651
652 let mut child = cmd
653 .spawn()
654 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
655
656 let stdin = child
657 .stdin
658 .take()
659 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
660 let stdout = child
661 .stdout
662 .take()
663 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
664 if let Some(stderr) = child.stderr.take() {
665 let extension_id = id.to_string();
666 tokio::spawn(async move {
667 let mut lines = BufReader::new(stderr).lines();
668 loop {
669 match lines.next_line().await {
670 Ok(Some(line)) => {
671 tracing::debug!(extension = %extension_id, stderr = %line);
672 }
673 Ok(None) => break,
674 Err(error) => {
675 tracing::debug!(
676 extension = %extension_id,
677 error = %error,
678 "Failed to read extension stderr",
679 );
680 break;
681 }
682 }
683 }
684 });
685 }
686
687 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
688
689 let stdin_arc = Arc::new(Mutex::new(stdin));
690 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
693
694 Ok(ProcessState {
695 child,
696 stdin: stdin_arc,
697 reader_handle,
698 })
699 }
700
701 fn spawn_reader(
706 stdout: ChildStdout,
707 inbox: Arc<Inbox>,
708 extension_id: String,
709 ) -> JoinHandle<()> {
710 tokio::spawn(async move {
711 let mut reader = BufReader::new(stdout);
712 loop {
713 match Self::read_one_frame(&mut reader, &extension_id).await {
714 Ok(Some(value)) => {
715 Self::dispatch_frame(value, &inbox, &extension_id).await;
716 }
717 Ok(None) => {
718 tracing::debug!(
719 extension = %extension_id,
720 "Extension stdout closed (EOF); failing pending requests",
721 );
722 inbox.fail_all_pending("transport closed: EOF").await;
723 inbox.notification_sinks.lock().await.clear();
725 return;
726 }
727 Err(error) => {
728 tracing::debug!(
729 extension = %extension_id,
730 error = %error,
731 "Extension transport read error",
732 );
733 inbox
734 .fail_all_pending(&format!("transport error: {}", error))
735 .await;
736 inbox.notification_sinks.lock().await.clear();
737 return;
738 }
739 }
740 }
741 })
742 }
743
744 async fn read_one_frame(
748 reader: &mut BufReader<ChildStdout>,
749 extension_id: &str,
750 ) -> Result<Option<Value>, String> {
751 let mut content_length: Option<usize> = None;
752 let mut saw_any_header = false;
753 loop {
754 let mut header_line = String::new();
755 let n = reader
756 .read_line(&mut header_line)
757 .await
758 .map_err(|e| format!("Read header error: {}", e))?;
759 if n == 0 {
760 if saw_any_header {
761 return Err("Unexpected EOF while reading response headers".into());
762 }
763 return Ok(None);
764 }
765 saw_any_header = true;
766 if header_line.len() > 1024 {
767 return Err(format!(
768 "Extension '{}' header line too long ({} bytes)",
769 extension_id,
770 header_line.len()
771 ));
772 }
773 let trimmed = header_line.trim();
774 if trimmed.is_empty() {
775 break;
776 }
777 if let Some((name, value)) = trimmed.split_once(':') {
778 if name.trim().eq_ignore_ascii_case("Content-Length") {
779 content_length = Some(value.trim().parse().map_err(|_| {
780 format!("Invalid Content-Length value: {:?}", value.trim())
781 })?);
782 }
783 }
784 }
785 let content_length = content_length.ok_or_else(|| {
786 format!(
787 "Extension '{}' frame missing Content-Length header",
788 extension_id
789 )
790 })?;
791 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
792 if content_length > MAX_RESPONSE_SIZE {
793 return Err(format!(
794 "Extension '{}' frame too large: {} bytes (max {})",
795 extension_id, content_length, MAX_RESPONSE_SIZE
796 ));
797 }
798 let mut buf = vec![0u8; content_length];
799 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
800 .await
801 .map_err(|e| format!("Read body error: {}", e))?;
802 let value: Value = serde_json::from_slice(&buf)
803 .map_err(|e| format!("Parse frame error: {}", e))?;
804 Ok(Some(value))
805 }
806
807 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
813 let id_field = value.get("id");
814 let id_is_present = !matches!(id_field, None | Some(Value::Null));
815 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
816
817 if id_is_present && method_field.is_some() {
818 let id = match id_field.and_then(Value::as_u64) {
821 Some(id) => id,
822 None => {
823 tracing::trace!(
824 extension = %extension_id,
825 frame = %value,
826 "Discarding inbound request with non-numeric id",
827 );
828 return;
829 }
830 };
831 let Some(method) = method_field else { return };
832 let params = value.get("params").cloned().unwrap_or(Value::Null);
833 let inbox = inbox.clone();
834 let extension_id = extension_id.to_string();
835 tokio::spawn(async move {
836 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
837 let payload = match outcome {
838 Ok(result) => serde_json::json!({
839 "jsonrpc": "2.0",
840 "id": id,
841 "result": result,
842 }),
843 Err((code, message)) => serde_json::json!({
844 "jsonrpc": "2.0",
845 "id": id,
846 "error": {"code": code, "message": message},
847 }),
848 };
849 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
850 if let Some(stdin) = stdin_handle {
851 let body = match serde_json::to_string(&payload) {
852 Ok(s) => s,
853 Err(error) => {
854 tracing::warn!(
855 extension = %extension_id,
856 error = %error,
857 "Failed to serialize inbound response",
858 );
859 return;
860 }
861 };
862 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
863 let mut stdin = stdin.lock().await;
864 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
865 tracing::warn!(
866 extension = %extension_id,
867 error = %error,
868 "Failed to write inbound response",
869 );
870 return;
871 }
872 if let Err(error) = stdin.flush().await {
873 tracing::warn!(
874 extension = %extension_id,
875 error = %error,
876 "Failed to flush inbound response",
877 );
878 }
879 } else {
880 tracing::warn!(
881 extension = %extension_id,
882 "No stdin available to reply to inbound request",
883 );
884 }
885 });
886 return;
887 }
888
889 if id_is_present {
890 let id = match id_field.and_then(Value::as_u64) {
891 Some(id) => id,
892 None => {
893 tracing::trace!(
894 extension = %extension_id,
895 frame = %value,
896 "Discarding frame with non-numeric id",
897 );
898 return;
899 }
900 };
901 let sender = inbox.pending.lock().await.remove(&id);
902 match sender {
903 Some(tx) => {
904 let payload = if let Some(err) = value.get("error") {
905 let message = err
906 .get("message")
907 .and_then(Value::as_str)
908 .unwrap_or("unknown extension error")
909 .to_string();
910 Err(format!("Extension error: {}", message))
911 } else {
912 Ok(value
913 .get("result")
914 .cloned()
915 .unwrap_or(Value::Null))
916 };
917 let _ = tx.send(payload);
918 }
919 None => {
920 tracing::trace!(
921 extension = %extension_id,
922 id = id,
923 "Response with unknown id (no pending request); dropping",
924 );
925 }
926 }
927 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
928 let params = value.get("params").cloned().unwrap_or(Value::Null);
929 let frame = NotificationFrame {
930 method: method.to_string(),
931 params,
932 };
933 let mut sinks = inbox.notification_sinks.lock().await;
934 if sinks.is_empty() {
935 tracing::trace!(
936 extension = %extension_id,
937 method = %method,
938 "Notification with no active subscribers; dropping",
939 );
940 } else {
941 sinks.retain(|(_, tx)| tx.send(frame.clone()).is_ok());
945 }
946 } else {
947 tracing::trace!(
948 extension = %extension_id,
949 frame = %value,
950 "Unrecognized frame; dropping",
951 );
952 }
953 }
954
955 pub fn restart_count(&self) -> usize {
956 self.total_restarts.load(Ordering::Relaxed)
959 }
960
961 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
964 *self.inbox.permissions.write().await = Some(perms);
965 }
966
967 #[allow(clippy::doc_lazy_continuation)]
975 async fn handle_inbound_request(
976 inbox: &Arc<Inbox>,
977 method: &str,
978 params: Value,
979 ) -> Result<Value, (i32, String)> {
980 use crate::extensions::permissions::Permission;
981 use crate::memory::store::{self, MemoryQuery};
982
983 match method {
984 "memory.append" => {
985 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
986 let namespace = Self::param_str(¶ms, "namespace")?;
987 Self::require_namespace_matches(inbox, &namespace).await?;
988 let content = Self::param_str(¶ms, "content")?;
989 let tags = match params.get("tags") {
990 None | Some(Value::Null) => Vec::new(),
991 Some(Value::Array(arr)) => {
992 let mut out = Vec::with_capacity(arr.len());
993 for v in arr {
994 match v.as_str() {
995 Some(s) => out.push(s.to_string()),
996 None => {
997 return Err((
998 -32602,
999 "tags must be an array of strings".to_string(),
1000 ))
1001 }
1002 }
1003 }
1004 out
1005 }
1006 _ => {
1007 return Err((
1008 -32602,
1009 "tags must be an array of strings".to_string(),
1010 ))
1011 }
1012 };
1013 let meta = match params.get("meta") {
1014 None | Some(Value::Null) => None,
1015 Some(v) => Some(v.clone()),
1016 };
1017 let record = store::new_record(namespace, content, tags, meta);
1018 let timestamp_ms = record.timestamp_ms;
1019 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
1020 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1021 }
1022 "memory.query" => {
1023 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1024 let namespace = Self::param_str(¶ms, "namespace")?;
1025 Self::require_namespace_matches(inbox, &namespace).await?;
1026 let q = MemoryQuery {
1027 content_contains: params
1028 .get("content_contains")
1029 .and_then(Value::as_str)
1030 .map(str::to_string),
1031 tag_prefix: params
1032 .get("tag_prefix")
1033 .and_then(Value::as_str)
1034 .map(str::to_string),
1035 since_ms: params.get("since_ms").and_then(Value::as_u64),
1036 until_ms: params.get("until_ms").and_then(Value::as_u64),
1037 limit: params
1038 .get("limit")
1039 .and_then(Value::as_u64)
1040 .map(|n| n as usize),
1041 };
1042 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1043 Ok(serde_json::json!({"records": records}))
1044 }
1045 "config.get" => {
1046 let key = Self::param_str(¶ms, "key")?;
1047 Self::validate_config_key(&key)?;
1048 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1049 Ok(serde_json::json!({"value": value}))
1050 }
1051 "config.set" => {
1052 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1053 let key = Self::param_str(¶ms, "key")?;
1054 Self::validate_config_key(&key)?;
1055 let value = Self::param_str(¶ms, "value")?;
1056 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1057 .map_err(|e| (-32000, e.to_string()))?;
1058 Ok(serde_json::json!({"ok": true}))
1059 }
1060 "config.subscribe" => {
1061 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1062 Ok(serde_json::json!({"ok": true}))
1066 }
1067 other => Err((-32601, format!("method not found: {other}"))),
1068 }
1069 }
1070
1071 async fn require_permission(
1072 inbox: &Arc<Inbox>,
1073 perm: crate::extensions::permissions::Permission,
1074 wire: &str,
1075 ) -> Result<(), (i32, String)> {
1076 let guard = inbox.permissions.read().await;
1077 match guard.as_ref() {
1078 Some(set) if set.has(perm) => Ok(()),
1079 _ => Err((
1080 -32602,
1081 format!("permission denied: {wire} required"),
1082 )),
1083 }
1084 }
1085
1086 async fn require_namespace_matches(
1087 inbox: &Arc<Inbox>,
1088 namespace: &str,
1089 ) -> Result<(), (i32, String)> {
1090 if namespace == inbox.extension_id {
1091 Ok(())
1092 } else {
1093 Err((
1094 -32602,
1095 format!(
1096 "namespace must equal extension id '{}' (got '{}')",
1097 inbox.extension_id, namespace
1098 ),
1099 ))
1100 }
1101 }
1102
1103 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1104 params
1105 .get(name)
1106 .and_then(Value::as_str)
1107 .map(str::to_string)
1108 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1109 }
1110
1111 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1112 let trimmed = key.trim();
1113 if trimmed.is_empty() {
1114 return Err((-32602, "config key must be non-empty".to_string()));
1115 }
1116 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1117 return Err((
1118 -32602,
1119 "config key must not contain dots, slashes, or spaces".to_string(),
1120 ));
1121 }
1122 Ok(())
1123 }
1124
1125 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1126 let params = InitializeParams {
1127 synaps_version: env!("CARGO_PKG_VERSION"),
1128 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1129 plugin_id: self.id.clone(),
1130 plugin_root: plugin_root
1131 .or_else(|| self.cwd.clone())
1132 .map(|path| path.to_string_lossy().to_string()),
1133 config,
1134 };
1135 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1136 Self::parse_initialize_result(&self.id, value)
1137 }
1138
1139 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1140 let result: InitializeResult = serde_json::from_value(value)
1141 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1142 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1143 return Err(format!(
1144 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1145 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1146 ));
1147 }
1148 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1149 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1150 Ok(InitializeCapabilitiesResult {
1151 tools: result.capabilities.tools,
1152 providers: result.capabilities.providers,
1153 capabilities: result.capabilities.capabilities,
1154 })
1155 }
1156
1157 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1158 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1159 let mut names = HashSet::new();
1160 for tool in tools {
1161 let name = tool.name.trim();
1162 if let Err(err) = validate_id_segment(name) {
1163 return Err(match err {
1164 IdValidationError::Empty => format!(
1165 "Extension '{}' registered a tool with an empty tool name",
1166 id
1167 ),
1168 IdValidationError::ContainsReserved { ch } => format!(
1169 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1170 id, name, ch
1171 ),
1172 IdValidationError::TooLong { len, max } => format!(
1173 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1174 id, name, max, len
1175 ),
1176 IdValidationError::ContainsWhitespace => format!(
1177 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1178 id, name
1179 ),
1180 IdValidationError::ContainsControl { ch } => format!(
1181 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1182 id, name, ch as u32
1183 ),
1184 });
1185 }
1186 if !names.insert(name.to_string()) {
1187 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1188 }
1189 if tool.description.trim().is_empty() {
1190 return Err(format!(
1191 "Extension '{}' registered tool '{}' with an empty description",
1192 id, name,
1193 ));
1194 }
1195 if !tool.input_schema.is_object() {
1196 return Err(format!(
1197 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1198 id, name,
1199 ));
1200 }
1201 }
1202 Ok(())
1203 }
1204
1205 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1206 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1207 for provider in providers {
1208 let provider_id = provider.id.trim();
1209 match validate_id_segment(provider_id) {
1210 Ok(()) => {
1211 if !Self::is_safe_provider_id(provider_id) {
1212 return Err(format!(
1213 "Extension '{}' registered provider '{}' with invalid provider id",
1214 id, provider_id
1215 ));
1216 }
1217 }
1218 Err(IdValidationError::Empty) => {
1219 return Err(format!(
1220 "Extension '{}' registered provider with empty provider id",
1221 id
1222 ));
1223 }
1224 Err(err) => {
1225 return Err(format!(
1226 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1227 id, provider_id, err
1228 ));
1229 }
1230 }
1231 if provider.display_name.trim().is_empty() {
1232 return Err(format!(
1233 "Extension '{}' registered provider '{}' with empty display_name",
1234 id, provider_id,
1235 ));
1236 }
1237 if provider.description.trim().is_empty() {
1238 return Err(format!(
1239 "Extension '{}' registered provider '{}' with empty description",
1240 id, provider_id,
1241 ));
1242 }
1243 if provider.models.is_empty() {
1244 return Err(format!(
1245 "Extension '{}' registered provider '{}' must declare at least one model",
1246 id, provider_id,
1247 ));
1248 }
1249 let mut model_ids = HashSet::new();
1250 for model in &provider.models {
1251 let model_id = model.id.trim();
1252 if let Err(err) = validate_id_segment(model_id) {
1253 return Err(match err {
1254 IdValidationError::Empty => format!(
1255 "Extension '{}' registered provider '{}' with empty model id",
1256 id, provider_id
1257 ),
1258 IdValidationError::ContainsReserved { ch } => format!(
1259 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1260 id, provider_id, model_id, ch
1261 ),
1262 IdValidationError::TooLong { len, max } => format!(
1263 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1264 id, provider_id, model_id, max, len
1265 ),
1266 IdValidationError::ContainsWhitespace => format!(
1267 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1268 id, provider_id, model_id
1269 ),
1270 IdValidationError::ContainsControl { ch } => format!(
1271 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1272 id, provider_id, model_id, ch as u32
1273 ),
1274 });
1275 }
1276 if !model_ids.insert(model_id.to_string()) {
1277 return Err(format!(
1278 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1279 id, provider_id, model_id,
1280 ));
1281 }
1282 }
1283 if let Some(config_schema) = &provider.config_schema {
1284 if !config_schema.is_object() {
1285 return Err(format!(
1286 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1287 id, provider_id,
1288 ));
1289 }
1290 }
1291 }
1292 Ok(())
1293 }
1294
1295 fn is_safe_provider_id(id: &str) -> bool {
1296 !id.is_empty()
1297 && !id.contains(':')
1298 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1299 }
1300
1301 #[doc(hidden)]
1302 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1303 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1304 }
1305
1306 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1307 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1308 self.total_restarts.fetch_add(1, Ordering::Relaxed);
1309 let max_attempts = self.restart_policy.max_attempts;
1310 if attempted > max_attempts as usize {
1311 *state = None;
1312 return Err(format!(
1313 "Extension '{}' exceeded restart limit ({})",
1314 self.id, max_attempts,
1315 ));
1316 }
1317
1318 if let Some(old) = state.take() {
1319 old.reader_handle.abort();
1320 let mut child = old.child;
1321 let _ = child.kill().await;
1322 }
1323 self.inbox
1325 .fail_all_pending("transport closed: process restarting")
1326 .await;
1327
1328 let delay = self
1329 .restart_policy
1330 .delay_for_attempt(attempted as u32)
1331 .unwrap_or_default();
1332
1333 tracing::warn!(
1334 extension = %self.id,
1335 attempt = attempted,
1336 max_attempts = max_attempts,
1337 delay_ms = delay.as_millis() as u64,
1338 "Restarting extension process after transport failure",
1339 );
1340
1341 if !delay.is_zero() {
1342 tokio::time::sleep(delay).await;
1343 }
1344
1345 *state = Some(Self::spawn_state(
1346 &self.id,
1347 &self.command,
1348 &self.args,
1349 self.cwd.as_ref(),
1350 self.inbox.clone(),
1351 ).await?);
1352 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1354 self.initialize_locked(state).await?;
1355 Ok(())
1362 }
1363
1364
1365 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1366 let params = InitializeParams {
1367 synaps_version: env!("CARGO_PKG_VERSION"),
1368 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1369 plugin_id: self.id.clone(),
1370 plugin_root: self.cwd
1371 .clone()
1372 .map(|path| path.to_string_lossy().to_string()),
1373 config: Value::Object(Default::default()),
1374 };
1375 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1376 let value = tokio::time::timeout(
1377 std::time::Duration::from_secs(10),
1378 self.call_once_locked(
1379 state.as_mut().expect("state should exist for initialize"),
1380 "initialize",
1381 serde_json::to_value(params).map_err(|e| e.to_string())?,
1382 id,
1383 ),
1384 )
1385 .await
1386 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1387 ?;
1388 Self::parse_initialize_result(&self.id, value).map(|_| ())
1389 }
1390
1391 async fn call_once_locked(
1395 &self,
1396 state: &mut ProcessState,
1397 method: &str,
1398 params: Value,
1399 id: u64,
1400 ) -> Result<Value, String> {
1401 let body = serde_json::to_string(&JsonRpcRequest {
1402 jsonrpc: "2.0",
1403 method: method.to_string(),
1404 params,
1405 id,
1406 })
1407 .map_err(|e| format!("Serialize error: {}", e))?;
1408
1409 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1410 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1413 return Err("transport closed: inbox is shut down".to_string());
1414 }
1415
1416 self.inbox.pending.lock().await.insert(id, tx);
1419
1420 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1423 self.inbox.pending.lock().await.remove(&id);
1424 return Err("transport closed: inbox shut down during registration".to_string());
1425 }
1426
1427 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1428 let write_result = {
1429 let mut stdin = state.stdin.lock().await;
1430 match stdin.write_all(frame.as_bytes()).await {
1431 Ok(()) => stdin.flush().await,
1432 Err(e) => Err(e),
1433 }
1434 };
1435 if let Err(e) = write_result {
1436 self.inbox.pending.lock().await.remove(&id);
1438 return Err(format!("Write error: {}", e));
1439 }
1440
1441 match rx.await {
1442 Ok(payload) => payload,
1443 Err(_) => {
1444 self.inbox.pending.lock().await.remove(&id);
1449 Err("transport closed: response channel dropped".to_string())
1450 }
1451 }
1452 }
1453
1454 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1455 let _call_guard = self.call_lock.lock().await;
1456 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1457 let mut state_guard = self.state.lock().await;
1458 if state_guard.is_none() {
1459 *state_guard = Some(Self::spawn_state(
1460 &self.id,
1461 &self.command,
1462 &self.args,
1463 self.cwd.as_ref(),
1464 self.inbox.clone(),
1465 ).await?);
1466 }
1467 self.call_once_locked(
1468 state_guard.as_mut().expect("state should exist"),
1469 method,
1470 params,
1471 id,
1472 ).await
1473 }
1474
1475 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1476 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1477 let id_str = self.id.clone();
1478 let method_str = method.to_string();
1479
1480 let result = tokio::time::timeout(
1481 std::time::Duration::from_secs(timeout_secs),
1482 self.call_inner(method, params),
1483 )
1484 .await;
1485
1486 match result {
1487 Ok(inner) => inner,
1488 Err(_) => Err(format!(
1489 "Extension '{}' method '{}' timed out after {}s",
1490 id_str, method_str, timeout_secs
1491 )),
1492 }
1493 }
1494
1495 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1496 let _call_guard = self.call_lock.lock().await;
1497 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1498 let mut state_guard = self.state.lock().await;
1499 if state_guard.is_none() {
1500 self.restart_locked(&mut state_guard).await?;
1501 }
1502
1503 let result = self
1504 .call_once_locked(
1505 state_guard.as_mut().expect("state should exist after restart"),
1506 method,
1507 params.clone(),
1508 id,
1509 )
1510 .await;
1511
1512 match result {
1513 Ok(value) => {
1514 self.restart_count.store(0, Ordering::Relaxed);
1518 Ok(value)
1519 }
1520 Err(first_error) => {
1521 self.restart_locked(&mut state_guard).await?;
1522 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1523 self.call_once_locked(
1524 state_guard.as_mut().expect("state should exist after restart"),
1525 method,
1526 params,
1527 retry_id,
1528 )
1529 .await
1530 .inspect(|_| {
1531 self.restart_count.store(0, Ordering::Relaxed);
1533 })
1534 .map_err(|retry_error| {
1535 format!("{}; retry after restart failed: {}", first_error, retry_error)
1536 })
1537 }
1538 }
1539 }
1540
1541 #[doc(hidden)]
1559 pub async fn subscribe_notifications(
1560 &self,
1561 ) -> (usize, mpsc::UnboundedReceiver<NotificationFrame>) {
1562 let (tx, rx) = mpsc::unbounded_channel();
1563 let id = self
1564 .inbox
1565 .next_sink_id
1566 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1567 let mut sinks = self.inbox.notification_sinks.lock().await;
1568 sinks.retain(|(_, tx)| !tx.is_closed());
1572 sinks.push((id, tx));
1573 (id, rx)
1574 }
1575
1576 #[doc(hidden)]
1581 pub async fn unsubscribe_notifications(&self, id: usize) {
1582 let mut sinks = self.inbox.notification_sinks.lock().await;
1583 sinks.retain(|(sub_id, tx)| *sub_id != id && !tx.is_closed());
1584 }
1585
1586 pub(crate) fn forward_invoke_command_frame(
1596 extension_id: &str,
1597 request_id: &str,
1598 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1599 sink_open: &mut bool,
1600 frame: NotificationFrame,
1601 ) -> bool {
1602 use crate::extensions::commands::parse_command_output;
1603 use crate::extensions::tasks::{is_task_method, parse_task_event};
1604 use crate::extensions::runtime::InvokeCommandEvent;
1605
1606 let mut saw_done = false;
1607 if frame.method == "command.output" {
1608 match parse_command_output(&frame.params) {
1609 Ok(parsed) if parsed.request_id == request_id => {
1610 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1611 saw_done = true;
1612 }
1613 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1614 *sink_open = false;
1615 }
1616 }
1617 Ok(_) => {
1618 tracing::trace!(
1620 extension = %extension_id,
1621 "Ignoring command.output for unrelated request_id",
1622 );
1623 }
1624 Err(error) => {
1625 tracing::warn!(
1626 extension = %extension_id,
1627 error = %error,
1628 params = %frame.params,
1629 "Skipping malformed command.output notification",
1630 );
1631 }
1632 }
1633 } else if is_task_method(&frame.method) {
1634 match parse_task_event(&frame.method, &frame.params) {
1635 Ok(event) => {
1636 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1637 *sink_open = false;
1638 }
1639 }
1640 Err(error) => {
1641 tracing::warn!(
1642 extension = %extension_id,
1643 method = %frame.method,
1644 error = %error,
1645 params = %frame.params,
1646 "Skipping malformed task notification",
1647 );
1648 }
1649 }
1650 } else {
1651 tracing::trace!(
1652 extension = %extension_id,
1653 method = %frame.method,
1654 "Ignoring non-command/task notification during command.invoke",
1655 );
1656 }
1657 saw_done
1658 }
1659
1660 fn forward_provider_stream_frame(
1667 extension_id: &str,
1668 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1669 sink_open: &mut bool,
1670 frame: NotificationFrame,
1671 ) {
1672 if frame.method != "provider.stream.event" {
1673 tracing::trace!(
1674 extension = %extension_id,
1675 method = %frame.method,
1676 "Ignoring non-stream notification during provider.stream",
1677 );
1678 return;
1679 }
1680 match parse_provider_stream_event(&frame.params) {
1681 Ok(event) => {
1682 if *sink_open && sink.send(event).is_err() {
1683 *sink_open = false;
1684 }
1685 }
1686 Err(error) => {
1687 tracing::warn!(
1688 extension = %extension_id,
1689 error = %error,
1690 params = %frame.params,
1691 "Skipping malformed provider.stream.event notification",
1692 );
1693 }
1694 }
1695 }
1696}
1697
1698#[async_trait]
1699impl ExtensionHandler for ProcessExtension {
1700 fn id(&self) -> &str {
1701 &self.id
1702 }
1703
1704 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1705 self.call("tool.call", serde_json::json!({
1706 "name": name,
1707 "input": input,
1708 })).await
1709 }
1710
1711 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1712 let value = tokio::time::timeout(
1713 std::time::Duration::from_secs(60),
1714 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1715 )
1716 .await
1717 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1718 let result: ProviderCompleteResult = serde_json::from_value(value)
1719 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1720 if result.content.is_empty() {
1721 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1722 }
1723 Ok(result)
1724 }
1725
1726 async fn provider_stream(
1727 &self,
1728 params: ProviderCompleteParams,
1729 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1730 ) -> Result<ProviderCompleteResult, String> {
1731 let (sub_id, mut rx) = self.subscribe_notifications().await;
1734 let params_value =
1735 serde_json::to_value(params).map_err(|e| e.to_string())?;
1736
1737 let extension_id = self.id.clone();
1738 let stream_future = async {
1739 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1740 let mut sink_open = true;
1741 let response = loop {
1742 tokio::select! {
1743 response = &mut call_fut => break response,
1744 Some(frame) = rx.recv() => {
1745 Self::forward_provider_stream_frame(
1746 &extension_id, &sink, &mut sink_open, frame,
1747 );
1748 }
1749 }
1750 };
1751 self.unsubscribe_notifications(sub_id).await;
1756 while let Some(frame) = rx.recv().await {
1757 Self::forward_provider_stream_frame(
1758 &extension_id, &sink, &mut sink_open, frame,
1759 );
1760 }
1761 response
1762 };
1763
1764 let outcome = tokio::time::timeout(
1765 std::time::Duration::from_secs(60),
1766 stream_future,
1767 )
1768 .await;
1769
1770 self.unsubscribe_notifications(sub_id).await;
1773
1774 let value = outcome
1775 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1776
1777 let result: ProviderCompleteResult = serde_json::from_value(value)
1778 .map_err(|e| {
1779 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1780 })?;
1781 Ok(result)
1784 }
1785
1786 async fn invoke_command(
1787 &self,
1788 command: &str,
1789 args: Vec<String>,
1790 request_id: &str,
1791 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1792 ) -> Result<Value, String> {
1793 let (sub_id, mut rx) = self.subscribe_notifications().await;
1795 let params = serde_json::json!({
1796 "command": command,
1797 "args": args,
1798 "request_id": request_id,
1799 });
1800
1801 let extension_id = self.id.clone();
1802 let request_id_owned = request_id.to_string();
1803 let invoke_future = async {
1804 let mut call_fut = Box::pin(self.call("command.invoke", params));
1805 let mut sink_open = true;
1806 let response = loop {
1807 tokio::select! {
1808 response = &mut call_fut => break response,
1809 Some(frame) = rx.recv() => {
1810 let _ = Self::forward_invoke_command_frame(
1811 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1812 );
1813 }
1814 }
1815 };
1816 self.unsubscribe_notifications(sub_id).await;
1820 while let Ok(frame) = rx.try_recv() {
1821 let _ = Self::forward_invoke_command_frame(
1822 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1823 );
1824 }
1825 response
1826 };
1827
1828 let outcome = tokio::time::timeout(
1829 std::time::Duration::from_secs(120),
1830 invoke_future,
1831 )
1832 .await;
1833
1834 self.unsubscribe_notifications(sub_id).await;
1837
1838 outcome
1839 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1840 }
1841
1842 async fn handle(&self, event: &HookEvent) -> HookResult {
1843 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1844 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1845 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1846 Ok(result) => result,
1847 Err(error) => {
1848 tracing::warn!(
1849 extension = %self.id,
1850 error = %error,
1851 response = %value,
1852 "Extension hook handler returned invalid result",
1853 );
1854 if value.get("action").and_then(Value::as_str) == Some("modify") {
1855 HookResult::Block {
1856 reason: "Extension returned malformed modify result".to_string(),
1857 }
1858 } else {
1859 HookResult::Continue
1860 }
1861 }
1862 },
1863 Ok(Err(e)) => {
1864 tracing::warn!(
1865 extension = %self.id,
1866 error = %e,
1867 "Extension hook handler failed — continuing",
1868 );
1869 HookResult::Continue
1870 }
1871 Err(_) => {
1872 tracing::warn!(
1873 extension = %self.id,
1874 timeout_secs = 5,
1875 "Extension hook handler timed out — continuing",
1876 );
1877 HookResult::Continue
1878 }
1879 }
1880 }
1881
1882 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1883 let value = tokio::time::timeout(
1884 std::time::Duration::from_secs(5),
1885 self.call("info.get", Value::Null),
1886 )
1887 .await
1888 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1889 serde_json::from_value(value)
1890 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1891 }
1892
1893 async fn sidecar_spawn_args(
1894 &self,
1895 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1896 let value = tokio::time::timeout(
1897 std::time::Duration::from_secs(5),
1898 self.call("sidecar.spawn_args", Value::Null),
1899 )
1900 .await
1901 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1902 serde_json::from_value(value).map_err(|e| {
1903 format!(
1904 "Invalid sidecar.spawn_args response from extension '{}': {}",
1905 self.id, e
1906 )
1907 })
1908 }
1909
1910 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1911 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1912 category: category.to_string(),
1913 field: field.to_string(),
1914 };
1915 tokio::time::timeout(
1916 std::time::Duration::from_secs(5),
1917 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1918 )
1919 .await
1920 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1921 }
1922
1923 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1924 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1925 key: key.to_string(),
1926 }).map_err(|e| e.to_string())?;
1927 if let Some(obj) = params.as_object_mut() {
1928 obj.insert("category".to_string(), Value::String(category.to_string()));
1929 obj.insert("field".to_string(), Value::String(field.to_string()));
1930 }
1931 tokio::time::timeout(
1932 std::time::Duration::from_secs(5),
1933 self.call("settings.editor.key", params),
1934 )
1935 .await
1936 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1937 }
1938
1939 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1940 let params = serde_json::json!({
1941 "category": category,
1942 "field": field,
1943 "value": value,
1944 });
1945 tokio::time::timeout(
1946 std::time::Duration::from_secs(5),
1947 self.call("settings.editor.commit", params),
1948 )
1949 .await
1950 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1951 }
1952
1953 async fn shutdown(&self) {
1954 let _ = tokio::time::timeout(
1955 std::time::Duration::from_millis(500),
1956 self.call("shutdown", Value::Null),
1957 )
1958 .await;
1959
1960 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1961 let mut state_guard = self.state.lock().await;
1962 if let Some(state) = state_guard.take() {
1963 state.reader_handle.abort();
1964 let mut child = state.child;
1965 let _ = child.kill().await;
1966 }
1967 self.inbox.notification_sinks.lock().await.clear();
1969 self.inbox
1970 .fail_all_pending("transport closed: extension shutdown")
1971 .await;
1972 }
1973
1974 async fn subscribe_notifications(&self) -> (usize, tokio::sync::mpsc::UnboundedReceiver<NotificationFrame>) {
1975 ProcessExtension::subscribe_notifications(self).await
1976 }
1977
1978 async fn restart_count(&self) -> usize {
1979 self.restart_count()
1980 }
1981
1982 async fn health(&self) -> ExtensionHealth {
1983 let consecutive = self.restart_count.load(Ordering::Relaxed);
1987 let lifetime = self.total_restarts.load(Ordering::Relaxed);
1988 let max = self.restart_policy.max_attempts as usize;
1989 if consecutive >= max {
1990 ExtensionHealth::Failed
1991 } else if lifetime > 0 {
1992 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1996 if state_alive {
1997 ExtensionHealth::Degraded
1998 } else {
1999 ExtensionHealth::Restarting
2000 }
2001 } else {
2002 ExtensionHealth::Running
2003 }
2004 }
2005}
2006
2007#[cfg(test)]
2008mod stream_event_tests {
2009 use super::*;
2010 use serde_json::json;
2011
2012 #[test]
2013 fn parses_text_delta_with_delta_key() {
2014 let v = json!({"type": "text", "delta": "hi"});
2015 assert_eq!(
2016 parse_provider_stream_event(&v).unwrap(),
2017 ProviderStreamEvent::TextDelta { text: "hi".into() }
2018 );
2019 }
2020
2021 #[test]
2022 fn parses_text_delta_with_text_key() {
2023 let v = json!({"type": "text", "text": "hi"});
2024 assert_eq!(
2025 parse_provider_stream_event(&v).unwrap(),
2026 ProviderStreamEvent::TextDelta { text: "hi".into() }
2027 );
2028 }
2029
2030 #[test]
2031 fn parses_thinking_delta() {
2032 let v = json!({"type": "thinking", "delta": "hmm"});
2033 assert_eq!(
2034 parse_provider_stream_event(&v).unwrap(),
2035 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2036 );
2037 let v2 = json!({"type": "thinking", "text": "hmm"});
2038 assert_eq!(
2039 parse_provider_stream_event(&v2).unwrap(),
2040 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2041 );
2042 }
2043
2044 #[test]
2045 fn parses_tool_use() {
2046 let v = json!({
2047 "type": "tool_use",
2048 "id": "t1",
2049 "name": "echo",
2050 "input": {"x": 1}
2051 });
2052 assert_eq!(
2053 parse_provider_stream_event(&v).unwrap(),
2054 ProviderStreamEvent::ToolUse {
2055 id: "t1".into(),
2056 name: "echo".into(),
2057 input: json!({"x": 1}),
2058 }
2059 );
2060 }
2061
2062 #[test]
2063 fn tool_use_input_defaults_to_empty_object() {
2064 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2065 assert_eq!(
2066 parse_provider_stream_event(&v).unwrap(),
2067 ProviderStreamEvent::ToolUse {
2068 id: "t1".into(),
2069 name: "echo".into(),
2070 input: json!({}),
2071 }
2072 );
2073 }
2074
2075 #[test]
2076 fn parses_usage_strips_type() {
2077 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2078 assert_eq!(
2079 parse_provider_stream_event(&v).unwrap(),
2080 ProviderStreamEvent::Usage {
2081 usage: json!({"input_tokens": 5, "output_tokens": 7})
2082 }
2083 );
2084 }
2085
2086 #[test]
2087 fn parses_error() {
2088 let v = json!({"type": "error", "message": "boom"});
2089 assert_eq!(
2090 parse_provider_stream_event(&v).unwrap(),
2091 ProviderStreamEvent::Error { message: "boom".into() }
2092 );
2093 }
2094
2095 #[test]
2096 fn parses_done() {
2097 let v = json!({"type": "done"});
2098 assert_eq!(
2099 parse_provider_stream_event(&v).unwrap(),
2100 ProviderStreamEvent::Done
2101 );
2102 }
2103
2104 #[test]
2105 fn nested_event_shape_matches_flat() {
2106 let flat = json!({"type": "text", "delta": "hi"});
2107 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2108 assert_eq!(
2109 parse_provider_stream_event(&flat).unwrap(),
2110 parse_provider_stream_event(&nested).unwrap()
2111 );
2112 }
2113
2114 #[test]
2115 fn missing_type_errors() {
2116 let v = json!({"delta": "hi"});
2117 let err = parse_provider_stream_event(&v).unwrap_err();
2118 assert!(err.contains("missing type"), "got: {err}");
2119 }
2120
2121 #[test]
2122 fn unknown_type_errors_with_type() {
2123 let v = json!({"type": "wat"});
2124 let err = parse_provider_stream_event(&v).unwrap_err();
2125 assert!(err.contains("wat"), "got: {err}");
2126 }
2127
2128 #[test]
2129 fn tool_use_missing_id_errors() {
2130 let v = json!({"type": "tool_use", "name": "echo"});
2131 let err = parse_provider_stream_event(&v).unwrap_err();
2132 assert!(err.contains("id"), "got: {err}");
2133 }
2134
2135 #[test]
2136 fn tool_use_missing_name_errors() {
2137 let v = json!({"type": "tool_use", "id": "t1"});
2138 let err = parse_provider_stream_event(&v).unwrap_err();
2139 assert!(err.contains("name"), "got: {err}");
2140 }
2141
2142 #[test]
2143 fn tool_use_empty_id_errors() {
2144 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2145 assert!(parse_provider_stream_event(&v).is_err());
2146 }
2147
2148 #[test]
2149 fn tool_use_empty_name_errors() {
2150 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2151 assert!(parse_provider_stream_event(&v).is_err());
2152 }
2153
2154 #[test]
2155 fn tool_use_non_object_input_errors() {
2156 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2157 let err = parse_provider_stream_event(&v).unwrap_err();
2158 assert!(err.contains("input"), "got: {err}");
2159 }
2160
2161 #[test]
2162 fn text_missing_delta_and_text_errors() {
2163 let v = json!({"type": "text"});
2164 let err = parse_provider_stream_event(&v).unwrap_err();
2165 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2166 }
2167
2168 #[test]
2169 fn error_missing_message_errors() {
2170 let v = json!({"type": "error"});
2171 assert!(parse_provider_stream_event(&v).is_err());
2172 }
2173
2174 #[test]
2175 fn error_empty_message_errors() {
2176 let v = json!({"type": "error", "message": ""});
2177 assert!(parse_provider_stream_event(&v).is_err());
2178 }
2179}
2180
2181#[cfg(test)]
2182mod restart_policy_tests {
2183 use super::*;
2184
2185 #[tokio::test]
2186 async fn restart_policy_default_max_attempts_is_3() {
2187 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2192 .await
2193 .expect("spawn /bin/cat");
2194 assert_eq!(ext.restart_policy.max_attempts, 3);
2195 ext.shutdown().await;
2196 }
2197
2198 #[tokio::test]
2199 async fn with_restart_policy_overrides_default() {
2200 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2201 .await
2202 .expect("spawn /bin/cat");
2203 let custom = RestartPolicy {
2204 max_attempts: 7,
2205 ..RestartPolicy::default()
2206 };
2207 let ext = ext.with_restart_policy(custom);
2208 assert_eq!(ext.restart_policy.max_attempts, 7);
2209 ext.shutdown().await;
2210 }
2211}
2212
2213#[cfg(test)]
2214mod capture_validator_tests {
2215 use super::*;
2216 use crate::extensions::permissions::{Permission, PermissionSet};
2217
2218 fn perms_with(grants: &[Permission]) -> PermissionSet {
2219 let mut p = PermissionSet::new();
2220 for g in grants {
2221 p.grant(*g);
2222 }
2223 p
2224 }
2225
2226 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2227 CapabilityDeclaration {
2228 kind: kind.to_string(),
2229 name: name.to_string(),
2230 permissions: perms.iter().map(|p| p.to_string()).collect(),
2231 params: serde_json::Value::Null,
2232 }
2233 }
2234
2235 #[test]
2236 fn capability_validator_rejects_empty_kind() {
2237 let d = cap(" ", "Sample", &["audio.input"]);
2238 let perms = perms_with(&[Permission::AudioInput]);
2239 let err = validate_capability(&d, &perms).unwrap_err();
2240 assert!(err.contains("kind"), "got: {}", err);
2241 }
2242
2243 #[test]
2244 fn capability_validator_rejects_empty_name() {
2245 let d = cap("capture", " ", &["audio.input"]);
2246 let perms = perms_with(&[Permission::AudioInput]);
2247 let err = validate_capability(&d, &perms).unwrap_err();
2248 assert!(err.contains("name"), "got: {}", err);
2249 }
2250
2251 #[test]
2252 fn capability_validator_rejects_unknown_permission_string() {
2253 let d = cap("capture", "Sample", &["audio.telepathy"]);
2254 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2255 let err = validate_capability(&d, &perms).unwrap_err();
2256 assert!(
2257 err.contains("unknown permission") && err.contains("audio.telepathy"),
2258 "got: {}",
2259 err,
2260 );
2261 }
2262
2263 #[test]
2264 fn capability_validator_requires_every_declared_permission() {
2265 let d = cap("capture", "Sample", &["audio.input"]);
2266 let perms = perms_with(&[]);
2267 let err = validate_capability(&d, &perms).unwrap_err();
2268 assert!(
2269 err.contains("audio.input") && err.contains("not granted"),
2270 "got: {}",
2271 err,
2272 );
2273 }
2274
2275 #[test]
2276 fn capability_validator_accepts_when_all_permissions_granted() {
2277 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2278 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2279 validate_capability(&d, &perms).expect("should validate");
2280 }
2281
2282 #[test]
2283 fn capability_validator_accepts_no_permissions() {
2284 let d = cap("ocr", "Tesseract", &[]);
2288 let perms = perms_with(&[]);
2289 validate_capability(&d, &perms).expect("should validate");
2290 }
2291
2292 #[test]
2293 fn capability_validator_does_not_branch_on_kind() {
2294 let perms = perms_with(&[Permission::AudioInput]);
2298 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2299 let d = cap(kind, "Anything", &["audio.input"]);
2300 validate_capability(&d, &perms).expect("should validate");
2301 }
2302 }
2303
2304}
2305
2306#[cfg(test)]
2307mod invoke_command_dispatch_tests {
2308 use super::*;
2313 use crate::extensions::commands::CommandOutputEvent;
2314 use crate::extensions::runtime::InvokeCommandEvent;
2315 use crate::extensions::tasks::{TaskEvent, TaskKind};
2316 use serde_json::json;
2317 use tokio::sync::mpsc;
2318
2319 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2320 NotificationFrame {
2321 method: method.to_string(),
2322 params,
2323 }
2324 }
2325
2326 #[test]
2327 fn forwards_mixed_event_stream_in_order() {
2328 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2329 let mut open = true;
2330 let frames = vec![
2331 frame(
2332 "command.output",
2333 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2334 ),
2335 frame(
2336 "task.start",
2337 json!({"id":"dl","label":"Downloading","kind":"download"}),
2338 ),
2339 frame(
2340 "task.update",
2341 json!({"id":"dl","current":50,"total":100}),
2342 ),
2343 frame(
2344 "command.output",
2345 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2346 ),
2347 frame("task.done", json!({"id":"dl"})),
2348 frame(
2349 "command.output",
2350 json!({"request_id":"r1","event":{"kind":"done"}}),
2351 ),
2352 ];
2353
2354 let mut saw_done = false;
2355 for f in frames {
2356 saw_done |= ProcessExtension::forward_invoke_command_frame(
2357 "ext-test", "r1", &tx, &mut open, f,
2358 );
2359 }
2360 drop(tx);
2361 assert!(saw_done, "should have observed the command Done marker");
2362
2363 let mut events = Vec::new();
2364 while let Ok(ev) = rx.try_recv() {
2365 events.push(ev);
2366 }
2367 assert_eq!(events.len(), 6);
2368 assert_eq!(
2369 events[0],
2370 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2371 );
2372 assert!(matches!(
2373 events[1],
2374 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2375 ));
2376 assert!(matches!(
2377 events[2],
2378 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2379 ));
2380 assert!(matches!(
2381 events[3],
2382 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2383 ));
2384 assert!(matches!(
2385 events[4],
2386 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2387 ));
2388 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2389 }
2390
2391 #[test]
2392 fn ignores_command_output_for_unrelated_request_id() {
2393 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2394 let mut open = true;
2395 ProcessExtension::forward_invoke_command_frame(
2396 "ext",
2397 "r1",
2398 &tx,
2399 &mut open,
2400 frame(
2401 "command.output",
2402 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2403 ),
2404 );
2405 drop(tx);
2406 assert!(rx.try_recv().is_err());
2407 }
2408
2409 #[test]
2410 fn skips_malformed_command_output_without_aborting() {
2411 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2412 let mut open = true;
2413 ProcessExtension::forward_invoke_command_frame(
2415 "ext",
2416 "r1",
2417 &tx,
2418 &mut open,
2419 frame("command.output", json!({"request_id":"r1","event":{}})),
2420 );
2421 ProcessExtension::forward_invoke_command_frame(
2423 "ext",
2424 "r1",
2425 &tx,
2426 &mut open,
2427 frame(
2428 "command.output",
2429 json!({"request_id":"r1","event":{"kind":"done"}}),
2430 ),
2431 );
2432 drop(tx);
2433 let ev = rx.try_recv().unwrap();
2434 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2435 assert!(rx.try_recv().is_err());
2436 }
2437
2438 #[test]
2439 fn task_events_pass_through_regardless_of_request_id() {
2440 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2441 let mut open = true;
2442 ProcessExtension::forward_invoke_command_frame(
2443 "ext",
2444 "r1",
2445 &tx,
2446 &mut open,
2447 frame("task.log", json!({"id":"abc","line":"..."})),
2448 );
2449 drop(tx);
2450 match rx.try_recv().unwrap() {
2451 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2452 assert_eq!(id, "abc");
2453 assert_eq!(line, "...");
2454 }
2455 other => panic!("unexpected: {other:?}"),
2456 }
2457 }
2458
2459 #[test]
2460 fn unrelated_methods_are_dropped() {
2461 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2462 let mut open = true;
2463 ProcessExtension::forward_invoke_command_frame(
2464 "ext",
2465 "r1",
2466 &tx,
2467 &mut open,
2468 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2469 );
2470 drop(tx);
2471 assert!(rx.try_recv().is_err());
2472 }
2473}