1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18 Connecting,
20 Connected {
22 tool_names: Vec<String>,
24 },
25 Disconnected {
27 reason: Option<String>,
29 },
30 Failed {
32 error: String,
34 },
35}
36
37#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40 pub server_name: String,
42 pub transport: McpTransport,
44 pub state: McpConnectionState,
46 pub last_updated: DateTime<Utc>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54 Local {
55 path: PathBuf,
56 },
57 Remote {
58 agent_address: String,
59 auth: Option<RemoteAuth>,
60 },
61}
62
63impl WorkspaceConfig {
64 pub fn get_path(&self) -> Option<String> {
65 match self {
66 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68 }
69 }
70
71 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73 match self {
74 WorkspaceConfig::Local { path } => {
75 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76 }
77 WorkspaceConfig::Remote {
78 agent_address,
79 auth,
80 } => steer_workspace::WorkspaceConfig::Remote {
81 address: agent_address.clone(),
82 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83 },
84 }
85 }
86}
87
88impl Default for WorkspaceConfig {
89 fn default() -> Self {
90 Self::Local {
91 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99 pub id: String,
100 pub created_at: DateTime<Utc>,
101 pub updated_at: DateTime<Utc>,
102 pub config: SessionConfig,
103 pub state: SessionState,
104}
105
106impl Session {
107 pub fn new(id: String, config: SessionConfig) -> Self {
108 let now = Utc::now();
109 Self {
110 id,
111 created_at: now,
112 updated_at: now,
113 config,
114 state: SessionState::default(),
115 }
116 }
117
118 pub fn update_timestamp(&mut self) {
119 self.updated_at = Utc::now();
120 }
121
122 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124 let cutoff = Utc::now() - threshold;
125 self.updated_at > cutoff
126 }
127
128 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137 pub workspace: WorkspaceConfig,
138 #[serde(default)]
139 pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140 #[serde(default)]
141 pub workspace_id: Option<crate::workspace::WorkspaceId>,
142 #[serde(default)]
143 pub repo_ref: Option<crate::workspace::RepoRef>,
144 #[serde(default)]
145 pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146 #[serde(default)]
147 pub workspace_name: Option<String>,
148 pub tool_config: SessionToolConfig,
149 pub system_prompt: Option<String>,
152 #[serde(default)]
154 pub primary_agent_id: Option<String>,
155 #[serde(default = "SessionPolicyOverrides::empty")]
157 pub policy_overrides: SessionPolicyOverrides,
158 pub metadata: HashMap<String, String>,
159 pub default_model: ModelId,
160}
161
162impl SessionConfig {
163 pub async fn build_registry(
166 &self,
167 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
168 let mut registry = BackendRegistry::new();
169 let mut mcp_servers = HashMap::new();
170
171 for backend_config in &self.tool_config.backends {
172 let BackendConfig::Mcp {
173 server_name,
174 transport,
175 tool_filter,
176 } = backend_config;
177
178 tracing::info!(
179 "Attempting to initialize MCP backend '{}' with transport: {:?}",
180 server_name,
181 transport
182 );
183
184 let mut server_info = McpServerInfo {
185 server_name: server_name.clone(),
186 transport: transport.clone(),
187 state: McpConnectionState::Connecting,
188 last_updated: Utc::now(),
189 };
190
191 match crate::tools::McpBackend::new(
192 server_name.clone(),
193 transport.clone(),
194 tool_filter.clone(),
195 )
196 .await
197 {
198 Ok(mcp_backend) => {
199 let tool_names = mcp_backend.supported_tools().await;
200 let tool_count = tool_names.len();
201 tracing::info!(
202 "Successfully initialized MCP backend '{}' with {} tools",
203 server_name,
204 tool_count
205 );
206 server_info.state = McpConnectionState::Connected { tool_names };
207 server_info.last_updated = Utc::now();
208 registry
209 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
210 .await;
211 }
212 Err(e) => {
213 tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
214 server_info.state = McpConnectionState::Failed {
215 error: e.to_string(),
216 };
217 server_info.last_updated = Utc::now();
218 }
219 }
220
221 mcp_servers.insert(server_name.clone(), server_info);
222 }
223
224 Ok((registry, mcp_servers))
225 }
226
227 pub fn filter_tools_by_visibility(
229 &self,
230 tools: Vec<steer_tools::ToolSchema>,
231 ) -> Vec<steer_tools::ToolSchema> {
232 match &self.tool_config.visibility {
233 ToolVisibility::All => tools,
234 ToolVisibility::ReadOnly => {
235 let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
236 .iter()
237 .map(|name| (*name).to_string())
238 .collect();
239
240 tools
241 .into_iter()
242 .filter(|schema| read_only_names.contains(&schema.name))
243 .collect()
244 }
245 ToolVisibility::Whitelist(allowed) => tools
246 .into_iter()
247 .filter(|schema| allowed.contains(&schema.name))
248 .collect(),
249 ToolVisibility::Blacklist(blocked) => tools
250 .into_iter()
251 .filter(|schema| !blocked.contains(&schema.name))
252 .collect(),
253 }
254 }
255
256 #[cfg(test)]
258 pub fn read_only(default_model: ModelId) -> Self {
259 Self {
260 workspace: WorkspaceConfig::Local {
261 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
262 },
263 workspace_ref: None,
264 workspace_id: None,
265 repo_ref: None,
266 parent_session_id: None,
267 workspace_name: None,
268 tool_config: SessionToolConfig::read_only(),
269 system_prompt: None,
270 primary_agent_id: None,
271 policy_overrides: SessionPolicyOverrides::empty(),
272 metadata: HashMap::new(),
273 default_model,
274 }
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
280pub struct SessionPolicyOverrides {
281 #[serde(default, skip_serializing_if = "Option::is_none")]
282 pub default_model: Option<ModelId>,
283 #[serde(default, skip_serializing_if = "Option::is_none")]
284 pub tool_visibility: Option<ToolVisibility>,
285 #[serde(default = "ToolApprovalPolicyOverrides::empty")]
286 pub approval_policy: ToolApprovalPolicyOverrides,
287}
288
289impl SessionPolicyOverrides {
290 pub fn empty() -> Self {
291 Self {
292 default_model: None,
293 tool_visibility: None,
294 approval_policy: ToolApprovalPolicyOverrides::empty(),
295 }
296 }
297
298 pub fn is_empty(&self) -> bool {
299 self.default_model.is_none()
300 && self.tool_visibility.is_none()
301 && self.approval_policy.is_empty()
302 }
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
306pub struct ToolApprovalPolicyOverrides {
307 #[serde(default, skip_serializing_if = "Option::is_none")]
308 pub default_behavior: Option<UnapprovedBehavior>,
309 #[serde(default = "ApprovalRulesOverrides::empty")]
310 pub preapproved: ApprovalRulesOverrides,
311}
312
313impl ToolApprovalPolicyOverrides {
314 pub fn empty() -> Self {
315 Self {
316 default_behavior: None,
317 preapproved: ApprovalRulesOverrides::empty(),
318 }
319 }
320
321 pub fn is_empty(&self) -> bool {
322 self.default_behavior.is_none() && self.preapproved.is_empty()
323 }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
327pub struct ApprovalRulesOverrides {
328 #[serde(default)]
329 pub tools: HashSet<String>,
330 #[serde(default)]
331 pub per_tool: HashMap<String, ToolRuleOverrides>,
332}
333
334impl ApprovalRulesOverrides {
335 pub fn empty() -> Self {
336 Self {
337 tools: HashSet::new(),
338 per_tool: HashMap::new(),
339 }
340 }
341
342 pub fn is_empty(&self) -> bool {
343 self.tools.is_empty() && self.per_tool.is_empty()
344 }
345}
346
347#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
348#[serde(tag = "type", rename_all = "snake_case")]
349pub enum ToolRuleOverrides {
350 Bash { patterns: Vec<String> },
351 DispatchAgent { agent_patterns: Vec<String> },
352}
353
354#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
356#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
357pub enum ToolVisibility {
358 #[default]
360 All,
361
362 ReadOnly,
364
365 Whitelist(HashSet<String>),
367
368 Blacklist(HashSet<String>),
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq)]
373pub enum ToolDecision {
374 Allow,
375 Ask,
376 Deny,
377}
378
379#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
380#[serde(rename_all = "snake_case")]
381pub enum UnapprovedBehavior {
382 #[default]
383 Prompt,
384 Deny,
385 Allow,
386}
387
388#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
389#[serde(tag = "type", rename_all = "snake_case")]
390pub enum ToolRule {
391 Bash { patterns: Vec<String> },
392 DispatchAgent { agent_patterns: Vec<String> },
393}
394
395#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
396pub struct ApprovalRules {
397 #[serde(default)]
398 pub tools: HashSet<String>,
399 #[serde(default)]
400 pub per_tool: HashMap<String, ToolRule>,
401}
402
403impl ApprovalRules {
404 pub fn is_empty(&self) -> bool {
405 self.tools.is_empty() && self.per_tool.is_empty()
406 }
407
408 pub fn bash_patterns(&self) -> Option<&[String]> {
409 self.per_tool.get("bash").and_then(|rule| match rule {
410 ToolRule::Bash { patterns } => Some(patterns.as_slice()),
411 ToolRule::DispatchAgent { .. } => None,
412 })
413 }
414
415 pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
416 self.per_tool
417 .get("dispatch_agent")
418 .and_then(|rule| match rule {
419 ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
420 ToolRule::Bash { .. } => None,
421 })
422 }
423}
424
425#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
426pub struct ToolApprovalPolicy {
427 pub default_behavior: UnapprovedBehavior,
428 #[serde(default)]
429 pub preapproved: ApprovalRules,
430}
431
432impl Default for ToolApprovalPolicy {
433 fn default() -> Self {
434 Self {
435 default_behavior: UnapprovedBehavior::Prompt,
436 preapproved: ApprovalRules {
437 tools: READ_ONLY_TOOL_NAMES
438 .iter()
439 .map(|name| (*name).to_string())
440 .collect(),
441 per_tool: HashMap::new(),
442 },
443 }
444 }
445}
446
447impl ToolApprovalPolicy {
448 pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
449 if self.preapproved.tools.contains(tool_name) {
450 ToolDecision::Allow
451 } else {
452 match self.default_behavior {
453 UnapprovedBehavior::Prompt => ToolDecision::Ask,
454 UnapprovedBehavior::Deny => ToolDecision::Deny,
455 UnapprovedBehavior::Allow => ToolDecision::Allow,
456 }
457 }
458 }
459
460 pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
461 let Some(patterns) = self.preapproved.bash_patterns() else {
462 return false;
463 };
464 patterns.iter().any(|pattern| {
465 if pattern == command {
466 return true;
467 }
468 glob::Pattern::new(pattern)
469 .map(|glob| glob.matches(command))
470 .unwrap_or(false)
471 })
472 }
473
474 pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
475 let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
476 return false;
477 };
478 patterns.iter().any(|pattern| {
479 if pattern == agent_id {
480 return true;
481 }
482 glob::Pattern::new(pattern)
483 .map(|glob| glob.matches(agent_id))
484 .unwrap_or(false)
485 })
486 }
487
488 pub fn pre_approved_tools(&self) -> &HashSet<String> {
489 &self.preapproved.tools
490 }
491}
492
493impl ToolApprovalPolicyOverrides {
494 pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
495 let mut merged = base.clone();
496
497 if let Some(default_behavior) = self.default_behavior {
498 merged.default_behavior = default_behavior;
499 }
500
501 if !self.preapproved.tools.is_empty() {
502 merged
503 .preapproved
504 .tools
505 .extend(self.preapproved.tools.iter().cloned());
506 }
507
508 for (tool_name, override_rule) in &self.preapproved.per_tool {
509 let base_rule = merged.preapproved.per_tool.get(tool_name);
510 let merged_rule = merge_tool_rule_override(base_rule, override_rule);
511 merged
512 .preapproved
513 .per_tool
514 .insert(tool_name.clone(), merged_rule);
515 }
516
517 merged
518 }
519}
520
521fn merge_tool_rule_override(
522 base: Option<&ToolRule>,
523 override_rule: &ToolRuleOverrides,
524) -> ToolRule {
525 match (base, override_rule) {
526 (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
527 ToolRule::Bash {
528 patterns: merge_patterns(patterns, extra),
529 }
530 }
531 (
532 Some(ToolRule::DispatchAgent { agent_patterns }),
533 ToolRuleOverrides::DispatchAgent {
534 agent_patterns: extra,
535 },
536 ) => ToolRule::DispatchAgent {
537 agent_patterns: merge_patterns(agent_patterns, extra),
538 },
539 (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
540 patterns: patterns.clone(),
541 },
542 (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
543 agent_patterns: agent_patterns.clone(),
544 },
545 }
546}
547
548fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
549 let mut merged = base.to_vec();
550 for pattern in extra {
551 if !merged.iter().any(|existing| existing == pattern) {
552 merged.push(pattern.clone());
553 }
554 }
555 merged
556}
557
558#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
560pub enum RemoteAuth {
561 Bearer { token: String },
562 ApiKey { key: String },
563}
564
565impl RemoteAuth {
566 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
568 match self {
569 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
570 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
571 }
572 }
573}
574
575#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
577#[serde(rename_all = "snake_case")]
578#[derive(Default)]
579pub enum ToolFilter {
580 #[default]
582 All,
583 Include(Vec<String>),
585 Exclude(Vec<String>),
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
591#[serde(tag = "type", rename_all = "snake_case")]
592pub enum BackendConfig {
593 Mcp {
594 server_name: String,
595 transport: McpTransport,
596 tool_filter: ToolFilter,
597 },
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
601pub struct SessionToolConfig {
602 pub backends: Vec<BackendConfig>,
603 pub visibility: ToolVisibility,
604 pub approval_policy: ToolApprovalPolicy,
605 pub metadata: HashMap<String, String>,
606}
607
608impl Default for SessionToolConfig {
609 fn default() -> Self {
610 Self {
611 backends: Vec::new(),
612 visibility: ToolVisibility::All,
613 approval_policy: ToolApprovalPolicy::default(),
614 metadata: HashMap::new(),
615 }
616 }
617}
618
619impl SessionToolConfig {
620 pub fn read_only() -> Self {
621 Self {
622 backends: Vec::new(),
623 visibility: ToolVisibility::ReadOnly,
624 approval_policy: ToolApprovalPolicy::default(),
625 metadata: HashMap::new(),
626 }
627 }
628}
629
630#[derive(Debug, Clone, Serialize, Deserialize, Default)]
632pub struct SessionState {
633 pub messages: Vec<Message>,
635
636 pub tool_calls: HashMap<String, ToolCallState>,
638
639 pub approved_tools: HashSet<String>,
641
642 #[serde(default)]
644 pub approved_bash_patterns: HashSet<String>,
645
646 pub last_event_sequence: u64,
648
649 pub metadata: HashMap<String, String>,
651
652 #[serde(default, skip_serializing_if = "Option::is_none")]
655 pub active_message_id: Option<String>,
656
657 #[serde(default, skip_serializing, skip_deserializing)]
660 pub mcp_servers: HashMap<String, McpServerInfo>,
661}
662
663impl SessionState {
664 pub fn add_message(&mut self, message: Message) {
666 self.messages.push(message);
667 }
668
669 pub fn message_count(&self) -> usize {
671 self.messages.len()
672 }
673
674 pub fn last_message(&self) -> Option<&Message> {
676 self.messages.last()
677 }
678
679 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
681 let state = ToolCallState {
682 tool_call: tool_call.clone(),
683 status: ToolCallStatus::PendingApproval,
684 started_at: None,
685 completed_at: None,
686 result: None,
687 };
688 self.tool_calls.insert(tool_call.id, state);
689 }
690
691 pub fn update_tool_call_status(
693 &mut self,
694 tool_call_id: &str,
695 status: ToolCallStatus,
696 ) -> std::result::Result<(), String> {
697 let tool_call = self
698 .tool_calls
699 .get_mut(tool_call_id)
700 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
701
702 match (&tool_call.status, &status) {
704 (_, ToolCallStatus::Executing) => {
705 tool_call.started_at = Some(Utc::now());
706 }
707 (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
708 tool_call.completed_at = Some(Utc::now());
709 }
710 _ => {}
711 }
712
713 tool_call.status = status;
714 Ok(())
715 }
716
717 pub fn approve_tool(&mut self, tool_name: String) {
719 self.approved_tools.insert(tool_name);
720 }
721
722 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
724 self.approved_tools.contains(tool_name)
725 }
726
727 pub fn validate(&self) -> std::result::Result<(), String> {
729 for message in &self.messages {
731 let tool_calls = Self::extract_tool_calls_from_message(message);
732 if !tool_calls.is_empty() {
733 for tool_call_id in tool_calls {
734 if !self.tool_calls.contains_key(&tool_call_id) {
735 return Err(format!(
736 "Message references unknown tool call: {tool_call_id}"
737 ));
738 }
739 }
740 }
741 }
742
743 Ok(())
744 }
745
746 fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
748 let mut tool_call_ids = Vec::new();
749
750 match &message.data {
751 MessageData::Assistant { content, .. } => {
752 for c in content {
753 if let crate::app::conversation::AssistantContent::ToolCall {
754 tool_call, ..
755 } = c
756 {
757 tool_call_ids.push(tool_call.id.clone());
758 }
759 }
760 }
761 MessageData::Tool { tool_use_id, .. } => {
762 tool_call_ids.push(tool_use_id.clone());
763 }
764 MessageData::User { .. } => {}
765 }
766
767 tool_call_ids
768 }
769}
770
771#[derive(Debug, Clone, Serialize, Deserialize)]
773pub struct ToolCallState {
774 pub tool_call: ToolCall,
775 pub status: ToolCallStatus,
776 pub started_at: Option<DateTime<Utc>>,
777 pub completed_at: Option<DateTime<Utc>>,
778 pub result: Option<ToolResult>,
779}
780
781impl ToolCallState {
782 pub fn is_pending(&self) -> bool {
783 matches!(self.status, ToolCallStatus::PendingApproval)
784 }
785
786 pub fn is_complete(&self) -> bool {
787 matches!(
788 self.status,
789 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
790 )
791 }
792
793 pub fn duration(&self) -> Option<chrono::Duration> {
794 match (self.started_at, self.completed_at) {
795 (Some(start), Some(end)) => Some(end - start),
796 _ => None,
797 }
798 }
799}
800
801#[derive(Debug, Clone, Serialize, Deserialize)]
803#[serde(tag = "status", rename_all = "snake_case")]
804pub enum ToolCallStatus {
805 PendingApproval,
806 Approved,
807 Denied,
808 Executing,
809 Completed,
810 Failed { error: String },
811}
812
813impl ToolCallStatus {
814 pub fn is_terminal(&self) -> bool {
815 matches!(
816 self,
817 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
818 )
819 }
820}
821
822#[derive(Debug, Clone, Serialize, Deserialize)]
824pub struct ToolExecutionStats {
825 #[serde(skip_serializing_if = "Option::is_none")]
826 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
828 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
831 pub execution_time_ms: u64,
832 pub metadata: HashMap<String, String>,
833}
834
835impl ToolExecutionStats {
836 pub fn success(output: String, execution_time_ms: u64) -> Self {
837 Self {
838 output: Some(output),
839 json_output: None,
840 result_type: None,
841 success: true,
842 execution_time_ms,
843 metadata: HashMap::new(),
844 }
845 }
846
847 pub fn success_typed(
848 json_output: serde_json::Value,
849 result_type: String,
850 execution_time_ms: u64,
851 ) -> Self {
852 Self {
853 output: None,
854 json_output: Some(json_output),
855 result_type: Some(result_type),
856 success: true,
857 execution_time_ms,
858 metadata: HashMap::new(),
859 }
860 }
861
862 pub fn failure(error: String, execution_time_ms: u64) -> Self {
863 Self {
864 output: Some(error),
865 json_output: None,
866 result_type: None,
867 success: false,
868 execution_time_ms,
869 metadata: HashMap::new(),
870 }
871 }
872
873 pub fn with_metadata(mut self, key: String, value: String) -> Self {
874 self.metadata.insert(key, value);
875 self
876 }
877}
878
879#[derive(Debug, Clone, Serialize, Deserialize)]
881pub struct SessionInfo {
882 pub id: String,
883 pub created_at: DateTime<Utc>,
884 pub updated_at: DateTime<Utc>,
885 pub last_model: Option<ModelId>,
887 pub message_count: usize,
888 pub metadata: HashMap<String, String>,
889}
890
891impl From<&Session> for SessionInfo {
892 fn from(session: &Session) -> Self {
893 Self {
894 id: session.id.clone(),
895 created_at: session.created_at,
896 updated_at: session.updated_at,
897 last_model: None, message_count: session.state.message_count(),
899 metadata: session.config.metadata.clone(),
900 }
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907 use crate::app::conversation::{Message, MessageData, UserContent};
908 use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
909 use crate::tools::DISPATCH_AGENT_TOOL_NAME;
910 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
911 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
912
913 #[test]
914 fn test_session_creation() {
915 let config = SessionConfig {
916 workspace: WorkspaceConfig::Local {
917 path: PathBuf::from("/test/path"),
918 },
919 workspace_ref: None,
920 workspace_id: None,
921 repo_ref: None,
922 parent_session_id: None,
923 workspace_name: None,
924 tool_config: SessionToolConfig::default(),
925 system_prompt: None,
926 primary_agent_id: None,
927 policy_overrides: SessionPolicyOverrides::empty(),
928 metadata: HashMap::new(),
929 default_model: test_model(),
930 };
931 let session = Session::new("test-session".to_string(), config.clone());
932
933 assert_eq!(session.id, "test-session");
934 assert_eq!(
935 session
936 .config
937 .tool_config
938 .approval_policy
939 .tool_decision("any_tool"),
940 ToolDecision::Ask
941 );
942 assert_eq!(session.state.message_count(), 0);
943 }
944
945 #[test]
946 fn test_tool_approval_policy_prompt_unapproved() {
947 let policy = ToolApprovalPolicy {
948 default_behavior: UnapprovedBehavior::Prompt,
949 preapproved: ApprovalRules {
950 tools: ["read_file", "list_files"]
951 .iter()
952 .map(|s| (*s).to_string())
953 .collect(),
954 per_tool: HashMap::new(),
955 },
956 };
957
958 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
959 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
960 }
961
962 #[test]
963 fn test_tool_approval_policy_deny_unapproved() {
964 let policy = ToolApprovalPolicy {
965 default_behavior: UnapprovedBehavior::Deny,
966 preapproved: ApprovalRules {
967 tools: ["read_file", "list_files"]
968 .iter()
969 .map(|s| (*s).to_string())
970 .collect(),
971 per_tool: HashMap::new(),
972 },
973 };
974
975 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
976 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
977 }
978
979 #[test]
980 fn test_tool_approval_policy_default() {
981 let policy = ToolApprovalPolicy::default();
982
983 assert_eq!(
984 policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
985 ToolDecision::Allow
986 );
987 assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
988 }
989
990 #[test]
991 fn test_tool_approval_policy_allow_unapproved() {
992 let policy = ToolApprovalPolicy {
993 default_behavior: UnapprovedBehavior::Allow,
994 preapproved: ApprovalRules {
995 tools: ["read_file", "list_files"]
996 .iter()
997 .map(|s| (*s).to_string())
998 .collect(),
999 per_tool: HashMap::new(),
1000 },
1001 };
1002
1003 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1004 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1005 }
1006
1007 #[test]
1008 fn test_tool_approval_policy_overrides_union_rules() {
1009 let base_policy = ToolApprovalPolicy {
1010 default_behavior: UnapprovedBehavior::Prompt,
1011 preapproved: ApprovalRules {
1012 tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1013 per_tool: [
1014 (
1015 BASH_TOOL_NAME.to_string(),
1016 ToolRule::Bash {
1017 patterns: vec!["git status".to_string()],
1018 },
1019 ),
1020 (
1021 DISPATCH_AGENT_TOOL_NAME.to_string(),
1022 ToolRule::DispatchAgent {
1023 agent_patterns: vec!["explore".to_string()],
1024 },
1025 ),
1026 ]
1027 .into_iter()
1028 .collect(),
1029 },
1030 };
1031
1032 let overrides = ToolApprovalPolicyOverrides {
1033 default_behavior: Some(UnapprovedBehavior::Deny),
1034 preapproved: ApprovalRulesOverrides {
1035 tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1036 per_tool: [
1037 (
1038 BASH_TOOL_NAME.to_string(),
1039 ToolRuleOverrides::Bash {
1040 patterns: vec!["git log".to_string()],
1041 },
1042 ),
1043 (
1044 DISPATCH_AGENT_TOOL_NAME.to_string(),
1045 ToolRuleOverrides::DispatchAgent {
1046 agent_patterns: vec!["review".to_string()],
1047 },
1048 ),
1049 ]
1050 .into_iter()
1051 .collect(),
1052 },
1053 };
1054
1055 let merged = overrides.apply_to(&base_policy);
1056
1057 assert_eq!(merged.default_behavior, UnapprovedBehavior::Deny);
1058 assert!(merged.preapproved.tools.contains("read_file"));
1059 assert!(merged.preapproved.tools.contains("write_file"));
1060
1061 let bash_patterns = match merged
1062 .preapproved
1063 .per_tool
1064 .get(BASH_TOOL_NAME)
1065 .expect("bash rule")
1066 {
1067 ToolRule::Bash { patterns } => patterns,
1068 ToolRule::DispatchAgent { .. } => {
1069 panic!("Unexpected bash rule: dispatch agent")
1070 }
1071 };
1072 assert!(bash_patterns.contains(&"git status".to_string()));
1073 assert!(bash_patterns.contains(&"git log".to_string()));
1074 assert_eq!(bash_patterns.len(), 2);
1075
1076 let agent_patterns = match merged
1077 .preapproved
1078 .per_tool
1079 .get(DISPATCH_AGENT_TOOL_NAME)
1080 .expect("dispatch_agent rule")
1081 {
1082 ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1083 ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1084 };
1085 assert!(agent_patterns.contains(&"explore".to_string()));
1086 assert!(agent_patterns.contains(&"review".to_string()));
1087 assert_eq!(agent_patterns.len(), 2);
1088 }
1089
1090 #[test]
1091 fn test_bash_pattern_matching() {
1092 let policy = ToolApprovalPolicy {
1093 default_behavior: UnapprovedBehavior::Prompt,
1094 preapproved: ApprovalRules {
1095 tools: HashSet::new(),
1096 per_tool: [(
1097 "bash".to_string(),
1098 ToolRule::Bash {
1099 patterns: vec![
1100 "git status".to_string(),
1101 "git log*".to_string(),
1102 "git * --oneline".to_string(),
1103 "ls -?a*".to_string(),
1104 "cargo build*".to_string(),
1105 ],
1106 },
1107 )]
1108 .into_iter()
1109 .collect(),
1110 },
1111 };
1112
1113 assert!(policy.is_bash_pattern_preapproved("git status"));
1114 assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1115 assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1116 assert!(policy.is_bash_pattern_preapproved("ls -la"));
1117 assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1118 assert!(!policy.is_bash_pattern_preapproved("git commit"));
1119 assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1120 assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1121 }
1122
1123 #[test]
1124 fn test_dispatch_agent_pattern_matching() {
1125 let policy = ToolApprovalPolicy {
1126 default_behavior: UnapprovedBehavior::Prompt,
1127 preapproved: ApprovalRules {
1128 tools: HashSet::new(),
1129 per_tool: [(
1130 "dispatch_agent".to_string(),
1131 ToolRule::DispatchAgent {
1132 agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1133 },
1134 )]
1135 .into_iter()
1136 .collect(),
1137 },
1138 };
1139
1140 assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1141 assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1142 assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1143 }
1144
1145 #[test]
1146 fn test_session_state_validation() {
1147 let mut state = SessionState::default();
1148
1149 assert!(state.validate().is_ok());
1151
1152 let message = Message {
1154 data: MessageData::User {
1155 content: vec![UserContent::Text {
1156 text: "Hello".to_string(),
1157 }],
1158 },
1159 timestamp: 123_456_789,
1160 id: "msg1".to_string(),
1161 parent_message_id: None,
1162 };
1163 state.add_message(message);
1164
1165 assert!(state.validate().is_ok());
1166 assert_eq!(state.message_count(), 1);
1167 }
1168
1169 #[test]
1170 fn test_tool_call_state_tracking() {
1171 let mut state = SessionState::default();
1172
1173 let tool_call = ToolCall {
1174 id: "tool1".to_string(),
1175 name: "read_file".to_string(),
1176 parameters: serde_json::json!({"path": "/test.txt"}),
1177 };
1178
1179 state.add_tool_call(tool_call.clone());
1180 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1181
1182 state
1183 .update_tool_call_status("tool1", ToolCallStatus::Executing)
1184 .unwrap();
1185 let tool_state = state.tool_calls.get("tool1").unwrap();
1186 assert!(tool_state.started_at.is_some());
1187 assert!(!tool_state.is_complete());
1188
1189 state
1190 .update_tool_call_status("tool1", ToolCallStatus::Completed)
1191 .unwrap();
1192 let tool_state = state.tool_calls.get("tool1").unwrap();
1193 assert!(tool_state.completed_at.is_some());
1194 assert!(tool_state.is_complete());
1195 }
1196
1197 #[test]
1198 fn test_session_tool_config_default() {
1199 let config = SessionToolConfig::default();
1200 assert!(config.backends.is_empty());
1201 }
1202
1203 #[test]
1204 fn test_tool_filter_exclude() {
1205 let excluded =
1206 ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1207
1208 if let ToolFilter::Exclude(tools) = &excluded {
1209 assert_eq!(tools.len(), 2);
1210 assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1211 assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1212 } else {
1213 panic!("Expected ToolFilter::Exclude");
1214 }
1215 }
1216
1217 #[test]
1218 fn test_session_tool_config_read_only() {
1219 let config = SessionToolConfig::read_only();
1220 assert_eq!(config.backends.len(), 0);
1221 assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1222 assert_eq!(
1223 config.approval_policy.default_behavior,
1224 UnapprovedBehavior::Prompt
1225 );
1226 }
1227
1228 #[tokio::test]
1229 async fn test_session_config_build_registry_no_default_backends() {
1230 let config = SessionConfig {
1234 workspace: WorkspaceConfig::Local {
1235 path: PathBuf::from("/test/path"),
1236 },
1237 workspace_ref: None,
1238 workspace_id: None,
1239 repo_ref: None,
1240 parent_session_id: None,
1241 workspace_name: None,
1242 tool_config: SessionToolConfig::default(), system_prompt: None,
1244 primary_agent_id: None,
1245 policy_overrides: SessionPolicyOverrides::empty(),
1246 metadata: HashMap::new(),
1247 default_model: test_model(),
1248 };
1249
1250 let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1251 let schemas = registry.get_tool_schemas().await;
1252
1253 assert!(
1254 schemas.is_empty(),
1255 "BackendRegistry should be empty with default config; got: {:?}",
1256 schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1257 );
1258 }
1259
1260 #[test]
1267 fn test_mcp_status_tracking() {
1268 let mut session_state = SessionState::default();
1270
1271 let mcp_info = McpServerInfo {
1273 server_name: "test-server".to_string(),
1274 transport: crate::tools::McpTransport::Stdio {
1275 command: "python".to_string(),
1276 args: vec!["-m".to_string(), "test_server".to_string()],
1277 },
1278 state: McpConnectionState::Connected {
1279 tool_names: vec![
1280 "tool1".to_string(),
1281 "tool2".to_string(),
1282 "tool3".to_string(),
1283 "tool4".to_string(),
1284 "tool5".to_string(),
1285 ],
1286 },
1287 last_updated: Utc::now(),
1288 };
1289
1290 session_state
1291 .mcp_servers
1292 .insert("test-server".to_string(), mcp_info.clone());
1293
1294 assert_eq!(session_state.mcp_servers.len(), 1);
1296 let stored = session_state.mcp_servers.get("test-server").unwrap();
1297 assert_eq!(stored.server_name, "test-server");
1298 assert!(matches!(
1299 stored.state,
1300 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1301 ));
1302
1303 let failed_info = McpServerInfo {
1305 server_name: "failed-server".to_string(),
1306 transport: crate::tools::McpTransport::Tcp {
1307 host: "localhost".to_string(),
1308 port: 9999,
1309 },
1310 state: McpConnectionState::Failed {
1311 error: "Connection refused".to_string(),
1312 },
1313 last_updated: Utc::now(),
1314 };
1315
1316 session_state
1317 .mcp_servers
1318 .insert("failed-server".to_string(), failed_info);
1319 assert_eq!(session_state.mcp_servers.len(), 2);
1320 }
1321
1322 #[tokio::test]
1323 async fn test_mcp_server_tracking_in_build_registry() {
1324 let mut config = SessionConfig::read_only(test_model());
1326
1327 config.tool_config.backends.push(BackendConfig::Mcp {
1329 server_name: "bad-server".to_string(),
1330 transport: crate::tools::McpTransport::Tcp {
1331 host: "nonexistent.invalid".to_string(),
1332 port: 12345,
1333 },
1334 tool_filter: ToolFilter::All,
1335 });
1336
1337 config.tool_config.backends.push(BackendConfig::Mcp {
1339 server_name: "good-server".to_string(),
1340 transport: crate::tools::McpTransport::Stdio {
1341 command: "echo".to_string(),
1342 args: vec!["test".to_string()],
1343 },
1344 tool_filter: ToolFilter::All,
1345 });
1346
1347 let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1348
1349 assert_eq!(mcp_servers.len(), 2);
1351
1352 let bad_server = mcp_servers.get("bad-server").unwrap();
1354 assert_eq!(bad_server.server_name, "bad-server");
1355 assert!(matches!(
1356 bad_server.state,
1357 McpConnectionState::Failed { .. }
1358 ));
1359
1360 let good_server = mcp_servers.get("good-server").unwrap();
1362 assert_eq!(good_server.server_name, "good-server");
1363 assert!(matches!(
1364 good_server.state,
1365 McpConnectionState::Failed { .. }
1366 ));
1367 }
1368
1369 #[test]
1370 fn test_backend_config_mcp_variant() {
1371 let mcp_config = BackendConfig::Mcp {
1372 server_name: "test-mcp".to_string(),
1373 transport: crate::tools::McpTransport::Stdio {
1374 command: "python".to_string(),
1375 args: vec!["-m".to_string(), "test_server".to_string()],
1376 },
1377 tool_filter: ToolFilter::All,
1378 };
1379
1380 let BackendConfig::Mcp {
1381 server_name,
1382 transport,
1383 ..
1384 } = mcp_config;
1385
1386 assert_eq!(server_name, "test-mcp");
1387 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1388 assert_eq!(command, "python");
1389 assert_eq!(args.len(), 2);
1390 } else {
1391 panic!("Expected Stdio transport");
1392 }
1393 }
1394}