1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18 Connecting,
20 Connected {
22 tool_names: Vec<String>,
24 },
25 Disconnected {
27 reason: Option<String>,
29 },
30 Failed {
32 error: String,
34 },
35}
36
37#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40 pub server_name: String,
42 pub transport: McpTransport,
44 pub state: McpConnectionState,
46 pub last_updated: DateTime<Utc>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54 Local {
55 path: PathBuf,
56 },
57 Remote {
58 agent_address: String,
59 auth: Option<RemoteAuth>,
60 },
61}
62
63impl WorkspaceConfig {
64 pub fn get_path(&self) -> Option<String> {
65 match self {
66 WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67 WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68 }
69 }
70
71 pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73 match self {
74 WorkspaceConfig::Local { path } => {
75 steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76 }
77 WorkspaceConfig::Remote {
78 agent_address,
79 auth,
80 } => steer_workspace::WorkspaceConfig::Remote {
81 address: agent_address.clone(),
82 auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83 },
84 }
85 }
86}
87
88impl Default for WorkspaceConfig {
89 fn default() -> Self {
90 Self::Local {
91 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99 pub id: String,
100 pub created_at: DateTime<Utc>,
101 pub updated_at: DateTime<Utc>,
102 pub config: SessionConfig,
103 pub state: SessionState,
104}
105
106impl Session {
107 pub fn new(id: String, config: SessionConfig) -> Self {
108 let now = Utc::now();
109 Self {
110 id,
111 created_at: now,
112 updated_at: now,
113 config,
114 state: SessionState::default(),
115 }
116 }
117
118 pub fn update_timestamp(&mut self) {
119 self.updated_at = Utc::now();
120 }
121
122 pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124 let cutoff = Utc::now() - threshold;
125 self.updated_at > cutoff
126 }
127
128 pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130 crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137 pub workspace: WorkspaceConfig,
138 #[serde(default)]
139 pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140 #[serde(default)]
141 pub workspace_id: Option<crate::workspace::WorkspaceId>,
142 #[serde(default)]
143 pub repo_ref: Option<crate::workspace::RepoRef>,
144 #[serde(default)]
145 pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146 #[serde(default)]
147 pub workspace_name: Option<String>,
148 pub tool_config: SessionToolConfig,
149 pub system_prompt: Option<String>,
152 #[serde(default)]
154 pub primary_agent_id: Option<String>,
155 #[serde(default = "SessionPolicyOverrides::empty")]
157 pub policy_overrides: SessionPolicyOverrides,
158 pub metadata: HashMap<String, String>,
159 pub default_model: ModelId,
160 #[serde(default)]
161 pub auto_compaction: AutoCompactionConfig,
162}
163
164impl SessionConfig {
165 pub async fn build_registry(
168 &self,
169 ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
170 let mut registry = BackendRegistry::new();
171 let mut mcp_servers = HashMap::new();
172
173 for backend_config in &self.tool_config.backends {
174 let BackendConfig::Mcp {
175 server_name,
176 transport,
177 tool_filter,
178 } = backend_config;
179
180 tracing::info!(
181 "Attempting to initialize MCP backend '{}' with transport: {:?}",
182 server_name,
183 transport
184 );
185
186 let mut server_info = McpServerInfo {
187 server_name: server_name.clone(),
188 transport: transport.clone(),
189 state: McpConnectionState::Connecting,
190 last_updated: Utc::now(),
191 };
192
193 match crate::tools::McpBackend::new(
194 server_name.clone(),
195 transport.clone(),
196 tool_filter.clone(),
197 )
198 .await
199 {
200 Ok(mcp_backend) => {
201 let tool_names = mcp_backend.supported_tools().await;
202 let tool_count = tool_names.len();
203 tracing::info!(
204 "Successfully initialized MCP backend '{}' with {} tools",
205 server_name,
206 tool_count
207 );
208 server_info.state = McpConnectionState::Connected { tool_names };
209 server_info.last_updated = Utc::now();
210 registry
211 .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
212 .await;
213 }
214 Err(e) => {
215 tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
216 server_info.state = McpConnectionState::Failed {
217 error: e.to_string(),
218 };
219 server_info.last_updated = Utc::now();
220 }
221 }
222
223 mcp_servers.insert(server_name.clone(), server_info);
224 }
225
226 Ok((registry, mcp_servers))
227 }
228
229 pub fn filter_tools_by_visibility(
231 &self,
232 tools: Vec<steer_tools::ToolSchema>,
233 ) -> Vec<steer_tools::ToolSchema> {
234 match &self.tool_config.visibility {
235 ToolVisibility::All => tools,
236 ToolVisibility::ReadOnly => {
237 let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
238 .iter()
239 .map(|name| (*name).to_string())
240 .collect();
241
242 tools
243 .into_iter()
244 .filter(|schema| read_only_names.contains(&schema.name))
245 .collect()
246 }
247 ToolVisibility::Whitelist(allowed) => tools
248 .into_iter()
249 .filter(|schema| allowed.contains(&schema.name))
250 .collect(),
251 ToolVisibility::Blacklist(blocked) => tools
252 .into_iter()
253 .filter(|schema| !blocked.contains(&schema.name))
254 .collect(),
255 }
256 }
257
258 #[cfg(test)]
260 pub fn read_only(default_model: ModelId) -> Self {
261 Self {
262 workspace: WorkspaceConfig::Local {
263 path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
264 },
265 workspace_ref: None,
266 workspace_id: None,
267 repo_ref: None,
268 parent_session_id: None,
269 workspace_name: None,
270 tool_config: SessionToolConfig::read_only(),
271 system_prompt: None,
272 primary_agent_id: None,
273 policy_overrides: SessionPolicyOverrides::empty(),
274 metadata: HashMap::new(),
275 default_model,
276 auto_compaction: AutoCompactionConfig::default(),
277 }
278 }
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
283pub struct AutoCompactionConfig {
284 pub enabled: bool,
285 pub threshold_percent: u32,
286}
287
288impl Default for AutoCompactionConfig {
289 fn default() -> Self {
290 Self {
291 enabled: true,
292 threshold_percent: 90,
293 }
294 }
295}
296
297impl AutoCompactionConfig {
298 pub fn threshold_ratio(&self) -> f64 {
299 f64::from(self.threshold_percent) / 100.0
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
305pub struct SessionPolicyOverrides {
306 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub default_model: Option<ModelId>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
309 pub tool_visibility: Option<ToolVisibility>,
310 #[serde(default = "ToolApprovalPolicyOverrides::empty")]
311 pub approval_policy: ToolApprovalPolicyOverrides,
312}
313
314impl SessionPolicyOverrides {
315 pub fn empty() -> Self {
316 Self {
317 default_model: None,
318 tool_visibility: None,
319 approval_policy: ToolApprovalPolicyOverrides::empty(),
320 }
321 }
322
323 pub fn is_empty(&self) -> bool {
324 self.default_model.is_none()
325 && self.tool_visibility.is_none()
326 && self.approval_policy.is_empty()
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
331pub struct ToolApprovalPolicyOverrides {
332 #[serde(default, skip_serializing_if = "Option::is_none")]
333 pub default_behavior: Option<UnapprovedBehavior>,
334 #[serde(default = "ApprovalRulesOverrides::empty")]
335 pub preapproved: ApprovalRulesOverrides,
336}
337
338impl ToolApprovalPolicyOverrides {
339 pub fn empty() -> Self {
340 Self {
341 default_behavior: None,
342 preapproved: ApprovalRulesOverrides::empty(),
343 }
344 }
345
346 pub fn is_empty(&self) -> bool {
347 self.default_behavior.is_none() && self.preapproved.is_empty()
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
352pub struct ApprovalRulesOverrides {
353 #[serde(default)]
354 pub tools: HashSet<String>,
355 #[serde(default)]
356 pub per_tool: HashMap<String, ToolRuleOverrides>,
357}
358
359impl ApprovalRulesOverrides {
360 pub fn empty() -> Self {
361 Self {
362 tools: HashSet::new(),
363 per_tool: HashMap::new(),
364 }
365 }
366
367 pub fn is_empty(&self) -> bool {
368 self.tools.is_empty() && self.per_tool.is_empty()
369 }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
373#[serde(tag = "type", rename_all = "snake_case")]
374pub enum ToolRuleOverrides {
375 Bash { patterns: Vec<String> },
376 DispatchAgent { agent_patterns: Vec<String> },
377}
378
379#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
381#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
382pub enum ToolVisibility {
383 #[default]
385 All,
386
387 ReadOnly,
389
390 Whitelist(HashSet<String>),
392
393 Blacklist(HashSet<String>),
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum ToolDecision {
399 Allow,
400 Ask,
401 Deny,
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
405#[serde(rename_all = "snake_case")]
406pub enum UnapprovedBehavior {
407 #[default]
408 Prompt,
409 Deny,
410 Allow,
411}
412
413#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum ToolRule {
416 Bash { patterns: Vec<String> },
417 DispatchAgent { agent_patterns: Vec<String> },
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
421pub struct ApprovalRules {
422 #[serde(default)]
423 pub tools: HashSet<String>,
424 #[serde(default)]
425 pub per_tool: HashMap<String, ToolRule>,
426}
427
428impl ApprovalRules {
429 pub fn is_empty(&self) -> bool {
430 self.tools.is_empty() && self.per_tool.is_empty()
431 }
432
433 pub fn bash_patterns(&self) -> Option<&[String]> {
434 self.per_tool.get("bash").and_then(|rule| match rule {
435 ToolRule::Bash { patterns } => Some(patterns.as_slice()),
436 ToolRule::DispatchAgent { .. } => None,
437 })
438 }
439
440 pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
441 self.per_tool
442 .get("dispatch_agent")
443 .and_then(|rule| match rule {
444 ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
445 ToolRule::Bash { .. } => None,
446 })
447 }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
451pub struct ToolApprovalPolicy {
452 pub default_behavior: UnapprovedBehavior,
453 #[serde(default)]
454 pub preapproved: ApprovalRules,
455}
456
457impl Default for ToolApprovalPolicy {
458 fn default() -> Self {
459 Self {
460 default_behavior: UnapprovedBehavior::Prompt,
461 preapproved: ApprovalRules {
462 tools: READ_ONLY_TOOL_NAMES
463 .iter()
464 .map(|name| (*name).to_string())
465 .collect(),
466 per_tool: HashMap::new(),
467 },
468 }
469 }
470}
471
472impl ToolApprovalPolicy {
473 pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
474 if self.preapproved.tools.contains(tool_name) {
475 ToolDecision::Allow
476 } else {
477 match self.default_behavior {
478 UnapprovedBehavior::Prompt => ToolDecision::Ask,
479 UnapprovedBehavior::Deny => ToolDecision::Deny,
480 UnapprovedBehavior::Allow => ToolDecision::Allow,
481 }
482 }
483 }
484
485 pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
486 let Some(patterns) = self.preapproved.bash_patterns() else {
487 return false;
488 };
489 patterns.iter().any(|pattern| {
490 if pattern == command {
491 return true;
492 }
493 glob::Pattern::new(pattern)
494 .map(|glob| glob.matches(command))
495 .unwrap_or(false)
496 })
497 }
498
499 pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
500 let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
501 return false;
502 };
503 patterns.iter().any(|pattern| {
504 if pattern == agent_id {
505 return true;
506 }
507 glob::Pattern::new(pattern)
508 .map(|glob| glob.matches(agent_id))
509 .unwrap_or(false)
510 })
511 }
512
513 pub fn pre_approved_tools(&self) -> &HashSet<String> {
514 &self.preapproved.tools
515 }
516}
517
518impl ToolApprovalPolicyOverrides {
519 pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
520 let mut merged = base.clone();
521
522 if let Some(default_behavior) = self.default_behavior {
523 merged.default_behavior = default_behavior;
524 }
525
526 if !self.preapproved.tools.is_empty() {
527 merged
528 .preapproved
529 .tools
530 .extend(self.preapproved.tools.iter().cloned());
531 }
532
533 for (tool_name, override_rule) in &self.preapproved.per_tool {
534 let base_rule = merged.preapproved.per_tool.get(tool_name);
535 let merged_rule = merge_tool_rule_override(base_rule, override_rule);
536 merged
537 .preapproved
538 .per_tool
539 .insert(tool_name.clone(), merged_rule);
540 }
541
542 merged
543 }
544}
545
546fn merge_tool_rule_override(
547 base: Option<&ToolRule>,
548 override_rule: &ToolRuleOverrides,
549) -> ToolRule {
550 match (base, override_rule) {
551 (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
552 ToolRule::Bash {
553 patterns: merge_patterns(patterns, extra),
554 }
555 }
556 (
557 Some(ToolRule::DispatchAgent { agent_patterns }),
558 ToolRuleOverrides::DispatchAgent {
559 agent_patterns: extra,
560 },
561 ) => ToolRule::DispatchAgent {
562 agent_patterns: merge_patterns(agent_patterns, extra),
563 },
564 (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
565 patterns: patterns.clone(),
566 },
567 (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
568 agent_patterns: agent_patterns.clone(),
569 },
570 }
571}
572
573fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
574 let mut merged = base.to_vec();
575 for pattern in extra {
576 if !merged.iter().any(|existing| existing == pattern) {
577 merged.push(pattern.clone());
578 }
579 }
580 merged
581}
582
583#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
585pub enum RemoteAuth {
586 Bearer { token: String },
587 ApiKey { key: String },
588}
589
590impl RemoteAuth {
591 pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
593 match self {
594 RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
595 RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
596 }
597 }
598}
599
600#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
602#[serde(rename_all = "snake_case")]
603#[derive(Default)]
604pub enum ToolFilter {
605 #[default]
607 All,
608 Include(Vec<String>),
610 Exclude(Vec<String>),
612}
613
614#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
616#[serde(tag = "type", rename_all = "snake_case")]
617pub enum BackendConfig {
618 Mcp {
619 server_name: String,
620 transport: McpTransport,
621 tool_filter: ToolFilter,
622 },
623}
624
625#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
626pub struct SessionToolConfig {
627 pub backends: Vec<BackendConfig>,
628 pub visibility: ToolVisibility,
629 pub approval_policy: ToolApprovalPolicy,
630 pub metadata: HashMap<String, String>,
631}
632
633impl Default for SessionToolConfig {
634 fn default() -> Self {
635 Self {
636 backends: Vec::new(),
637 visibility: ToolVisibility::All,
638 approval_policy: ToolApprovalPolicy::default(),
639 metadata: HashMap::new(),
640 }
641 }
642}
643
644impl SessionToolConfig {
645 pub fn read_only() -> Self {
646 Self {
647 backends: Vec::new(),
648 visibility: ToolVisibility::ReadOnly,
649 approval_policy: ToolApprovalPolicy::default(),
650 metadata: HashMap::new(),
651 }
652 }
653}
654
655#[derive(Debug, Clone, Serialize, Deserialize, Default)]
657pub struct SessionState {
658 pub messages: Vec<Message>,
660
661 pub tool_calls: HashMap<String, ToolCallState>,
663
664 pub approved_tools: HashSet<String>,
666
667 #[serde(default)]
669 pub approved_bash_patterns: HashSet<String>,
670
671 pub last_event_sequence: u64,
673
674 pub metadata: HashMap<String, String>,
676
677 #[serde(default, skip_serializing_if = "Option::is_none")]
680 pub active_message_id: Option<String>,
681
682 #[serde(default, skip_serializing, skip_deserializing)]
685 pub mcp_servers: HashMap<String, McpServerInfo>,
686}
687
688impl SessionState {
689 pub fn add_message(&mut self, message: Message) {
691 self.messages.push(message);
692 }
693
694 pub fn message_count(&self) -> usize {
696 self.messages.len()
697 }
698
699 pub fn last_message(&self) -> Option<&Message> {
701 self.messages.last()
702 }
703
704 pub fn add_tool_call(&mut self, tool_call: ToolCall) {
706 let state = ToolCallState {
707 tool_call: tool_call.clone(),
708 status: ToolCallStatus::PendingApproval,
709 started_at: None,
710 completed_at: None,
711 result: None,
712 };
713 self.tool_calls.insert(tool_call.id, state);
714 }
715
716 pub fn update_tool_call_status(
718 &mut self,
719 tool_call_id: &str,
720 status: ToolCallStatus,
721 ) -> std::result::Result<(), String> {
722 let tool_call = self
723 .tool_calls
724 .get_mut(tool_call_id)
725 .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
726
727 match (&tool_call.status, &status) {
729 (_, ToolCallStatus::Executing) => {
730 tool_call.started_at = Some(Utc::now());
731 }
732 (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
733 tool_call.completed_at = Some(Utc::now());
734 }
735 _ => {}
736 }
737
738 tool_call.status = status;
739 Ok(())
740 }
741
742 pub fn approve_tool(&mut self, tool_name: String) {
744 self.approved_tools.insert(tool_name);
745 }
746
747 pub fn is_tool_approved(&self, tool_name: &str) -> bool {
749 self.approved_tools.contains(tool_name)
750 }
751
752 pub fn validate(&self) -> std::result::Result<(), String> {
754 for message in &self.messages {
756 let tool_calls = Self::extract_tool_calls_from_message(message);
757 if !tool_calls.is_empty() {
758 for tool_call_id in tool_calls {
759 if !self.tool_calls.contains_key(&tool_call_id) {
760 return Err(format!(
761 "Message references unknown tool call: {tool_call_id}"
762 ));
763 }
764 }
765 }
766 }
767
768 Ok(())
769 }
770
771 fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
773 let mut tool_call_ids = Vec::new();
774
775 match &message.data {
776 MessageData::Assistant { content, .. } => {
777 for c in content {
778 if let crate::app::conversation::AssistantContent::ToolCall {
779 tool_call, ..
780 } = c
781 {
782 tool_call_ids.push(tool_call.id.clone());
783 }
784 }
785 }
786 MessageData::Tool { tool_use_id, .. } => {
787 tool_call_ids.push(tool_use_id.clone());
788 }
789 MessageData::User { .. } => {}
790 }
791
792 tool_call_ids
793 }
794}
795
796#[derive(Debug, Clone, Serialize, Deserialize)]
798pub struct ToolCallState {
799 pub tool_call: ToolCall,
800 pub status: ToolCallStatus,
801 pub started_at: Option<DateTime<Utc>>,
802 pub completed_at: Option<DateTime<Utc>>,
803 pub result: Option<ToolResult>,
804}
805
806impl ToolCallState {
807 pub fn is_pending(&self) -> bool {
808 matches!(self.status, ToolCallStatus::PendingApproval)
809 }
810
811 pub fn is_complete(&self) -> bool {
812 matches!(
813 self.status,
814 ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
815 )
816 }
817
818 pub fn duration(&self) -> Option<chrono::Duration> {
819 match (self.started_at, self.completed_at) {
820 (Some(start), Some(end)) => Some(end - start),
821 _ => None,
822 }
823 }
824}
825
826#[derive(Debug, Clone, Serialize, Deserialize)]
828#[serde(tag = "status", rename_all = "snake_case")]
829pub enum ToolCallStatus {
830 PendingApproval,
831 Approved,
832 Denied,
833 Executing,
834 Completed,
835 Failed { error: String },
836}
837
838impl ToolCallStatus {
839 pub fn is_terminal(&self) -> bool {
840 matches!(
841 self,
842 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
843 )
844 }
845}
846
847#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct ToolExecutionStats {
850 #[serde(skip_serializing_if = "Option::is_none")]
851 pub output: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
853 pub json_output: Option<serde_json::Value>, pub result_type: Option<String>, pub success: bool,
856 pub execution_time_ms: u64,
857 pub metadata: HashMap<String, String>,
858}
859
860impl ToolExecutionStats {
861 pub fn success(output: String, execution_time_ms: u64) -> Self {
862 Self {
863 output: Some(output),
864 json_output: None,
865 result_type: None,
866 success: true,
867 execution_time_ms,
868 metadata: HashMap::new(),
869 }
870 }
871
872 pub fn success_typed(
873 json_output: serde_json::Value,
874 result_type: String,
875 execution_time_ms: u64,
876 ) -> Self {
877 Self {
878 output: None,
879 json_output: Some(json_output),
880 result_type: Some(result_type),
881 success: true,
882 execution_time_ms,
883 metadata: HashMap::new(),
884 }
885 }
886
887 pub fn failure(error: String, execution_time_ms: u64) -> Self {
888 Self {
889 output: Some(error),
890 json_output: None,
891 result_type: None,
892 success: false,
893 execution_time_ms,
894 metadata: HashMap::new(),
895 }
896 }
897
898 pub fn with_metadata(mut self, key: String, value: String) -> Self {
899 self.metadata.insert(key, value);
900 self
901 }
902}
903
904#[derive(Debug, Clone, Serialize, Deserialize)]
906pub struct SessionInfo {
907 pub id: String,
908 pub created_at: DateTime<Utc>,
909 pub updated_at: DateTime<Utc>,
910 pub last_model: Option<ModelId>,
912 pub message_count: usize,
913 pub metadata: HashMap<String, String>,
914}
915
916impl From<&Session> for SessionInfo {
917 fn from(session: &Session) -> Self {
918 Self {
919 id: session.id.clone(),
920 created_at: session.created_at,
921 updated_at: session.updated_at,
922 last_model: None, message_count: session.state.message_count(),
924 metadata: session.config.metadata.clone(),
925 }
926 }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932 use crate::app::conversation::{Message, MessageData, UserContent};
933 use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
934 use crate::tools::DISPATCH_AGENT_TOOL_NAME;
935 use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
936 use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
937
938 #[test]
939 fn test_session_creation() {
940 let config = SessionConfig {
941 workspace: WorkspaceConfig::Local {
942 path: PathBuf::from("/test/path"),
943 },
944 workspace_ref: None,
945 workspace_id: None,
946 repo_ref: None,
947 parent_session_id: None,
948 workspace_name: None,
949 tool_config: SessionToolConfig::default(),
950 system_prompt: None,
951 primary_agent_id: None,
952 policy_overrides: SessionPolicyOverrides::empty(),
953 metadata: HashMap::new(),
954 default_model: test_model(),
955 auto_compaction: AutoCompactionConfig::default(),
956 };
957 let session = Session::new("test-session".to_string(), config.clone());
958
959 assert_eq!(session.id, "test-session");
960 assert_eq!(
961 session
962 .config
963 .tool_config
964 .approval_policy
965 .tool_decision("any_tool"),
966 ToolDecision::Ask
967 );
968 assert_eq!(session.state.message_count(), 0);
969 }
970
971 #[test]
972 fn test_tool_approval_policy_prompt_unapproved() {
973 let policy = ToolApprovalPolicy {
974 default_behavior: UnapprovedBehavior::Prompt,
975 preapproved: ApprovalRules {
976 tools: ["read_file", "list_files"]
977 .iter()
978 .map(|s| (*s).to_string())
979 .collect(),
980 per_tool: HashMap::new(),
981 },
982 };
983
984 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
985 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
986 }
987
988 #[test]
989 fn test_tool_approval_policy_deny_unapproved() {
990 let policy = ToolApprovalPolicy {
991 default_behavior: UnapprovedBehavior::Deny,
992 preapproved: ApprovalRules {
993 tools: ["read_file", "list_files"]
994 .iter()
995 .map(|s| (*s).to_string())
996 .collect(),
997 per_tool: HashMap::new(),
998 },
999 };
1000
1001 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1002 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
1003 }
1004
1005 #[test]
1006 fn test_tool_approval_policy_default() {
1007 let policy = ToolApprovalPolicy::default();
1008
1009 assert_eq!(
1010 policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1011 ToolDecision::Allow
1012 );
1013 assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1014 }
1015
1016 #[test]
1017 fn test_tool_approval_policy_allow_unapproved() {
1018 let policy = ToolApprovalPolicy {
1019 default_behavior: UnapprovedBehavior::Allow,
1020 preapproved: ApprovalRules {
1021 tools: ["read_file", "list_files"]
1022 .iter()
1023 .map(|s| (*s).to_string())
1024 .collect(),
1025 per_tool: HashMap::new(),
1026 },
1027 };
1028
1029 assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1030 assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1031 }
1032
1033 #[test]
1034 fn test_tool_approval_policy_overrides_union_rules() {
1035 let base_policy = ToolApprovalPolicy {
1036 default_behavior: UnapprovedBehavior::Prompt,
1037 preapproved: ApprovalRules {
1038 tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1039 per_tool: [
1040 (
1041 BASH_TOOL_NAME.to_string(),
1042 ToolRule::Bash {
1043 patterns: vec!["git status".to_string()],
1044 },
1045 ),
1046 (
1047 DISPATCH_AGENT_TOOL_NAME.to_string(),
1048 ToolRule::DispatchAgent {
1049 agent_patterns: vec!["explore".to_string()],
1050 },
1051 ),
1052 ]
1053 .into_iter()
1054 .collect(),
1055 },
1056 };
1057
1058 let overrides = ToolApprovalPolicyOverrides {
1059 default_behavior: Some(UnapprovedBehavior::Deny),
1060 preapproved: ApprovalRulesOverrides {
1061 tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1062 per_tool: [
1063 (
1064 BASH_TOOL_NAME.to_string(),
1065 ToolRuleOverrides::Bash {
1066 patterns: vec!["git log".to_string()],
1067 },
1068 ),
1069 (
1070 DISPATCH_AGENT_TOOL_NAME.to_string(),
1071 ToolRuleOverrides::DispatchAgent {
1072 agent_patterns: vec!["review".to_string()],
1073 },
1074 ),
1075 ]
1076 .into_iter()
1077 .collect(),
1078 },
1079 };
1080
1081 let merged = overrides.apply_to(&base_policy);
1082
1083 assert_eq!(merged.default_behavior, UnapprovedBehavior::Deny);
1084 assert!(merged.preapproved.tools.contains("read_file"));
1085 assert!(merged.preapproved.tools.contains("write_file"));
1086
1087 let bash_patterns = match merged
1088 .preapproved
1089 .per_tool
1090 .get(BASH_TOOL_NAME)
1091 .expect("bash rule")
1092 {
1093 ToolRule::Bash { patterns } => patterns,
1094 ToolRule::DispatchAgent { .. } => {
1095 panic!("Unexpected bash rule: dispatch agent")
1096 }
1097 };
1098 assert!(bash_patterns.contains(&"git status".to_string()));
1099 assert!(bash_patterns.contains(&"git log".to_string()));
1100 assert_eq!(bash_patterns.len(), 2);
1101
1102 let agent_patterns = match merged
1103 .preapproved
1104 .per_tool
1105 .get(DISPATCH_AGENT_TOOL_NAME)
1106 .expect("dispatch_agent rule")
1107 {
1108 ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1109 ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1110 };
1111 assert!(agent_patterns.contains(&"explore".to_string()));
1112 assert!(agent_patterns.contains(&"review".to_string()));
1113 assert_eq!(agent_patterns.len(), 2);
1114 }
1115
1116 #[test]
1117 fn test_bash_pattern_matching() {
1118 let policy = ToolApprovalPolicy {
1119 default_behavior: UnapprovedBehavior::Prompt,
1120 preapproved: ApprovalRules {
1121 tools: HashSet::new(),
1122 per_tool: [(
1123 "bash".to_string(),
1124 ToolRule::Bash {
1125 patterns: vec![
1126 "git status".to_string(),
1127 "git log*".to_string(),
1128 "git * --oneline".to_string(),
1129 "ls -?a*".to_string(),
1130 "cargo build*".to_string(),
1131 ],
1132 },
1133 )]
1134 .into_iter()
1135 .collect(),
1136 },
1137 };
1138
1139 assert!(policy.is_bash_pattern_preapproved("git status"));
1140 assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1141 assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1142 assert!(policy.is_bash_pattern_preapproved("ls -la"));
1143 assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1144 assert!(!policy.is_bash_pattern_preapproved("git commit"));
1145 assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1146 assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1147 }
1148
1149 #[test]
1150 fn test_dispatch_agent_pattern_matching() {
1151 let policy = ToolApprovalPolicy {
1152 default_behavior: UnapprovedBehavior::Prompt,
1153 preapproved: ApprovalRules {
1154 tools: HashSet::new(),
1155 per_tool: [(
1156 "dispatch_agent".to_string(),
1157 ToolRule::DispatchAgent {
1158 agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1159 },
1160 )]
1161 .into_iter()
1162 .collect(),
1163 },
1164 };
1165
1166 assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1167 assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1168 assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1169 }
1170
1171 #[test]
1172 fn test_session_state_validation() {
1173 let mut state = SessionState::default();
1174
1175 assert!(state.validate().is_ok());
1177
1178 let message = Message {
1180 data: MessageData::User {
1181 content: vec![UserContent::Text {
1182 text: "Hello".to_string(),
1183 }],
1184 },
1185 timestamp: 123_456_789,
1186 id: "msg1".to_string(),
1187 parent_message_id: None,
1188 };
1189 state.add_message(message);
1190
1191 assert!(state.validate().is_ok());
1192 assert_eq!(state.message_count(), 1);
1193 }
1194
1195 #[test]
1196 fn test_tool_call_state_tracking() {
1197 let mut state = SessionState::default();
1198
1199 let tool_call = ToolCall {
1200 id: "tool1".to_string(),
1201 name: "read_file".to_string(),
1202 parameters: serde_json::json!({"path": "/test.txt"}),
1203 };
1204
1205 state.add_tool_call(tool_call.clone());
1206 assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1207
1208 state
1209 .update_tool_call_status("tool1", ToolCallStatus::Executing)
1210 .unwrap();
1211 let tool_state = state.tool_calls.get("tool1").unwrap();
1212 assert!(tool_state.started_at.is_some());
1213 assert!(!tool_state.is_complete());
1214
1215 state
1216 .update_tool_call_status("tool1", ToolCallStatus::Completed)
1217 .unwrap();
1218 let tool_state = state.tool_calls.get("tool1").unwrap();
1219 assert!(tool_state.completed_at.is_some());
1220 assert!(tool_state.is_complete());
1221 }
1222
1223 #[test]
1224 fn test_session_tool_config_default() {
1225 let config = SessionToolConfig::default();
1226 assert!(config.backends.is_empty());
1227 }
1228
1229 #[test]
1230 fn test_tool_filter_exclude() {
1231 let excluded =
1232 ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1233
1234 if let ToolFilter::Exclude(tools) = &excluded {
1235 assert_eq!(tools.len(), 2);
1236 assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1237 assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1238 } else {
1239 panic!("Expected ToolFilter::Exclude");
1240 }
1241 }
1242
1243 #[test]
1244 fn test_session_tool_config_read_only() {
1245 let config = SessionToolConfig::read_only();
1246 assert_eq!(config.backends.len(), 0);
1247 assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1248 assert_eq!(
1249 config.approval_policy.default_behavior,
1250 UnapprovedBehavior::Prompt
1251 );
1252 }
1253
1254 #[tokio::test]
1255 async fn test_session_config_build_registry_no_default_backends() {
1256 let config = SessionConfig {
1260 workspace: WorkspaceConfig::Local {
1261 path: PathBuf::from("/test/path"),
1262 },
1263 workspace_ref: None,
1264 workspace_id: None,
1265 repo_ref: None,
1266 parent_session_id: None,
1267 workspace_name: None,
1268 tool_config: SessionToolConfig::default(), system_prompt: None,
1270 primary_agent_id: None,
1271 policy_overrides: SessionPolicyOverrides::empty(),
1272 metadata: HashMap::new(),
1273 default_model: test_model(),
1274 auto_compaction: AutoCompactionConfig::default(),
1275 };
1276
1277 let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1278 let schemas = registry.get_tool_schemas().await;
1279
1280 assert!(
1281 schemas.is_empty(),
1282 "BackendRegistry should be empty with default config; got: {:?}",
1283 schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1284 );
1285 }
1286
1287 #[test]
1294 fn test_mcp_status_tracking() {
1295 let mut session_state = SessionState::default();
1297
1298 let mcp_info = McpServerInfo {
1300 server_name: "test-server".to_string(),
1301 transport: crate::tools::McpTransport::Stdio {
1302 command: "python".to_string(),
1303 args: vec!["-m".to_string(), "test_server".to_string()],
1304 },
1305 state: McpConnectionState::Connected {
1306 tool_names: vec![
1307 "tool1".to_string(),
1308 "tool2".to_string(),
1309 "tool3".to_string(),
1310 "tool4".to_string(),
1311 "tool5".to_string(),
1312 ],
1313 },
1314 last_updated: Utc::now(),
1315 };
1316
1317 session_state
1318 .mcp_servers
1319 .insert("test-server".to_string(), mcp_info.clone());
1320
1321 assert_eq!(session_state.mcp_servers.len(), 1);
1323 let stored = session_state.mcp_servers.get("test-server").unwrap();
1324 assert_eq!(stored.server_name, "test-server");
1325 assert!(matches!(
1326 stored.state,
1327 McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1328 ));
1329
1330 let failed_info = McpServerInfo {
1332 server_name: "failed-server".to_string(),
1333 transport: crate::tools::McpTransport::Tcp {
1334 host: "localhost".to_string(),
1335 port: 9999,
1336 },
1337 state: McpConnectionState::Failed {
1338 error: "Connection refused".to_string(),
1339 },
1340 last_updated: Utc::now(),
1341 };
1342
1343 session_state
1344 .mcp_servers
1345 .insert("failed-server".to_string(), failed_info);
1346 assert_eq!(session_state.mcp_servers.len(), 2);
1347 }
1348
1349 #[tokio::test]
1350 async fn test_mcp_server_tracking_in_build_registry() {
1351 let mut config = SessionConfig::read_only(test_model());
1353
1354 config.tool_config.backends.push(BackendConfig::Mcp {
1356 server_name: "bad-server".to_string(),
1357 transport: crate::tools::McpTransport::Tcp {
1358 host: "nonexistent.invalid".to_string(),
1359 port: 12345,
1360 },
1361 tool_filter: ToolFilter::All,
1362 });
1363
1364 config.tool_config.backends.push(BackendConfig::Mcp {
1366 server_name: "good-server".to_string(),
1367 transport: crate::tools::McpTransport::Stdio {
1368 command: "echo".to_string(),
1369 args: vec!["test".to_string()],
1370 },
1371 tool_filter: ToolFilter::All,
1372 });
1373
1374 let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1375
1376 assert_eq!(mcp_servers.len(), 2);
1378
1379 let bad_server = mcp_servers.get("bad-server").unwrap();
1381 assert_eq!(bad_server.server_name, "bad-server");
1382 assert!(matches!(
1383 bad_server.state,
1384 McpConnectionState::Failed { .. }
1385 ));
1386
1387 let good_server = mcp_servers.get("good-server").unwrap();
1389 assert_eq!(good_server.server_name, "good-server");
1390 assert!(matches!(
1391 good_server.state,
1392 McpConnectionState::Failed { .. }
1393 ));
1394 }
1395
1396 #[test]
1397 fn test_backend_config_mcp_variant() {
1398 let mcp_config = BackendConfig::Mcp {
1399 server_name: "test-mcp".to_string(),
1400 transport: crate::tools::McpTransport::Stdio {
1401 command: "python".to_string(),
1402 args: vec!["-m".to_string(), "test_server".to_string()],
1403 },
1404 tool_filter: ToolFilter::All,
1405 };
1406
1407 let BackendConfig::Mcp {
1408 server_name,
1409 transport,
1410 ..
1411 } = mcp_config;
1412
1413 assert_eq!(server_name, "test-mcp");
1414 if let crate::tools::McpTransport::Stdio { command, args } = transport {
1415 assert_eq!(command, "python");
1416 assert_eq!(args.len(), 2);
1417 } else {
1418 panic!("Expected Stdio transport");
1419 }
1420 }
1421}