Skip to main content

steer_core/session/
state.rs

1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15/// State of an MCP server connection
16#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18    /// Currently attempting to connect
19    Connecting,
20    /// Successfully connected
21    Connected {
22        /// Names of tools available from this server
23        tool_names: Vec<String>,
24    },
25    /// Gracefully disconnected
26    Disconnected {
27        /// Optional reason for disconnection
28        reason: Option<String>,
29    },
30    /// Failed to connect
31    Failed {
32        /// Error message describing the failure
33        error: String,
34    },
35}
36
37/// Information about an MCP server
38#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40    /// The configured server name
41    pub server_name: String,
42    /// The transport configuration
43    pub transport: McpTransport,
44    /// Current connection state
45    pub state: McpConnectionState,
46    /// Timestamp when this state was last updated
47    pub last_updated: DateTime<Utc>,
48}
49
50/// Defines the primary execution environment for a session's workspace
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54    Local {
55        path: PathBuf,
56    },
57    Remote {
58        agent_address: String,
59        auth: Option<RemoteAuth>,
60    },
61}
62
63impl WorkspaceConfig {
64    pub fn get_path(&self) -> Option<String> {
65        match self {
66            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68        }
69    }
70
71    /// Convert to steer_workspace::WorkspaceConfig
72    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73        match self {
74            WorkspaceConfig::Local { path } => {
75                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76            }
77            WorkspaceConfig::Remote {
78                agent_address,
79                auth,
80            } => steer_workspace::WorkspaceConfig::Remote {
81                address: agent_address.clone(),
82                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83            },
84        }
85    }
86}
87
88impl Default for WorkspaceConfig {
89    fn default() -> Self {
90        Self::Local {
91            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92        }
93    }
94}
95
96/// Complete session representation
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99    pub id: String,
100    pub created_at: DateTime<Utc>,
101    pub updated_at: DateTime<Utc>,
102    pub config: SessionConfig,
103    pub state: SessionState,
104}
105
106impl Session {
107    pub fn new(id: String, config: SessionConfig) -> Self {
108        let now = Utc::now();
109        Self {
110            id,
111            created_at: now,
112            updated_at: now,
113            config,
114            state: SessionState::default(),
115        }
116    }
117
118    pub fn update_timestamp(&mut self) {
119        self.updated_at = Utc::now();
120    }
121
122    /// Check if session has any recent activity
123    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124        let cutoff = Utc::now() - threshold;
125        self.updated_at > cutoff
126    }
127
128    /// Build a workspace from this session's configuration
129    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131    }
132}
133
134/// Session configuration - immutable once created
135#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137    pub workspace: WorkspaceConfig,
138    #[serde(default)]
139    pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140    #[serde(default)]
141    pub workspace_id: Option<crate::workspace::WorkspaceId>,
142    #[serde(default)]
143    pub repo_ref: Option<crate::workspace::RepoRef>,
144    #[serde(default)]
145    pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146    #[serde(default)]
147    pub workspace_name: Option<String>,
148    pub tool_config: SessionToolConfig,
149    /// Optional custom system prompt to use for the session. If `None`, Steer will
150    /// fall back to its built-in default prompt.
151    pub system_prompt: Option<String>,
152    /// Primary agent mode for this session. Defaults to "normal" if unset.
153    #[serde(default)]
154    pub primary_agent_id: Option<String>,
155    /// User-controlled policy overrides that apply on top of the primary agent base policy.
156    #[serde(default = "SessionPolicyOverrides::empty")]
157    pub policy_overrides: SessionPolicyOverrides,
158    pub metadata: HashMap<String, String>,
159    pub default_model: ModelId,
160    #[serde(default)]
161    pub auto_compaction: AutoCompactionConfig,
162}
163
164impl SessionConfig {
165    /// Build a BackendRegistry from MCP server configurations.
166    /// Returns the registry and a map of MCP server connection states.
167    pub async fn build_registry(
168        &self,
169    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
170        let mut registry = BackendRegistry::new();
171        let mut mcp_servers = HashMap::new();
172
173        for backend_config in &self.tool_config.backends {
174            let BackendConfig::Mcp {
175                server_name,
176                transport,
177                tool_filter,
178            } = backend_config;
179
180            tracing::info!(
181                "Attempting to initialize MCP backend '{}' with transport: {:?}",
182                server_name,
183                transport
184            );
185
186            let mut server_info = McpServerInfo {
187                server_name: server_name.clone(),
188                transport: transport.clone(),
189                state: McpConnectionState::Connecting,
190                last_updated: Utc::now(),
191            };
192
193            match crate::tools::McpBackend::new(
194                server_name.clone(),
195                transport.clone(),
196                tool_filter.clone(),
197            )
198            .await
199            {
200                Ok(mcp_backend) => {
201                    let tool_names = mcp_backend.supported_tools().await;
202                    let tool_count = tool_names.len();
203                    tracing::info!(
204                        "Successfully initialized MCP backend '{}' with {} tools",
205                        server_name,
206                        tool_count
207                    );
208                    server_info.state = McpConnectionState::Connected { tool_names };
209                    server_info.last_updated = Utc::now();
210                    registry
211                        .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
212                        .await;
213                }
214                Err(e) => {
215                    tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
216                    server_info.state = McpConnectionState::Failed {
217                        error: e.to_string(),
218                    };
219                    server_info.last_updated = Utc::now();
220                }
221            }
222
223            mcp_servers.insert(server_name.clone(), server_info);
224        }
225
226        Ok((registry, mcp_servers))
227    }
228
229    /// Filter tools based on visibility settings
230    pub fn filter_tools_by_visibility(
231        &self,
232        tools: Vec<steer_tools::ToolSchema>,
233    ) -> Vec<steer_tools::ToolSchema> {
234        match &self.tool_config.visibility {
235            ToolVisibility::All => tools,
236            ToolVisibility::ReadOnly => {
237                let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
238                    .iter()
239                    .map(|name| (*name).to_string())
240                    .collect();
241
242                tools
243                    .into_iter()
244                    .filter(|schema| read_only_names.contains(&schema.name))
245                    .collect()
246            }
247            ToolVisibility::Whitelist(allowed) => tools
248                .into_iter()
249                .filter(|schema| allowed.contains(&schema.name))
250                .collect(),
251            ToolVisibility::Blacklist(blocked) => tools
252                .into_iter()
253                .filter(|schema| !blocked.contains(&schema.name))
254                .collect(),
255        }
256    }
257
258    /// Minimal read-only configuration
259    #[cfg(test)]
260    pub fn read_only(default_model: ModelId) -> Self {
261        Self {
262            workspace: WorkspaceConfig::Local {
263                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
264            },
265            workspace_ref: None,
266            workspace_id: None,
267            repo_ref: None,
268            parent_session_id: None,
269            workspace_name: None,
270            tool_config: SessionToolConfig::read_only(),
271            system_prompt: None,
272            primary_agent_id: None,
273            policy_overrides: SessionPolicyOverrides::empty(),
274            metadata: HashMap::new(),
275            default_model,
276            auto_compaction: AutoCompactionConfig::default(),
277        }
278    }
279}
280
281/// Configuration for automatic context-window compaction.
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
283pub struct AutoCompactionConfig {
284    pub enabled: bool,
285    pub threshold_percent: u32,
286}
287
288impl Default for AutoCompactionConfig {
289    fn default() -> Self {
290        Self {
291            enabled: true,
292            threshold_percent: 90,
293        }
294    }
295}
296
297impl AutoCompactionConfig {
298    pub fn threshold_ratio(&self) -> f64 {
299        f64::from(self.threshold_percent) / 100.0
300    }
301}
302
303/// User-controlled policy overrides applied on top of a primary agent base policy.
304#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
305pub struct SessionPolicyOverrides {
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub default_model: Option<ModelId>,
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub tool_visibility: Option<ToolVisibility>,
310    #[serde(default = "ToolApprovalPolicyOverrides::empty")]
311    pub approval_policy: ToolApprovalPolicyOverrides,
312}
313
314impl SessionPolicyOverrides {
315    pub fn empty() -> Self {
316        Self {
317            default_model: None,
318            tool_visibility: None,
319            approval_policy: ToolApprovalPolicyOverrides::empty(),
320        }
321    }
322
323    pub fn is_empty(&self) -> bool {
324        self.default_model.is_none()
325            && self.tool_visibility.is_none()
326            && self.approval_policy.is_empty()
327    }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
331pub struct ToolApprovalPolicyOverrides {
332    #[serde(default, skip_serializing_if = "Option::is_none")]
333    pub default_behavior: Option<UnapprovedBehavior>,
334    #[serde(default = "ApprovalRulesOverrides::empty")]
335    pub preapproved: ApprovalRulesOverrides,
336}
337
338impl ToolApprovalPolicyOverrides {
339    pub fn empty() -> Self {
340        Self {
341            default_behavior: None,
342            preapproved: ApprovalRulesOverrides::empty(),
343        }
344    }
345
346    pub fn is_empty(&self) -> bool {
347        self.default_behavior.is_none() && self.preapproved.is_empty()
348    }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
352pub struct ApprovalRulesOverrides {
353    #[serde(default)]
354    pub tools: HashSet<String>,
355    #[serde(default)]
356    pub per_tool: HashMap<String, ToolRuleOverrides>,
357}
358
359impl ApprovalRulesOverrides {
360    pub fn empty() -> Self {
361        Self {
362            tools: HashSet::new(),
363            per_tool: HashMap::new(),
364        }
365    }
366
367    pub fn is_empty(&self) -> bool {
368        self.tools.is_empty() && self.per_tool.is_empty()
369    }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
373#[serde(tag = "type", rename_all = "snake_case")]
374pub enum ToolRuleOverrides {
375    Bash { patterns: Vec<String> },
376    DispatchAgent { agent_patterns: Vec<String> },
377}
378
379/// Tool visibility configuration - controls which tools are shown to the AI agent
380#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
381#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
382pub enum ToolVisibility {
383    /// Show all registered tools to the AI
384    #[default]
385    All,
386
387    /// Only show read-only tools to the AI
388    ReadOnly,
389
390    /// Show only specific tools to the AI (whitelist)
391    Whitelist(HashSet<String>),
392
393    /// Hide specific tools from the AI (blacklist)
394    Blacklist(HashSet<String>),
395}
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum ToolDecision {
399    Allow,
400    Ask,
401    Deny,
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
405#[serde(rename_all = "snake_case")]
406pub enum UnapprovedBehavior {
407    #[default]
408    Prompt,
409    Deny,
410    Allow,
411}
412
413#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
414#[serde(tag = "type", rename_all = "snake_case")]
415pub enum ToolRule {
416    Bash { patterns: Vec<String> },
417    DispatchAgent { agent_patterns: Vec<String> },
418}
419
420#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
421pub struct ApprovalRules {
422    #[serde(default)]
423    pub tools: HashSet<String>,
424    #[serde(default)]
425    pub per_tool: HashMap<String, ToolRule>,
426}
427
428impl ApprovalRules {
429    pub fn is_empty(&self) -> bool {
430        self.tools.is_empty() && self.per_tool.is_empty()
431    }
432
433    pub fn bash_patterns(&self) -> Option<&[String]> {
434        self.per_tool.get("bash").and_then(|rule| match rule {
435            ToolRule::Bash { patterns } => Some(patterns.as_slice()),
436            ToolRule::DispatchAgent { .. } => None,
437        })
438    }
439
440    pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
441        self.per_tool
442            .get("dispatch_agent")
443            .and_then(|rule| match rule {
444                ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
445                ToolRule::Bash { .. } => None,
446            })
447    }
448}
449
450#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
451pub struct ToolApprovalPolicy {
452    pub default_behavior: UnapprovedBehavior,
453    #[serde(default)]
454    pub preapproved: ApprovalRules,
455}
456
457impl Default for ToolApprovalPolicy {
458    fn default() -> Self {
459        Self {
460            default_behavior: UnapprovedBehavior::Prompt,
461            preapproved: ApprovalRules {
462                tools: READ_ONLY_TOOL_NAMES
463                    .iter()
464                    .map(|name| (*name).to_string())
465                    .collect(),
466                per_tool: HashMap::new(),
467            },
468        }
469    }
470}
471
472impl ToolApprovalPolicy {
473    pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
474        if self.preapproved.tools.contains(tool_name) {
475            ToolDecision::Allow
476        } else {
477            match self.default_behavior {
478                UnapprovedBehavior::Prompt => ToolDecision::Ask,
479                UnapprovedBehavior::Deny => ToolDecision::Deny,
480                UnapprovedBehavior::Allow => ToolDecision::Allow,
481            }
482        }
483    }
484
485    pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
486        let Some(patterns) = self.preapproved.bash_patterns() else {
487            return false;
488        };
489        patterns.iter().any(|pattern| {
490            if pattern == command {
491                return true;
492            }
493            glob::Pattern::new(pattern)
494                .map(|glob| glob.matches(command))
495                .unwrap_or(false)
496        })
497    }
498
499    pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
500        let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
501            return false;
502        };
503        patterns.iter().any(|pattern| {
504            if pattern == agent_id {
505                return true;
506            }
507            glob::Pattern::new(pattern)
508                .map(|glob| glob.matches(agent_id))
509                .unwrap_or(false)
510        })
511    }
512
513    pub fn pre_approved_tools(&self) -> &HashSet<String> {
514        &self.preapproved.tools
515    }
516}
517
518impl ToolApprovalPolicyOverrides {
519    pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
520        let mut merged = base.clone();
521
522        if let Some(default_behavior) = self.default_behavior {
523            merged.default_behavior = default_behavior;
524        }
525
526        if !self.preapproved.tools.is_empty() {
527            merged
528                .preapproved
529                .tools
530                .extend(self.preapproved.tools.iter().cloned());
531        }
532
533        for (tool_name, override_rule) in &self.preapproved.per_tool {
534            let base_rule = merged.preapproved.per_tool.get(tool_name);
535            let merged_rule = merge_tool_rule_override(base_rule, override_rule);
536            merged
537                .preapproved
538                .per_tool
539                .insert(tool_name.clone(), merged_rule);
540        }
541
542        merged
543    }
544}
545
546fn merge_tool_rule_override(
547    base: Option<&ToolRule>,
548    override_rule: &ToolRuleOverrides,
549) -> ToolRule {
550    match (base, override_rule) {
551        (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
552            ToolRule::Bash {
553                patterns: merge_patterns(patterns, extra),
554            }
555        }
556        (
557            Some(ToolRule::DispatchAgent { agent_patterns }),
558            ToolRuleOverrides::DispatchAgent {
559                agent_patterns: extra,
560            },
561        ) => ToolRule::DispatchAgent {
562            agent_patterns: merge_patterns(agent_patterns, extra),
563        },
564        (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
565            patterns: patterns.clone(),
566        },
567        (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
568            agent_patterns: agent_patterns.clone(),
569        },
570    }
571}
572
573fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
574    let mut merged = base.to_vec();
575    for pattern in extra {
576        if !merged.iter().any(|existing| existing == pattern) {
577            merged.push(pattern.clone());
578        }
579    }
580    merged
581}
582
583/// Authentication configuration for remote backends
584#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
585pub enum RemoteAuth {
586    Bearer { token: String },
587    ApiKey { key: String },
588}
589
590impl RemoteAuth {
591    /// Convert to steer_workspace RemoteAuth type
592    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
593        match self {
594            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
595            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
596        }
597    }
598}
599
600/// Tool filtering configuration for backends
601#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
602#[serde(rename_all = "snake_case")]
603#[derive(Default)]
604pub enum ToolFilter {
605    /// Include all available tools
606    #[default]
607    All,
608    /// Include only the specified tools
609    Include(Vec<String>),
610    /// Include all tools except the specified ones
611    Exclude(Vec<String>),
612}
613
614/// Configuration for MCP server backends
615#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
616#[serde(tag = "type", rename_all = "snake_case")]
617pub enum BackendConfig {
618    Mcp {
619        server_name: String,
620        transport: McpTransport,
621        tool_filter: ToolFilter,
622    },
623}
624
625#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
626pub struct SessionToolConfig {
627    pub backends: Vec<BackendConfig>,
628    pub visibility: ToolVisibility,
629    pub approval_policy: ToolApprovalPolicy,
630    pub metadata: HashMap<String, String>,
631}
632
633impl Default for SessionToolConfig {
634    fn default() -> Self {
635        Self {
636            backends: Vec::new(),
637            visibility: ToolVisibility::All,
638            approval_policy: ToolApprovalPolicy::default(),
639            metadata: HashMap::new(),
640        }
641    }
642}
643
644impl SessionToolConfig {
645    pub fn read_only() -> Self {
646        Self {
647            backends: Vec::new(),
648            visibility: ToolVisibility::ReadOnly,
649            approval_policy: ToolApprovalPolicy::default(),
650            metadata: HashMap::new(),
651        }
652    }
653}
654
655/// Mutable session state that changes during execution
656#[derive(Debug, Clone, Serialize, Deserialize, Default)]
657pub struct SessionState {
658    /// Conversation messages
659    pub messages: Vec<Message>,
660
661    /// Tool call tracking
662    pub tool_calls: HashMap<String, ToolCallState>,
663
664    /// Tools that have been approved for this session
665    pub approved_tools: HashSet<String>,
666
667    /// Bash commands that have been approved for this session (dynamically added)
668    #[serde(default)]
669    pub approved_bash_patterns: HashSet<String>,
670
671    /// Last processed event sequence number for replay
672    pub last_event_sequence: u64,
673
674    /// Additional runtime metadata
675    pub metadata: HashMap<String, String>,
676
677    /// The ID of the currently active message (head of selected branch)
678    /// None means use last message semantics for backward compatibility
679    #[serde(default, skip_serializing_if = "Option::is_none")]
680    pub active_message_id: Option<String>,
681
682    /// Status of MCP server connections
683    /// This is a transient field that is rebuilt on session activation
684    #[serde(default, skip_serializing, skip_deserializing)]
685    pub mcp_servers: HashMap<String, McpServerInfo>,
686}
687
688impl SessionState {
689    /// Add a message to the conversation
690    pub fn add_message(&mut self, message: Message) {
691        self.messages.push(message);
692    }
693
694    /// Get the number of messages in the conversation
695    pub fn message_count(&self) -> usize {
696        self.messages.len()
697    }
698
699    /// Get the last message in the conversation
700    pub fn last_message(&self) -> Option<&Message> {
701        self.messages.last()
702    }
703
704    /// Add a tool call to tracking
705    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
706        let state = ToolCallState {
707            tool_call: tool_call.clone(),
708            status: ToolCallStatus::PendingApproval,
709            started_at: None,
710            completed_at: None,
711            result: None,
712        };
713        self.tool_calls.insert(tool_call.id, state);
714    }
715
716    /// Update tool call status
717    pub fn update_tool_call_status(
718        &mut self,
719        tool_call_id: &str,
720        status: ToolCallStatus,
721    ) -> std::result::Result<(), String> {
722        let tool_call = self
723            .tool_calls
724            .get_mut(tool_call_id)
725            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
726
727        // Update timestamps based on status changes
728        match (&tool_call.status, &status) {
729            (_, ToolCallStatus::Executing) => {
730                tool_call.started_at = Some(Utc::now());
731            }
732            (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
733                tool_call.completed_at = Some(Utc::now());
734            }
735            _ => {}
736        }
737
738        tool_call.status = status;
739        Ok(())
740    }
741
742    /// Approve a tool for future use
743    pub fn approve_tool(&mut self, tool_name: String) {
744        self.approved_tools.insert(tool_name);
745    }
746
747    /// Check if a tool is approved
748    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
749        self.approved_tools.contains(tool_name)
750    }
751
752    /// Validate internal consistency
753    pub fn validate(&self) -> std::result::Result<(), String> {
754        // Check that all tool calls referenced in messages exist
755        for message in &self.messages {
756            let tool_calls = Self::extract_tool_calls_from_message(message);
757            if !tool_calls.is_empty() {
758                for tool_call_id in tool_calls {
759                    if !self.tool_calls.contains_key(&tool_call_id) {
760                        return Err(format!(
761                            "Message references unknown tool call: {tool_call_id}"
762                        ));
763                    }
764                }
765            }
766        }
767
768        Ok(())
769    }
770
771    /// Extract tool call IDs from a message
772    fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
773        let mut tool_call_ids = Vec::new();
774
775        match &message.data {
776            MessageData::Assistant { content, .. } => {
777                for c in content {
778                    if let crate::app::conversation::AssistantContent::ToolCall {
779                        tool_call, ..
780                    } = c
781                    {
782                        tool_call_ids.push(tool_call.id.clone());
783                    }
784                }
785            }
786            MessageData::Tool { tool_use_id, .. } => {
787                tool_call_ids.push(tool_use_id.clone());
788            }
789            MessageData::User { .. } => {}
790        }
791
792        tool_call_ids
793    }
794}
795
796/// Tool call state tracking
797#[derive(Debug, Clone, Serialize, Deserialize)]
798pub struct ToolCallState {
799    pub tool_call: ToolCall,
800    pub status: ToolCallStatus,
801    pub started_at: Option<DateTime<Utc>>,
802    pub completed_at: Option<DateTime<Utc>>,
803    pub result: Option<ToolResult>,
804}
805
806impl ToolCallState {
807    pub fn is_pending(&self) -> bool {
808        matches!(self.status, ToolCallStatus::PendingApproval)
809    }
810
811    pub fn is_complete(&self) -> bool {
812        matches!(
813            self.status,
814            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
815        )
816    }
817
818    pub fn duration(&self) -> Option<chrono::Duration> {
819        match (self.started_at, self.completed_at) {
820            (Some(start), Some(end)) => Some(end - start),
821            _ => None,
822        }
823    }
824}
825
826/// Tool call execution status
827#[derive(Debug, Clone, Serialize, Deserialize)]
828#[serde(tag = "status", rename_all = "snake_case")]
829pub enum ToolCallStatus {
830    PendingApproval,
831    Approved,
832    Denied,
833    Executing,
834    Completed,
835    Failed { error: String },
836}
837
838impl ToolCallStatus {
839    pub fn is_terminal(&self) -> bool {
840        matches!(
841            self,
842            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
843        )
844    }
845}
846
847/// Tool execution statistics
848#[derive(Debug, Clone, Serialize, Deserialize)]
849pub struct ToolExecutionStats {
850    #[serde(skip_serializing_if = "Option::is_none")]
851    pub output: Option<String>, // Legacy string output
852    #[serde(skip_serializing_if = "Option::is_none")]
853    pub json_output: Option<serde_json::Value>, // Typed JSON output
854    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
855    pub success: bool,
856    pub execution_time_ms: u64,
857    pub metadata: HashMap<String, String>,
858}
859
860impl ToolExecutionStats {
861    pub fn success(output: String, execution_time_ms: u64) -> Self {
862        Self {
863            output: Some(output),
864            json_output: None,
865            result_type: None,
866            success: true,
867            execution_time_ms,
868            metadata: HashMap::new(),
869        }
870    }
871
872    pub fn success_typed(
873        json_output: serde_json::Value,
874        result_type: String,
875        execution_time_ms: u64,
876    ) -> Self {
877        Self {
878            output: None,
879            json_output: Some(json_output),
880            result_type: Some(result_type),
881            success: true,
882            execution_time_ms,
883            metadata: HashMap::new(),
884        }
885    }
886
887    pub fn failure(error: String, execution_time_ms: u64) -> Self {
888        Self {
889            output: Some(error),
890            json_output: None,
891            result_type: None,
892            success: false,
893            execution_time_ms,
894            metadata: HashMap::new(),
895        }
896    }
897
898    pub fn with_metadata(mut self, key: String, value: String) -> Self {
899        self.metadata.insert(key, value);
900        self
901    }
902}
903
904/// Session metadata for listing and filtering
905#[derive(Debug, Clone, Serialize, Deserialize)]
906pub struct SessionInfo {
907    pub id: String,
908    pub created_at: DateTime<Utc>,
909    pub updated_at: DateTime<Utc>,
910    /// The last known model used in this session
911    pub last_model: Option<ModelId>,
912    pub message_count: usize,
913    pub metadata: HashMap<String, String>,
914}
915
916impl From<&Session> for SessionInfo {
917    fn from(session: &Session) -> Self {
918        Self {
919            id: session.id.clone(),
920            created_at: session.created_at,
921            updated_at: session.updated_at,
922            last_model: None, // TODO: Track last model used from events
923            message_count: session.state.message_count(),
924            metadata: session.config.metadata.clone(),
925        }
926    }
927}
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932    use crate::app::conversation::{Message, MessageData, UserContent};
933    use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
934    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
935    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
936    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
937
938    #[test]
939    fn test_session_creation() {
940        let config = SessionConfig {
941            workspace: WorkspaceConfig::Local {
942                path: PathBuf::from("/test/path"),
943            },
944            workspace_ref: None,
945            workspace_id: None,
946            repo_ref: None,
947            parent_session_id: None,
948            workspace_name: None,
949            tool_config: SessionToolConfig::default(),
950            system_prompt: None,
951            primary_agent_id: None,
952            policy_overrides: SessionPolicyOverrides::empty(),
953            metadata: HashMap::new(),
954            default_model: test_model(),
955            auto_compaction: AutoCompactionConfig::default(),
956        };
957        let session = Session::new("test-session".to_string(), config.clone());
958
959        assert_eq!(session.id, "test-session");
960        assert_eq!(
961            session
962                .config
963                .tool_config
964                .approval_policy
965                .tool_decision("any_tool"),
966            ToolDecision::Ask
967        );
968        assert_eq!(session.state.message_count(), 0);
969    }
970
971    #[test]
972    fn test_tool_approval_policy_prompt_unapproved() {
973        let policy = ToolApprovalPolicy {
974            default_behavior: UnapprovedBehavior::Prompt,
975            preapproved: ApprovalRules {
976                tools: ["read_file", "list_files"]
977                    .iter()
978                    .map(|s| (*s).to_string())
979                    .collect(),
980                per_tool: HashMap::new(),
981            },
982        };
983
984        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
985        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
986    }
987
988    #[test]
989    fn test_tool_approval_policy_deny_unapproved() {
990        let policy = ToolApprovalPolicy {
991            default_behavior: UnapprovedBehavior::Deny,
992            preapproved: ApprovalRules {
993                tools: ["read_file", "list_files"]
994                    .iter()
995                    .map(|s| (*s).to_string())
996                    .collect(),
997                per_tool: HashMap::new(),
998            },
999        };
1000
1001        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1002        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
1003    }
1004
1005    #[test]
1006    fn test_tool_approval_policy_default() {
1007        let policy = ToolApprovalPolicy::default();
1008
1009        assert_eq!(
1010            policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1011            ToolDecision::Allow
1012        );
1013        assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1014    }
1015
1016    #[test]
1017    fn test_tool_approval_policy_allow_unapproved() {
1018        let policy = ToolApprovalPolicy {
1019            default_behavior: UnapprovedBehavior::Allow,
1020            preapproved: ApprovalRules {
1021                tools: ["read_file", "list_files"]
1022                    .iter()
1023                    .map(|s| (*s).to_string())
1024                    .collect(),
1025                per_tool: HashMap::new(),
1026            },
1027        };
1028
1029        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1030        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1031    }
1032
1033    #[test]
1034    fn test_tool_approval_policy_overrides_union_rules() {
1035        let base_policy = ToolApprovalPolicy {
1036            default_behavior: UnapprovedBehavior::Prompt,
1037            preapproved: ApprovalRules {
1038                tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1039                per_tool: [
1040                    (
1041                        BASH_TOOL_NAME.to_string(),
1042                        ToolRule::Bash {
1043                            patterns: vec!["git status".to_string()],
1044                        },
1045                    ),
1046                    (
1047                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1048                        ToolRule::DispatchAgent {
1049                            agent_patterns: vec!["explore".to_string()],
1050                        },
1051                    ),
1052                ]
1053                .into_iter()
1054                .collect(),
1055            },
1056        };
1057
1058        let overrides = ToolApprovalPolicyOverrides {
1059            default_behavior: Some(UnapprovedBehavior::Deny),
1060            preapproved: ApprovalRulesOverrides {
1061                tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1062                per_tool: [
1063                    (
1064                        BASH_TOOL_NAME.to_string(),
1065                        ToolRuleOverrides::Bash {
1066                            patterns: vec!["git log".to_string()],
1067                        },
1068                    ),
1069                    (
1070                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1071                        ToolRuleOverrides::DispatchAgent {
1072                            agent_patterns: vec!["review".to_string()],
1073                        },
1074                    ),
1075                ]
1076                .into_iter()
1077                .collect(),
1078            },
1079        };
1080
1081        let merged = overrides.apply_to(&base_policy);
1082
1083        assert_eq!(merged.default_behavior, UnapprovedBehavior::Deny);
1084        assert!(merged.preapproved.tools.contains("read_file"));
1085        assert!(merged.preapproved.tools.contains("write_file"));
1086
1087        let bash_patterns = match merged
1088            .preapproved
1089            .per_tool
1090            .get(BASH_TOOL_NAME)
1091            .expect("bash rule")
1092        {
1093            ToolRule::Bash { patterns } => patterns,
1094            ToolRule::DispatchAgent { .. } => {
1095                panic!("Unexpected bash rule: dispatch agent")
1096            }
1097        };
1098        assert!(bash_patterns.contains(&"git status".to_string()));
1099        assert!(bash_patterns.contains(&"git log".to_string()));
1100        assert_eq!(bash_patterns.len(), 2);
1101
1102        let agent_patterns = match merged
1103            .preapproved
1104            .per_tool
1105            .get(DISPATCH_AGENT_TOOL_NAME)
1106            .expect("dispatch_agent rule")
1107        {
1108            ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1109            ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1110        };
1111        assert!(agent_patterns.contains(&"explore".to_string()));
1112        assert!(agent_patterns.contains(&"review".to_string()));
1113        assert_eq!(agent_patterns.len(), 2);
1114    }
1115
1116    #[test]
1117    fn test_bash_pattern_matching() {
1118        let policy = ToolApprovalPolicy {
1119            default_behavior: UnapprovedBehavior::Prompt,
1120            preapproved: ApprovalRules {
1121                tools: HashSet::new(),
1122                per_tool: [(
1123                    "bash".to_string(),
1124                    ToolRule::Bash {
1125                        patterns: vec![
1126                            "git status".to_string(),
1127                            "git log*".to_string(),
1128                            "git * --oneline".to_string(),
1129                            "ls -?a*".to_string(),
1130                            "cargo build*".to_string(),
1131                        ],
1132                    },
1133                )]
1134                .into_iter()
1135                .collect(),
1136            },
1137        };
1138
1139        assert!(policy.is_bash_pattern_preapproved("git status"));
1140        assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1141        assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1142        assert!(policy.is_bash_pattern_preapproved("ls -la"));
1143        assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1144        assert!(!policy.is_bash_pattern_preapproved("git commit"));
1145        assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1146        assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1147    }
1148
1149    #[test]
1150    fn test_dispatch_agent_pattern_matching() {
1151        let policy = ToolApprovalPolicy {
1152            default_behavior: UnapprovedBehavior::Prompt,
1153            preapproved: ApprovalRules {
1154                tools: HashSet::new(),
1155                per_tool: [(
1156                    "dispatch_agent".to_string(),
1157                    ToolRule::DispatchAgent {
1158                        agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1159                    },
1160                )]
1161                .into_iter()
1162                .collect(),
1163            },
1164        };
1165
1166        assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1167        assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1168        assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1169    }
1170
1171    #[test]
1172    fn test_session_state_validation() {
1173        let mut state = SessionState::default();
1174
1175        // Valid empty state
1176        assert!(state.validate().is_ok());
1177
1178        // Add a message
1179        let message = Message {
1180            data: MessageData::User {
1181                content: vec![UserContent::Text {
1182                    text: "Hello".to_string(),
1183                }],
1184            },
1185            timestamp: 123_456_789,
1186            id: "msg1".to_string(),
1187            parent_message_id: None,
1188        };
1189        state.add_message(message);
1190
1191        assert!(state.validate().is_ok());
1192        assert_eq!(state.message_count(), 1);
1193    }
1194
1195    #[test]
1196    fn test_tool_call_state_tracking() {
1197        let mut state = SessionState::default();
1198
1199        let tool_call = ToolCall {
1200            id: "tool1".to_string(),
1201            name: "read_file".to_string(),
1202            parameters: serde_json::json!({"path": "/test.txt"}),
1203        };
1204
1205        state.add_tool_call(tool_call.clone());
1206        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1207
1208        state
1209            .update_tool_call_status("tool1", ToolCallStatus::Executing)
1210            .unwrap();
1211        let tool_state = state.tool_calls.get("tool1").unwrap();
1212        assert!(tool_state.started_at.is_some());
1213        assert!(!tool_state.is_complete());
1214
1215        state
1216            .update_tool_call_status("tool1", ToolCallStatus::Completed)
1217            .unwrap();
1218        let tool_state = state.tool_calls.get("tool1").unwrap();
1219        assert!(tool_state.completed_at.is_some());
1220        assert!(tool_state.is_complete());
1221    }
1222
1223    #[test]
1224    fn test_session_tool_config_default() {
1225        let config = SessionToolConfig::default();
1226        assert!(config.backends.is_empty());
1227    }
1228
1229    #[test]
1230    fn test_tool_filter_exclude() {
1231        let excluded =
1232            ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1233
1234        if let ToolFilter::Exclude(tools) = &excluded {
1235            assert_eq!(tools.len(), 2);
1236            assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1237            assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1238        } else {
1239            panic!("Expected ToolFilter::Exclude");
1240        }
1241    }
1242
1243    #[test]
1244    fn test_session_tool_config_read_only() {
1245        let config = SessionToolConfig::read_only();
1246        assert_eq!(config.backends.len(), 0);
1247        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1248        assert_eq!(
1249            config.approval_policy.default_behavior,
1250            UnapprovedBehavior::Prompt
1251        );
1252    }
1253
1254    #[tokio::test]
1255    async fn test_session_config_build_registry_no_default_backends() {
1256        // Test that BackendRegistry only contains user-configured backends.
1257        // Static tools (dispatch_agent, web_fetch) are now in ToolRegistry,
1258        // not BackendRegistry.
1259        let config = SessionConfig {
1260            workspace: WorkspaceConfig::Local {
1261                path: PathBuf::from("/test/path"),
1262            },
1263            workspace_ref: None,
1264            workspace_id: None,
1265            repo_ref: None,
1266            parent_session_id: None,
1267            workspace_name: None,
1268            tool_config: SessionToolConfig::default(), // No backends configured
1269            system_prompt: None,
1270            primary_agent_id: None,
1271            policy_overrides: SessionPolicyOverrides::empty(),
1272            metadata: HashMap::new(),
1273            default_model: test_model(),
1274            auto_compaction: AutoCompactionConfig::default(),
1275        };
1276
1277        let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1278        let schemas = registry.get_tool_schemas().await;
1279
1280        assert!(
1281            schemas.is_empty(),
1282            "BackendRegistry should be empty with default config; got: {:?}",
1283            schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1284        );
1285    }
1286
1287    // Test removed: workspace tools are no longer in the registry
1288
1289    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
1290
1291    // Test removed: workspace backend no longer exists in the registry
1292
1293    #[test]
1294    fn test_mcp_status_tracking() {
1295        // Test that MCP server info is properly tracked in session state
1296        let mut session_state = SessionState::default();
1297
1298        // Add some MCP server info
1299        let mcp_info = McpServerInfo {
1300            server_name: "test-server".to_string(),
1301            transport: crate::tools::McpTransport::Stdio {
1302                command: "python".to_string(),
1303                args: vec!["-m".to_string(), "test_server".to_string()],
1304            },
1305            state: McpConnectionState::Connected {
1306                tool_names: vec![
1307                    "tool1".to_string(),
1308                    "tool2".to_string(),
1309                    "tool3".to_string(),
1310                    "tool4".to_string(),
1311                    "tool5".to_string(),
1312                ],
1313            },
1314            last_updated: Utc::now(),
1315        };
1316
1317        session_state
1318            .mcp_servers
1319            .insert("test-server".to_string(), mcp_info.clone());
1320
1321        // Verify it's stored
1322        assert_eq!(session_state.mcp_servers.len(), 1);
1323        let stored = session_state.mcp_servers.get("test-server").unwrap();
1324        assert_eq!(stored.server_name, "test-server");
1325        assert!(matches!(
1326            stored.state,
1327            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1328        ));
1329
1330        // Test failed connection
1331        let failed_info = McpServerInfo {
1332            server_name: "failed-server".to_string(),
1333            transport: crate::tools::McpTransport::Tcp {
1334                host: "localhost".to_string(),
1335                port: 9999,
1336            },
1337            state: McpConnectionState::Failed {
1338                error: "Connection refused".to_string(),
1339            },
1340            last_updated: Utc::now(),
1341        };
1342
1343        session_state
1344            .mcp_servers
1345            .insert("failed-server".to_string(), failed_info);
1346        assert_eq!(session_state.mcp_servers.len(), 2);
1347    }
1348
1349    #[tokio::test]
1350    async fn test_mcp_server_tracking_in_build_registry() {
1351        // Create a session config with both good and bad MCP servers
1352        let mut config = SessionConfig::read_only(test_model());
1353
1354        // This one should fail (invalid transport)
1355        config.tool_config.backends.push(BackendConfig::Mcp {
1356            server_name: "bad-server".to_string(),
1357            transport: crate::tools::McpTransport::Tcp {
1358                host: "nonexistent.invalid".to_string(),
1359                port: 12345,
1360            },
1361            tool_filter: ToolFilter::All,
1362        });
1363
1364        // This one would succeed if we had a real server running
1365        config.tool_config.backends.push(BackendConfig::Mcp {
1366            server_name: "good-server".to_string(),
1367            transport: crate::tools::McpTransport::Stdio {
1368                command: "echo".to_string(),
1369                args: vec!["test".to_string()],
1370            },
1371            tool_filter: ToolFilter::All,
1372        });
1373
1374        let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1375
1376        // Should have tracked both servers
1377        assert_eq!(mcp_servers.len(), 2);
1378
1379        // Check the bad server
1380        let bad_server = mcp_servers.get("bad-server").unwrap();
1381        assert_eq!(bad_server.server_name, "bad-server");
1382        assert!(matches!(
1383            bad_server.state,
1384            McpConnectionState::Failed { .. }
1385        ));
1386
1387        // Check the good server (will also fail in tests since echo isn't an MCP server)
1388        let good_server = mcp_servers.get("good-server").unwrap();
1389        assert_eq!(good_server.server_name, "good-server");
1390        assert!(matches!(
1391            good_server.state,
1392            McpConnectionState::Failed { .. }
1393        ));
1394    }
1395
1396    #[test]
1397    fn test_backend_config_mcp_variant() {
1398        let mcp_config = BackendConfig::Mcp {
1399            server_name: "test-mcp".to_string(),
1400            transport: crate::tools::McpTransport::Stdio {
1401                command: "python".to_string(),
1402                args: vec!["-m".to_string(), "test_server".to_string()],
1403            },
1404            tool_filter: ToolFilter::All,
1405        };
1406
1407        let BackendConfig::Mcp {
1408            server_name,
1409            transport,
1410            ..
1411        } = mcp_config;
1412
1413        assert_eq!(server_name, "test-mcp");
1414        if let crate::tools::McpTransport::Stdio { command, args } = transport {
1415            assert_eq!(command, "python");
1416            assert_eq!(args.len(), 2);
1417        } else {
1418            panic!("Expected Stdio transport");
1419        }
1420    }
1421}