Skip to main content

steer_core/session/
state.rs

1use crate::config::model::ModelId;
2use crate::error::Result;
3use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
4use chrono::{DateTime, Utc};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use crate::app::{Message, MessageData};
12use crate::tools::{BackendRegistry, McpTransport, ToolBackend};
13use steer_tools::{ToolCall, result::ToolResult};
14
15/// State of an MCP server connection
16#[derive(Debug, Clone)]
17pub enum McpConnectionState {
18    /// Currently attempting to connect
19    Connecting,
20    /// Successfully connected
21    Connected {
22        /// Names of tools available from this server
23        tool_names: Vec<String>,
24    },
25    /// Gracefully disconnected
26    Disconnected {
27        /// Optional reason for disconnection
28        reason: Option<String>,
29    },
30    /// Failed to connect
31    Failed {
32        /// Error message describing the failure
33        error: String,
34    },
35}
36
37/// Information about an MCP server
38#[derive(Debug, Clone)]
39pub struct McpServerInfo {
40    /// The configured server name
41    pub server_name: String,
42    /// The transport configuration
43    pub transport: McpTransport,
44    /// Current connection state
45    pub state: McpConnectionState,
46    /// Timestamp when this state was last updated
47    pub last_updated: DateTime<Utc>,
48}
49
50/// Defines the primary execution environment for a session's workspace
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum WorkspaceConfig {
54    Local {
55        path: PathBuf,
56    },
57    Remote {
58        agent_address: String,
59        auth: Option<RemoteAuth>,
60    },
61}
62
63impl WorkspaceConfig {
64    pub fn get_path(&self) -> Option<String> {
65        match self {
66            WorkspaceConfig::Local { path } => Some(path.to_string_lossy().to_string()),
67            WorkspaceConfig::Remote { agent_address, .. } => Some(agent_address.clone()),
68        }
69    }
70
71    /// Convert to steer_workspace::WorkspaceConfig
72    pub fn to_workspace_config(&self) -> steer_workspace::WorkspaceConfig {
73        match self {
74            WorkspaceConfig::Local { path } => {
75                steer_workspace::WorkspaceConfig::Local { path: path.clone() }
76            }
77            WorkspaceConfig::Remote {
78                agent_address,
79                auth,
80            } => steer_workspace::WorkspaceConfig::Remote {
81                address: agent_address.clone(),
82                auth: auth.as_ref().map(|a| a.to_workspace_auth()),
83            },
84        }
85    }
86}
87
88impl Default for WorkspaceConfig {
89    fn default() -> Self {
90        Self::Local {
91            path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
92        }
93    }
94}
95
96/// Complete session representation
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Session {
99    pub id: String,
100    pub created_at: DateTime<Utc>,
101    pub updated_at: DateTime<Utc>,
102    pub config: SessionConfig,
103    pub state: SessionState,
104}
105
106impl Session {
107    pub fn new(id: String, config: SessionConfig) -> Self {
108        let now = Utc::now();
109        Self {
110            id,
111            created_at: now,
112            updated_at: now,
113            config,
114            state: SessionState::default(),
115        }
116    }
117
118    pub fn update_timestamp(&mut self) {
119        self.updated_at = Utc::now();
120    }
121
122    /// Check if session has any recent activity
123    pub fn is_recently_active(&self, threshold: chrono::Duration) -> bool {
124        let cutoff = Utc::now() - threshold;
125        self.updated_at > cutoff
126    }
127
128    /// Build a workspace from this session's configuration
129    pub async fn build_workspace(&self) -> Result<Arc<dyn crate::workspace::Workspace>> {
130        crate::workspace::create_workspace(&self.config.workspace.to_workspace_config()).await
131    }
132}
133
134/// Session configuration - immutable once created
135#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
136pub struct SessionConfig {
137    pub workspace: WorkspaceConfig,
138    #[serde(default)]
139    pub workspace_ref: Option<crate::workspace::WorkspaceRef>,
140    #[serde(default)]
141    pub workspace_id: Option<crate::workspace::WorkspaceId>,
142    #[serde(default)]
143    pub repo_ref: Option<crate::workspace::RepoRef>,
144    #[serde(default)]
145    pub parent_session_id: Option<crate::app::domain::types::SessionId>,
146    #[serde(default)]
147    pub workspace_name: Option<String>,
148    pub tool_config: SessionToolConfig,
149    /// Optional custom system prompt to use for the session. If `None`, Steer will
150    /// fall back to its built-in default prompt.
151    pub system_prompt: Option<String>,
152    /// Primary agent mode for this session. Defaults to "normal" if unset.
153    #[serde(default)]
154    pub primary_agent_id: Option<String>,
155    /// User-controlled policy overrides that apply on top of the primary agent base policy.
156    #[serde(default = "SessionPolicyOverrides::empty")]
157    pub policy_overrides: SessionPolicyOverrides,
158    pub metadata: HashMap<String, String>,
159    pub default_model: ModelId,
160    #[serde(default)]
161    pub auto_compaction: AutoCompactionConfig,
162}
163
164impl SessionConfig {
165    /// Build a BackendRegistry from MCP server configurations.
166    /// Returns the registry and a map of MCP server connection states.
167    pub async fn build_registry(
168        &self,
169    ) -> Result<(BackendRegistry, HashMap<String, McpServerInfo>)> {
170        let mut registry = BackendRegistry::new();
171        let mut mcp_servers = HashMap::new();
172
173        for backend_config in &self.tool_config.backends {
174            let BackendConfig::Mcp {
175                server_name,
176                transport,
177                tool_filter,
178            } = backend_config;
179
180            tracing::info!(
181                "Attempting to initialize MCP backend '{}' with transport: {:?}",
182                server_name,
183                transport
184            );
185
186            let mut server_info = McpServerInfo {
187                server_name: server_name.clone(),
188                transport: transport.clone(),
189                state: McpConnectionState::Connecting,
190                last_updated: Utc::now(),
191            };
192
193            match crate::tools::McpBackend::new(
194                server_name.clone(),
195                transport.clone(),
196                tool_filter.clone(),
197            )
198            .await
199            {
200                Ok(mcp_backend) => {
201                    let tool_names = mcp_backend.supported_tools().await;
202                    let tool_count = tool_names.len();
203                    tracing::info!(
204                        "Successfully initialized MCP backend '{}' with {} tools",
205                        server_name,
206                        tool_count
207                    );
208                    server_info.state = McpConnectionState::Connected { tool_names };
209                    server_info.last_updated = Utc::now();
210                    registry
211                        .register(format!("mcp_{server_name}"), Arc::new(mcp_backend))
212                        .await;
213                }
214                Err(e) => {
215                    tracing::error!("Failed to initialize MCP backend '{}': {}", server_name, e);
216                    server_info.state = McpConnectionState::Failed {
217                        error: e.to_string(),
218                    };
219                    server_info.last_updated = Utc::now();
220                }
221            }
222
223            mcp_servers.insert(server_name.clone(), server_info);
224        }
225
226        Ok((registry, mcp_servers))
227    }
228
229    /// Filter tools based on visibility settings
230    pub fn filter_tools_by_visibility(
231        &self,
232        tools: Vec<steer_tools::ToolSchema>,
233    ) -> Vec<steer_tools::ToolSchema> {
234        match &self.tool_config.visibility {
235            ToolVisibility::All => tools,
236            ToolVisibility::ReadOnly => {
237                let read_only_names: HashSet<String> = READ_ONLY_TOOL_NAMES
238                    .iter()
239                    .map(|name| (*name).to_string())
240                    .collect();
241
242                tools
243                    .into_iter()
244                    .filter(|schema| read_only_names.contains(&schema.name))
245                    .collect()
246            }
247            ToolVisibility::Whitelist(allowed) => tools
248                .into_iter()
249                .filter(|schema| allowed.contains(&schema.name))
250                .collect(),
251            ToolVisibility::Blacklist(blocked) => tools
252                .into_iter()
253                .filter(|schema| !blocked.contains(&schema.name))
254                .collect(),
255        }
256    }
257
258    /// Minimal read-only configuration
259    #[cfg(test)]
260    pub fn read_only(default_model: ModelId) -> Self {
261        Self {
262            workspace: WorkspaceConfig::Local {
263                path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
264            },
265            workspace_ref: None,
266            workspace_id: None,
267            repo_ref: None,
268            parent_session_id: None,
269            workspace_name: None,
270            tool_config: SessionToolConfig::read_only(),
271            system_prompt: None,
272            primary_agent_id: None,
273            policy_overrides: SessionPolicyOverrides::empty(),
274            metadata: HashMap::new(),
275            default_model,
276            auto_compaction: AutoCompactionConfig::default(),
277        }
278    }
279}
280
281/// Configuration for automatic context-window compaction.
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
283pub struct AutoCompactionConfig {
284    pub enabled: bool,
285    pub threshold_percent: u32,
286}
287
288impl Default for AutoCompactionConfig {
289    fn default() -> Self {
290        Self {
291            enabled: true,
292            threshold_percent: 90,
293        }
294    }
295}
296
297impl AutoCompactionConfig {
298    pub fn threshold_ratio(&self) -> f64 {
299        f64::from(self.threshold_percent) / 100.0
300    }
301}
302
303/// User-controlled policy overrides applied on top of a primary agent base policy.
304#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
305pub struct SessionPolicyOverrides {
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub default_model: Option<ModelId>,
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub tool_visibility: Option<ToolVisibility>,
310    #[serde(default = "ToolApprovalPolicyOverrides::empty")]
311    pub approval_policy: ToolApprovalPolicyOverrides,
312}
313
314impl SessionPolicyOverrides {
315    pub fn empty() -> Self {
316        Self {
317            default_model: None,
318            tool_visibility: None,
319            approval_policy: ToolApprovalPolicyOverrides::empty(),
320        }
321    }
322
323    pub fn is_empty(&self) -> bool {
324        self.default_model.is_none()
325            && self.tool_visibility.is_none()
326            && self.approval_policy.is_empty()
327    }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
331pub struct ToolApprovalPolicyOverrides {
332    #[serde(default = "ApprovalRulesOverrides::empty")]
333    pub preapproved: ApprovalRulesOverrides,
334}
335
336impl ToolApprovalPolicyOverrides {
337    pub fn empty() -> Self {
338        Self {
339            preapproved: ApprovalRulesOverrides::empty(),
340        }
341    }
342
343    pub fn is_empty(&self) -> bool {
344        self.preapproved.is_empty()
345    }
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
349pub struct ApprovalRulesOverrides {
350    #[serde(default)]
351    pub tools: HashSet<String>,
352    #[serde(default)]
353    pub per_tool: HashMap<String, ToolRuleOverrides>,
354}
355
356impl ApprovalRulesOverrides {
357    pub fn empty() -> Self {
358        Self {
359            tools: HashSet::new(),
360            per_tool: HashMap::new(),
361        }
362    }
363
364    pub fn is_empty(&self) -> bool {
365        self.tools.is_empty() && self.per_tool.is_empty()
366    }
367}
368
369#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
370#[serde(tag = "type", rename_all = "snake_case")]
371pub enum ToolRuleOverrides {
372    Bash { patterns: Vec<String> },
373    DispatchAgent { agent_patterns: Vec<String> },
374}
375
376/// Tool visibility configuration - controls which tools are shown to the AI agent
377#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
378#[serde(tag = "type", content = "tools", rename_all = "snake_case")]
379pub enum ToolVisibility {
380    /// Show all registered tools to the AI
381    #[default]
382    All,
383
384    /// Only show read-only tools to the AI
385    ReadOnly,
386
387    /// Show only specific tools to the AI (whitelist)
388    Whitelist(HashSet<String>),
389
390    /// Hide specific tools from the AI (blacklist)
391    Blacklist(HashSet<String>),
392}
393
394#[derive(Debug, Clone, Copy, PartialEq, Eq)]
395pub enum ToolDecision {
396    Allow,
397    Ask,
398    Deny,
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
402#[serde(rename_all = "snake_case")]
403pub enum UnapprovedBehavior {
404    #[default]
405    Prompt,
406    Deny,
407    Allow,
408}
409
410#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
411#[serde(tag = "type", rename_all = "snake_case")]
412pub enum ToolRule {
413    Bash { patterns: Vec<String> },
414    DispatchAgent { agent_patterns: Vec<String> },
415}
416
417#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
418pub struct ApprovalRules {
419    #[serde(default)]
420    pub tools: HashSet<String>,
421    #[serde(default)]
422    pub per_tool: HashMap<String, ToolRule>,
423}
424
425impl ApprovalRules {
426    pub fn is_empty(&self) -> bool {
427        self.tools.is_empty() && self.per_tool.is_empty()
428    }
429
430    pub fn bash_patterns(&self) -> Option<&[String]> {
431        self.per_tool.get("bash").and_then(|rule| match rule {
432            ToolRule::Bash { patterns } => Some(patterns.as_slice()),
433            ToolRule::DispatchAgent { .. } => None,
434        })
435    }
436
437    pub fn dispatch_agent_rule(&self) -> Option<&[String]> {
438        self.per_tool
439            .get("dispatch_agent")
440            .and_then(|rule| match rule {
441                ToolRule::DispatchAgent { agent_patterns } => Some(agent_patterns.as_slice()),
442                ToolRule::Bash { .. } => None,
443            })
444    }
445}
446
447#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
448pub struct ToolApprovalPolicy {
449    pub default_behavior: UnapprovedBehavior,
450    #[serde(default)]
451    pub preapproved: ApprovalRules,
452}
453
454impl Default for ToolApprovalPolicy {
455    fn default() -> Self {
456        Self {
457            default_behavior: UnapprovedBehavior::Prompt,
458            preapproved: ApprovalRules {
459                tools: READ_ONLY_TOOL_NAMES
460                    .iter()
461                    .map(|name| (*name).to_string())
462                    .collect(),
463                per_tool: HashMap::new(),
464            },
465        }
466    }
467}
468
469impl ToolApprovalPolicy {
470    pub fn tool_decision(&self, tool_name: &str) -> ToolDecision {
471        if self.preapproved.tools.contains(tool_name) {
472            ToolDecision::Allow
473        } else {
474            match self.default_behavior {
475                UnapprovedBehavior::Prompt => ToolDecision::Ask,
476                UnapprovedBehavior::Deny => ToolDecision::Deny,
477                UnapprovedBehavior::Allow => ToolDecision::Allow,
478            }
479        }
480    }
481
482    pub fn is_bash_pattern_preapproved(&self, command: &str) -> bool {
483        let Some(patterns) = self.preapproved.bash_patterns() else {
484            return false;
485        };
486        patterns.iter().any(|pattern| {
487            if pattern == command {
488                return true;
489            }
490            glob::Pattern::new(pattern)
491                .map(|glob| glob.matches(command))
492                .unwrap_or(false)
493        })
494    }
495
496    pub fn is_dispatch_agent_pattern_preapproved(&self, agent_id: &str) -> bool {
497        let Some(patterns) = self.preapproved.dispatch_agent_rule() else {
498            return false;
499        };
500        patterns.iter().any(|pattern| {
501            if pattern == agent_id {
502                return true;
503            }
504            glob::Pattern::new(pattern)
505                .map(|glob| glob.matches(agent_id))
506                .unwrap_or(false)
507        })
508    }
509
510    pub fn pre_approved_tools(&self) -> &HashSet<String> {
511        &self.preapproved.tools
512    }
513}
514
515impl ToolApprovalPolicyOverrides {
516    pub fn apply_to(&self, base: &ToolApprovalPolicy) -> ToolApprovalPolicy {
517        let mut merged = base.clone();
518
519        if !self.preapproved.tools.is_empty() {
520            merged
521                .preapproved
522                .tools
523                .extend(self.preapproved.tools.iter().cloned());
524        }
525
526        for (tool_name, override_rule) in &self.preapproved.per_tool {
527            let base_rule = merged.preapproved.per_tool.get(tool_name);
528            let merged_rule = merge_tool_rule_override(base_rule, override_rule);
529            merged
530                .preapproved
531                .per_tool
532                .insert(tool_name.clone(), merged_rule);
533        }
534
535        merged
536    }
537}
538
539fn merge_tool_rule_override(
540    base: Option<&ToolRule>,
541    override_rule: &ToolRuleOverrides,
542) -> ToolRule {
543    match (base, override_rule) {
544        (Some(ToolRule::Bash { patterns }), ToolRuleOverrides::Bash { patterns: extra }) => {
545            ToolRule::Bash {
546                patterns: merge_patterns(patterns, extra),
547            }
548        }
549        (
550            Some(ToolRule::DispatchAgent { agent_patterns }),
551            ToolRuleOverrides::DispatchAgent {
552                agent_patterns: extra,
553            },
554        ) => ToolRule::DispatchAgent {
555            agent_patterns: merge_patterns(agent_patterns, extra),
556        },
557        (_, ToolRuleOverrides::Bash { patterns }) => ToolRule::Bash {
558            patterns: patterns.clone(),
559        },
560        (_, ToolRuleOverrides::DispatchAgent { agent_patterns }) => ToolRule::DispatchAgent {
561            agent_patterns: agent_patterns.clone(),
562        },
563    }
564}
565
566fn merge_patterns(base: &[String], extra: &[String]) -> Vec<String> {
567    let mut merged = base.to_vec();
568    for pattern in extra {
569        if !merged.iter().any(|existing| existing == pattern) {
570            merged.push(pattern.clone());
571        }
572    }
573    merged
574}
575
576/// Authentication configuration for remote backends
577#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
578pub enum RemoteAuth {
579    Bearer { token: String },
580    ApiKey { key: String },
581}
582
583impl RemoteAuth {
584    /// Convert to steer_workspace RemoteAuth type
585    pub fn to_workspace_auth(&self) -> steer_workspace::RemoteAuth {
586        match self {
587            RemoteAuth::Bearer { token } => steer_workspace::RemoteAuth::BearerToken(token.clone()),
588            RemoteAuth::ApiKey { key } => steer_workspace::RemoteAuth::ApiKey(key.clone()),
589        }
590    }
591}
592
593/// Tool filtering configuration for backends
594#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
595#[serde(rename_all = "snake_case")]
596#[derive(Default)]
597pub enum ToolFilter {
598    /// Include all available tools
599    #[default]
600    All,
601    /// Include only the specified tools
602    Include(Vec<String>),
603    /// Include all tools except the specified ones
604    Exclude(Vec<String>),
605}
606
607/// Configuration for MCP server backends
608#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
609#[serde(tag = "type", rename_all = "snake_case")]
610pub enum BackendConfig {
611    Mcp {
612        server_name: String,
613        transport: McpTransport,
614        tool_filter: ToolFilter,
615    },
616}
617
618#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
619pub struct SessionToolConfig {
620    pub backends: Vec<BackendConfig>,
621    pub visibility: ToolVisibility,
622    pub approval_policy: ToolApprovalPolicy,
623    pub metadata: HashMap<String, String>,
624}
625
626impl Default for SessionToolConfig {
627    fn default() -> Self {
628        Self {
629            backends: Vec::new(),
630            visibility: ToolVisibility::All,
631            approval_policy: ToolApprovalPolicy::default(),
632            metadata: HashMap::new(),
633        }
634    }
635}
636
637impl SessionToolConfig {
638    pub fn read_only() -> Self {
639        Self {
640            backends: Vec::new(),
641            visibility: ToolVisibility::ReadOnly,
642            approval_policy: ToolApprovalPolicy::default(),
643            metadata: HashMap::new(),
644        }
645    }
646}
647
648/// Mutable session state that changes during execution
649#[derive(Debug, Clone, Serialize, Deserialize, Default)]
650pub struct SessionState {
651    /// Conversation messages
652    pub messages: Vec<Message>,
653
654    /// Tool call tracking
655    pub tool_calls: HashMap<String, ToolCallState>,
656
657    /// Tools that have been approved for this session
658    pub approved_tools: HashSet<String>,
659
660    /// Bash commands that have been approved for this session (dynamically added)
661    #[serde(default)]
662    pub approved_bash_patterns: HashSet<String>,
663
664    /// Last processed event sequence number for replay
665    pub last_event_sequence: u64,
666
667    /// Additional runtime metadata
668    pub metadata: HashMap<String, String>,
669
670    /// The ID of the currently active message (head of selected branch)
671    /// None means use last message semantics for backward compatibility
672    #[serde(default, skip_serializing_if = "Option::is_none")]
673    pub active_message_id: Option<String>,
674
675    /// Status of MCP server connections
676    /// This is a transient field that is rebuilt on session activation
677    #[serde(default, skip_serializing, skip_deserializing)]
678    pub mcp_servers: HashMap<String, McpServerInfo>,
679}
680
681impl SessionState {
682    /// Add a message to the conversation
683    pub fn add_message(&mut self, message: Message) {
684        self.messages.push(message);
685    }
686
687    /// Get the number of messages in the conversation
688    pub fn message_count(&self) -> usize {
689        self.messages.len()
690    }
691
692    /// Get the last message in the conversation
693    pub fn last_message(&self) -> Option<&Message> {
694        self.messages.last()
695    }
696
697    /// Add a tool call to tracking
698    pub fn add_tool_call(&mut self, tool_call: ToolCall) {
699        let state = ToolCallState {
700            tool_call: tool_call.clone(),
701            status: ToolCallStatus::PendingApproval,
702            started_at: None,
703            completed_at: None,
704            result: None,
705        };
706        self.tool_calls.insert(tool_call.id, state);
707    }
708
709    /// Update tool call status
710    pub fn update_tool_call_status(
711        &mut self,
712        tool_call_id: &str,
713        status: ToolCallStatus,
714    ) -> std::result::Result<(), String> {
715        let tool_call = self
716            .tool_calls
717            .get_mut(tool_call_id)
718            .ok_or_else(|| format!("Tool call not found: {tool_call_id}"))?;
719
720        // Update timestamps based on status changes
721        match (&tool_call.status, &status) {
722            (_, ToolCallStatus::Executing) => {
723                tool_call.started_at = Some(Utc::now());
724            }
725            (_, ToolCallStatus::Completed | ToolCallStatus::Failed { .. }) => {
726                tool_call.completed_at = Some(Utc::now());
727            }
728            _ => {}
729        }
730
731        tool_call.status = status;
732        Ok(())
733    }
734
735    /// Approve a tool for future use
736    pub fn approve_tool(&mut self, tool_name: String) {
737        self.approved_tools.insert(tool_name);
738    }
739
740    /// Check if a tool is approved
741    pub fn is_tool_approved(&self, tool_name: &str) -> bool {
742        self.approved_tools.contains(tool_name)
743    }
744
745    /// Validate internal consistency
746    pub fn validate(&self) -> std::result::Result<(), String> {
747        // Check that all tool calls referenced in messages exist
748        for message in &self.messages {
749            let tool_calls = Self::extract_tool_calls_from_message(message);
750            if !tool_calls.is_empty() {
751                for tool_call_id in tool_calls {
752                    if !self.tool_calls.contains_key(&tool_call_id) {
753                        return Err(format!(
754                            "Message references unknown tool call: {tool_call_id}"
755                        ));
756                    }
757                }
758            }
759        }
760
761        Ok(())
762    }
763
764    /// Extract tool call IDs from a message
765    fn extract_tool_calls_from_message(message: &Message) -> Vec<String> {
766        let mut tool_call_ids = Vec::new();
767
768        match &message.data {
769            MessageData::Assistant { content, .. } => {
770                for c in content {
771                    if let crate::app::conversation::AssistantContent::ToolCall {
772                        tool_call, ..
773                    } = c
774                    {
775                        tool_call_ids.push(tool_call.id.clone());
776                    }
777                }
778            }
779            MessageData::Tool { tool_use_id, .. } => {
780                tool_call_ids.push(tool_use_id.clone());
781            }
782            MessageData::User { .. } => {}
783        }
784
785        tool_call_ids
786    }
787}
788
789/// Tool call state tracking
790#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct ToolCallState {
792    pub tool_call: ToolCall,
793    pub status: ToolCallStatus,
794    pub started_at: Option<DateTime<Utc>>,
795    pub completed_at: Option<DateTime<Utc>>,
796    pub result: Option<ToolResult>,
797}
798
799impl ToolCallState {
800    pub fn is_pending(&self) -> bool {
801        matches!(self.status, ToolCallStatus::PendingApproval)
802    }
803
804    pub fn is_complete(&self) -> bool {
805        matches!(
806            self.status,
807            ToolCallStatus::Completed | ToolCallStatus::Failed { .. }
808        )
809    }
810
811    pub fn duration(&self) -> Option<chrono::Duration> {
812        match (self.started_at, self.completed_at) {
813            (Some(start), Some(end)) => Some(end - start),
814            _ => None,
815        }
816    }
817}
818
819/// Tool call execution status
820#[derive(Debug, Clone, Serialize, Deserialize)]
821#[serde(tag = "status", rename_all = "snake_case")]
822pub enum ToolCallStatus {
823    PendingApproval,
824    Approved,
825    Denied,
826    Executing,
827    Completed,
828    Failed { error: String },
829}
830
831impl ToolCallStatus {
832    pub fn is_terminal(&self) -> bool {
833        matches!(
834            self,
835            ToolCallStatus::Completed | ToolCallStatus::Failed { .. } | ToolCallStatus::Denied
836        )
837    }
838}
839
840/// Tool execution statistics
841#[derive(Debug, Clone, Serialize, Deserialize)]
842pub struct ToolExecutionStats {
843    #[serde(skip_serializing_if = "Option::is_none")]
844    pub output: Option<String>, // Legacy string output
845    #[serde(skip_serializing_if = "Option::is_none")]
846    pub json_output: Option<serde_json::Value>, // Typed JSON output
847    pub result_type: Option<String>, // Type name (e.g., "SearchResult")
848    pub success: bool,
849    pub execution_time_ms: u64,
850    pub metadata: HashMap<String, String>,
851}
852
853impl ToolExecutionStats {
854    pub fn success(output: String, execution_time_ms: u64) -> Self {
855        Self {
856            output: Some(output),
857            json_output: None,
858            result_type: None,
859            success: true,
860            execution_time_ms,
861            metadata: HashMap::new(),
862        }
863    }
864
865    pub fn success_typed(
866        json_output: serde_json::Value,
867        result_type: String,
868        execution_time_ms: u64,
869    ) -> Self {
870        Self {
871            output: None,
872            json_output: Some(json_output),
873            result_type: Some(result_type),
874            success: true,
875            execution_time_ms,
876            metadata: HashMap::new(),
877        }
878    }
879
880    pub fn failure(error: String, execution_time_ms: u64) -> Self {
881        Self {
882            output: Some(error),
883            json_output: None,
884            result_type: None,
885            success: false,
886            execution_time_ms,
887            metadata: HashMap::new(),
888        }
889    }
890
891    pub fn with_metadata(mut self, key: String, value: String) -> Self {
892        self.metadata.insert(key, value);
893        self
894    }
895}
896
897/// Session metadata for listing and filtering
898#[derive(Debug, Clone, Serialize, Deserialize)]
899pub struct SessionInfo {
900    pub id: String,
901    pub created_at: DateTime<Utc>,
902    pub updated_at: DateTime<Utc>,
903    /// The last known model used in this session
904    pub last_model: Option<ModelId>,
905    pub message_count: usize,
906    pub metadata: HashMap<String, String>,
907}
908
909impl From<&Session> for SessionInfo {
910    fn from(session: &Session) -> Self {
911        Self {
912            id: session.id.clone(),
913            created_at: session.created_at,
914            updated_at: session.updated_at,
915            last_model: None, // TODO: Track last model used from events
916            message_count: session.state.message_count(),
917            metadata: session.config.metadata.clone(),
918        }
919    }
920}
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925    use crate::app::conversation::{Message, MessageData, UserContent};
926    use crate::config::model::builtin::claude_sonnet_4_5 as test_model;
927    use crate::tools::DISPATCH_AGENT_TOOL_NAME;
928    use crate::tools::static_tools::READ_ONLY_TOOL_NAMES;
929    use steer_tools::tools::{BASH_TOOL_NAME, EDIT_TOOL_NAME};
930
931    #[test]
932    fn test_session_creation() {
933        let config = SessionConfig {
934            workspace: WorkspaceConfig::Local {
935                path: PathBuf::from("/test/path"),
936            },
937            workspace_ref: None,
938            workspace_id: None,
939            repo_ref: None,
940            parent_session_id: None,
941            workspace_name: None,
942            tool_config: SessionToolConfig::default(),
943            system_prompt: None,
944            primary_agent_id: None,
945            policy_overrides: SessionPolicyOverrides::empty(),
946            metadata: HashMap::new(),
947            default_model: test_model(),
948            auto_compaction: AutoCompactionConfig::default(),
949        };
950        let session = Session::new("test-session".to_string(), config.clone());
951
952        assert_eq!(session.id, "test-session");
953        assert_eq!(
954            session
955                .config
956                .tool_config
957                .approval_policy
958                .tool_decision("any_tool"),
959            ToolDecision::Ask
960        );
961        assert_eq!(session.state.message_count(), 0);
962    }
963
964    #[test]
965    fn test_tool_approval_policy_prompt_unapproved() {
966        let policy = ToolApprovalPolicy {
967            default_behavior: UnapprovedBehavior::Prompt,
968            preapproved: ApprovalRules {
969                tools: ["read_file", "list_files"]
970                    .iter()
971                    .map(|s| (*s).to_string())
972                    .collect(),
973                per_tool: HashMap::new(),
974            },
975        };
976
977        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
978        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Ask);
979    }
980
981    #[test]
982    fn test_tool_approval_policy_deny_unapproved() {
983        let policy = ToolApprovalPolicy {
984            default_behavior: UnapprovedBehavior::Deny,
985            preapproved: ApprovalRules {
986                tools: ["read_file", "list_files"]
987                    .iter()
988                    .map(|s| (*s).to_string())
989                    .collect(),
990                per_tool: HashMap::new(),
991            },
992        };
993
994        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
995        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Deny);
996    }
997
998    #[test]
999    fn test_tool_approval_policy_default() {
1000        let policy = ToolApprovalPolicy::default();
1001
1002        assert_eq!(
1003            policy.tool_decision(READ_ONLY_TOOL_NAMES[0]),
1004            ToolDecision::Allow
1005        );
1006        assert_eq!(policy.tool_decision(BASH_TOOL_NAME), ToolDecision::Ask);
1007    }
1008
1009    #[test]
1010    fn test_tool_approval_policy_allow_unapproved() {
1011        let policy = ToolApprovalPolicy {
1012            default_behavior: UnapprovedBehavior::Allow,
1013            preapproved: ApprovalRules {
1014                tools: ["read_file", "list_files"]
1015                    .iter()
1016                    .map(|s| (*s).to_string())
1017                    .collect(),
1018                per_tool: HashMap::new(),
1019            },
1020        };
1021
1022        assert_eq!(policy.tool_decision("read_file"), ToolDecision::Allow);
1023        assert_eq!(policy.tool_decision("write_file"), ToolDecision::Allow);
1024    }
1025
1026    #[test]
1027    fn test_tool_approval_policy_overrides_union_rules() {
1028        let base_policy = ToolApprovalPolicy {
1029            default_behavior: UnapprovedBehavior::Prompt,
1030            preapproved: ApprovalRules {
1031                tools: ["read_file"].iter().map(|s| (*s).to_string()).collect(),
1032                per_tool: [
1033                    (
1034                        BASH_TOOL_NAME.to_string(),
1035                        ToolRule::Bash {
1036                            patterns: vec!["git status".to_string()],
1037                        },
1038                    ),
1039                    (
1040                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1041                        ToolRule::DispatchAgent {
1042                            agent_patterns: vec!["explore".to_string()],
1043                        },
1044                    ),
1045                ]
1046                .into_iter()
1047                .collect(),
1048            },
1049        };
1050
1051        let overrides = ToolApprovalPolicyOverrides {
1052            preapproved: ApprovalRulesOverrides {
1053                tools: ["write_file"].iter().map(|s| (*s).to_string()).collect(),
1054                per_tool: [
1055                    (
1056                        BASH_TOOL_NAME.to_string(),
1057                        ToolRuleOverrides::Bash {
1058                            patterns: vec!["git log".to_string()],
1059                        },
1060                    ),
1061                    (
1062                        DISPATCH_AGENT_TOOL_NAME.to_string(),
1063                        ToolRuleOverrides::DispatchAgent {
1064                            agent_patterns: vec!["review".to_string()],
1065                        },
1066                    ),
1067                ]
1068                .into_iter()
1069                .collect(),
1070            },
1071        };
1072
1073        let merged = overrides.apply_to(&base_policy);
1074
1075        assert_eq!(merged.default_behavior, UnapprovedBehavior::Prompt);
1076        assert!(merged.preapproved.tools.contains("read_file"));
1077        assert!(merged.preapproved.tools.contains("write_file"));
1078
1079        let bash_patterns = match merged
1080            .preapproved
1081            .per_tool
1082            .get(BASH_TOOL_NAME)
1083            .expect("bash rule")
1084        {
1085            ToolRule::Bash { patterns } => patterns,
1086            ToolRule::DispatchAgent { .. } => {
1087                panic!("Unexpected bash rule: dispatch agent")
1088            }
1089        };
1090        assert!(bash_patterns.contains(&"git status".to_string()));
1091        assert!(bash_patterns.contains(&"git log".to_string()));
1092        assert_eq!(bash_patterns.len(), 2);
1093
1094        let agent_patterns = match merged
1095            .preapproved
1096            .per_tool
1097            .get(DISPATCH_AGENT_TOOL_NAME)
1098            .expect("dispatch_agent rule")
1099        {
1100            ToolRule::DispatchAgent { agent_patterns } => agent_patterns,
1101            ToolRule::Bash { .. } => panic!("Unexpected dispatch_agent rule: bash"),
1102        };
1103        assert!(agent_patterns.contains(&"explore".to_string()));
1104        assert!(agent_patterns.contains(&"review".to_string()));
1105        assert_eq!(agent_patterns.len(), 2);
1106    }
1107
1108    #[test]
1109    fn test_bash_pattern_matching() {
1110        let policy = ToolApprovalPolicy {
1111            default_behavior: UnapprovedBehavior::Prompt,
1112            preapproved: ApprovalRules {
1113                tools: HashSet::new(),
1114                per_tool: [(
1115                    "bash".to_string(),
1116                    ToolRule::Bash {
1117                        patterns: vec![
1118                            "git status".to_string(),
1119                            "git log*".to_string(),
1120                            "git * --oneline".to_string(),
1121                            "ls -?a*".to_string(),
1122                            "cargo build*".to_string(),
1123                        ],
1124                    },
1125                )]
1126                .into_iter()
1127                .collect(),
1128            },
1129        };
1130
1131        assert!(policy.is_bash_pattern_preapproved("git status"));
1132        assert!(policy.is_bash_pattern_preapproved("git log --oneline"));
1133        assert!(policy.is_bash_pattern_preapproved("git show --oneline"));
1134        assert!(policy.is_bash_pattern_preapproved("ls -la"));
1135        assert!(policy.is_bash_pattern_preapproved("cargo build --release"));
1136        assert!(!policy.is_bash_pattern_preapproved("git commit"));
1137        assert!(!policy.is_bash_pattern_preapproved("ls -l"));
1138        assert!(!policy.is_bash_pattern_preapproved("rm -rf /"));
1139    }
1140
1141    #[test]
1142    fn test_dispatch_agent_pattern_matching() {
1143        let policy = ToolApprovalPolicy {
1144            default_behavior: UnapprovedBehavior::Prompt,
1145            preapproved: ApprovalRules {
1146                tools: HashSet::new(),
1147                per_tool: [(
1148                    "dispatch_agent".to_string(),
1149                    ToolRule::DispatchAgent {
1150                        agent_patterns: vec!["explore".to_string(), "explore-*".to_string()],
1151                    },
1152                )]
1153                .into_iter()
1154                .collect(),
1155            },
1156        };
1157
1158        assert!(policy.is_dispatch_agent_pattern_preapproved("explore"));
1159        assert!(policy.is_dispatch_agent_pattern_preapproved("explore-fast"));
1160        assert!(!policy.is_dispatch_agent_pattern_preapproved("build"));
1161    }
1162
1163    #[test]
1164    fn test_session_state_validation() {
1165        let mut state = SessionState::default();
1166
1167        // Valid empty state
1168        assert!(state.validate().is_ok());
1169
1170        // Add a message
1171        let message = Message {
1172            data: MessageData::User {
1173                content: vec![UserContent::Text {
1174                    text: "Hello".to_string(),
1175                }],
1176            },
1177            timestamp: 123_456_789,
1178            id: "msg1".to_string(),
1179            parent_message_id: None,
1180        };
1181        state.add_message(message);
1182
1183        assert!(state.validate().is_ok());
1184        assert_eq!(state.message_count(), 1);
1185    }
1186
1187    #[test]
1188    fn test_tool_call_state_tracking() {
1189        let mut state = SessionState::default();
1190
1191        let tool_call = ToolCall {
1192            id: "tool1".to_string(),
1193            name: "read_file".to_string(),
1194            parameters: serde_json::json!({"path": "/test.txt"}),
1195        };
1196
1197        state.add_tool_call(tool_call.clone());
1198        assert!(state.tool_calls.get("tool1").unwrap().is_pending());
1199
1200        state
1201            .update_tool_call_status("tool1", ToolCallStatus::Executing)
1202            .unwrap();
1203        let tool_state = state.tool_calls.get("tool1").unwrap();
1204        assert!(tool_state.started_at.is_some());
1205        assert!(!tool_state.is_complete());
1206
1207        state
1208            .update_tool_call_status("tool1", ToolCallStatus::Completed)
1209            .unwrap();
1210        let tool_state = state.tool_calls.get("tool1").unwrap();
1211        assert!(tool_state.completed_at.is_some());
1212        assert!(tool_state.is_complete());
1213    }
1214
1215    #[test]
1216    fn test_session_tool_config_default() {
1217        let config = SessionToolConfig::default();
1218        assert!(config.backends.is_empty());
1219    }
1220
1221    #[test]
1222    fn test_tool_filter_exclude() {
1223        let excluded =
1224            ToolFilter::Exclude(vec![BASH_TOOL_NAME.to_string(), EDIT_TOOL_NAME.to_string()]);
1225
1226        if let ToolFilter::Exclude(tools) = &excluded {
1227            assert_eq!(tools.len(), 2);
1228            assert!(tools.contains(&BASH_TOOL_NAME.to_string()));
1229            assert!(tools.contains(&EDIT_TOOL_NAME.to_string()));
1230        } else {
1231            panic!("Expected ToolFilter::Exclude");
1232        }
1233    }
1234
1235    #[test]
1236    fn test_session_tool_config_read_only() {
1237        let config = SessionToolConfig::read_only();
1238        assert_eq!(config.backends.len(), 0);
1239        assert!(matches!(config.visibility, ToolVisibility::ReadOnly));
1240        assert_eq!(
1241            config.approval_policy.default_behavior,
1242            UnapprovedBehavior::Prompt
1243        );
1244    }
1245
1246    #[tokio::test]
1247    async fn test_session_config_build_registry_no_default_backends() {
1248        // Test that BackendRegistry only contains user-configured backends.
1249        // Static tools (dispatch_agent, web_fetch) are now in ToolRegistry,
1250        // not BackendRegistry.
1251        let config = SessionConfig {
1252            workspace: WorkspaceConfig::Local {
1253                path: PathBuf::from("/test/path"),
1254            },
1255            workspace_ref: None,
1256            workspace_id: None,
1257            repo_ref: None,
1258            parent_session_id: None,
1259            workspace_name: None,
1260            tool_config: SessionToolConfig::default(), // No backends configured
1261            system_prompt: None,
1262            primary_agent_id: None,
1263            policy_overrides: SessionPolicyOverrides::empty(),
1264            metadata: HashMap::new(),
1265            default_model: test_model(),
1266            auto_compaction: AutoCompactionConfig::default(),
1267        };
1268
1269        let (registry, _mcp_servers) = config.build_registry().await.unwrap();
1270        let schemas = registry.get_tool_schemas().await;
1271
1272        assert!(
1273            schemas.is_empty(),
1274            "BackendRegistry should be empty with default config; got: {:?}",
1275            schemas.iter().map(|s| &s.name).collect::<Vec<_>>()
1276        );
1277    }
1278
1279    // Test removed: workspace tools are no longer in the registry
1280
1281    // Test removed: tool visibility filtering for workspace tools happens at the Workspace level
1282
1283    // Test removed: workspace backend no longer exists in the registry
1284
1285    #[test]
1286    fn test_mcp_status_tracking() {
1287        // Test that MCP server info is properly tracked in session state
1288        let mut session_state = SessionState::default();
1289
1290        // Add some MCP server info
1291        let mcp_info = McpServerInfo {
1292            server_name: "test-server".to_string(),
1293            transport: crate::tools::McpTransport::Stdio {
1294                command: "python".to_string(),
1295                args: vec!["-m".to_string(), "test_server".to_string()],
1296            },
1297            state: McpConnectionState::Connected {
1298                tool_names: vec![
1299                    "tool1".to_string(),
1300                    "tool2".to_string(),
1301                    "tool3".to_string(),
1302                    "tool4".to_string(),
1303                    "tool5".to_string(),
1304                ],
1305            },
1306            last_updated: Utc::now(),
1307        };
1308
1309        session_state
1310            .mcp_servers
1311            .insert("test-server".to_string(), mcp_info.clone());
1312
1313        // Verify it's stored
1314        assert_eq!(session_state.mcp_servers.len(), 1);
1315        let stored = session_state.mcp_servers.get("test-server").unwrap();
1316        assert_eq!(stored.server_name, "test-server");
1317        assert!(matches!(
1318            stored.state,
1319            McpConnectionState::Connected { ref tool_names } if tool_names.len() == 5
1320        ));
1321
1322        // Test failed connection
1323        let failed_info = McpServerInfo {
1324            server_name: "failed-server".to_string(),
1325            transport: crate::tools::McpTransport::Tcp {
1326                host: "localhost".to_string(),
1327                port: 9999,
1328            },
1329            state: McpConnectionState::Failed {
1330                error: "Connection refused".to_string(),
1331            },
1332            last_updated: Utc::now(),
1333        };
1334
1335        session_state
1336            .mcp_servers
1337            .insert("failed-server".to_string(), failed_info);
1338        assert_eq!(session_state.mcp_servers.len(), 2);
1339    }
1340
1341    #[tokio::test]
1342    async fn test_mcp_server_tracking_in_build_registry() {
1343        // Create a session config with both good and bad MCP servers
1344        let mut config = SessionConfig::read_only(test_model());
1345
1346        // This one should fail (invalid transport)
1347        config.tool_config.backends.push(BackendConfig::Mcp {
1348            server_name: "bad-server".to_string(),
1349            transport: crate::tools::McpTransport::Tcp {
1350                host: "nonexistent.invalid".to_string(),
1351                port: 12345,
1352            },
1353            tool_filter: ToolFilter::All,
1354        });
1355
1356        // This one would succeed if we had a real server running
1357        config.tool_config.backends.push(BackendConfig::Mcp {
1358            server_name: "good-server".to_string(),
1359            transport: crate::tools::McpTransport::Stdio {
1360                command: "echo".to_string(),
1361                args: vec!["test".to_string()],
1362            },
1363            tool_filter: ToolFilter::All,
1364        });
1365
1366        let (_registry, mcp_servers) = config.build_registry().await.unwrap();
1367
1368        // Should have tracked both servers
1369        assert_eq!(mcp_servers.len(), 2);
1370
1371        // Check the bad server
1372        let bad_server = mcp_servers.get("bad-server").unwrap();
1373        assert_eq!(bad_server.server_name, "bad-server");
1374        assert!(matches!(
1375            bad_server.state,
1376            McpConnectionState::Failed { .. }
1377        ));
1378
1379        // Check the good server (will also fail in tests since echo isn't an MCP server)
1380        let good_server = mcp_servers.get("good-server").unwrap();
1381        assert_eq!(good_server.server_name, "good-server");
1382        assert!(matches!(
1383            good_server.state,
1384            McpConnectionState::Failed { .. }
1385        ));
1386    }
1387
1388    #[test]
1389    fn test_backend_config_mcp_variant() {
1390        let mcp_config = BackendConfig::Mcp {
1391            server_name: "test-mcp".to_string(),
1392            transport: crate::tools::McpTransport::Stdio {
1393                command: "python".to_string(),
1394                args: vec!["-m".to_string(), "test_server".to_string()],
1395            },
1396            tool_filter: ToolFilter::All,
1397        };
1398
1399        let BackendConfig::Mcp {
1400            server_name,
1401            transport,
1402            ..
1403        } = mcp_config;
1404
1405        assert_eq!(server_name, "test-mcp");
1406        if let crate::tools::McpTransport::Stdio { command, args } = transport {
1407            assert_eq!(command, "python");
1408            assert_eq!(args.len(), 2);
1409        } else {
1410            panic!("Expected Stdio transport");
1411        }
1412    }
1413}