1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18 Connecting,
20 Connected {
22 tool_names: Vec<String>,
24 },
25 Disconnected {
27 reason: Option<String>,
29 },
30 Failed {
32 error: String,
34 },
35}
36
37#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40 pub server_name: String,
42 pub transport: McpTransport,
44 pub state: McpConnectionState,
46 pub last_updated: DateTime<Utc>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54 Local {
55 path: PathBuf,
56 },
57 Remote {
58 agent_address: String,
59 auth: Option<RemoteAuth>,
60 },
61}
62
63impl WorkspaceConfig {
64 pub fn get_path(&self) -> Option<String> {
65 match self {
66 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68 }
69 }
70
71 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73 match self {
74 WorkspaceConfig::Local { path } => {
75 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76 }
77 WorkspaceConfig::Remote {
78 agent_address,
79 auth,
80 } => steer_workspace::WorkspaceConfig::Remote {
81 address: agent_address.clone(),
82 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83 },
84 }
85 }
86}
87
88impl Default for WorkspaceConfig {
89 fn default() -> Self {
90 Self::Local {
91 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99 pub id: String,
100 pub created_at: DateTime<Utc>,
101 pub updated_at: DateTime<Utc>,
102 pub config: SessionConfig,
103 pub state: SessionState,
104}
105
106impl Session {
107 pub fn new(id: String, config: SessionConfig) -> Self {
108 let now = Utc::now();
109 Self {
110 id,
111 created_at: now,
112 updated_at: now,
113 config,
114 state: SessionState::default(),
115 }
116 }
117
118 pub fn update_timestamp(&mut self) {
119 self.updated_at = Utc::now();
120 }
121
122 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124 let cutoff = Utc::now() - threshold;
125 self.updated_at > cutoff
126 }
127
128 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137 pub workspace: WorkspaceConfig,
138 #[serde(default)]
139 pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140 #[serde(default)]
141 pub workspace_id: Option<crate::workspace::WorkspaceId>,
142 #[serde(default)]
143 pub repo_ref: Option<crate::workspace::RepoRef>,
144 #[serde(default)]
145 pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146 #[serde(default)]
147 pub workspace_name: Option<String>,
148 pub tool_config: SessionToolConfig,
149 pub system_prompt: Option<String>,
152 #[serde(default)]
154 pub primary_agent_id: Option<String>,
155 #[serde(default = "SessionPolicyOverrides::empty")]
157 pub policy_overrides: SessionPolicyOverrides,
158 pub metadata: HashMap<String, String>,
159 pub default_model: ModelId,
160 #[serde(default)]
161 pub auto_compaction: AutoCompactionConfig,
162}
163
164impl SessionConfig {
165 pub async fn build_registry(
168 &self,
169 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
170 let mut registry = BackendRegistry::new();
171 let mut mcp_servers = HashMap::new();
172
173 for backend_config in &self.tool_config.backends {
174 let BackendConfig::Mcp {
175 server_name,
176 transport,
177 tool_filter,
178 } = backend_config;
179
180 tracing::info!(
181 "Attempting to initialize MCP backend '{}' with transport: {:?}",
182 server_name,
183 transport
184 );
185
186 let mut server_info = McpServerInfo {
187 server_name: server_name.clone(),
188 transport: transport.clone(),
189 state: McpConnectionState::Connecting,
190 last_updated: Utc::now(),
191 };
192
193 match crate::tools::McpBackend::new(
194 server_name.clone(),
195 transport.clone(),
196 tool_filter.clone(),
197 )
198 .await
199 {
200 Ok(mcp_backend) => {
201 let tool_names = mcp_backend.supported_tools().await;
202 let tool_count = tool_names.len();
203 tracing::info!(
204 "Successfully initialized MCP backend '{}' with {} tools",
205 server_name,
206 tool_count
207 );
208 server_info.state = McpConnectionState::Connected { tool_names };
209 server_info.last_updated = Utc::now();
210 registry
211 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
212 .await;
213 }
214 Err(e) => {
215 tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
216 server_info.state = McpConnectionState::Failed {
217 error: e.to_string(),
218 };
219 server_info.last_updated = Utc::now();
220 }
221 }
222
223 mcp_servers.insert(server_name.clone(), server_info);
224 }
225
226 Ok((registry, mcp_servers))
227 }
228
229 pub fn filter_tools_by_visibility(
231 &self,
232 tools: Vec<steer_tools::ToolSchema>,
233 ) -> Vec<steer_tools::ToolSchema> {
234 match &self.tool_config.visibility {
235 ToolVisibility::All => tools,
236 ToolVisibility::ReadOnly => {
237 let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
238 .iter()
239 .map(|name| (*name).to_string())
240 .collect();
241
242 tools
243 .into_iter()
244 .filter(|schema| read_only_names.contains(&schema.name))
245 .collect()
246 }
247 ToolVisibility::Whitelist(allowed) => tools
248 .into_iter()
249 .filter(|schema| allowed.contains(&schema.name))
250 .collect(),
251 ToolVisibility::Blacklist(blocked) => tools
252 .into_iter()
253 .filter(|schema| !blocked.contains(&schema.name))
254 .collect(),
255 }
256 }
257
258 #[cfg(test)]
260 pub fn read_only(default_model: ModelId) -> Self {
261 Self {
262 workspace: WorkspaceConfig::Local {
263 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
264 },
265 workspace_ref: None,
266 workspace_id: None,
267 repo_ref: None,
268 parent_session_id: None,
269 workspace_name: None,
270 tool_config: SessionToolConfig::read_only(),
271 system_prompt: None,
272 primary_agent_id: None,
273 policy_overrides: SessionPolicyOverrides::empty(),
274 metadata: HashMap::new(),
275 default_model,
276 auto_compaction: AutoCompactionConfig::default(),
277 }
278 }
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
283pub struct AutoCompactionConfig {
284 pub enabled: bool,
285 pub threshold_percent: u32,
286}
287
288impl Default for AutoCompactionConfig {
289 fn default() -> Self {
290 Self {
291 enabled: true,
292 threshold_percent: 90,
293 }
294 }
295}
296
297impl AutoCompactionConfig {
298 pub fn threshold_ratio(&self) -> f64 {
299 f64::from(self.threshold_percent) / 100.0
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
305pub struct SessionPolicyOverrides {
306 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub default_model: Option<ModelId>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
309 pub tool_visibility: Option<ToolVisibility>,
310 #[serde(default = "ToolApprovalPolicyOverrides::empty")]
311 pub approval_policy: ToolApprovalPolicyOverrides,
312}
313
314impl SessionPolicyOverrides {
315 pub fn empty() -> Self {
316 Self {
317 default_model: None,
318 tool_visibility: None,
319 approval_policy: ToolApprovalPolicyOverrides::empty(),
320 }
321 }
322
323 pub fn is_empty(&self) -> bool {
324 self.default_model.is_none()
325 && self.tool_visibility.is_none()
326 && self.approval_policy.is_empty()
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
331pub struct ToolApprovalPolicyOverrides {
332 #[serde(default = "ApprovalRulesOverrides::empty")]
333 pub preapproved: ApprovalRulesOverrides,
334}
335
336impl ToolApprovalPolicyOverrides {
337 pub fn empty() -> Self {
338 Self {
339 preapproved: ApprovalRulesOverrides::empty(),
340 }
341 }
342
343 pub fn is_empty(&self) -> bool {
344 self.preapproved.is_empty()
345 }
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
349pub struct ApprovalRulesOverrides {
350 #[serde(default)]
351 pub tools: HashSet<String>,
352 #[serde(default)]
353 pub per_tool: HashMap<String, ToolRuleOverrides>,
354}
355
356impl ApprovalRulesOverrides {
357 pub fn empty() -> Self {
358 Self {
359 tools: HashSet::new(),
360 per_tool: HashMap::new(),
361 }
362 }
363
364 pub fn is_empty(&self) -> bool {
365 self.tools.is_empty() && self.per_tool.is_empty()
366 }
367}
368
369#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
370#[serde(tag = "type", rename_all = "snake_case")]
371pub enum ToolRuleOverrides {
372 Bash { patterns: Vec<String> },
373 DispatchAgent { agent_patterns: Vec<String> },
374}
375
376#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
378#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
379pub enum ToolVisibility {
380 #[default]
382 All,
383
384 ReadOnly,
386
387 Whitelist(HashSet<String>),
389
390 Blacklist(HashSet<String>),
392}
393
394#[derive(Debug, Clone, Copy, PartialEq, Eq)]
395pub enum ToolDecision {
396 Allow,
397 Ask,
398 Deny,
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
402#[serde(rename_all = "snake_case")]
403pub enum UnapprovedBehavior {
404 #[default]
405 Prompt,
406 Deny,
407 Allow,
408}
409
410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
411#[serde(tag = "type", rename_all = "snake_case")]
412pub enum ToolRule {
413 Bash { patterns: Vec<String> },
414 DispatchAgent { agent_patterns: Vec<String> },
415}
416
417#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
418pub struct ApprovalRules {
419 #[serde(default)]
420 pub tools: HashSet<String>,
421 #[serde(default)]
422 pub per_tool: HashMap<String, ToolRule>,
423}
424
425impl ApprovalRules {
426 pub fn is_empty(&self) -> bool {
427 self.tools.is_empty() && self.per_tool.is_empty()
428 }
429
430 pub fn bash_patterns(&self) -> Option<&[String]> {
431 self.per_tool.get("bash").and_then(|rule| match rule {
432 ToolRule::Bash { patterns } => Some(patterns.as_slice()),
433 ToolRule::DispatchAgent { .. } => None,
434 })
435 }
436
437 pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
438 self.per_tool
439 .get("dispatch_agent")
440 .and_then(|rule| match rule {
441 ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
442 ToolRule::Bash { .. } => None,
443 })
444 }
445}
446
447#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
448pub struct ToolApprovalPolicy {
449 pub default_behavior: UnapprovedBehavior,
450 #[serde(default)]
451 pub preapproved: ApprovalRules,
452}
453
454impl Default for ToolApprovalPolicy {
455 fn default() -> Self {
456 Self {
457 default_behavior: UnapprovedBehavior::Prompt,
458 preapproved: ApprovalRules {
459 tools: READ_ONLY_TOOL_NAMES
460 .iter()
461 .map(|name| (*name).to_string())
462 .collect(),
463 per_tool: HashMap::new(),
464 },
465 }
466 }
467}
468
469impl ToolApprovalPolicy {
470 pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
471 if self.preapproved.tools.contains(tool_name) {
472 ToolDecision::Allow
473 } else {
474 match self.default_behavior {
475 UnapprovedBehavior::Prompt => ToolDecision::Ask,
476 UnapprovedBehavior::Deny => ToolDecision::Deny,
477 UnapprovedBehavior::Allow => ToolDecision::Allow,
478 }
479 }
480 }
481
482 pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
483 let Some(patterns) = self.preapproved.bash_patterns() else {
484 return false;
485 };
486 patterns.iter().any(|pattern| {
487 if pattern == command {
488 return true;
489 }
490 glob::Pattern::new(pattern)
491 .map(|glob| glob.matches(command))
492 .unwrap_or(false)
493 })
494 }
495
496 pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
497 let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
498 return false;
499 };
500 patterns.iter().any(|pattern| {
501 if pattern == agent_id {
502 return true;
503 }
504 glob::Pattern::new(pattern)
505 .map(|glob| glob.matches(agent_id))
506 .unwrap_or(false)
507 })
508 }
509
510 pub fn pre_approved_tools(&self) -> &HashSet<String> {
511 &self.preapproved.tools
512 }
513}
514
515impl ToolApprovalPolicyOverrides {
516 pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
517 let mut merged = base.clone();
518
519 if !self.preapproved.tools.is_empty() {
520 merged
521 .preapproved
522 .tools
523 .extend(self.preapproved.tools.iter().cloned());
524 }
525
526 for (tool_name, override_rule) in &self.preapproved.per_tool {
527 let base_rule = merged.preapproved.per_tool.get(tool_name);
528 let merged_rule = merge_tool_rule_override(base_rule, override_rule);
529 merged
530 .preapproved
531 .per_tool
532 .insert(tool_name.clone(), merged_rule);
533 }
534
535 merged
536 }
537}
538
539fn merge_tool_rule_override(
540 base: Option<&ToolRule>,
541 override_rule: &ToolRuleOverrides,
542) -> ToolRule {
543 match (base, override_rule) {
544 (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
545 ToolRule::Bash {
546 patterns: merge_patterns(patterns, extra),
547 }
548 }
549 (
550 Some(ToolRule::DispatchAgent { agent_patterns }),
551 ToolRuleOverrides::DispatchAgent {
552 agent_patterns: extra,
553 },
554 ) => ToolRule::DispatchAgent {
555 agent_patterns: merge_patterns(agent_patterns, extra),
556 },
557 (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
558 patterns: patterns.clone(),
559 },
560 (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
561 agent_patterns: agent_patterns.clone(),
562 },
563 }
564}
565
566fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
567 let mut merged = base.to_vec();
568 for pattern in extra {
569 if !merged.iter().any(|existing| existing == pattern) {
570 merged.push(pattern.clone());
571 }
572 }
573 merged
574}
575
576#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
578pub enum RemoteAuth {
579 Bearer { token: String },
580 ApiKey { key: String },
581}
582
583impl RemoteAuth {
584 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
586 match self {
587 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
588 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
589 }
590 }
591}
592
593#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
595#[serde(rename_all = "snake_case")]
596#[derive(Default)]
597pub enum ToolFilter {
598 #[default]
600 All,
601 Include(Vec<String>),
603 Exclude(Vec<String>),
605}
606
607#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
609#[serde(tag = "type", rename_all = "snake_case")]
610pub enum BackendConfig {
611 Mcp {
612 server_name: String,
613 transport: McpTransport,
614 tool_filter: ToolFilter,
615 },
616}
617
618#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
619pub struct SessionToolConfig {
620 pub backends: Vec<BackendConfig>,
621 pub visibility: ToolVisibility,
622 pub approval_policy: ToolApprovalPolicy,
623 pub metadata: HashMap<String, String>,
624}
625
626impl Default for SessionToolConfig {
627 fn default() -> Self {
628 Self {
629 backends: Vec::new(),
630 visibility: ToolVisibility::All,
631 approval_policy: ToolApprovalPolicy::default(),
632 metadata: HashMap::new(),
633 }
634 }
635}
636
637impl SessionToolConfig {
638 pub fn read_only() -> Self {
639 Self {
640 backends: Vec::new(),
641 visibility: ToolVisibility::ReadOnly,
642 approval_policy: ToolApprovalPolicy::default(),
643 metadata: HashMap::new(),
644 }
645 }
646}
647
648#[derive(Debug, Clone, Serialize, Deserialize, Default)]
650pub struct SessionState {
651 pub messages: Vec<Message>,
653
654 pub tool_calls: HashMap<String, ToolCallState>,
656
657 pub approved_tools: HashSet<String>,
659
660 #[serde(default)]
662 pub approved_bash_patterns: HashSet<String>,
663
664 pub last_event_sequence: u64,
666
667 pub metadata: HashMap<String, String>,
669
670 #[serde(default, skip_serializing_if = "Option::is_none")]
673 pub active_message_id: Option<String>,
674
675 #[serde(default, skip_serializing, skip_deserializing)]
678 pub mcp_servers: HashMap<String, McpServerInfo>,
679}
680
681impl SessionState {
682 pub fn add_message(&mut self, message: Message) {
684 self.messages.push(message);
685 }
686
687 pub fn message_count(&self) -> usize {
689 self.messages.len()
690 }
691
692 pub fn last_message(&self) -> Option<&Message> {
694 self.messages.last()
695 }
696
697 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
699 let state = ToolCallState {
700 tool_call: tool_call.clone(),
701 status: ToolCallStatus::PendingApproval,
702 started_at: None,
703 completed_at: None,
704 result: None,
705 };
706 self.tool_calls.insert(tool_call.id, state);
707 }
708
709 pub fn update_tool_call_status(
711 &mut self,
712 tool_call_id: &str,
713 status: ToolCallStatus,
714 ) -> std::result::Result<(), String> {
715 let tool_call = self
716 .tool_calls
717 .get_mut(tool_call_id)
718 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
719
720 match (&tool_call.status, &status) {
722 (_, ToolCallStatus::Executing) => {
723 tool_call.started_at = Some(Utc::now());
724 }
725 (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
726 tool_call.completed_at = Some(Utc::now());
727 }
728 _ => {}
729 }
730
731 tool_call.status = status;
732 Ok(())
733 }
734
735 pub fn approve_tool(&mut self, tool_name: String) {
737 self.approved_tools.insert(tool_name);
738 }
739
740 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
742 self.approved_tools.contains(tool_name)
743 }
744
745 pub fn validate(&self) -> std::result::Result<(), String> {
747 for message in &self.messages {
749 let tool_calls = Self::extract_tool_calls_from_message(message);
750 if !tool_calls.is_empty() {
751 for tool_call_id in tool_calls {
752 if !self.tool_calls.contains_key(&tool_call_id) {
753 return Err(format!(
754 "Message references unknown tool call: {tool_call_id}"
755 ));
756 }
757 }
758 }
759 }
760
761 Ok(())
762 }
763
764 fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
766 let mut tool_call_ids = Vec::new();
767
768 match &message.data {
769 MessageData::Assistant { content, .. } => {
770 for c in content {
771 if let crate::app::conversation::AssistantContent::ToolCall {
772 tool_call, ..
773 } = c
774 {
775 tool_call_ids.push(tool_call.id.clone());
776 }
777 }
778 }
779 MessageData::Tool { tool_use_id, .. } => {
780 tool_call_ids.push(tool_use_id.clone());
781 }
782 MessageData::User { .. } => {}
783 }
784
785 tool_call_ids
786 }
787}
788
789#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct ToolCallState {
792 pub tool_call: ToolCall,
793 pub status: ToolCallStatus,
794 pub started_at: Option<DateTime<Utc>>,
795 pub completed_at: Option<DateTime<Utc>>,
796 pub result: Option<ToolResult>,
797}
798
799impl ToolCallState {
800 pub fn is_pending(&self) -> bool {
801 matches!(self.status, ToolCallStatus::PendingApproval)
802 }
803
804 pub fn is_complete(&self) -> bool {
805 matches!(
806 self.status,
807 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
808 )
809 }
810
811 pub fn duration(&self) -> Option<chrono::Duration> {
812 match (self.started_at, self.completed_at) {
813 (Some(start), Some(end)) => Some(end - start),
814 _ => None,
815 }
816 }
817}
818
819#[derive(Debug, Clone, Serialize, Deserialize)]
821#[serde(tag = "status", rename_all = "snake_case")]
822pub enum ToolCallStatus {
823 PendingApproval,
824 Approved,
825 Denied,
826 Executing,
827 Completed,
828 Failed { error: String },
829}
830
831impl ToolCallStatus {
832 pub fn is_terminal(&self) -> bool {
833 matches!(
834 self,
835 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
836 )
837 }
838}
839
840#[derive(Debug, Clone, Serialize, Deserialize)]
842pub struct ToolExecutionStats {
843 #[serde(skip_serializing_if = "Option::is_none")]
844 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
846 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
849 pub execution_time_ms: u64,
850 pub metadata: HashMap<String, String>,
851}
852
853impl ToolExecutionStats {
854 pub fn success(output: String, execution_time_ms: u64) -> Self {
855 Self {
856 output: Some(output),
857 json_output: None,
858 result_type: None,
859 success: true,
860 execution_time_ms,
861 metadata: HashMap::new(),
862 }
863 }
864
865 pub fn success_typed(
866 json_output: serde_json::Value,
867 result_type: String,
868 execution_time_ms: u64,
869 ) -> Self {
870 Self {
871 output: None,
872 json_output: Some(json_output),
873 result_type: Some(result_type),
874 success: true,
875 execution_time_ms,
876 metadata: HashMap::new(),
877 }
878 }
879
880 pub fn failure(error: String, execution_time_ms: u64) -> Self {
881 Self {
882 output: Some(error),
883 json_output: None,
884 result_type: None,
885 success: false,
886 execution_time_ms,
887 metadata: HashMap::new(),
888 }
889 }
890
891 pub fn with_metadata(mut self, key: String, value: String) -> Self {
892 self.metadata.insert(key, value);
893 self
894 }
895}
896
897#[derive(Debug, Clone, Serialize, Deserialize)]
899pub struct SessionInfo {
900 pub id: String,
901 pub created_at: DateTime<Utc>,
902 pub updated_at: DateTime<Utc>,
903 pub last_model: Option<ModelId>,
905 pub message_count: usize,
906 pub metadata: HashMap<String, String>,
907}
908
909impl From<&Session> for SessionInfo {
910 fn from(session: &Session) -> Self {
911 Self {
912 id: session.id.clone(),
913 created_at: session.created_at,
914 updated_at: session.updated_at,
915 last_model: None, message_count: session.state.message_count(),
917 metadata: session.config.metadata.clone(),
918 }
919 }
920}
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925 use crate::app::conversation::{Message, MessageData, UserContent};
926 use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
927 use crate::tools::DISPATCH_AGENT_TOOL_NAME;
928 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
929 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
930
931 #[test]
932 fn test_session_creation() {
933 let config = SessionConfig {
934 workspace: WorkspaceConfig::Local {
935 path: PathBuf::from("/test/path"),
936 },
937 workspace_ref: None,
938 workspace_id: None,
939 repo_ref: None,
940 parent_session_id: None,
941 workspace_name: None,
942 tool_config: SessionToolConfig::default(),
943 system_prompt: None,
944 primary_agent_id: None,
945 policy_overrides: SessionPolicyOverrides::empty(),
946 metadata: HashMap::new(),
947 default_model: test_model(),
948 auto_compaction: AutoCompactionConfig::default(),
949 };
950 let session = Session::new("test-session".to_string(), config.clone());
951
952 assert_eq!(session.id, "test-session");
953 assert_eq!(
954 session
955 .config
956 .tool_config
957 .approval_policy
958 .tool_decision("any_tool"),
959 ToolDecision::Ask
960 );
961 assert_eq!(session.state.message_count(), 0);
962 }
963
964 #[test]
965 fn test_tool_approval_policy_prompt_unapproved() {
966 let policy = ToolApprovalPolicy {
967 default_behavior: UnapprovedBehavior::Prompt,
968 preapproved: ApprovalRules {
969 tools: ["read_file", "list_files"]
970 .iter()
971 .map(|s| (*s).to_string())
972 .collect(),
973 per_tool: HashMap::new(),
974 },
975 };
976
977 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
978 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
979 }
980
981 #[test]
982 fn test_tool_approval_policy_deny_unapproved() {
983 let policy = ToolApprovalPolicy {
984 default_behavior: UnapprovedBehavior::Deny,
985 preapproved: ApprovalRules {
986 tools: ["read_file", "list_files"]
987 .iter()
988 .map(|s| (*s).to_string())
989 .collect(),
990 per_tool: HashMap::new(),
991 },
992 };
993
994 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
995 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
996 }
997
998 #[test]
999 fn test_tool_approval_policy_default() {
1000 let policy = ToolApprovalPolicy::default();
1001
1002 assert_eq!(
1003 policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1004 ToolDecision::Allow
1005 );
1006 assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1007 }
1008
1009 #[test]
1010 fn test_tool_approval_policy_allow_unapproved() {
1011 let policy = ToolApprovalPolicy {
1012 default_behavior: UnapprovedBehavior::Allow,
1013 preapproved: ApprovalRules {
1014 tools: ["read_file", "list_files"]
1015 .iter()
1016 .map(|s| (*s).to_string())
1017 .collect(),
1018 per_tool: HashMap::new(),
1019 },
1020 };
1021
1022 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1023 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1024 }
1025
1026 #[test]
1027 fn test_tool_approval_policy_overrides_union_rules() {
1028 let base_policy = ToolApprovalPolicy {
1029 default_behavior: UnapprovedBehavior::Prompt,
1030 preapproved: ApprovalRules {
1031 tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1032 per_tool: [
1033 (
1034 BASH_TOOL_NAME.to_string(),
1035 ToolRule::Bash {
1036 patterns: vec!["git status".to_string()],
1037 },
1038 ),
1039 (
1040 DISPATCH_AGENT_TOOL_NAME.to_string(),
1041 ToolRule::DispatchAgent {
1042 agent_patterns: vec!["explore".to_string()],
1043 },
1044 ),
1045 ]
1046 .into_iter()
1047 .collect(),
1048 },
1049 };
1050
1051 let overrides = ToolApprovalPolicyOverrides {
1052 preapproved: ApprovalRulesOverrides {
1053 tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1054 per_tool: [
1055 (
1056 BASH_TOOL_NAME.to_string(),
1057 ToolRuleOverrides::Bash {
1058 patterns: vec!["git log".to_string()],
1059 },
1060 ),
1061 (
1062 DISPATCH_AGENT_TOOL_NAME.to_string(),
1063 ToolRuleOverrides::DispatchAgent {
1064 agent_patterns: vec!["review".to_string()],
1065 },
1066 ),
1067 ]
1068 .into_iter()
1069 .collect(),
1070 },
1071 };
1072
1073 let merged = overrides.apply_to(&base_policy);
1074
1075 assert_eq!(merged.default_behavior, UnapprovedBehavior::Prompt);
1076 assert!(merged.preapproved.tools.contains("read_file"));
1077 assert!(merged.preapproved.tools.contains("write_file"));
1078
1079 let bash_patterns = match merged
1080 .preapproved
1081 .per_tool
1082 .get(BASH_TOOL_NAME)
1083 .expect("bash rule")
1084 {
1085 ToolRule::Bash { patterns } => patterns,
1086 ToolRule::DispatchAgent { .. } => {
1087 panic!("Unexpected bash rule: dispatch agent")
1088 }
1089 };
1090 assert!(bash_patterns.contains(&"git status".to_string()));
1091 assert!(bash_patterns.contains(&"git log".to_string()));
1092 assert_eq!(bash_patterns.len(), 2);
1093
1094 let agent_patterns = match merged
1095 .preapproved
1096 .per_tool
1097 .get(DISPATCH_AGENT_TOOL_NAME)
1098 .expect("dispatch_agent rule")
1099 {
1100 ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1101 ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1102 };
1103 assert!(agent_patterns.contains(&"explore".to_string()));
1104 assert!(agent_patterns.contains(&"review".to_string()));
1105 assert_eq!(agent_patterns.len(), 2);
1106 }
1107
1108 #[test]
1109 fn test_bash_pattern_matching() {
1110 let policy = ToolApprovalPolicy {
1111 default_behavior: UnapprovedBehavior::Prompt,
1112 preapproved: ApprovalRules {
1113 tools: HashSet::new(),
1114 per_tool: [(
1115 "bash".to_string(),
1116 ToolRule::Bash {
1117 patterns: vec![
1118 "git status".to_string(),
1119 "git log*".to_string(),
1120 "git * --oneline".to_string(),
1121 "ls -?a*".to_string(),
1122 "cargo build*".to_string(),
1123 ],
1124 },
1125 )]
1126 .into_iter()
1127 .collect(),
1128 },
1129 };
1130
1131 assert!(policy.is_bash_pattern_preapproved("git status"));
1132 assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1133 assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1134 assert!(policy.is_bash_pattern_preapproved("ls -la"));
1135 assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1136 assert!(!policy.is_bash_pattern_preapproved("git commit"));
1137 assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1138 assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1139 }
1140
1141 #[test]
1142 fn test_dispatch_agent_pattern_matching() {
1143 let policy = ToolApprovalPolicy {
1144 default_behavior: UnapprovedBehavior::Prompt,
1145 preapproved: ApprovalRules {
1146 tools: HashSet::new(),
1147 per_tool: [(
1148 "dispatch_agent".to_string(),
1149 ToolRule::DispatchAgent {
1150 agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1151 },
1152 )]
1153 .into_iter()
1154 .collect(),
1155 },
1156 };
1157
1158 assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1159 assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1160 assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1161 }
1162
1163 #[test]
1164 fn test_session_state_validation() {
1165 let mut state = SessionState::default();
1166
1167 assert!(state.validate().is_ok());
1169
1170 let message = Message {
1172 data: MessageData::User {
1173 content: vec![UserContent::Text {
1174 text: "Hello".to_string(),
1175 }],
1176 },
1177 timestamp: 123_456_789,
1178 id: "msg1".to_string(),
1179 parent_message_id: None,
1180 };
1181 state.add_message(message);
1182
1183 assert!(state.validate().is_ok());
1184 assert_eq!(state.message_count(), 1);
1185 }
1186
1187 #[test]
1188 fn test_tool_call_state_tracking() {
1189 let mut state = SessionState::default();
1190
1191 let tool_call = ToolCall {
1192 id: "tool1".to_string(),
1193 name: "read_file".to_string(),
1194 parameters: serde_json::json!({"path": "/test.txt"}),
1195 };
1196
1197 state.add_tool_call(tool_call.clone());
1198 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1199
1200 state
1201 .update_tool_call_status("tool1", ToolCallStatus::Executing)
1202 .unwrap();
1203 let tool_state = state.tool_calls.get("tool1").unwrap();
1204 assert!(tool_state.started_at.is_some());
1205 assert!(!tool_state.is_complete());
1206
1207 state
1208 .update_tool_call_status("tool1", ToolCallStatus::Completed)
1209 .unwrap();
1210 let tool_state = state.tool_calls.get("tool1").unwrap();
1211 assert!(tool_state.completed_at.is_some());
1212 assert!(tool_state.is_complete());
1213 }
1214
1215 #[test]
1216 fn test_session_tool_config_default() {
1217 let config = SessionToolConfig::default();
1218 assert!(config.backends.is_empty());
1219 }
1220
1221 #[test]
1222 fn test_tool_filter_exclude() {
1223 let excluded =
1224 ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1225
1226 if let ToolFilter::Exclude(tools) = &excluded {
1227 assert_eq!(tools.len(), 2);
1228 assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1229 assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1230 } else {
1231 panic!("Expected ToolFilter::Exclude");
1232 }
1233 }
1234
1235 #[test]
1236 fn test_session_tool_config_read_only() {
1237 let config = SessionToolConfig::read_only();
1238 assert_eq!(config.backends.len(), 0);
1239 assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1240 assert_eq!(
1241 config.approval_policy.default_behavior,
1242 UnapprovedBehavior::Prompt
1243 );
1244 }
1245
1246 #[tokio::test]
1247 async fn test_session_config_build_registry_no_default_backends() {
1248 let config = SessionConfig {
1252 workspace: WorkspaceConfig::Local {
1253 path: PathBuf::from("/test/path"),
1254 },
1255 workspace_ref: None,
1256 workspace_id: None,
1257 repo_ref: None,
1258 parent_session_id: None,
1259 workspace_name: None,
1260 tool_config: SessionToolConfig::default(), system_prompt: None,
1262 primary_agent_id: None,
1263 policy_overrides: SessionPolicyOverrides::empty(),
1264 metadata: HashMap::new(),
1265 default_model: test_model(),
1266 auto_compaction: AutoCompactionConfig::default(),
1267 };
1268
1269 let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1270 let schemas = registry.get_tool_schemas().await;
1271
1272 assert!(
1273 schemas.is_empty(),
1274 "BackendRegistry should be empty with default config; got: {:?}",
1275 schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1276 );
1277 }
1278
1279 #[test]
1286 fn test_mcp_status_tracking() {
1287 let mut session_state = SessionState::default();
1289
1290 let mcp_info = McpServerInfo {
1292 server_name: "test-server".to_string(),
1293 transport: crate::tools::McpTransport::Stdio {
1294 command: "python".to_string(),
1295 args: vec!["-m".to_string(), "test_server".to_string()],
1296 },
1297 state: McpConnectionState::Connected {
1298 tool_names: vec![
1299 "tool1".to_string(),
1300 "tool2".to_string(),
1301 "tool3".to_string(),
1302 "tool4".to_string(),
1303 "tool5".to_string(),
1304 ],
1305 },
1306 last_updated: Utc::now(),
1307 };
1308
1309 session_state
1310 .mcp_servers
1311 .insert("test-server".to_string(), mcp_info.clone());
1312
1313 assert_eq!(session_state.mcp_servers.len(), 1);
1315 let stored = session_state.mcp_servers.get("test-server").unwrap();
1316 assert_eq!(stored.server_name, "test-server");
1317 assert!(matches!(
1318 stored.state,
1319 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1320 ));
1321
1322 let failed_info = McpServerInfo {
1324 server_name: "failed-server".to_string(),
1325 transport: crate::tools::McpTransport::Tcp {
1326 host: "localhost".to_string(),
1327 port: 9999,
1328 },
1329 state: McpConnectionState::Failed {
1330 error: "Connection refused".to_string(),
1331 },
1332 last_updated: Utc::now(),
1333 };
1334
1335 session_state
1336 .mcp_servers
1337 .insert("failed-server".to_string(), failed_info);
1338 assert_eq!(session_state.mcp_servers.len(), 2);
1339 }
1340
1341 #[tokio::test]
1342 async fn test_mcp_server_tracking_in_build_registry() {
1343 let mut config = SessionConfig::read_only(test_model());
1345
1346 config.tool_config.backends.push(BackendConfig::Mcp {
1348 server_name: "bad-server".to_string(),
1349 transport: crate::tools::McpTransport::Tcp {
1350 host: "nonexistent.invalid".to_string(),
1351 port: 12345,
1352 },
1353 tool_filter: ToolFilter::All,
1354 });
1355
1356 config.tool_config.backends.push(BackendConfig::Mcp {
1358 server_name: "good-server".to_string(),
1359 transport: crate::tools::McpTransport::Stdio {
1360 command: "echo".to_string(),
1361 args: vec!["test".to_string()],
1362 },
1363 tool_filter: ToolFilter::All,
1364 });
1365
1366 let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1367
1368 assert_eq!(mcp_servers.len(), 2);
1370
1371 let bad_server = mcp_servers.get("bad-server").unwrap();
1373 assert_eq!(bad_server.server_name, "bad-server");
1374 assert!(matches!(
1375 bad_server.state,
1376 McpConnectionState::Failed { .. }
1377 ));
1378
1379 let good_server = mcp_servers.get("good-server").unwrap();
1381 assert_eq!(good_server.server_name, "good-server");
1382 assert!(matches!(
1383 good_server.state,
1384 McpConnectionState::Failed { .. }
1385 ));
1386 }
1387
1388 #[test]
1389 fn test_backend_config_mcp_variant() {
1390 let mcp_config = BackendConfig::Mcp {
1391 server_name: "test-mcp".to_string(),
1392 transport: crate::tools::McpTransport::Stdio {
1393 command: "python".to_string(),
1394 args: vec!["-m".to_string(), "test_server".to_string()],
1395 },
1396 tool_filter: ToolFilter::All,
1397 };
1398
1399 let BackendConfig::Mcp {
1400 server_name,
1401 transport,
1402 ..
1403 } = mcp_config;
1404
1405 assert_eq!(server_name, "test-mcp");
1406 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1407 assert_eq!(command, "python");
1408 assert_eq!(args.len(), 2);
1409 } else {
1410 panic!("Expected Stdio transport");
1411 }
1412 }
1413}