Skip to main content

steer_cli/
session_config.rs

1use eyre::{Context, Result};
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::path::PathBuf;
6use steer_core::session::{
7    BackendConfig, BashToolConfig, ContainerRuntime, RemoteAuth, SessionConfig, SessionToolConfig,
8    ToolApprovalPolicy, ToolSpecificConfig, ToolVisibility, WorkspaceConfig,
9};
10use thiserror::Error;
11use tokio::fs;
12use tracing::debug;
13
14/// Session configuration validation errors
15#[derive(Debug, Error)]
16pub enum SessionConfigError {
17    #[error("MCP backend server_name cannot be empty")]
18    EmptyServerName,
19
20    #[error("MCP stdio transport command cannot be empty")]
21    EmptyStdioCommand,
22
23    #[error("MCP TCP transport host cannot be empty")]
24    EmptyTcpHost,
25
26    #[error("MCP TCP transport port cannot be 0")]
27    InvalidTcpPort,
28
29    #[error("MCP Unix transport path cannot be empty")]
30    EmptyUnixPath,
31
32    #[error("MCP SSE transport url cannot be empty")]
33    EmptySseUrl,
34
35    #[error("MCP HTTP transport url cannot be empty")]
36    EmptyHttpUrl,
37
38    #[error("IO error: {0}")]
39    Io(#[from] std::io::Error),
40
41    #[error("TOML parse error: {0}")]
42    TomlParse(#[from] toml::de::Error),
43}
44
45/// Partial session configuration that can be loaded from a TOML file.
46/// All fields are optional so users can specify only what they want to override.
47#[derive(Debug, Deserialize, Serialize, Default, JsonSchema)]
48pub struct PartialSessionConfig {
49    #[schemars(description = "URL to the JSON schema file")]
50    #[serde(rename = "$schema")]
51    pub schema: Option<String>,
52    pub workspace: Option<PartialWorkspaceConfig>,
53    pub tool_config: Option<PartialToolConfig>,
54    pub system_prompt: Option<String>,
55    pub metadata: Option<HashMap<String, String>>,
56}
57
58#[derive(Debug, Deserialize, Serialize, JsonSchema)]
59#[serde(tag = "type", rename_all = "snake_case")]
60pub enum PartialWorkspaceConfig {
61    Local {
62        #[serde(default)]
63        path: Option<PathBuf>,
64    },
65    Remote {
66        agent_address: String,
67        auth: Option<RemoteAuth>,
68    },
69    Container {
70        image: String,
71        runtime: ContainerRuntime,
72    },
73}
74
75#[derive(Debug, Deserialize, Serialize, Default, JsonSchema)]
76#[schemars(deny_unknown_fields)]
77pub struct PartialToolConfig {
78    pub backends: Option<Vec<BackendConfig>>,
79    pub visibility: Option<ToolVisibilityConfig>,
80    pub approval_policy: Option<ToolApprovalPolicyConfig>,
81    #[serde(default)]
82    pub tools: Option<HashMap<String, PartialToolSpecificConfig>>,
83}
84
85#[derive(Debug, Deserialize, Serialize, JsonSchema)]
86#[serde(untagged)]
87pub enum PartialToolSpecificConfig {
88    Bash(PartialBashConfig),
89}
90
91#[derive(Debug, Deserialize, Serialize, JsonSchema)]
92pub struct PartialBashConfig {
93    pub approved_patterns: Option<Vec<String>>,
94}
95
96#[derive(Debug, Deserialize, Serialize, JsonSchema)]
97#[serde(untagged)]
98pub enum ToolVisibilityConfig {
99    String(String), // "all" or "read_only"
100    Object(ToolVisibilityObject),
101}
102
103#[derive(Debug, Deserialize, Serialize, JsonSchema)]
104#[serde(rename_all = "snake_case")]
105pub enum ToolVisibilityObject {
106    Whitelist(HashSet<String>),
107    Blacklist(HashSet<String>),
108}
109
110#[derive(Debug, Deserialize, Serialize, JsonSchema)]
111#[serde(untagged)]
112pub enum ToolApprovalPolicyConfig {
113    String(String),             // "always_ask"
114    Tagged(ToolApprovalPolicy), // Direct deserialization for tagged enum format
115}
116
117/// Overrides that can be applied from CLI arguments
118#[derive(Debug, Default)]
119pub struct SessionConfigOverrides {
120    pub system_prompt: Option<String>,
121    pub metadata: Option<String>,
122}
123
124/// Loads session configuration from files and applies overrides
125pub struct SessionConfigLoader {
126    config_path: Option<PathBuf>,
127    overrides: SessionConfigOverrides,
128}
129
130impl SessionConfigLoader {
131    pub fn new(config_path: Option<PathBuf>) -> Self {
132        debug!("Loading session config from: {:?}", config_path);
133        Self {
134            config_path,
135            overrides: SessionConfigOverrides::default(),
136        }
137    }
138
139    pub fn with_overrides(mut self, overrides: SessionConfigOverrides) -> Self {
140        self.overrides = overrides;
141        self
142    }
143
144    pub async fn load(&self) -> Result<SessionConfig> {
145        let mut config = if let Some(path) = &self.config_path {
146            // Load from TOML file
147            let content = fs::read_to_string(path)
148                .await
149                .with_context(|| format!("Failed to read config file: {}", path.display()))?;
150
151            let partial: PartialSessionConfig = toml::from_str(&content)
152                .with_context(|| format!("Failed to parse TOML config from: {}", path.display()))?;
153
154            self.partial_to_full(partial)?
155        } else {
156            // Use defaults
157            SessionConfig {
158                workspace: WorkspaceConfig::default(),
159                tool_config: SessionToolConfig::default(),
160                system_prompt: None,
161                metadata: HashMap::new(),
162            }
163        };
164
165        self.apply_overrides(&mut config)?;
166        self.validate_config(&config)?;
167
168        Ok(config)
169    }
170
171    fn partial_to_full(&self, partial: PartialSessionConfig) -> Result<SessionConfig> {
172        let workspace = match partial.workspace {
173            Some(PartialWorkspaceConfig::Local { path }) => WorkspaceConfig::Local {
174                path: path.unwrap_or_else(|| {
175                    std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
176                }),
177            },
178            Some(PartialWorkspaceConfig::Remote {
179                agent_address,
180                auth,
181            }) => WorkspaceConfig::Remote {
182                agent_address,
183                auth,
184            },
185            Some(PartialWorkspaceConfig::Container { image, runtime }) => {
186                WorkspaceConfig::Container { image, runtime }
187            }
188            None => WorkspaceConfig::default(),
189        };
190
191        let tool_config = if let Some(partial_tool_config) = partial.tool_config {
192            let backends = partial_tool_config.backends.unwrap_or_default();
193
194            let visibility = match partial_tool_config.visibility {
195                Some(ToolVisibilityConfig::String(s)) => match s.as_str() {
196                    "all" => ToolVisibility::All,
197                    "read_only" => ToolVisibility::ReadOnly,
198                    _ => {
199                        return Err(eyre::eyre!(
200                            "Invalid visibility string: {}. Expected 'all' or 'read_only'",
201                            s
202                        ));
203                    }
204                },
205                Some(ToolVisibilityConfig::Object(obj)) => match obj {
206                    ToolVisibilityObject::Whitelist(tools) => ToolVisibility::Whitelist(tools),
207                    ToolVisibilityObject::Blacklist(tools) => ToolVisibility::Blacklist(tools),
208                },
209                None => ToolVisibility::default(),
210            };
211
212            let approval_policy = match partial_tool_config.approval_policy {
213                Some(ToolApprovalPolicyConfig::String(s)) => match s.as_str() {
214                    "always_ask" => ToolApprovalPolicy::AlwaysAsk,
215                    _ => {
216                        return Err(eyre::eyre!(
217                            "Invalid approval policy string: {}. Expected 'always_ask'",
218                            s
219                        ));
220                    }
221                },
222                Some(ToolApprovalPolicyConfig::Tagged(policy)) => policy,
223                None => ToolApprovalPolicy::AlwaysAsk,
224            };
225
226            // Process tool-specific configs
227            let mut tools = HashMap::new();
228            if let Some(partial_tools) = partial_tool_config.tools {
229                for (tool_name, tool_config) in partial_tools {
230                    match tool_config {
231                        PartialToolSpecificConfig::Bash(bash_config) => {
232                            if tool_name == "bash" {
233                                tools.insert(
234                                    "bash".to_string(),
235                                    ToolSpecificConfig::Bash(BashToolConfig {
236                                        approved_patterns: bash_config
237                                            .approved_patterns
238                                            .unwrap_or_default(),
239                                    }),
240                                );
241                            }
242                        }
243                    }
244                }
245            }
246
247            SessionToolConfig {
248                backends,
249                visibility,
250                approval_policy,
251                metadata: HashMap::new(),
252                tools,
253            }
254        } else {
255            SessionToolConfig::default()
256        };
257
258        debug!("Loaded tool config: {:?}", tool_config);
259
260        Ok(SessionConfig {
261            workspace,
262            tool_config,
263            system_prompt: partial.system_prompt,
264            metadata: partial.metadata.unwrap_or_default(),
265        })
266    }
267
268    fn apply_overrides(&self, config: &mut SessionConfig) -> Result<()> {
269        // Apply system prompt override
270        if let Some(system_prompt) = &self.overrides.system_prompt {
271            config.system_prompt = Some(system_prompt.clone());
272        }
273
274        // Apply metadata overrides
275        if let Some(metadata_str) = &self.overrides.metadata {
276            let metadata = steer_core::utils::session::parse_metadata(Some(metadata_str))?;
277            config.metadata.extend(metadata);
278        }
279
280        Ok(())
281    }
282
283    fn validate_config(&self, config: &SessionConfig) -> Result<(), SessionConfigError> {
284        // Validate MCP backends have required fields
285        for backend in &config.tool_config.backends {
286            if let BackendConfig::Mcp {
287                server_name,
288                transport,
289                ..
290            } = backend
291            {
292                if server_name.is_empty() {
293                    return Err(SessionConfigError::EmptyServerName);
294                }
295
296                // Validate transport-specific requirements
297                match transport {
298                    steer_core::tools::McpTransport::Stdio { command, .. } => {
299                        if command.is_empty() {
300                            return Err(SessionConfigError::EmptyStdioCommand);
301                        }
302                        // Check if command exists in PATH
303                        if which::which(command).is_err() {
304                            // Log warning but don't fail - the command might be a full path or available later
305                            tracing::warn!(
306                                "MCP command '{}' for server '{}' not found in PATH",
307                                command,
308                                server_name
309                            );
310                        }
311                    }
312                    steer_core::tools::McpTransport::Tcp { host, port } => {
313                        if host.is_empty() {
314                            return Err(SessionConfigError::EmptyTcpHost);
315                        }
316                        if *port == 0 {
317                            return Err(SessionConfigError::InvalidTcpPort);
318                        }
319                    }
320                    #[cfg(unix)]
321                    steer_core::tools::McpTransport::Unix { path } => {
322                        if path.is_empty() {
323                            return Err(SessionConfigError::EmptyUnixPath);
324                        }
325                    }
326                    steer_core::tools::McpTransport::Sse { url, .. } => {
327                        if url.is_empty() {
328                            return Err(SessionConfigError::EmptySseUrl);
329                        }
330                    }
331                    steer_core::tools::McpTransport::Http { url, .. } => {
332                        if url.is_empty() {
333                            return Err(SessionConfigError::EmptyHttpUrl);
334                        }
335                    }
336                }
337            }
338        }
339
340        Ok(())
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use steer_core::session::ToolFilter;
348
349    #[tokio::test]
350    async fn test_backend_serialization() {
351        // Test that we can serialize and deserialize BackendConfig
352        let backend = BackendConfig::Mcp {
353            server_name: "test".to_string(),
354            transport: steer_core::tools::McpTransport::Stdio {
355                command: "python".to_string(),
356                args: vec!["-m".to_string(), "test".to_string()],
357            },
358            tool_filter: ToolFilter::All,
359        };
360
361        let json = serde_json::to_string(&backend).unwrap();
362        println!("Backend JSON: {json}");
363
364        let backend2: BackendConfig = serde_json::from_str(&json).unwrap();
365        match backend2 {
366            BackendConfig::Mcp {
367                server_name,
368                transport,
369                ..
370            } => {
371                assert_eq!(server_name, "test");
372                match transport {
373                    steer_core::tools::McpTransport::Stdio { command, args } => {
374                        assert_eq!(command, "python");
375                        assert_eq!(args, vec!["-m", "test"]);
376                    }
377                    _ => unreachable!("Expected Stdio transport"),
378                }
379            }
380            _ => unreachable!("Expected correct variant"),
381        }
382    }
383
384    #[tokio::test]
385    async fn test_partial_config_parsing() {
386        // Test simple config without backends
387        let toml_content = r#"
388[tool_config]
389visibility = "all"
390approval_policy = "always_ask"
391"#;
392
393        let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
394        assert!(partial.tool_config.is_some());
395    }
396
397    #[tokio::test]
398    async fn test_config_with_empty_backends() {
399        // Test config with empty backends array
400        let toml_content = r#"
401[tool_config]
402backends = []
403visibility = "all"
404"#;
405
406        let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
407        assert!(partial.tool_config.is_some());
408
409        let tool_config = partial.tool_config.unwrap();
410        assert!(tool_config.backends.is_some());
411        assert_eq!(tool_config.backends.unwrap().len(), 0);
412    }
413
414    #[tokio::test]
415    async fn test_full_config_parsing() {
416        let toml_content = r#"
417system_prompt = "You are a helpful assistant."
418
419[workspace]
420type = "local"
421
422[tool_config]
423visibility = "read_only"
424
425[metadata]
426project = "my-project"
427"#;
428
429        let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
430        assert!(partial.workspace.is_some());
431        assert!(partial.system_prompt.is_some());
432        assert_eq!(
433            partial.system_prompt.unwrap(),
434            "You are a helpful assistant."
435        );
436        assert!(partial.metadata.is_some());
437    }
438
439    #[tokio::test]
440    async fn test_config_loader() {
441        let loader = SessionConfigLoader::new(None);
442        let config = loader.load().await.unwrap();
443
444        // Should get defaults
445        assert!(matches!(config.workspace, WorkspaceConfig::Local { .. }));
446        assert!(matches!(
447            config.tool_config.approval_policy,
448            ToolApprovalPolicy::AlwaysAsk
449        ));
450    }
451
452    #[tokio::test]
453    async fn test_config_loader_with_overrides() {
454        let overrides = SessionConfigOverrides {
455            system_prompt: Some("Custom prompt".to_string()),
456            metadata: Some("key1=value1,key2=value2".to_string()),
457        };
458
459        let loader = SessionConfigLoader::new(None).with_overrides(overrides);
460        let config = loader.load().await.unwrap();
461
462        assert_eq!(config.system_prompt, Some("Custom prompt".to_string()));
463        assert_eq!(config.metadata.get("key1"), Some(&"value1".to_string()));
464    }
465
466    #[tokio::test]
467    async fn test_load_non_existent_file() {
468        let loader = SessionConfigLoader::new(Some(PathBuf::from("/tmp/non_existent_file.toml")));
469        let result = loader.load().await;
470
471        assert!(result.is_err());
472        let err = result.unwrap_err();
473        assert!(err.to_string().contains("Failed to read config file"));
474    }
475
476    #[tokio::test]
477    async fn test_load_invalid_toml() {
478        use std::io::Write;
479        use tempfile::NamedTempFile;
480
481        let mut temp_file = NamedTempFile::new().unwrap();
482        writeln!(temp_file, "invalid toml syntax {{").unwrap();
483
484        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
485        let result = loader.load().await;
486
487        assert!(result.is_err());
488        let err = result.unwrap_err();
489        assert!(err.to_string().contains("Failed to parse TOML config"));
490    }
491
492    #[tokio::test]
493    async fn test_invalid_visibility_config() {
494        use std::io::Write;
495        use tempfile::NamedTempFile;
496
497        let mut temp_file = NamedTempFile::new().unwrap();
498        writeln!(
499            temp_file,
500            r#"
501[tool_config]
502visibility = "invalid_value"
503"#
504        )
505        .unwrap();
506
507        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
508        let result = loader.load().await;
509
510        assert!(result.is_err());
511        let err = result.unwrap_err();
512        assert!(err.to_string().contains("Invalid visibility string"));
513        assert!(err.to_string().contains("Expected 'all' or 'read_only'"));
514    }
515
516    #[tokio::test]
517    async fn test_invalid_approval_policy_config() {
518        use std::io::Write;
519        use tempfile::NamedTempFile;
520
521        let mut temp_file = NamedTempFile::new().unwrap();
522        writeln!(
523            temp_file,
524            r#"
525[tool_config]
526approval_policy = "invalid_policy"
527"#
528        )
529        .unwrap();
530
531        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
532        let result = loader.load().await;
533
534        assert!(result.is_err());
535        let err = result.unwrap_err();
536        assert!(err.to_string().contains("Invalid approval policy string"));
537        assert!(err.to_string().contains("Expected 'always_ask'"));
538    }
539
540    #[tokio::test]
541    async fn test_mcp_backend_validation_empty_server_name() {
542        use std::io::Write;
543        use tempfile::NamedTempFile;
544
545        let mut temp_file = NamedTempFile::new().unwrap();
546        writeln!(temp_file, r#"
547[tool_config]
548backends = [
549  {{ type = "mcp", server_name = "", transport = {{ type = "stdio", command = "python", args = ["-m", "test"] }}, tool_filter = "all" }}
550]
551"#).unwrap();
552
553        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
554        let result = loader.load().await;
555
556        assert!(result.is_err());
557        let err = result.unwrap_err();
558        assert!(
559            err.to_string()
560                .contains("MCP backend server_name cannot be empty")
561        );
562    }
563
564    #[tokio::test]
565    async fn test_mcp_backend_validation_empty_command() {
566        use std::io::Write;
567        use tempfile::NamedTempFile;
568
569        let mut temp_file = NamedTempFile::new().unwrap();
570        writeln!(temp_file, r#"
571[tool_config]
572backends = [
573  {{ type = "mcp", server_name = "test", transport = {{ type = "stdio", command = "", args = ["-m", "test"] }}, tool_filter = "all" }}
574]
575"#).unwrap();
576
577        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
578        let result = loader.load().await;
579
580        assert!(result.is_err());
581        let err = result.unwrap_err();
582        assert!(
583            err.to_string()
584                .contains("MCP stdio transport command cannot be empty")
585        );
586    }
587
588    #[tokio::test]
589    async fn test_file_config_with_cli_overrides() {
590        use std::io::Write;
591        use tempfile::NamedTempFile;
592
593        // Create a config file with initial values
594        let mut temp_file = NamedTempFile::new().unwrap();
595        writeln!(
596            temp_file,
597            r#"
598system_prompt = "Original prompt"
599
600[tool_config]
601visibility = "all"
602approval_policy = "always_ask"
603
604[metadata]
605key1 = "original1"
606key2 = "original2"
607"#
608        )
609        .unwrap();
610
611        // Apply CLI overrides
612        let overrides = SessionConfigOverrides {
613            system_prompt: Some("Overridden prompt".to_string()),
614            metadata: Some("key2=overridden2,key3=new3".to_string()),
615        };
616
617        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()))
618            .with_overrides(overrides);
619        let config = loader.load().await.unwrap();
620
621        // Check that overrides were applied
622        assert_eq!(config.system_prompt, Some("Overridden prompt".to_string()));
623
624        // Check metadata was merged (key1 unchanged, key2 overridden, key3 added)
625        assert_eq!(config.metadata.get("key1"), Some(&"original1".to_string()));
626        assert_eq!(
627            config.metadata.get("key2"),
628            Some(&"overridden2".to_string())
629        );
630        assert_eq!(config.metadata.get("key3"), Some(&"new3".to_string()));
631
632        // Visibility and approval policy should remain from file
633        assert!(matches!(config.tool_config.visibility, ToolVisibility::All));
634        assert!(matches!(
635            config.tool_config.approval_policy,
636            ToolApprovalPolicy::AlwaysAsk
637        ));
638    }
639
640    #[tokio::test]
641    async fn test_complex_tool_visibility_whitelist() {
642        use std::io::Write;
643        use tempfile::NamedTempFile;
644
645        let mut temp_file = NamedTempFile::new().unwrap();
646        writeln!(
647            temp_file,
648            r#"
649[tool_config]
650visibility = {{ whitelist = ["grep", "ls", "view"] }}
651"#
652        )
653        .unwrap();
654
655        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
656        let config = loader.load().await.unwrap();
657
658        match &config.tool_config.visibility {
659            ToolVisibility::Whitelist(tools) => {
660                assert_eq!(tools.len(), 3);
661                assert!(tools.contains("grep"));
662                assert!(tools.contains("ls"));
663                assert!(tools.contains("view"));
664            }
665            _ => unreachable!("Expected Whitelist visibility"),
666        }
667    }
668
669    #[tokio::test]
670    async fn test_complex_tool_visibility_blacklist() {
671        use std::io::Write;
672        use tempfile::NamedTempFile;
673
674        let mut temp_file = NamedTempFile::new().unwrap();
675        writeln!(
676            temp_file,
677            r#"
678[tool_config]
679visibility = {{ blacklist = ["bash", "edit_file"] }}
680"#
681        )
682        .unwrap();
683
684        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
685        let config = loader.load().await.unwrap();
686
687        match &config.tool_config.visibility {
688            ToolVisibility::Blacklist(tools) => {
689                assert_eq!(tools.len(), 2);
690                assert!(tools.contains("bash"));
691                assert!(tools.contains("edit_file"));
692            }
693            _ => unreachable!("Expected Blacklist visibility"),
694        }
695    }
696
697    #[tokio::test]
698    async fn test_workspace_remote_config() {
699        use std::io::Write;
700        use tempfile::NamedTempFile;
701
702        let mut temp_file = NamedTempFile::new().unwrap();
703        writeln!(
704            temp_file,
705            r#"
706[workspace]
707type = "remote"
708agent_address = "192.168.1.100:50051"
709auth = {{ Bearer = {{ token = "secret-token" }} }}
710"#
711        )
712        .unwrap();
713
714        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
715        let config = loader.load().await.unwrap();
716
717        match &config.workspace {
718            WorkspaceConfig::Remote {
719                agent_address,
720                auth,
721            } => {
722                assert_eq!(agent_address, "192.168.1.100:50051");
723                assert!(auth.is_some());
724                match auth.as_ref().unwrap() {
725                    RemoteAuth::Bearer { token } => {
726                        assert_eq!(token, "secret-token");
727                    }
728                    _ => unreachable!("Expected Bearer auth"),
729                }
730            }
731            _ => unreachable!("Expected Remote workspace"),
732        }
733    }
734
735    #[tokio::test]
736    async fn test_bash_tool_config_with_approved_patterns() {
737        use std::io::Write;
738        use tempfile::NamedTempFile;
739
740        let mut temp_file = NamedTempFile::new().unwrap();
741        writeln!(
742            temp_file,
743            r#"
744[tool_config.tools.bash]
745approved_patterns = [
746    "git status",
747    "git log*",
748    "npm run*",
749    "cargo build*"
750]
751"#
752        )
753        .unwrap();
754
755        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
756        let config = loader.load().await.unwrap();
757
758        // Check that bash tool config was parsed correctly
759        let bash_config = config.tool_config.tools.get("bash");
760        assert!(bash_config.is_some(), "Bash config should be present");
761
762        match bash_config.unwrap() {
763            ToolSpecificConfig::Bash(bash) => {
764                assert_eq!(bash.approved_patterns.len(), 4);
765                assert_eq!(bash.approved_patterns[0], "git status");
766                assert_eq!(bash.approved_patterns[1], "git log*");
767                assert_eq!(bash.approved_patterns[2], "npm run*");
768                assert_eq!(bash.approved_patterns[3], "cargo build*");
769            }
770        }
771    }
772
773    #[tokio::test]
774    async fn test_bash_tool_config_empty_patterns() {
775        use std::io::Write;
776        use tempfile::NamedTempFile;
777
778        let mut temp_file = NamedTempFile::new().unwrap();
779        writeln!(
780            temp_file,
781            r#"
782[tool_config.tools.bash]
783approved_patterns = []
784"#
785        )
786        .unwrap();
787
788        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
789        let config = loader.load().await.unwrap();
790
791        let bash_config = config.tool_config.tools.get("bash");
792        assert!(bash_config.is_some());
793
794        match bash_config.unwrap() {
795            ToolSpecificConfig::Bash(bash) => {
796                assert_eq!(bash.approved_patterns.len(), 0);
797            }
798        }
799    }
800
801    #[tokio::test]
802    async fn test_bash_tool_config_with_other_tools() {
803        use std::io::Write;
804        use tempfile::NamedTempFile;
805
806        let mut temp_file = NamedTempFile::new().unwrap();
807        writeln!(
808            temp_file,
809            r#"
810[tool_config]
811visibility = "all"
812approval_policy = "always_ask"
813
814[tool_config.tools.bash]
815approved_patterns = ["ls -la", "pwd"]
816"#
817        )
818        .unwrap();
819
820        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
821        let config = loader.load().await.unwrap();
822
823        // Check visibility and approval policy
824        assert!(matches!(config.tool_config.visibility, ToolVisibility::All));
825        assert!(matches!(
826            config.tool_config.approval_policy,
827            ToolApprovalPolicy::AlwaysAsk
828        ));
829
830        // Check bash config
831        let bash_config = config.tool_config.tools.get("bash");
832        assert!(bash_config.is_some());
833
834        match bash_config.unwrap() {
835            ToolSpecificConfig::Bash(bash) => {
836                assert_eq!(bash.approved_patterns.len(), 2);
837                assert_eq!(bash.approved_patterns[0], "ls -la");
838                assert_eq!(bash.approved_patterns[1], "pwd");
839            }
840        }
841    }
842
843    #[tokio::test]
844    async fn test_bash_tool_config_without_approved_patterns() {
845        use std::io::Write;
846        use tempfile::NamedTempFile;
847
848        let mut temp_file = NamedTempFile::new().unwrap();
849        writeln!(
850            temp_file,
851            r#"
852[tool_config.tools.bash]
853# No approved_patterns field
854"#
855        )
856        .unwrap();
857
858        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
859        let config = loader.load().await.unwrap();
860
861        let bash_config = config.tool_config.tools.get("bash");
862        assert!(bash_config.is_some());
863
864        match bash_config.unwrap() {
865            ToolSpecificConfig::Bash(bash) => {
866                // Should default to empty vec when approved_patterns is None
867                assert_eq!(bash.approved_patterns.len(), 0);
868            }
869        }
870    }
871
872    #[tokio::test]
873    async fn test_full_config_with_bash_patterns() {
874        use std::io::Write;
875        use tempfile::NamedTempFile;
876
877        let mut temp_file = NamedTempFile::new().unwrap();
878        writeln!(
879            temp_file,
880            r#"
881system_prompt = "You are a helpful assistant"
882
883[workspace]
884type = "local"
885
886[tool_config]
887visibility = "all"
888approval_policy = {{ type = "pre_approved", tools = ["grep", "ls", "view"] }}
889backends = []
890
891[tool_config.tools.bash]
892approved_patterns = [
893    "git status",
894    "git diff",
895    "git log --oneline",
896    "npm test",
897    "cargo check"
898]
899
900[metadata]
901project = "test-project"
902"#
903        )
904        .unwrap();
905
906        let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
907        let config = loader.load().await.unwrap();
908
909        // Check all parts of the config
910        assert_eq!(
911            config.system_prompt,
912            Some("You are a helpful assistant".to_string())
913        );
914        assert!(matches!(config.workspace, WorkspaceConfig::Local { .. }));
915        assert_eq!(
916            config.metadata.get("project"),
917            Some(&"test-project".to_string())
918        );
919
920        // Check tool approval policy
921        match &config.tool_config.approval_policy {
922            ToolApprovalPolicy::PreApproved { tools } => {
923                assert_eq!(tools.len(), 3);
924                assert!(tools.contains("grep"));
925                assert!(tools.contains("ls"));
926                assert!(tools.contains("view"));
927            }
928            _ => unreachable!("Expected PreApproved policy"),
929        }
930
931        // Check bash config
932        let bash_config = config.tool_config.tools.get("bash");
933        assert!(bash_config.is_some());
934
935        match bash_config.unwrap() {
936            ToolSpecificConfig::Bash(bash) => {
937                assert_eq!(bash.approved_patterns.len(), 5);
938                assert_eq!(bash.approved_patterns[0], "git status");
939                assert_eq!(bash.approved_patterns[1], "git diff");
940                assert_eq!(bash.approved_patterns[2], "git log --oneline");
941                assert_eq!(bash.approved_patterns[3], "npm test");
942                assert_eq!(bash.approved_patterns[4], "cargo check");
943            }
944        }
945    }
946}