Skip to main content

stynx_code_engine/application/
query_engine.rs

1use std::sync::{Arc, atomic::AtomicU8};
2
3use stynx_code_config::HooksConfig;
4use stynx_code_errors::{AppError, AppResult};
5use stynx_code_types::{
6    ContentBlock, Conversation, Message, PermissionChecker, PermissionLevel, PermissionMode,
7    Provider, Role, StopReason,
8};
9use stynx_code_tools::ToolRegistry;
10
11use crate::application::undo::UndoStack;
12use crate::domain::EngineEvent;
13use super::compactor::compact;
14use super::hook_runner::{run_post_tool_use, run_pre_tool_use, run_stop_hooks};
15use super::stream_reader::read_stream;
16use super::tool_executor::{execute_tool, is_overloaded};
17
18pub struct QueryEngine {
19    provider: Arc<dyn Provider>,
20    registry: Arc<ToolRegistry>,
21    permission: Arc<dyn PermissionChecker>,
22    hooks: HooksConfig,
23    max_turns: usize,
24    context_limit: u64,
25    mode: Arc<AtomicU8>,
26    undo_stack: Arc<UndoStack>,
27}
28
29impl QueryEngine {
30    pub fn new(
31        provider: Arc<dyn Provider>,
32        registry: Arc<ToolRegistry>,
33        permission: Arc<dyn PermissionChecker>,
34        mode: Arc<AtomicU8>,
35        hooks: HooksConfig,
36    ) -> Self {
37        Self {
38            provider, registry, permission, hooks, mode,
39            max_turns: 200, context_limit: 80_000,
40            undo_stack: Arc::new(UndoStack::default()),
41        }
42    }
43
44    pub fn with_max_turns(mut self, n: usize) -> Self {
45        self.max_turns = n;
46        self
47    }
48
49    pub fn mode_flag(&self) -> Arc<AtomicU8> { self.mode.clone() }
50    pub fn undo_stack(&self) -> Arc<UndoStack> { self.undo_stack.clone() }
51
52    pub async fn run<F>(
53        &self,
54        mut conversation: Conversation,
55        mut on_event: F,
56    ) -> AppResult<Conversation>
57    where
58        F: FnMut(EngineEvent) + Send,
59    {
60        let mut last_input_tokens: u64 = 0;
61
62        for turn in 0..self.max_turns {
63            tracing::info!(turn, "starting provider turn");
64            let is_plan = PermissionMode::load(&self.mode) == PermissionMode::Plan;
65            let tools = if is_plan {
66                self.registry.tool_definitions_filtered(|t| {
67                    t.permission_level() == PermissionLevel::ReadOnly || t.name() == "exit_plan_mode"
68                })
69            } else {
70                self.registry.tool_definitions_filtered(|t| {
71                    t.name() != "enter_plan_mode" && t.name() != "exit_plan_mode"
72                })
73            };
74
75            if last_input_tokens > 0
76                && last_input_tokens > self.context_limit * 60 / 100
77                && conversation.messages.len() > 2
78            {
79                let original_turns = conversation.messages.len();
80                conversation = compact(&self.provider, conversation, &mut on_event).await?;
81                on_event(EngineEvent::Compacted { original_turns });
82            }
83
84            let mut attempts = 0u32;
85            let (assistant_blocks, stop_reason) = loop {
86                let mut stream = match self.provider.stream(&conversation, &tools).await {
87                    Ok(s) => s,
88                    Err(e) if attempts < 3 && is_overloaded(&e.to_string()) => {
89                        attempts += 1;
90                        let delay = std::time::Duration::from_secs(2u64.pow(attempts));
91                        tokio::time::sleep(delay).await;
92                        continue;
93                    }
94                    Err(e) => return Err(e),
95                };
96
97                let (blocks, stop_reason, input_tokens, stream_error) =
98                    read_stream(&mut stream, &mut on_event).await;
99
100                if input_tokens > 0 {
101                    last_input_tokens = input_tokens;
102                }
103
104                if let Some(err_msg) = stream_error {
105                    if attempts < 3 && is_overloaded(&err_msg) {
106                        attempts += 1;
107                        let delay = std::time::Duration::from_secs(2u64.pow(attempts));
108                        tokio::time::sleep(delay).await;
109                        continue;
110                    }
111                    return Err(AppError::Provider(err_msg));
112                }
113
114                break (blocks, stop_reason);
115            };
116
117            conversation.push(Message::assistant(assistant_blocks.clone()));
118
119            if !matches!(stop_reason, StopReason::ToolUse) {
120                on_event(EngineEvent::TurnComplete);
121                let stop_out = run_stop_hooks(&self.hooks).await;
122                if !stop_out.is_empty() {
123                    on_event(EngineEvent::HookOutput { source: "stop".into(), output: stop_out });
124                }
125                return Ok(conversation);
126            }
127
128            let tool_uses: Vec<(String, String, serde_json::Value)> = assistant_blocks
129                .iter()
130                .filter_map(|b| match b {
131                    ContentBlock::ToolUse { id, name, input } => Some((id.clone(), name.clone(), input.clone())),
132                    _ => None,
133                })
134                .collect();
135
136            let mut pre_outs = Vec::new();
137            for (_, name, input) in &tool_uses {
138                let pre = run_pre_tool_use(&self.hooks, name, &input.to_string()).await;
139                pre_outs.push(pre);
140            }
141
142            let registry = self.registry.clone();
143            let permission = self.permission.clone();
144            let undo = self.undo_stack.clone();
145
146            let mut exec_results: Vec<Result<Result<String, AppError>, tokio::task::JoinError>> =
147                Vec::with_capacity(tool_uses.len());
148
149            let mut parallel_handles: Vec<(usize, tokio::task::JoinHandle<Result<String, AppError>>)> = Vec::new();
150            for (i, ((_, name, input), pre)) in tool_uses.iter().zip(pre_outs.iter()).enumerate() {
151                if pre.blocked {
152                    continue;
153                }
154                let tool = registry.get(name);
155                let is_safe = tool.is_some_and(|t| t.is_concurrent_safe(input));
156                if is_safe {
157                    let reg = registry.clone();
158                    let perm = permission.clone();
159                    let ud = undo.clone();
160                    let n = name.clone();
161                    let inp = input.clone();
162                    parallel_handles.push((i, tokio::spawn(async move {
163                        execute_tool(&reg, &perm, &n, &inp, &ud).await
164                    })));
165                }
166            }
167
168            let parallel_results: Vec<_> = futures::future::join_all(
169                parallel_handles.into_iter().map(|(i, h)| async move { (i, h.await) })
170            ).await;
171            let mut result_map: std::collections::HashMap<usize, Result<Result<String, AppError>, tokio::task::JoinError>> =
172                parallel_results.into_iter().collect();
173
174            for (i, ((_, name, input), pre)) in tool_uses.iter().zip(pre_outs.iter()).enumerate() {
175                if pre.blocked {
176                    exec_results.push(Ok(Ok(String::new())));
177                } else if let Some(result) = result_map.remove(&i) {
178                    exec_results.push(result);
179                } else {
180
181                    let result = execute_tool(&registry, &permission, name, input, &undo).await;
182                    exec_results.push(Ok(result));
183                }
184            }
185
186            let mut tool_results = Vec::new();
187            let mut exit_plan_called = false;
188            let mut pre_iter = pre_outs.into_iter();
189            let mut exec_iter = exec_results.into_iter();
190            for (id, name, input) in &tool_uses {
191                let pre = pre_iter.next().unwrap();
192                let exec_result = exec_iter.next().unwrap();
193                let input_json = input.to_string();
194                if !pre.output.is_empty() {
195                    on_event(EngineEvent::HookOutput { source: "pre-tool".into(), output: pre.output });
196                }
197                if pre.blocked {
198                    on_event(EngineEvent::ToolResult { name: name.clone(), output: pre.reason.clone(), is_error: true });
199                    tool_results.push(ContentBlock::ToolResult { tool_use_id: id.clone(), content: pre.reason, is_error: Some(true) });
200                } else {
201                    let result: AppResult<String> = match exec_result {
202                        Ok(r) => r,
203                        Err(e) => Err(AppError::Tool(e.to_string())),
204                    };
205                    match result {
206                        Ok(output) => {
207                            let post = run_post_tool_use(&self.hooks, name, &input_json, &output).await;
208                            if !post.is_empty() {
209                                on_event(EngineEvent::HookOutput { source: "post-tool".into(), output: post });
210                            }
211                            on_event(EngineEvent::ToolResult { name: name.clone(), output: output.clone(), is_error: false });
212                            tool_results.push(ContentBlock::ToolResult { tool_use_id: id.clone(), content: output, is_error: None });
213                            if name == "exit_plan_mode" {
214                                PermissionMode::Normal.store(&self.mode);
215                                on_event(EngineEvent::ModeChanged { mode: PermissionMode::Normal });
216                                exit_plan_called = true;
217                            }
218                        }
219                        Err(ref e) if e.is_interrupted() => {
220                            return Err(AppError::Interrupted);
221                        }
222                        Err(e) => {
223                            let msg = e.to_string();
224                            on_event(EngineEvent::ToolResult { name: name.clone(), output: msg.clone(), is_error: true });
225                            tool_results.push(ContentBlock::ToolResult { tool_use_id: id.clone(), content: msg, is_error: Some(true) });
226                        }
227                    }
228                }
229            }
230
231            conversation.push(Message {
232                role: Role::User,
233                content: tool_results,
234            });
235
236            if exit_plan_called {
237                on_event(EngineEvent::TurnComplete);
238                return Ok(conversation);
239            }
240        }
241
242        Err(AppError::MaxTurnsExceeded(self.max_turns))
243    }
244}