1use crate::config::model::ModelId;
2use crate::error::Result;
3use chrono::{DateTime, Utc};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16#[derive(Debug, Clone)]
18pub enum McpConnectionState {
19 Connecting,
21 Connected {
23 tool_names: Vec<String>,
25 },
26 Failed {
28 error: String,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct McpServerInfo {
36 pub server_name: String,
38 pub transport: McpTransport,
40 pub state: McpConnectionState,
42 pub last_updated: DateTime<Utc>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
48#[serde(tag = "type", rename_all = "snake_case")]
49pub enum WorkspaceConfig {
50 Local {
51 path: PathBuf,
52 },
53 Remote {
54 agent_address: String,
55 auth: Option<RemoteAuth>,
56 },
57}
58
59impl WorkspaceConfig {
60 pub fn get_path(&self) -> Option<String> {
61 match self {
62 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
63 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
64 }
65 }
66
67 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
69 match self {
70 WorkspaceConfig::Local { path } => {
71 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
72 }
73 WorkspaceConfig::Remote {
74 agent_address,
75 auth,
76 } => steer_workspace::WorkspaceConfig::Remote {
77 address: agent_address.clone(),
78 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
79 },
80 }
81 }
82}
83
84impl Default for WorkspaceConfig {
85 fn default() -> Self {
86 Self::Local {
87 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct Session {
95 pub id: String,
96 pub created_at: DateTime<Utc>,
97 pub updated_at: DateTime<Utc>,
98 pub config: SessionConfig,
99 pub state: SessionState,
100}
101
102impl Session {
103 pub fn new(id: String, config: SessionConfig) -> Self {
104 let now = Utc::now();
105 Self {
106 id,
107 created_at: now,
108 updated_at: now,
109 config,
110 state: SessionState::default(),
111 }
112 }
113
114 pub fn update_timestamp(&mut self) {
115 self.updated_at = Utc::now();
116 }
117
118 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
120 let cutoff = Utc::now() - threshold;
121 self.updated_at > cutoff
122 }
123
124 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
126 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
132pub struct SessionConfig {
133 pub workspace: WorkspaceConfig,
134 pub tool_config: SessionToolConfig,
135 pub system_prompt: Option<String>,
138 pub metadata: HashMap<String, String>,
139}
140
141impl SessionConfig {
142 pub async fn build_registry(
146 &self,
147 llm_config_provider: Arc<LlmConfigProvider>,
148 workspace: Arc<dyn crate::workspace::Workspace>,
149 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
150 let mut registry = BackendRegistry::new();
151 let mut mcp_servers = HashMap::new();
152
153 for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
156 match backend_config {
157 BackendConfig::Local { tool_filter } => {
158 let backend = match tool_filter {
159 ToolFilter::All => {
160 LocalBackend::full(llm_config_provider.clone(), workspace.clone())
161 }
162 ToolFilter::Include(tools) => LocalBackend::with_tools(
163 tools.clone(),
164 llm_config_provider.clone(),
165 workspace.clone(),
166 ),
167 ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
168 excluded.clone(),
169 llm_config_provider.clone(),
170 workspace.clone(),
171 ),
172 };
173 registry
174 .register(format!("user_local_{idx}"), Arc::new(backend))
175 .await;
176 }
177 BackendConfig::Mcp {
178 server_name,
179 transport,
180 tool_filter,
181 } => {
182 tracing::info!(
183 "Attempting to initialize MCP backend '{}' with transport: {:?}",
184 server_name,
185 transport
186 );
187
188 let mut server_info = McpServerInfo {
190 server_name: server_name.clone(),
191 transport: transport.clone(),
192 state: McpConnectionState::Connecting,
193 last_updated: Utc::now(),
194 };
195
196 match crate::tools::McpBackend::new(
197 server_name.clone(),
198 transport.clone(),
199 tool_filter.clone(),
200 )
201 .await
202 {
203 Ok(mcp_backend) => {
204 let tool_names = mcp_backend.supported_tools().await;
205 let tool_count = tool_names.len();
206 tracing::info!(
207 "Successfully initialized MCP backend '{}' with {} tools",
208 server_name,
209 tool_count
210 );
211 server_info.state = McpConnectionState::Connected { tool_names };
212 server_info.last_updated = Utc::now();
213 registry
214 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
215 .await;
216 }
217 Err(e) => {
218 tracing::error!(
219 "Failed to initialize MCP backend '{}': {}",
220 server_name,
221 e
222 );
223 server_info.state = McpConnectionState::Failed {
224 error: e.to_string(),
225 };
226 server_info.last_updated = Utc::now();
227 }
228 }
229
230 mcp_servers.insert(server_name.clone(), server_info);
231 }
232 }
233 }
234
235 let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
238 if !server_backend.supported_tools().await.is_empty() {
239 registry
240 .register("server".to_string(), Arc::new(server_backend))
241 .await;
242 }
243
244 Ok((registry, mcp_servers))
247 }
248
249 pub fn filter_tools_by_visibility(
251 &self,
252 tools: Vec<steer_tools::ToolSchema>,
253 ) -> Vec<steer_tools::ToolSchema> {
254 match &self.tool_config.visibility {
255 ToolVisibility::All => tools,
256 ToolVisibility::ReadOnly => {
257 let read_only_names: HashSet<String> = read_only_workspace_tools()
258 .iter()
259 .map(|t| t.name().to_string())
260 .collect();
261
262 tools
263 .into_iter()
264 .filter(|schema| read_only_names.contains(&schema.name))
265 .collect()
266 }
267 ToolVisibility::Whitelist(allowed) => tools
268 .into_iter()
269 .filter(|schema| allowed.contains(&schema.name))
270 .collect(),
271 ToolVisibility::Blacklist(blocked) => tools
272 .into_iter()
273 .filter(|schema| !blocked.contains(&schema.name))
274 .collect(),
275 }
276 }
277
278 pub fn read_only() -> Self {
280 Self {
281 workspace: WorkspaceConfig::Local {
282 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
283 },
284 tool_config: SessionToolConfig::read_only(),
285 system_prompt: None,
286 metadata: HashMap::new(),
287 }
288 }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
293#[serde(tag = "type", rename_all = "snake_case")]
294pub enum ToolVisibility {
295 All,
297
298 ReadOnly,
300
301 Whitelist(HashSet<String>),
303
304 Blacklist(HashSet<String>),
306}
307
308impl Default for ToolVisibility {
309 fn default() -> Self {
310 Self::All
311 }
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
316pub struct BashToolConfig {
317 #[serde(default)]
319 pub approved_patterns: Vec<String>,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
324#[serde(tag = "type", rename_all = "snake_case")]
325pub enum ToolApprovalPolicy {
326 AlwaysAsk,
328
329 PreApproved { tools: HashSet<String> },
331
332 Mixed {
334 pre_approved: HashSet<String>,
335 ask_for_others: bool,
336 },
337}
338
339impl ToolApprovalPolicy {
340 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
341 match self {
342 ToolApprovalPolicy::AlwaysAsk => false,
343 ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
344 ToolApprovalPolicy::Mixed {
345 pre_approved,
346 ask_for_others: _,
347 } => pre_approved.contains(tool_name),
348 }
349 }
350
351 pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
352 match self {
353 ToolApprovalPolicy::AlwaysAsk => true,
354 ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
355 ToolApprovalPolicy::Mixed {
356 pre_approved,
357 ask_for_others,
358 } => !pre_approved.contains(tool_name) && *ask_for_others,
359 }
360 }
361}
362
363#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
365pub enum RemoteAuth {
366 Bearer { token: String },
367 ApiKey { key: String },
368}
369
370impl RemoteAuth {
371 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
373 match self {
374 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
375 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
376 }
377 }
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
382#[serde(rename_all = "snake_case")]
383pub enum ToolFilter {
384 All,
386 Include(Vec<String>),
388 Exclude(Vec<String>),
390}
391
392impl Default for ToolFilter {
393 fn default() -> Self {
394 Self::All
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
400#[serde(tag = "type", rename_all = "snake_case")]
401pub enum BackendConfig {
402 Local {
403 tool_filter: ToolFilter,
405 },
406 Mcp {
407 server_name: String,
408 transport: McpTransport,
409 tool_filter: ToolFilter,
410 },
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
415pub struct SessionToolConfig {
416 pub backends: Vec<BackendConfig>,
418 pub visibility: ToolVisibility,
420 pub approval_policy: ToolApprovalPolicy,
422 pub metadata: HashMap<String, String>,
424 #[serde(default)]
426 pub tools: HashMap<String, ToolSpecificConfig>,
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
431#[serde(untagged)]
432pub enum ToolSpecificConfig {
433 Bash(BashToolConfig),
435}
436
437impl Default for SessionToolConfig {
438 fn default() -> Self {
439 Self {
440 backends: Vec::new(),
441 visibility: ToolVisibility::All,
442 approval_policy: ToolApprovalPolicy::AlwaysAsk,
443 metadata: HashMap::new(),
444 tools: HashMap::new(),
445 }
446 }
447}
448
449impl SessionToolConfig {
450 pub fn read_only() -> Self {
452 Self {
453 backends: Vec::new(), visibility: ToolVisibility::ReadOnly,
455 approval_policy: ToolApprovalPolicy::AlwaysAsk,
456 metadata: HashMap::new(),
457 tools: HashMap::new(),
458 }
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize, Default)]
464pub struct SessionState {
465 pub messages: Vec<Message>,
467
468 pub tool_calls: HashMap<String, ToolCallState>,
470
471 pub approved_tools: HashSet<String>,
473
474 #[serde(default)]
476 pub approved_bash_patterns: HashSet<String>,
477
478 pub last_event_sequence: u64,
480
481 pub metadata: HashMap<String, String>,
483
484 #[serde(default, skip_serializing_if = "Option::is_none")]
487 pub active_message_id: Option<String>,
488
489 #[serde(default, skip_serializing, skip_deserializing)]
492 pub mcp_servers: HashMap<String, McpServerInfo>,
493}
494
495impl SessionState {
496 pub fn add_message(&mut self, message: Message) {
498 self.messages.push(message);
499 }
500
501 pub fn message_count(&self) -> usize {
503 self.messages.len()
504 }
505
506 pub fn last_message(&self) -> Option<&Message> {
508 self.messages.last()
509 }
510
511 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
513 let state = ToolCallState {
514 tool_call: tool_call.clone(),
515 status: ToolCallStatus::PendingApproval,
516 started_at: None,
517 completed_at: None,
518 result: None,
519 };
520 self.tool_calls.insert(tool_call.id, state);
521 }
522
523 pub fn update_tool_call_status(
525 &mut self,
526 tool_call_id: &str,
527 status: ToolCallStatus,
528 ) -> std::result::Result<(), String> {
529 let tool_call = self
530 .tool_calls
531 .get_mut(tool_call_id)
532 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
533
534 match (&tool_call.status, &status) {
536 (_, ToolCallStatus::Executing) => {
537 tool_call.started_at = Some(Utc::now());
538 }
539 (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
540 tool_call.completed_at = Some(Utc::now());
541 }
542 _ => {}
543 }
544
545 tool_call.status = status;
546 Ok(())
547 }
548
549 pub fn approve_tool(&mut self, tool_name: String) {
551 self.approved_tools.insert(tool_name);
552 }
553
554 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
556 self.approved_tools.contains(tool_name)
557 }
558
559 pub fn validate(&self) -> std::result::Result<(), String> {
561 for message in &self.messages {
563 let tool_calls = self.extract_tool_calls_from_message(message);
564 if !tool_calls.is_empty() {
565 for tool_call_id in tool_calls {
566 if !self.tool_calls.contains_key(&tool_call_id) {
567 return Err(format!(
568 "Message references unknown tool call: {tool_call_id}"
569 ));
570 }
571 }
572 }
573 }
574
575 Ok(())
576 }
577
578 fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
580 let mut tool_call_ids = Vec::new();
581
582 match &message.data {
583 MessageData::Assistant { content, .. } => {
584 for c in content {
585 if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
586 tool_call_ids.push(tool_call.id.clone());
587 }
588 }
589 }
590 MessageData::Tool { tool_use_id, .. } => {
591 tool_call_ids.push(tool_use_id.clone());
592 }
593 _ => {}
594 }
595
596 tool_call_ids
597 }
598
599 pub fn apply_event(
601 &mut self,
602 event: &crate::events::StreamEvent,
603 ) -> std::result::Result<(), String> {
604 use crate::events::StreamEvent;
605
606 match event {
607 StreamEvent::MessageComplete { message, .. } => {
608 self.add_message(message.clone());
609 }
610 StreamEvent::ToolCallStarted { tool_call, .. } => {
611 self.add_tool_call(tool_call.clone());
612 }
613 StreamEvent::ToolCallCompleted {
614 tool_call_id,
615 result,
616 ..
617 } => {
618 self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
619 if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
620 tool_call_state.result = Some(result.clone());
621 }
622 }
623 StreamEvent::ToolCallFailed {
624 tool_call_id,
625 error,
626 ..
627 } => {
628 self.update_tool_call_status(
629 tool_call_id,
630 ToolCallStatus::Failed {
631 error: error.clone(),
632 },
633 )?;
634 }
635 StreamEvent::ToolApprovalRequired { tool_call, .. } => {
636 if !self.tool_calls.contains_key(&tool_call.id) {
638 self.add_tool_call(tool_call.clone());
639 }
640 }
641 _ => {}
643 }
644
645 Ok(())
646 }
647}
648
649#[derive(Debug, Clone, Serialize, Deserialize)]
651pub struct ToolCallState {
652 pub tool_call: ToolCall,
653 pub status: ToolCallStatus,
654 pub started_at: Option<DateTime<Utc>>,
655 pub completed_at: Option<DateTime<Utc>>,
656 pub result: Option<ToolResult>,
657}
658
659impl ToolCallState {
660 pub fn is_pending(&self) -> bool {
661 matches!(self.status, ToolCallStatus::PendingApproval)
662 }
663
664 pub fn is_complete(&self) -> bool {
665 matches!(
666 self.status,
667 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
668 )
669 }
670
671 pub fn duration(&self) -> Option<chrono::Duration> {
672 match (self.started_at, self.completed_at) {
673 (Some(start), Some(end)) => Some(end - start),
674 _ => None,
675 }
676 }
677}
678
679#[derive(Debug, Clone, Serialize, Deserialize)]
681#[serde(tag = "status", rename_all = "snake_case")]
682pub enum ToolCallStatus {
683 PendingApproval,
684 Approved,
685 Denied,
686 Executing,
687 Completed,
688 Failed { error: String },
689}
690
691impl ToolCallStatus {
692 pub fn is_terminal(&self) -> bool {
693 matches!(
694 self,
695 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
696 )
697 }
698}
699
700#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct ToolExecutionStats {
703 #[serde(skip_serializing_if = "Option::is_none")]
704 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
706 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
709 pub execution_time_ms: u64,
710 pub metadata: HashMap<String, String>,
711}
712
713impl ToolExecutionStats {
714 pub fn success(output: String, execution_time_ms: u64) -> Self {
715 Self {
716 output: Some(output),
717 json_output: None,
718 result_type: None,
719 success: true,
720 execution_time_ms,
721 metadata: HashMap::new(),
722 }
723 }
724
725 pub fn success_typed(
726 json_output: serde_json::Value,
727 result_type: String,
728 execution_time_ms: u64,
729 ) -> Self {
730 Self {
731 output: None,
732 json_output: Some(json_output),
733 result_type: Some(result_type),
734 success: true,
735 execution_time_ms,
736 metadata: HashMap::new(),
737 }
738 }
739
740 pub fn failure(error: String, execution_time_ms: u64) -> Self {
741 Self {
742 output: Some(error),
743 json_output: None,
744 result_type: None,
745 success: false,
746 execution_time_ms,
747 metadata: HashMap::new(),
748 }
749 }
750
751 pub fn with_metadata(mut self, key: String, value: String) -> Self {
752 self.metadata.insert(key, value);
753 self
754 }
755}
756
757#[derive(Debug, Clone, Serialize, Deserialize)]
759pub struct SessionInfo {
760 pub id: String,
761 pub created_at: DateTime<Utc>,
762 pub updated_at: DateTime<Utc>,
763 pub last_model: Option<ModelId>,
765 pub message_count: usize,
766 pub metadata: HashMap<String, String>,
767}
768
769impl From<&Session> for SessionInfo {
770 fn from(session: &Session) -> Self {
771 Self {
772 id: session.id.clone(),
773 created_at: session.created_at,
774 updated_at: session.updated_at,
775 last_model: None, message_count: session.state.message_count(),
777 metadata: session.config.metadata.clone(),
778 }
779 }
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785 use crate::app::conversation::{Message, MessageData, UserContent};
786 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
787
788 #[test]
789 fn test_session_creation() {
790 let config = SessionConfig {
791 workspace: WorkspaceConfig::Local {
792 path: PathBuf::from("/test/path"),
793 },
794 tool_config: SessionToolConfig::default(),
795 system_prompt: None,
796 metadata: HashMap::new(),
797 };
798 let session = Session::new("test-session".to_string(), config.clone());
799
800 assert_eq!(session.id, "test-session");
801 assert!(
802 session
803 .config
804 .tool_config
805 .approval_policy
806 .should_ask_for_approval("any_tool")
807 );
808 assert_eq!(session.state.message_count(), 0);
809 }
810
811 #[test]
812 fn test_tool_approval_policy() {
813 let policy = ToolApprovalPolicy::PreApproved {
814 tools: ["read_file", "list_files"]
815 .iter()
816 .map(|s| s.to_string())
817 .collect(),
818 };
819
820 assert!(policy.is_tool_approved("read_file"));
821 assert!(!policy.is_tool_approved("write_file"));
822 assert!(!policy.should_ask_for_approval("read_file"));
823 assert!(policy.should_ask_for_approval("write_file"));
824 }
825
826 #[test]
827 fn test_session_state_validation() {
828 let mut state = SessionState::default();
829
830 assert!(state.validate().is_ok());
832
833 let message = Message {
835 data: MessageData::User {
836 content: vec![UserContent::Text {
837 text: "Hello".to_string(),
838 }],
839 },
840 timestamp: 123456789,
841 id: "msg1".to_string(),
842 parent_message_id: None,
843 };
844 state.add_message(message);
845
846 assert!(state.validate().is_ok());
847 assert_eq!(state.message_count(), 1);
848 }
849
850 #[test]
851 fn test_tool_call_state_tracking() {
852 let mut state = SessionState::default();
853
854 let tool_call = ToolCall {
855 id: "tool1".to_string(),
856 name: "read_file".to_string(),
857 parameters: serde_json::json!({"path": "/test.txt"}),
858 };
859
860 state.add_tool_call(tool_call.clone());
861 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
862
863 state
864 .update_tool_call_status("tool1", ToolCallStatus::Executing)
865 .unwrap();
866 let tool_state = state.tool_calls.get("tool1").unwrap();
867 assert!(tool_state.started_at.is_some());
868 assert!(!tool_state.is_complete());
869
870 state
871 .update_tool_call_status("tool1", ToolCallStatus::Completed)
872 .unwrap();
873 let tool_state = state.tool_calls.get("tool1").unwrap();
874 assert!(tool_state.completed_at.is_some());
875 assert!(tool_state.is_complete());
876 }
877
878 #[test]
879 fn test_session_tool_config_default() {
880 let config = SessionToolConfig::default();
881 assert!(config.backends.is_empty());
882 }
883
884 #[test]
885 fn test_tool_filter_exclude() {
886 let config = SessionToolConfig {
888 backends: vec![BackendConfig::Local {
889 tool_filter: ToolFilter::Exclude(vec![
890 BASH_TOOL_NAME.to_string(),
891 EDIT_TOOL_NAME.to_string(),
892 ]),
893 }],
894 visibility: ToolVisibility::All,
895 approval_policy: ToolApprovalPolicy::AlwaysAsk,
896 metadata: HashMap::new(),
897 tools: HashMap::new(),
898 };
899
900 assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
901 if let BackendConfig::Local { tool_filter } = &config.backends[0] {
902 assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
903 if let ToolFilter::Exclude(excluded_tools) = tool_filter {
904 assert_eq!(excluded_tools.len(), 2);
905 assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
906 assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
907 }
908 }
909 }
910
911 #[test]
912 fn test_session_tool_config_read_only() {
913 let config = SessionToolConfig::read_only();
914 assert_eq!(config.backends.len(), 0); assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
916 assert!(matches!(
917 config.approval_policy,
918 ToolApprovalPolicy::AlwaysAsk
919 ));
920 }
921
922 #[tokio::test]
923 async fn test_session_config_build_registry_server_tools() {
924 use crate::auth::DefaultAuthStorage;
925 use crate::config::LlmConfigProvider;
926
927 let config = SessionConfig {
929 workspace: WorkspaceConfig::Local {
930 path: PathBuf::from("/test/path"),
931 },
932 tool_config: SessionToolConfig::default(),
933 system_prompt: None,
934 metadata: HashMap::new(),
935 };
936
937 let auth_storage =
939 DefaultAuthStorage::new().expect("Failed to create auth storage for test");
940 let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
941
942 let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
944 .await
945 .unwrap();
946
947 let (registry, _mcp_servers) = config
948 .build_registry(llm_config_provider, workspace)
949 .await
950 .unwrap();
951 let schemas = registry.get_tool_schemas().await;
952 let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
953
954 assert!(tool_names.contains(&"dispatch_agent".to_string()));
956 assert!(tool_names.contains(&"web_fetch".to_string()));
957
958 let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
960 for tool_name in workspace_tool_names {
961 assert!(
962 !tool_names.contains(&tool_name.to_string()),
963 "Workspace tool {tool_name} should not be in registry"
964 );
965 }
966 }
967
968 #[test]
975 fn test_mcp_status_tracking() {
976 let mut session_state = SessionState::default();
978
979 let mcp_info = McpServerInfo {
981 server_name: "test-server".to_string(),
982 transport: crate::tools::McpTransport::Stdio {
983 command: "python".to_string(),
984 args: vec!["-m".to_string(), "test_server".to_string()],
985 },
986 state: McpConnectionState::Connected {
987 tool_names: vec![
988 "tool1".to_string(),
989 "tool2".to_string(),
990 "tool3".to_string(),
991 "tool4".to_string(),
992 "tool5".to_string(),
993 ],
994 },
995 last_updated: Utc::now(),
996 };
997
998 session_state
999 .mcp_servers
1000 .insert("test-server".to_string(), mcp_info.clone());
1001
1002 assert_eq!(session_state.mcp_servers.len(), 1);
1004 let stored = session_state.mcp_servers.get("test-server").unwrap();
1005 assert_eq!(stored.server_name, "test-server");
1006 assert!(matches!(
1007 stored.state,
1008 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1009 ));
1010
1011 let failed_info = McpServerInfo {
1013 server_name: "failed-server".to_string(),
1014 transport: crate::tools::McpTransport::Tcp {
1015 host: "localhost".to_string(),
1016 port: 9999,
1017 },
1018 state: McpConnectionState::Failed {
1019 error: "Connection refused".to_string(),
1020 },
1021 last_updated: Utc::now(),
1022 };
1023
1024 session_state
1025 .mcp_servers
1026 .insert("failed-server".to_string(), failed_info);
1027 assert_eq!(session_state.mcp_servers.len(), 2);
1028 }
1029
1030 #[tokio::test]
1031 async fn test_mcp_server_tracking_in_build_registry() {
1032 use crate::auth::DefaultAuthStorage;
1033 use crate::config::LlmConfigProvider;
1034
1035 let mut config = SessionConfig::read_only();
1037
1038 config.tool_config.backends.push(BackendConfig::Mcp {
1040 server_name: "bad-server".to_string(),
1041 transport: crate::tools::McpTransport::Tcp {
1042 host: "nonexistent.invalid".to_string(),
1043 port: 12345,
1044 },
1045 tool_filter: ToolFilter::All,
1046 });
1047
1048 config.tool_config.backends.push(BackendConfig::Mcp {
1050 server_name: "good-server".to_string(),
1051 transport: crate::tools::McpTransport::Stdio {
1052 command: "echo".to_string(),
1053 args: vec!["test".to_string()],
1054 },
1055 tool_filter: ToolFilter::All,
1056 });
1057
1058 let auth_storage =
1059 DefaultAuthStorage::new().expect("Failed to create auth storage for test");
1060 let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
1061 let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
1062 .await
1063 .unwrap();
1064
1065 let (_registry, mcp_servers) = config
1066 .build_registry(llm_config_provider, workspace)
1067 .await
1068 .unwrap();
1069
1070 assert_eq!(mcp_servers.len(), 2);
1072
1073 let bad_server = mcp_servers.get("bad-server").unwrap();
1075 assert_eq!(bad_server.server_name, "bad-server");
1076 assert!(matches!(
1077 bad_server.state,
1078 McpConnectionState::Failed { .. }
1079 ));
1080
1081 let good_server = mcp_servers.get("good-server").unwrap();
1083 assert_eq!(good_server.server_name, "good-server");
1084 assert!(matches!(
1085 good_server.state,
1086 McpConnectionState::Failed { .. }
1087 ));
1088 }
1089
1090 #[test]
1091 fn test_backend_config_variants() {
1092 let local_config = BackendConfig::Local {
1094 tool_filter: ToolFilter::Include(vec![
1095 VIEW_TOOL_NAME.to_string(),
1096 LS_TOOL_NAME.to_string(),
1097 ]),
1098 };
1099
1100 assert!(matches!(local_config, BackendConfig::Local { .. }));
1101 if let BackendConfig::Local { tool_filter } = local_config {
1102 assert!(matches!(tool_filter, ToolFilter::Include(_)));
1103 if let ToolFilter::Include(tools) = tool_filter {
1104 assert_eq!(tools.len(), 2);
1105 }
1106 }
1107
1108 let mcp_config = BackendConfig::Mcp {
1110 server_name: "test-mcp".to_string(),
1111 transport: crate::tools::McpTransport::Stdio {
1112 command: "python".to_string(),
1113 args: vec!["-m".to_string(), "test_server".to_string()],
1114 },
1115 tool_filter: ToolFilter::All,
1116 };
1117
1118 assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
1119 if let BackendConfig::Mcp {
1120 server_name,
1121 transport,
1122 ..
1123 } = mcp_config
1124 {
1125 assert_eq!(server_name, "test-mcp");
1126 assert!(matches!(
1127 transport,
1128 crate::tools::McpTransport::Stdio { .. }
1129 ));
1130 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1131 assert_eq!(command, "python");
1132 assert_eq!(args.len(), 2);
1133 }
1134 }
1135 }
1136}