steer_core/utils/
session.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use crate::error::{Error, Result};
6use crate::session::{
7    Session, SessionStoreConfig,
8    state::{SessionConfig, SessionToolConfig, ToolApprovalPolicy, WorkspaceConfig},
9    store::SessionStore,
10};
11
12pub fn create_session_store_path() -> Result<std::path::PathBuf> {
13    let home_dir = dirs::home_dir()
14        .ok_or_else(|| Error::Configuration("Could not determine home directory".to_string()))?;
15    let db_path = home_dir.join(".steer").join("sessions.db");
16    Ok(db_path)
17}
18
19/// Resolve session store configuration from an optional path
20/// If no path is provided, uses the default SQLite configuration
21pub fn resolve_session_store_config(
22    session_db_path: Option<PathBuf>,
23) -> Result<SessionStoreConfig> {
24    match session_db_path {
25        Some(path) => Ok(SessionStoreConfig::sqlite(path)),
26        None => SessionStoreConfig::default_sqlite()
27            .map_err(|e| Error::Configuration(format!("Failed to get default sqlite config: {e}"))),
28    }
29}
30
31pub async fn create_session_store() -> Result<Arc<dyn SessionStore>> {
32    let config = SessionStoreConfig::default();
33    create_session_store_with_config(config).await
34}
35
36pub async fn create_session_store_with_config(
37    config: SessionStoreConfig,
38) -> Result<Arc<dyn SessionStore>> {
39    use crate::session::stores::sqlite::SqliteSessionStore;
40
41    match config {
42        SessionStoreConfig::Sqlite { path } => {
43            // Create directory if it doesn't exist
44            if let Some(parent) = path.parent() {
45                std::fs::create_dir_all(parent)?;
46            }
47
48            let store = SqliteSessionStore::new(&path).await?;
49
50            Ok(Arc::new(store))
51        }
52        _ => Err(Error::Configuration(
53            "Unsupported session store type".to_string(),
54        )),
55    }
56}
57
58pub fn create_default_session_config() -> SessionConfig {
59    SessionConfig {
60        workspace: WorkspaceConfig::default(),
61        tool_config: SessionToolConfig::default(),
62        system_prompt: None,
63        metadata: HashMap::new(),
64    }
65}
66
67pub fn parse_tool_policy(
68    policy_str: &str,
69    pre_approved_tools: Option<&str>,
70) -> Result<ToolApprovalPolicy> {
71    match policy_str {
72        "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
73        "pre_approved" => {
74            let tools = if let Some(tools_str) = pre_approved_tools {
75                tools_str.split(',').map(|s| s.trim().to_string()).collect()
76            } else {
77                return Err(Error::Configuration(
78                    "pre_approved_tools is required when using pre_approved policy".to_string(),
79                ));
80            };
81            Ok(ToolApprovalPolicy::PreApproved { tools })
82        }
83        "mixed" => {
84            let tools = if let Some(tools_str) = pre_approved_tools {
85                tools_str.split(',').map(|s| s.trim().to_string()).collect()
86            } else {
87                std::collections::HashSet::new()
88            };
89            Ok(ToolApprovalPolicy::Mixed {
90                pre_approved: tools,
91                ask_for_others: true,
92            })
93        }
94        _ => Err(Error::Configuration(format!(
95            "Invalid tool policy: {policy_str}. Valid options: always_ask, pre_approved, mixed"
96        ))),
97    }
98}
99
100pub fn parse_metadata(metadata_str: Option<&str>) -> Result<HashMap<String, String>> {
101    let mut metadata = HashMap::new();
102
103    if let Some(meta_str) = metadata_str {
104        for pair in meta_str.split(',') {
105            let parts: Vec<&str> = pair.split('=').collect();
106            if parts.len() != 2 {
107                return Err(Error::Configuration(
108                    "Invalid metadata format. Expected key=value pairs separated by commas"
109                        .to_string(),
110                ));
111            }
112            metadata.insert(parts[0].trim().to_string(), parts[1].trim().to_string());
113        }
114    }
115
116    Ok(metadata)
117}
118
119pub fn create_mock_session(id: &str, tool_policy: ToolApprovalPolicy) -> Session {
120    let tool_config = SessionToolConfig {
121        approval_policy: tool_policy,
122        ..Default::default()
123    };
124
125    let config = SessionConfig {
126        workspace: WorkspaceConfig::default(),
127        tool_config,
128        system_prompt: None,
129        metadata: std::collections::HashMap::new(),
130    };
131    Session::new(id.to_string(), config)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_resolve_session_store_config_with_path() {
140        let custom_path = PathBuf::from("/custom/path/sessions.db");
141        let config = resolve_session_store_config(Some(custom_path.clone())).unwrap();
142
143        match config {
144            SessionStoreConfig::Sqlite { path } => {
145                assert_eq!(path, custom_path);
146            }
147            _ => unreachable!("SQLite config"),
148        }
149    }
150
151    #[test]
152    fn test_resolve_session_store_config_without_path() {
153        let config = resolve_session_store_config(None).unwrap();
154
155        match config {
156            SessionStoreConfig::Sqlite { path } => {
157                assert!(path.to_string_lossy().contains("sessions.db"));
158            }
159            _ => unreachable!("SQLite config"),
160        }
161    }
162
163    #[test]
164    fn test_parse_tool_policy() {
165        // Test always_ask
166        let policy = parse_tool_policy("always_ask", None).unwrap();
167        assert!(matches!(policy, ToolApprovalPolicy::AlwaysAsk));
168
169        // Test pre_approved
170        let policy = parse_tool_policy("pre_approved", Some("tool1,tool2")).unwrap();
171        match policy {
172            ToolApprovalPolicy::PreApproved { tools } => {
173                assert_eq!(tools.len(), 2);
174                assert!(tools.contains("tool1"));
175                assert!(tools.contains("tool2"));
176            }
177            _ => unreachable!("PreApproved policy"),
178        }
179
180        // Test mixed
181        let policy = parse_tool_policy("mixed", Some("tool3,tool4")).unwrap();
182        match policy {
183            ToolApprovalPolicy::Mixed { pre_approved, .. } => {
184                assert_eq!(pre_approved.len(), 2);
185                assert!(pre_approved.contains("tool3"));
186                assert!(pre_approved.contains("tool4"));
187            }
188            _ => unreachable!("Mixed policy"),
189        }
190
191        // Test invalid policy
192        assert!(parse_tool_policy("invalid", None).is_err());
193    }
194
195    #[test]
196    fn test_parse_metadata() {
197        // Test with metadata
198        let metadata = parse_metadata(Some("key1=value1,key2=value2")).unwrap();
199        assert_eq!(metadata.len(), 2);
200        assert_eq!(metadata.get("key1"), Some(&"value1".to_string()));
201        assert_eq!(metadata.get("key2"), Some(&"value2".to_string()));
202
203        // Test without metadata
204        let metadata = parse_metadata(None).unwrap();
205        assert!(metadata.is_empty());
206
207        // Test invalid format
208        assert!(parse_metadata(Some("invalid_format")).is_err());
209    }
210}