steer_core/utils/
session.rs1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use crate::error::{Error, Result};
6use crate::session::{
7 Session, SessionStoreConfig,
8 state::{SessionConfig, SessionToolConfig, ToolApprovalPolicy, WorkspaceConfig},
9 store::SessionStore,
10};
11
12pub fn create_session_store_path() -> Result<std::path::PathBuf> {
13 let home_dir = dirs::home_dir()
14 .ok_or_else(|| Error::Configuration("Could not determine home directory".to_string()))?;
15 let db_path = home_dir.join(".steer").join("sessions.db");
16 Ok(db_path)
17}
18
19pub fn resolve_session_store_config(
22 session_db_path: Option<PathBuf>,
23) -> Result<SessionStoreConfig> {
24 match session_db_path {
25 Some(path) => Ok(SessionStoreConfig::sqlite(path)),
26 None => SessionStoreConfig::default_sqlite()
27 .map_err(|e| Error::Configuration(format!("Failed to get default sqlite config: {e}"))),
28 }
29}
30
31pub async fn create_session_store() -> Result<Arc<dyn SessionStore>> {
32 let config = SessionStoreConfig::default();
33 create_session_store_with_config(config).await
34}
35
36pub async fn create_session_store_with_config(
37 config: SessionStoreConfig,
38) -> Result<Arc<dyn SessionStore>> {
39 use crate::session::stores::sqlite::SqliteSessionStore;
40
41 match config {
42 SessionStoreConfig::Sqlite { path } => {
43 if let Some(parent) = path.parent() {
45 std::fs::create_dir_all(parent)?;
46 }
47
48 let store = SqliteSessionStore::new(&path).await?;
49
50 Ok(Arc::new(store))
51 }
52 _ => Err(Error::Configuration(
53 "Unsupported session store type".to_string(),
54 )),
55 }
56}
57
58pub fn create_default_session_config() -> SessionConfig {
59 SessionConfig {
60 workspace: WorkspaceConfig::default(),
61 tool_config: SessionToolConfig::default(),
62 system_prompt: None,
63 metadata: HashMap::new(),
64 }
65}
66
67pub fn parse_tool_policy(
68 policy_str: &str,
69 pre_approved_tools: Option<&str>,
70) -> Result<ToolApprovalPolicy> {
71 match policy_str {
72 "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
73 "pre_approved" => {
74 let tools = if let Some(tools_str) = pre_approved_tools {
75 tools_str.split(',').map(|s| s.trim().to_string()).collect()
76 } else {
77 return Err(Error::Configuration(
78 "pre_approved_tools is required when using pre_approved policy".to_string(),
79 ));
80 };
81 Ok(ToolApprovalPolicy::PreApproved { tools })
82 }
83 "mixed" => {
84 let tools = if let Some(tools_str) = pre_approved_tools {
85 tools_str.split(',').map(|s| s.trim().to_string()).collect()
86 } else {
87 std::collections::HashSet::new()
88 };
89 Ok(ToolApprovalPolicy::Mixed {
90 pre_approved: tools,
91 ask_for_others: true,
92 })
93 }
94 _ => Err(Error::Configuration(format!(
95 "Invalid tool policy: {policy_str}. Valid options: always_ask, pre_approved, mixed"
96 ))),
97 }
98}
99
100pub fn parse_metadata(metadata_str: Option<&str>) -> Result<HashMap<String, String>> {
101 let mut metadata = HashMap::new();
102
103 if let Some(meta_str) = metadata_str {
104 for pair in meta_str.split(',') {
105 let parts: Vec<&str> = pair.split('=').collect();
106 if parts.len() != 2 {
107 return Err(Error::Configuration(
108 "Invalid metadata format. Expected key=value pairs separated by commas"
109 .to_string(),
110 ));
111 }
112 metadata.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
113 }
114 }
115
116 Ok(metadata)
117}
118
119pub fn create_mock_session(id: &str, tool_policy: ToolApprovalPolicy) -> Session {
120 let tool_config = SessionToolConfig {
121 approval_policy: tool_policy,
122 ..Default::default()
123 };
124
125 let config = SessionConfig {
126 workspace: WorkspaceConfig::default(),
127 tool_config,
128 system_prompt: None,
129 metadata: std::collections::HashMap::new(),
130 };
131 Session::new(id.to_string(), config)
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_resolve_session_store_config_with_path() {
140 let custom_path = PathBuf::from("/custom/path/sessions.db");
141 let config = resolve_session_store_config(Some(custom_path.clone())).unwrap();
142
143 match config {
144 SessionStoreConfig::Sqlite { path } => {
145 assert_eq!(path, custom_path);
146 }
147 _ => unreachable!("SQLite config"),
148 }
149 }
150
151 #[test]
152 fn test_resolve_session_store_config_without_path() {
153 let config = resolve_session_store_config(None).unwrap();
154
155 match config {
156 SessionStoreConfig::Sqlite { path } => {
157 assert!(path.to_string_lossy().contains("sessions.db"));
158 }
159 _ => unreachable!("SQLite config"),
160 }
161 }
162
163 #[test]
164 fn test_parse_tool_policy() {
165 let policy = parse_tool_policy("always_ask", None).unwrap();
167 assert!(matches!(policy, ToolApprovalPolicy::AlwaysAsk));
168
169 let policy = parse_tool_policy("pre_approved", Some("tool1,tool2")).unwrap();
171 match policy {
172 ToolApprovalPolicy::PreApproved { tools } => {
173 assert_eq!(tools.len(), 2);
174 assert!(tools.contains("tool1"));
175 assert!(tools.contains("tool2"));
176 }
177 _ => unreachable!("PreApproved policy"),
178 }
179
180 let policy = parse_tool_policy("mixed", Some("tool3,tool4")).unwrap();
182 match policy {
183 ToolApprovalPolicy::Mixed { pre_approved, .. } => {
184 assert_eq!(pre_approved.len(), 2);
185 assert!(pre_approved.contains("tool3"));
186 assert!(pre_approved.contains("tool4"));
187 }
188 _ => unreachable!("Mixed policy"),
189 }
190
191 assert!(parse_tool_policy("invalid", None).is_err());
193 }
194
195 #[test]
196 fn test_parse_metadata() {
197 let metadata = parse_metadata(Some("key1=value1,key2=value2")).unwrap();
199 assert_eq!(metadata.len(), 2);
200 assert_eq!(metadata.get("key1"), Some(&"value1".to_string()));
201 assert_eq!(metadata.get("key2"), Some(&"value2".to_string()));
202
203 let metadata = parse_metadata(None).unwrap();
205 assert!(metadata.is_empty());
206
207 assert!(parse_metadata(Some("invalid_format")).is_err());
209 }
210}