1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18 Connecting,
20 Connected {
22 tool_names: Vec<String>,
24 },
25 Disconnected {
27 reason: Option<String>,
29 },
30 Failed {
32 error: String,
34 },
35}
36
37#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40 pub server_name: String,
42 pub transport: McpTransport,
44 pub state: McpConnectionState,
46 pub last_updated: DateTime<Utc>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54 Local {
55 path: PathBuf,
56 },
57 Remote {
58 agent_address: String,
59 auth: Option<RemoteAuth>,
60 },
61}
62
63impl WorkspaceConfig {
64 pub fn get_path(&self) -> Option<String> {
65 match self {
66 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68 }
69 }
70
71 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73 match self {
74 WorkspaceConfig::Local { path } => {
75 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76 }
77 WorkspaceConfig::Remote {
78 agent_address,
79 auth,
80 } => steer_workspace::WorkspaceConfig::Remote {
81 address: agent_address.clone(),
82 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83 },
84 }
85 }
86}
87
88impl Default for WorkspaceConfig {
89 fn default() -> Self {
90 Self::Local {
91 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99 pub id: String,
100 pub created_at: DateTime<Utc>,
101 pub updated_at: DateTime<Utc>,
102 pub config: SessionConfig,
103 pub state: SessionState,
104}
105
106impl Session {
107 pub fn new(id: String, config: SessionConfig) -> Self {
108 let now = Utc::now();
109 Self {
110 id,
111 created_at: now,
112 updated_at: now,
113 config,
114 state: SessionState::default(),
115 }
116 }
117
118 pub fn update_timestamp(&mut self) {
119 self.updated_at = Utc::now();
120 }
121
122 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124 let cutoff = Utc::now() - threshold;
125 self.updated_at > cutoff
126 }
127
128 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137 pub workspace: WorkspaceConfig,
138 #[serde(default)]
139 pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140 #[serde(default)]
141 pub workspace_id: Option<crate::workspace::WorkspaceId>,
142 #[serde(default)]
143 pub repo_ref: Option<crate::workspace::RepoRef>,
144 #[serde(default)]
145 pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146 #[serde(default)]
147 pub workspace_name: Option<String>,
148 pub tool_config: SessionToolConfig,
149 pub system_prompt: Option<String>,
152 #[serde(default)]
154 pub primary_agent_id: Option<String>,
155 #[serde(default = "SessionPolicyOverrides::empty")]
157 pub policy_overrides: SessionPolicyOverrides,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub title: Option<String>,
160 pub metadata: HashMap<String, String>,
161 pub default_model: ModelId,
162 #[serde(default)]
163 pub auto_compaction: AutoCompactionConfig,
164}
165
166impl SessionConfig {
167 pub async fn build_registry(
170 &self,
171 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
172 let mut registry = BackendRegistry::new();
173 let mut mcp_servers = HashMap::new();
174
175 for backend_config in &self.tool_config.backends {
176 let BackendConfig::Mcp {
177 server_name,
178 transport,
179 tool_filter,
180 } = backend_config;
181
182 tracing::info!(
183 "Attempting to initialize MCP backend '{}' with transport: {:?}",
184 server_name,
185 transport
186 );
187
188 let mut server_info = McpServerInfo {
189 server_name: server_name.clone(),
190 transport: transport.clone(),
191 state: McpConnectionState::Connecting,
192 last_updated: Utc::now(),
193 };
194
195 match crate::tools::McpBackend::new(
196 server_name.clone(),
197 transport.clone(),
198 tool_filter.clone(),
199 )
200 .await
201 {
202 Ok(mcp_backend) => {
203 let tool_names = mcp_backend.supported_tools().await;
204 let tool_count = tool_names.len();
205 tracing::info!(
206 "Successfully initialized MCP backend '{}' with {} tools",
207 server_name,
208 tool_count
209 );
210 server_info.state = McpConnectionState::Connected { tool_names };
211 server_info.last_updated = Utc::now();
212 registry
213 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
214 .await;
215 }
216 Err(e) => {
217 tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
218 server_info.state = McpConnectionState::Failed {
219 error: e.to_string(),
220 };
221 server_info.last_updated = Utc::now();
222 }
223 }
224
225 mcp_servers.insert(server_name.clone(), server_info);
226 }
227
228 Ok((registry, mcp_servers))
229 }
230
231 pub fn filter_tools_by_visibility(
233 &self,
234 tools: Vec<steer_tools::ToolSchema>,
235 ) -> Vec<steer_tools::ToolSchema> {
236 match &self.tool_config.visibility {
237 ToolVisibility::All => tools,
238 ToolVisibility::ReadOnly => {
239 let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
240 .iter()
241 .map(|name| (*name).to_string())
242 .collect();
243
244 tools
245 .into_iter()
246 .filter(|schema| read_only_names.contains(&schema.name))
247 .collect()
248 }
249 ToolVisibility::Whitelist(allowed) => tools
250 .into_iter()
251 .filter(|schema| allowed.contains(&schema.name))
252 .collect(),
253 ToolVisibility::Blacklist(blocked) => tools
254 .into_iter()
255 .filter(|schema| !blocked.contains(&schema.name))
256 .collect(),
257 }
258 }
259
260 #[cfg(test)]
262 pub fn read_only(default_model: ModelId) -> Self {
263 Self {
264 workspace: WorkspaceConfig::Local {
265 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
266 },
267 workspace_ref: None,
268 workspace_id: None,
269 repo_ref: None,
270 parent_session_id: None,
271 workspace_name: None,
272 tool_config: SessionToolConfig::read_only(),
273 system_prompt: None,
274 primary_agent_id: None,
275 policy_overrides: SessionPolicyOverrides::empty(),
276 title: None,
277 metadata: HashMap::new(),
278 default_model,
279 auto_compaction: AutoCompactionConfig::default(),
280 }
281 }
282}
283
284#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
286pub struct AutoCompactionConfig {
287 pub enabled: bool,
288 pub threshold_percent: u32,
289}
290
291impl Default for AutoCompactionConfig {
292 fn default() -> Self {
293 Self {
294 enabled: true,
295 threshold_percent: 90,
296 }
297 }
298}
299
300impl AutoCompactionConfig {
301 pub fn threshold_ratio(&self) -> f64 {
302 f64::from(self.threshold_percent) / 100.0
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
308pub struct SessionPolicyOverrides {
309 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub default_model: Option<ModelId>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
312 pub tool_visibility: Option<ToolVisibility>,
313 #[serde(default = "ToolApprovalPolicyOverrides::empty")]
314 pub approval_policy: ToolApprovalPolicyOverrides,
315}
316
317impl SessionPolicyOverrides {
318 pub fn empty() -> Self {
319 Self {
320 default_model: None,
321 tool_visibility: None,
322 approval_policy: ToolApprovalPolicyOverrides::empty(),
323 }
324 }
325
326 pub fn is_empty(&self) -> bool {
327 self.default_model.is_none()
328 && self.tool_visibility.is_none()
329 && self.approval_policy.is_empty()
330 }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
334pub struct ToolApprovalPolicyOverrides {
335 #[serde(default = "ApprovalRulesOverrides::empty")]
336 pub preapproved: ApprovalRulesOverrides,
337}
338
339impl ToolApprovalPolicyOverrides {
340 pub fn empty() -> Self {
341 Self {
342 preapproved: ApprovalRulesOverrides::empty(),
343 }
344 }
345
346 pub fn is_empty(&self) -> bool {
347 self.preapproved.is_empty()
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
352pub struct ApprovalRulesOverrides {
353 #[serde(default)]
354 pub tools: HashSet<String>,
355 #[serde(default)]
356 pub per_tool: HashMap<String, ToolRuleOverrides>,
357}
358
359impl ApprovalRulesOverrides {
360 pub fn empty() -> Self {
361 Self {
362 tools: HashSet::new(),
363 per_tool: HashMap::new(),
364 }
365 }
366
367 pub fn is_empty(&self) -> bool {
368 self.tools.is_empty() && self.per_tool.is_empty()
369 }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
373#[serde(tag = "type", rename_all = "snake_case")]
374pub enum ToolRuleOverrides {
375 Bash { patterns: Vec<String> },
376 DispatchAgent { agent_patterns: Vec<String> },
377}
378
379#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
381#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
382pub enum ToolVisibility {
383 #[default]
385 All,
386
387 ReadOnly,
389
390 Whitelist(HashSet<String>),
392
393 Blacklist(HashSet<String>),
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum ToolDecision {
399 Allow,
400 Ask,
401 Deny,
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
405#[serde(rename_all = "snake_case")]
406pub enum UnapprovedBehavior {
407 #[default]
408 Prompt,
409 Deny,
410 Allow,
411}
412
413#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum ToolRule {
416 Bash { patterns: Vec<String> },
417 DispatchAgent { agent_patterns: Vec<String> },
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
421pub struct ApprovalRules {
422 #[serde(default)]
423 pub tools: HashSet<String>,
424 #[serde(default)]
425 pub per_tool: HashMap<String, ToolRule>,
426}
427
428impl ApprovalRules {
429 pub fn is_empty(&self) -> bool {
430 self.tools.is_empty() && self.per_tool.is_empty()
431 }
432
433 pub fn bash_patterns(&self) -> Option<&[String]> {
434 self.per_tool.get("bash").and_then(|rule| match rule {
435 ToolRule::Bash { patterns } => Some(patterns.as_slice()),
436 ToolRule::DispatchAgent { .. } => None,
437 })
438 }
439
440 pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
441 self.per_tool
442 .get("dispatch_agent")
443 .and_then(|rule| match rule {
444 ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
445 ToolRule::Bash { .. } => None,
446 })
447 }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
451pub struct ToolApprovalPolicy {
452 pub default_behavior: UnapprovedBehavior,
453 #[serde(default)]
454 pub preapproved: ApprovalRules,
455}
456
457impl Default for ToolApprovalPolicy {
458 fn default() -> Self {
459 Self {
460 default_behavior: UnapprovedBehavior::Prompt,
461 preapproved: ApprovalRules {
462 tools: READ_ONLY_TOOL_NAMES
463 .iter()
464 .map(|name| (*name).to_string())
465 .collect(),
466 per_tool: HashMap::new(),
467 },
468 }
469 }
470}
471
472impl ToolApprovalPolicy {
473 pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
474 if self.preapproved.tools.contains(tool_name) {
475 ToolDecision::Allow
476 } else {
477 match self.default_behavior {
478 UnapprovedBehavior::Prompt => ToolDecision::Ask,
479 UnapprovedBehavior::Deny => ToolDecision::Deny,
480 UnapprovedBehavior::Allow => ToolDecision::Allow,
481 }
482 }
483 }
484
485 pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
486 let Some(patterns) = self.preapproved.bash_patterns() else {
487 return false;
488 };
489 patterns.iter().any(|pattern| {
490 if pattern == command {
491 return true;
492 }
493 glob::Pattern::new(pattern)
494 .map(|glob| glob.matches(command))
495 .unwrap_or(false)
496 })
497 }
498
499 pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
500 let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
501 return false;
502 };
503 patterns.iter().any(|pattern| {
504 if pattern == agent_id {
505 return true;
506 }
507 glob::Pattern::new(pattern)
508 .map(|glob| glob.matches(agent_id))
509 .unwrap_or(false)
510 })
511 }
512
513 pub fn pre_approved_tools(&self) -> &HashSet<String> {
514 &self.preapproved.tools
515 }
516}
517
518impl ToolApprovalPolicyOverrides {
519 pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
520 let mut merged = base.clone();
521
522 if !self.preapproved.tools.is_empty() {
523 merged
524 .preapproved
525 .tools
526 .extend(self.preapproved.tools.iter().cloned());
527 }
528
529 for (tool_name, override_rule) in &self.preapproved.per_tool {
530 let base_rule = merged.preapproved.per_tool.get(tool_name);
531 let merged_rule = merge_tool_rule_override(base_rule, override_rule);
532 merged
533 .preapproved
534 .per_tool
535 .insert(tool_name.clone(), merged_rule);
536 }
537
538 merged
539 }
540}
541
542fn merge_tool_rule_override(
543 base: Option<&ToolRule>,
544 override_rule: &ToolRuleOverrides,
545) -> ToolRule {
546 match (base, override_rule) {
547 (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
548 ToolRule::Bash {
549 patterns: merge_patterns(patterns, extra),
550 }
551 }
552 (
553 Some(ToolRule::DispatchAgent { agent_patterns }),
554 ToolRuleOverrides::DispatchAgent {
555 agent_patterns: extra,
556 },
557 ) => ToolRule::DispatchAgent {
558 agent_patterns: merge_patterns(agent_patterns, extra),
559 },
560 (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
561 patterns: patterns.clone(),
562 },
563 (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
564 agent_patterns: agent_patterns.clone(),
565 },
566 }
567}
568
569fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
570 let mut merged = base.to_vec();
571 for pattern in extra {
572 if !merged.iter().any(|existing| existing == pattern) {
573 merged.push(pattern.clone());
574 }
575 }
576 merged
577}
578
579#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
581pub enum RemoteAuth {
582 Bearer { token: String },
583 ApiKey { key: String },
584}
585
586impl RemoteAuth {
587 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
589 match self {
590 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
591 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
592 }
593 }
594}
595
596#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
598#[serde(rename_all = "snake_case")]
599#[derive(Default)]
600pub enum ToolFilter {
601 #[default]
603 All,
604 Include(Vec<String>),
606 Exclude(Vec<String>),
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
612#[serde(tag = "type", rename_all = "snake_case")]
613pub enum BackendConfig {
614 Mcp {
615 server_name: String,
616 transport: McpTransport,
617 tool_filter: ToolFilter,
618 },
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
622pub struct SessionToolConfig {
623 pub backends: Vec<BackendConfig>,
624 pub visibility: ToolVisibility,
625 pub approval_policy: ToolApprovalPolicy,
626 pub metadata: HashMap<String, String>,
627}
628
629impl Default for SessionToolConfig {
630 fn default() -> Self {
631 Self {
632 backends: Vec::new(),
633 visibility: ToolVisibility::All,
634 approval_policy: ToolApprovalPolicy::default(),
635 metadata: HashMap::new(),
636 }
637 }
638}
639
640impl SessionToolConfig {
641 pub fn read_only() -> Self {
642 Self {
643 backends: Vec::new(),
644 visibility: ToolVisibility::ReadOnly,
645 approval_policy: ToolApprovalPolicy::default(),
646 metadata: HashMap::new(),
647 }
648 }
649}
650
651#[derive(Debug, Clone, Serialize, Deserialize, Default)]
653pub struct SessionState {
654 pub messages: Vec<Message>,
656
657 pub tool_calls: HashMap<String, ToolCallState>,
659
660 pub approved_tools: HashSet<String>,
662
663 #[serde(default)]
665 pub approved_bash_patterns: HashSet<String>,
666
667 pub last_event_sequence: u64,
669
670 pub metadata: HashMap<String, String>,
672
673 #[serde(default, skip_serializing_if = "Option::is_none")]
676 pub active_message_id: Option<String>,
677
678 #[serde(default, skip_serializing, skip_deserializing)]
681 pub mcp_servers: HashMap<String, McpServerInfo>,
682}
683
684impl SessionState {
685 pub fn add_message(&mut self, message: Message) {
687 self.messages.push(message);
688 }
689
690 pub fn message_count(&self) -> usize {
692 self.messages.len()
693 }
694
695 pub fn last_message(&self) -> Option<&Message> {
697 self.messages.last()
698 }
699
700 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
702 let state = ToolCallState {
703 tool_call: tool_call.clone(),
704 status: ToolCallStatus::PendingApproval,
705 started_at: None,
706 completed_at: None,
707 result: None,
708 };
709 self.tool_calls.insert(tool_call.id, state);
710 }
711
712 pub fn update_tool_call_status(
714 &mut self,
715 tool_call_id: &str,
716 status: ToolCallStatus,
717 ) -> std::result::Result<(), String> {
718 let tool_call = self
719 .tool_calls
720 .get_mut(tool_call_id)
721 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
722
723 match (&tool_call.status, &status) {
725 (_, ToolCallStatus::Executing) => {
726 tool_call.started_at = Some(Utc::now());
727 }
728 (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
729 tool_call.completed_at = Some(Utc::now());
730 }
731 _ => {}
732 }
733
734 tool_call.status = status;
735 Ok(())
736 }
737
738 pub fn approve_tool(&mut self, tool_name: String) {
740 self.approved_tools.insert(tool_name);
741 }
742
743 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
745 self.approved_tools.contains(tool_name)
746 }
747
748 pub fn validate(&self) -> std::result::Result<(), String> {
750 for message in &self.messages {
752 let tool_calls = Self::extract_tool_calls_from_message(message);
753 if !tool_calls.is_empty() {
754 for tool_call_id in tool_calls {
755 if !self.tool_calls.contains_key(&tool_call_id) {
756 return Err(format!(
757 "Message references unknown tool call: {tool_call_id}"
758 ));
759 }
760 }
761 }
762 }
763
764 Ok(())
765 }
766
767 fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
769 let mut tool_call_ids = Vec::new();
770
771 match &message.data {
772 MessageData::Assistant { content, .. } => {
773 for c in content {
774 if let crate::app::conversation::AssistantContent::ToolCall {
775 tool_call, ..
776 } = c
777 {
778 tool_call_ids.push(tool_call.id.clone());
779 }
780 }
781 }
782 MessageData::Tool { tool_use_id, .. } => {
783 tool_call_ids.push(tool_use_id.clone());
784 }
785 MessageData::User { .. } => {}
786 }
787
788 tool_call_ids
789 }
790}
791
792#[derive(Debug, Clone, Serialize, Deserialize)]
794pub struct ToolCallState {
795 pub tool_call: ToolCall,
796 pub status: ToolCallStatus,
797 pub started_at: Option<DateTime<Utc>>,
798 pub completed_at: Option<DateTime<Utc>>,
799 pub result: Option<ToolResult>,
800}
801
802impl ToolCallState {
803 pub fn is_pending(&self) -> bool {
804 matches!(self.status, ToolCallStatus::PendingApproval)
805 }
806
807 pub fn is_complete(&self) -> bool {
808 matches!(
809 self.status,
810 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
811 )
812 }
813
814 pub fn duration(&self) -> Option<chrono::Duration> {
815 match (self.started_at, self.completed_at) {
816 (Some(start), Some(end)) => Some(end - start),
817 _ => None,
818 }
819 }
820}
821
822#[derive(Debug, Clone, Serialize, Deserialize)]
824#[serde(tag = "status", rename_all = "snake_case")]
825pub enum ToolCallStatus {
826 PendingApproval,
827 Approved,
828 Denied,
829 Executing,
830 Completed,
831 Failed { error: String },
832}
833
834impl ToolCallStatus {
835 pub fn is_terminal(&self) -> bool {
836 matches!(
837 self,
838 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
839 )
840 }
841}
842
843#[derive(Debug, Clone, Serialize, Deserialize)]
845pub struct ToolExecutionStats {
846 #[serde(skip_serializing_if = "Option::is_none")]
847 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
849 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
852 pub execution_time_ms: u64,
853 pub metadata: HashMap<String, String>,
854}
855
856impl ToolExecutionStats {
857 pub fn success(output: String, execution_time_ms: u64) -> Self {
858 Self {
859 output: Some(output),
860 json_output: None,
861 result_type: None,
862 success: true,
863 execution_time_ms,
864 metadata: HashMap::new(),
865 }
866 }
867
868 pub fn success_typed(
869 json_output: serde_json::Value,
870 result_type: String,
871 execution_time_ms: u64,
872 ) -> Self {
873 Self {
874 output: None,
875 json_output: Some(json_output),
876 result_type: Some(result_type),
877 success: true,
878 execution_time_ms,
879 metadata: HashMap::new(),
880 }
881 }
882
883 pub fn failure(error: String, execution_time_ms: u64) -> Self {
884 Self {
885 output: Some(error),
886 json_output: None,
887 result_type: None,
888 success: false,
889 execution_time_ms,
890 metadata: HashMap::new(),
891 }
892 }
893
894 pub fn with_metadata(mut self, key: String, value: String) -> Self {
895 self.metadata.insert(key, value);
896 self
897 }
898}
899
900#[derive(Debug, Clone, Serialize, Deserialize)]
902pub struct SessionInfo {
903 pub id: String,
904 pub created_at: DateTime<Utc>,
905 pub updated_at: DateTime<Utc>,
906 pub last_model: Option<ModelId>,
908 pub message_count: usize,
909 pub title: Option<String>,
910 pub metadata: HashMap<String, String>,
911}
912
913impl From<&Session> for SessionInfo {
914 fn from(session: &Session) -> Self {
915 Self {
916 id: session.id.clone(),
917 created_at: session.created_at,
918 updated_at: session.updated_at,
919 last_model: None, message_count: session.state.message_count(),
921 title: session.config.title.clone(),
922 metadata: session.config.metadata.clone(),
923 }
924 }
925}
926
927#[cfg(test)]
928mod tests {
929 use super::*;
930 use crate::app::conversation::{Message, MessageData, UserContent};
931 use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
932 use crate::tools::DISPATCH_AGENT_TOOL_NAME;
933 use crate::tools::builtin_tools::READ_ONLY_TOOL_NAMES;
934 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
935
936 #[test]
937 fn test_session_creation() {
938 let config = SessionConfig {
939 workspace: WorkspaceConfig::Local {
940 path: PathBuf::from("/test/path"),
941 },
942 workspace_ref: None,
943 workspace_id: None,
944 repo_ref: None,
945 parent_session_id: None,
946 workspace_name: None,
947 tool_config: SessionToolConfig::default(),
948 system_prompt: None,
949 primary_agent_id: None,
950 policy_overrides: SessionPolicyOverrides::empty(),
951 title: None,
952 metadata: HashMap::new(),
953 default_model: test_model(),
954 auto_compaction: AutoCompactionConfig::default(),
955 };
956 let session = Session::new("test-session".to_string(), config.clone());
957
958 assert_eq!(session.id, "test-session");
959 assert_eq!(
960 session
961 .config
962 .tool_config
963 .approval_policy
964 .tool_decision("any_tool"),
965 ToolDecision::Ask
966 );
967 assert_eq!(session.state.message_count(), 0);
968 }
969
970 #[test]
971 fn test_tool_approval_policy_prompt_unapproved() {
972 let policy = ToolApprovalPolicy {
973 default_behavior: UnapprovedBehavior::Prompt,
974 preapproved: ApprovalRules {
975 tools: ["read_file", "list_files"]
976 .iter()
977 .map(|s| (*s).to_string())
978 .collect(),
979 per_tool: HashMap::new(),
980 },
981 };
982
983 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
984 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
985 }
986
987 #[test]
988 fn test_tool_approval_policy_deny_unapproved() {
989 let policy = ToolApprovalPolicy {
990 default_behavior: UnapprovedBehavior::Deny,
991 preapproved: ApprovalRules {
992 tools: ["read_file", "list_files"]
993 .iter()
994 .map(|s| (*s).to_string())
995 .collect(),
996 per_tool: HashMap::new(),
997 },
998 };
999
1000 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1001 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
1002 }
1003
1004 #[test]
1005 fn test_tool_approval_policy_default() {
1006 let policy = ToolApprovalPolicy::default();
1007
1008 assert_eq!(
1009 policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1010 ToolDecision::Allow
1011 );
1012 assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1013 }
1014
1015 #[test]
1016 fn test_tool_approval_policy_allow_unapproved() {
1017 let policy = ToolApprovalPolicy {
1018 default_behavior: UnapprovedBehavior::Allow,
1019 preapproved: ApprovalRules {
1020 tools: ["read_file", "list_files"]
1021 .iter()
1022 .map(|s| (*s).to_string())
1023 .collect(),
1024 per_tool: HashMap::new(),
1025 },
1026 };
1027
1028 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1029 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1030 }
1031
1032 #[test]
1033 fn test_tool_approval_policy_overrides_union_rules() {
1034 let base_policy = ToolApprovalPolicy {
1035 default_behavior: UnapprovedBehavior::Prompt,
1036 preapproved: ApprovalRules {
1037 tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1038 per_tool: [
1039 (
1040 BASH_TOOL_NAME.to_string(),
1041 ToolRule::Bash {
1042 patterns: vec!["git status".to_string()],
1043 },
1044 ),
1045 (
1046 DISPATCH_AGENT_TOOL_NAME.to_string(),
1047 ToolRule::DispatchAgent {
1048 agent_patterns: vec!["explore".to_string()],
1049 },
1050 ),
1051 ]
1052 .into_iter()
1053 .collect(),
1054 },
1055 };
1056
1057 let overrides = ToolApprovalPolicyOverrides {
1058 preapproved: ApprovalRulesOverrides {
1059 tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1060 per_tool: [
1061 (
1062 BASH_TOOL_NAME.to_string(),
1063 ToolRuleOverrides::Bash {
1064 patterns: vec!["git log".to_string()],
1065 },
1066 ),
1067 (
1068 DISPATCH_AGENT_TOOL_NAME.to_string(),
1069 ToolRuleOverrides::DispatchAgent {
1070 agent_patterns: vec!["review".to_string()],
1071 },
1072 ),
1073 ]
1074 .into_iter()
1075 .collect(),
1076 },
1077 };
1078
1079 let merged = overrides.apply_to(&base_policy);
1080
1081 assert_eq!(merged.default_behavior, UnapprovedBehavior::Prompt);
1082 assert!(merged.preapproved.tools.contains("read_file"));
1083 assert!(merged.preapproved.tools.contains("write_file"));
1084
1085 let bash_patterns = match merged
1086 .preapproved
1087 .per_tool
1088 .get(BASH_TOOL_NAME)
1089 .expect("bash rule")
1090 {
1091 ToolRule::Bash { patterns } => patterns,
1092 ToolRule::DispatchAgent { .. } => {
1093 panic!("Unexpected bash rule: dispatch agent")
1094 }
1095 };
1096 assert!(bash_patterns.contains(&"git status".to_string()));
1097 assert!(bash_patterns.contains(&"git log".to_string()));
1098 assert_eq!(bash_patterns.len(), 2);
1099
1100 let agent_patterns = match merged
1101 .preapproved
1102 .per_tool
1103 .get(DISPATCH_AGENT_TOOL_NAME)
1104 .expect("dispatch_agent rule")
1105 {
1106 ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1107 ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1108 };
1109 assert!(agent_patterns.contains(&"explore".to_string()));
1110 assert!(agent_patterns.contains(&"review".to_string()));
1111 assert_eq!(agent_patterns.len(), 2);
1112 }
1113
1114 #[test]
1115 fn test_bash_pattern_matching() {
1116 let policy = ToolApprovalPolicy {
1117 default_behavior: UnapprovedBehavior::Prompt,
1118 preapproved: ApprovalRules {
1119 tools: HashSet::new(),
1120 per_tool: [(
1121 "bash".to_string(),
1122 ToolRule::Bash {
1123 patterns: vec![
1124 "git status".to_string(),
1125 "git log*".to_string(),
1126 "git * --oneline".to_string(),
1127 "ls -?a*".to_string(),
1128 "cargo build*".to_string(),
1129 ],
1130 },
1131 )]
1132 .into_iter()
1133 .collect(),
1134 },
1135 };
1136
1137 assert!(policy.is_bash_pattern_preapproved("git status"));
1138 assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1139 assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1140 assert!(policy.is_bash_pattern_preapproved("ls -la"));
1141 assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1142 assert!(!policy.is_bash_pattern_preapproved("git commit"));
1143 assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1144 assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1145 }
1146
1147 #[test]
1148 fn test_dispatch_agent_pattern_matching() {
1149 let policy = ToolApprovalPolicy {
1150 default_behavior: UnapprovedBehavior::Prompt,
1151 preapproved: ApprovalRules {
1152 tools: HashSet::new(),
1153 per_tool: [(
1154 "dispatch_agent".to_string(),
1155 ToolRule::DispatchAgent {
1156 agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1157 },
1158 )]
1159 .into_iter()
1160 .collect(),
1161 },
1162 };
1163
1164 assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1165 assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1166 assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1167 }
1168
1169 #[test]
1170 fn test_session_state_validation() {
1171 let mut state = SessionState::default();
1172
1173 assert!(state.validate().is_ok());
1175
1176 let message = Message {
1178 data: MessageData::User {
1179 content: vec![UserContent::Text {
1180 text: "Hello".to_string(),
1181 }],
1182 },
1183 timestamp: 123_456_789,
1184 id: "msg1".to_string(),
1185 parent_message_id: None,
1186 };
1187 state.add_message(message);
1188
1189 assert!(state.validate().is_ok());
1190 assert_eq!(state.message_count(), 1);
1191 }
1192
1193 #[test]
1194 fn test_tool_call_state_tracking() {
1195 let mut state = SessionState::default();
1196
1197 let tool_call = ToolCall {
1198 id: "tool1".to_string(),
1199 name: "read_file".to_string(),
1200 parameters: serde_json::json!({"path": "/test.txt"}),
1201 };
1202
1203 state.add_tool_call(tool_call.clone());
1204 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1205
1206 state
1207 .update_tool_call_status("tool1", ToolCallStatus::Executing)
1208 .unwrap();
1209 let tool_state = state.tool_calls.get("tool1").unwrap();
1210 assert!(tool_state.started_at.is_some());
1211 assert!(!tool_state.is_complete());
1212
1213 state
1214 .update_tool_call_status("tool1", ToolCallStatus::Completed)
1215 .unwrap();
1216 let tool_state = state.tool_calls.get("tool1").unwrap();
1217 assert!(tool_state.completed_at.is_some());
1218 assert!(tool_state.is_complete());
1219 }
1220
1221 #[test]
1222 fn test_session_tool_config_default() {
1223 let config = SessionToolConfig::default();
1224 assert!(config.backends.is_empty());
1225 }
1226
1227 #[test]
1228 fn test_tool_filter_exclude() {
1229 let excluded =
1230 ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1231
1232 if let ToolFilter::Exclude(tools) = &excluded {
1233 assert_eq!(tools.len(), 2);
1234 assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1235 assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1236 } else {
1237 panic!("Expected ToolFilter::Exclude");
1238 }
1239 }
1240
1241 #[test]
1242 fn test_session_tool_config_read_only() {
1243 let config = SessionToolConfig::read_only();
1244 assert_eq!(config.backends.len(), 0);
1245 assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1246 assert_eq!(
1247 config.approval_policy.default_behavior,
1248 UnapprovedBehavior::Prompt
1249 );
1250 }
1251
1252 #[tokio::test]
1253 async fn test_session_config_build_registry_no_default_backends() {
1254 let config = SessionConfig {
1258 workspace: WorkspaceConfig::Local {
1259 path: PathBuf::from("/test/path"),
1260 },
1261 workspace_ref: None,
1262 workspace_id: None,
1263 repo_ref: None,
1264 parent_session_id: None,
1265 workspace_name: None,
1266 tool_config: SessionToolConfig::default(), system_prompt: None,
1268 primary_agent_id: None,
1269 policy_overrides: SessionPolicyOverrides::empty(),
1270 title: None,
1271 metadata: HashMap::new(),
1272 default_model: test_model(),
1273 auto_compaction: AutoCompactionConfig::default(),
1274 };
1275
1276 let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1277 let schemas = registry.get_tool_schemas().await;
1278
1279 assert!(
1280 schemas.is_empty(),
1281 "BackendRegistry should be empty with default config; got: {:?}",
1282 schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1283 );
1284 }
1285
1286 #[test]
1293 fn test_mcp_status_tracking() {
1294 let mut session_state = SessionState::default();
1296
1297 let mcp_info = McpServerInfo {
1299 server_name: "test-server".to_string(),
1300 transport: crate::tools::McpTransport::Stdio {
1301 command: "python".to_string(),
1302 args: vec!["-m".to_string(), "test_server".to_string()],
1303 },
1304 state: McpConnectionState::Connected {
1305 tool_names: vec![
1306 "tool1".to_string(),
1307 "tool2".to_string(),
1308 "tool3".to_string(),
1309 "tool4".to_string(),
1310 "tool5".to_string(),
1311 ],
1312 },
1313 last_updated: Utc::now(),
1314 };
1315
1316 session_state
1317 .mcp_servers
1318 .insert("test-server".to_string(), mcp_info.clone());
1319
1320 assert_eq!(session_state.mcp_servers.len(), 1);
1322 let stored = session_state.mcp_servers.get("test-server").unwrap();
1323 assert_eq!(stored.server_name, "test-server");
1324 assert!(matches!(
1325 stored.state,
1326 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1327 ));
1328
1329 let failed_info = McpServerInfo {
1331 server_name: "failed-server".to_string(),
1332 transport: crate::tools::McpTransport::Tcp {
1333 host: "localhost".to_string(),
1334 port: 9999,
1335 },
1336 state: McpConnectionState::Failed {
1337 error: "Connection refused".to_string(),
1338 },
1339 last_updated: Utc::now(),
1340 };
1341
1342 session_state
1343 .mcp_servers
1344 .insert("failed-server".to_string(), failed_info);
1345 assert_eq!(session_state.mcp_servers.len(), 2);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_mcp_server_tracking_in_build_registry() {
1350 let mut config = SessionConfig::read_only(test_model());
1352
1353 config.tool_config.backends.push(BackendConfig::Mcp {
1355 server_name: "bad-server".to_string(),
1356 transport: crate::tools::McpTransport::Tcp {
1357 host: "nonexistent.invalid".to_string(),
1358 port: 12345,
1359 },
1360 tool_filter: ToolFilter::All,
1361 });
1362
1363 config.tool_config.backends.push(BackendConfig::Mcp {
1365 server_name: "good-server".to_string(),
1366 transport: crate::tools::McpTransport::Stdio {
1367 command: "echo".to_string(),
1368 args: vec!["test".to_string()],
1369 },
1370 tool_filter: ToolFilter::All,
1371 });
1372
1373 let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1374
1375 assert_eq!(mcp_servers.len(), 2);
1377
1378 let bad_server = mcp_servers.get("bad-server").unwrap();
1380 assert_eq!(bad_server.server_name, "bad-server");
1381 assert!(matches!(
1382 bad_server.state,
1383 McpConnectionState::Failed { .. }
1384 ));
1385
1386 let good_server = mcp_servers.get("good-server").unwrap();
1388 assert_eq!(good_server.server_name, "good-server");
1389 assert!(matches!(
1390 good_server.state,
1391 McpConnectionState::Failed { .. }
1392 ));
1393 }
1394
1395 #[test]
1396 fn test_backend_config_mcp_variant() {
1397 let mcp_config = BackendConfig::Mcp {
1398 server_name: "test-mcp".to_string(),
1399 transport: crate::tools::McpTransport::Stdio {
1400 command: "python".to_string(),
1401 args: vec!["-m".to_string(), "test_server".to_string()],
1402 },
1403 tool_filter: ToolFilter::All,
1404 };
1405
1406 let BackendConfig::Mcp {
1407 server_name,
1408 transport,
1409 ..
1410 } = mcp_config;
1411
1412 assert_eq!(server_name, "test-mcp");
1413 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1414 assert_eq!(command, "python");
1415 assert_eq!(args.len(), 2);
1416 } else {
1417 panic!("Expected Stdio transport");
1418 }
1419 }
1420}