1use crate::error::Result;
2use chrono::{DateTime, Utc};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::path::PathBuf;
7use std::sync::Arc;
8
9use crate::api::Model;
10use crate::app::{Message, MessageData};
11use crate::config::LlmConfigProvider;
12use crate::tools::{BackendRegistry, LocalBackend, McpTransport, ToolBackend};
13use steer_tools::tools::read_only_workspace_tools;
14use steer_tools::{ToolCall, result::ToolResult};
15
16#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
18#[serde(rename_all = "snake_case")]
19pub enum ContainerRuntime {
20 Docker,
21 Podman,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
26#[serde(tag = "type", rename_all = "snake_case")]
27pub enum WorkspaceConfig {
28 Local {
29 path: PathBuf,
30 },
31 Remote {
32 agent_address: String,
33 auth: Option<RemoteAuth>,
34 },
35 Container {
36 image: String,
37 runtime: ContainerRuntime,
38 },
39}
40
41impl WorkspaceConfig {
42 pub fn get_path(&self) -> Option<String> {
43 match self {
44 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
45 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
46 WorkspaceConfig::Container { .. } => None,
47 }
48 }
49
50 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
52 match self {
53 WorkspaceConfig::Local { path } => {
54 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
55 }
56 WorkspaceConfig::Remote {
57 agent_address,
58 auth,
59 } => steer_workspace::WorkspaceConfig::Remote {
60 address: agent_address.clone(),
61 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
62 },
63 WorkspaceConfig::Container { .. } => {
64 steer_workspace::WorkspaceConfig::Local {
66 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
67 }
68 }
69 }
70 }
71}
72
73impl Default for WorkspaceConfig {
74 fn default() -> Self {
75 Self::Local {
76 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Session {
84 pub id: String,
85 pub created_at: DateTime<Utc>,
86 pub updated_at: DateTime<Utc>,
87 pub config: SessionConfig,
88 pub state: SessionState,
89}
90
91impl Session {
92 pub fn new(id: String, config: SessionConfig) -> Self {
93 let now = Utc::now();
94 Self {
95 id,
96 created_at: now,
97 updated_at: now,
98 config,
99 state: SessionState::default(),
100 }
101 }
102
103 pub fn update_timestamp(&mut self) {
104 self.updated_at = Utc::now();
105 }
106
107 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
109 let cutoff = Utc::now() - threshold;
110 self.updated_at > cutoff
111 }
112
113 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
115 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
121pub struct SessionConfig {
122 pub workspace: WorkspaceConfig,
123 pub tool_config: SessionToolConfig,
124 pub system_prompt: Option<String>,
127 pub metadata: HashMap<String, String>,
128}
129
130impl SessionConfig {
131 pub async fn build_registry(
134 &self,
135 llm_config_provider: Arc<LlmConfigProvider>,
136 workspace: Arc<dyn crate::workspace::Workspace>,
137 ) -> Result<BackendRegistry> {
138 let mut registry = BackendRegistry::new();
139
140 for (idx, backend_config) in self.tool_config.backends.iter().enumerate() {
143 match backend_config {
144 BackendConfig::Local { tool_filter } => {
145 let backend = match tool_filter {
146 ToolFilter::All => {
147 LocalBackend::full(llm_config_provider.clone(), workspace.clone())
148 }
149 ToolFilter::Include(tools) => LocalBackend::with_tools(
150 tools.clone(),
151 llm_config_provider.clone(),
152 workspace.clone(),
153 ),
154 ToolFilter::Exclude(excluded) => LocalBackend::without_tools(
155 excluded.clone(),
156 llm_config_provider.clone(),
157 workspace.clone(),
158 ),
159 };
160 registry
161 .register(format!("user_local_{idx}"), Arc::new(backend))
162 .await;
163 }
164 BackendConfig::Mcp {
165 server_name,
166 transport,
167 tool_filter,
168 } => {
169 tracing::info!(
170 "Attempting to initialize MCP backend '{}' with transport: {:?}",
171 server_name,
172 transport
173 );
174 match crate::tools::McpBackend::new(
175 server_name.clone(),
176 transport.clone(),
177 tool_filter.clone(),
178 )
179 .await
180 {
181 Ok(mcp_backend) => {
182 tracing::info!(
183 "Successfully initialized MCP backend '{}'",
184 server_name
185 );
186 registry
187 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
188 .await;
189 }
190 Err(e) => {
191 tracing::error!(
192 "Failed to initialize MCP backend '{}': {}",
193 server_name,
194 e
195 );
196 }
197 }
198 }
199 }
200 }
201
202 let server_backend = LocalBackend::server_only(llm_config_provider.clone(), workspace);
205 if !server_backend.supported_tools().await.is_empty() {
206 registry
207 .register("server".to_string(), Arc::new(server_backend))
208 .await;
209 }
210
211 Ok(registry)
214 }
215
216 pub fn filter_tools_by_visibility(
218 &self,
219 tools: Vec<steer_tools::ToolSchema>,
220 ) -> Vec<steer_tools::ToolSchema> {
221 match &self.tool_config.visibility {
222 ToolVisibility::All => tools,
223 ToolVisibility::ReadOnly => {
224 let read_only_names: HashSet<String> = read_only_workspace_tools()
225 .iter()
226 .map(|t| t.name().to_string())
227 .collect();
228
229 tools
230 .into_iter()
231 .filter(|schema| read_only_names.contains(&schema.name))
232 .collect()
233 }
234 ToolVisibility::Whitelist(allowed) => tools
235 .into_iter()
236 .filter(|schema| allowed.contains(&schema.name))
237 .collect(),
238 ToolVisibility::Blacklist(blocked) => tools
239 .into_iter()
240 .filter(|schema| !blocked.contains(&schema.name))
241 .collect(),
242 }
243 }
244
245 pub fn read_only() -> Self {
247 Self {
248 workspace: WorkspaceConfig::Local {
249 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
250 },
251 tool_config: SessionToolConfig::read_only(),
252 system_prompt: None,
253 metadata: HashMap::new(),
254 }
255 }
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
260#[serde(tag = "type", rename_all = "snake_case")]
261pub enum ToolVisibility {
262 All,
264
265 ReadOnly,
267
268 Whitelist(HashSet<String>),
270
271 Blacklist(HashSet<String>),
273}
274
275impl Default for ToolVisibility {
276 fn default() -> Self {
277 Self::All
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonSchema, Default)]
283pub struct BashToolConfig {
284 #[serde(default)]
286 pub approved_patterns: Vec<String>,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
291#[serde(tag = "type", rename_all = "snake_case")]
292pub enum ToolApprovalPolicy {
293 AlwaysAsk,
295
296 PreApproved { tools: HashSet<String> },
298
299 Mixed {
301 pre_approved: HashSet<String>,
302 ask_for_others: bool,
303 },
304}
305
306impl ToolApprovalPolicy {
307 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
308 match self {
309 ToolApprovalPolicy::AlwaysAsk => false,
310 ToolApprovalPolicy::PreApproved { tools } => tools.contains(tool_name),
311 ToolApprovalPolicy::Mixed {
312 pre_approved,
313 ask_for_others: _,
314 } => pre_approved.contains(tool_name),
315 }
316 }
317
318 pub fn should_ask_for_approval(&self, tool_name: &str) -> bool {
319 match self {
320 ToolApprovalPolicy::AlwaysAsk => true,
321 ToolApprovalPolicy::PreApproved { tools } => !tools.contains(tool_name),
322 ToolApprovalPolicy::Mixed {
323 pre_approved,
324 ask_for_others,
325 } => !pre_approved.contains(tool_name) && *ask_for_others,
326 }
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
332pub enum RemoteAuth {
333 Bearer { token: String },
334 ApiKey { key: String },
335}
336
337impl RemoteAuth {
338 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
340 match self {
341 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
342 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
343 }
344 }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
349#[serde(rename_all = "snake_case")]
350pub enum ToolFilter {
351 All,
353 Include(Vec<String>),
355 Exclude(Vec<String>),
357}
358
359impl Default for ToolFilter {
360 fn default() -> Self {
361 Self::All
362 }
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
367#[serde(tag = "type", rename_all = "snake_case")]
368pub enum BackendConfig {
369 Local {
370 tool_filter: ToolFilter,
372 },
373 Mcp {
374 server_name: String,
375 transport: McpTransport,
376 tool_filter: ToolFilter,
377 },
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
382pub struct SessionToolConfig {
383 pub backends: Vec<BackendConfig>,
385 pub visibility: ToolVisibility,
387 pub approval_policy: ToolApprovalPolicy,
389 pub metadata: HashMap<String, String>,
391 #[serde(default)]
393 pub tools: HashMap<String, ToolSpecificConfig>,
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
398#[serde(untagged)]
399pub enum ToolSpecificConfig {
400 Bash(BashToolConfig),
402}
403
404impl Default for SessionToolConfig {
405 fn default() -> Self {
406 Self {
407 backends: Vec::new(),
408 visibility: ToolVisibility::All,
409 approval_policy: ToolApprovalPolicy::AlwaysAsk,
410 metadata: HashMap::new(),
411 tools: HashMap::new(),
412 }
413 }
414}
415
416impl SessionToolConfig {
417 pub fn read_only() -> Self {
419 Self {
420 backends: Vec::new(), visibility: ToolVisibility::ReadOnly,
422 approval_policy: ToolApprovalPolicy::AlwaysAsk,
423 metadata: HashMap::new(),
424 tools: HashMap::new(),
425 }
426 }
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize, Default)]
431pub struct SessionState {
432 pub messages: Vec<Message>,
434
435 pub tool_calls: HashMap<String, ToolCallState>,
437
438 pub approved_tools: HashSet<String>,
440
441 #[serde(default)]
443 pub approved_bash_patterns: HashSet<String>,
444
445 pub last_event_sequence: u64,
447
448 pub metadata: HashMap<String, String>,
450
451 #[serde(default, skip_serializing_if = "Option::is_none")]
454 pub active_message_id: Option<String>,
455}
456
457impl SessionState {
458 pub fn add_message(&mut self, message: Message) {
460 self.messages.push(message);
461 }
462
463 pub fn message_count(&self) -> usize {
465 self.messages.len()
466 }
467
468 pub fn last_message(&self) -> Option<&Message> {
470 self.messages.last()
471 }
472
473 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
475 let state = ToolCallState {
476 tool_call: tool_call.clone(),
477 status: ToolCallStatus::PendingApproval,
478 started_at: None,
479 completed_at: None,
480 result: None,
481 };
482 self.tool_calls.insert(tool_call.id, state);
483 }
484
485 pub fn update_tool_call_status(
487 &mut self,
488 tool_call_id: &str,
489 status: ToolCallStatus,
490 ) -> std::result::Result<(), String> {
491 let tool_call = self
492 .tool_calls
493 .get_mut(tool_call_id)
494 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
495
496 match (&tool_call.status, &status) {
498 (_, ToolCallStatus::Executing) => {
499 tool_call.started_at = Some(Utc::now());
500 }
501 (_, ToolCallStatus::Completed) | (_, ToolCallStatus::Failed { .. }) => {
502 tool_call.completed_at = Some(Utc::now());
503 }
504 _ => {}
505 }
506
507 tool_call.status = status;
508 Ok(())
509 }
510
511 pub fn approve_tool(&mut self, tool_name: String) {
513 self.approved_tools.insert(tool_name);
514 }
515
516 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
518 self.approved_tools.contains(tool_name)
519 }
520
521 pub fn validate(&self) -> std::result::Result<(), String> {
523 for message in &self.messages {
525 let tool_calls = self.extract_tool_calls_from_message(message);
526 if !tool_calls.is_empty() {
527 for tool_call_id in tool_calls {
528 if !self.tool_calls.contains_key(&tool_call_id) {
529 return Err(format!(
530 "Message references unknown tool call: {tool_call_id}"
531 ));
532 }
533 }
534 }
535 }
536
537 Ok(())
538 }
539
540 fn extract_tool_calls_from_message(&self, message: &Message) -> Vec<String> {
542 let mut tool_call_ids = Vec::new();
543
544 match &message.data {
545 MessageData::Assistant { content, .. } => {
546 for c in content {
547 if let crate::app::conversation::AssistantContent::ToolCall { tool_call } = c {
548 tool_call_ids.push(tool_call.id.clone());
549 }
550 }
551 }
552 MessageData::Tool { tool_use_id, .. } => {
553 tool_call_ids.push(tool_use_id.clone());
554 }
555 _ => {}
556 }
557
558 tool_call_ids
559 }
560
561 pub fn apply_event(
563 &mut self,
564 event: &crate::events::StreamEvent,
565 ) -> std::result::Result<(), String> {
566 use crate::events::StreamEvent;
567
568 match event {
569 StreamEvent::MessageComplete { message, .. } => {
570 self.add_message(message.clone());
571 }
572 StreamEvent::ToolCallStarted { tool_call, .. } => {
573 self.add_tool_call(tool_call.clone());
574 }
575 StreamEvent::ToolCallCompleted {
576 tool_call_id,
577 result,
578 ..
579 } => {
580 self.update_tool_call_status(tool_call_id, ToolCallStatus::Completed)?;
581 if let Some(tool_call_state) = self.tool_calls.get_mut(tool_call_id) {
582 tool_call_state.result = Some(result.clone());
583 }
584 }
585 StreamEvent::ToolCallFailed {
586 tool_call_id,
587 error,
588 ..
589 } => {
590 self.update_tool_call_status(
591 tool_call_id,
592 ToolCallStatus::Failed {
593 error: error.clone(),
594 },
595 )?;
596 }
597 StreamEvent::ToolApprovalRequired { tool_call, .. } => {
598 if !self.tool_calls.contains_key(&tool_call.id) {
600 self.add_tool_call(tool_call.clone());
601 }
602 }
603 _ => {}
605 }
606
607 Ok(())
608 }
609}
610
611#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct ToolCallState {
614 pub tool_call: ToolCall,
615 pub status: ToolCallStatus,
616 pub started_at: Option<DateTime<Utc>>,
617 pub completed_at: Option<DateTime<Utc>>,
618 pub result: Option<ToolResult>,
619}
620
621impl ToolCallState {
622 pub fn is_pending(&self) -> bool {
623 matches!(self.status, ToolCallStatus::PendingApproval)
624 }
625
626 pub fn is_complete(&self) -> bool {
627 matches!(
628 self.status,
629 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
630 )
631 }
632
633 pub fn duration(&self) -> Option<chrono::Duration> {
634 match (self.started_at, self.completed_at) {
635 (Some(start), Some(end)) => Some(end - start),
636 _ => None,
637 }
638 }
639}
640
641#[derive(Debug, Clone, Serialize, Deserialize)]
643#[serde(tag = "status", rename_all = "snake_case")]
644pub enum ToolCallStatus {
645 PendingApproval,
646 Approved,
647 Denied,
648 Executing,
649 Completed,
650 Failed { error: String },
651}
652
653impl ToolCallStatus {
654 pub fn is_terminal(&self) -> bool {
655 matches!(
656 self,
657 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
658 )
659 }
660}
661
662#[derive(Debug, Clone, Serialize, Deserialize)]
664pub struct ToolExecutionStats {
665 #[serde(skip_serializing_if = "Option::is_none")]
666 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
668 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
671 pub execution_time_ms: u64,
672 pub metadata: HashMap<String, String>,
673}
674
675impl ToolExecutionStats {
676 pub fn success(output: String, execution_time_ms: u64) -> Self {
677 Self {
678 output: Some(output),
679 json_output: None,
680 result_type: None,
681 success: true,
682 execution_time_ms,
683 metadata: HashMap::new(),
684 }
685 }
686
687 pub fn success_typed(
688 json_output: serde_json::Value,
689 result_type: String,
690 execution_time_ms: u64,
691 ) -> Self {
692 Self {
693 output: None,
694 json_output: Some(json_output),
695 result_type: Some(result_type),
696 success: true,
697 execution_time_ms,
698 metadata: HashMap::new(),
699 }
700 }
701
702 pub fn failure(error: String, execution_time_ms: u64) -> Self {
703 Self {
704 output: Some(error),
705 json_output: None,
706 result_type: None,
707 success: false,
708 execution_time_ms,
709 metadata: HashMap::new(),
710 }
711 }
712
713 pub fn with_metadata(mut self, key: String, value: String) -> Self {
714 self.metadata.insert(key, value);
715 self
716 }
717}
718
719#[derive(Debug, Clone, Serialize, Deserialize)]
721pub struct SessionInfo {
722 pub id: String,
723 pub created_at: DateTime<Utc>,
724 pub updated_at: DateTime<Utc>,
725 pub last_model: Option<Model>,
727 pub message_count: usize,
728 pub metadata: HashMap<String, String>,
729}
730
731impl From<&Session> for SessionInfo {
732 fn from(session: &Session) -> Self {
733 Self {
734 id: session.id.clone(),
735 created_at: session.created_at,
736 updated_at: session.updated_at,
737 last_model: None, message_count: session.state.message_count(),
739 metadata: session.config.metadata.clone(),
740 }
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747 use crate::app::conversation::{Message, MessageData, UserContent};
748 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME, LS_TOOL_NAME, VIEW_TOOL_NAME};
749
750 #[test]
751 fn test_session_creation() {
752 let config = SessionConfig {
753 workspace: WorkspaceConfig::Local {
754 path: PathBuf::from("/test/path"),
755 },
756 tool_config: SessionToolConfig::default(),
757 system_prompt: None,
758 metadata: HashMap::new(),
759 };
760 let session = Session::new("test-session".to_string(), config.clone());
761
762 assert_eq!(session.id, "test-session");
763 assert!(
764 session
765 .config
766 .tool_config
767 .approval_policy
768 .should_ask_for_approval("any_tool")
769 );
770 assert_eq!(session.state.message_count(), 0);
771 }
772
773 #[test]
774 fn test_tool_approval_policy() {
775 let policy = ToolApprovalPolicy::PreApproved {
776 tools: ["read_file", "list_files"]
777 .iter()
778 .map(|s| s.to_string())
779 .collect(),
780 };
781
782 assert!(policy.is_tool_approved("read_file"));
783 assert!(!policy.is_tool_approved("write_file"));
784 assert!(!policy.should_ask_for_approval("read_file"));
785 assert!(policy.should_ask_for_approval("write_file"));
786 }
787
788 #[test]
789 fn test_session_state_validation() {
790 let mut state = SessionState::default();
791
792 assert!(state.validate().is_ok());
794
795 let message = Message {
797 data: MessageData::User {
798 content: vec![UserContent::Text {
799 text: "Hello".to_string(),
800 }],
801 },
802 timestamp: 123456789,
803 id: "msg1".to_string(),
804 parent_message_id: None,
805 };
806 state.add_message(message);
807
808 assert!(state.validate().is_ok());
809 assert_eq!(state.message_count(), 1);
810 }
811
812 #[test]
813 fn test_tool_call_state_tracking() {
814 let mut state = SessionState::default();
815
816 let tool_call = ToolCall {
817 id: "tool1".to_string(),
818 name: "read_file".to_string(),
819 parameters: serde_json::json!({"path": "/test.txt"}),
820 };
821
822 state.add_tool_call(tool_call.clone());
823 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
824
825 state
826 .update_tool_call_status("tool1", ToolCallStatus::Executing)
827 .unwrap();
828 let tool_state = state.tool_calls.get("tool1").unwrap();
829 assert!(tool_state.started_at.is_some());
830 assert!(!tool_state.is_complete());
831
832 state
833 .update_tool_call_status("tool1", ToolCallStatus::Completed)
834 .unwrap();
835 let tool_state = state.tool_calls.get("tool1").unwrap();
836 assert!(tool_state.completed_at.is_some());
837 assert!(tool_state.is_complete());
838 }
839
840 #[test]
841 fn test_session_tool_config_default() {
842 let config = SessionToolConfig::default();
843 assert!(config.backends.is_empty());
844 }
845
846 #[test]
847 fn test_tool_filter_exclude() {
848 let config = SessionToolConfig {
850 backends: vec![BackendConfig::Local {
851 tool_filter: ToolFilter::Exclude(vec![
852 BASH_TOOL_NAME.to_string(),
853 EDIT_TOOL_NAME.to_string(),
854 ]),
855 }],
856 visibility: ToolVisibility::All,
857 approval_policy: ToolApprovalPolicy::AlwaysAsk,
858 metadata: HashMap::new(),
859 tools: HashMap::new(),
860 };
861
862 assert!(matches!(config.backends[0], BackendConfig::Local { .. }));
863 if let BackendConfig::Local { tool_filter } = &config.backends[0] {
864 assert!(matches!(tool_filter, ToolFilter::Exclude(_)));
865 if let ToolFilter::Exclude(excluded_tools) = tool_filter {
866 assert_eq!(excluded_tools.len(), 2);
867 assert!(excluded_tools.contains(&BASH_TOOL_NAME.to_string()));
868 assert!(excluded_tools.contains(&EDIT_TOOL_NAME.to_string()));
869 }
870 }
871 }
872
873 #[test]
874 fn test_session_tool_config_read_only() {
875 let config = SessionToolConfig::read_only();
876 assert_eq!(config.backends.len(), 0); assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
878 assert!(matches!(
879 config.approval_policy,
880 ToolApprovalPolicy::AlwaysAsk
881 ));
882 }
883
884 #[tokio::test]
885 async fn test_session_config_build_registry_server_tools() {
886 use crate::auth::DefaultAuthStorage;
887 use crate::config::LlmConfigProvider;
888
889 let config = SessionConfig {
891 workspace: WorkspaceConfig::Local {
892 path: PathBuf::from("/test/path"),
893 },
894 tool_config: SessionToolConfig::default(),
895 system_prompt: None,
896 metadata: HashMap::new(),
897 };
898
899 let auth_storage =
901 DefaultAuthStorage::new().expect("Failed to create auth storage for test");
902 let llm_config_provider = Arc::new(LlmConfigProvider::new(Arc::new(auth_storage)));
903
904 let workspace = crate::workspace::create_workspace(&config.workspace.to_workspace_config())
906 .await
907 .unwrap();
908
909 let registry = config
910 .build_registry(llm_config_provider, workspace)
911 .await
912 .unwrap();
913 let schemas = registry.get_tool_schemas().await;
914 let tool_names: Vec<String> = schemas.iter().map(|s| s.name.clone()).collect();
915
916 assert!(tool_names.contains(&"dispatch_agent".to_string()));
918 assert!(tool_names.contains(&"web_fetch".to_string()));
919
920 let workspace_tool_names = vec!["bash", "grep", "glob", "ls", "read", "write", "edit"];
922 for tool_name in workspace_tool_names {
923 assert!(
924 !tool_names.contains(&tool_name.to_string()),
925 "Workspace tool {tool_name} should not be in registry"
926 );
927 }
928 }
929
930 #[test]
937 fn test_backend_config_variants() {
938 let local_config = BackendConfig::Local {
940 tool_filter: ToolFilter::Include(vec![
941 VIEW_TOOL_NAME.to_string(),
942 LS_TOOL_NAME.to_string(),
943 ]),
944 };
945
946 assert!(matches!(local_config, BackendConfig::Local { .. }));
947 if let BackendConfig::Local { tool_filter } = local_config {
948 assert!(matches!(tool_filter, ToolFilter::Include(_)));
949 if let ToolFilter::Include(tools) = tool_filter {
950 assert_eq!(tools.len(), 2);
951 }
952 }
953
954 let mcp_config = BackendConfig::Mcp {
958 server_name: "test-mcp".to_string(),
959 transport: crate::tools::McpTransport::Stdio {
960 command: "python".to_string(),
961 args: vec!["-m".to_string(), "test_server".to_string()],
962 },
963 tool_filter: ToolFilter::All,
964 };
965
966 assert!(matches!(mcp_config, BackendConfig::Mcp { .. }));
967 if let BackendConfig::Mcp {
968 server_name,
969 transport,
970 ..
971 } = mcp_config
972 {
973 assert_eq!(server_name, "test-mcp");
974 assert!(matches!(
975 transport,
976 crate::tools::McpTransport::Stdio { .. }
977 ));
978 if let crate::tools::McpTransport::Stdio { command, args } = transport {
979 assert_eq!(command, "python");
980 assert_eq!(args.len(), 2);
981 }
982 }
983 }
984}