1use eyre::{Context, Result};
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::path::PathBuf;
6use steer_core::session::{
7 BackendConfig, BashToolConfig, ContainerRuntime, RemoteAuth, SessionConfig, SessionToolConfig,
8 ToolApprovalPolicy, ToolSpecificConfig, ToolVisibility, WorkspaceConfig,
9};
10use thiserror::Error;
11use tokio::fs;
12use tracing::debug;
13
14#[derive(Debug, Error)]
16pub enum SessionConfigError {
17 #[error("MCP backend server_name cannot be empty")]
18 EmptyServerName,
19
20 #[error("MCP stdio transport command cannot be empty")]
21 EmptyStdioCommand,
22
23 #[error("MCP TCP transport host cannot be empty")]
24 EmptyTcpHost,
25
26 #[error("MCP TCP transport port cannot be 0")]
27 InvalidTcpPort,
28
29 #[error("MCP Unix transport path cannot be empty")]
30 EmptyUnixPath,
31
32 #[error("MCP SSE transport url cannot be empty")]
33 EmptySseUrl,
34
35 #[error("MCP HTTP transport url cannot be empty")]
36 EmptyHttpUrl,
37
38 #[error("IO error: {0}")]
39 Io(#[from] std::io::Error),
40
41 #[error("TOML parse error: {0}")]
42 TomlParse(#[from] toml::de::Error),
43}
44
45#[derive(Debug, Deserialize, Serialize, Default, JsonSchema)]
48pub struct PartialSessionConfig {
49 #[schemars(description = "URL to the JSON schema file")]
50 #[serde(rename = "$schema")]
51 pub schema: Option<String>,
52 pub workspace: Option<PartialWorkspaceConfig>,
53 pub tool_config: Option<PartialToolConfig>,
54 pub system_prompt: Option<String>,
55 pub metadata: Option<HashMap<String, String>>,
56}
57
58#[derive(Debug, Deserialize, Serialize, JsonSchema)]
59#[serde(tag = "type", rename_all = "snake_case")]
60pub enum PartialWorkspaceConfig {
61 Local {
62 #[serde(default)]
63 path: Option<PathBuf>,
64 },
65 Remote {
66 agent_address: String,
67 auth: Option<RemoteAuth>,
68 },
69 Container {
70 image: String,
71 runtime: ContainerRuntime,
72 },
73}
74
75#[derive(Debug, Deserialize, Serialize, Default, JsonSchema)]
76#[schemars(deny_unknown_fields)]
77pub struct PartialToolConfig {
78 pub backends: Option<Vec<BackendConfig>>,
79 pub visibility: Option<ToolVisibilityConfig>,
80 pub approval_policy: Option<ToolApprovalPolicyConfig>,
81 #[serde(default)]
82 pub tools: Option<HashMap<String, PartialToolSpecificConfig>>,
83}
84
85#[derive(Debug, Deserialize, Serialize, JsonSchema)]
86#[serde(untagged)]
87pub enum PartialToolSpecificConfig {
88 Bash(PartialBashConfig),
89}
90
91#[derive(Debug, Deserialize, Serialize, JsonSchema)]
92pub struct PartialBashConfig {
93 pub approved_patterns: Option<Vec<String>>,
94}
95
96#[derive(Debug, Deserialize, Serialize, JsonSchema)]
97#[serde(untagged)]
98pub enum ToolVisibilityConfig {
99 String(String), Object(ToolVisibilityObject),
101}
102
103#[derive(Debug, Deserialize, Serialize, JsonSchema)]
104#[serde(rename_all = "snake_case")]
105pub enum ToolVisibilityObject {
106 Whitelist(HashSet<String>),
107 Blacklist(HashSet<String>),
108}
109
110#[derive(Debug, Deserialize, Serialize, JsonSchema)]
111#[serde(untagged)]
112pub enum ToolApprovalPolicyConfig {
113 String(String), Tagged(ToolApprovalPolicy), }
116
117#[derive(Debug, Default)]
119pub struct SessionConfigOverrides {
120 pub system_prompt: Option<String>,
121 pub metadata: Option<String>,
122}
123
124pub struct SessionConfigLoader {
126 config_path: Option<PathBuf>,
127 overrides: SessionConfigOverrides,
128}
129
130impl SessionConfigLoader {
131 pub fn new(config_path: Option<PathBuf>) -> Self {
132 debug!("Loading session config from: {:?}", config_path);
133 Self {
134 config_path,
135 overrides: SessionConfigOverrides::default(),
136 }
137 }
138
139 pub fn with_overrides(mut self, overrides: SessionConfigOverrides) -> Self {
140 self.overrides = overrides;
141 self
142 }
143
144 pub async fn load(&self) -> Result<SessionConfig> {
145 let mut config = if let Some(path) = &self.config_path {
146 let content = fs::read_to_string(path)
148 .await
149 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
150
151 let partial: PartialSessionConfig = toml::from_str(&content)
152 .with_context(|| format!("Failed to parse TOML config from: {}", path.display()))?;
153
154 self.partial_to_full(partial)?
155 } else {
156 SessionConfig {
158 workspace: WorkspaceConfig::default(),
159 tool_config: SessionToolConfig::default(),
160 system_prompt: None,
161 metadata: HashMap::new(),
162 }
163 };
164
165 self.apply_overrides(&mut config)?;
166 self.validate_config(&config)?;
167
168 Ok(config)
169 }
170
171 fn partial_to_full(&self, partial: PartialSessionConfig) -> Result<SessionConfig> {
172 let workspace = match partial.workspace {
173 Some(PartialWorkspaceConfig::Local { path }) => WorkspaceConfig::Local {
174 path: path.unwrap_or_else(|| {
175 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
176 }),
177 },
178 Some(PartialWorkspaceConfig::Remote {
179 agent_address,
180 auth,
181 }) => WorkspaceConfig::Remote {
182 agent_address,
183 auth,
184 },
185 Some(PartialWorkspaceConfig::Container { image, runtime }) => {
186 WorkspaceConfig::Container { image, runtime }
187 }
188 None => WorkspaceConfig::default(),
189 };
190
191 let tool_config = if let Some(partial_tool_config) = partial.tool_config {
192 let backends = partial_tool_config.backends.unwrap_or_default();
193
194 let visibility = match partial_tool_config.visibility {
195 Some(ToolVisibilityConfig::String(s)) => match s.as_str() {
196 "all" => ToolVisibility::All,
197 "read_only" => ToolVisibility::ReadOnly,
198 _ => {
199 return Err(eyre::eyre!(
200 "Invalid visibility string: {}. Expected 'all' or 'read_only'",
201 s
202 ));
203 }
204 },
205 Some(ToolVisibilityConfig::Object(obj)) => match obj {
206 ToolVisibilityObject::Whitelist(tools) => ToolVisibility::Whitelist(tools),
207 ToolVisibilityObject::Blacklist(tools) => ToolVisibility::Blacklist(tools),
208 },
209 None => ToolVisibility::default(),
210 };
211
212 let approval_policy = match partial_tool_config.approval_policy {
213 Some(ToolApprovalPolicyConfig::String(s)) => match s.as_str() {
214 "always_ask" => ToolApprovalPolicy::AlwaysAsk,
215 _ => {
216 return Err(eyre::eyre!(
217 "Invalid approval policy string: {}. Expected 'always_ask'",
218 s
219 ));
220 }
221 },
222 Some(ToolApprovalPolicyConfig::Tagged(policy)) => policy,
223 None => ToolApprovalPolicy::AlwaysAsk,
224 };
225
226 let mut tools = HashMap::new();
228 if let Some(partial_tools) = partial_tool_config.tools {
229 for (tool_name, tool_config) in partial_tools {
230 match tool_config {
231 PartialToolSpecificConfig::Bash(bash_config) => {
232 if tool_name == "bash" {
233 tools.insert(
234 "bash".to_string(),
235 ToolSpecificConfig::Bash(BashToolConfig {
236 approved_patterns: bash_config
237 .approved_patterns
238 .unwrap_or_default(),
239 }),
240 );
241 }
242 }
243 }
244 }
245 }
246
247 SessionToolConfig {
248 backends,
249 visibility,
250 approval_policy,
251 metadata: HashMap::new(),
252 tools,
253 }
254 } else {
255 SessionToolConfig::default()
256 };
257
258 debug!("Loaded tool config: {:?}", tool_config);
259
260 Ok(SessionConfig {
261 workspace,
262 tool_config,
263 system_prompt: partial.system_prompt,
264 metadata: partial.metadata.unwrap_or_default(),
265 })
266 }
267
268 fn apply_overrides(&self, config: &mut SessionConfig) -> Result<()> {
269 if let Some(system_prompt) = &self.overrides.system_prompt {
271 config.system_prompt = Some(system_prompt.clone());
272 }
273
274 if let Some(metadata_str) = &self.overrides.metadata {
276 let metadata = steer_core::utils::session::parse_metadata(Some(metadata_str))?;
277 config.metadata.extend(metadata);
278 }
279
280 Ok(())
281 }
282
283 fn validate_config(&self, config: &SessionConfig) -> Result<(), SessionConfigError> {
284 for backend in &config.tool_config.backends {
286 if let BackendConfig::Mcp {
287 server_name,
288 transport,
289 ..
290 } = backend
291 {
292 if server_name.is_empty() {
293 return Err(SessionConfigError::EmptyServerName);
294 }
295
296 match transport {
298 steer_core::tools::McpTransport::Stdio { command, .. } => {
299 if command.is_empty() {
300 return Err(SessionConfigError::EmptyStdioCommand);
301 }
302 if which::which(command).is_err() {
304 tracing::warn!(
306 "MCP command '{}' for server '{}' not found in PATH",
307 command,
308 server_name
309 );
310 }
311 }
312 steer_core::tools::McpTransport::Tcp { host, port } => {
313 if host.is_empty() {
314 return Err(SessionConfigError::EmptyTcpHost);
315 }
316 if *port == 0 {
317 return Err(SessionConfigError::InvalidTcpPort);
318 }
319 }
320 #[cfg(unix)]
321 steer_core::tools::McpTransport::Unix { path } => {
322 if path.is_empty() {
323 return Err(SessionConfigError::EmptyUnixPath);
324 }
325 }
326 steer_core::tools::McpTransport::Sse { url, .. } => {
327 if url.is_empty() {
328 return Err(SessionConfigError::EmptySseUrl);
329 }
330 }
331 steer_core::tools::McpTransport::Http { url, .. } => {
332 if url.is_empty() {
333 return Err(SessionConfigError::EmptyHttpUrl);
334 }
335 }
336 }
337 }
338 }
339
340 Ok(())
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use steer_core::session::ToolFilter;
348
349 #[tokio::test]
350 async fn test_backend_serialization() {
351 let backend = BackendConfig::Mcp {
353 server_name: "test".to_string(),
354 transport: steer_core::tools::McpTransport::Stdio {
355 command: "python".to_string(),
356 args: vec!["-m".to_string(), "test".to_string()],
357 },
358 tool_filter: ToolFilter::All,
359 };
360
361 let json = serde_json::to_string(&backend).unwrap();
362 println!("Backend JSON: {json}");
363
364 let backend2: BackendConfig = serde_json::from_str(&json).unwrap();
365 match backend2 {
366 BackendConfig::Mcp {
367 server_name,
368 transport,
369 ..
370 } => {
371 assert_eq!(server_name, "test");
372 match transport {
373 steer_core::tools::McpTransport::Stdio { command, args } => {
374 assert_eq!(command, "python");
375 assert_eq!(args, vec!["-m", "test"]);
376 }
377 _ => unreachable!("Expected Stdio transport"),
378 }
379 }
380 _ => unreachable!("Expected correct variant"),
381 }
382 }
383
384 #[tokio::test]
385 async fn test_partial_config_parsing() {
386 let toml_content = r#"
388[tool_config]
389visibility = "all"
390approval_policy = "always_ask"
391"#;
392
393 let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
394 assert!(partial.tool_config.is_some());
395 }
396
397 #[tokio::test]
398 async fn test_config_with_empty_backends() {
399 let toml_content = r#"
401[tool_config]
402backends = []
403visibility = "all"
404"#;
405
406 let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
407 assert!(partial.tool_config.is_some());
408
409 let tool_config = partial.tool_config.unwrap();
410 assert!(tool_config.backends.is_some());
411 assert_eq!(tool_config.backends.unwrap().len(), 0);
412 }
413
414 #[tokio::test]
415 async fn test_full_config_parsing() {
416 let toml_content = r#"
417system_prompt = "You are a helpful assistant."
418
419[workspace]
420type = "local"
421
422[tool_config]
423visibility = "read_only"
424
425[metadata]
426project = "my-project"
427"#;
428
429 let partial: PartialSessionConfig = toml::from_str(toml_content).unwrap();
430 assert!(partial.workspace.is_some());
431 assert!(partial.system_prompt.is_some());
432 assert_eq!(
433 partial.system_prompt.unwrap(),
434 "You are a helpful assistant."
435 );
436 assert!(partial.metadata.is_some());
437 }
438
439 #[tokio::test]
440 async fn test_config_loader() {
441 let loader = SessionConfigLoader::new(None);
442 let config = loader.load().await.unwrap();
443
444 assert!(matches!(config.workspace, WorkspaceConfig::Local { .. }));
446 assert!(matches!(
447 config.tool_config.approval_policy,
448 ToolApprovalPolicy::AlwaysAsk
449 ));
450 }
451
452 #[tokio::test]
453 async fn test_config_loader_with_overrides() {
454 let overrides = SessionConfigOverrides {
455 system_prompt: Some("Custom prompt".to_string()),
456 metadata: Some("key1=value1,key2=value2".to_string()),
457 };
458
459 let loader = SessionConfigLoader::new(None).with_overrides(overrides);
460 let config = loader.load().await.unwrap();
461
462 assert_eq!(config.system_prompt, Some("Custom prompt".to_string()));
463 assert_eq!(config.metadata.get("key1"), Some(&"value1".to_string()));
464 }
465
466 #[tokio::test]
467 async fn test_load_non_existent_file() {
468 let loader = SessionConfigLoader::new(Some(PathBuf::from("/tmp/non_existent_file.toml")));
469 let result = loader.load().await;
470
471 assert!(result.is_err());
472 let err = result.unwrap_err();
473 assert!(err.to_string().contains("Failed to read config file"));
474 }
475
476 #[tokio::test]
477 async fn test_load_invalid_toml() {
478 use std::io::Write;
479 use tempfile::NamedTempFile;
480
481 let mut temp_file = NamedTempFile::new().unwrap();
482 writeln!(temp_file, "invalid toml syntax {{").unwrap();
483
484 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
485 let result = loader.load().await;
486
487 assert!(result.is_err());
488 let err = result.unwrap_err();
489 assert!(err.to_string().contains("Failed to parse TOML config"));
490 }
491
492 #[tokio::test]
493 async fn test_invalid_visibility_config() {
494 use std::io::Write;
495 use tempfile::NamedTempFile;
496
497 let mut temp_file = NamedTempFile::new().unwrap();
498 writeln!(
499 temp_file,
500 r#"
501[tool_config]
502visibility = "invalid_value"
503"#
504 )
505 .unwrap();
506
507 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
508 let result = loader.load().await;
509
510 assert!(result.is_err());
511 let err = result.unwrap_err();
512 assert!(err.to_string().contains("Invalid visibility string"));
513 assert!(err.to_string().contains("Expected 'all' or 'read_only'"));
514 }
515
516 #[tokio::test]
517 async fn test_invalid_approval_policy_config() {
518 use std::io::Write;
519 use tempfile::NamedTempFile;
520
521 let mut temp_file = NamedTempFile::new().unwrap();
522 writeln!(
523 temp_file,
524 r#"
525[tool_config]
526approval_policy = "invalid_policy"
527"#
528 )
529 .unwrap();
530
531 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
532 let result = loader.load().await;
533
534 assert!(result.is_err());
535 let err = result.unwrap_err();
536 assert!(err.to_string().contains("Invalid approval policy string"));
537 assert!(err.to_string().contains("Expected 'always_ask'"));
538 }
539
540 #[tokio::test]
541 async fn test_mcp_backend_validation_empty_server_name() {
542 use std::io::Write;
543 use tempfile::NamedTempFile;
544
545 let mut temp_file = NamedTempFile::new().unwrap();
546 writeln!(temp_file, r#"
547[tool_config]
548backends = [
549 {{ type = "mcp", server_name = "", transport = {{ type = "stdio", command = "python", args = ["-m", "test"] }}, tool_filter = "all" }}
550]
551"#).unwrap();
552
553 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
554 let result = loader.load().await;
555
556 assert!(result.is_err());
557 let err = result.unwrap_err();
558 assert!(
559 err.to_string()
560 .contains("MCP backend server_name cannot be empty")
561 );
562 }
563
564 #[tokio::test]
565 async fn test_mcp_backend_validation_empty_command() {
566 use std::io::Write;
567 use tempfile::NamedTempFile;
568
569 let mut temp_file = NamedTempFile::new().unwrap();
570 writeln!(temp_file, r#"
571[tool_config]
572backends = [
573 {{ type = "mcp", server_name = "test", transport = {{ type = "stdio", command = "", args = ["-m", "test"] }}, tool_filter = "all" }}
574]
575"#).unwrap();
576
577 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
578 let result = loader.load().await;
579
580 assert!(result.is_err());
581 let err = result.unwrap_err();
582 assert!(
583 err.to_string()
584 .contains("MCP stdio transport command cannot be empty")
585 );
586 }
587
588 #[tokio::test]
589 async fn test_file_config_with_cli_overrides() {
590 use std::io::Write;
591 use tempfile::NamedTempFile;
592
593 let mut temp_file = NamedTempFile::new().unwrap();
595 writeln!(
596 temp_file,
597 r#"
598system_prompt = "Original prompt"
599
600[tool_config]
601visibility = "all"
602approval_policy = "always_ask"
603
604[metadata]
605key1 = "original1"
606key2 = "original2"
607"#
608 )
609 .unwrap();
610
611 let overrides = SessionConfigOverrides {
613 system_prompt: Some("Overridden prompt".to_string()),
614 metadata: Some("key2=overridden2,key3=new3".to_string()),
615 };
616
617 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()))
618 .with_overrides(overrides);
619 let config = loader.load().await.unwrap();
620
621 assert_eq!(config.system_prompt, Some("Overridden prompt".to_string()));
623
624 assert_eq!(config.metadata.get("key1"), Some(&"original1".to_string()));
626 assert_eq!(
627 config.metadata.get("key2"),
628 Some(&"overridden2".to_string())
629 );
630 assert_eq!(config.metadata.get("key3"), Some(&"new3".to_string()));
631
632 assert!(matches!(config.tool_config.visibility, ToolVisibility::All));
634 assert!(matches!(
635 config.tool_config.approval_policy,
636 ToolApprovalPolicy::AlwaysAsk
637 ));
638 }
639
640 #[tokio::test]
641 async fn test_complex_tool_visibility_whitelist() {
642 use std::io::Write;
643 use tempfile::NamedTempFile;
644
645 let mut temp_file = NamedTempFile::new().unwrap();
646 writeln!(
647 temp_file,
648 r#"
649[tool_config]
650visibility = {{ whitelist = ["grep", "ls", "view"] }}
651"#
652 )
653 .unwrap();
654
655 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
656 let config = loader.load().await.unwrap();
657
658 match &config.tool_config.visibility {
659 ToolVisibility::Whitelist(tools) => {
660 assert_eq!(tools.len(), 3);
661 assert!(tools.contains("grep"));
662 assert!(tools.contains("ls"));
663 assert!(tools.contains("view"));
664 }
665 _ => unreachable!("Expected Whitelist visibility"),
666 }
667 }
668
669 #[tokio::test]
670 async fn test_complex_tool_visibility_blacklist() {
671 use std::io::Write;
672 use tempfile::NamedTempFile;
673
674 let mut temp_file = NamedTempFile::new().unwrap();
675 writeln!(
676 temp_file,
677 r#"
678[tool_config]
679visibility = {{ blacklist = ["bash", "edit_file"] }}
680"#
681 )
682 .unwrap();
683
684 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
685 let config = loader.load().await.unwrap();
686
687 match &config.tool_config.visibility {
688 ToolVisibility::Blacklist(tools) => {
689 assert_eq!(tools.len(), 2);
690 assert!(tools.contains("bash"));
691 assert!(tools.contains("edit_file"));
692 }
693 _ => unreachable!("Expected Blacklist visibility"),
694 }
695 }
696
697 #[tokio::test]
698 async fn test_workspace_remote_config() {
699 use std::io::Write;
700 use tempfile::NamedTempFile;
701
702 let mut temp_file = NamedTempFile::new().unwrap();
703 writeln!(
704 temp_file,
705 r#"
706[workspace]
707type = "remote"
708agent_address = "192.168.1.100:50051"
709auth = {{ Bearer = {{ token = "secret-token" }} }}
710"#
711 )
712 .unwrap();
713
714 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
715 let config = loader.load().await.unwrap();
716
717 match &config.workspace {
718 WorkspaceConfig::Remote {
719 agent_address,
720 auth,
721 } => {
722 assert_eq!(agent_address, "192.168.1.100:50051");
723 assert!(auth.is_some());
724 match auth.as_ref().unwrap() {
725 RemoteAuth::Bearer { token } => {
726 assert_eq!(token, "secret-token");
727 }
728 _ => unreachable!("Expected Bearer auth"),
729 }
730 }
731 _ => unreachable!("Expected Remote workspace"),
732 }
733 }
734
735 #[tokio::test]
736 async fn test_bash_tool_config_with_approved_patterns() {
737 use std::io::Write;
738 use tempfile::NamedTempFile;
739
740 let mut temp_file = NamedTempFile::new().unwrap();
741 writeln!(
742 temp_file,
743 r#"
744[tool_config.tools.bash]
745approved_patterns = [
746 "git status",
747 "git log*",
748 "npm run*",
749 "cargo build*"
750]
751"#
752 )
753 .unwrap();
754
755 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
756 let config = loader.load().await.unwrap();
757
758 let bash_config = config.tool_config.tools.get("bash");
760 assert!(bash_config.is_some(), "Bash config should be present");
761
762 match bash_config.unwrap() {
763 ToolSpecificConfig::Bash(bash) => {
764 assert_eq!(bash.approved_patterns.len(), 4);
765 assert_eq!(bash.approved_patterns[0], "git status");
766 assert_eq!(bash.approved_patterns[1], "git log*");
767 assert_eq!(bash.approved_patterns[2], "npm run*");
768 assert_eq!(bash.approved_patterns[3], "cargo build*");
769 }
770 }
771 }
772
773 #[tokio::test]
774 async fn test_bash_tool_config_empty_patterns() {
775 use std::io::Write;
776 use tempfile::NamedTempFile;
777
778 let mut temp_file = NamedTempFile::new().unwrap();
779 writeln!(
780 temp_file,
781 r#"
782[tool_config.tools.bash]
783approved_patterns = []
784"#
785 )
786 .unwrap();
787
788 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
789 let config = loader.load().await.unwrap();
790
791 let bash_config = config.tool_config.tools.get("bash");
792 assert!(bash_config.is_some());
793
794 match bash_config.unwrap() {
795 ToolSpecificConfig::Bash(bash) => {
796 assert_eq!(bash.approved_patterns.len(), 0);
797 }
798 }
799 }
800
801 #[tokio::test]
802 async fn test_bash_tool_config_with_other_tools() {
803 use std::io::Write;
804 use tempfile::NamedTempFile;
805
806 let mut temp_file = NamedTempFile::new().unwrap();
807 writeln!(
808 temp_file,
809 r#"
810[tool_config]
811visibility = "all"
812approval_policy = "always_ask"
813
814[tool_config.tools.bash]
815approved_patterns = ["ls -la", "pwd"]
816"#
817 )
818 .unwrap();
819
820 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
821 let config = loader.load().await.unwrap();
822
823 assert!(matches!(config.tool_config.visibility, ToolVisibility::All));
825 assert!(matches!(
826 config.tool_config.approval_policy,
827 ToolApprovalPolicy::AlwaysAsk
828 ));
829
830 let bash_config = config.tool_config.tools.get("bash");
832 assert!(bash_config.is_some());
833
834 match bash_config.unwrap() {
835 ToolSpecificConfig::Bash(bash) => {
836 assert_eq!(bash.approved_patterns.len(), 2);
837 assert_eq!(bash.approved_patterns[0], "ls -la");
838 assert_eq!(bash.approved_patterns[1], "pwd");
839 }
840 }
841 }
842
843 #[tokio::test]
844 async fn test_bash_tool_config_without_approved_patterns() {
845 use std::io::Write;
846 use tempfile::NamedTempFile;
847
848 let mut temp_file = NamedTempFile::new().unwrap();
849 writeln!(
850 temp_file,
851 r#"
852[tool_config.tools.bash]
853# No approved_patterns field
854"#
855 )
856 .unwrap();
857
858 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
859 let config = loader.load().await.unwrap();
860
861 let bash_config = config.tool_config.tools.get("bash");
862 assert!(bash_config.is_some());
863
864 match bash_config.unwrap() {
865 ToolSpecificConfig::Bash(bash) => {
866 assert_eq!(bash.approved_patterns.len(), 0);
868 }
869 }
870 }
871
872 #[tokio::test]
873 async fn test_full_config_with_bash_patterns() {
874 use std::io::Write;
875 use tempfile::NamedTempFile;
876
877 let mut temp_file = NamedTempFile::new().unwrap();
878 writeln!(
879 temp_file,
880 r#"
881system_prompt = "You are a helpful assistant"
882
883[workspace]
884type = "local"
885
886[tool_config]
887visibility = "all"
888approval_policy = {{ type = "pre_approved", tools = ["grep", "ls", "view"] }}
889backends = []
890
891[tool_config.tools.bash]
892approved_patterns = [
893 "git status",
894 "git diff",
895 "git log --oneline",
896 "npm test",
897 "cargo check"
898]
899
900[metadata]
901project = "test-project"
902"#
903 )
904 .unwrap();
905
906 let loader = SessionConfigLoader::new(Some(temp_file.path().to_path_buf()));
907 let config = loader.load().await.unwrap();
908
909 assert_eq!(
911 config.system_prompt,
912 Some("You are a helpful assistant".to_string())
913 );
914 assert!(matches!(config.workspace, WorkspaceConfig::Local { .. }));
915 assert_eq!(
916 config.metadata.get("project"),
917 Some(&"test-project".to_string())
918 );
919
920 match &config.tool_config.approval_policy {
922 ToolApprovalPolicy::PreApproved { tools } => {
923 assert_eq!(tools.len(), 3);
924 assert!(tools.contains("grep"));
925 assert!(tools.contains("ls"));
926 assert!(tools.contains("view"));
927 }
928 _ => unreachable!("Expected PreApproved policy"),
929 }
930
931 let bash_config = config.tool_config.tools.get("bash");
933 assert!(bash_config.is_some());
934
935 match bash_config.unwrap() {
936 ToolSpecificConfig::Bash(bash) => {
937 assert_eq!(bash.approved_patterns.len(), 5);
938 assert_eq!(bash.approved_patterns[0], "git status");
939 assert_eq!(bash.approved_patterns[1], "git diff");
940 assert_eq!(bash.approved_patterns[2], "git log --oneline");
941 assert_eq!(bash.approved_patterns[3], "npm test");
942 assert_eq!(bash.approved_patterns[4], "cargo check");
943 }
944 }
945 }
946}