1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 false,
251 ).await;
252
253 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
254 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
255 return serde_json::json!({
256 "type": "tool_result",
257 "tool_use_id": tool_id,
258 "content": format!("Tool call blocked by extension: {}", reason),
259 "is_error": true,
260 });
261 };
262
263 let input_for_hook = input.clone();
264 let (result, is_error) = match tool.execute(input, ctx).await {
265 Ok(output) => (output, false),
266 Err(error) => (format!("Tool execution failed: {}", error), true),
267 };
268 let _ = crate::runtime::emit_after_tool_call(
269 hook_bus,
270 &tool_name,
271 Some(&runtime_name),
272 input_for_hook,
273 result.clone(),
274 ).await;
275
276 let mut response = serde_json::json!({
277 "type": "tool_result",
278 "tool_use_id": tool_id,
279 "content": crate::truncate_str(&result, max_tool_output).to_string(),
280 });
281 if is_error {
282 response["is_error"] = serde_json::json!(true);
283 }
284 response
285}
286
287pub async fn complete_provider_with_tools<F>(
288 handler: Arc<dyn ExtensionHandler>,
289 mut params: ProviderCompleteParams,
290 registry: &crate::ToolRegistry,
291 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
292 mut context_factory: F,
293 max_tool_output: usize,
294 max_iterations: usize,
295) -> Result<ProviderCompleteResult, String>
296where
297 F: FnMut() -> crate::ToolContext,
298{
299 let max_iterations = max_iterations.max(1);
300 for iteration in 0..max_iterations {
301 let result = handler.provider_complete(params.clone()).await?;
302 let tool_uses = extract_provider_tool_uses(&result.content)?;
303 if tool_uses.is_empty() {
304 return Ok(result);
305 }
306 if iteration + 1 == max_iterations {
307 return Err(format!(
308 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
309 handler.id(),
310 max_iterations,
311 ));
312 }
313
314 let assistant_content = result.content.clone();
315 params.messages.push(serde_json::json!({
316 "role": "assistant",
317 "content": assistant_content,
318 }));
319
320 let mut tool_results = Vec::with_capacity(tool_uses.len());
321 for tool_use in tool_uses {
322 tool_results.push(execute_provider_tool_use(
323 registry,
324 hook_bus,
325 tool_use,
326 context_factory(),
327 max_tool_output,
328 ).await);
329 }
330 params.messages.push(serde_json::json!({
331 "role": "user",
332 "content": tool_results,
333 }));
334 }
335 Err(format!(
336 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
337 handler.id(),
338 max_iterations,
339 ))
340}
341
342pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
343 let mut tool_uses = Vec::new();
344 for block in content {
345 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
346 continue;
347 }
348 let id = block
349 .get("id")
350 .and_then(Value::as_str)
351 .ok_or_else(|| "provider tool_use missing id".to_string())?;
352 let name = block
353 .get("name")
354 .and_then(Value::as_str)
355 .ok_or_else(|| "provider tool_use missing name".to_string())?;
356 if id.trim().is_empty() {
357 return Err("provider tool_use id is empty".to_string());
358 }
359 if name.trim().is_empty() {
360 return Err("provider tool_use name is empty".to_string());
361 }
362 let input = block
363 .get("input")
364 .cloned()
365 .unwrap_or_else(|| serde_json::json!({}));
366 if !input.is_object() {
367 return Err(format!(
368 "provider tool_use '{}' input must be a JSON object",
369 id
370 ));
371 }
372 tool_uses.push(ProviderToolUse {
373 id: id.to_string(),
374 name: name.to_string(),
375 input,
376 });
377 }
378 Ok(tool_uses)
379}
380
381#[derive(Debug, Clone, Default, PartialEq)]
382pub struct InitializeCapabilitiesResult {
383 pub tools: Vec<RegisteredExtensionToolSpec>,
384 pub providers: Vec<RegisteredProviderSpec>,
385 pub capabilities: Vec<CapabilityDeclaration>,
391}
392
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
402pub struct CapabilityDeclaration {
403 pub kind: String,
407 pub name: String,
410 #[serde(default)]
415 pub permissions: Vec<String>,
416 #[serde(default, skip_serializing_if = "is_null_value")]
420 pub params: serde_json::Value,
421}
422
423fn is_null_value(v: &serde_json::Value) -> bool {
424 v.is_null()
425}
426
427pub fn validate_capability(
437 decl: &CapabilityDeclaration,
438 granted: &crate::extensions::permissions::PermissionSet,
439) -> Result<(), String> {
440 use crate::extensions::permissions::Permission;
441 if decl.kind.trim().is_empty() {
442 return Err("capability 'kind' must be non-empty".to_string());
443 }
444 if decl.name.trim().is_empty() {
445 return Err("capability 'name' must be non-empty".to_string());
446 }
447 for perm_name in &decl.permissions {
448 let parsed = Permission::parse(perm_name).ok_or_else(|| {
449 format!(
450 "capability '{}' declares unknown permission '{}'",
451 decl.kind, perm_name
452 )
453 })?;
454 if !granted.has(parsed) {
455 return Err(format!(
456 "capability '{}' requires permission '{}' but it is not granted",
457 decl.kind, perm_name
458 ));
459 }
460 }
461 Ok(())
462}
463
464#[derive(Deserialize)]
465struct InitializeResult {
466 protocol_version: u32,
467 #[serde(default)]
468 capabilities: InitializeCapabilities,
469}
470
471#[derive(Default, Deserialize)]
472struct InitializeCapabilities {
473 #[serde(default)]
474 tools: Vec<RegisteredExtensionToolSpec>,
475 #[serde(default)]
476 providers: Vec<RegisteredProviderSpec>,
477 #[serde(default)]
479 capabilities: Vec<CapabilityDeclaration>,
480}
481
482#[doc(hidden)]
488#[derive(Debug, Clone)]
489pub struct NotificationFrame {
490 pub method: String,
491 pub params: Value,
492}
493
494struct Inbox {
501 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
502 notification_sinks: Mutex<Vec<(usize, mpsc::UnboundedSender<NotificationFrame>)>>,
508 next_sink_id: std::sync::atomic::AtomicUsize,
512 closed: std::sync::atomic::AtomicBool,
515 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
518 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
522 extension_id: String,
524}
525
526impl Inbox {
527 fn new(extension_id: String) -> Self {
528 Self {
529 pending: Mutex::new(HashMap::new()),
530 notification_sinks: Mutex::new(Vec::new()),
531 next_sink_id: std::sync::atomic::AtomicUsize::new(0),
532 closed: std::sync::atomic::AtomicBool::new(false),
533 permissions: RwLock::new(None),
534 inbound_stdin: Mutex::new(None),
535 extension_id,
536 }
537 }
538
539 async fn fail_all_pending(&self, reason: &str) {
542 self.closed.store(true, std::sync::atomic::Ordering::Release);
543 let drained: Vec<_> = {
544 let mut pending = self.pending.lock().await;
545 pending.drain().collect()
546 };
547 for (_, tx) in drained {
548 let _ = tx.send(Err(reason.to_string()));
549 }
550 }
551}
552
553struct ProcessState {
554 child: Child,
555 stdin: Arc<Mutex<ChildStdin>>,
556 reader_handle: JoinHandle<()>,
557}
558
559pub struct ProcessExtension {
561 id: String,
562 command: String,
563 args: Vec<String>,
564 cwd: Option<PathBuf>,
565 state: Arc<Mutex<Option<ProcessState>>>,
566 call_lock: Arc<Mutex<()>>,
568 next_id: AtomicU64,
569 restart_count: AtomicUsize,
570 pub(crate) restart_policy: RestartPolicy,
572 inbox: Arc<Inbox>,
576}
577
578impl ProcessExtension {
579 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
580 Self::spawn_with_cwd(id, command, args, None).await
581 }
582
583 pub async fn spawn_with_cwd(
588 id: &str,
589 command: &str,
590 args: &[String],
591 cwd: Option<PathBuf>,
592 ) -> Result<Self, String> {
593 let inbox = Arc::new(Inbox::new(id.to_string()));
594 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
595 Ok(Self {
596 id: id.to_string(),
597 command: command.to_string(),
598 args: args.to_vec(),
599 cwd,
600 state: Arc::new(Mutex::new(Some(state))),
601 call_lock: Arc::new(Mutex::new(())),
602 next_id: AtomicU64::new(1),
603 restart_count: AtomicUsize::new(0),
604 restart_policy: RestartPolicy::default(),
605 inbox,
606 })
607 }
608
609 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
611 self.restart_policy = policy;
612 self
613 }
614
615 async fn spawn_state(
616 id: &str,
617 command: &str,
618 args: &[String],
619 cwd: Option<&PathBuf>,
620 inbox: Arc<Inbox>,
621 ) -> Result<ProcessState, String> {
622 let mut cmd = Command::new(command);
623 cmd.args(args)
624 .stdin(Stdio::piped())
625 .stdout(Stdio::piped())
626 .stderr(Stdio::piped());
627 if let Some(cwd) = cwd {
628 cmd.current_dir(cwd);
629 }
630
631 cmd.env_clear();
637 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
638 if let Ok(val) = std::env::var(var) {
639 cmd.env(var, val);
640 }
641 }
642
643 cmd.kill_on_drop(true);
644
645 let mut child = cmd
646 .spawn()
647 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
648
649 let stdin = child
650 .stdin
651 .take()
652 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
653 let stdout = child
654 .stdout
655 .take()
656 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
657 if let Some(stderr) = child.stderr.take() {
658 let extension_id = id.to_string();
659 tokio::spawn(async move {
660 let mut lines = BufReader::new(stderr).lines();
661 loop {
662 match lines.next_line().await {
663 Ok(Some(line)) => {
664 tracing::debug!(extension = %extension_id, stderr = %line);
665 }
666 Ok(None) => break,
667 Err(error) => {
668 tracing::debug!(
669 extension = %extension_id,
670 error = %error,
671 "Failed to read extension stderr",
672 );
673 break;
674 }
675 }
676 }
677 });
678 }
679
680 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
681
682 let stdin_arc = Arc::new(Mutex::new(stdin));
683 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
686
687 Ok(ProcessState {
688 child,
689 stdin: stdin_arc,
690 reader_handle,
691 })
692 }
693
694 fn spawn_reader(
699 stdout: ChildStdout,
700 inbox: Arc<Inbox>,
701 extension_id: String,
702 ) -> JoinHandle<()> {
703 tokio::spawn(async move {
704 let mut reader = BufReader::new(stdout);
705 loop {
706 match Self::read_one_frame(&mut reader, &extension_id).await {
707 Ok(Some(value)) => {
708 Self::dispatch_frame(value, &inbox, &extension_id).await;
709 }
710 Ok(None) => {
711 tracing::debug!(
712 extension = %extension_id,
713 "Extension stdout closed (EOF); failing pending requests",
714 );
715 inbox.fail_all_pending("transport closed: EOF").await;
716 inbox.notification_sinks.lock().await.clear();
718 return;
719 }
720 Err(error) => {
721 tracing::debug!(
722 extension = %extension_id,
723 error = %error,
724 "Extension transport read error",
725 );
726 inbox
727 .fail_all_pending(&format!("transport error: {}", error))
728 .await;
729 inbox.notification_sinks.lock().await.clear();
730 return;
731 }
732 }
733 }
734 })
735 }
736
737 async fn read_one_frame(
741 reader: &mut BufReader<ChildStdout>,
742 extension_id: &str,
743 ) -> Result<Option<Value>, String> {
744 let mut content_length: Option<usize> = None;
745 let mut saw_any_header = false;
746 loop {
747 let mut header_line = String::new();
748 let n = reader
749 .read_line(&mut header_line)
750 .await
751 .map_err(|e| format!("Read header error: {}", e))?;
752 if n == 0 {
753 if saw_any_header {
754 return Err("Unexpected EOF while reading response headers".into());
755 }
756 return Ok(None);
757 }
758 saw_any_header = true;
759 if header_line.len() > 1024 {
760 return Err(format!(
761 "Extension '{}' header line too long ({} bytes)",
762 extension_id,
763 header_line.len()
764 ));
765 }
766 let trimmed = header_line.trim();
767 if trimmed.is_empty() {
768 break;
769 }
770 if let Some((name, value)) = trimmed.split_once(':') {
771 if name.trim().eq_ignore_ascii_case("Content-Length") {
772 content_length = Some(value.trim().parse().map_err(|_| {
773 format!("Invalid Content-Length value: {:?}", value.trim())
774 })?);
775 }
776 }
777 }
778 let content_length = content_length.ok_or_else(|| {
779 format!(
780 "Extension '{}' frame missing Content-Length header",
781 extension_id
782 )
783 })?;
784 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
785 if content_length > MAX_RESPONSE_SIZE {
786 return Err(format!(
787 "Extension '{}' frame too large: {} bytes (max {})",
788 extension_id, content_length, MAX_RESPONSE_SIZE
789 ));
790 }
791 let mut buf = vec![0u8; content_length];
792 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
793 .await
794 .map_err(|e| format!("Read body error: {}", e))?;
795 let value: Value = serde_json::from_slice(&buf)
796 .map_err(|e| format!("Parse frame error: {}", e))?;
797 Ok(Some(value))
798 }
799
800 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
806 let id_field = value.get("id");
807 let id_is_present = !matches!(id_field, None | Some(Value::Null));
808 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
809
810 if id_is_present && method_field.is_some() {
811 let id = match id_field.and_then(Value::as_u64) {
814 Some(id) => id,
815 None => {
816 tracing::trace!(
817 extension = %extension_id,
818 frame = %value,
819 "Discarding inbound request with non-numeric id",
820 );
821 return;
822 }
823 };
824 let Some(method) = method_field else { return };
825 let params = value.get("params").cloned().unwrap_or(Value::Null);
826 let inbox = inbox.clone();
827 let extension_id = extension_id.to_string();
828 tokio::spawn(async move {
829 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
830 let payload = match outcome {
831 Ok(result) => serde_json::json!({
832 "jsonrpc": "2.0",
833 "id": id,
834 "result": result,
835 }),
836 Err((code, message)) => serde_json::json!({
837 "jsonrpc": "2.0",
838 "id": id,
839 "error": {"code": code, "message": message},
840 }),
841 };
842 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
843 if let Some(stdin) = stdin_handle {
844 let body = match serde_json::to_string(&payload) {
845 Ok(s) => s,
846 Err(error) => {
847 tracing::warn!(
848 extension = %extension_id,
849 error = %error,
850 "Failed to serialize inbound response",
851 );
852 return;
853 }
854 };
855 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
856 let mut stdin = stdin.lock().await;
857 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
858 tracing::warn!(
859 extension = %extension_id,
860 error = %error,
861 "Failed to write inbound response",
862 );
863 return;
864 }
865 if let Err(error) = stdin.flush().await {
866 tracing::warn!(
867 extension = %extension_id,
868 error = %error,
869 "Failed to flush inbound response",
870 );
871 }
872 } else {
873 tracing::warn!(
874 extension = %extension_id,
875 "No stdin available to reply to inbound request",
876 );
877 }
878 });
879 return;
880 }
881
882 if id_is_present {
883 let id = match id_field.and_then(Value::as_u64) {
884 Some(id) => id,
885 None => {
886 tracing::trace!(
887 extension = %extension_id,
888 frame = %value,
889 "Discarding frame with non-numeric id",
890 );
891 return;
892 }
893 };
894 let sender = inbox.pending.lock().await.remove(&id);
895 match sender {
896 Some(tx) => {
897 let payload = if let Some(err) = value.get("error") {
898 let message = err
899 .get("message")
900 .and_then(Value::as_str)
901 .unwrap_or("unknown extension error")
902 .to_string();
903 Err(format!("Extension error: {}", message))
904 } else {
905 Ok(value
906 .get("result")
907 .cloned()
908 .unwrap_or(Value::Null))
909 };
910 let _ = tx.send(payload);
911 }
912 None => {
913 tracing::trace!(
914 extension = %extension_id,
915 id = id,
916 "Response with unknown id (no pending request); dropping",
917 );
918 }
919 }
920 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
921 let params = value.get("params").cloned().unwrap_or(Value::Null);
922 let frame = NotificationFrame {
923 method: method.to_string(),
924 params,
925 };
926 let mut sinks = inbox.notification_sinks.lock().await;
927 if sinks.is_empty() {
928 tracing::trace!(
929 extension = %extension_id,
930 method = %method,
931 "Notification with no active subscribers; dropping",
932 );
933 } else {
934 sinks.retain(|(_, tx)| tx.send(frame.clone()).is_ok());
938 }
939 } else {
940 tracing::trace!(
941 extension = %extension_id,
942 frame = %value,
943 "Unrecognized frame; dropping",
944 );
945 }
946 }
947
948 pub fn restart_count(&self) -> usize {
949 self.restart_count.load(Ordering::Relaxed)
950 }
951
952 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
955 *self.inbox.permissions.write().await = Some(perms);
956 }
957
958 #[allow(clippy::doc_lazy_continuation)]
966 async fn handle_inbound_request(
967 inbox: &Arc<Inbox>,
968 method: &str,
969 params: Value,
970 ) -> Result<Value, (i32, String)> {
971 use crate::extensions::permissions::Permission;
972 use crate::memory::store::{self, MemoryQuery};
973
974 match method {
975 "memory.append" => {
976 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
977 let namespace = Self::param_str(¶ms, "namespace")?;
978 Self::require_namespace_matches(inbox, &namespace).await?;
979 let content = Self::param_str(¶ms, "content")?;
980 let tags = match params.get("tags") {
981 None | Some(Value::Null) => Vec::new(),
982 Some(Value::Array(arr)) => {
983 let mut out = Vec::with_capacity(arr.len());
984 for v in arr {
985 match v.as_str() {
986 Some(s) => out.push(s.to_string()),
987 None => {
988 return Err((
989 -32602,
990 "tags must be an array of strings".to_string(),
991 ))
992 }
993 }
994 }
995 out
996 }
997 _ => {
998 return Err((
999 -32602,
1000 "tags must be an array of strings".to_string(),
1001 ))
1002 }
1003 };
1004 let meta = match params.get("meta") {
1005 None | Some(Value::Null) => None,
1006 Some(v) => Some(v.clone()),
1007 };
1008 let record = store::new_record(namespace, content, tags, meta);
1009 let timestamp_ms = record.timestamp_ms;
1010 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
1011 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1012 }
1013 "memory.query" => {
1014 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1015 let namespace = Self::param_str(¶ms, "namespace")?;
1016 Self::require_namespace_matches(inbox, &namespace).await?;
1017 let q = MemoryQuery {
1018 content_contains: params
1019 .get("content_contains")
1020 .and_then(Value::as_str)
1021 .map(str::to_string),
1022 tag_prefix: params
1023 .get("tag_prefix")
1024 .and_then(Value::as_str)
1025 .map(str::to_string),
1026 since_ms: params.get("since_ms").and_then(Value::as_u64),
1027 until_ms: params.get("until_ms").and_then(Value::as_u64),
1028 limit: params
1029 .get("limit")
1030 .and_then(Value::as_u64)
1031 .map(|n| n as usize),
1032 };
1033 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1034 Ok(serde_json::json!({"records": records}))
1035 }
1036 "config.get" => {
1037 let key = Self::param_str(¶ms, "key")?;
1038 Self::validate_config_key(&key)?;
1039 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1040 Ok(serde_json::json!({"value": value}))
1041 }
1042 "config.set" => {
1043 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1044 let key = Self::param_str(¶ms, "key")?;
1045 Self::validate_config_key(&key)?;
1046 let value = Self::param_str(¶ms, "value")?;
1047 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1048 .map_err(|e| (-32000, e.to_string()))?;
1049 Ok(serde_json::json!({"ok": true}))
1050 }
1051 "config.subscribe" => {
1052 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1053 Ok(serde_json::json!({"ok": true}))
1057 }
1058 other => Err((-32601, format!("method not found: {other}"))),
1059 }
1060 }
1061
1062 async fn require_permission(
1063 inbox: &Arc<Inbox>,
1064 perm: crate::extensions::permissions::Permission,
1065 wire: &str,
1066 ) -> Result<(), (i32, String)> {
1067 let guard = inbox.permissions.read().await;
1068 match guard.as_ref() {
1069 Some(set) if set.has(perm) => Ok(()),
1070 _ => Err((
1071 -32602,
1072 format!("permission denied: {wire} required"),
1073 )),
1074 }
1075 }
1076
1077 async fn require_namespace_matches(
1078 inbox: &Arc<Inbox>,
1079 namespace: &str,
1080 ) -> Result<(), (i32, String)> {
1081 if namespace == inbox.extension_id {
1082 Ok(())
1083 } else {
1084 Err((
1085 -32602,
1086 format!(
1087 "namespace must equal extension id '{}' (got '{}')",
1088 inbox.extension_id, namespace
1089 ),
1090 ))
1091 }
1092 }
1093
1094 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1095 params
1096 .get(name)
1097 .and_then(Value::as_str)
1098 .map(str::to_string)
1099 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1100 }
1101
1102 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1103 let trimmed = key.trim();
1104 if trimmed.is_empty() {
1105 return Err((-32602, "config key must be non-empty".to_string()));
1106 }
1107 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1108 return Err((
1109 -32602,
1110 "config key must not contain dots, slashes, or spaces".to_string(),
1111 ));
1112 }
1113 Ok(())
1114 }
1115
1116 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1117 let params = InitializeParams {
1118 synaps_version: env!("CARGO_PKG_VERSION"),
1119 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1120 plugin_id: self.id.clone(),
1121 plugin_root: plugin_root
1122 .or_else(|| self.cwd.clone())
1123 .map(|path| path.to_string_lossy().to_string()),
1124 config,
1125 };
1126 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1127 Self::parse_initialize_result(&self.id, value)
1128 }
1129
1130 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1131 let result: InitializeResult = serde_json::from_value(value)
1132 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1133 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1134 return Err(format!(
1135 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1136 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1137 ));
1138 }
1139 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1140 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1141 Ok(InitializeCapabilitiesResult {
1142 tools: result.capabilities.tools,
1143 providers: result.capabilities.providers,
1144 capabilities: result.capabilities.capabilities,
1145 })
1146 }
1147
1148 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1149 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1150 let mut names = HashSet::new();
1151 for tool in tools {
1152 let name = tool.name.trim();
1153 if let Err(err) = validate_id_segment(name) {
1154 return Err(match err {
1155 IdValidationError::Empty => format!(
1156 "Extension '{}' registered a tool with an empty tool name",
1157 id
1158 ),
1159 IdValidationError::ContainsReserved { ch } => format!(
1160 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1161 id, name, ch
1162 ),
1163 IdValidationError::TooLong { len, max } => format!(
1164 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1165 id, name, max, len
1166 ),
1167 IdValidationError::ContainsWhitespace => format!(
1168 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1169 id, name
1170 ),
1171 IdValidationError::ContainsControl { ch } => format!(
1172 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1173 id, name, ch as u32
1174 ),
1175 });
1176 }
1177 if !names.insert(name.to_string()) {
1178 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1179 }
1180 if tool.description.trim().is_empty() {
1181 return Err(format!(
1182 "Extension '{}' registered tool '{}' with an empty description",
1183 id, name,
1184 ));
1185 }
1186 if !tool.input_schema.is_object() {
1187 return Err(format!(
1188 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1189 id, name,
1190 ));
1191 }
1192 }
1193 Ok(())
1194 }
1195
1196 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1197 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1198 for provider in providers {
1199 let provider_id = provider.id.trim();
1200 match validate_id_segment(provider_id) {
1201 Ok(()) => {
1202 if !Self::is_safe_provider_id(provider_id) {
1203 return Err(format!(
1204 "Extension '{}' registered provider '{}' with invalid provider id",
1205 id, provider_id
1206 ));
1207 }
1208 }
1209 Err(IdValidationError::Empty) => {
1210 return Err(format!(
1211 "Extension '{}' registered provider with empty provider id",
1212 id
1213 ));
1214 }
1215 Err(err) => {
1216 return Err(format!(
1217 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1218 id, provider_id, err
1219 ));
1220 }
1221 }
1222 if provider.display_name.trim().is_empty() {
1223 return Err(format!(
1224 "Extension '{}' registered provider '{}' with empty display_name",
1225 id, provider_id,
1226 ));
1227 }
1228 if provider.description.trim().is_empty() {
1229 return Err(format!(
1230 "Extension '{}' registered provider '{}' with empty description",
1231 id, provider_id,
1232 ));
1233 }
1234 if provider.models.is_empty() {
1235 return Err(format!(
1236 "Extension '{}' registered provider '{}' must declare at least one model",
1237 id, provider_id,
1238 ));
1239 }
1240 let mut model_ids = HashSet::new();
1241 for model in &provider.models {
1242 let model_id = model.id.trim();
1243 if let Err(err) = validate_id_segment(model_id) {
1244 return Err(match err {
1245 IdValidationError::Empty => format!(
1246 "Extension '{}' registered provider '{}' with empty model id",
1247 id, provider_id
1248 ),
1249 IdValidationError::ContainsReserved { ch } => format!(
1250 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1251 id, provider_id, model_id, ch
1252 ),
1253 IdValidationError::TooLong { len, max } => format!(
1254 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1255 id, provider_id, model_id, max, len
1256 ),
1257 IdValidationError::ContainsWhitespace => format!(
1258 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1259 id, provider_id, model_id
1260 ),
1261 IdValidationError::ContainsControl { ch } => format!(
1262 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1263 id, provider_id, model_id, ch as u32
1264 ),
1265 });
1266 }
1267 if !model_ids.insert(model_id.to_string()) {
1268 return Err(format!(
1269 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1270 id, provider_id, model_id,
1271 ));
1272 }
1273 }
1274 if let Some(config_schema) = &provider.config_schema {
1275 if !config_schema.is_object() {
1276 return Err(format!(
1277 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1278 id, provider_id,
1279 ));
1280 }
1281 }
1282 }
1283 Ok(())
1284 }
1285
1286 fn is_safe_provider_id(id: &str) -> bool {
1287 !id.is_empty()
1288 && !id.contains(':')
1289 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1290 }
1291
1292 #[doc(hidden)]
1293 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1294 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1295 }
1296
1297 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1298 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1299 let max_attempts = self.restart_policy.max_attempts;
1300 if attempted > max_attempts as usize {
1301 *state = None;
1302 return Err(format!(
1303 "Extension '{}' exceeded restart limit ({})",
1304 self.id, max_attempts,
1305 ));
1306 }
1307
1308 if let Some(old) = state.take() {
1309 old.reader_handle.abort();
1310 let mut child = old.child;
1311 let _ = child.kill().await;
1312 }
1313 self.inbox
1315 .fail_all_pending("transport closed: process restarting")
1316 .await;
1317
1318 let delay = self
1319 .restart_policy
1320 .delay_for_attempt(attempted as u32)
1321 .unwrap_or_default();
1322
1323 tracing::warn!(
1324 extension = %self.id,
1325 attempt = attempted,
1326 max_attempts = max_attempts,
1327 delay_ms = delay.as_millis() as u64,
1328 "Restarting extension process after transport failure",
1329 );
1330
1331 if !delay.is_zero() {
1332 tokio::time::sleep(delay).await;
1333 }
1334
1335 *state = Some(Self::spawn_state(
1336 &self.id,
1337 &self.command,
1338 &self.args,
1339 self.cwd.as_ref(),
1340 self.inbox.clone(),
1341 ).await?);
1342 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1344 self.initialize_locked(state).await?;
1345 self.restart_count.store(0, Ordering::Relaxed);
1348 Ok(())
1349 }
1350
1351
1352 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1353 let params = InitializeParams {
1354 synaps_version: env!("CARGO_PKG_VERSION"),
1355 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1356 plugin_id: self.id.clone(),
1357 plugin_root: self.cwd
1358 .clone()
1359 .map(|path| path.to_string_lossy().to_string()),
1360 config: Value::Object(Default::default()),
1361 };
1362 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1363 let value = tokio::time::timeout(
1364 std::time::Duration::from_secs(10),
1365 self.call_once_locked(
1366 state.as_mut().expect("state should exist for initialize"),
1367 "initialize",
1368 serde_json::to_value(params).map_err(|e| e.to_string())?,
1369 id,
1370 ),
1371 )
1372 .await
1373 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1374 ?;
1375 Self::parse_initialize_result(&self.id, value).map(|_| ())
1376 }
1377
1378 async fn call_once_locked(
1382 &self,
1383 state: &mut ProcessState,
1384 method: &str,
1385 params: Value,
1386 id: u64,
1387 ) -> Result<Value, String> {
1388 let body = serde_json::to_string(&JsonRpcRequest {
1389 jsonrpc: "2.0",
1390 method: method.to_string(),
1391 params,
1392 id,
1393 })
1394 .map_err(|e| format!("Serialize error: {}", e))?;
1395
1396 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1397 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1400 return Err("transport closed: inbox is shut down".to_string());
1401 }
1402
1403 self.inbox.pending.lock().await.insert(id, tx);
1406
1407 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1410 self.inbox.pending.lock().await.remove(&id);
1411 return Err("transport closed: inbox shut down during registration".to_string());
1412 }
1413
1414 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1415 let write_result = {
1416 let mut stdin = state.stdin.lock().await;
1417 match stdin.write_all(frame.as_bytes()).await {
1418 Ok(()) => stdin.flush().await,
1419 Err(e) => Err(e),
1420 }
1421 };
1422 if let Err(e) = write_result {
1423 self.inbox.pending.lock().await.remove(&id);
1425 return Err(format!("Write error: {}", e));
1426 }
1427
1428 match rx.await {
1429 Ok(payload) => payload,
1430 Err(_) => {
1431 self.inbox.pending.lock().await.remove(&id);
1436 Err("transport closed: response channel dropped".to_string())
1437 }
1438 }
1439 }
1440
1441 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1442 let _call_guard = self.call_lock.lock().await;
1443 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1444 let mut state_guard = self.state.lock().await;
1445 if state_guard.is_none() {
1446 *state_guard = Some(Self::spawn_state(
1447 &self.id,
1448 &self.command,
1449 &self.args,
1450 self.cwd.as_ref(),
1451 self.inbox.clone(),
1452 ).await?);
1453 }
1454 self.call_once_locked(
1455 state_guard.as_mut().expect("state should exist"),
1456 method,
1457 params,
1458 id,
1459 ).await
1460 }
1461
1462 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1463 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1464 let id_str = self.id.clone();
1465 let method_str = method.to_string();
1466
1467 let result = tokio::time::timeout(
1468 std::time::Duration::from_secs(timeout_secs),
1469 self.call_inner(method, params),
1470 )
1471 .await;
1472
1473 match result {
1474 Ok(inner) => inner,
1475 Err(_) => Err(format!(
1476 "Extension '{}' method '{}' timed out after {}s",
1477 id_str, method_str, timeout_secs
1478 )),
1479 }
1480 }
1481
1482 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1483 let _call_guard = self.call_lock.lock().await;
1484 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1485 let mut state_guard = self.state.lock().await;
1486 if state_guard.is_none() {
1487 self.restart_locked(&mut state_guard).await?;
1488 }
1489
1490 let result = self
1491 .call_once_locked(
1492 state_guard.as_mut().expect("state should exist after restart"),
1493 method,
1494 params.clone(),
1495 id,
1496 )
1497 .await;
1498
1499 match result {
1500 Ok(value) => Ok(value),
1501 Err(first_error) => {
1502 self.restart_locked(&mut state_guard).await?;
1503 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1504 self.call_once_locked(
1505 state_guard.as_mut().expect("state should exist after restart"),
1506 method,
1507 params,
1508 retry_id,
1509 )
1510 .await
1511 .map_err(|retry_error| {
1512 format!("{}; retry after restart failed: {}", first_error, retry_error)
1513 })
1514 }
1515 }
1516 }
1517
1518 #[doc(hidden)]
1536 pub async fn subscribe_notifications(
1537 &self,
1538 ) -> (usize, mpsc::UnboundedReceiver<NotificationFrame>) {
1539 let (tx, rx) = mpsc::unbounded_channel();
1540 let id = self
1541 .inbox
1542 .next_sink_id
1543 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1544 let mut sinks = self.inbox.notification_sinks.lock().await;
1545 sinks.retain(|(_, tx)| !tx.is_closed());
1549 sinks.push((id, tx));
1550 (id, rx)
1551 }
1552
1553 #[doc(hidden)]
1558 pub async fn unsubscribe_notifications(&self, id: usize) {
1559 let mut sinks = self.inbox.notification_sinks.lock().await;
1560 sinks.retain(|(sub_id, tx)| *sub_id != id && !tx.is_closed());
1561 }
1562
1563 pub(crate) fn forward_invoke_command_frame(
1573 extension_id: &str,
1574 request_id: &str,
1575 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1576 sink_open: &mut bool,
1577 frame: NotificationFrame,
1578 ) -> bool {
1579 use crate::extensions::commands::parse_command_output;
1580 use crate::extensions::tasks::{is_task_method, parse_task_event};
1581 use crate::extensions::runtime::InvokeCommandEvent;
1582
1583 let mut saw_done = false;
1584 if frame.method == "command.output" {
1585 match parse_command_output(&frame.params) {
1586 Ok(parsed) if parsed.request_id == request_id => {
1587 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1588 saw_done = true;
1589 }
1590 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1591 *sink_open = false;
1592 }
1593 }
1594 Ok(_) => {
1595 tracing::trace!(
1597 extension = %extension_id,
1598 "Ignoring command.output for unrelated request_id",
1599 );
1600 }
1601 Err(error) => {
1602 tracing::warn!(
1603 extension = %extension_id,
1604 error = %error,
1605 params = %frame.params,
1606 "Skipping malformed command.output notification",
1607 );
1608 }
1609 }
1610 } else if is_task_method(&frame.method) {
1611 match parse_task_event(&frame.method, &frame.params) {
1612 Ok(event) => {
1613 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1614 *sink_open = false;
1615 }
1616 }
1617 Err(error) => {
1618 tracing::warn!(
1619 extension = %extension_id,
1620 method = %frame.method,
1621 error = %error,
1622 params = %frame.params,
1623 "Skipping malformed task notification",
1624 );
1625 }
1626 }
1627 } else {
1628 tracing::trace!(
1629 extension = %extension_id,
1630 method = %frame.method,
1631 "Ignoring non-command/task notification during command.invoke",
1632 );
1633 }
1634 saw_done
1635 }
1636
1637 fn forward_provider_stream_frame(
1644 extension_id: &str,
1645 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1646 sink_open: &mut bool,
1647 frame: NotificationFrame,
1648 ) {
1649 if frame.method != "provider.stream.event" {
1650 tracing::trace!(
1651 extension = %extension_id,
1652 method = %frame.method,
1653 "Ignoring non-stream notification during provider.stream",
1654 );
1655 return;
1656 }
1657 match parse_provider_stream_event(&frame.params) {
1658 Ok(event) => {
1659 if *sink_open && sink.send(event).is_err() {
1660 *sink_open = false;
1661 }
1662 }
1663 Err(error) => {
1664 tracing::warn!(
1665 extension = %extension_id,
1666 error = %error,
1667 params = %frame.params,
1668 "Skipping malformed provider.stream.event notification",
1669 );
1670 }
1671 }
1672 }
1673}
1674
1675#[async_trait]
1676impl ExtensionHandler for ProcessExtension {
1677 fn id(&self) -> &str {
1678 &self.id
1679 }
1680
1681 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1682 self.call("tool.call", serde_json::json!({
1683 "name": name,
1684 "input": input,
1685 })).await
1686 }
1687
1688 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1689 let value = tokio::time::timeout(
1690 std::time::Duration::from_secs(60),
1691 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1692 )
1693 .await
1694 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1695 let result: ProviderCompleteResult = serde_json::from_value(value)
1696 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1697 if result.content.is_empty() {
1698 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1699 }
1700 Ok(result)
1701 }
1702
1703 async fn provider_stream(
1704 &self,
1705 params: ProviderCompleteParams,
1706 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1707 ) -> Result<ProviderCompleteResult, String> {
1708 let (sub_id, mut rx) = self.subscribe_notifications().await;
1711 let params_value =
1712 serde_json::to_value(params).map_err(|e| e.to_string())?;
1713
1714 let extension_id = self.id.clone();
1715 let stream_future = async {
1716 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1717 let mut sink_open = true;
1718 let response = loop {
1719 tokio::select! {
1720 response = &mut call_fut => break response,
1721 Some(frame) = rx.recv() => {
1722 Self::forward_provider_stream_frame(
1723 &extension_id, &sink, &mut sink_open, frame,
1724 );
1725 }
1726 }
1727 };
1728 self.unsubscribe_notifications(sub_id).await;
1733 while let Some(frame) = rx.recv().await {
1734 Self::forward_provider_stream_frame(
1735 &extension_id, &sink, &mut sink_open, frame,
1736 );
1737 }
1738 response
1739 };
1740
1741 let outcome = tokio::time::timeout(
1742 std::time::Duration::from_secs(60),
1743 stream_future,
1744 )
1745 .await;
1746
1747 self.unsubscribe_notifications(sub_id).await;
1750
1751 let value = outcome
1752 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1753
1754 let result: ProviderCompleteResult = serde_json::from_value(value)
1755 .map_err(|e| {
1756 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1757 })?;
1758 Ok(result)
1761 }
1762
1763 async fn invoke_command(
1764 &self,
1765 command: &str,
1766 args: Vec<String>,
1767 request_id: &str,
1768 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1769 ) -> Result<Value, String> {
1770 let (sub_id, mut rx) = self.subscribe_notifications().await;
1772 let params = serde_json::json!({
1773 "command": command,
1774 "args": args,
1775 "request_id": request_id,
1776 });
1777
1778 let extension_id = self.id.clone();
1779 let request_id_owned = request_id.to_string();
1780 let invoke_future = async {
1781 let mut call_fut = Box::pin(self.call("command.invoke", params));
1782 let mut sink_open = true;
1783 let response = loop {
1784 tokio::select! {
1785 response = &mut call_fut => break response,
1786 Some(frame) = rx.recv() => {
1787 let _ = Self::forward_invoke_command_frame(
1788 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1789 );
1790 }
1791 }
1792 };
1793 self.unsubscribe_notifications(sub_id).await;
1797 while let Ok(frame) = rx.try_recv() {
1798 let _ = Self::forward_invoke_command_frame(
1799 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1800 );
1801 }
1802 response
1803 };
1804
1805 let outcome = tokio::time::timeout(
1806 std::time::Duration::from_secs(120),
1807 invoke_future,
1808 )
1809 .await;
1810
1811 self.unsubscribe_notifications(sub_id).await;
1814
1815 outcome
1816 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1817 }
1818
1819 async fn handle(&self, event: &HookEvent) -> HookResult {
1820 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1821 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1822 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1823 Ok(result) => result,
1824 Err(error) => {
1825 tracing::warn!(
1826 extension = %self.id,
1827 error = %error,
1828 response = %value,
1829 "Extension hook handler returned invalid result",
1830 );
1831 if value.get("action").and_then(Value::as_str) == Some("modify") {
1832 HookResult::Block {
1833 reason: "Extension returned malformed modify result".to_string(),
1834 }
1835 } else {
1836 HookResult::Continue
1837 }
1838 }
1839 },
1840 Ok(Err(e)) => {
1841 tracing::warn!(
1842 extension = %self.id,
1843 error = %e,
1844 "Extension hook handler failed — continuing",
1845 );
1846 HookResult::Continue
1847 }
1848 Err(_) => {
1849 tracing::warn!(
1850 extension = %self.id,
1851 timeout_secs = 5,
1852 "Extension hook handler timed out — continuing",
1853 );
1854 HookResult::Continue
1855 }
1856 }
1857 }
1858
1859 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1860 let value = tokio::time::timeout(
1861 std::time::Duration::from_secs(5),
1862 self.call("info.get", Value::Null),
1863 )
1864 .await
1865 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1866 serde_json::from_value(value)
1867 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1868 }
1869
1870 async fn sidecar_spawn_args(
1871 &self,
1872 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1873 let value = tokio::time::timeout(
1874 std::time::Duration::from_secs(5),
1875 self.call("sidecar.spawn_args", Value::Null),
1876 )
1877 .await
1878 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1879 serde_json::from_value(value).map_err(|e| {
1880 format!(
1881 "Invalid sidecar.spawn_args response from extension '{}': {}",
1882 self.id, e
1883 )
1884 })
1885 }
1886
1887 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1888 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1889 category: category.to_string(),
1890 field: field.to_string(),
1891 };
1892 tokio::time::timeout(
1893 std::time::Duration::from_secs(5),
1894 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1895 )
1896 .await
1897 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1898 }
1899
1900 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1901 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1902 key: key.to_string(),
1903 }).map_err(|e| e.to_string())?;
1904 if let Some(obj) = params.as_object_mut() {
1905 obj.insert("category".to_string(), Value::String(category.to_string()));
1906 obj.insert("field".to_string(), Value::String(field.to_string()));
1907 }
1908 tokio::time::timeout(
1909 std::time::Duration::from_secs(5),
1910 self.call("settings.editor.key", params),
1911 )
1912 .await
1913 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1914 }
1915
1916 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1917 let params = serde_json::json!({
1918 "category": category,
1919 "field": field,
1920 "value": value,
1921 });
1922 tokio::time::timeout(
1923 std::time::Duration::from_secs(5),
1924 self.call("settings.editor.commit", params),
1925 )
1926 .await
1927 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1928 }
1929
1930 async fn shutdown(&self) {
1931 let _ = tokio::time::timeout(
1932 std::time::Duration::from_millis(500),
1933 self.call("shutdown", Value::Null),
1934 )
1935 .await;
1936
1937 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1938 let mut state_guard = self.state.lock().await;
1939 if let Some(state) = state_guard.take() {
1940 state.reader_handle.abort();
1941 let mut child = state.child;
1942 let _ = child.kill().await;
1943 }
1944 self.inbox.notification_sinks.lock().await.clear();
1946 self.inbox
1947 .fail_all_pending("transport closed: extension shutdown")
1948 .await;
1949 }
1950
1951 async fn subscribe_notifications(&self) -> (usize, tokio::sync::mpsc::UnboundedReceiver<NotificationFrame>) {
1952 ProcessExtension::subscribe_notifications(self).await
1953 }
1954
1955 async fn restart_count(&self) -> usize {
1956 self.restart_count()
1957 }
1958
1959 async fn health(&self) -> ExtensionHealth {
1960 let count = self.restart_count.load(Ordering::Relaxed);
1961 let max = self.restart_policy.max_attempts as usize;
1962 if count >= max {
1963 ExtensionHealth::Failed
1964 } else if count > 0 {
1965 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1969 if state_alive {
1970 ExtensionHealth::Degraded
1971 } else {
1972 ExtensionHealth::Restarting
1973 }
1974 } else {
1975 ExtensionHealth::Running
1976 }
1977 }
1978}
1979
1980#[cfg(test)]
1981mod stream_event_tests {
1982 use super::*;
1983 use serde_json::json;
1984
1985 #[test]
1986 fn parses_text_delta_with_delta_key() {
1987 let v = json!({"type": "text", "delta": "hi"});
1988 assert_eq!(
1989 parse_provider_stream_event(&v).unwrap(),
1990 ProviderStreamEvent::TextDelta { text: "hi".into() }
1991 );
1992 }
1993
1994 #[test]
1995 fn parses_text_delta_with_text_key() {
1996 let v = json!({"type": "text", "text": "hi"});
1997 assert_eq!(
1998 parse_provider_stream_event(&v).unwrap(),
1999 ProviderStreamEvent::TextDelta { text: "hi".into() }
2000 );
2001 }
2002
2003 #[test]
2004 fn parses_thinking_delta() {
2005 let v = json!({"type": "thinking", "delta": "hmm"});
2006 assert_eq!(
2007 parse_provider_stream_event(&v).unwrap(),
2008 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2009 );
2010 let v2 = json!({"type": "thinking", "text": "hmm"});
2011 assert_eq!(
2012 parse_provider_stream_event(&v2).unwrap(),
2013 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
2014 );
2015 }
2016
2017 #[test]
2018 fn parses_tool_use() {
2019 let v = json!({
2020 "type": "tool_use",
2021 "id": "t1",
2022 "name": "echo",
2023 "input": {"x": 1}
2024 });
2025 assert_eq!(
2026 parse_provider_stream_event(&v).unwrap(),
2027 ProviderStreamEvent::ToolUse {
2028 id: "t1".into(),
2029 name: "echo".into(),
2030 input: json!({"x": 1}),
2031 }
2032 );
2033 }
2034
2035 #[test]
2036 fn tool_use_input_defaults_to_empty_object() {
2037 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2038 assert_eq!(
2039 parse_provider_stream_event(&v).unwrap(),
2040 ProviderStreamEvent::ToolUse {
2041 id: "t1".into(),
2042 name: "echo".into(),
2043 input: json!({}),
2044 }
2045 );
2046 }
2047
2048 #[test]
2049 fn parses_usage_strips_type() {
2050 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2051 assert_eq!(
2052 parse_provider_stream_event(&v).unwrap(),
2053 ProviderStreamEvent::Usage {
2054 usage: json!({"input_tokens": 5, "output_tokens": 7})
2055 }
2056 );
2057 }
2058
2059 #[test]
2060 fn parses_error() {
2061 let v = json!({"type": "error", "message": "boom"});
2062 assert_eq!(
2063 parse_provider_stream_event(&v).unwrap(),
2064 ProviderStreamEvent::Error { message: "boom".into() }
2065 );
2066 }
2067
2068 #[test]
2069 fn parses_done() {
2070 let v = json!({"type": "done"});
2071 assert_eq!(
2072 parse_provider_stream_event(&v).unwrap(),
2073 ProviderStreamEvent::Done
2074 );
2075 }
2076
2077 #[test]
2078 fn nested_event_shape_matches_flat() {
2079 let flat = json!({"type": "text", "delta": "hi"});
2080 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2081 assert_eq!(
2082 parse_provider_stream_event(&flat).unwrap(),
2083 parse_provider_stream_event(&nested).unwrap()
2084 );
2085 }
2086
2087 #[test]
2088 fn missing_type_errors() {
2089 let v = json!({"delta": "hi"});
2090 let err = parse_provider_stream_event(&v).unwrap_err();
2091 assert!(err.contains("missing type"), "got: {err}");
2092 }
2093
2094 #[test]
2095 fn unknown_type_errors_with_type() {
2096 let v = json!({"type": "wat"});
2097 let err = parse_provider_stream_event(&v).unwrap_err();
2098 assert!(err.contains("wat"), "got: {err}");
2099 }
2100
2101 #[test]
2102 fn tool_use_missing_id_errors() {
2103 let v = json!({"type": "tool_use", "name": "echo"});
2104 let err = parse_provider_stream_event(&v).unwrap_err();
2105 assert!(err.contains("id"), "got: {err}");
2106 }
2107
2108 #[test]
2109 fn tool_use_missing_name_errors() {
2110 let v = json!({"type": "tool_use", "id": "t1"});
2111 let err = parse_provider_stream_event(&v).unwrap_err();
2112 assert!(err.contains("name"), "got: {err}");
2113 }
2114
2115 #[test]
2116 fn tool_use_empty_id_errors() {
2117 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2118 assert!(parse_provider_stream_event(&v).is_err());
2119 }
2120
2121 #[test]
2122 fn tool_use_empty_name_errors() {
2123 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2124 assert!(parse_provider_stream_event(&v).is_err());
2125 }
2126
2127 #[test]
2128 fn tool_use_non_object_input_errors() {
2129 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2130 let err = parse_provider_stream_event(&v).unwrap_err();
2131 assert!(err.contains("input"), "got: {err}");
2132 }
2133
2134 #[test]
2135 fn text_missing_delta_and_text_errors() {
2136 let v = json!({"type": "text"});
2137 let err = parse_provider_stream_event(&v).unwrap_err();
2138 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2139 }
2140
2141 #[test]
2142 fn error_missing_message_errors() {
2143 let v = json!({"type": "error"});
2144 assert!(parse_provider_stream_event(&v).is_err());
2145 }
2146
2147 #[test]
2148 fn error_empty_message_errors() {
2149 let v = json!({"type": "error", "message": ""});
2150 assert!(parse_provider_stream_event(&v).is_err());
2151 }
2152}
2153
2154#[cfg(test)]
2155mod restart_policy_tests {
2156 use super::*;
2157
2158 #[tokio::test]
2159 async fn restart_policy_default_max_attempts_is_3() {
2160 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2165 .await
2166 .expect("spawn /bin/cat");
2167 assert_eq!(ext.restart_policy.max_attempts, 3);
2168 ext.shutdown().await;
2169 }
2170
2171 #[tokio::test]
2172 async fn with_restart_policy_overrides_default() {
2173 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2174 .await
2175 .expect("spawn /bin/cat");
2176 let custom = RestartPolicy {
2177 max_attempts: 7,
2178 ..RestartPolicy::default()
2179 };
2180 let ext = ext.with_restart_policy(custom);
2181 assert_eq!(ext.restart_policy.max_attempts, 7);
2182 ext.shutdown().await;
2183 }
2184}
2185
2186#[cfg(test)]
2187mod capture_validator_tests {
2188 use super::*;
2189 use crate::extensions::permissions::{Permission, PermissionSet};
2190
2191 fn perms_with(grants: &[Permission]) -> PermissionSet {
2192 let mut p = PermissionSet::new();
2193 for g in grants {
2194 p.grant(*g);
2195 }
2196 p
2197 }
2198
2199 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2200 CapabilityDeclaration {
2201 kind: kind.to_string(),
2202 name: name.to_string(),
2203 permissions: perms.iter().map(|p| p.to_string()).collect(),
2204 params: serde_json::Value::Null,
2205 }
2206 }
2207
2208 #[test]
2209 fn capability_validator_rejects_empty_kind() {
2210 let d = cap(" ", "Sample", &["audio.input"]);
2211 let perms = perms_with(&[Permission::AudioInput]);
2212 let err = validate_capability(&d, &perms).unwrap_err();
2213 assert!(err.contains("kind"), "got: {}", err);
2214 }
2215
2216 #[test]
2217 fn capability_validator_rejects_empty_name() {
2218 let d = cap("capture", " ", &["audio.input"]);
2219 let perms = perms_with(&[Permission::AudioInput]);
2220 let err = validate_capability(&d, &perms).unwrap_err();
2221 assert!(err.contains("name"), "got: {}", err);
2222 }
2223
2224 #[test]
2225 fn capability_validator_rejects_unknown_permission_string() {
2226 let d = cap("capture", "Sample", &["audio.telepathy"]);
2227 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2228 let err = validate_capability(&d, &perms).unwrap_err();
2229 assert!(
2230 err.contains("unknown permission") && err.contains("audio.telepathy"),
2231 "got: {}",
2232 err,
2233 );
2234 }
2235
2236 #[test]
2237 fn capability_validator_requires_every_declared_permission() {
2238 let d = cap("capture", "Sample", &["audio.input"]);
2239 let perms = perms_with(&[]);
2240 let err = validate_capability(&d, &perms).unwrap_err();
2241 assert!(
2242 err.contains("audio.input") && err.contains("not granted"),
2243 "got: {}",
2244 err,
2245 );
2246 }
2247
2248 #[test]
2249 fn capability_validator_accepts_when_all_permissions_granted() {
2250 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2251 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2252 validate_capability(&d, &perms).expect("should validate");
2253 }
2254
2255 #[test]
2256 fn capability_validator_accepts_no_permissions() {
2257 let d = cap("ocr", "Tesseract", &[]);
2261 let perms = perms_with(&[]);
2262 validate_capability(&d, &perms).expect("should validate");
2263 }
2264
2265 #[test]
2266 fn capability_validator_does_not_branch_on_kind() {
2267 let perms = perms_with(&[Permission::AudioInput]);
2271 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2272 let d = cap(kind, "Anything", &["audio.input"]);
2273 validate_capability(&d, &perms).expect("should validate");
2274 }
2275 }
2276
2277}
2278
2279#[cfg(test)]
2280mod invoke_command_dispatch_tests {
2281 use super::*;
2286 use crate::extensions::commands::CommandOutputEvent;
2287 use crate::extensions::runtime::InvokeCommandEvent;
2288 use crate::extensions::tasks::{TaskEvent, TaskKind};
2289 use serde_json::json;
2290 use tokio::sync::mpsc;
2291
2292 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2293 NotificationFrame {
2294 method: method.to_string(),
2295 params,
2296 }
2297 }
2298
2299 #[test]
2300 fn forwards_mixed_event_stream_in_order() {
2301 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2302 let mut open = true;
2303 let frames = vec![
2304 frame(
2305 "command.output",
2306 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2307 ),
2308 frame(
2309 "task.start",
2310 json!({"id":"dl","label":"Downloading","kind":"download"}),
2311 ),
2312 frame(
2313 "task.update",
2314 json!({"id":"dl","current":50,"total":100}),
2315 ),
2316 frame(
2317 "command.output",
2318 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2319 ),
2320 frame("task.done", json!({"id":"dl"})),
2321 frame(
2322 "command.output",
2323 json!({"request_id":"r1","event":{"kind":"done"}}),
2324 ),
2325 ];
2326
2327 let mut saw_done = false;
2328 for f in frames {
2329 saw_done |= ProcessExtension::forward_invoke_command_frame(
2330 "ext-test", "r1", &tx, &mut open, f,
2331 );
2332 }
2333 drop(tx);
2334 assert!(saw_done, "should have observed the command Done marker");
2335
2336 let mut events = Vec::new();
2337 while let Ok(ev) = rx.try_recv() {
2338 events.push(ev);
2339 }
2340 assert_eq!(events.len(), 6);
2341 assert_eq!(
2342 events[0],
2343 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2344 );
2345 assert!(matches!(
2346 events[1],
2347 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2348 ));
2349 assert!(matches!(
2350 events[2],
2351 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2352 ));
2353 assert!(matches!(
2354 events[3],
2355 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2356 ));
2357 assert!(matches!(
2358 events[4],
2359 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2360 ));
2361 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2362 }
2363
2364 #[test]
2365 fn ignores_command_output_for_unrelated_request_id() {
2366 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2367 let mut open = true;
2368 ProcessExtension::forward_invoke_command_frame(
2369 "ext",
2370 "r1",
2371 &tx,
2372 &mut open,
2373 frame(
2374 "command.output",
2375 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2376 ),
2377 );
2378 drop(tx);
2379 assert!(rx.try_recv().is_err());
2380 }
2381
2382 #[test]
2383 fn skips_malformed_command_output_without_aborting() {
2384 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2385 let mut open = true;
2386 ProcessExtension::forward_invoke_command_frame(
2388 "ext",
2389 "r1",
2390 &tx,
2391 &mut open,
2392 frame("command.output", json!({"request_id":"r1","event":{}})),
2393 );
2394 ProcessExtension::forward_invoke_command_frame(
2396 "ext",
2397 "r1",
2398 &tx,
2399 &mut open,
2400 frame(
2401 "command.output",
2402 json!({"request_id":"r1","event":{"kind":"done"}}),
2403 ),
2404 );
2405 drop(tx);
2406 let ev = rx.try_recv().unwrap();
2407 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2408 assert!(rx.try_recv().is_err());
2409 }
2410
2411 #[test]
2412 fn task_events_pass_through_regardless_of_request_id() {
2413 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2414 let mut open = true;
2415 ProcessExtension::forward_invoke_command_frame(
2416 "ext",
2417 "r1",
2418 &tx,
2419 &mut open,
2420 frame("task.log", json!({"id":"abc","line":"..."})),
2421 );
2422 drop(tx);
2423 match rx.try_recv().unwrap() {
2424 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2425 assert_eq!(id, "abc");
2426 assert_eq!(line, "...");
2427 }
2428 other => panic!("unexpected: {other:?}"),
2429 }
2430 }
2431
2432 #[test]
2433 fn unrelated_methods_are_dropped() {
2434 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2435 let mut open = true;
2436 ProcessExtension::forward_invoke_command_frame(
2437 "ext",
2438 "r1",
2439 &tx,
2440 &mut open,
2441 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2442 );
2443 drop(tx);
2444 assert!(rx.try_recv().is_err());
2445 }
2446}