steer_core/app/
context.rs

1use std::collections::HashMap;
2use std::time::Instant;
3use tokio::task::JoinSet;
4use tokio_util::sync::CancellationToken;
5use uuid;
6
7use crate::app::agent_executor::AgentExecutorError;
8use crate::app::command::AppCommand;
9use crate::app::conversation::Message;
10use once_cell::sync::OnceCell;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14// Global command sender for tool approval requests
15static COMMAND_TX: OnceCell<Arc<mpsc::Sender<AppCommand>>> = OnceCell::new();
16
17#[derive(Debug)]
18pub enum TaskOutcome {
19    AgentOperationComplete {
20        result: std::result::Result<Message, AgentExecutorError>,
21    },
22    DispatchAgentResult {
23        result: std::result::Result<String, steer_tools::ToolError>,
24    },
25    BashCommandComplete {
26        op_id: uuid::Uuid,
27        command: String,
28        start_time: Instant,
29        result: std::result::Result<steer_tools::ToolResult, steer_tools::ToolError>,
30    },
31}
32
33// Holds the state for a single, cancellable user-initiated operation
34pub struct OpContext {
35    pub cancel_token: CancellationToken,
36    // Tasks now return TaskOutcome
37    pub tasks: JoinSet<TaskOutcome>,
38    // Track active tools by tool_call_id -> (op_id, start_time, tool_name)
39    pub active_tools: HashMap<String, (uuid::Uuid, Instant, String)>,
40    // Track the main operation ID if this context is for a Started/Finished operation
41    pub operation_id: Option<uuid::Uuid>,
42    // Removed: agent_event_receiver
43    // Removed: pending_tool_calls, expected_tool_results, api_call_in_progress
44}
45
46impl Default for OpContext {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl OpContext {
53    pub fn new() -> Self {
54        Self {
55            cancel_token: CancellationToken::new(),
56            tasks: JoinSet::new(),
57            active_tools: HashMap::new(),
58            operation_id: None,
59            // Removed: agent_event_receiver: None,
60        }
61    }
62
63    pub fn new_with_id(op_id: uuid::Uuid) -> Self {
64        Self {
65            cancel_token: CancellationToken::new(),
66            tasks: JoinSet::new(),
67            active_tools: HashMap::new(),
68            operation_id: Some(op_id),
69        }
70    }
71
72    // Removed: start_api_call, complete_api_call
73
74    pub fn add_active_tool(&mut self, id: String, op_id: uuid::Uuid, name: String) {
75        self.active_tools.insert(id, (op_id, Instant::now(), name));
76    }
77
78    pub fn remove_active_tool(&mut self, id: &str) -> Option<(uuid::Uuid, Instant, String)> {
79        self.active_tools.remove(id)
80    }
81
82    pub fn has_activity(&self) -> bool {
83        !self.tasks.is_empty() || !self.active_tools.is_empty()
84    }
85
86    pub async fn cancel_and_shutdown(&mut self) {
87        self.cancel_token.cancel();
88        self.tasks.shutdown().await;
89        self.active_tools.clear();
90        // Removed: self.agent_event_receiver = None;
91    }
92
93    /// Initialize the global command sender, should be called once during app setup
94    pub fn init_command_tx(tx: mpsc::Sender<AppCommand>) {
95        let _ = COMMAND_TX.set(Arc::new(tx));
96    }
97
98    /// Get the global command sender for tool approval requests
99    pub fn command_tx() -> Arc<mpsc::Sender<AppCommand>> {
100        COMMAND_TX
101            .get()
102            .expect("Command sender not initialized. Call OpContext::init_command_tx first")
103            .clone()
104    }
105}