1use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{Mutex, RwLock};
17use tracing::{debug, error, info};
18
19use crate::errors::{ComputerError, ComputerResult};
20use crate::inputs::handler::InputHandler;
21use crate::inputs::model::InputValue;
22use crate::inputs::utils::run_command;
23use crate::mcp_clients::{
24 manager::MCPServerManager,
25 model::{CallToolResult, MCPServerConfig, MCPServerInput, Tool},
26 ConfigRender, RenderError,
27};
28use crate::socketio_client::SmcpComputerClient;
29
30type ConfirmCallbackType = Arc<dyn Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync>;
32
33fn input_value_to_json(value: InputValue) -> serde_json::Value {
35 match value {
36 InputValue::String(s) => serde_json::Value::String(s),
37 InputValue::Number(n) => serde_json::Value::Number(serde_json::Number::from(n)),
38 InputValue::Float(f) => serde_json::Value::Number(
39 serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
40 ),
41 InputValue::Bool(b) => serde_json::Value::Bool(b),
42 }
43}
44
45fn json_to_input_value(value: serde_json::Value) -> ComputerResult<InputValue> {
47 match value {
48 serde_json::Value::String(s) => Ok(InputValue::String(s)),
49 serde_json::Value::Number(n) => {
50 if let Some(i) = n.as_i64() {
51 Ok(InputValue::Number(i))
52 } else if let Some(u) = n.as_u64() {
53 Ok(InputValue::Number(u as i64))
54 } else if let Some(f) = n.as_f64() {
55 Ok(InputValue::Float(f))
56 } else {
57 Err(ComputerError::ValidationError(
58 "Invalid number value".to_string(),
59 ))
60 }
61 }
62 serde_json::Value::Bool(b) => Ok(InputValue::Bool(b)),
63 serde_json::Value::Null => Err(ComputerError::ValidationError(
64 "Null value not supported".to_string(),
65 )),
66 _ => Err(ComputerError::ValidationError(
67 "Unsupported value type".to_string(),
68 )),
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ToolCallRecord {
75 pub timestamp: DateTime<Utc>,
77 pub req_id: String,
79 pub server: String,
81 pub tool: String,
83 pub parameters: serde_json::Value,
85 pub timeout: Option<f64>,
87 pub success: bool,
89 pub error: Option<String>,
91}
92
93#[async_trait]
96pub trait Session: Send + Sync {
97 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value>;
99
100 fn session_id(&self) -> &str;
102}
103
104pub struct SilentSession {
106 id: String,
107}
108
109impl SilentSession {
110 pub fn new(id: impl Into<String>) -> Self {
112 Self { id: id.into() }
113 }
114}
115
116#[async_trait]
117impl Session for SilentSession {
118 async fn resolve_input(&self, input: &MCPServerInput) -> ComputerResult<serde_json::Value> {
119 match input {
121 MCPServerInput::PromptString(input) => Ok(serde_json::Value::String(
122 input.default.clone().unwrap_or_default(),
123 )),
124 MCPServerInput::PickString(input) => Ok(serde_json::Value::String(
125 input
126 .default
127 .clone()
128 .unwrap_or_else(|| input.options.first().cloned().unwrap_or_default()),
129 )),
130 MCPServerInput::Command(input) => {
131 let args: Vec<String> = input
133 .args
134 .as_ref()
135 .map(|m| {
136 let mut sorted_pairs: Vec<_> = m.iter().collect();
137 sorted_pairs.sort_by_key(|(k, _)| *k);
138 sorted_pairs.into_iter().map(|(_, v)| v.clone()).collect()
139 })
140 .unwrap_or_default();
141 match run_command(&input.command, &args).await {
142 Ok(output) => Ok(serde_json::Value::String(output)),
143 Err(e) => Err(ComputerError::RuntimeError(format!(
144 "Failed to execute command '{}': {}",
145 input.command, e
146 ))),
147 }
148 }
149 }
150 }
151
152 fn session_id(&self) -> &str {
153 &self.id
154 }
155}
156
157pub struct Computer<S: Session> {
159 name: String,
161 mcp_manager: Arc<RwLock<Option<MCPServerManager>>>,
163 inputs: Arc<RwLock<HashMap<String, MCPServerInput>>>,
167 mcp_servers: RwLock<HashMap<String, MCPServerConfig>>,
169 input_handler: Arc<RwLock<InputHandler>>,
171 auto_connect: bool,
173 auto_reconnect: bool,
175 tool_history: Arc<Mutex<Vec<ToolCallRecord>>>,
177 session: S,
179 socketio_client: Arc<RwLock<Option<Arc<SmcpComputerClient>>>>,
183 confirm_callback: Option<ConfirmCallbackType>,
185}
186
187impl<S: Session> Computer<S> {
188 pub fn new(
190 name: impl Into<String>,
191 session: S,
192 inputs: Option<HashMap<String, MCPServerInput>>,
193 mcp_servers: Option<HashMap<String, MCPServerConfig>>,
194 auto_connect: bool,
195 auto_reconnect: bool,
196 ) -> Self {
197 let name = name.into();
198 let inputs = inputs.unwrap_or_default();
199 let mcp_servers = mcp_servers.unwrap_or_default();
200
201 Self {
202 name,
203 mcp_manager: Arc::new(RwLock::new(None)),
204 inputs: Arc::new(RwLock::new(inputs)),
205 mcp_servers: RwLock::new(mcp_servers),
206 input_handler: Arc::new(RwLock::new(InputHandler::new())),
207 auto_connect,
208 auto_reconnect,
209 tool_history: Arc::new(Mutex::new(Vec::new())),
210 session,
211 socketio_client: Arc::new(RwLock::new(None)),
212 confirm_callback: None,
213 }
214 }
215
216 pub fn with_confirm_callback<F>(mut self, callback: F) -> Self
218 where
219 F: Fn(&str, &str, &str, &serde_json::Value) -> bool + Send + Sync + 'static,
220 {
221 self.confirm_callback = Some(Arc::new(callback));
222 self
223 }
224
225 pub fn name(&self) -> &str {
227 &self.name
228 }
229
230 pub fn get_socketio_client(&self) -> Arc<RwLock<Option<Arc<SmcpComputerClient>>>> {
234 self.socketio_client.clone()
235 }
236
237 pub async fn boot_up(&self) -> ComputerResult<()> {
239 info!("Starting Computer: {}", self.name);
240
241 let manager = MCPServerManager::new();
243
244 let servers = self.mcp_servers.read().await;
246 let mut validated_servers = Vec::new();
247
248 for (_name, server_config) in servers.iter() {
249 match self.render_server_config(server_config).await {
250 Ok(validated) => validated_servers.push(validated),
251 Err(e) => {
252 error!(
253 "Failed to render server config {}: {}",
254 server_config.name(),
255 e
256 );
257 validated_servers.push(server_config.clone());
259 }
260 }
261 }
262
263 manager.initialize(validated_servers).await?;
265
266 *self.mcp_manager.write().await = Some(manager);
268
269 info!("Computer {} started successfully", self.name);
270 Ok(())
271 }
272
273 async fn render_server_config(
277 &self,
278 config: &MCPServerConfig,
279 ) -> ComputerResult<MCPServerConfig> {
280 let config_json = serde_json::to_value(config)?;
282
283 let renderer = ConfigRender::default();
285
286 let inputs = self.inputs.read().await;
288 let inputs_clone: std::collections::HashMap<String, MCPServerInput> = inputs.clone();
289 drop(inputs); let mut resolved_values: std::collections::HashMap<String, serde_json::Value> =
295 std::collections::HashMap::new();
296 for (input_id, input) in inputs_clone.iter() {
297 match self.session.resolve_input(input).await {
298 Ok(value) => {
299 resolved_values.insert(input_id.clone(), value);
300 }
301 Err(e) => {
302 debug!(
303 "Failed to resolve input '{}': {}, will use default",
304 input_id, e
305 );
306 if let Some(default) = input.default() {
308 resolved_values.insert(input_id.clone(), default);
309 }
310 }
311 }
312 }
313
314 let resolver = |input_id: String| {
316 let values = resolved_values.clone();
317 async move {
318 if let Some(value) = values.get(&input_id) {
319 Ok(value.clone())
320 } else {
321 Err(RenderError::InputNotFound(input_id))
322 }
323 }
324 };
325
326 let rendered_json = renderer.render(config_json, resolver).await?;
328
329 let rendered_config: MCPServerConfig = serde_json::from_value(rendered_json)?;
331
332 Ok(rendered_config)
333 }
334
335 pub async fn add_or_update_server(&self, server: MCPServerConfig) -> ComputerResult<()> {
337 {
339 let mut manager_guard = self.mcp_manager.write().await;
340 if manager_guard.is_none() {
341 *manager_guard = Some(MCPServerManager::new());
342 }
343 }
344
345 let validated = self.render_server_config(&server).await?;
347
348 let manager = self.mcp_manager.read().await;
350 if let Some(ref manager) = *manager {
351 manager.add_or_update_server(validated).await?;
352 }
353
354 {
356 let mut servers = self.mcp_servers.write().await;
357 servers.insert(server.name().to_string(), server);
358 }
359
360 let _ = self.emit_update_config().await;
362
363 Ok(())
364 }
365
366 pub async fn remove_server(&self, server_name: &str) -> ComputerResult<()> {
368 let manager = self.mcp_manager.read().await;
369 if let Some(ref manager) = *manager {
370 manager.remove_server(server_name).await?;
371 }
372
373 {
375 let mut servers = self.mcp_servers.write().await;
376 servers.remove(server_name);
377 }
378
379 let _ = self.emit_update_config().await;
381
382 Ok(())
383 }
384
385 pub async fn update_inputs(
387 &self,
388 inputs: HashMap<String, MCPServerInput>,
389 ) -> ComputerResult<()> {
390 *self.inputs.write().await = inputs;
391
392 {
394 let mut input_handler = self.input_handler.write().await;
395 *input_handler = InputHandler::new();
396 }
397
398 let _ = self.emit_update_config().await;
400
401 Ok(())
402 }
403
404 pub async fn add_or_update_input(&self, input: MCPServerInput) -> ComputerResult<()> {
406 let input_id = input.id().to_string();
407 {
408 let mut inputs = self.inputs.write().await;
409 inputs.insert(input_id.clone(), input);
410 }
411
412 self.clear_input_values(Some(&input_id)).await?;
414
415 let _ = self.emit_update_config().await;
417
418 Ok(())
419 }
420
421 pub async fn remove_input(&self, input_id: &str) -> ComputerResult<bool> {
423 let removed = {
424 let mut inputs = self.inputs.write().await;
425 inputs.remove(input_id).is_some()
426 };
427
428 if removed {
429 self.clear_input_values(Some(input_id)).await?;
431
432 let _ = self.emit_update_config().await;
434 }
435
436 Ok(removed)
437 }
438
439 pub async fn get_input(&self, input_id: &str) -> ComputerResult<Option<MCPServerInput>> {
441 let inputs = self.inputs.read().await;
442 Ok(inputs.get(input_id).cloned())
443 }
444
445 pub async fn list_inputs(&self) -> ComputerResult<Vec<MCPServerInput>> {
447 let inputs = self.inputs.read().await;
448 Ok(inputs.values().cloned().collect())
449 }
450
451 pub async fn get_input_value(
453 &self,
454 input_id: &str,
455 ) -> ComputerResult<Option<serde_json::Value>> {
456 let handler = self.input_handler.read().await;
458 let cached_values = handler.get_all_cached_values().await;
459
460 for (key, value) in cached_values {
462 if key.starts_with(input_id) {
465 let parts: Vec<&str> = key.split(':').collect();
467 if !parts.is_empty() && parts[0] == input_id {
468 return Ok(Some(input_value_to_json(value)));
469 }
470 }
471 }
472
473 Ok(None)
474 }
475
476 pub async fn set_input_value(
478 &self,
479 input_id: &str,
480 value: serde_json::Value,
481 ) -> ComputerResult<bool> {
482 {
484 let inputs = self.inputs.read().await;
485 if !inputs.contains_key(input_id) {
486 return Ok(false);
487 }
488 }
489
490 let handler = self.input_handler.read().await;
492 let input_value = json_to_input_value(value)?;
493 handler
494 .set_cached_value(input_id.to_string(), input_value)
495 .await;
496
497 Ok(true)
498 }
499
500 pub async fn remove_input_value(&self, input_id: &str) -> ComputerResult<bool> {
502 let handler = self.input_handler.read().await;
503 let removed = handler.remove_cached_value(input_id).await.is_some();
504 Ok(removed)
505 }
506
507 pub async fn list_input_values(&self) -> ComputerResult<HashMap<String, serde_json::Value>> {
509 let handler = self.input_handler.read().await;
510 let cached_values = handler.get_all_cached_values().await;
511
512 let mut result = HashMap::new();
513 for (key, value) in cached_values {
514 let parts: Vec<&str> = key.split(':').collect();
517 if !parts.is_empty() {
518 result.insert(parts[0].to_string(), input_value_to_json(value));
519 }
520 }
521
522 Ok(result)
523 }
524
525 pub async fn clear_input_values(&self, input_id: Option<&str>) -> ComputerResult<()> {
527 let handler = self.input_handler.read().await;
528
529 if let Some(id) = input_id {
530 let cached_values = handler.get_all_cached_values().await;
532 let keys_to_remove: Vec<String> = cached_values
533 .keys()
534 .filter(|key| key.starts_with(id))
535 .cloned()
536 .collect();
537
538 for key in keys_to_remove {
539 handler.remove_cached_value(&key).await;
540 }
541 } else {
542 handler.clear_all_cache().await;
544 }
545
546 Ok(())
547 }
548
549 pub async fn get_available_tools(&self) -> ComputerResult<Vec<Tool>> {
551 let manager = self.mcp_manager.read().await;
552 if let Some(ref manager) = *manager {
553 let tools: Vec<Tool> = manager.list_available_tools().await;
554 Ok(tools)
558 } else {
559 Err(ComputerError::InvalidState(
560 "Computer not initialized".to_string(),
561 ))
562 }
563 }
564
565 pub async fn execute_tool(
567 &self,
568 req_id: &str,
569 tool_name: &str,
570 parameters: serde_json::Value,
571 timeout: Option<f64>,
572 ) -> ComputerResult<CallToolResult> {
573 let manager = self.mcp_manager.read().await;
574 if let Some(ref manager) = *manager {
575 let (server_name, tool_name) =
577 manager.validate_tool_call(tool_name, ¶meters).await?;
578 let server_name = server_name.to_string();
579 let tool_name = tool_name.to_string();
580
581 let timestamp = Utc::now();
582 let mut success = false;
583 let mut error_msg = None;
584 let result: CallToolResult;
585
586 let need_confirm = true; let parameters_for_call = parameters.clone();
592
593 if need_confirm {
594 if let Some(ref callback) = self.confirm_callback {
595 let confirmed = callback(req_id, &server_name, &tool_name, ¶meters);
596 if confirmed {
597 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
598 result = manager
599 .call_tool(
600 &server_name,
601 &tool_name,
602 parameters_for_call,
603 timeout_duration,
604 )
605 .await?;
606 success = !result.is_error;
607 } else {
608 result = CallToolResult {
609 content: vec![crate::mcp_clients::model::Content::Text {
610 text: "工具调用二次确认被拒绝,请稍后再试".to_string(),
611 }],
612 is_error: false,
613 meta: None,
614 };
615 }
616 } else {
617 result = CallToolResult {
618 content: vec![crate::mcp_clients::model::Content::Text {
619 text: "当前工具需要调用前进行二次确认,但客户端目前没有实现二次确认回调方法".to_string(),
620 }],
621 is_error: true,
622 meta: None,
623 };
624 error_msg = Some("No confirmation callback".to_string());
625 }
626 } else {
627 let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
628 result = manager
629 .call_tool(
630 &server_name,
631 &tool_name,
632 parameters_for_call,
633 timeout_duration,
634 )
635 .await?;
636 success = !result.is_error;
637 }
638
639 if result.is_error {
640 error_msg = result.content.iter().find_map(|c| match c {
641 crate::mcp_clients::model::Content::Text { text } => Some(text.clone()),
642 _ => None,
643 });
644 }
645
646 let record = ToolCallRecord {
648 timestamp,
649 req_id: req_id.to_string(),
650 server: server_name,
651 tool: tool_name,
652 parameters,
653 timeout,
654 success,
655 error: error_msg,
656 };
657
658 {
659 let mut history = self.tool_history.lock().await;
660 history.push(record);
661 if history.len() > 10 {
663 history.remove(0);
664 }
665 }
666
667 Ok(result)
668 } else {
669 Err(ComputerError::InvalidState(
670 "Computer not initialized".to_string(),
671 ))
672 }
673 }
674
675 pub async fn get_tool_history(&self) -> ComputerResult<Vec<ToolCallRecord>> {
677 let history = self.tool_history.lock().await;
678 Ok(history.clone())
679 }
680
681 pub async fn get_server_status(&self) -> Vec<(String, bool, String)> {
683 let manager_guard = self.mcp_manager.read().await;
684 if let Some(ref manager) = *manager_guard {
685 manager.get_server_status().await
686 } else {
687 Vec::new()
688 }
689 }
690
691 pub async fn list_mcp_servers(&self) -> Vec<MCPServerConfig> {
693 let servers = self.mcp_servers.read().await;
694 servers.values().cloned().collect()
695 }
696
697 pub async fn start_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
699 let manager_guard = self.mcp_manager.read().await;
700 if let Some(ref manager) = *manager_guard {
701 if server_name == "all" {
702 manager.start_all().await
703 } else {
704 manager.start_client(server_name).await
705 }
706 } else {
707 Err(ComputerError::InvalidState(
708 "MCP Manager not initialized".to_string(),
709 ))
710 }
711 }
712
713 pub async fn stop_mcp_client(&self, server_name: &str) -> ComputerResult<()> {
715 let manager_guard = self.mcp_manager.read().await;
716 if let Some(ref manager) = *manager_guard {
717 if server_name == "all" {
718 manager.stop_all().await
719 } else {
720 manager.stop_client(server_name).await
721 }
722 } else {
723 Err(ComputerError::InvalidState(
724 "MCP Manager not initialized".to_string(),
725 ))
726 }
727 }
728
729 pub async fn is_mcp_manager_initialized(&self) -> bool {
731 let manager_guard = self.mcp_manager.read().await;
732 manager_guard.is_some()
733 }
734
735 pub async fn set_socketio_client(&self, client: Arc<SmcpComputerClient>) {
739 let mut socketio_ref = self.socketio_client.write().await;
740 *socketio_ref = Some(client);
743 }
744
745 pub async fn connect_socketio(
747 &self,
748 url: &str,
749 _namespace: &str,
750 auth: &Option<String>,
751 _headers: &Option<String>,
752 ) -> ComputerResult<()> {
753 let _manager_check = {
755 let manager_guard = self.mcp_manager.read().await;
756 match manager_guard.as_ref() {
757 Some(_m) => {
758 true
761 }
762 None => {
763 return Err(ComputerError::InvalidState(
764 "MCP Manager not initialized. Please add and start servers first."
765 .to_string(),
766 ));
767 }
768 }
769 };
770
771 let new_manager = MCPServerManager::new();
776
777 let client = SmcpComputerClient::new(
780 url,
781 Arc::new(RwLock::new(Some(new_manager))),
782 self.name.clone(),
783 auth.clone(),
784 self.inputs.clone(),
785 )
786 .await?;
787
788 let client_arc = Arc::new(client);
790 self.set_socketio_client(client_arc.clone()).await;
791
792 info!(
793 "Connected to SMCP server at {} with computer name: {}",
794 url, self.name
795 );
796
797 Ok(())
798 }
799
800 pub async fn disconnect_socketio(&self) -> ComputerResult<()> {
802 let mut socketio_ref = self.socketio_client.write().await;
803 *socketio_ref = None;
804 info!("Disconnected from server");
805 Ok(())
806 }
807
808 pub async fn join_office(&self, office_id: &str, _computer_name: &str) -> ComputerResult<()> {
810 let socketio_ref = self.socketio_client.read().await;
811 if let Some(ref client) = *socketio_ref {
812 client.join_office(office_id).await?;
815 return Ok(());
816 }
817 Err(ComputerError::InvalidState(
818 "Socket.IO client not connected".to_string(),
819 ))
820 }
821
822 pub async fn leave_office(&self) -> ComputerResult<()> {
824 let socketio_ref = self.socketio_client.read().await;
825 if let Some(ref client) = *socketio_ref {
826 let current_office_id = client.get_current_office_id().await?;
829 client.leave_office(¤t_office_id).await?;
830 return Ok(());
831 }
832 Err(ComputerError::InvalidState(
833 "Socket.IO client not connected".to_string(),
834 ))
835 }
836
837 pub async fn emit_update_config(&self) -> ComputerResult<()> {
839 let socketio_ref = self.socketio_client.read().await;
840 if let Some(ref client) = *socketio_ref {
841 client.emit_update_config().await?;
844 return Ok(());
845 }
846 Err(ComputerError::InvalidState(
847 "Socket.IO client not connected".to_string(),
848 ))
849 }
850
851 pub async fn shutdown(&self) -> ComputerResult<()> {
853 info!("Shutting down Computer: {}", self.name);
854
855 let mut manager_guard = self.mcp_manager.write().await;
856 if let Some(manager) = manager_guard.take() {
857 manager.stop_all().await?;
858 }
859
860 {
862 let mut socketio_ref = self.socketio_client.write().await;
863 *socketio_ref = None;
864 }
865
866 info!("Computer {} shutdown successfully", self.name);
867 Ok(())
868 }
869}
870
871impl<S: Session + Clone> Clone for Computer<S> {
873 fn clone(&self) -> Self {
874 Self {
875 name: self.name.clone(),
876 mcp_manager: Arc::clone(&self.mcp_manager),
877 inputs: Arc::new(RwLock::new(HashMap::new())), mcp_servers: RwLock::new(HashMap::new()),
879 input_handler: Arc::clone(&self.input_handler),
880 auto_connect: self.auto_connect,
881 auto_reconnect: self.auto_reconnect,
882 tool_history: Arc::clone(&self.tool_history),
883 session: self.session.clone(),
884 socketio_client: Arc::clone(&self.socketio_client),
885 confirm_callback: self.confirm_callback.clone(),
886 }
887 }
888}
889
890#[async_trait]
892pub trait ManagerChangeHandler: Send + Sync {
893 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()>;
895}
896
897#[derive(Debug, Clone)]
899pub enum ManagerChangeMessage {
900 ToolListChanged,
902 ResourceListChanged { windows: Vec<String> },
904 ResourceUpdated { uri: String },
906}
907
908#[async_trait]
909impl<S: Session> ManagerChangeHandler for Computer<S> {
910 async fn on_change(&self, message: ManagerChangeMessage) -> ComputerResult<()> {
911 match message {
912 ManagerChangeMessage::ToolListChanged => {
913 debug!("Tool list changed, notifying Socket.IO client");
914 let socketio_ref = self.socketio_client.read().await;
915 if let Some(ref client) = *socketio_ref {
916 client.emit_update_tool_list().await?;
919 }
920 }
921 ManagerChangeMessage::ResourceListChanged { windows: _ } => {
922 debug!("Resource list changed, checking for window updates");
923 }
925 ManagerChangeMessage::ResourceUpdated { uri } => {
926 debug!("Resource updated: {}", uri);
927 }
929 }
930 Ok(())
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937 use crate::mcp_clients::model::{
938 CommandInput, MCPServerConfig, MCPServerInput, PickStringInput, PromptStringInput,
939 StdioServerConfig, StdioServerParameters,
940 };
941
942 #[tokio::test]
943 async fn test_computer_creation() {
944 let session = SilentSession::new("test");
945 let computer = Computer::new("test_computer", session, None, None, true, true);
946
947 assert_eq!(computer.name, "test_computer");
948 assert!(computer.auto_connect);
949 assert!(computer.auto_reconnect);
950 }
951
952 #[tokio::test]
953 async fn test_computer_with_initial_inputs_and_servers() {
954 let session = SilentSession::new("test");
955 let mut inputs = HashMap::new();
956 inputs.insert(
957 "input1".to_string(),
958 MCPServerInput::PromptString(PromptStringInput {
959 id: "input1".to_string(),
960 description: "Test input".to_string(),
961 default: Some("default".to_string()),
962 password: Some(false),
963 }),
964 );
965
966 let mut servers = HashMap::new();
967 servers.insert(
968 "server1".to_string(),
969 MCPServerConfig::Stdio(StdioServerConfig {
970 name: "server1".to_string(),
971 disabled: false,
972 forbidden_tools: vec![],
973 tool_meta: std::collections::HashMap::new(),
974 default_tool_meta: None,
975 vrl: None,
976 server_parameters: StdioServerParameters {
977 command: "echo".to_string(),
978 args: vec![],
979 env: std::collections::HashMap::new(),
980 cwd: None,
981 },
982 }),
983 );
984
985 let computer = Computer::new(
986 "test_computer",
987 session,
988 Some(inputs),
989 Some(servers),
990 false,
991 false,
992 );
993
994 let inputs = computer.list_inputs().await.unwrap();
996 assert_eq!(inputs.len(), 1);
997 match &inputs[0] {
998 MCPServerInput::PromptString(input) => {
999 assert_eq!(input.id, "input1");
1000 assert_eq!(input.description, "Test input");
1001 }
1002 _ => panic!("Expected PromptString input"),
1003 }
1004 }
1005
1006 #[tokio::test]
1007 async fn test_input_management() {
1008 let session = SilentSession::new("test");
1009 let computer = Computer::new("test_computer", session, None, None, true, true);
1010
1011 let input = MCPServerInput::PromptString(PromptStringInput {
1013 id: "test_input".to_string(),
1014 description: "Test input".to_string(),
1015 default: Some("default".to_string()),
1016 password: Some(false),
1017 });
1018
1019 computer.add_or_update_input(input.clone()).await.unwrap();
1020
1021 let retrieved = computer.get_input("test_input").await.unwrap();
1023 assert!(retrieved.is_some());
1024
1025 let inputs = computer.list_inputs().await.unwrap();
1027 assert_eq!(inputs.len(), 1);
1028
1029 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1031 id: "test_input".to_string(),
1032 description: "Updated description".to_string(),
1033 default: Some("new_default".to_string()),
1034 password: Some(true),
1035 });
1036 computer.add_or_update_input(updated_input).await.unwrap();
1037
1038 let retrieved = computer.get_input("test_input").await.unwrap().unwrap();
1039 match retrieved {
1040 MCPServerInput::PromptString(input) => {
1041 assert_eq!(input.description, "Updated description");
1042 assert_eq!(input.default, Some("new_default".to_string()));
1043 assert_eq!(input.password, Some(true));
1044 }
1045 _ => panic!("Expected PromptString input"),
1046 }
1047
1048 let removed = computer.remove_input("test_input").await.unwrap();
1050 assert!(removed);
1051
1052 let retrieved = computer.get_input("test_input").await.unwrap();
1053 assert!(retrieved.is_none());
1054
1055 let removed = computer.remove_input("non_existent").await.unwrap();
1057 assert!(!removed);
1058 }
1059
1060 #[tokio::test]
1061 async fn test_multiple_input_types() {
1062 let session = SilentSession::new("test");
1063 let computer = Computer::new("test_computer", session, None, None, true, true);
1064
1065 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1067 id: "prompt".to_string(),
1068 description: "Prompt input".to_string(),
1069 default: None,
1070 password: Some(false),
1071 });
1072
1073 let pick_input = MCPServerInput::PickString(PickStringInput {
1074 id: "pick".to_string(),
1075 description: "Pick input".to_string(),
1076 options: vec!["option1".to_string(), "option2".to_string()],
1077 default: Some("option1".to_string()),
1078 });
1079
1080 let command_input = MCPServerInput::Command(CommandInput {
1081 id: "command".to_string(),
1082 description: "Command input".to_string(),
1083 command: "ls".to_string(),
1084 args: None,
1085 });
1086
1087 computer.add_or_update_input(prompt_input).await.unwrap();
1088 computer.add_or_update_input(pick_input).await.unwrap();
1089 computer.add_or_update_input(command_input).await.unwrap();
1090
1091 let inputs = computer.list_inputs().await.unwrap();
1092 assert_eq!(inputs.len(), 3);
1093
1094 let input_types: std::collections::HashSet<_> = inputs
1096 .iter()
1097 .map(|input| match input {
1098 MCPServerInput::PromptString(_) => "prompt",
1099 MCPServerInput::PickString(_) => "pick",
1100 MCPServerInput::Command(_) => "command",
1101 })
1102 .collect();
1103
1104 assert!(input_types.contains("prompt"));
1105 assert!(input_types.contains("pick"));
1106 assert!(input_types.contains("command"));
1107 }
1108
1109 #[tokio::test]
1110 async fn test_server_management() {
1111 let session = SilentSession::new("test");
1112 let computer = Computer::new("test_computer", session, None, None, true, true);
1113
1114 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1116 name: "test_server".to_string(),
1117 disabled: false,
1118 forbidden_tools: vec![],
1119 tool_meta: std::collections::HashMap::new(),
1120 default_tool_meta: None,
1121 vrl: None,
1122 server_parameters: StdioServerParameters {
1123 command: "echo".to_string(),
1124 args: vec!["hello".to_string()],
1125 env: std::collections::HashMap::new(),
1126 cwd: None,
1127 },
1128 });
1129
1130 computer
1131 .add_or_update_server(server_config.clone())
1132 .await
1133 .unwrap();
1134
1135 let updated_config = MCPServerConfig::Stdio(StdioServerConfig {
1138 name: "test_server".to_string(),
1139 disabled: true, forbidden_tools: vec!["tool1".to_string()],
1141 tool_meta: std::collections::HashMap::new(),
1142 default_tool_meta: None,
1143 vrl: None,
1144 server_parameters: StdioServerParameters {
1145 command: "echo".to_string(),
1146 args: vec!["updated".to_string()],
1147 env: std::collections::HashMap::new(),
1148 cwd: None,
1149 },
1150 });
1151
1152 computer.add_or_update_server(updated_config).await.unwrap();
1153
1154 computer.remove_server("test_server").await.unwrap();
1156 }
1157
1158 #[tokio::test]
1159 async fn test_session_trait() {
1160 let session = SilentSession::new("test_session");
1162 assert_eq!(session.session_id(), "test_session");
1163
1164 let prompt_input = MCPServerInput::PromptString(PromptStringInput {
1166 id: "test".to_string(),
1167 description: "Test".to_string(),
1168 default: Some("default_value".to_string()),
1169 password: Some(false),
1170 });
1171
1172 let result = session.resolve_input(&prompt_input).await.unwrap();
1173 assert_eq!(
1174 result,
1175 serde_json::Value::String("default_value".to_string())
1176 );
1177
1178 let no_default_input = MCPServerInput::PromptString(PromptStringInput {
1180 id: "test2".to_string(),
1181 description: "Test2".to_string(),
1182 default: None,
1183 password: Some(false),
1184 });
1185
1186 let result = session.resolve_input(&no_default_input).await.unwrap();
1187 assert_eq!(result, serde_json::Value::String("".to_string()));
1188
1189 let pick_input = MCPServerInput::PickString(PickStringInput {
1191 id: "pick".to_string(),
1192 description: "Pick".to_string(),
1193 options: vec!["opt1".to_string(), "opt2".to_string()],
1194 default: Some("opt2".to_string()),
1195 });
1196
1197 let result = session.resolve_input(&pick_input).await.unwrap();
1198 assert_eq!(result, serde_json::Value::String("opt2".to_string()));
1199
1200 let command_input = MCPServerInput::Command(CommandInput {
1202 id: "cmd".to_string(),
1203 description: "Command".to_string(),
1204 command: "echo hello world".to_string(),
1205 args: None,
1206 });
1207
1208 let result = session.resolve_input(&command_input).await.unwrap();
1209 assert_eq!(result, serde_json::Value::String("hello world".to_string()));
1210 }
1211
1212 #[tokio::test]
1213 async fn test_cache_operations() {
1214 let session = SilentSession::new("test");
1215 let computer = Computer::new("test_computer", session, None, None, true, true);
1216
1217 let input = MCPServerInput::PromptString(PromptStringInput {
1219 id: "test_input".to_string(),
1220 description: "Test input".to_string(),
1221 default: Some("default".to_string()),
1222 password: Some(false),
1223 });
1224 computer.add_or_update_input(input).await.unwrap();
1225
1226 let test_value = serde_json::Value::String("cached_value".to_string());
1228 let set_result = computer
1229 .set_input_value("test_input", test_value.clone())
1230 .await
1231 .unwrap();
1232 assert!(set_result);
1233
1234 let retrieved = computer.get_input_value("test_input").await.unwrap();
1235 assert_eq!(retrieved, Some(test_value));
1236
1237 let invalid_result = computer
1239 .set_input_value(
1240 "nonexistent",
1241 serde_json::Value::String("value".to_string()),
1242 )
1243 .await
1244 .unwrap();
1245 assert!(!invalid_result);
1246
1247 let not_found = computer.get_input_value("nonexistent").await.unwrap();
1249 assert!(not_found.is_none());
1250 }
1251
1252 #[tokio::test]
1253 async fn test_cache_remove_and_clear() {
1254 let session = SilentSession::new("test");
1255 let computer = Computer::new("test_computer", session, None, None, true, true);
1256
1257 let input1 = MCPServerInput::PromptString(PromptStringInput {
1259 id: "input1".to_string(),
1260 description: "Input 1".to_string(),
1261 default: None,
1262 password: Some(false),
1263 });
1264 let input2 = MCPServerInput::PromptString(PromptStringInput {
1265 id: "input2".to_string(),
1266 description: "Input 2".to_string(),
1267 default: None,
1268 password: Some(false),
1269 });
1270 computer.add_or_update_input(input1).await.unwrap();
1271 computer.add_or_update_input(input2).await.unwrap();
1272
1273 computer
1275 .set_input_value("input1", serde_json::Value::String("value1".to_string()))
1276 .await
1277 .unwrap();
1278 computer
1279 .set_input_value("input2", serde_json::Value::String("value2".to_string()))
1280 .await
1281 .unwrap();
1282
1283 let removed = computer.remove_input_value("input1").await.unwrap();
1285 assert!(removed);
1286
1287 let retrieved = computer.get_input_value("input1").await.unwrap();
1288 assert!(retrieved.is_none());
1289
1290 let still_exists = computer.get_input_value("input2").await.unwrap();
1291 assert!(still_exists.is_some());
1292
1293 computer.clear_input_values(None).await.unwrap();
1295 let cleared1 = computer.get_input_value("input1").await.unwrap();
1296 let cleared2 = computer.get_input_value("input2").await.unwrap();
1297 assert!(cleared1.is_none());
1298 assert!(cleared2.is_none());
1299 }
1300
1301 #[tokio::test]
1302 async fn test_cache_list_values() {
1303 let session = SilentSession::new("test");
1304 let computer = Computer::new("test_computer", session, None, None, true, true);
1305
1306 let input1 = MCPServerInput::PromptString(PromptStringInput {
1308 id: "input1".to_string(),
1309 description: "Input 1".to_string(),
1310 default: None,
1311 password: Some(false),
1312 });
1313 let input2 = MCPServerInput::PromptString(PromptStringInput {
1314 id: "input2".to_string(),
1315 description: "Input 2".to_string(),
1316 default: None,
1317 password: Some(false),
1318 });
1319 computer.add_or_update_input(input1).await.unwrap();
1320 computer.add_or_update_input(input2).await.unwrap();
1321
1322 computer
1324 .set_input_value(
1325 "input1",
1326 serde_json::Value::String("string_value".to_string()),
1327 )
1328 .await
1329 .unwrap();
1330 computer
1331 .set_input_value(
1332 "input2",
1333 serde_json::Value::Number(serde_json::Number::from(42)),
1334 )
1335 .await
1336 .unwrap();
1337
1338 let values = computer.list_input_values().await.unwrap();
1340 assert_eq!(values.len(), 2);
1341 assert_eq!(
1342 values.get("input1"),
1343 Some(&serde_json::Value::String("string_value".to_string()))
1344 );
1345 assert_eq!(
1346 values.get("input2"),
1347 Some(&serde_json::Value::Number(serde_json::Number::from(42)))
1348 );
1349 }
1350
1351 #[tokio::test]
1352 async fn test_cache_clear_on_input_update() {
1353 let session = SilentSession::new("test");
1354 let computer = Computer::new("test_computer", session, None, None, true, true);
1355
1356 let input = MCPServerInput::PromptString(PromptStringInput {
1358 id: "test_input".to_string(),
1359 description: "Test input".to_string(),
1360 default: None,
1361 password: Some(false),
1362 });
1363 computer.add_or_update_input(input).await.unwrap();
1364
1365 computer
1367 .set_input_value(
1368 "test_input",
1369 serde_json::Value::String("cached".to_string()),
1370 )
1371 .await
1372 .unwrap();
1373 assert!(computer
1374 .get_input_value("test_input")
1375 .await
1376 .unwrap()
1377 .is_some());
1378
1379 let updated_input = MCPServerInput::PromptString(PromptStringInput {
1381 id: "test_input".to_string(),
1382 description: "Updated input".to_string(),
1383 default: Some("new_default".to_string()),
1384 password: Some(true),
1385 });
1386 computer.add_or_update_input(updated_input).await.unwrap();
1387
1388 assert!(computer
1390 .get_input_value("test_input")
1391 .await
1392 .unwrap()
1393 .is_none());
1394 }
1395
1396 #[tokio::test]
1397 async fn test_cache_clear_on_input_remove() {
1398 let session = SilentSession::new("test");
1399 let computer = Computer::new("test_computer", session, None, None, true, true);
1400
1401 let input = MCPServerInput::PromptString(PromptStringInput {
1403 id: "test_input".to_string(),
1404 description: "Test input".to_string(),
1405 default: None,
1406 password: Some(false),
1407 });
1408 computer.add_or_update_input(input).await.unwrap();
1409
1410 computer
1412 .set_input_value(
1413 "test_input",
1414 serde_json::Value::String("cached".to_string()),
1415 )
1416 .await
1417 .unwrap();
1418 assert!(computer
1419 .get_input_value("test_input")
1420 .await
1421 .unwrap()
1422 .is_some());
1423
1424 let removed = computer.remove_input("test_input").await.unwrap();
1426 assert!(removed);
1427
1428 assert!(computer
1430 .get_input_value("test_input")
1431 .await
1432 .unwrap()
1433 .is_none());
1434 }
1435
1436 #[tokio::test]
1437 async fn test_tool_call_history() {
1438 let session = SilentSession::new("test");
1439 let computer = Computer::new("test_computer", session, None, None, true, true);
1440
1441 let history = computer.get_tool_history().await.unwrap();
1443 assert!(history.is_empty());
1444
1445 }
1448
1449 #[tokio::test]
1450 async fn test_confirmation_callback() {
1451 let session = SilentSession::new("test");
1452 let computer = Computer::new("test_computer", session, None, None, true, true);
1453
1454 let callback_called = Arc::new(Mutex::new(false));
1456 let callback_called_clone = callback_called.clone();
1457
1458 let _computer = computer.with_confirm_callback(move |_req_id, _server, _tool, _params| {
1459 let rt = tokio::runtime::Handle::current();
1462 rt.block_on(async {
1463 let mut called = callback_called_clone.lock().await;
1464 *called = true;
1465 });
1466 true });
1468
1469 }
1472
1473 #[tokio::test]
1474 async fn test_computer_shutdown() {
1475 let session = SilentSession::new("test");
1476 let computer = Computer::new("test_computer", session, None, None, true, true);
1477
1478 computer.shutdown().await.unwrap();
1480
1481 computer.boot_up().await.unwrap();
1483 computer.shutdown().await.unwrap();
1484 }
1485
1486 #[tokio::test]
1487 async fn test_config_render() {
1488 let session = SilentSession::new("test");
1489
1490 let mut inputs = HashMap::new();
1492 inputs.insert(
1493 "api_key".to_string(),
1494 MCPServerInput::PromptString(PromptStringInput {
1495 id: "api_key".to_string(),
1496 description: "API Key".to_string(),
1497 default: Some("test-api-key-12345".to_string()),
1498 password: Some(true),
1499 }),
1500 );
1501 inputs.insert(
1502 "server_url".to_string(),
1503 MCPServerInput::PromptString(PromptStringInput {
1504 id: "server_url".to_string(),
1505 description: "Server URL".to_string(),
1506 default: Some("https://api.example.com".to_string()),
1507 password: Some(false),
1508 }),
1509 );
1510
1511 let computer = Computer::new("test_computer", session, Some(inputs), None, true, true);
1512
1513 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1515 name: "test_server".to_string(),
1516 disabled: false,
1517 forbidden_tools: vec![],
1518 tool_meta: std::collections::HashMap::new(),
1519 default_tool_meta: None,
1520 vrl: None,
1521 server_parameters: StdioServerParameters {
1522 command: "echo".to_string(),
1523 args: vec!["${input:api_key}".to_string()],
1524 env: {
1525 let mut env = std::collections::HashMap::new();
1526 env.insert("API_URL".to_string(), "${input:server_url}".to_string());
1527 env
1528 },
1529 cwd: None,
1530 },
1531 });
1532
1533 let rendered = computer.render_server_config(&server_config).await.unwrap();
1535
1536 match rendered {
1538 MCPServerConfig::Stdio(config) => {
1539 assert_eq!(config.server_parameters.args[0], "test-api-key-12345");
1540 assert_eq!(
1541 config.server_parameters.env.get("API_URL"),
1542 Some(&"https://api.example.com".to_string())
1543 );
1544 }
1545 _ => panic!("Expected Stdio config"),
1546 }
1547 }
1548
1549 #[tokio::test]
1550 async fn test_config_render_missing_input() {
1551 let session = SilentSession::new("test");
1552 let computer = Computer::new("test_computer", session, None, None, true, true);
1553
1554 let server_config = MCPServerConfig::Stdio(StdioServerConfig {
1556 name: "test_server".to_string(),
1557 disabled: false,
1558 forbidden_tools: vec![],
1559 tool_meta: std::collections::HashMap::new(),
1560 default_tool_meta: None,
1561 vrl: None,
1562 server_parameters: StdioServerParameters {
1563 command: "echo".to_string(),
1564 args: vec!["${input:missing_input}".to_string()],
1565 env: std::collections::HashMap::new(),
1566 cwd: None,
1567 },
1568 });
1569
1570 let rendered = computer.render_server_config(&server_config).await.unwrap();
1572
1573 match rendered {
1574 MCPServerConfig::Stdio(config) => {
1575 assert_eq!(config.server_parameters.args[0], "${input:missing_input}");
1577 }
1578 _ => panic!("Expected Stdio config"),
1579 }
1580 }
1581}