zagens_core/engine/
handle.rs1use std::sync::{Arc, Mutex as StdMutex};
16use std::time::Duration;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use tokio::sync::{RwLock, mpsc, oneshot};
21use tokio_util::sync::CancellationToken;
22
23use crate::engine::approval::{ApprovalDecision, UserInputDecision};
24use crate::engine::context_snapshot::ThreadContextSnapshot;
25use crate::engine::op::Op;
26use crate::engine::start_turn::StartTurnParams;
27use crate::engine::turn_port::TurnEnginePort;
28use crate::events::Event;
29use crate::turn::TurnLoopMode;
30
31#[derive(Clone)]
33pub struct EngineHandle<P, R> {
34 pub tx_op: mpsc::Sender<Op>,
36 pub rx_event: Arc<RwLock<mpsc::Receiver<Event>>>,
38 cancel_token: Arc<StdMutex<CancellationToken>>,
40 tx_approval: mpsc::Sender<ApprovalDecision<P>>,
42 tx_user_input: mpsc::Sender<UserInputDecision<R>>,
44 tx_steer: mpsc::Sender<String>,
46}
47
48impl<P, R> EngineHandle<P, R>
49where
50 P: Send + Sync + 'static,
51 R: Send + Sync + 'static,
52{
53 #[must_use]
56 pub fn new(
57 tx_op: mpsc::Sender<Op>,
58 rx_event: Arc<RwLock<mpsc::Receiver<Event>>>,
59 cancel_token: Arc<StdMutex<CancellationToken>>,
60 tx_approval: mpsc::Sender<ApprovalDecision<P>>,
61 tx_user_input: mpsc::Sender<UserInputDecision<R>>,
62 tx_steer: mpsc::Sender<String>,
63 ) -> Self {
64 Self {
65 tx_op,
66 rx_event,
67 cancel_token,
68 tx_approval,
69 tx_user_input,
70 tx_steer,
71 }
72 }
73
74 pub async fn send(&self, op: Op) -> Result<()> {
76 self.tx_op.send(op).await?;
77 Ok(())
78 }
79
80 pub fn cancel(&self) {
82 match self.cancel_token.lock() {
83 Ok(token) => token.cancel(),
84 Err(poisoned) => poisoned.into_inner().cancel(),
85 }
86 }
87
88 #[must_use]
90 #[allow(dead_code)]
91 pub fn is_cancelled(&self) -> bool {
92 match self.cancel_token.lock() {
93 Ok(token) => token.is_cancelled(),
94 Err(poisoned) => poisoned.into_inner().is_cancelled(),
95 }
96 }
97
98 pub async fn approve_tool_call(&self, id: impl Into<String>) -> Result<()> {
100 self.approve_tool_call_with_options(id, None, false).await
101 }
102
103 pub async fn approve_tool_call_with_options(
105 &self,
106 id: impl Into<String>,
107 cache_key: Option<String>,
108 remember_for_session: bool,
109 ) -> Result<()> {
110 self.tx_approval
111 .send(ApprovalDecision::Approved {
112 id: id.into(),
113 cache_key,
114 remember_for_session,
115 })
116 .await?;
117 Ok(())
118 }
119
120 pub async fn deny_tool_call(&self, id: impl Into<String>) -> Result<()> {
122 self.tx_approval
123 .send(ApprovalDecision::Denied { id: id.into() })
124 .await?;
125 Ok(())
126 }
127
128 pub async fn retry_tool_with_policy(&self, id: impl Into<String>, policy: P) -> Result<()> {
130 self.tx_approval
131 .send(ApprovalDecision::RetryWithPolicy {
132 id: id.into(),
133 policy,
134 })
135 .await?;
136 Ok(())
137 }
138
139 pub async fn submit_user_input(&self, id: impl Into<String>, response: R) -> Result<()> {
141 self.tx_user_input
142 .send(UserInputDecision::Submitted {
143 id: id.into(),
144 response,
145 })
146 .await?;
147 Ok(())
148 }
149
150 pub async fn cancel_user_input(&self, id: impl Into<String>) -> Result<()> {
152 self.tx_user_input
153 .send(UserInputDecision::Cancelled { id: id.into() })
154 .await?;
155 Ok(())
156 }
157
158 pub async fn steer(&self, content: impl Into<String>) -> Result<()> {
160 self.tx_steer.send(content.into()).await?;
161 Ok(())
162 }
163
164 pub async fn query_context_snapshot(&self) -> Result<ThreadContextSnapshot> {
166 let (tx, rx) = oneshot::channel();
167 self.send(Op::QueryContext { reply: tx }).await?;
168 tokio::time::timeout(Duration::from_secs(5), rx)
169 .await
170 .map_err(|_| anyhow::anyhow!("context query timed out"))?
171 .map_err(|_| anyhow::anyhow!("engine dropped context query"))
172 }
173
174 pub async fn query_harness_task_graph(&self) -> Result<serde_json::Value> {
176 let (tx, rx) = oneshot::channel();
177 self.send(Op::QueryHarnessTaskGraph { reply: tx }).await?;
178 tokio::time::timeout(Duration::from_secs(5), rx)
179 .await
180 .map_err(|_| anyhow::anyhow!("harness task-graph query timed out"))?
181 .map_err(|_| anyhow::anyhow!("engine dropped harness task-graph query"))
182 }
183
184 pub async fn query_harness_cycles(&self) -> Result<serde_json::Value> {
186 let (tx, rx) = oneshot::channel();
187 self.send(Op::QueryHarnessCycles { reply: tx }).await?;
188 tokio::time::timeout(Duration::from_secs(5), rx)
189 .await
190 .map_err(|_| anyhow::anyhow!("harness cycles query timed out"))?
191 .map_err(|_| anyhow::anyhow!("engine dropped harness cycles query"))
192 }
193
194 pub async fn truncate_before_last_user_message(&self) -> Result<bool> {
196 let (tx, rx) = oneshot::channel();
197 self.send(Op::TruncateBeforeLastUserMessage { reply: tx })
198 .await?;
199 rx.await
200 .map_err(|_| anyhow::anyhow!("engine dropped truncate-before-last-user reply"))
201 }
202}
203
204#[async_trait]
205impl<P, R> TurnEnginePort for EngineHandle<P, R>
206where
207 P: Send + Sync + 'static,
208 R: Send + Sync + 'static,
209{
210 async fn start_turn(&self, params: StartTurnParams) -> Result<()> {
211 params.validate().map_err(anyhow::Error::msg)?;
212 self.send(Op::SendMessage {
213 content: params.prompt,
214 mode: TurnLoopMode::from_setting(¶ms.mode),
215 model: params.model,
216 goal_objective: None,
217 reasoning_effort: params.reasoning_effort,
218 reasoning_effort_auto: params.reasoning_effort_auto,
219 auto_model: params.auto_model,
220 allow_shell: params.allow_shell,
221 trust_mode: params.trust_mode,
222 auto_approve: params.auto_approve,
223 approval_mode: params.approval_mode,
224 temperature: params.temperature,
225 top_p: params.top_p,
226 max_output_tokens: params.max_output_tokens,
227 })
228 .await
229 }
230
231 fn cancel_active_turn(&self) {
232 self.cancel();
233 }
234}