1use crate::error::Result;
2use chrono::{DateTime, Utc};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use crate::api::Model;
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16#[derive(Debug, Clone)]
18pub enum McpConnectionState {
19 Connecting,
21 Connected {
23 tool_names: Vec<String>,
25 },
26 Failed {
28 error: String,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct McpServerInfo {
36 pub server_name: String,
38 pub transport: McpTransport,
40 pub state: McpConnectionState,
42 pub last_updated: DateTime<Utc>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
48#[serde(rename_all = "snake_case")]
49pub enum ContainerRuntime {
50 Docker,
51 Podman,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum WorkspaceConfig {
58 Local {
59 path: PathBuf,
60 },
61 Remote {
62 agent_address: String,
63 auth: Option<RemoteAuth>,
64 },
65 Container {
66 image: String,
67 runtime: ContainerRuntime,
68 },
69}
70
71impl WorkspaceConfig {
72 pub fn get_path(&self) -> Option<String> {
73 match self {
74 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
75 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
76 WorkspaceConfig::Container { .. } => None,
77 }
78 }
79
80 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
82 match self {
83 WorkspaceConfig::Local { path } => {
84 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
85 }
86 WorkspaceConfig::Remote {
87 agent_address,
88 auth,
89 } => steer_workspace::WorkspaceConfig::Remote {
90 address: agent_address.clone(),
91 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
92 },
93 WorkspaceConfig::Container { .. } => {
94 steer_workspace::WorkspaceConfig::Local {
96 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
97 }
98 }
99 }
100 }
101}
102
103impl Default for WorkspaceConfig {
104 fn default() -> Self {
105 Self::Local {
106 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct Session {
114 pub id: String,
115 pub created_at: DateTime<Utc>,
116 pub updated_at: DateTime<Utc>,
117 pub config: SessionConfig,
118 pub state: SessionState,
119}
120
121impl Session {
122 pub fn new(id: String, config: SessionConfig) -> Self {
123 let now = Utc::now();
124 Self {
125 id,
126 created_at: now,
127 updated_at: now,
128 config,
129 state: SessionState::default(),
130 }
131 }
132
133 pub fn update_timestamp(&mut self) {
134 self.updated_at = Utc::now();
135 }
136
137 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
139 let cutoff = Utc::now() - threshold;
140 self.updated_at > cutoff
141 }
142
143 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
145 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
151pub struct SessionConfig {
152 pub workspace: WorkspaceConfig,
153 pub tool_config: SessionToolConfig,
154 pub system_prompt: Option<String>,
157 pub metadata: HashMap<String, String>,
158}
159
160impl SessionConfig {
161 pub async fn build_registry(
165 &self,
166 llm_config_provider: Arc<LlmConfigProvider>,
167 workspace: Arc<dyn crate::workspace::Workspace>,
168 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
169 let mut registry = BackendRegistry::new();
170 let mut mcp_servers = HashMap::new();
171
172 for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
175 match backend_config {
176 BackendConfig::Local { tool_filter } => {
177 let backend = match tool_filter {
178 ToolFilter::All => {
179 LocalBackend::full(llm_config_provider.clone(), workspace.clone())
180 }
181 ToolFilter::Include(tools) => LocalBackend::with_tools(
182 tools.clone(),
183 llm_config_provider.clone(),
184 workspace.clone(),
185 ),
186 ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
187 excluded.clone(),
188 llm_config_provider.clone(),
189 workspace.clone(),
190 ),
191 };
192 registry
193 .register(format!("user_local_{idx}"), Arc::new(backend))
194 .await;
195 }
196 BackendConfig::Mcp {
197 server_name,
198 transport,
199 tool_filter,
200 } => {
201 tracing::info!(
202 "Attempting to initialize MCP backend '{}' with transport: {:?}",
203 server_name,
204 transport
205 );
206
207 let mut server_info = McpServerInfo {
209 server_name: server_name.clone(),
210 transport: transport.clone(),
211 state: McpConnectionState::Connecting,
212 last_updated: Utc::now(),
213 };
214
215 match crate::tools::McpBackend::new(
216 server_name.clone(),
217 transport.clone(),
218 tool_filter.clone(),
219 )
220 .await
221 {
222 Ok(mcp_backend) => {
223 let tool_names = mcp_backend.supported_tools().await;
224 let tool_count = tool_names.len();
225 tracing::info!(
226 "Successfully initialized MCP backend '{}' with {} tools",
227 server_name,
228 tool_count
229 );
230 server_info.state = McpConnectionState::Connected { tool_names };
231 server_info.last_updated = Utc::now();
232 registry
233 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
234 .await;
235 }
236 Err(e) => {
237 tracing::error!(
238 "Failed to initialize MCP backend '{}': {}",
239 server_name,
240 e
241 );
242 server_info.state = McpConnectionState::Failed {
243 error: e.to_string(),
244 };
245 server_info.last_updated = Utc::now();
246 }
247 }
248
249 mcp_servers.insert(server_name.clone(), server_info);
250 }
251 }
252 }
253
254 let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
257 if !server_backend.supported_tools().await.is_empty() {
258 registry
259 .register("server".to_string(), Arc::new(server_backend))
260 .await;
261 }
262
263 Ok((registry, mcp_servers))
266 }
267
268 pub fn filter_tools_by_visibility(
270 &self,
271 tools: Vec<steer_tools::ToolSchema>,
272 ) -> Vec<steer_tools::ToolSchema> {
273 match &self.tool_config.visibility {
274 ToolVisibility::All => tools,
275 ToolVisibility::ReadOnly => {
276 let read_only_names: HashSet<String> = read_only_workspace_tools()
277 .iter()
278 .map(|t| t.name().to_string())
279 .collect();
280
281 tools
282 .into_iter()
283 .filter(|schema| read_only_names.contains(&schema.name))
284 .collect()
285 }
286 ToolVisibility::Whitelist(allowed) => tools
287 .into_iter()
288 .filter(|schema| allowed.contains(&schema.name))
289 .collect(),
290 ToolVisibility::Blacklist(blocked) => tools
291 .into_iter()
292 .filter(|schema| !blocked.contains(&schema.name))
293 .collect(),
294 }
295 }
296
297 pub fn read_only() -> Self {
299 Self {
300 workspace: WorkspaceConfig::Local {
301 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
302 },
303 tool_config: SessionToolConfig::read_only(),
304 system_prompt: None,
305 metadata: HashMap::new(),
306 }
307 }
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
312#[serde(tag = "type", rename_all = "snake_case")]
313pub enum ToolVisibility {
314 All,
316
317 ReadOnly,
319
320 Whitelist(HashSet<String>),
322
323 Blacklist(HashSet<String>),
325}
326
327impl Default for ToolVisibility {
328 fn default() -> Self {
329 Self::All
330 }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
335pub struct BashToolConfig {
336 #[serde(default)]
338 pub approved_patterns: Vec<String>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
343#[serde(tag = "type", rename_all = "snake_case")]
344pub enum ToolApprovalPolicy {
345 AlwaysAsk,
347
348 PreApproved { tools: HashSet<String> },
350
351 Mixed {
353 pre_approved: HashSet<String>,
354 ask_for_others: bool,
355 },
356}
357
358impl ToolApprovalPolicy {
359 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
360 match self {
361 ToolApprovalPolicy::AlwaysAsk => false,
362 ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
363 ToolApprovalPolicy::Mixed {
364 pre_approved,
365 ask_for_others: _,
366 } => pre_approved.contains(tool_name),
367 }
368 }
369
370 pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
371 match self {
372 ToolApprovalPolicy::AlwaysAsk => true,
373 ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
374 ToolApprovalPolicy::Mixed {
375 pre_approved,
376 ask_for_others,
377 } => !pre_approved.contains(tool_name) && *ask_for_others,
378 }
379 }
380}
381
382#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
384pub enum RemoteAuth {
385 Bearer { token: String },
386 ApiKey { key: String },
387}
388
389impl RemoteAuth {
390 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
392 match self {
393 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
394 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
395 }
396 }
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
401#[serde(rename_all = "snake_case")]
402pub enum ToolFilter {
403 All,
405 Include(Vec<String>),
407 Exclude(Vec<String>),
409}
410
411impl Default for ToolFilter {
412 fn default() -> Self {
413 Self::All
414 }
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
419#[serde(tag = "type", rename_all = "snake_case")]
420pub enum BackendConfig {
421 Local {
422 tool_filter: ToolFilter,
424 },
425 Mcp {
426 server_name: String,
427 transport: McpTransport,
428 tool_filter: ToolFilter,
429 },
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
434pub struct SessionToolConfig {
435 pub backends: Vec<BackendConfig>,
437 pub visibility: ToolVisibility,
439 pub approval_policy: ToolApprovalPolicy,
441 pub metadata: HashMap<String, String>,
443 #[serde(default)]
445 pub tools: HashMap<String, ToolSpecificConfig>,
446}
447
448#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
450#[serde(untagged)]
451pub enum ToolSpecificConfig {
452 Bash(BashToolConfig),
454}
455
456impl Default for SessionToolConfig {
457 fn default() -> Self {
458 Self {
459 backends: Vec::new(),
460 visibility: ToolVisibility::All,
461 approval_policy: ToolApprovalPolicy::AlwaysAsk,
462 metadata: HashMap::new(),
463 tools: HashMap::new(),
464 }
465 }
466}
467
468impl SessionToolConfig {
469 pub fn read_only() -> Self {
471 Self {
472 backends: Vec::new(), visibility: ToolVisibility::ReadOnly,
474 approval_policy: ToolApprovalPolicy::AlwaysAsk,
475 metadata: HashMap::new(),
476 tools: HashMap::new(),
477 }
478 }
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize, Default)]
483pub struct SessionState {
484 pub messages: Vec<Message>,
486
487 pub tool_calls: HashMap<String, ToolCallState>,
489
490 pub approved_tools: HashSet<String>,
492
493 #[serde(default)]
495 pub approved_bash_patterns: HashSet<String>,
496
497 pub last_event_sequence: u64,
499
500 pub metadata: HashMap<String, String>,
502
503 #[serde(default, skip_serializing_if = "Option::is_none")]
506 pub active_message_id: Option<String>,
507
508 #[serde(default, skip_serializing, skip_deserializing)]
511 pub mcp_servers: HashMap<String, McpServerInfo>,
512}
513
514impl SessionState {
515 pub fn add_message(&mut self, message: Message) {
517 self.messages.push(message);
518 }
519
520 pub fn message_count(&self) -> usize {
522 self.messages.len()
523 }
524
525 pub fn last_message(&self) -> Option<&Message> {
527 self.messages.last()
528 }
529
530 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
532 let state = ToolCallState {
533 tool_call: tool_call.clone(),
534 status: ToolCallStatus::PendingApproval,
535 started_at: None,
536 completed_at: None,
537 result: None,
538 };
539 self.tool_calls.insert(tool_call.id, state);
540 }
541
542 pub fn update_tool_call_status(
544 &mut self,
545 tool_call_id: &str,
546 status: ToolCallStatus,
547 ) -> std::result::Result<(), String> {
548 let tool_call = self
549 .tool_calls
550 .get_mut(tool_call_id)
551 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
552
553 match (&tool_call.status, &status) {
555 (_, ToolCallStatus::Executing) => {
556 tool_call.started_at = Some(Utc::now());
557 }
558 (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
559 tool_call.completed_at = Some(Utc::now());
560 }
561 _ => {}
562 }
563
564 tool_call.status = status;
565 Ok(())
566 }
567
568 pub fn approve_tool(&mut self, tool_name: String) {
570 self.approved_tools.insert(tool_name);
571 }
572
573 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
575 self.approved_tools.contains(tool_name)
576 }
577
578 pub fn validate(&self) -> std::result::Result<(), String> {
580 for message in &self.messages {
582 let tool_calls = self.extract_tool_calls_from_message(message);
583 if !tool_calls.is_empty() {
584 for tool_call_id in tool_calls {
585 if !self.tool_calls.contains_key(&tool_call_id) {
586 return Err(format!(
587 "Message references unknown tool call: {tool_call_id}"
588 ));
589 }
590 }
591 }
592 }
593
594 Ok(())
595 }
596
597 fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
599 let mut tool_call_ids = Vec::new();
600
601 match &message.data {
602 MessageData::Assistant { content, .. } => {
603 for c in content {
604 if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
605 tool_call_ids.push(tool_call.id.clone());
606 }
607 }
608 }
609 MessageData::Tool { tool_use_id, .. } => {
610 tool_call_ids.push(tool_use_id.clone());
611 }
612 _ => {}
613 }
614
615 tool_call_ids
616 }
617
618 pub fn apply_event(
620 &mut self,
621 event: &crate::events::StreamEvent,
622 ) -> std::result::Result<(), String> {
623 use crate::events::StreamEvent;
624
625 match event {
626 StreamEvent::MessageComplete { message, .. } => {
627 self.add_message(message.clone());
628 }
629 StreamEvent::ToolCallStarted { tool_call, .. } => {
630 self.add_tool_call(tool_call.clone());
631 }
632 StreamEvent::ToolCallCompleted {
633 tool_call_id,
634 result,
635 ..
636 } => {
637 self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
638 if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
639 tool_call_state.result = Some(result.clone());
640 }
641 }
642 StreamEvent::ToolCallFailed {
643 tool_call_id,
644 error,
645 ..
646 } => {
647 self.update_tool_call_status(
648 tool_call_id,
649 ToolCallStatus::Failed {
650 error: error.clone(),
651 },
652 )?;
653 }
654 StreamEvent::ToolApprovalRequired { tool_call, .. } => {
655 if !self.tool_calls.contains_key(&tool_call.id) {
657 self.add_tool_call(tool_call.clone());
658 }
659 }
660 _ => {}
662 }
663
664 Ok(())
665 }
666}
667
668#[derive(Debug, Clone, Serialize, Deserialize)]
670pub struct ToolCallState {
671 pub tool_call: ToolCall,
672 pub status: ToolCallStatus,
673 pub started_at: Option<DateTime<Utc>>,
674 pub completed_at: Option<DateTime<Utc>>,
675 pub result: Option<ToolResult>,
676}
677
678impl ToolCallState {
679 pub fn is_pending(&self) -> bool {
680 matches!(self.status, ToolCallStatus::PendingApproval)
681 }
682
683 pub fn is_complete(&self) -> bool {
684 matches!(
685 self.status,
686 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
687 )
688 }
689
690 pub fn duration(&self) -> Option<chrono::Duration> {
691 match (self.started_at, self.completed_at) {
692 (Some(start), Some(end)) => Some(end - start),
693 _ => None,
694 }
695 }
696}
697
698#[derive(Debug, Clone, Serialize, Deserialize)]
700#[serde(tag = "status", rename_all = "snake_case")]
701pub enum ToolCallStatus {
702 PendingApproval,
703 Approved,
704 Denied,
705 Executing,
706 Completed,
707 Failed { error: String },
708}
709
710impl ToolCallStatus {
711 pub fn is_terminal(&self) -> bool {
712 matches!(
713 self,
714 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
715 )
716 }
717}
718
719#[derive(Debug, Clone, Serialize, Deserialize)]
721pub struct ToolExecutionStats {
722 #[serde(skip_serializing_if = "Option::is_none")]
723 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
725 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
728 pub execution_time_ms: u64,
729 pub metadata: HashMap<String, String>,
730}
731
732impl ToolExecutionStats {
733 pub fn success(output: String, execution_time_ms: u64) -> Self {
734 Self {
735 output: Some(output),
736 json_output: None,
737 result_type: None,
738 success: true,
739 execution_time_ms,
740 metadata: HashMap::new(),
741 }
742 }
743
744 pub fn success_typed(
745 json_output: serde_json::Value,
746 result_type: String,
747 execution_time_ms: u64,
748 ) -> Self {
749 Self {
750 output: None,
751 json_output: Some(json_output),
752 result_type: Some(result_type),
753 success: true,
754 execution_time_ms,
755 metadata: HashMap::new(),
756 }
757 }
758
759 pub fn failure(error: String, execution_time_ms: u64) -> Self {
760 Self {
761 output: Some(error),
762 json_output: None,
763 result_type: None,
764 success: false,
765 execution_time_ms,
766 metadata: HashMap::new(),
767 }
768 }
769
770 pub fn with_metadata(mut self, key: String, value: String) -> Self {
771 self.metadata.insert(key, value);
772 self
773 }
774}
775
776#[derive(Debug, Clone, Serialize, Deserialize)]
778pub struct SessionInfo {
779 pub id: String,
780 pub created_at: DateTime<Utc>,
781 pub updated_at: DateTime<Utc>,
782 pub last_model: Option<Model>,
784 pub message_count: usize,
785 pub metadata: HashMap<String, String>,
786}
787
788impl From<&Session> for SessionInfo {
789 fn from(session: &Session) -> Self {
790 Self {
791 id: session.id.clone(),
792 created_at: session.created_at,
793 updated_at: session.updated_at,
794 last_model: None, message_count: session.state.message_count(),
796 metadata: session.config.metadata.clone(),
797 }
798 }
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804 use crate::app::conversation::{Message, MessageData, UserContent};
805 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
806
807 #[test]
808 fn test_session_creation() {
809 let config = SessionConfig {
810 workspace: WorkspaceConfig::Local {
811 path: PathBuf::from("/test/path"),
812 },
813 tool_config: SessionToolConfig::default(),
814 system_prompt: None,
815 metadata: HashMap::new(),
816 };
817 let session = Session::new("test-session".to_string(), config.clone());
818
819 assert_eq!(session.id, "test-session");
820 assert!(
821 session
822 .config
823 .tool_config
824 .approval_policy
825 .should_ask_for_approval("any_tool")
826 );
827 assert_eq!(session.state.message_count(), 0);
828 }
829
830 #[test]
831 fn test_tool_approval_policy() {
832 let policy = ToolApprovalPolicy::PreApproved {
833 tools: ["read_file", "list_files"]
834 .iter()
835 .map(|s| s.to_string())
836 .collect(),
837 };
838
839 assert!(policy.is_tool_approved("read_file"));
840 assert!(!policy.is_tool_approved("write_file"));
841 assert!(!policy.should_ask_for_approval("read_file"));
842 assert!(policy.should_ask_for_approval("write_file"));
843 }
844
845 #[test]
846 fn test_session_state_validation() {
847 let mut state = SessionState::default();
848
849 assert!(state.validate().is_ok());
851
852 let message = Message {
854 data: MessageData::User {
855 content: vec![UserContent::Text {
856 text: "Hello".to_string(),
857 }],
858 },
859 timestamp: 123456789,
860 id: "msg1".to_string(),
861 parent_message_id: None,
862 };
863 state.add_message(message);
864
865 assert!(state.validate().is_ok());
866 assert_eq!(state.message_count(), 1);
867 }
868
869 #[test]
870 fn test_tool_call_state_tracking() {
871 let mut state = SessionState::default();
872
873 let tool_call = ToolCall {
874 id: "tool1".to_string(),
875 name: "read_file".to_string(),
876 parameters: serde_json::json!({"path": "/test.txt"}),
877 };
878
879 state.add_tool_call(tool_call.clone());
880 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
881
882 state
883 .update_tool_call_status("tool1", ToolCallStatus::Executing)
884 .unwrap();
885 let tool_state = state.tool_calls.get("tool1").unwrap();
886 assert!(tool_state.started_at.is_some());
887 assert!(!tool_state.is_complete());
888
889 state
890 .update_tool_call_status("tool1", ToolCallStatus::Completed)
891 .unwrap();
892 let tool_state = state.tool_calls.get("tool1").unwrap();
893 assert!(tool_state.completed_at.is_some());
894 assert!(tool_state.is_complete());
895 }
896
897 #[test]
898 fn test_session_tool_config_default() {
899 let config = SessionToolConfig::default();
900 assert!(config.backends.is_empty());
901 }
902
903 #[test]
904 fn test_tool_filter_exclude() {
905 let config = SessionToolConfig {
907 backends: vec![BackendConfig::Local {
908 tool_filter: ToolFilter::Exclude(vec![
909 BASH_TOOL_NAME.to_string(),
910 EDIT_TOOL_NAME.to_string(),
911 ]),
912 }],
913 visibility: ToolVisibility::All,
914 approval_policy: ToolApprovalPolicy::AlwaysAsk,
915 metadata: HashMap::new(),
916 tools: HashMap::new(),
917 };
918
919 assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
920 if let BackendConfig::Local { tool_filter } = &config.backends[0] {
921 assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
922 if let ToolFilter::Exclude(excluded_tools) = tool_filter {
923 assert_eq!(excluded_tools.len(), 2);
924 assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
925 assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
926 }
927 }
928 }
929
930 #[test]
931 fn test_session_tool_config_read_only() {
932 let config = SessionToolConfig::read_only();
933 assert_eq!(config.backends.len(), 0); assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
935 assert!(matches!(
936 config.approval_policy,
937 ToolApprovalPolicy::AlwaysAsk
938 ));
939 }
940
941 #[tokio::test]
942 async fn test_session_config_build_registry_server_tools() {
943 use crate::auth::DefaultAuthStorage;
944 use crate::config::LlmConfigProvider;
945
946 let config = SessionConfig {
948 workspace: WorkspaceConfig::Local {
949 path: PathBuf::from("/test/path"),
950 },
951 tool_config: SessionToolConfig::default(),
952 system_prompt: None,
953 metadata: HashMap::new(),
954 };
955
956 let auth_storage =
958 DefaultAuthStorage::new().expect("Failed to create auth storage for test");
959 let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
960
961 let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
963 .await
964 .unwrap();
965
966 let (registry, _mcp_servers) = config
967 .build_registry(llm_config_provider, workspace)
968 .await
969 .unwrap();
970 let schemas = registry.get_tool_schemas().await;
971 let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
972
973 assert!(tool_names.contains(&"dispatch_agent".to_string()));
975 assert!(tool_names.contains(&"web_fetch".to_string()));
976
977 let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
979 for tool_name in workspace_tool_names {
980 assert!(
981 !tool_names.contains(&tool_name.to_string()),
982 "Workspace tool {tool_name} should not be in registry"
983 );
984 }
985 }
986
987 #[test]
994 fn test_mcp_status_tracking() {
995 let mut session_state = SessionState::default();
997
998 let mcp_info = McpServerInfo {
1000 server_name: "test-server".to_string(),
1001 transport: crate::tools::McpTransport::Stdio {
1002 command: "python".to_string(),
1003 args: vec!["-m".to_string(), "test_server".to_string()],
1004 },
1005 state: McpConnectionState::Connected {
1006 tool_names: vec![
1007 "tool1".to_string(),
1008 "tool2".to_string(),
1009 "tool3".to_string(),
1010 "tool4".to_string(),
1011 "tool5".to_string(),
1012 ],
1013 },
1014 last_updated: Utc::now(),
1015 };
1016
1017 session_state
1018 .mcp_servers
1019 .insert("test-server".to_string(), mcp_info.clone());
1020
1021 assert_eq!(session_state.mcp_servers.len(), 1);
1023 let stored = session_state.mcp_servers.get("test-server").unwrap();
1024 assert_eq!(stored.server_name, "test-server");
1025 assert!(matches!(
1026 stored.state,
1027 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1028 ));
1029
1030 let failed_info = McpServerInfo {
1032 server_name: "failed-server".to_string(),
1033 transport: crate::tools::McpTransport::Tcp {
1034 host: "localhost".to_string(),
1035 port: 9999,
1036 },
1037 state: McpConnectionState::Failed {
1038 error: "Connection refused".to_string(),
1039 },
1040 last_updated: Utc::now(),
1041 };
1042
1043 session_state
1044 .mcp_servers
1045 .insert("failed-server".to_string(), failed_info);
1046 assert_eq!(session_state.mcp_servers.len(), 2);
1047 }
1048
1049 #[tokio::test]
1050 async fn test_mcp_server_tracking_in_build_registry() {
1051 use crate::auth::DefaultAuthStorage;
1052 use crate::config::LlmConfigProvider;
1053
1054 let mut config = SessionConfig::read_only();
1056
1057 config.tool_config.backends.push(BackendConfig::Mcp {
1059 server_name: "bad-server".to_string(),
1060 transport: crate::tools::McpTransport::Tcp {
1061 host: "nonexistent.invalid".to_string(),
1062 port: 12345,
1063 },
1064 tool_filter: ToolFilter::All,
1065 });
1066
1067 config.tool_config.backends.push(BackendConfig::Mcp {
1069 server_name: "good-server".to_string(),
1070 transport: crate::tools::McpTransport::Stdio {
1071 command: "echo".to_string(),
1072 args: vec!["test".to_string()],
1073 },
1074 tool_filter: ToolFilter::All,
1075 });
1076
1077 let auth_storage =
1078 DefaultAuthStorage::new().expect("Failed to create auth storage for test");
1079 let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
1080 let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
1081 .await
1082 .unwrap();
1083
1084 let (_registry, mcp_servers) = config
1085 .build_registry(llm_config_provider, workspace)
1086 .await
1087 .unwrap();
1088
1089 assert_eq!(mcp_servers.len(), 2);
1091
1092 let bad_server = mcp_servers.get("bad-server").unwrap();
1094 assert_eq!(bad_server.server_name, "bad-server");
1095 assert!(matches!(
1096 bad_server.state,
1097 McpConnectionState::Failed { .. }
1098 ));
1099
1100 let good_server = mcp_servers.get("good-server").unwrap();
1102 assert_eq!(good_server.server_name, "good-server");
1103 assert!(matches!(
1104 good_server.state,
1105 McpConnectionState::Failed { .. }
1106 ));
1107 }
1108
1109 #[test]
1110 fn test_backend_config_variants() {
1111 let local_config = BackendConfig::Local {
1113 tool_filter: ToolFilter::Include(vec![
1114 VIEW_TOOL_NAME.to_string(),
1115 LS_TOOL_NAME.to_string(),
1116 ]),
1117 };
1118
1119 assert!(matches!(local_config, BackendConfig::Local { .. }));
1120 if let BackendConfig::Local { tool_filter } = local_config {
1121 assert!(matches!(tool_filter, ToolFilter::Include(_)));
1122 if let ToolFilter::Include(tools) = tool_filter {
1123 assert_eq!(tools.len(), 2);
1124 }
1125 }
1126
1127 let mcp_config = BackendConfig::Mcp {
1131 server_name: "test-mcp".to_string(),
1132 transport: crate::tools::McpTransport::Stdio {
1133 command: "python".to_string(),
1134 args: vec!["-m".to_string(), "test_server".to_string()],
1135 },
1136 tool_filter: ToolFilter::All,
1137 };
1138
1139 assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
1140 if let BackendConfig::Mcp {
1141 server_name,
1142 transport,
1143 ..
1144 } = mcp_config
1145 {
1146 assert_eq!(server_name, "test-mcp");
1147 assert!(matches!(
1148 transport,
1149 crate::tools::McpTransport::Stdio { .. }
1150 ));
1151 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1152 assert_eq!(command, "python");
1153 assert_eq!(args.len(), 2);
1154 }
1155 }
1156 }
1157}