1use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{Mutex, RwLock};
17use tracing::{debug, error, info};
18
19use crate::errors::{ComputerError, ComputerResult};
20use crate::inputs::handler::InputHandler;
21use crate::inputs::model::InputValue;
22use crate::inputs::utils::run_command;
23use crate::mcp_clients::{
24 manager::MCPServerManager,
25 model::{
26 content_as_text, is_call_tool_error, CallToolResult, Content, MCPServerConfig,
27 MCPServerInput, Tool,
28 },
29 ConfigRender, RenderError,
30};
31use crate::socketio_client::SmcpComputerClient;
32
33type ConfirmCallbackType = Arc<dyn Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync>;
35
36fn input_value_to_json(value: InputValue) -> serde_json::Value {
38 match value {
39 InputValue::String(s) => serde_json::Value::String(s),
40 InputValue::Number(n) => serde_json::Value::Number(serde_json::Number::from(n)),
41 InputValue::Float(f) => serde_json::Value::Number(
42 serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
43 ),
44 InputValue::Bool(b) => serde_json::Value::Bool(b),
45 }
46}
47
48fn json_to_input_value(value: serde_json::Value) -> ComputerResult<InputValue> {
50 match value {
51 serde_json::Value::String(s) => Ok(InputValue::String(s)),
52 serde_json::Value::Number(n) => {
53 if let Some(i) = n.as_i64() {
54 Ok(InputValue::Number(i))
55 } else if let Some(u) = n.as_u64() {
56 Ok(InputValue::Number(u as i64))
57 } else if let Some(f) = n.as_f64() {
58 Ok(InputValue::Float(f))
59 } else {
60 Err(ComputerError::ValidationError(
61 "Invalid number value".to_string(),
62 ))
63 }
64 }
65 serde_json::Value::Bool(b) => Ok(InputValue::Bool(b)),
66 serde_json::Value::Null => Err(ComputerError::ValidationError(
67 "Null value not supported".to_string(),
68 )),
69 _ => Err(ComputerError::ValidationError(
70 "Unsupported value type".to_string(),
71 )),
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ToolCallRecord {
78 pub timestamp: DateTime<Utc>,
80 pub req_id: String,
82 pub server: String,
84 pub tool: String,
86 pub parameters: serde_json::Value,
88 pub timeout: Option<f64>,
90 pub success: bool,
92 pub error: Option<String>,
94}
95
96#[async_trait]
99pub trait Session: Send + Sync {
100 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value>;
102
103 fn session_id(&self) -> &str;
105}
106
107pub struct SilentSession {
109 id: String,
110}
111
112impl SilentSession {
113 pub fn new(id: impl Into<String>) -> Self {
115 Self { id: id.into() }
116 }
117}
118
119#[async_trait]
120impl Session for SilentSession {
121 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value> {
122 match input {
124 MCPServerInput::PromptString(input) => Ok(serde_json::Value::String(
125 input.default.clone().unwrap_or_default(),
126 )),
127 MCPServerInput::PickString(input) => Ok(serde_json::Value::String(
128 input
129 .default
130 .clone()
131 .unwrap_or_else(|| input.options.first().cloned().unwrap_or_default()),
132 )),
133 MCPServerInput::Command(input) => {
134 let args: Vec<String> = input
136 .args
137 .as_ref()
138 .map(|m| {
139 let mut sorted_pairs: Vec<_> = m.iter().collect();
140 sorted_pairs.sort_by_key(|(k, _)| *k);
141 sorted_pairs.into_iter().map(|(_, v)| v.clone()).collect()
142 })
143 .unwrap_or_default();
144 match run_command(&input.command, &args).await {
145 Ok(output) => Ok(serde_json::Value::String(output)),
146 Err(e) => Err(ComputerError::RuntimeError(format!(
147 "Failed to execute command '{}': {}",
148 input.command, e
149 ))),
150 }
151 }
152 }
153 }
154
155 fn session_id(&self) -> &str {
156 &self.id
157 }
158}
159
160pub struct Computer<S: Session> {
162 name: String,
164 mcp_manager: Arc<RwLock<Option<MCPServerManager>>>,
166 inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
170 mcp_servers: RwLock<HashMap<String, MCPServerConfig>>,
172 input_handler: Arc<RwLock<InputHandler>>,
174 auto_connect: bool,
176 auto_reconnect: bool,
178 tool_history: Arc<Mutex<Vec<ToolCallRecord>>>,
180 session: S,
182 socketio_client: Arc<RwLock<Option<Arc<SmcpComputerClient>>>>,
186 confirm_callback: Option<ConfirmCallbackType>,
188}
189
190impl<S: Session> Computer<S> {
191 pub fn new(
193 name: impl Into<String>,
194 session: S,
195 inputs: Option<HashMap<String, MCPServerInput>>,
196 mcp_servers: Option<HashMap<String, MCPServerConfig>>,
197 auto_connect: bool,
198 auto_reconnect: bool,
199 ) -> Self {
200 let name = name.into();
201 let inputs = inputs.unwrap_or_default();
202 let mcp_servers = mcp_servers.unwrap_or_default();
203
204 Self {
205 name,
206 mcp_manager: Arc::new(RwLock::new(None)),
207 inputs: Arc::new(RwLock::new(inputs)),
208 mcp_servers: RwLock::new(mcp_servers),
209 input_handler: Arc::new(RwLock::new(InputHandler::new())),
210 auto_connect,
211 auto_reconnect,
212 tool_history: Arc::new(Mutex::new(Vec::new())),
213 session,
214 socketio_client: Arc::new(RwLock::new(None)),
215 confirm_callback: None,
216 }
217 }
218
219 pub fn with_confirm_callback<F>(mut self, callback: F) -> Self
221 where
222 F: Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync + 'static,
223 {
224 self.confirm_callback = Some(Arc::new(callback));
225 self
226 }
227
228 pub fn name(&self) -> &str {
230 &self.name
231 }
232
233 pub fn get_socketio_client(&self) -> Arc<RwLock<Option<Arc<SmcpComputerClient>>>> {
237 self.socketio_client.clone()
238 }
239
240 pub async fn boot_up(&self) -> ComputerResult<()> {
242 info!("Starting Computer: {}", self.name);
243
244 let manager = MCPServerManager::new();
246
247 let servers = self.mcp_servers.read().await;
249 let mut validated_servers = Vec::new();
250
251 for (_name, server_config) in servers.iter() {
252 match self.render_server_config(server_config).await {
253 Ok(validated) => validated_servers.push(validated),
254 Err(e) => {
255 error!(
256 "Failed to render server config {}: {}",
257 server_config.name(),
258 e
259 );
260 validated_servers.push(server_config.clone());
262 }
263 }
264 }
265
266 manager.initialize(validated_servers).await?;
268
269 *self.mcp_manager.write().await = Some(manager);
271
272 info!("Computer {} started successfully", self.name);
273 Ok(())
274 }
275
276 async fn render_server_config(
280 &self,
281 config: &MCPServerConfig,
282 ) -> ComputerResult<MCPServerConfig> {
283 let config_json = serde_json::to_value(config)?;
285
286 let renderer = ConfigRender::default();
288
289 let inputs = self.inputs.read().await;
291 let inputs_clone: std::collections::HashMap<String, MCPServerInput> = inputs.clone();
292 drop(inputs); let mut resolved_values: std::collections::HashMap<String, serde_json::Value> =
298 std::collections::HashMap::new();
299 for (input_id, input) in inputs_clone.iter() {
300 match self.session.resolve_input(input).await {
301 Ok(value) => {
302 resolved_values.insert(input_id.clone(), value);
303 }
304 Err(e) => {
305 debug!(
306 "Failed to resolve input '{}': {}, will use default",
307 input_id, e
308 );
309 if let Some(default) = input.default() {
311 resolved_values.insert(input_id.clone(), default);
312 }
313 }
314 }
315 }
316
317 let resolver = |input_id: String| {
319 let values = resolved_values.clone();
320 async move {
321 if let Some(value) = values.get(&input_id) {
322 Ok(value.clone())
323 } else {
324 Err(RenderError::InputNotFound(input_id))
325 }
326 }
327 };
328
329 let rendered_json = renderer.render(config_json, resolver).await?;
331
332 let rendered_config: MCPServerConfig = serde_json::from_value(rendered_json)?;
334
335 Ok(rendered_config)
336 }
337
338 pub async fn add_or_update_server(&self, server: MCPServerConfig) -> ComputerResult<()> {
340 {
342 let mut manager_guard = self.mcp_manager.write().await;
343 if manager_guard.is_none() {
344 *manager_guard = Some(MCPServerManager::new());
345 }
346 }
347
348 let validated = self.render_server_config(&server).await?;
350
351 let manager = self.mcp_manager.read().await;
353 if let Some(ref manager) = *manager {
354 manager.add_or_update_server(validated).await?;
355 }
356
357 {
359 let mut servers = self.mcp_servers.write().await;
360 servers.insert(server.name().to_string(), server);
361 }
362
363 let _ = self.emit_update_config().await;
365
366 Ok(())
367 }
368
369 pub async fn remove_server(&self, server_name: &str) -> ComputerResult<()> {
371 let manager = self.mcp_manager.read().await;
372 if let Some(ref manager) = *manager {
373 manager.remove_server(server_name).await?;
374 }
375
376 {
378 let mut servers = self.mcp_servers.write().await;
379 servers.remove(server_name);
380 }
381
382 let _ = self.emit_update_config().await;
384
385 Ok(())
386 }
387
388 pub async fn update_inputs(
390 &self,
391 inputs: HashMap<String, MCPServerInput>,
392 ) -> ComputerResult<()> {
393 *self.inputs.write().await = inputs;
394
395 {
397 let mut input_handler = self.input_handler.write().await;
398 *input_handler = InputHandler::new();
399 }
400
401 let _ = self.emit_update_config().await;
403
404 Ok(())
405 }
406
407 pub async fn add_or_update_input(&self, input: MCPServerInput) -> ComputerResult<()> {
409 let input_id = input.id().to_string();
410 {
411 let mut inputs = self.inputs.write().await;
412 inputs.insert(input_id.clone(), input);
413 }
414
415 self.clear_input_values(Some(&input_id)).await?;
417
418 let _ = self.emit_update_config().await;
420
421 Ok(())
422 }
423
424 pub async fn remove_input(&self, input_id: &str) -> ComputerResult<bool> {
426 let removed = {
427 let mut inputs = self.inputs.write().await;
428 inputs.remove(input_id).is_some()
429 };
430
431 if removed {
432 self.clear_input_values(Some(input_id)).await?;
434
435 let _ = self.emit_update_config().await;
437 }
438
439 Ok(removed)
440 }
441
442 pub async fn get_input(&self, input_id: &str) -> ComputerResult<Option<MCPServerInput>> {
444 let inputs = self.inputs.read().await;
445 Ok(inputs.get(input_id).cloned())
446 }
447
448 pub async fn list_inputs(&self) -> ComputerResult<Vec<MCPServerInput>> {
450 let inputs = self.inputs.read().await;
451 Ok(inputs.values().cloned().collect())
452 }
453
454 pub async fn get_input_value(
456 &self,
457 input_id: &str,
458 ) -> ComputerResult<Option<serde_json::Value>> {
459 let handler = self.input_handler.read().await;
461 let cached_values = handler.get_all_cached_values().await;
462
463 for (key, value) in cached_values {
465 if key.starts_with(input_id) {
468 let parts: Vec<&str> = key.split(':').collect();
470 if !parts.is_empty() && parts[0] == input_id {
471 return Ok(Some(input_value_to_json(value)));
472 }
473 }
474 }
475
476 Ok(None)
477 }
478
479 pub async fn set_input_value(
481 &self,
482 input_id: &str,
483 value: serde_json::Value,
484 ) -> ComputerResult<bool> {
485 {
487 let inputs = self.inputs.read().await;
488 if !inputs.contains_key(input_id) {
489 return Ok(false);
490 }
491 }
492
493 let handler = self.input_handler.read().await;
495 let input_value = json_to_input_value(value)?;
496 handler
497 .set_cached_value(input_id.to_string(), input_value)
498 .await;
499
500 Ok(true)
501 }
502
503 pub async fn remove_input_value(&self, input_id: &str) -> ComputerResult<bool> {
505 let handler = self.input_handler.read().await;
506 let removed = handler.remove_cached_value(input_id).await.is_some();
507 Ok(removed)
508 }
509
510 pub async fn list_input_values(&self) -> ComputerResult<HashMap<String, serde_json::Value>> {
512 let handler = self.input_handler.read().await;
513 let cached_values = handler.get_all_cached_values().await;
514
515 let mut result = HashMap::new();
516 for (key, value) in cached_values {
517 let parts: Vec<&str> = key.split(':').collect();
520 if !parts.is_empty() {
521 result.insert(parts[0].to_string(), input_value_to_json(value));
522 }
523 }
524
525 Ok(result)
526 }
527
528 pub async fn clear_input_values(&self, input_id: Option<&str>) -> ComputerResult<()> {
530 let handler = self.input_handler.read().await;
531
532 if let Some(id) = input_id {
533 let cached_values = handler.get_all_cached_values().await;
535 let keys_to_remove: Vec<String> = cached_values
536 .keys()
537 .filter(|key| key.starts_with(id))
538 .cloned()
539 .collect();
540
541 for key in keys_to_remove {
542 handler.remove_cached_value(&key).await;
543 }
544 } else {
545 handler.clear_all_cache().await;
547 }
548
549 Ok(())
550 }
551
552 pub async fn get_available_tools(&self) -> ComputerResult<Vec<Tool>> {
554 let manager = self.mcp_manager.read().await;
555 if let Some(ref manager) = *manager {
556 let tools: Vec<Tool> = manager.list_available_tools().await;
557 Ok(tools)
561 } else {
562 Err(ComputerError::InvalidState(
563 "Computer not initialized".to_string(),
564 ))
565 }
566 }
567
568 pub async fn execute_tool(
570 &self,
571 req_id: &str,
572 tool_name: &str,
573 parameters: serde_json::Value,
574 timeout: Option<f64>,
575 ) -> ComputerResult<CallToolResult> {
576 let manager = self.mcp_manager.read().await;
577 if let Some(ref manager) = *manager {
578 let (server_name, tool_name) =
580 manager.validate_tool_call(tool_name, ¶meters).await?;
581 let server_name = server_name.to_string();
582 let tool_name = tool_name.to_string();
583
584 let timestamp = Utc::now();
585 let mut success = false;
586 let mut error_msg = None;
587 let result: CallToolResult;
588
589 let need_confirm = true; let parameters_for_call = parameters.clone();
595
596 if need_confirm {
597 if let Some(ref callback) = self.confirm_callback {
598 let confirmed = callback(req_id, &server_name, &tool_name, ¶meters);
599 if confirmed {
600 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
601 result = manager
602 .call_tool(
603 &server_name,
604 &tool_name,
605 parameters_for_call,
606 timeout_duration,
607 )
608 .await?;
609 success = !is_call_tool_error(&result);
610 } else {
611 result = CallToolResult::success(vec![Content::text(
612 "工具调用二次确认被拒绝,请稍后再试",
613 )]);
614 }
615 } else {
616 result = CallToolResult::error(vec![Content::text(
617 "当前工具需要调用前进行二次确认,但客户端目前没有实现二次确认回调方法",
618 )]);
619 error_msg = Some("No confirmation callback".to_string());
620 }
621 } else {
622 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
623 result = manager
624 .call_tool(
625 &server_name,
626 &tool_name,
627 parameters_for_call,
628 timeout_duration,
629 )
630 .await?;
631 success = !is_call_tool_error(&result);
632 }
633
634 if is_call_tool_error(&result) {
635 error_msg = result
636 .content
637 .iter()
638 .find_map(|c| content_as_text(c).map(|t| t.to_string()));
639 }
640
641 let record = ToolCallRecord {
643 timestamp,
644 req_id: req_id.to_string(),
645 server: server_name,
646 tool: tool_name,
647 parameters,
648 timeout,
649 success,
650 error: error_msg,
651 };
652
653 {
654 let mut history = self.tool_history.lock().await;
655 history.push(record);
656 if history.len() > 10 {
658 history.remove(0);
659 }
660 }
661
662 Ok(result)
663 } else {
664 Err(ComputerError::InvalidState(
665 "Computer not initialized".to_string(),
666 ))
667 }
668 }
669
670 pub async fn get_tool_history(&self) -> ComputerResult<Vec<ToolCallRecord>> {
672 let history = self.tool_history.lock().await;
673 Ok(history.clone())
674 }
675
676 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
678 let manager_guard = self.mcp_manager.read().await;
679 if let Some(ref manager) = *manager_guard {
680 manager.get_server_status().await
681 } else {
682 Vec::new()
683 }
684 }
685
686 pub async fn list_mcp_servers(&self) -> Vec<MCPServerConfig> {
688 let servers = self.mcp_servers.read().await;
689 servers.values().cloned().collect()
690 }
691
692 pub async fn start_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
694 let manager_guard = self.mcp_manager.read().await;
695 if let Some(ref manager) = *manager_guard {
696 if server_name == "all" {
697 manager.start_all().await
698 } else {
699 manager.start_client(server_name).await
700 }
701 } else {
702 Err(ComputerError::InvalidState(
703 "MCP Manager not initialized".to_string(),
704 ))
705 }
706 }
707
708 pub async fn stop_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
710 let manager_guard = self.mcp_manager.read().await;
711 if let Some(ref manager) = *manager_guard {
712 if server_name == "all" {
713 manager.stop_all().await
714 } else {
715 manager.stop_client(server_name).await
716 }
717 } else {
718 Err(ComputerError::InvalidState(
719 "MCP Manager not initialized".to_string(),
720 ))
721 }
722 }
723
724 pub async fn is_mcp_manager_initialized(&self) -> bool {
726 let manager_guard = self.mcp_manager.read().await;
727 manager_guard.is_some()
728 }
729
730 pub async fn set_socketio_client(&self, client: Arc<SmcpComputerClient>) {
734 let mut socketio_ref = self.socketio_client.write().await;
735 *socketio_ref = Some(client);
738 }
739
740 pub async fn connect_socketio(
742 &self,
743 url: &str,
744 _namespace: &str,
745 auth: &Option<String>,
746 _headers: &Option<String>,
747 ) -> ComputerResult<()> {
748 let _manager_check = {
750 let manager_guard = self.mcp_manager.read().await;
751 match manager_guard.as_ref() {
752 Some(_m) => {
753 true
756 }
757 None => {
758 return Err(ComputerError::InvalidState(
759 "MCP Manager not initialized. Please add and start servers first."
760 .to_string(),
761 ));
762 }
763 }
764 };
765
766 let new_manager = MCPServerManager::new();
771
772 let client = SmcpComputerClient::new(
775 url,
776 Arc::new(RwLock::new(Some(new_manager))),
777 self.name.clone(),
778 auth.clone(),
779 self.inputs.clone(),
780 )
781 .await?;
782
783 let client_arc = Arc::new(client);
785 self.set_socketio_client(client_arc.clone()).await;
786
787 info!(
788 "Connected to SMCP server at {} with computer name: {}",
789 url, self.name
790 );
791
792 Ok(())
793 }
794
795 pub async fn disconnect_socketio(&self) -> ComputerResult<()> {
797 let mut socketio_ref = self.socketio_client.write().await;
798 *socketio_ref = None;
799 info!("Disconnected from server");
800 Ok(())
801 }
802
803 pub async fn join_office(&self, office_id: &str, _computer_name: &str) -> ComputerResult<()> {
805 let socketio_ref = self.socketio_client.read().await;
806 if let Some(ref client) = *socketio_ref {
807 client.join_office(office_id).await?;
810 return Ok(());
811 }
812 Err(ComputerError::InvalidState(
813 "Socket.IO client not connected".to_string(),
814 ))
815 }
816
817 pub async fn leave_office(&self) -> ComputerResult<()> {
819 let socketio_ref = self.socketio_client.read().await;
820 if let Some(ref client) = *socketio_ref {
821 let current_office_id = client.get_current_office_id().await?;
824 client.leave_office(¤t_office_id).await?;
825 return Ok(());
826 }
827 Err(ComputerError::InvalidState(
828 "Socket.IO client not connected".to_string(),
829 ))
830 }
831
832 pub async fn emit_update_config(&self) -> ComputerResult<()> {
834 let socketio_ref = self.socketio_client.read().await;
835 if let Some(ref client) = *socketio_ref {
836 client.emit_update_config().await?;
839 return Ok(());
840 }
841 Err(ComputerError::InvalidState(
842 "Socket.IO client not connected".to_string(),
843 ))
844 }
845
846 pub async fn shutdown(&self) -> ComputerResult<()> {
848 info!("Shutting down Computer: {}", self.name);
849
850 let mut manager_guard = self.mcp_manager.write().await;
851 if let Some(manager) = manager_guard.take() {
852 manager.stop_all().await?;
853 }
854
855 {
857 let mut socketio_ref = self.socketio_client.write().await;
858 *socketio_ref = None;
859 }
860
861 info!("Computer {} shutdown successfully", self.name);
862 Ok(())
863 }
864}
865
866impl<S: Session + Clone> Clone for Computer<S> {
868 fn clone(&self) -> Self {
869 Self {
870 name: self.name.clone(),
871 mcp_manager: Arc::clone(&self.mcp_manager),
872 inputs: Arc::new(RwLock::new(HashMap::new())), mcp_servers: RwLock::new(HashMap::new()),
874 input_handler: Arc::clone(&self.input_handler),
875 auto_connect: self.auto_connect,
876 auto_reconnect: self.auto_reconnect,
877 tool_history: Arc::clone(&self.tool_history),
878 session: self.session.clone(),
879 socketio_client: Arc::clone(&self.socketio_client),
880 confirm_callback: self.confirm_callback.clone(),
881 }
882 }
883}
884
885#[async_trait]
887pub trait ManagerChangeHandler: Send + Sync {
888 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()>;
890}
891
892#[derive(Debug, Clone)]
894pub enum ManagerChangeMessage {
895 ToolListChanged,
897 ResourceListChanged { windows: Vec<String> },
899 ResourceUpdated { uri: String },
901}
902
903#[async_trait]
904impl<S: Session> ManagerChangeHandler for Computer<S> {
905 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()> {
906 match message {
907 ManagerChangeMessage::ToolListChanged => {
908 debug!("Tool list changed, notifying Socket.IO client");
909 let socketio_ref = self.socketio_client.read().await;
910 if let Some(ref client) = *socketio_ref {
911 client.emit_update_tool_list().await?;
914 }
915 }
916 ManagerChangeMessage::ResourceListChanged { windows: _ } => {
917 debug!("Resource list changed, checking for window updates");
918 }
920 ManagerChangeMessage::ResourceUpdated { uri } => {
921 debug!("Resource updated: {}", uri);
922 }
924 }
925 Ok(())
926 }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932 use crate::mcp_clients::model::{
933 CommandInput, MCPServerConfig, MCPServerInput, PickStringInput, PromptStringInput,
934 StdioServerConfig, StdioServerParameters,
935 };
936
937 #[tokio::test]
938 async fn test_computer_creation() {
939 let session = SilentSession::new("test");
940 let computer = Computer::new("test_computer", session, None, None, true, true);
941
942 assert_eq!(computer.name, "test_computer");
943 assert!(computer.auto_connect);
944 assert!(computer.auto_reconnect);
945 }
946
947 #[tokio::test]
948 async fn test_computer_with_initial_inputs_and_servers() {
949 let session = SilentSession::new("test");
950 let mut inputs = HashMap::new();
951 inputs.insert(
952 "input1".to_string(),
953 MCPServerInput::PromptString(PromptStringInput {
954 id: "input1".to_string(),
955 description: "Test input".to_string(),
956 default: Some("default".to_string()),
957 password: Some(false),
958 }),
959 );
960
961 let mut servers = HashMap::new();
962 servers.insert(
963 "server1".to_string(),
964 MCPServerConfig::Stdio(StdioServerConfig {
965 name: "server1".to_string(),
966 disabled: false,
967 forbidden_tools: vec![],
968 tool_meta: std::collections::HashMap::new(),
969 default_tool_meta: None,
970 vrl: None,
971 server_parameters: StdioServerParameters {
972 command: "echo".to_string(),
973 args: vec![],
974 env: std::collections::HashMap::new(),
975 cwd: None,
976 },
977 }),
978 );
979
980 let computer = Computer::new(
981 "test_computer",
982 session,
983 Some(inputs),
984 Some(servers),
985 false,
986 false,
987 );
988
989 let inputs = computer.list_inputs().await.unwrap();
991 assert_eq!(inputs.len(), 1);
992 match &inputs[0] {
993 MCPServerInput::PromptString(input) => {
994 assert_eq!(input.id, "input1");
995 assert_eq!(input.description, "Test input");
996 }
997 _ => panic!("Expected PromptString input"),
998 }
999 }
1000
1001 #[tokio::test]
1002 async fn test_input_management() {
1003 let session = SilentSession::new("test");
1004 let computer = Computer::new("test_computer", session, None, None, true, true);
1005
1006 let input = MCPServerInput::PromptString(PromptStringInput {
1008 id: "test_input".to_string(),
1009 description: "Test input".to_string(),
1010 default: Some("default".to_string()),
1011 password: Some(false),
1012 });
1013
1014 computer.add_or_update_input(input.clone()).await.unwrap();
1015
1016 let retrieved = computer.get_input("test_input").await.unwrap();
1018 assert!(retrieved.is_some());
1019
1020 let inputs = computer.list_inputs().await.unwrap();
1022 assert_eq!(inputs.len(), 1);
1023
1024 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1026 id: "test_input".to_string(),
1027 description: "Updated description".to_string(),
1028 default: Some("new_default".to_string()),
1029 password: Some(true),
1030 });
1031 computer.add_or_update_input(updated_input).await.unwrap();
1032
1033 let retrieved = computer.get_input("test_input").await.unwrap().unwrap();
1034 match retrieved {
1035 MCPServerInput::PromptString(input) => {
1036 assert_eq!(input.description, "Updated description");
1037 assert_eq!(input.default, Some("new_default".to_string()));
1038 assert_eq!(input.password, Some(true));
1039 }
1040 _ => panic!("Expected PromptString input"),
1041 }
1042
1043 let removed = computer.remove_input("test_input").await.unwrap();
1045 assert!(removed);
1046
1047 let retrieved = computer.get_input("test_input").await.unwrap();
1048 assert!(retrieved.is_none());
1049
1050 let removed = computer.remove_input("non_existent").await.unwrap();
1052 assert!(!removed);
1053 }
1054
1055 #[tokio::test]
1056 async fn test_multiple_input_types() {
1057 let session = SilentSession::new("test");
1058 let computer = Computer::new("test_computer", session, None, None, true, true);
1059
1060 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1062 id: "prompt".to_string(),
1063 description: "Prompt input".to_string(),
1064 default: None,
1065 password: Some(false),
1066 });
1067
1068 let pick_input = MCPServerInput::PickString(PickStringInput {
1069 id: "pick".to_string(),
1070 description: "Pick input".to_string(),
1071 options: vec!["option1".to_string(), "option2".to_string()],
1072 default: Some("option1".to_string()),
1073 });
1074
1075 let command_input = MCPServerInput::Command(CommandInput {
1076 id: "command".to_string(),
1077 description: "Command input".to_string(),
1078 command: "ls".to_string(),
1079 args: None,
1080 });
1081
1082 computer.add_or_update_input(prompt_input).await.unwrap();
1083 computer.add_or_update_input(pick_input).await.unwrap();
1084 computer.add_or_update_input(command_input).await.unwrap();
1085
1086 let inputs = computer.list_inputs().await.unwrap();
1087 assert_eq!(inputs.len(), 3);
1088
1089 let input_types: std::collections::HashSet<_> = inputs
1091 .iter()
1092 .map(|input| match input {
1093 MCPServerInput::PromptString(_) => "prompt",
1094 MCPServerInput::PickString(_) => "pick",
1095 MCPServerInput::Command(_) => "command",
1096 })
1097 .collect();
1098
1099 assert!(input_types.contains("prompt"));
1100 assert!(input_types.contains("pick"));
1101 assert!(input_types.contains("command"));
1102 }
1103
1104 #[tokio::test]
1105 async fn test_server_management() {
1106 let session = SilentSession::new("test");
1107 let computer = Computer::new("test_computer", session, None, None, true, true);
1108
1109 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1111 name: "test_server".to_string(),
1112 disabled: false,
1113 forbidden_tools: vec![],
1114 tool_meta: std::collections::HashMap::new(),
1115 default_tool_meta: None,
1116 vrl: None,
1117 server_parameters: StdioServerParameters {
1118 command: "echo".to_string(),
1119 args: vec!["hello".to_string()],
1120 env: std::collections::HashMap::new(),
1121 cwd: None,
1122 },
1123 });
1124
1125 computer
1126 .add_or_update_server(server_config.clone())
1127 .await
1128 .unwrap();
1129
1130 let updated_config = MCPServerConfig::Stdio(StdioServerConfig {
1133 name: "test_server".to_string(),
1134 disabled: true, forbidden_tools: vec!["tool1".to_string()],
1136 tool_meta: std::collections::HashMap::new(),
1137 default_tool_meta: None,
1138 vrl: None,
1139 server_parameters: StdioServerParameters {
1140 command: "echo".to_string(),
1141 args: vec!["updated".to_string()],
1142 env: std::collections::HashMap::new(),
1143 cwd: None,
1144 },
1145 });
1146
1147 computer.add_or_update_server(updated_config).await.unwrap();
1148
1149 computer.remove_server("test_server").await.unwrap();
1151 }
1152
1153 #[tokio::test]
1154 async fn test_session_trait() {
1155 let session = SilentSession::new("test_session");
1157 assert_eq!(session.session_id(), "test_session");
1158
1159 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1161 id: "test".to_string(),
1162 description: "Test".to_string(),
1163 default: Some("default_value".to_string()),
1164 password: Some(false),
1165 });
1166
1167 let result = session.resolve_input(&prompt_input).await.unwrap();
1168 assert_eq!(
1169 result,
1170 serde_json::Value::String("default_value".to_string())
1171 );
1172
1173 let no_default_input = MCPServerInput::PromptString(PromptStringInput {
1175 id: "test2".to_string(),
1176 description: "Test2".to_string(),
1177 default: None,
1178 password: Some(false),
1179 });
1180
1181 let result = session.resolve_input(&no_default_input).await.unwrap();
1182 assert_eq!(result, serde_json::Value::String("".to_string()));
1183
1184 let pick_input = MCPServerInput::PickString(PickStringInput {
1186 id: "pick".to_string(),
1187 description: "Pick".to_string(),
1188 options: vec!["opt1".to_string(), "opt2".to_string()],
1189 default: Some("opt2".to_string()),
1190 });
1191
1192 let result = session.resolve_input(&pick_input).await.unwrap();
1193 assert_eq!(result, serde_json::Value::String("opt2".to_string()));
1194
1195 let command_input = MCPServerInput::Command(CommandInput {
1197 id: "cmd".to_string(),
1198 description: "Command".to_string(),
1199 command: "echo hello world".to_string(),
1200 args: None,
1201 });
1202
1203 let result = session.resolve_input(&command_input).await.unwrap();
1204 assert_eq!(result, serde_json::Value::String("hello world".to_string()));
1205 }
1206
1207 #[tokio::test]
1208 async fn test_cache_operations() {
1209 let session = SilentSession::new("test");
1210 let computer = Computer::new("test_computer", session, None, None, true, true);
1211
1212 let input = MCPServerInput::PromptString(PromptStringInput {
1214 id: "test_input".to_string(),
1215 description: "Test input".to_string(),
1216 default: Some("default".to_string()),
1217 password: Some(false),
1218 });
1219 computer.add_or_update_input(input).await.unwrap();
1220
1221 let test_value = serde_json::Value::String("cached_value".to_string());
1223 let set_result = computer
1224 .set_input_value("test_input", test_value.clone())
1225 .await
1226 .unwrap();
1227 assert!(set_result);
1228
1229 let retrieved = computer.get_input_value("test_input").await.unwrap();
1230 assert_eq!(retrieved, Some(test_value));
1231
1232 let invalid_result = computer
1234 .set_input_value(
1235 "nonexistent",
1236 serde_json::Value::String("value".to_string()),
1237 )
1238 .await
1239 .unwrap();
1240 assert!(!invalid_result);
1241
1242 let not_found = computer.get_input_value("nonexistent").await.unwrap();
1244 assert!(not_found.is_none());
1245 }
1246
1247 #[tokio::test]
1248 async fn test_cache_remove_and_clear() {
1249 let session = SilentSession::new("test");
1250 let computer = Computer::new("test_computer", session, None, None, true, true);
1251
1252 let input1 = MCPServerInput::PromptString(PromptStringInput {
1254 id: "input1".to_string(),
1255 description: "Input 1".to_string(),
1256 default: None,
1257 password: Some(false),
1258 });
1259 let input2 = MCPServerInput::PromptString(PromptStringInput {
1260 id: "input2".to_string(),
1261 description: "Input 2".to_string(),
1262 default: None,
1263 password: Some(false),
1264 });
1265 computer.add_or_update_input(input1).await.unwrap();
1266 computer.add_or_update_input(input2).await.unwrap();
1267
1268 computer
1270 .set_input_value("input1", serde_json::Value::String("value1".to_string()))
1271 .await
1272 .unwrap();
1273 computer
1274 .set_input_value("input2", serde_json::Value::String("value2".to_string()))
1275 .await
1276 .unwrap();
1277
1278 let removed = computer.remove_input_value("input1").await.unwrap();
1280 assert!(removed);
1281
1282 let retrieved = computer.get_input_value("input1").await.unwrap();
1283 assert!(retrieved.is_none());
1284
1285 let still_exists = computer.get_input_value("input2").await.unwrap();
1286 assert!(still_exists.is_some());
1287
1288 computer.clear_input_values(None).await.unwrap();
1290 let cleared1 = computer.get_input_value("input1").await.unwrap();
1291 let cleared2 = computer.get_input_value("input2").await.unwrap();
1292 assert!(cleared1.is_none());
1293 assert!(cleared2.is_none());
1294 }
1295
1296 #[tokio::test]
1297 async fn test_cache_list_values() {
1298 let session = SilentSession::new("test");
1299 let computer = Computer::new("test_computer", session, None, None, true, true);
1300
1301 let input1 = MCPServerInput::PromptString(PromptStringInput {
1303 id: "input1".to_string(),
1304 description: "Input 1".to_string(),
1305 default: None,
1306 password: Some(false),
1307 });
1308 let input2 = MCPServerInput::PromptString(PromptStringInput {
1309 id: "input2".to_string(),
1310 description: "Input 2".to_string(),
1311 default: None,
1312 password: Some(false),
1313 });
1314 computer.add_or_update_input(input1).await.unwrap();
1315 computer.add_or_update_input(input2).await.unwrap();
1316
1317 computer
1319 .set_input_value(
1320 "input1",
1321 serde_json::Value::String("string_value".to_string()),
1322 )
1323 .await
1324 .unwrap();
1325 computer
1326 .set_input_value(
1327 "input2",
1328 serde_json::Value::Number(serde_json::Number::from(42)),
1329 )
1330 .await
1331 .unwrap();
1332
1333 let values = computer.list_input_values().await.unwrap();
1335 assert_eq!(values.len(), 2);
1336 assert_eq!(
1337 values.get("input1"),
1338 Some(&serde_json::Value::String("string_value".to_string()))
1339 );
1340 assert_eq!(
1341 values.get("input2"),
1342 Some(&serde_json::Value::Number(serde_json::Number::from(42)))
1343 );
1344 }
1345
1346 #[tokio::test]
1347 async fn test_cache_clear_on_input_update() {
1348 let session = SilentSession::new("test");
1349 let computer = Computer::new("test_computer", session, None, None, true, true);
1350
1351 let input = MCPServerInput::PromptString(PromptStringInput {
1353 id: "test_input".to_string(),
1354 description: "Test input".to_string(),
1355 default: None,
1356 password: Some(false),
1357 });
1358 computer.add_or_update_input(input).await.unwrap();
1359
1360 computer
1362 .set_input_value(
1363 "test_input",
1364 serde_json::Value::String("cached".to_string()),
1365 )
1366 .await
1367 .unwrap();
1368 assert!(computer
1369 .get_input_value("test_input")
1370 .await
1371 .unwrap()
1372 .is_some());
1373
1374 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1376 id: "test_input".to_string(),
1377 description: "Updated input".to_string(),
1378 default: Some("new_default".to_string()),
1379 password: Some(true),
1380 });
1381 computer.add_or_update_input(updated_input).await.unwrap();
1382
1383 assert!(computer
1385 .get_input_value("test_input")
1386 .await
1387 .unwrap()
1388 .is_none());
1389 }
1390
1391 #[tokio::test]
1392 async fn test_cache_clear_on_input_remove() {
1393 let session = SilentSession::new("test");
1394 let computer = Computer::new("test_computer", session, None, None, true, true);
1395
1396 let input = MCPServerInput::PromptString(PromptStringInput {
1398 id: "test_input".to_string(),
1399 description: "Test input".to_string(),
1400 default: None,
1401 password: Some(false),
1402 });
1403 computer.add_or_update_input(input).await.unwrap();
1404
1405 computer
1407 .set_input_value(
1408 "test_input",
1409 serde_json::Value::String("cached".to_string()),
1410 )
1411 .await
1412 .unwrap();
1413 assert!(computer
1414 .get_input_value("test_input")
1415 .await
1416 .unwrap()
1417 .is_some());
1418
1419 let removed = computer.remove_input("test_input").await.unwrap();
1421 assert!(removed);
1422
1423 assert!(computer
1425 .get_input_value("test_input")
1426 .await
1427 .unwrap()
1428 .is_none());
1429 }
1430
1431 #[tokio::test]
1432 async fn test_tool_call_history() {
1433 let session = SilentSession::new("test");
1434 let computer = Computer::new("test_computer", session, None, None, true, true);
1435
1436 let history = computer.get_tool_history().await.unwrap();
1438 assert!(history.is_empty());
1439
1440 }
1443
1444 #[tokio::test]
1445 async fn test_confirmation_callback() {
1446 let session = SilentSession::new("test");
1447 let computer = Computer::new("test_computer", session, None, None, true, true);
1448
1449 let callback_called = Arc::new(Mutex::new(false));
1451 let callback_called_clone = callback_called.clone();
1452
1453 let _computer = computer.with_confirm_callback(move |_req_id, _server, _tool, _params| {
1454 let rt = tokio::runtime::Handle::current();
1457 rt.block_on(async {
1458 let mut called = callback_called_clone.lock().await;
1459 *called = true;
1460 });
1461 true });
1463
1464 }
1467
1468 #[tokio::test]
1469 async fn test_computer_shutdown() {
1470 let session = SilentSession::new("test");
1471 let computer = Computer::new("test_computer", session, None, None, true, true);
1472
1473 computer.shutdown().await.unwrap();
1475
1476 computer.boot_up().await.unwrap();
1478 computer.shutdown().await.unwrap();
1479 }
1480
1481 #[tokio::test]
1482 async fn test_config_render() {
1483 let session = SilentSession::new("test");
1484
1485 let mut inputs = HashMap::new();
1487 inputs.insert(
1488 "api_key".to_string(),
1489 MCPServerInput::PromptString(PromptStringInput {
1490 id: "api_key".to_string(),
1491 description: "API Key".to_string(),
1492 default: Some("test-api-key-12345".to_string()),
1493 password: Some(true),
1494 }),
1495 );
1496 inputs.insert(
1497 "server_url".to_string(),
1498 MCPServerInput::PromptString(PromptStringInput {
1499 id: "server_url".to_string(),
1500 description: "Server URL".to_string(),
1501 default: Some("https://api.example.com".to_string()),
1502 password: Some(false),
1503 }),
1504 );
1505
1506 let computer = Computer::new("test_computer", session, Some(inputs), None, true, true);
1507
1508 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1510 name: "test_server".to_string(),
1511 disabled: false,
1512 forbidden_tools: vec![],
1513 tool_meta: std::collections::HashMap::new(),
1514 default_tool_meta: None,
1515 vrl: None,
1516 server_parameters: StdioServerParameters {
1517 command: "echo".to_string(),
1518 args: vec!["${input:api_key}".to_string()],
1519 env: {
1520 let mut env = std::collections::HashMap::new();
1521 env.insert("API_URL".to_string(), "${input:server_url}".to_string());
1522 env
1523 },
1524 cwd: None,
1525 },
1526 });
1527
1528 let rendered = computer.render_server_config(&server_config).await.unwrap();
1530
1531 match rendered {
1533 MCPServerConfig::Stdio(config) => {
1534 assert_eq!(config.server_parameters.args[0], "test-api-key-12345");
1535 assert_eq!(
1536 config.server_parameters.env.get("API_URL"),
1537 Some(&"https://api.example.com".to_string())
1538 );
1539 }
1540 _ => panic!("Expected Stdio config"),
1541 }
1542 }
1543
1544 #[tokio::test]
1545 async fn test_config_render_missing_input() {
1546 let session = SilentSession::new("test");
1547 let computer = Computer::new("test_computer", session, None, None, true, true);
1548
1549 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1551 name: "test_server".to_string(),
1552 disabled: false,
1553 forbidden_tools: vec![],
1554 tool_meta: std::collections::HashMap::new(),
1555 default_tool_meta: None,
1556 vrl: None,
1557 server_parameters: StdioServerParameters {
1558 command: "echo".to_string(),
1559 args: vec!["${input:missing_input}".to_string()],
1560 env: std::collections::HashMap::new(),
1561 cwd: None,
1562 },
1563 });
1564
1565 let rendered = computer.render_server_config(&server_config).await.unwrap();
1567
1568 match rendered {
1569 MCPServerConfig::Stdio(config) => {
1570 assert_eq!(config.server_parameters.args[0], "${input:missing_input}");
1572 }
1573 _ => panic!("Expected Stdio config"),
1574 }
1575 }
1576}