1use std::path::PathBuf;
7use std::process::Stdio;
8use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::collections::{HashMap, HashSet};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
16use tokio::process::{Child, ChildStdin, ChildStdout, Command};
17use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
18use tokio::task::JoinHandle;
19
20use super::{ExtensionHandler, ExtensionHealth, RestartPolicy};
21use crate::extensions::hooks::events::{HookEvent, HookResult};
22use crate::extensions::manifest::CURRENT_EXTENSION_PROTOCOL_VERSION;
23
24#[derive(Serialize)]
25struct JsonRpcRequest {
26 jsonrpc: &'static str,
27 method: String,
28 params: Value,
29 id: u64,
30}
31
32
33#[derive(Serialize)]
34struct InitializeParams {
35 synaps_version: &'static str,
36 extension_protocol_version: u32,
37 plugin_id: String,
38 plugin_root: Option<String>,
39 config: Value,
40}
41
42#[derive(Debug, Clone, PartialEq, Deserialize)]
43pub struct RegisteredExtensionToolSpec {
44 pub name: String,
45 pub description: String,
46 pub input_schema: Value,
47}
48
49#[derive(Debug, Clone, PartialEq, Deserialize)]
50pub struct RegisteredProviderSpec {
51 pub id: String,
52 pub display_name: String,
53 pub description: String,
54 #[serde(default)]
55 pub models: Vec<RegisteredProviderModelSpec>,
56 #[serde(default)]
57 pub config_schema: Option<Value>,
58}
59
60#[derive(Debug, Clone, PartialEq, Deserialize)]
61pub struct RegisteredProviderModelSpec {
62 pub id: String,
63 #[serde(default)]
64 pub display_name: Option<String>,
65 #[serde(default)]
66 pub capabilities: Value,
67 #[serde(default)]
68 pub context_window: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize)]
72pub struct ProviderCompleteParams {
73 pub provider_id: String,
74 pub model_id: String,
75 pub model: String,
76 pub messages: Vec<Value>,
77 pub system_prompt: Option<String>,
78 pub tools: Vec<Value>,
79 pub temperature: Option<f32>,
80 pub max_tokens: Option<u32>,
81 pub thinking_budget: u32,
82}
83
84#[derive(Debug, Clone, PartialEq, Deserialize)]
85pub struct ProviderCompleteResult {
86 pub content: Vec<Value>,
87 #[serde(default)]
88 pub stop_reason: Option<String>,
89 #[serde(default)]
90 pub usage: Option<Value>,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct ProviderToolUse {
95 pub id: String,
96 pub name: String,
97 pub input: Value,
98}
99
100#[derive(Debug, Clone, PartialEq)]
102pub enum ProviderStreamEvent {
103 TextDelta { text: String },
105 ThinkingDelta { text: String },
107 ToolUse {
109 id: String,
110 name: String,
111 input: Value,
112 },
113 Usage { usage: Value },
115 Error { message: String },
117 Done,
119}
120
121pub fn parse_provider_stream_event(params: &Value) -> Result<ProviderStreamEvent, String> {
127 let inner = match params.get("event") {
128 Some(ev) => ev,
129 None => params,
130 };
131 let obj = inner
132 .as_object()
133 .ok_or_else(|| "provider stream event must be a JSON object".to_string())?;
134
135 let ty = obj
136 .get("type")
137 .and_then(Value::as_str)
138 .ok_or_else(|| "provider stream event missing type".to_string())?;
139
140 match ty {
141 "text" => {
142 let text = obj
143 .get("delta")
144 .or_else(|| obj.get("text"))
145 .and_then(Value::as_str)
146 .ok_or_else(|| {
147 "provider stream text event missing 'delta' or 'text'".to_string()
148 })?;
149 Ok(ProviderStreamEvent::TextDelta {
150 text: text.to_string(),
151 })
152 }
153 "thinking" => {
154 let text = obj
155 .get("delta")
156 .or_else(|| obj.get("text"))
157 .and_then(Value::as_str)
158 .ok_or_else(|| {
159 "provider stream thinking event missing 'delta' or 'text'".to_string()
160 })?;
161 Ok(ProviderStreamEvent::ThinkingDelta {
162 text: text.to_string(),
163 })
164 }
165 "tool_use" => {
166 let id = obj
167 .get("id")
168 .and_then(Value::as_str)
169 .ok_or_else(|| "provider stream tool_use missing id".to_string())?;
170 if id.is_empty() {
171 return Err("provider stream tool_use id must be non-empty".to_string());
172 }
173 let name = obj
174 .get("name")
175 .and_then(Value::as_str)
176 .ok_or_else(|| "provider stream tool_use missing name".to_string())?;
177 if name.is_empty() {
178 return Err("provider stream tool_use name must be non-empty".to_string());
179 }
180 let input = match obj.get("input") {
181 None => Value::Object(Default::default()),
182 Some(v) if v.is_object() => v.clone(),
183 Some(_) => {
184 return Err(
185 "provider stream tool_use input must be a JSON object".to_string()
186 );
187 }
188 };
189 Ok(ProviderStreamEvent::ToolUse {
190 id: id.to_string(),
191 name: name.to_string(),
192 input,
193 })
194 }
195 "usage" => {
196 let mut clone = obj.clone();
197 clone.remove("type");
198 Ok(ProviderStreamEvent::Usage {
199 usage: Value::Object(clone),
200 })
201 }
202 "error" => {
203 let message = obj
204 .get("message")
205 .and_then(Value::as_str)
206 .ok_or_else(|| "provider stream error missing message".to_string())?;
207 if message.is_empty() {
208 return Err("provider stream error message must be non-empty".to_string());
209 }
210 Ok(ProviderStreamEvent::Error {
211 message: message.to_string(),
212 })
213 }
214 "done" => Ok(ProviderStreamEvent::Done),
215 other => Err(format!("unknown provider stream event type: {other}")),
216 }
217}
218
219pub async fn execute_provider_tool_use(
220 registry: &crate::ToolRegistry,
221 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
222 tool_use: ProviderToolUse,
223 ctx: crate::ToolContext,
224 max_tool_output: usize,
225) -> Value {
226 let tool_id = tool_use.id;
227 let tool_name = tool_use.name;
228 let input = tool_use.input;
229
230 let Some(tool) = registry.get(&tool_name).cloned() else {
231 return serde_json::json!({
232 "type": "tool_result",
233 "tool_use_id": tool_id,
234 "content": format!("Unknown tool: {}", tool_name),
235 "is_error": true,
236 });
237 };
238
239 let runtime_name = registry.runtime_name_for_api(&tool_name).to_string();
240 let input = registry.translate_input_for_api_tool(&tool_name, input);
241 let decision = crate::runtime::resolve_before_tool_call_decision(
242 input.clone(),
243 crate::runtime::emit_before_tool_call(
244 hook_bus,
245 &tool_name,
246 Some(&runtime_name),
247 input.clone(),
248 ).await,
249 ctx.capabilities.secret_prompt.as_ref(),
250 false,
251 ).await;
252
253 let crate::runtime::BeforeToolCallDecision::Continue { input } = decision else {
254 let crate::runtime::BeforeToolCallDecision::Block { reason } = decision else { unreachable!() };
255 return serde_json::json!({
256 "type": "tool_result",
257 "tool_use_id": tool_id,
258 "content": format!("Tool call blocked by extension: {}", reason),
259 "is_error": true,
260 });
261 };
262
263 let input_for_hook = input.clone();
264 let (result, is_error) = match tool.execute(input, ctx).await {
265 Ok(output) => (output, false),
266 Err(error) => (format!("Tool execution failed: {}", error), true),
267 };
268 let _ = crate::runtime::emit_after_tool_call(
269 hook_bus,
270 &tool_name,
271 Some(&runtime_name),
272 input_for_hook,
273 result.clone(),
274 ).await;
275
276 let mut response = serde_json::json!({
277 "type": "tool_result",
278 "tool_use_id": tool_id,
279 "content": crate::truncate_str(&result, max_tool_output).to_string(),
280 });
281 if is_error {
282 response["is_error"] = serde_json::json!(true);
283 }
284 response
285}
286
287pub async fn complete_provider_with_tools<F>(
288 handler: Arc<dyn ExtensionHandler>,
289 mut params: ProviderCompleteParams,
290 registry: &crate::ToolRegistry,
291 hook_bus: &Arc<crate::extensions::hooks::HookBus>,
292 mut context_factory: F,
293 max_tool_output: usize,
294 max_iterations: usize,
295) -> Result<ProviderCompleteResult, String>
296where
297 F: FnMut() -> crate::ToolContext,
298{
299 let max_iterations = max_iterations.max(1);
300 for iteration in 0..max_iterations {
301 let result = handler.provider_complete(params.clone()).await?;
302 let tool_uses = extract_provider_tool_uses(&result.content)?;
303 if tool_uses.is_empty() {
304 return Ok(result);
305 }
306 if iteration + 1 == max_iterations {
307 return Err(format!(
308 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
309 handler.id(),
310 max_iterations,
311 ));
312 }
313
314 let assistant_content = result.content.clone();
315 params.messages.push(serde_json::json!({
316 "role": "assistant",
317 "content": assistant_content,
318 }));
319
320 let mut tool_results = Vec::with_capacity(tool_uses.len());
321 for tool_use in tool_uses {
322 tool_results.push(execute_provider_tool_use(
323 registry,
324 hook_bus,
325 tool_use,
326 context_factory(),
327 max_tool_output,
328 ).await);
329 }
330 params.messages.push(serde_json::json!({
331 "role": "user",
332 "content": tool_results,
333 }));
334 }
335 Err(format!(
336 "extension provider '{}' exceeded provider tool-use iteration limit ({})",
337 handler.id(),
338 max_iterations,
339 ))
340}
341
342pub fn extract_provider_tool_uses(content: &[Value]) -> Result<Vec<ProviderToolUse>, String> {
343 let mut tool_uses = Vec::new();
344 for block in content {
345 if block.get("type").and_then(Value::as_str) != Some("tool_use") {
346 continue;
347 }
348 let id = block
349 .get("id")
350 .and_then(Value::as_str)
351 .ok_or_else(|| "provider tool_use missing id".to_string())?;
352 let name = block
353 .get("name")
354 .and_then(Value::as_str)
355 .ok_or_else(|| "provider tool_use missing name".to_string())?;
356 if id.trim().is_empty() {
357 return Err("provider tool_use id is empty".to_string());
358 }
359 if name.trim().is_empty() {
360 return Err("provider tool_use name is empty".to_string());
361 }
362 let input = block
363 .get("input")
364 .cloned()
365 .unwrap_or_else(|| serde_json::json!({}));
366 if !input.is_object() {
367 return Err(format!(
368 "provider tool_use '{}' input must be a JSON object",
369 id
370 ));
371 }
372 tool_uses.push(ProviderToolUse {
373 id: id.to_string(),
374 name: name.to_string(),
375 input,
376 });
377 }
378 Ok(tool_uses)
379}
380
381#[derive(Debug, Clone, Default, PartialEq)]
382pub struct InitializeCapabilitiesResult {
383 pub tools: Vec<RegisteredExtensionToolSpec>,
384 pub providers: Vec<RegisteredProviderSpec>,
385 pub capabilities: Vec<CapabilityDeclaration>,
391}
392
393#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
402pub struct CapabilityDeclaration {
403 pub kind: String,
407 pub name: String,
410 #[serde(default)]
415 pub permissions: Vec<String>,
416 #[serde(default, skip_serializing_if = "is_null_value")]
420 pub params: serde_json::Value,
421}
422
423fn is_null_value(v: &serde_json::Value) -> bool {
424 v.is_null()
425}
426
427pub fn validate_capability(
437 decl: &CapabilityDeclaration,
438 granted: &crate::extensions::permissions::PermissionSet,
439) -> Result<(), String> {
440 use crate::extensions::permissions::Permission;
441 if decl.kind.trim().is_empty() {
442 return Err("capability 'kind' must be non-empty".to_string());
443 }
444 if decl.name.trim().is_empty() {
445 return Err("capability 'name' must be non-empty".to_string());
446 }
447 for perm_name in &decl.permissions {
448 let parsed = Permission::parse(perm_name).ok_or_else(|| {
449 format!(
450 "capability '{}' declares unknown permission '{}'",
451 decl.kind, perm_name
452 )
453 })?;
454 if !granted.has(parsed) {
455 return Err(format!(
456 "capability '{}' requires permission '{}' but it is not granted",
457 decl.kind, perm_name
458 ));
459 }
460 }
461 Ok(())
462}
463
464#[derive(Deserialize)]
465struct InitializeResult {
466 protocol_version: u32,
467 #[serde(default)]
468 capabilities: InitializeCapabilities,
469}
470
471#[derive(Default, Deserialize)]
472struct InitializeCapabilities {
473 #[serde(default)]
474 tools: Vec<RegisteredExtensionToolSpec>,
475 #[serde(default)]
476 providers: Vec<RegisteredProviderSpec>,
477 #[serde(default)]
479 capabilities: Vec<CapabilityDeclaration>,
480}
481
482#[doc(hidden)]
488#[derive(Debug, Clone)]
489pub struct NotificationFrame {
490 pub method: String,
491 pub params: Value,
492}
493
494struct Inbox {
501 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>,
502 notification_sink: Mutex<Option<mpsc::UnboundedSender<NotificationFrame>>>,
503 closed: std::sync::atomic::AtomicBool,
506 permissions: RwLock<Option<crate::extensions::permissions::PermissionSet>>,
509 inbound_stdin: Mutex<Option<Arc<Mutex<ChildStdin>>>>,
513 extension_id: String,
515}
516
517impl Inbox {
518 fn new(extension_id: String) -> Self {
519 Self {
520 pending: Mutex::new(HashMap::new()),
521 notification_sink: Mutex::new(None),
522 closed: std::sync::atomic::AtomicBool::new(false),
523 permissions: RwLock::new(None),
524 inbound_stdin: Mutex::new(None),
525 extension_id,
526 }
527 }
528
529 async fn fail_all_pending(&self, reason: &str) {
532 self.closed.store(true, std::sync::atomic::Ordering::Release);
533 let drained: Vec<_> = {
534 let mut pending = self.pending.lock().await;
535 pending.drain().collect()
536 };
537 for (_, tx) in drained {
538 let _ = tx.send(Err(reason.to_string()));
539 }
540 }
541}
542
543struct ProcessState {
544 child: Child,
545 stdin: Arc<Mutex<ChildStdin>>,
546 reader_handle: JoinHandle<()>,
547}
548
549pub struct ProcessExtension {
551 id: String,
552 command: String,
553 args: Vec<String>,
554 cwd: Option<PathBuf>,
555 state: Arc<Mutex<Option<ProcessState>>>,
556 call_lock: Arc<Mutex<()>>,
558 next_id: AtomicU64,
559 restart_count: AtomicUsize,
560 pub(crate) restart_policy: RestartPolicy,
562 inbox: Arc<Inbox>,
566}
567
568impl ProcessExtension {
569 pub async fn spawn(id: &str, command: &str, args: &[String]) -> Result<Self, String> {
570 Self::spawn_with_cwd(id, command, args, None).await
571 }
572
573 pub async fn spawn_with_cwd(
578 id: &str,
579 command: &str,
580 args: &[String],
581 cwd: Option<PathBuf>,
582 ) -> Result<Self, String> {
583 let inbox = Arc::new(Inbox::new(id.to_string()));
584 let state = Self::spawn_state(id, command, args, cwd.as_ref(), inbox.clone()).await?;
585 Ok(Self {
586 id: id.to_string(),
587 command: command.to_string(),
588 args: args.to_vec(),
589 cwd,
590 state: Arc::new(Mutex::new(Some(state))),
591 call_lock: Arc::new(Mutex::new(())),
592 next_id: AtomicU64::new(1),
593 restart_count: AtomicUsize::new(0),
594 restart_policy: RestartPolicy::default(),
595 inbox,
596 })
597 }
598
599 pub fn with_restart_policy(mut self, policy: RestartPolicy) -> Self {
601 self.restart_policy = policy;
602 self
603 }
604
605 async fn spawn_state(
606 id: &str,
607 command: &str,
608 args: &[String],
609 cwd: Option<&PathBuf>,
610 inbox: Arc<Inbox>,
611 ) -> Result<ProcessState, String> {
612 let mut cmd = Command::new(command);
613 cmd.args(args)
614 .stdin(Stdio::piped())
615 .stdout(Stdio::piped())
616 .stderr(Stdio::piped());
617 if let Some(cwd) = cwd {
618 cmd.current_dir(cwd);
619 }
620
621 cmd.env_clear();
627 for var in &["PATH", "HOME", "LANG", "TERM", "XDG_RUNTIME_DIR"] {
628 if let Ok(val) = std::env::var(var) {
629 cmd.env(var, val);
630 }
631 }
632
633 cmd.kill_on_drop(true);
634
635 let mut child = cmd
636 .spawn()
637 .map_err(|e| format!("Failed to spawn extension '{}': {}", id, e))?;
638
639 let stdin = child
640 .stdin
641 .take()
642 .ok_or_else(|| format!("No stdin for extension '{}'", id))?;
643 let stdout = child
644 .stdout
645 .take()
646 .ok_or_else(|| format!("No stdout for extension '{}'", id))?;
647 if let Some(stderr) = child.stderr.take() {
648 let extension_id = id.to_string();
649 tokio::spawn(async move {
650 let mut lines = BufReader::new(stderr).lines();
651 loop {
652 match lines.next_line().await {
653 Ok(Some(line)) => {
654 tracing::debug!(extension = %extension_id, stderr = %line);
655 }
656 Ok(None) => break,
657 Err(error) => {
658 tracing::debug!(
659 extension = %extension_id,
660 error = %error,
661 "Failed to read extension stderr",
662 );
663 break;
664 }
665 }
666 }
667 });
668 }
669
670 let reader_handle = Self::spawn_reader(stdout, inbox.clone(), id.to_string());
671
672 let stdin_arc = Arc::new(Mutex::new(stdin));
673 *inbox.inbound_stdin.lock().await = Some(stdin_arc.clone());
676
677 Ok(ProcessState {
678 child,
679 stdin: stdin_arc,
680 reader_handle,
681 })
682 }
683
684 fn spawn_reader(
689 stdout: ChildStdout,
690 inbox: Arc<Inbox>,
691 extension_id: String,
692 ) -> JoinHandle<()> {
693 tokio::spawn(async move {
694 let mut reader = BufReader::new(stdout);
695 loop {
696 match Self::read_one_frame(&mut reader, &extension_id).await {
697 Ok(Some(value)) => {
698 Self::dispatch_frame(value, &inbox, &extension_id).await;
699 }
700 Ok(None) => {
701 tracing::debug!(
702 extension = %extension_id,
703 "Extension stdout closed (EOF); failing pending requests",
704 );
705 inbox.fail_all_pending("transport closed: EOF").await;
706 inbox.notification_sink.lock().await.take();
708 return;
709 }
710 Err(error) => {
711 tracing::debug!(
712 extension = %extension_id,
713 error = %error,
714 "Extension transport read error",
715 );
716 inbox
717 .fail_all_pending(&format!("transport error: {}", error))
718 .await;
719 inbox.notification_sink.lock().await.take();
720 return;
721 }
722 }
723 }
724 })
725 }
726
727 async fn read_one_frame(
731 reader: &mut BufReader<ChildStdout>,
732 extension_id: &str,
733 ) -> Result<Option<Value>, String> {
734 let mut content_length: Option<usize> = None;
735 let mut saw_any_header = false;
736 loop {
737 let mut header_line = String::new();
738 let n = reader
739 .read_line(&mut header_line)
740 .await
741 .map_err(|e| format!("Read header error: {}", e))?;
742 if n == 0 {
743 if saw_any_header {
744 return Err("Unexpected EOF while reading response headers".into());
745 }
746 return Ok(None);
747 }
748 saw_any_header = true;
749 if header_line.len() > 1024 {
750 return Err(format!(
751 "Extension '{}' header line too long ({} bytes)",
752 extension_id,
753 header_line.len()
754 ));
755 }
756 let trimmed = header_line.trim();
757 if trimmed.is_empty() {
758 break;
759 }
760 if let Some((name, value)) = trimmed.split_once(':') {
761 if name.trim().eq_ignore_ascii_case("Content-Length") {
762 content_length = Some(value.trim().parse().map_err(|_| {
763 format!("Invalid Content-Length value: {:?}", value.trim())
764 })?);
765 }
766 }
767 }
768 let content_length = content_length.ok_or_else(|| {
769 format!(
770 "Extension '{}' frame missing Content-Length header",
771 extension_id
772 )
773 })?;
774 const MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024;
775 if content_length > MAX_RESPONSE_SIZE {
776 return Err(format!(
777 "Extension '{}' frame too large: {} bytes (max {})",
778 extension_id, content_length, MAX_RESPONSE_SIZE
779 ));
780 }
781 let mut buf = vec![0u8; content_length];
782 tokio::io::AsyncReadExt::read_exact(reader, &mut buf)
783 .await
784 .map_err(|e| format!("Read body error: {}", e))?;
785 let value: Value = serde_json::from_slice(&buf)
786 .map_err(|e| format!("Parse frame error: {}", e))?;
787 Ok(Some(value))
788 }
789
790 async fn dispatch_frame(value: Value, inbox: &Arc<Inbox>, extension_id: &str) {
796 let id_field = value.get("id");
797 let id_is_present = !matches!(id_field, None | Some(Value::Null));
798 let method_field = value.get("method").and_then(Value::as_str).map(str::to_string);
799
800 if id_is_present && method_field.is_some() {
801 let id = match id_field.and_then(Value::as_u64) {
804 Some(id) => id,
805 None => {
806 tracing::trace!(
807 extension = %extension_id,
808 frame = %value,
809 "Discarding inbound request with non-numeric id",
810 );
811 return;
812 }
813 };
814 let method = method_field.unwrap();
815 let params = value.get("params").cloned().unwrap_or(Value::Null);
816 let inbox = inbox.clone();
817 let extension_id = extension_id.to_string();
818 tokio::spawn(async move {
819 let outcome = Self::handle_inbound_request(&inbox, &method, params).await;
820 let payload = match outcome {
821 Ok(result) => serde_json::json!({
822 "jsonrpc": "2.0",
823 "id": id,
824 "result": result,
825 }),
826 Err((code, message)) => serde_json::json!({
827 "jsonrpc": "2.0",
828 "id": id,
829 "error": {"code": code, "message": message},
830 }),
831 };
832 let stdin_handle = inbox.inbound_stdin.lock().await.clone();
833 if let Some(stdin) = stdin_handle {
834 let body = match serde_json::to_string(&payload) {
835 Ok(s) => s,
836 Err(error) => {
837 tracing::warn!(
838 extension = %extension_id,
839 error = %error,
840 "Failed to serialize inbound response",
841 );
842 return;
843 }
844 };
845 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
846 let mut stdin = stdin.lock().await;
847 if let Err(error) = stdin.write_all(frame.as_bytes()).await {
848 tracing::warn!(
849 extension = %extension_id,
850 error = %error,
851 "Failed to write inbound response",
852 );
853 return;
854 }
855 if let Err(error) = stdin.flush().await {
856 tracing::warn!(
857 extension = %extension_id,
858 error = %error,
859 "Failed to flush inbound response",
860 );
861 }
862 } else {
863 tracing::warn!(
864 extension = %extension_id,
865 "No stdin available to reply to inbound request",
866 );
867 }
868 });
869 return;
870 }
871
872 if id_is_present {
873 let id = match id_field.and_then(Value::as_u64) {
874 Some(id) => id,
875 None => {
876 tracing::trace!(
877 extension = %extension_id,
878 frame = %value,
879 "Discarding frame with non-numeric id",
880 );
881 return;
882 }
883 };
884 let sender = inbox.pending.lock().await.remove(&id);
885 match sender {
886 Some(tx) => {
887 let payload = if let Some(err) = value.get("error") {
888 let message = err
889 .get("message")
890 .and_then(Value::as_str)
891 .unwrap_or("unknown extension error")
892 .to_string();
893 Err(format!("Extension error: {}", message))
894 } else {
895 Ok(value
896 .get("result")
897 .cloned()
898 .unwrap_or(Value::Null))
899 };
900 let _ = tx.send(payload);
901 }
902 None => {
903 tracing::trace!(
904 extension = %extension_id,
905 id = id,
906 "Response with unknown id (no pending request); dropping",
907 );
908 }
909 }
910 } else if let Some(method) = value.get("method").and_then(Value::as_str) {
911 let params = value.get("params").cloned().unwrap_or(Value::Null);
912 let frame = NotificationFrame {
913 method: method.to_string(),
914 params,
915 };
916 let mut sink_guard = inbox.notification_sink.lock().await;
917 if let Some(sink) = sink_guard.as_ref() {
918 if sink.send(frame).is_err() {
919 sink_guard.take();
921 }
922 } else {
923 tracing::trace!(
924 extension = %extension_id,
925 method = %method,
926 "Notification with no active subscriber; dropping",
927 );
928 }
929 } else {
930 tracing::trace!(
931 extension = %extension_id,
932 frame = %value,
933 "Unrecognized frame; dropping",
934 );
935 }
936 }
937
938 pub fn restart_count(&self) -> usize {
939 self.restart_count.load(Ordering::Relaxed)
940 }
941
942 pub async fn set_permissions(&self, perms: crate::extensions::permissions::PermissionSet) {
945 *self.inbox.permissions.write().await = Some(perms);
946 }
947
948 async fn handle_inbound_request(
956 inbox: &Arc<Inbox>,
957 method: &str,
958 params: Value,
959 ) -> Result<Value, (i32, String)> {
960 use crate::extensions::permissions::Permission;
961 use crate::memory::store::{self, MemoryQuery};
962
963 match method {
964 "memory.append" => {
965 Self::require_permission(inbox, Permission::MemoryWrite, "memory.write").await?;
966 let namespace = Self::param_str(¶ms, "namespace")?;
967 Self::require_namespace_matches(inbox, &namespace).await?;
968 let content = Self::param_str(¶ms, "content")?;
969 let tags = match params.get("tags") {
970 None | Some(Value::Null) => Vec::new(),
971 Some(Value::Array(arr)) => {
972 let mut out = Vec::with_capacity(arr.len());
973 for v in arr {
974 match v.as_str() {
975 Some(s) => out.push(s.to_string()),
976 None => {
977 return Err((
978 -32602,
979 "tags must be an array of strings".to_string(),
980 ))
981 }
982 }
983 }
984 out
985 }
986 _ => {
987 return Err((
988 -32602,
989 "tags must be an array of strings".to_string(),
990 ))
991 }
992 };
993 let meta = match params.get("meta") {
994 None | Some(Value::Null) => None,
995 Some(v) => Some(v.clone()),
996 };
997 let record = store::new_record(namespace, content, tags, meta);
998 let timestamp_ms = record.timestamp_ms;
999 store::append(&record).map_err(|e| (-32000, e.to_string()))?;
1000 Ok(serde_json::json!({"ok": true, "timestamp_ms": timestamp_ms}))
1001 }
1002 "memory.query" => {
1003 Self::require_permission(inbox, Permission::MemoryRead, "memory.read").await?;
1004 let namespace = Self::param_str(¶ms, "namespace")?;
1005 Self::require_namespace_matches(inbox, &namespace).await?;
1006 let q = MemoryQuery {
1007 content_contains: params
1008 .get("content_contains")
1009 .and_then(Value::as_str)
1010 .map(str::to_string),
1011 tag_prefix: params
1012 .get("tag_prefix")
1013 .and_then(Value::as_str)
1014 .map(str::to_string),
1015 since_ms: params.get("since_ms").and_then(Value::as_u64),
1016 until_ms: params.get("until_ms").and_then(Value::as_u64),
1017 limit: params
1018 .get("limit")
1019 .and_then(Value::as_u64)
1020 .map(|n| n as usize),
1021 };
1022 let records = store::query(&namespace, &q).map_err(|e| (-32000, e.to_string()))?;
1023 Ok(serde_json::json!({"records": records}))
1024 }
1025 "config.get" => {
1026 let key = Self::param_str(¶ms, "key")?;
1027 Self::validate_config_key(&key)?;
1028 let value = crate::extensions::config_store::read_plugin_config(&inbox.extension_id, &key);
1029 Ok(serde_json::json!({"value": value}))
1030 }
1031 "config.set" => {
1032 Self::require_permission(inbox, Permission::ConfigWrite, "config.write").await?;
1033 let key = Self::param_str(¶ms, "key")?;
1034 Self::validate_config_key(&key)?;
1035 let value = Self::param_str(¶ms, "value")?;
1036 crate::extensions::config_store::write_plugin_config(&inbox.extension_id, &key, &value)
1037 .map_err(|e| (-32000, e.to_string()))?;
1038 Ok(serde_json::json!({"ok": true}))
1039 }
1040 "config.subscribe" => {
1041 Self::require_permission(inbox, Permission::ConfigSubscribe, "config.subscribe").await?;
1042 Ok(serde_json::json!({"ok": true}))
1046 }
1047 other => Err((-32601, format!("method not found: {other}"))),
1048 }
1049 }
1050
1051 async fn require_permission(
1052 inbox: &Arc<Inbox>,
1053 perm: crate::extensions::permissions::Permission,
1054 wire: &str,
1055 ) -> Result<(), (i32, String)> {
1056 let guard = inbox.permissions.read().await;
1057 match guard.as_ref() {
1058 Some(set) if set.has(perm) => Ok(()),
1059 _ => Err((
1060 -32602,
1061 format!("permission denied: {wire} required"),
1062 )),
1063 }
1064 }
1065
1066 async fn require_namespace_matches(
1067 inbox: &Arc<Inbox>,
1068 namespace: &str,
1069 ) -> Result<(), (i32, String)> {
1070 if namespace == inbox.extension_id {
1071 Ok(())
1072 } else {
1073 Err((
1074 -32602,
1075 format!(
1076 "namespace must equal extension id '{}' (got '{}')",
1077 inbox.extension_id, namespace
1078 ),
1079 ))
1080 }
1081 }
1082
1083 fn param_str(params: &Value, name: &str) -> Result<String, (i32, String)> {
1084 params
1085 .get(name)
1086 .and_then(Value::as_str)
1087 .map(str::to_string)
1088 .ok_or_else(|| (-32602, format!("missing or invalid '{name}' parameter")))
1089 }
1090
1091 fn validate_config_key(key: &str) -> Result<(), (i32, String)> {
1092 let trimmed = key.trim();
1093 if trimmed.is_empty() {
1094 return Err((-32602, "config key must be non-empty".to_string()));
1095 }
1096 if trimmed.contains('.') || trimmed.contains('/') || trimmed.contains(' ') {
1097 return Err((
1098 -32602,
1099 "config key must not contain dots, slashes, or spaces".to_string(),
1100 ));
1101 }
1102 Ok(())
1103 }
1104
1105 pub async fn initialize(&self, plugin_root: Option<PathBuf>, config: Value) -> Result<InitializeCapabilitiesResult, String> {
1106 let params = InitializeParams {
1107 synaps_version: env!("CARGO_PKG_VERSION"),
1108 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1109 plugin_id: self.id.clone(),
1110 plugin_root: plugin_root
1111 .or_else(|| self.cwd.clone())
1112 .map(|path| path.to_string_lossy().to_string()),
1113 config,
1114 };
1115 let value = self.call_no_restart("initialize", serde_json::to_value(params).map_err(|e| e.to_string())?).await?;
1116 Self::parse_initialize_result(&self.id, value)
1117 }
1118
1119 fn parse_initialize_result(id: &str, value: Value) -> Result<InitializeCapabilitiesResult, String> {
1120 let result: InitializeResult = serde_json::from_value(value)
1121 .map_err(|e| format!("Invalid initialize response from extension '{}': {}", id, e))?;
1122 if result.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
1123 return Err(format!(
1124 "Extension '{}' initialize returned unsupported protocol_version {} (supported: {})",
1125 id, result.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
1126 ));
1127 }
1128 Self::validate_registered_tool_specs(id, &result.capabilities.tools)?;
1129 Self::validate_registered_provider_specs(id, &result.capabilities.providers)?;
1130 Ok(InitializeCapabilitiesResult {
1131 tools: result.capabilities.tools,
1132 providers: result.capabilities.providers,
1133 capabilities: result.capabilities.capabilities,
1134 })
1135 }
1136
1137 fn validate_registered_tool_specs(id: &str, tools: &[RegisteredExtensionToolSpec]) -> Result<(), String> {
1138 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1139 let mut names = HashSet::new();
1140 for tool in tools {
1141 let name = tool.name.trim();
1142 if let Err(err) = validate_id_segment(name) {
1143 return Err(match err {
1144 IdValidationError::Empty => format!(
1145 "Extension '{}' registered a tool with an empty tool name",
1146 id
1147 ),
1148 IdValidationError::ContainsReserved { ch } => format!(
1149 "Extension '{}' registered tool '{}' with invalid tool name: '{}' is reserved",
1150 id, name, ch
1151 ),
1152 IdValidationError::TooLong { len, max } => format!(
1153 "Extension '{}' registered tool '{}' with invalid tool name: must be at most {} chars (got {})",
1154 id, name, max, len
1155 ),
1156 IdValidationError::ContainsWhitespace => format!(
1157 "Extension '{}' registered tool '{}' with invalid tool name: must not contain whitespace",
1158 id, name
1159 ),
1160 IdValidationError::ContainsControl { ch } => format!(
1161 "Extension '{}' registered tool '{}' with invalid tool name: contains control character U+{:04X}",
1162 id, name, ch as u32
1163 ),
1164 });
1165 }
1166 if !names.insert(name.to_string()) {
1167 return Err(format!("Extension '{}' registered duplicate tool name '{}'", id, name));
1168 }
1169 if tool.description.trim().is_empty() {
1170 return Err(format!(
1171 "Extension '{}' registered tool '{}' with an empty description",
1172 id, name,
1173 ));
1174 }
1175 if !tool.input_schema.is_object() {
1176 return Err(format!(
1177 "Extension '{}' registered tool '{}' with invalid input_schema: input_schema must be a JSON object",
1178 id, name,
1179 ));
1180 }
1181 }
1182 Ok(())
1183 }
1184
1185 fn validate_registered_provider_specs(id: &str, providers: &[RegisteredProviderSpec]) -> Result<(), String> {
1186 use crate::extensions::validation::{validate_id_segment, IdValidationError};
1187 for provider in providers {
1188 let provider_id = provider.id.trim();
1189 match validate_id_segment(provider_id) {
1190 Ok(()) => {
1191 if !Self::is_safe_provider_id(provider_id) {
1192 return Err(format!(
1193 "Extension '{}' registered provider '{}' with invalid provider id",
1194 id, provider_id
1195 ));
1196 }
1197 }
1198 Err(IdValidationError::Empty) => {
1199 return Err(format!(
1200 "Extension '{}' registered provider with empty provider id",
1201 id
1202 ));
1203 }
1204 Err(err) => {
1205 return Err(format!(
1206 "Extension '{}' registered provider '{}' with invalid provider id: {}",
1207 id, provider_id, err
1208 ));
1209 }
1210 }
1211 if provider.display_name.trim().is_empty() {
1212 return Err(format!(
1213 "Extension '{}' registered provider '{}' with empty display_name",
1214 id, provider_id,
1215 ));
1216 }
1217 if provider.description.trim().is_empty() {
1218 return Err(format!(
1219 "Extension '{}' registered provider '{}' with empty description",
1220 id, provider_id,
1221 ));
1222 }
1223 if provider.models.is_empty() {
1224 return Err(format!(
1225 "Extension '{}' registered provider '{}' must declare at least one model",
1226 id, provider_id,
1227 ));
1228 }
1229 let mut model_ids = HashSet::new();
1230 for model in &provider.models {
1231 let model_id = model.id.trim();
1232 if let Err(err) = validate_id_segment(model_id) {
1233 return Err(match err {
1234 IdValidationError::Empty => format!(
1235 "Extension '{}' registered provider '{}' with empty model id",
1236 id, provider_id
1237 ),
1238 IdValidationError::ContainsReserved { ch } => format!(
1239 "Extension '{}' registered provider '{}' with invalid model id '{}': '{}' is reserved",
1240 id, provider_id, model_id, ch
1241 ),
1242 IdValidationError::TooLong { len, max } => format!(
1243 "Extension '{}' registered provider '{}' with invalid model id '{}': must be at most {} chars (got {})",
1244 id, provider_id, model_id, max, len
1245 ),
1246 IdValidationError::ContainsWhitespace => format!(
1247 "Extension '{}' registered provider '{}' with invalid model id '{}': must not contain whitespace",
1248 id, provider_id, model_id
1249 ),
1250 IdValidationError::ContainsControl { ch } => format!(
1251 "Extension '{}' registered provider '{}' with invalid model id '{}': contains control character U+{:04X}",
1252 id, provider_id, model_id, ch as u32
1253 ),
1254 });
1255 }
1256 if !model_ids.insert(model_id.to_string()) {
1257 return Err(format!(
1258 "Extension '{}' registered provider '{}' with duplicate model id '{}'",
1259 id, provider_id, model_id,
1260 ));
1261 }
1262 }
1263 if let Some(config_schema) = &provider.config_schema {
1264 if !config_schema.is_object() {
1265 return Err(format!(
1266 "Extension '{}' registered provider '{}' with invalid config_schema: config_schema must be a JSON object",
1267 id, provider_id,
1268 ));
1269 }
1270 }
1271 }
1272 Ok(())
1273 }
1274
1275 fn is_safe_provider_id(id: &str) -> bool {
1276 !id.is_empty()
1277 && !id.contains(':')
1278 && id.chars().all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-' || c == '_')
1279 }
1280
1281 #[doc(hidden)]
1282 pub async fn initialize_for_test(&self, plugin_root: Option<PathBuf>) -> Result<(), String> {
1283 self.initialize(plugin_root, Value::Object(Default::default())).await.map(|_| ())
1284 }
1285
1286 async fn restart_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1287 let attempted = self.restart_count.fetch_add(1, Ordering::Relaxed) + 1;
1288 let max_attempts = self.restart_policy.max_attempts;
1289 if attempted > max_attempts as usize {
1290 *state = None;
1291 return Err(format!(
1292 "Extension '{}' exceeded restart limit ({})",
1293 self.id, max_attempts,
1294 ));
1295 }
1296
1297 if let Some(old) = state.take() {
1298 old.reader_handle.abort();
1299 let mut child = old.child;
1300 let _ = child.kill().await;
1301 }
1302 self.inbox
1304 .fail_all_pending("transport closed: process restarting")
1305 .await;
1306
1307 let delay = self
1308 .restart_policy
1309 .delay_for_attempt(attempted as u32)
1310 .unwrap_or_default();
1311
1312 tracing::warn!(
1313 extension = %self.id,
1314 attempt = attempted,
1315 max_attempts = max_attempts,
1316 delay_ms = delay.as_millis() as u64,
1317 "Restarting extension process after transport failure",
1318 );
1319
1320 if !delay.is_zero() {
1321 tokio::time::sleep(delay).await;
1322 }
1323
1324 *state = Some(Self::spawn_state(
1325 &self.id,
1326 &self.command,
1327 &self.args,
1328 self.cwd.as_ref(),
1329 self.inbox.clone(),
1330 ).await?);
1331 self.inbox.closed.store(false, std::sync::atomic::Ordering::Release);
1333 self.initialize_locked(state).await?;
1334 self.restart_count.store(0, Ordering::Relaxed);
1337 Ok(())
1338 }
1339
1340
1341 async fn initialize_locked(&self, state: &mut Option<ProcessState>) -> Result<(), String> {
1342 let params = InitializeParams {
1343 synaps_version: env!("CARGO_PKG_VERSION"),
1344 extension_protocol_version: CURRENT_EXTENSION_PROTOCOL_VERSION,
1345 plugin_id: self.id.clone(),
1346 plugin_root: self.cwd
1347 .clone()
1348 .map(|path| path.to_string_lossy().to_string()),
1349 config: Value::Object(Default::default()),
1350 };
1351 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1352 let value = tokio::time::timeout(
1353 std::time::Duration::from_secs(10),
1354 self.call_once_locked(
1355 state.as_mut().expect("state should exist for initialize"),
1356 "initialize",
1357 serde_json::to_value(params).map_err(|e| e.to_string())?,
1358 id,
1359 ),
1360 )
1361 .await
1362 .map_err(|_| format!("Extension '{}' initialize timed out after 10s", self.id))?
1363 ?;
1364 Self::parse_initialize_result(&self.id, value).map(|_| ())
1365 }
1366
1367 async fn call_once_locked(
1371 &self,
1372 state: &mut ProcessState,
1373 method: &str,
1374 params: Value,
1375 id: u64,
1376 ) -> Result<Value, String> {
1377 let body = serde_json::to_string(&JsonRpcRequest {
1378 jsonrpc: "2.0",
1379 method: method.to_string(),
1380 params,
1381 id,
1382 })
1383 .map_err(|e| format!("Serialize error: {}", e))?;
1384
1385 let (tx, rx) = oneshot::channel::<Result<Value, String>>();
1386 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1389 return Err("transport closed: inbox is shut down".to_string());
1390 }
1391
1392 self.inbox.pending.lock().await.insert(id, tx);
1395
1396 if self.inbox.closed.load(std::sync::atomic::Ordering::Acquire) {
1399 self.inbox.pending.lock().await.remove(&id);
1400 return Err("transport closed: inbox shut down during registration".to_string());
1401 }
1402
1403 let frame = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
1404 let write_result = {
1405 let mut stdin = state.stdin.lock().await;
1406 match stdin.write_all(frame.as_bytes()).await {
1407 Ok(()) => stdin.flush().await,
1408 Err(e) => Err(e),
1409 }
1410 };
1411 if let Err(e) = write_result {
1412 self.inbox.pending.lock().await.remove(&id);
1414 return Err(format!("Write error: {}", e));
1415 }
1416
1417 match rx.await {
1418 Ok(payload) => payload,
1419 Err(_) => {
1420 self.inbox.pending.lock().await.remove(&id);
1425 Err("transport closed: response channel dropped".to_string())
1426 }
1427 }
1428 }
1429
1430 async fn call_no_restart(&self, method: &str, params: Value) -> Result<Value, String> {
1431 let _call_guard = self.call_lock.lock().await;
1432 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1433 let mut state_guard = self.state.lock().await;
1434 if state_guard.is_none() {
1435 *state_guard = Some(Self::spawn_state(
1436 &self.id,
1437 &self.command,
1438 &self.args,
1439 self.cwd.as_ref(),
1440 self.inbox.clone(),
1441 ).await?);
1442 }
1443 self.call_once_locked(
1444 state_guard.as_mut().expect("state should exist"),
1445 method,
1446 params,
1447 id,
1448 ).await
1449 }
1450
1451 async fn call(&self, method: &str, params: Value) -> Result<Value, String> {
1452 let timeout_secs = if method == "tool.call" { 120 } else { 30 };
1453 let id_str = self.id.clone();
1454 let method_str = method.to_string();
1455
1456 let result = tokio::time::timeout(
1457 std::time::Duration::from_secs(timeout_secs),
1458 self.call_inner(method, params),
1459 )
1460 .await;
1461
1462 match result {
1463 Ok(inner) => inner,
1464 Err(_) => Err(format!(
1465 "Extension '{}' method '{}' timed out after {}s",
1466 id_str, method_str, timeout_secs
1467 )),
1468 }
1469 }
1470
1471 async fn call_inner(&self, method: &str, params: Value) -> Result<Value, String> {
1472 let _call_guard = self.call_lock.lock().await;
1473 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1474 let mut state_guard = self.state.lock().await;
1475 if state_guard.is_none() {
1476 self.restart_locked(&mut state_guard).await?;
1477 }
1478
1479 let result = self
1480 .call_once_locked(
1481 state_guard.as_mut().expect("state should exist after restart"),
1482 method,
1483 params.clone(),
1484 id,
1485 )
1486 .await;
1487
1488 match result {
1489 Ok(value) => Ok(value),
1490 Err(first_error) => {
1491 self.restart_locked(&mut state_guard).await?;
1492 let retry_id = self.next_id.fetch_add(1, Ordering::Relaxed);
1493 self.call_once_locked(
1494 state_guard.as_mut().expect("state should exist after restart"),
1495 method,
1496 params,
1497 retry_id,
1498 )
1499 .await
1500 .map_err(|retry_error| {
1501 format!("{}; retry after restart failed: {}", first_error, retry_error)
1502 })
1503 }
1504 }
1505 }
1506
1507 #[doc(hidden)]
1520 pub async fn subscribe_notifications(&self) -> mpsc::UnboundedReceiver<NotificationFrame> {
1521 let (tx, rx) = mpsc::unbounded_channel();
1522 let mut sink = self.inbox.notification_sink.lock().await;
1523 *sink = Some(tx);
1524 rx
1525 }
1526
1527 #[doc(hidden)]
1529 pub async fn unsubscribe_notifications(&self) {
1530 self.inbox.notification_sink.lock().await.take();
1531 }
1532
1533 pub(crate) fn forward_invoke_command_frame(
1543 extension_id: &str,
1544 request_id: &str,
1545 sink: &mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1546 sink_open: &mut bool,
1547 frame: NotificationFrame,
1548 ) -> bool {
1549 use crate::extensions::commands::parse_command_output;
1550 use crate::extensions::tasks::{is_task_method, parse_task_event};
1551 use crate::extensions::runtime::InvokeCommandEvent;
1552
1553 let mut saw_done = false;
1554 if frame.method == "command.output" {
1555 match parse_command_output(&frame.params) {
1556 Ok(parsed) if parsed.request_id == request_id => {
1557 if matches!(parsed.event, crate::extensions::commands::CommandOutputEvent::Done) {
1558 saw_done = true;
1559 }
1560 if *sink_open && sink.send(InvokeCommandEvent::Output(parsed.event)).is_err() {
1561 *sink_open = false;
1562 }
1563 }
1564 Ok(_) => {
1565 tracing::trace!(
1567 extension = %extension_id,
1568 "Ignoring command.output for unrelated request_id",
1569 );
1570 }
1571 Err(error) => {
1572 tracing::warn!(
1573 extension = %extension_id,
1574 error = %error,
1575 params = %frame.params,
1576 "Skipping malformed command.output notification",
1577 );
1578 }
1579 }
1580 } else if is_task_method(&frame.method) {
1581 match parse_task_event(&frame.method, &frame.params) {
1582 Ok(event) => {
1583 if *sink_open && sink.send(InvokeCommandEvent::Task(event)).is_err() {
1584 *sink_open = false;
1585 }
1586 }
1587 Err(error) => {
1588 tracing::warn!(
1589 extension = %extension_id,
1590 method = %frame.method,
1591 error = %error,
1592 params = %frame.params,
1593 "Skipping malformed task notification",
1594 );
1595 }
1596 }
1597 } else {
1598 tracing::trace!(
1599 extension = %extension_id,
1600 method = %frame.method,
1601 "Ignoring non-command/task notification during command.invoke",
1602 );
1603 }
1604 saw_done
1605 }
1606
1607 fn forward_provider_stream_frame(
1614 extension_id: &str,
1615 sink: &mpsc::UnboundedSender<ProviderStreamEvent>,
1616 sink_open: &mut bool,
1617 frame: NotificationFrame,
1618 ) {
1619 if frame.method != "provider.stream.event" {
1620 tracing::trace!(
1621 extension = %extension_id,
1622 method = %frame.method,
1623 "Ignoring non-stream notification during provider.stream",
1624 );
1625 return;
1626 }
1627 match parse_provider_stream_event(&frame.params) {
1628 Ok(event) => {
1629 if *sink_open && sink.send(event).is_err() {
1630 *sink_open = false;
1631 }
1632 }
1633 Err(error) => {
1634 tracing::warn!(
1635 extension = %extension_id,
1636 error = %error,
1637 params = %frame.params,
1638 "Skipping malformed provider.stream.event notification",
1639 );
1640 }
1641 }
1642 }
1643}
1644
1645#[async_trait]
1646impl ExtensionHandler for ProcessExtension {
1647 fn id(&self) -> &str {
1648 &self.id
1649 }
1650
1651 async fn call_tool(&self, name: &str, input: Value) -> Result<Value, String> {
1652 self.call("tool.call", serde_json::json!({
1653 "name": name,
1654 "input": input,
1655 })).await
1656 }
1657
1658 async fn provider_complete(&self, params: ProviderCompleteParams) -> Result<ProviderCompleteResult, String> {
1659 let value = tokio::time::timeout(
1660 std::time::Duration::from_secs(60),
1661 self.call("provider.complete", serde_json::to_value(params).map_err(|e| e.to_string())?),
1662 )
1663 .await
1664 .map_err(|_| format!("Extension '{}' provider.complete timed out", self.id))??;
1665 let result: ProviderCompleteResult = serde_json::from_value(value)
1666 .map_err(|e| format!("Invalid provider.complete response from extension '{}': {}", self.id, e))?;
1667 if result.content.is_empty() {
1668 return Err(format!("Extension '{}' provider.complete returned empty content", self.id));
1669 }
1670 Ok(result)
1671 }
1672
1673 async fn provider_stream(
1674 &self,
1675 params: ProviderCompleteParams,
1676 sink: tokio::sync::mpsc::UnboundedSender<ProviderStreamEvent>,
1677 ) -> Result<ProviderCompleteResult, String> {
1678 let mut rx = self.subscribe_notifications().await;
1681 let params_value =
1682 serde_json::to_value(params).map_err(|e| e.to_string())?;
1683
1684 let extension_id = self.id.clone();
1685 let stream_future = async {
1686 let mut call_fut = Box::pin(self.call("provider.stream", params_value));
1687 let mut sink_open = true;
1688 let response = loop {
1689 tokio::select! {
1690 response = &mut call_fut => break response,
1691 Some(frame) = rx.recv() => {
1692 Self::forward_provider_stream_frame(
1693 &extension_id, &sink, &mut sink_open, frame,
1694 );
1695 }
1696 }
1697 };
1698 self.unsubscribe_notifications().await;
1702 while let Some(frame) = rx.recv().await {
1703 Self::forward_provider_stream_frame(
1704 &extension_id, &sink, &mut sink_open, frame,
1705 );
1706 }
1707 response
1708 };
1709
1710 let outcome = tokio::time::timeout(
1711 std::time::Duration::from_secs(60),
1712 stream_future,
1713 )
1714 .await;
1715
1716 self.unsubscribe_notifications().await;
1718
1719 let value = outcome
1720 .map_err(|_| format!("Extension '{}' provider.stream timed out", self.id))??;
1721
1722 let result: ProviderCompleteResult = serde_json::from_value(value)
1723 .map_err(|e| {
1724 format!("Invalid provider.stream response from extension '{}': {}", self.id, e)
1725 })?;
1726 Ok(result)
1729 }
1730
1731 async fn invoke_command(
1732 &self,
1733 command: &str,
1734 args: Vec<String>,
1735 request_id: &str,
1736 sink: tokio::sync::mpsc::UnboundedSender<crate::extensions::runtime::InvokeCommandEvent>,
1737 ) -> Result<Value, String> {
1738 let mut rx = self.subscribe_notifications().await;
1740 let params = serde_json::json!({
1741 "command": command,
1742 "args": args,
1743 "request_id": request_id,
1744 });
1745
1746 let extension_id = self.id.clone();
1747 let request_id_owned = request_id.to_string();
1748 let invoke_future = async {
1749 let mut call_fut = Box::pin(self.call("command.invoke", params));
1750 let mut sink_open = true;
1751 let response = loop {
1752 tokio::select! {
1753 response = &mut call_fut => break response,
1754 Some(frame) = rx.recv() => {
1755 let _ = Self::forward_invoke_command_frame(
1756 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1757 );
1758 }
1759 }
1760 };
1761 self.unsubscribe_notifications().await;
1765 while let Ok(frame) = rx.try_recv() {
1766 let _ = Self::forward_invoke_command_frame(
1767 &extension_id, &request_id_owned, &sink, &mut sink_open, frame,
1768 );
1769 }
1770 response
1771 };
1772
1773 let outcome = tokio::time::timeout(
1774 std::time::Duration::from_secs(120),
1775 invoke_future,
1776 )
1777 .await;
1778
1779 self.unsubscribe_notifications().await;
1781
1782 outcome
1783 .map_err(|_| format!("Extension '{}' command.invoke timed out", self.id))?
1784 }
1785
1786 async fn handle(&self, event: &HookEvent) -> HookResult {
1787 let params = serde_json::to_value(event).unwrap_or(Value::Null);
1788 match tokio::time::timeout(std::time::Duration::from_secs(5), self.call("hook.handle", params)).await {
1789 Ok(Ok(value)) => match serde_json::from_value(value.clone()) {
1790 Ok(result) => result,
1791 Err(error) => {
1792 tracing::warn!(
1793 extension = %self.id,
1794 error = %error,
1795 response = %value,
1796 "Extension hook handler returned invalid result",
1797 );
1798 if value.get("action").and_then(Value::as_str) == Some("modify") {
1799 HookResult::Block {
1800 reason: "Extension returned malformed modify result".to_string(),
1801 }
1802 } else {
1803 HookResult::Continue
1804 }
1805 }
1806 },
1807 Ok(Err(e)) => {
1808 tracing::warn!(
1809 extension = %self.id,
1810 error = %e,
1811 "Extension hook handler failed — continuing",
1812 );
1813 HookResult::Continue
1814 }
1815 Err(_) => {
1816 tracing::warn!(
1817 extension = %self.id,
1818 timeout_secs = 5,
1819 "Extension hook handler timed out — continuing",
1820 );
1821 HookResult::Continue
1822 }
1823 }
1824 }
1825
1826 async fn get_info(&self) -> Result<crate::extensions::info::PluginInfo, String> {
1827 let value = tokio::time::timeout(
1828 std::time::Duration::from_secs(5),
1829 self.call("info.get", Value::Null),
1830 )
1831 .await
1832 .map_err(|_| format!("Extension '{}' info.get timed out", self.id))??;
1833 serde_json::from_value(value)
1834 .map_err(|e| format!("Invalid info.get response from extension '{}': {}", self.id, e))
1835 }
1836
1837 async fn sidecar_spawn_args(
1838 &self,
1839 ) -> Result<crate::sidecar::spawn::SidecarSpawnArgs, String> {
1840 let value = tokio::time::timeout(
1841 std::time::Duration::from_secs(5),
1842 self.call("sidecar.spawn_args", Value::Null),
1843 )
1844 .await
1845 .map_err(|_| format!("Extension '{}' sidecar.spawn_args timed out", self.id))??;
1846 serde_json::from_value(value).map_err(|e| {
1847 format!(
1848 "Invalid sidecar.spawn_args response from extension '{}': {}",
1849 self.id, e
1850 )
1851 })
1852 }
1853
1854 async fn settings_editor_open(&self, category: &str, field: &str) -> Result<Value, String> {
1855 let params = crate::extensions::settings_editor::SettingsEditorOpenParams {
1856 category: category.to_string(),
1857 field: field.to_string(),
1858 };
1859 tokio::time::timeout(
1860 std::time::Duration::from_secs(5),
1861 self.call("settings.editor.open", serde_json::to_value(params).map_err(|e| e.to_string())?),
1862 )
1863 .await
1864 .map_err(|_| format!("Extension '{}' settings.editor.open timed out", self.id))?
1865 }
1866
1867 async fn settings_editor_key(&self, category: &str, field: &str, key: &str) -> Result<Value, String> {
1868 let mut params = serde_json::to_value(crate::extensions::settings_editor::SettingsEditorKeyParams {
1869 key: key.to_string(),
1870 }).map_err(|e| e.to_string())?;
1871 if let Some(obj) = params.as_object_mut() {
1872 obj.insert("category".to_string(), Value::String(category.to_string()));
1873 obj.insert("field".to_string(), Value::String(field.to_string()));
1874 }
1875 tokio::time::timeout(
1876 std::time::Duration::from_secs(5),
1877 self.call("settings.editor.key", params),
1878 )
1879 .await
1880 .map_err(|_| format!("Extension '{}' settings.editor.key timed out", self.id))?
1881 }
1882
1883 async fn settings_editor_commit(&self, category: &str, field: &str, value: Value) -> Result<Value, String> {
1884 let params = serde_json::json!({
1885 "category": category,
1886 "field": field,
1887 "value": value,
1888 });
1889 tokio::time::timeout(
1890 std::time::Duration::from_secs(5),
1891 self.call("settings.editor.commit", params),
1892 )
1893 .await
1894 .map_err(|_| format!("Extension '{}' settings.editor.commit timed out", self.id))?
1895 }
1896
1897 async fn shutdown(&self) {
1898 let _ = tokio::time::timeout(
1899 std::time::Duration::from_millis(500),
1900 self.call("shutdown", Value::Null),
1901 )
1902 .await;
1903
1904 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1905 let mut state_guard = self.state.lock().await;
1906 if let Some(state) = state_guard.take() {
1907 state.reader_handle.abort();
1908 let mut child = state.child;
1909 let _ = child.kill().await;
1910 }
1911 self.inbox.notification_sink.lock().await.take();
1913 self.inbox
1914 .fail_all_pending("transport closed: extension shutdown")
1915 .await;
1916 }
1917
1918 async fn restart_count(&self) -> usize {
1919 self.restart_count()
1920 }
1921
1922 async fn health(&self) -> ExtensionHealth {
1923 let count = self.restart_count.load(Ordering::Relaxed);
1924 let max = self.restart_policy.max_attempts as usize;
1925 if count >= max {
1926 ExtensionHealth::Failed
1927 } else if count > 0 {
1928 let state_alive = self.state.try_lock().map(|g| g.is_some()).unwrap_or(true);
1932 if state_alive {
1933 ExtensionHealth::Degraded
1934 } else {
1935 ExtensionHealth::Restarting
1936 }
1937 } else {
1938 ExtensionHealth::Running
1939 }
1940 }
1941}
1942
1943#[cfg(test)]
1944mod stream_event_tests {
1945 use super::*;
1946 use serde_json::json;
1947
1948 #[test]
1949 fn parses_text_delta_with_delta_key() {
1950 let v = json!({"type": "text", "delta": "hi"});
1951 assert_eq!(
1952 parse_provider_stream_event(&v).unwrap(),
1953 ProviderStreamEvent::TextDelta { text: "hi".into() }
1954 );
1955 }
1956
1957 #[test]
1958 fn parses_text_delta_with_text_key() {
1959 let v = json!({"type": "text", "text": "hi"});
1960 assert_eq!(
1961 parse_provider_stream_event(&v).unwrap(),
1962 ProviderStreamEvent::TextDelta { text: "hi".into() }
1963 );
1964 }
1965
1966 #[test]
1967 fn parses_thinking_delta() {
1968 let v = json!({"type": "thinking", "delta": "hmm"});
1969 assert_eq!(
1970 parse_provider_stream_event(&v).unwrap(),
1971 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1972 );
1973 let v2 = json!({"type": "thinking", "text": "hmm"});
1974 assert_eq!(
1975 parse_provider_stream_event(&v2).unwrap(),
1976 ProviderStreamEvent::ThinkingDelta { text: "hmm".into() }
1977 );
1978 }
1979
1980 #[test]
1981 fn parses_tool_use() {
1982 let v = json!({
1983 "type": "tool_use",
1984 "id": "t1",
1985 "name": "echo",
1986 "input": {"x": 1}
1987 });
1988 assert_eq!(
1989 parse_provider_stream_event(&v).unwrap(),
1990 ProviderStreamEvent::ToolUse {
1991 id: "t1".into(),
1992 name: "echo".into(),
1993 input: json!({"x": 1}),
1994 }
1995 );
1996 }
1997
1998 #[test]
1999 fn tool_use_input_defaults_to_empty_object() {
2000 let v = json!({"type": "tool_use", "id": "t1", "name": "echo"});
2001 assert_eq!(
2002 parse_provider_stream_event(&v).unwrap(),
2003 ProviderStreamEvent::ToolUse {
2004 id: "t1".into(),
2005 name: "echo".into(),
2006 input: json!({}),
2007 }
2008 );
2009 }
2010
2011 #[test]
2012 fn parses_usage_strips_type() {
2013 let v = json!({"type": "usage", "input_tokens": 5, "output_tokens": 7});
2014 assert_eq!(
2015 parse_provider_stream_event(&v).unwrap(),
2016 ProviderStreamEvent::Usage {
2017 usage: json!({"input_tokens": 5, "output_tokens": 7})
2018 }
2019 );
2020 }
2021
2022 #[test]
2023 fn parses_error() {
2024 let v = json!({"type": "error", "message": "boom"});
2025 assert_eq!(
2026 parse_provider_stream_event(&v).unwrap(),
2027 ProviderStreamEvent::Error { message: "boom".into() }
2028 );
2029 }
2030
2031 #[test]
2032 fn parses_done() {
2033 let v = json!({"type": "done"});
2034 assert_eq!(
2035 parse_provider_stream_event(&v).unwrap(),
2036 ProviderStreamEvent::Done
2037 );
2038 }
2039
2040 #[test]
2041 fn nested_event_shape_matches_flat() {
2042 let flat = json!({"type": "text", "delta": "hi"});
2043 let nested = json!({"event": {"type": "text", "delta": "hi"}});
2044 assert_eq!(
2045 parse_provider_stream_event(&flat).unwrap(),
2046 parse_provider_stream_event(&nested).unwrap()
2047 );
2048 }
2049
2050 #[test]
2051 fn missing_type_errors() {
2052 let v = json!({"delta": "hi"});
2053 let err = parse_provider_stream_event(&v).unwrap_err();
2054 assert!(err.contains("missing type"), "got: {err}");
2055 }
2056
2057 #[test]
2058 fn unknown_type_errors_with_type() {
2059 let v = json!({"type": "wat"});
2060 let err = parse_provider_stream_event(&v).unwrap_err();
2061 assert!(err.contains("wat"), "got: {err}");
2062 }
2063
2064 #[test]
2065 fn tool_use_missing_id_errors() {
2066 let v = json!({"type": "tool_use", "name": "echo"});
2067 let err = parse_provider_stream_event(&v).unwrap_err();
2068 assert!(err.contains("id"), "got: {err}");
2069 }
2070
2071 #[test]
2072 fn tool_use_missing_name_errors() {
2073 let v = json!({"type": "tool_use", "id": "t1"});
2074 let err = parse_provider_stream_event(&v).unwrap_err();
2075 assert!(err.contains("name"), "got: {err}");
2076 }
2077
2078 #[test]
2079 fn tool_use_empty_id_errors() {
2080 let v = json!({"type": "tool_use", "id": "", "name": "echo"});
2081 assert!(parse_provider_stream_event(&v).is_err());
2082 }
2083
2084 #[test]
2085 fn tool_use_empty_name_errors() {
2086 let v = json!({"type": "tool_use", "id": "t1", "name": ""});
2087 assert!(parse_provider_stream_event(&v).is_err());
2088 }
2089
2090 #[test]
2091 fn tool_use_non_object_input_errors() {
2092 let v = json!({"type": "tool_use", "id": "t1", "name": "echo", "input": "nope"});
2093 let err = parse_provider_stream_event(&v).unwrap_err();
2094 assert!(err.contains("input"), "got: {err}");
2095 }
2096
2097 #[test]
2098 fn text_missing_delta_and_text_errors() {
2099 let v = json!({"type": "text"});
2100 let err = parse_provider_stream_event(&v).unwrap_err();
2101 assert!(err.contains("delta") || err.contains("text"), "got: {err}");
2102 }
2103
2104 #[test]
2105 fn error_missing_message_errors() {
2106 let v = json!({"type": "error"});
2107 assert!(parse_provider_stream_event(&v).is_err());
2108 }
2109
2110 #[test]
2111 fn error_empty_message_errors() {
2112 let v = json!({"type": "error", "message": ""});
2113 assert!(parse_provider_stream_event(&v).is_err());
2114 }
2115}
2116
2117#[cfg(test)]
2118mod restart_policy_tests {
2119 use super::*;
2120
2121 #[tokio::test]
2122 async fn restart_policy_default_max_attempts_is_3() {
2123 let ext = ProcessExtension::spawn("policy-test", "/bin/cat", &[])
2128 .await
2129 .expect("spawn /bin/cat");
2130 assert_eq!(ext.restart_policy.max_attempts, 3);
2131 ext.shutdown().await;
2132 }
2133
2134 #[tokio::test]
2135 async fn with_restart_policy_overrides_default() {
2136 let ext = ProcessExtension::spawn("policy-test-override", "/bin/cat", &[])
2137 .await
2138 .expect("spawn /bin/cat");
2139 let custom = RestartPolicy {
2140 max_attempts: 7,
2141 ..RestartPolicy::default()
2142 };
2143 let ext = ext.with_restart_policy(custom);
2144 assert_eq!(ext.restart_policy.max_attempts, 7);
2145 ext.shutdown().await;
2146 }
2147}
2148
2149#[cfg(test)]
2150mod capture_validator_tests {
2151 use super::*;
2152 use crate::extensions::permissions::{Permission, PermissionSet};
2153
2154 fn perms_with(grants: &[Permission]) -> PermissionSet {
2155 let mut p = PermissionSet::new();
2156 for g in grants {
2157 p.grant(*g);
2158 }
2159 p
2160 }
2161
2162 fn cap(kind: &str, name: &str, perms: &[&str]) -> CapabilityDeclaration {
2163 CapabilityDeclaration {
2164 kind: kind.to_string(),
2165 name: name.to_string(),
2166 permissions: perms.iter().map(|p| p.to_string()).collect(),
2167 params: serde_json::Value::Null,
2168 }
2169 }
2170
2171 #[test]
2172 fn capability_validator_rejects_empty_kind() {
2173 let d = cap(" ", "Sample", &["audio.input"]);
2174 let perms = perms_with(&[Permission::AudioInput]);
2175 let err = validate_capability(&d, &perms).unwrap_err();
2176 assert!(err.contains("kind"), "got: {}", err);
2177 }
2178
2179 #[test]
2180 fn capability_validator_rejects_empty_name() {
2181 let d = cap("capture", " ", &["audio.input"]);
2182 let perms = perms_with(&[Permission::AudioInput]);
2183 let err = validate_capability(&d, &perms).unwrap_err();
2184 assert!(err.contains("name"), "got: {}", err);
2185 }
2186
2187 #[test]
2188 fn capability_validator_rejects_unknown_permission_string() {
2189 let d = cap("capture", "Sample", &["audio.telepathy"]);
2190 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2191 let err = validate_capability(&d, &perms).unwrap_err();
2192 assert!(
2193 err.contains("unknown permission") && err.contains("audio.telepathy"),
2194 "got: {}",
2195 err,
2196 );
2197 }
2198
2199 #[test]
2200 fn capability_validator_requires_every_declared_permission() {
2201 let d = cap("capture", "Sample", &["audio.input"]);
2202 let perms = perms_with(&[]);
2203 let err = validate_capability(&d, &perms).unwrap_err();
2204 assert!(
2205 err.contains("audio.input") && err.contains("not granted"),
2206 "got: {}",
2207 err,
2208 );
2209 }
2210
2211 #[test]
2212 fn capability_validator_accepts_when_all_permissions_granted() {
2213 let d = cap("capture", "Sample", &["audio.input", "audio.output"]);
2214 let perms = perms_with(&[Permission::AudioInput, Permission::AudioOutput]);
2215 validate_capability(&d, &perms).expect("should validate");
2216 }
2217
2218 #[test]
2219 fn capability_validator_accepts_no_permissions() {
2220 let d = cap("ocr", "Tesseract", &[]);
2224 let perms = perms_with(&[]);
2225 validate_capability(&d, &perms).expect("should validate");
2226 }
2227
2228 #[test]
2229 fn capability_validator_does_not_branch_on_kind() {
2230 let perms = perms_with(&[Permission::AudioInput]);
2234 for kind in ["capture", "ocr", "agent", "foot_pedal", "eeg"] {
2235 let d = cap(kind, "Anything", &["audio.input"]);
2236 validate_capability(&d, &perms).expect("should validate");
2237 }
2238 }
2239
2240}
2241
2242#[cfg(test)]
2243mod invoke_command_dispatch_tests {
2244 use super::*;
2249 use crate::extensions::commands::CommandOutputEvent;
2250 use crate::extensions::runtime::InvokeCommandEvent;
2251 use crate::extensions::tasks::{TaskEvent, TaskKind};
2252 use serde_json::json;
2253 use tokio::sync::mpsc;
2254
2255 fn frame(method: &str, params: serde_json::Value) -> NotificationFrame {
2256 NotificationFrame {
2257 method: method.to_string(),
2258 params,
2259 }
2260 }
2261
2262 #[test]
2263 fn forwards_mixed_event_stream_in_order() {
2264 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2265 let mut open = true;
2266 let frames = vec![
2267 frame(
2268 "command.output",
2269 json!({"request_id":"r1","event":{"kind":"text","content":"A"}}),
2270 ),
2271 frame(
2272 "task.start",
2273 json!({"id":"dl","label":"Downloading","kind":"download"}),
2274 ),
2275 frame(
2276 "task.update",
2277 json!({"id":"dl","current":50,"total":100}),
2278 ),
2279 frame(
2280 "command.output",
2281 json!({"request_id":"r1","event":{"kind":"system","content":"working"}}),
2282 ),
2283 frame("task.done", json!({"id":"dl"})),
2284 frame(
2285 "command.output",
2286 json!({"request_id":"r1","event":{"kind":"done"}}),
2287 ),
2288 ];
2289
2290 let mut saw_done = false;
2291 for f in frames {
2292 saw_done |= ProcessExtension::forward_invoke_command_frame(
2293 "ext-test", "r1", &tx, &mut open, f,
2294 );
2295 }
2296 drop(tx);
2297 assert!(saw_done, "should have observed the command Done marker");
2298
2299 let mut events = Vec::new();
2300 while let Ok(ev) = rx.try_recv() {
2301 events.push(ev);
2302 }
2303 assert_eq!(events.len(), 6);
2304 assert_eq!(
2305 events[0],
2306 InvokeCommandEvent::Output(CommandOutputEvent::Text { content: "A".into() })
2307 );
2308 assert!(matches!(
2309 events[1],
2310 InvokeCommandEvent::Task(TaskEvent::Start { kind: TaskKind::Download, .. })
2311 ));
2312 assert!(matches!(
2313 events[2],
2314 InvokeCommandEvent::Task(TaskEvent::Update { .. })
2315 ));
2316 assert!(matches!(
2317 events[3],
2318 InvokeCommandEvent::Output(CommandOutputEvent::System { .. })
2319 ));
2320 assert!(matches!(
2321 events[4],
2322 InvokeCommandEvent::Task(TaskEvent::Done { error: None, .. })
2323 ));
2324 assert_eq!(events[5], InvokeCommandEvent::Output(CommandOutputEvent::Done));
2325 }
2326
2327 #[test]
2328 fn ignores_command_output_for_unrelated_request_id() {
2329 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2330 let mut open = true;
2331 ProcessExtension::forward_invoke_command_frame(
2332 "ext",
2333 "r1",
2334 &tx,
2335 &mut open,
2336 frame(
2337 "command.output",
2338 json!({"request_id":"other","event":{"kind":"text","content":"x"}}),
2339 ),
2340 );
2341 drop(tx);
2342 assert!(rx.try_recv().is_err());
2343 }
2344
2345 #[test]
2346 fn skips_malformed_command_output_without_aborting() {
2347 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2348 let mut open = true;
2349 ProcessExtension::forward_invoke_command_frame(
2351 "ext",
2352 "r1",
2353 &tx,
2354 &mut open,
2355 frame("command.output", json!({"request_id":"r1","event":{}})),
2356 );
2357 ProcessExtension::forward_invoke_command_frame(
2359 "ext",
2360 "r1",
2361 &tx,
2362 &mut open,
2363 frame(
2364 "command.output",
2365 json!({"request_id":"r1","event":{"kind":"done"}}),
2366 ),
2367 );
2368 drop(tx);
2369 let ev = rx.try_recv().unwrap();
2370 assert_eq!(ev, InvokeCommandEvent::Output(CommandOutputEvent::Done));
2371 assert!(rx.try_recv().is_err());
2372 }
2373
2374 #[test]
2375 fn task_events_pass_through_regardless_of_request_id() {
2376 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2377 let mut open = true;
2378 ProcessExtension::forward_invoke_command_frame(
2379 "ext",
2380 "r1",
2381 &tx,
2382 &mut open,
2383 frame("task.log", json!({"id":"abc","line":"..."})),
2384 );
2385 drop(tx);
2386 match rx.try_recv().unwrap() {
2387 InvokeCommandEvent::Task(TaskEvent::Log { id, line }) => {
2388 assert_eq!(id, "abc");
2389 assert_eq!(line, "...");
2390 }
2391 other => panic!("unexpected: {other:?}"),
2392 }
2393 }
2394
2395 #[test]
2396 fn unrelated_methods_are_dropped() {
2397 let (tx, mut rx) = mpsc::unbounded_channel::<InvokeCommandEvent>();
2398 let mut open = true;
2399 ProcessExtension::forward_invoke_command_frame(
2400 "ext",
2401 "r1",
2402 &tx,
2403 &mut open,
2404 frame("provider.stream.event", json!({"type":"text","delta":"x"})),
2405 );
2406 drop(tx);
2407 assert!(rx.try_recv().is_err());
2408 }
2409}