1use std::collections::HashMap;
18use std::path::PathBuf;
19use std::time::Instant;
20
21use rig::client::{CompletionClient, ProviderClient};
22use rig::completion::Message as RigMessage;
23use rig::completion::Prompt;
24use rig::completion::message::{AssistantContent, UserContent};
25use rig::one_or_many::OneOrMany;
26use rig::providers::{anthropic, openai};
27use syncable_ag_ui_core::{Role, RunId, ThreadId};
28use tokio::sync::mpsc;
29use tracing::{debug, error, info, warn};
30
31use super::{AgentMessage, EventBridge};
32use crate::agent::prompts;
33use crate::agent::tools::{
34 AnalyzeTool,
36 DclintTool,
37 HadolintTool,
39 HelmlintTool,
40 K8sCostsTool,
41 K8sDriftTool,
42 K8sOptimizeTool,
44 KubelintTool,
45 ListDirectoryTool,
46 ListOutputsTool,
47 ReadFileTool,
48 RetrieveOutputTool,
49 SecurityScanTool,
50 ShellTool,
51 TerraformFmtTool,
53 TerraformInstallTool,
54 TerraformValidateTool,
55 VulnerabilitiesTool,
56 WebFetchTool,
58 WriteFileTool,
60 WriteFilesTool,
61};
62
63use rig::agent::CancelSignal;
64use rig::completion::{CompletionModel, CompletionResponse, Message as RigPromptMessage};
65use serde::{Deserialize, Serialize};
66use std::sync::Arc;
67use syncable_ag_ui_core::ToolCallId;
68use tokio::sync::Mutex;
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72#[serde(rename_all = "lowercase")]
73pub enum StepStatus {
74 Pending,
75 Completed,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct AgentStep {
81 pub description: String,
82 pub status: StepStatus,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ToolResult {
88 pub tool_name: String,
90 pub args: serde_json::Value,
92 pub result: serde_json::Value,
94 #[serde(default)]
96 pub is_error: bool,
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
104pub struct AgentUiState {
105 pub steps: Vec<AgentStep>,
107 #[serde(skip_serializing_if = "Option::is_none")]
109 pub current_tool: Option<String>,
110 #[serde(default)]
112 pub tool_results: Vec<ToolResult>,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub metadata: Option<serde_json::Value>,
116}
117
118impl AgentUiState {
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn add_step(&mut self, description: impl Into<String>) {
126 self.steps.push(AgentStep {
127 description: description.into(),
128 status: StepStatus::Pending,
129 });
130 }
131
132 pub fn complete_step(&mut self, index: usize) {
134 if let Some(step) = self.steps.get_mut(index) {
135 step.status = StepStatus::Completed;
136 }
137 }
138
139 pub fn complete_current_step(&mut self) {
141 for step in &mut self.steps {
142 if step.status == StepStatus::Pending {
143 step.status = StepStatus::Completed;
144 break;
145 }
146 }
147 }
148
149 pub fn set_current_tool(&mut self, tool: Option<String>) {
151 self.current_tool = tool;
152 }
153
154 pub fn add_tool_result(
156 &mut self,
157 tool_name: String,
158 args: serde_json::Value,
159 result: serde_json::Value,
160 is_error: bool,
161 ) {
162 self.tool_results.push(ToolResult {
163 tool_name,
164 args,
165 result,
166 is_error,
167 });
168 }
169
170 pub fn to_json(&self) -> serde_json::Value {
172 serde_json::to_value(self).unwrap_or_default()
173 }
174}
175
176#[derive(Clone)]
178struct ToolCallInfo {
179 id: ToolCallId,
180 name: String,
181 args: serde_json::Value,
182}
183
184#[derive(Clone)]
190pub struct AgUiHook {
191 event_bridge: EventBridge,
192 current_tool_call: Arc<Mutex<Option<ToolCallInfo>>>,
194 state: Arc<Mutex<AgentUiState>>,
196}
197
198impl AgUiHook {
199 pub fn new(event_bridge: EventBridge) -> Self {
201 Self {
202 event_bridge,
203 current_tool_call: Arc::new(Mutex::new(None)),
204 state: Arc::new(Mutex::new(AgentUiState::new())),
205 }
206 }
207
208 async fn emit_state(&self) {
210 let state = self.state.lock().await;
211 self.event_bridge.emit_state_snapshot(state.to_json()).await;
212 }
213
214 pub async fn add_step(&self, description: impl Into<String>) {
216 {
217 let mut state = self.state.lock().await;
218 state.add_step(description);
219 }
220 self.emit_state().await;
221 }
222
223 pub async fn complete_current_step(&self) {
225 {
226 let mut state = self.state.lock().await;
227 state.complete_current_step();
228 }
229 self.emit_state().await;
230 }
231}
232
233impl<M> rig::agent::PromptHook<M> for AgUiHook
234where
235 M: CompletionModel,
236{
237 fn on_tool_call(
238 &self,
239 tool_name: &str,
240 _tool_call_id: Option<String>,
241 args: &str,
242 _cancel: CancelSignal,
243 ) -> impl std::future::Future<Output = ()> + Send {
244 let bridge = self.event_bridge.clone();
245 let name = tool_name.to_string();
246 let args_str = args.to_string();
247 let current_call = Arc::clone(&self.current_tool_call);
248 let state = Arc::clone(&self.state);
249
250 async move {
251 debug!(tool = %name, "AgUiHook: on_tool_call triggered");
252
253 let args_json: serde_json::Value = serde_json::from_str(&args_str)
255 .unwrap_or_else(|_| serde_json::json!({"raw": args_str}));
256
257 {
259 let mut s = state.lock().await;
260 let description = match name.as_str() {
262 "analyze_project" => "Analyzing project structure...".to_string(),
264 "read_file" => format!(
265 "Reading file: {}",
266 args_json
267 .get("path")
268 .and_then(|v| v.as_str())
269 .unwrap_or("...")
270 ),
271 "list_directory" => format!(
272 "Listing directory: {}",
273 args_json
274 .get("path")
275 .and_then(|v| v.as_str())
276 .unwrap_or("...")
277 ),
278 "security_scan" => "Running security scan...".to_string(),
280 "check_vulnerabilities" => "Checking for vulnerabilities...".to_string(),
281 "hadolint" => "Linting Dockerfiles...".to_string(),
283 "dclint" => "Linting docker-compose files...".to_string(),
284 "kubelint" => "Linting Kubernetes manifests...".to_string(),
285 "helmlint" => "Linting Helm charts...".to_string(),
286 "k8s_optimize" => "Analyzing Kubernetes resource optimization...".to_string(),
288 "k8s_costs" => "Calculating Kubernetes costs...".to_string(),
289 "k8s_drift" => "Detecting configuration drift...".to_string(),
290 "terraform_fmt" => "Formatting Terraform files...".to_string(),
292 "terraform_validate" => "Validating Terraform configuration...".to_string(),
293 "terraform_install" => "Installing Terraform...".to_string(),
294 "web_fetch" => format!(
296 "Fetching: {}",
297 args_json
298 .get("url")
299 .and_then(|v| v.as_str())
300 .unwrap_or("...")
301 ),
302 "retrieve_output" => "Retrieving stored output...".to_string(),
304 "list_outputs" => "Listing available outputs...".to_string(),
305 "write_file" => format!(
307 "Writing file: {}",
308 args_json
309 .get("path")
310 .and_then(|v| v.as_str())
311 .unwrap_or("...")
312 ),
313 "write_files" => "Writing multiple files...".to_string(),
314 "shell" => format!(
316 "Running command: {}",
317 args_json
318 .get("command")
319 .and_then(|v| v.as_str())
320 .map(|s| if s.len() > 50 {
321 format!("{}...", &s[..50])
322 } else {
323 s.to_string()
324 })
325 .unwrap_or("...".to_string())
326 ),
327 _ => format!("Running {}...", name.replace('_', " ")),
328 };
329 s.add_step(description);
330 s.set_current_tool(Some(name.clone()));
331 }
332
333 let s = state.lock().await;
335 bridge.emit_state_snapshot(s.to_json()).await;
336 drop(s);
337
338 let tool_call_id = bridge.start_tool_call(&name, &args_json).await;
340
341 let mut call_guard = current_call.lock().await;
343 *call_guard = Some(ToolCallInfo {
344 id: tool_call_id,
345 name: name.clone(),
346 args: args_json.clone(),
347 });
348 }
349 }
350
351 fn on_tool_result(
352 &self,
353 _tool_name: &str,
354 _tool_call_id: Option<String>,
355 _args: &str,
356 result: &str,
357 _cancel: CancelSignal,
358 ) -> impl std::future::Future<Output = ()> + Send {
359 let bridge = self.event_bridge.clone();
360 let current_call = Arc::clone(&self.current_tool_call);
361 let state = Arc::clone(&self.state);
362 let result_str = result.to_string();
363
364 async move {
365 let tool_call_info = {
367 let mut call_guard = current_call.lock().await;
368 call_guard.take()
369 };
370
371 let result_json: serde_json::Value = serde_json::from_str(&result_str)
373 .unwrap_or_else(|_| serde_json::json!({"raw": result_str}));
374
375 let is_error = result_json.get("error").is_some()
378 || result_json.get("success").and_then(|v| v.as_bool()) == Some(false)
379 || result_json.get("status").and_then(|v| v.as_str()) == Some("error")
380 || result_json.get("status").and_then(|v| v.as_str()) == Some("ERROR");
381
382 {
384 let mut s = state.lock().await;
385 s.complete_current_step();
386 s.set_current_tool(None);
387
388 if let Some(ref info) = tool_call_info {
390 debug!(
391 tool = %info.name,
392 result_size = result_str.len(),
393 "AgUiHook: capturing tool result for UI"
394 );
395 s.add_tool_result(
396 info.name.clone(),
397 info.args.clone(),
398 result_json.clone(),
399 is_error,
400 );
401 }
402 }
403
404 let s = state.lock().await;
406 bridge.emit_state_snapshot(s.to_json()).await;
407 drop(s);
408
409 if let Some(info) = tool_call_info {
411 bridge.end_tool_call(&info.id).await;
412 }
413 }
414 }
415
416 fn on_completion_response(
417 &self,
418 _prompt: &RigPromptMessage,
419 _response: &CompletionResponse<M::Response>,
420 _cancel: CancelSignal,
421 ) -> impl std::future::Future<Output = ()> + Send {
422 async {}
424 }
425}
426
427#[derive(Debug, thiserror::Error)]
429pub enum ProcessorError {
430 #[error("Unsupported provider: {0}")]
431 UnsupportedProvider(String),
432 #[error("LLM completion failed: {0}")]
433 CompletionFailed(String),
434 #[error("Missing API key for provider: {0}")]
435 MissingApiKey(String),
436}
437
438#[derive(Debug, Clone)]
440pub struct ProcessorConfig {
441 pub provider: String,
443 pub model: String,
445 pub max_turns: usize,
447 pub system_prompt: Option<String>,
449 pub project_path: PathBuf,
451}
452
453impl Default for ProcessorConfig {
454 fn default() -> Self {
455 Self {
456 provider: "openai".to_string(),
457 model: "gpt-4o".to_string(),
458 max_turns: 50,
459 system_prompt: None,
460 project_path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
461 }
462 }
463}
464
465impl ProcessorConfig {
466 pub fn new() -> Self {
468 Self::default()
469 }
470
471 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
473 self.provider = provider.into();
474 self
475 }
476
477 pub fn with_model(mut self, model: impl Into<String>) -> Self {
479 self.model = model.into();
480 self
481 }
482
483 pub fn with_max_turns(mut self, max_turns: usize) -> Self {
485 self.max_turns = max_turns;
486 self
487 }
488
489 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
491 self.system_prompt = Some(prompt.into());
492 self
493 }
494
495 pub fn with_project_path(mut self, path: impl Into<PathBuf>) -> Self {
497 self.project_path = path.into();
498 self
499 }
500
501 pub fn effective_system_prompt(&self, query: Option<&str>) -> String {
505 if let Some(ref prompt) = self.system_prompt {
506 return prompt.clone();
507 }
508 if let Some(q) = query {
511 if prompts::is_code_development_query(q) {
512 return prompts::get_code_development_prompt(&self.project_path);
513 }
514 if prompts::is_generation_query(q) {
515 return prompts::get_devops_prompt(&self.project_path, Some(q));
516 }
517 }
518 prompts::get_analysis_prompt(&self.project_path)
519 }
520}
521
522#[derive(Debug)]
524pub struct ThreadSession {
525 pub thread_id: ThreadId,
527 pub history: Vec<RigMessage>,
529 pub created_at: Instant,
531 pub turn_count: usize,
533}
534
535impl ThreadSession {
536 pub fn new(thread_id: ThreadId) -> Self {
538 Self {
539 thread_id,
540 history: Vec::new(),
541 created_at: Instant::now(),
542 turn_count: 0,
543 }
544 }
545
546 pub fn add_user_message(&mut self, content: &str) {
548 self.history.push(RigMessage::User {
549 content: OneOrMany::one(UserContent::text(content)),
550 });
551 }
552
553 pub fn add_assistant_message(&mut self, content: &str) {
555 self.history.push(RigMessage::Assistant {
556 id: None,
557 content: OneOrMany::one(AssistantContent::text(content)),
558 });
559 self.turn_count += 1;
560 }
561
562 pub fn inject_context(&mut self, context: &str) {
565 let context_msg = RigMessage::User {
568 content: OneOrMany::one(UserContent::text(format!("[Context]: {}", context))),
569 };
570 self.history.insert(0, context_msg);
571 }
572}
573
574pub struct AgentProcessor {
579 message_rx: mpsc::Receiver<AgentMessage>,
581 event_bridge: EventBridge,
583 sessions: HashMap<ThreadId, ThreadSession>,
585 config: ProcessorConfig,
587}
588
589impl AgentProcessor {
590 pub fn new(
597 message_rx: mpsc::Receiver<AgentMessage>,
598 event_bridge: EventBridge,
599 config: ProcessorConfig,
600 ) -> Self {
601 Self {
602 message_rx,
603 event_bridge,
604 sessions: HashMap::new(),
605 config,
606 }
607 }
608
609 pub fn with_defaults(
611 message_rx: mpsc::Receiver<AgentMessage>,
612 event_bridge: EventBridge,
613 ) -> Self {
614 Self::new(message_rx, event_bridge, ProcessorConfig::default())
615 }
616
617 fn get_or_create_session(&mut self, thread_id: &ThreadId) -> &mut ThreadSession {
619 self.sessions
620 .entry(thread_id.clone())
621 .or_insert_with(|| ThreadSession::new(thread_id.clone()))
622 }
623
624 pub fn session_count(&self) -> usize {
626 self.sessions.len()
627 }
628
629 pub fn config(&self) -> &ProcessorConfig {
631 &self.config
632 }
633
634 fn extract_user_input(
638 &self,
639 messages: &[syncable_ag_ui_core::types::Message],
640 ) -> Option<String> {
641 messages
643 .iter()
644 .rev()
645 .find(|m| m.role() == Role::User)
646 .and_then(|m| m.content().map(|s| s.to_string()))
647 }
648
649 pub async fn run(&mut self) {
654 info!("AgentProcessor starting message processing loop");
655
656 while let Some(msg) = self.message_rx.recv().await {
657 let input = msg.input;
658 let thread_id = input.thread_id.clone();
659 let run_id = input.run_id.clone();
660
661 debug!(
662 thread_id = %thread_id,
663 run_id = %run_id,
664 message_count = input.messages.len(),
665 "Received message from frontend"
666 );
667
668 self.apply_forwarded_props(&input.forwarded_props);
670
671 match self.extract_user_input(&input.messages) {
673 Some(user_input) => {
674 self.process_message(thread_id, run_id, user_input).await;
675 }
676 None => {
677 debug!(
678 thread_id = %thread_id,
679 "No user message found in input, skipping"
680 );
681 self.event_bridge.start_run().await;
683 self.event_bridge
684 .finish_run_with_error("No user message found in input")
685 .await;
686 }
687 }
688 }
689
690 info!("AgentProcessor message channel closed, shutting down");
691 }
692
693 fn apply_forwarded_props(&mut self, forwarded_props: &serde_json::Value) {
695 if let Some(obj) = forwarded_props.as_object() {
696 if let Some(provider) = obj.get("provider").and_then(|v| v.as_str()) {
698 if !provider.is_empty() {
699 debug!(provider = %provider, "Applying provider from forwardedProps");
700 self.config.provider = provider.to_string();
701 }
702 }
703
704 if let Some(model) = obj.get("model").and_then(|v| v.as_str()) {
706 if !model.is_empty() {
707 debug!(model = %model, "Applying model from forwardedProps");
708 self.config.model = model.to_string();
709 }
710 }
711
712 if let Some(api_key) = obj.get("apiKey").and_then(|v| v.as_str()) {
714 if !api_key.is_empty() {
715 let provider = self.config.provider.to_lowercase();
716 match provider.as_str() {
717 "openai" => {
718 debug!("Setting OPENAI_API_KEY from forwardedProps");
719 unsafe {
721 std::env::set_var("OPENAI_API_KEY", api_key);
722 }
723 }
724 "anthropic" => {
725 debug!("Setting ANTHROPIC_API_KEY from forwardedProps");
726 unsafe {
727 std::env::set_var("ANTHROPIC_API_KEY", api_key);
728 }
729 }
730 _ => {}
731 }
732 }
733 }
734
735 if let Some(region) = obj.get("awsRegion").and_then(|v| v.as_str()) {
737 if !region.is_empty() {
738 debug!(region = %region, "Setting AWS_REGION from forwardedProps");
739 unsafe {
740 std::env::set_var("AWS_REGION", region);
741 }
742 }
743 }
744 }
745 }
746
747 async fn process_message(&mut self, thread_id: ThreadId, _run_id: RunId, user_input: String) {
756 info!(
757 thread_id = %thread_id,
758 input_len = user_input.len(),
759 provider = %self.config.provider,
760 model = %self.config.model,
761 "Processing message through LLM"
762 );
763
764 let session = self.get_or_create_session(&thread_id);
766 session.add_user_message(&user_input);
767
768 self.event_bridge.start_run().await;
770
771 self.event_bridge.start_thinking(Some("Thinking")).await;
773
774 let response = self.call_llm(&thread_id, &user_input).await;
776
777 self.event_bridge.end_thinking().await;
778
779 match response {
780 Ok(response_text) => {
781 self.event_bridge.start_message().await;
783
784 for chunk in response_text.chars().collect::<Vec<_>>().chunks(50) {
786 let chunk_str: String = chunk.iter().collect();
787 self.event_bridge.emit_text_chunk(&chunk_str).await;
788 }
789
790 self.event_bridge.end_message().await;
791
792 let session = self.get_or_create_session(&thread_id);
794 session.add_assistant_message(&response_text);
795
796 debug!(
797 thread_id = %thread_id,
798 turn_count = session.turn_count,
799 response_len = response_text.len(),
800 "Message processed successfully"
801 );
802
803 self.event_bridge.finish_run().await;
805 }
806 Err(e) => {
807 error!(
808 thread_id = %thread_id,
809 error = %e,
810 "LLM call failed"
811 );
812 self.event_bridge
813 .finish_run_with_error(&e.to_string())
814 .await;
815 }
816 }
817 }
818
819 async fn call_llm(
821 &mut self,
822 thread_id: &ThreadId,
823 user_input: &str,
824 ) -> Result<String, ProcessorError> {
825 let preamble = self.config.effective_system_prompt(Some(user_input));
828 let provider = self.config.provider.to_lowercase();
829 let model = self.config.model.clone();
830 let max_turns = self.config.max_turns;
831 let project_path = self.config.project_path.clone();
832 let event_bridge = self.event_bridge.clone();
833
834 let session = self.get_or_create_session(thread_id);
836 let history = &mut session.history;
837
838 debug!(
839 provider = %provider,
840 model = %model,
841 project_path = %project_path.display(),
842 history_len = history.len(),
843 "Calling LLM with tools"
844 );
845
846 match provider.as_str() {
847 "openai" => {
848 if std::env::var("OPENAI_API_KEY").is_err() {
850 warn!("OPENAI_API_KEY not set");
851 return Err(ProcessorError::MissingApiKey("OPENAI_API_KEY".to_string()));
852 }
853
854 let hook = AgUiHook::new(event_bridge.clone());
856
857 let client = openai::Client::from_env();
858 let agent = client
859 .agent(model)
860 .preamble(&preamble)
861 .max_tokens(4096)
862 .tool(AnalyzeTool::new(project_path.clone()))
864 .tool(ReadFileTool::new(project_path.clone()))
865 .tool(ListDirectoryTool::new(project_path.clone()))
866 .tool(SecurityScanTool::new(project_path.clone()))
868 .tool(VulnerabilitiesTool::new(project_path.clone()))
869 .tool(HadolintTool::new(project_path.clone()))
870 .tool(DclintTool::new(project_path.clone()))
871 .tool(KubelintTool::new(project_path.clone()))
872 .tool(HelmlintTool::new(project_path.clone()))
873 .tool(K8sOptimizeTool::new(project_path.clone()))
875 .tool(K8sCostsTool::new(project_path.clone()))
876 .tool(K8sDriftTool::new(project_path.clone()))
877 .tool(TerraformFmtTool::new(project_path.clone()))
879 .tool(TerraformValidateTool::new(project_path.clone()))
880 .tool(TerraformInstallTool::new())
881 .tool(WebFetchTool::new())
883 .tool(RetrieveOutputTool::new())
884 .tool(ListOutputsTool::new())
885 .tool(WriteFileTool::new(project_path.clone()))
887 .tool(WriteFilesTool::new(project_path.clone()))
888 .tool(ShellTool::new(project_path.clone()))
889 .build();
890
891 agent
892 .prompt(user_input)
893 .with_history(history)
894 .with_hook(hook) .multi_turn(max_turns)
896 .await
897 .map_err(|e| ProcessorError::CompletionFailed(e.to_string()))
898 }
899 "anthropic" => {
900 if std::env::var("ANTHROPIC_API_KEY").is_err() {
902 warn!("ANTHROPIC_API_KEY not set");
903 return Err(ProcessorError::MissingApiKey(
904 "ANTHROPIC_API_KEY".to_string(),
905 ));
906 }
907
908 let hook = AgUiHook::new(event_bridge.clone());
910
911 let client = anthropic::Client::from_env();
912 let agent = client
913 .agent(model)
914 .preamble(&preamble)
915 .max_tokens(4096)
916 .tool(AnalyzeTool::new(project_path.clone()))
918 .tool(ReadFileTool::new(project_path.clone()))
919 .tool(ListDirectoryTool::new(project_path.clone()))
920 .tool(SecurityScanTool::new(project_path.clone()))
922 .tool(VulnerabilitiesTool::new(project_path.clone()))
923 .tool(HadolintTool::new(project_path.clone()))
924 .tool(DclintTool::new(project_path.clone()))
925 .tool(KubelintTool::new(project_path.clone()))
926 .tool(HelmlintTool::new(project_path.clone()))
927 .tool(K8sOptimizeTool::new(project_path.clone()))
929 .tool(K8sCostsTool::new(project_path.clone()))
930 .tool(K8sDriftTool::new(project_path.clone()))
931 .tool(TerraformFmtTool::new(project_path.clone()))
933 .tool(TerraformValidateTool::new(project_path.clone()))
934 .tool(TerraformInstallTool::new())
935 .tool(WebFetchTool::new())
937 .tool(RetrieveOutputTool::new())
938 .tool(ListOutputsTool::new())
939 .tool(WriteFileTool::new(project_path.clone()))
941 .tool(WriteFilesTool::new(project_path.clone()))
942 .tool(ShellTool::new(project_path.clone()))
943 .build();
944
945 agent
946 .prompt(user_input)
947 .with_history(history)
948 .with_hook(hook) .multi_turn(max_turns)
950 .await
951 .map_err(|e| ProcessorError::CompletionFailed(e.to_string()))
952 }
953 "bedrock" | "aws" | "aws-bedrock" => {
954 let hook = AgUiHook::new(event_bridge.clone());
956
957 let client = crate::bedrock::client::Client::from_env();
959 let agent = client
960 .agent(model)
961 .preamble(&preamble)
962 .max_tokens(4096)
963 .tool(AnalyzeTool::new(project_path.clone()))
965 .tool(ReadFileTool::new(project_path.clone()))
966 .tool(ListDirectoryTool::new(project_path.clone()))
967 .tool(SecurityScanTool::new(project_path.clone()))
969 .tool(VulnerabilitiesTool::new(project_path.clone()))
970 .tool(HadolintTool::new(project_path.clone()))
971 .tool(DclintTool::new(project_path.clone()))
972 .tool(KubelintTool::new(project_path.clone()))
973 .tool(HelmlintTool::new(project_path.clone()))
974 .tool(K8sOptimizeTool::new(project_path.clone()))
976 .tool(K8sCostsTool::new(project_path.clone()))
977 .tool(K8sDriftTool::new(project_path.clone()))
978 .tool(TerraformFmtTool::new(project_path.clone()))
980 .tool(TerraformValidateTool::new(project_path.clone()))
981 .tool(TerraformInstallTool::new())
982 .tool(WebFetchTool::new())
984 .tool(RetrieveOutputTool::new())
985 .tool(ListOutputsTool::new())
986 .tool(WriteFileTool::new(project_path.clone()))
988 .tool(WriteFilesTool::new(project_path.clone()))
989 .tool(ShellTool::new(project_path))
990 .build();
991
992 agent
993 .prompt(user_input)
994 .with_history(history)
995 .with_hook(hook) .multi_turn(max_turns)
997 .await
998 .map_err(|e| ProcessorError::CompletionFailed(e.to_string()))
999 }
1000 _ => Err(ProcessorError::UnsupportedProvider(provider.to_string())),
1001 }
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008 use std::sync::Arc;
1009 use tokio::sync::RwLock;
1010 use tokio::sync::broadcast;
1011
1012 fn create_test_processor() -> (AgentProcessor, mpsc::Sender<AgentMessage>) {
1013 let (msg_tx, msg_rx) = mpsc::channel(100);
1014 let (event_tx, _) = broadcast::channel(100);
1015 let bridge = EventBridge::new(
1016 event_tx,
1017 Arc::new(RwLock::new(ThreadId::random())),
1018 Arc::new(RwLock::new(None)),
1019 );
1020 let processor = AgentProcessor::with_defaults(msg_rx, bridge);
1021 (processor, msg_tx)
1022 }
1023
1024 #[test]
1025 fn test_processor_config_default() {
1026 let config = ProcessorConfig::default();
1027 assert_eq!(config.provider, "openai");
1028 assert_eq!(config.model, "gpt-4o");
1029 assert_eq!(config.max_turns, 50);
1030 }
1031
1032 #[test]
1033 fn test_processor_config_builder() {
1034 let config = ProcessorConfig::new()
1035 .with_provider("anthropic")
1036 .with_model("claude-3-opus")
1037 .with_max_turns(100);
1038 assert_eq!(config.provider, "anthropic");
1039 assert_eq!(config.model, "claude-3-opus");
1040 assert_eq!(config.max_turns, 100);
1041 }
1042
1043 #[test]
1044 fn test_processor_config_system_prompt() {
1045 let config = ProcessorConfig::default();
1047 assert!(config.system_prompt.is_none());
1048 assert!(
1050 config
1051 .effective_system_prompt(None)
1052 .contains("DevOps/Platform Engineer")
1053 );
1054
1055 let config = ProcessorConfig::new().with_system_prompt("You are a DevOps expert.");
1057 assert_eq!(
1058 config.system_prompt,
1059 Some("You are a DevOps expert.".to_string())
1060 );
1061 assert_eq!(
1062 config.effective_system_prompt(None),
1063 "You are a DevOps expert."
1064 );
1065 }
1066
1067 #[test]
1068 fn test_thread_session_inject_context() {
1069 let mut session = ThreadSession::new(ThreadId::random());
1070
1071 session.add_user_message("Hello");
1073 session.add_assistant_message("Hi there!");
1074 assert_eq!(session.history.len(), 2);
1075
1076 session.inject_context("Working on project: my-app");
1078 assert_eq!(session.history.len(), 3);
1079
1080 if let RigMessage::User { content } = &session.history[0] {
1082 let content_str = format!("{:?}", content);
1083 assert!(content_str.contains("[Context]"));
1084 assert!(content_str.contains("my-app"));
1085 } else {
1086 panic!("Expected User message at index 0");
1087 }
1088 }
1089
1090 #[test]
1091 fn test_thread_session_new() {
1092 let thread_id = ThreadId::random();
1093 let session = ThreadSession::new(thread_id.clone());
1094 assert_eq!(session.thread_id, thread_id);
1095 assert!(session.history.is_empty());
1096 assert_eq!(session.turn_count, 0);
1097 }
1098
1099 #[test]
1100 fn test_thread_session_add_messages() {
1101 let mut session = ThreadSession::new(ThreadId::random());
1102
1103 session.add_user_message("Hello");
1104 assert_eq!(session.history.len(), 1);
1105 assert_eq!(session.turn_count, 0); session.add_assistant_message("Hi there!");
1108 assert_eq!(session.history.len(), 2);
1109 assert_eq!(session.turn_count, 1); }
1111
1112 #[test]
1113 fn test_processor_creation() {
1114 let (processor, _tx) = create_test_processor();
1115 assert_eq!(processor.session_count(), 0);
1116 assert_eq!(processor.config().provider, "openai");
1117 }
1118
1119 #[test]
1120 fn test_get_or_create_session() {
1121 let (mut processor, _tx) = create_test_processor();
1122 let thread_id = ThreadId::random();
1123
1124 let session = processor.get_or_create_session(&thread_id);
1126 assert_eq!(session.turn_count, 0);
1127
1128 session.add_user_message("test");
1130
1131 let session = processor.get_or_create_session(&thread_id);
1133 assert_eq!(session.history.len(), 1);
1134 }
1135
1136 #[tokio::test]
1137 async fn test_process_message() {
1138 let (mut processor, _tx) = create_test_processor();
1139 let thread_id = ThreadId::random();
1140 let run_id = RunId::random();
1141
1142 processor
1143 .process_message(thread_id.clone(), run_id, "Hello, agent!".to_string())
1144 .await;
1145
1146 assert_eq!(processor.session_count(), 1);
1148 let session = processor.sessions.get(&thread_id).unwrap();
1149
1150 assert!(
1152 session.history.len() >= 1,
1153 "User message should be in history"
1154 );
1155
1156 if std::env::var("OPENAI_API_KEY").is_ok() {
1159 assert_eq!(session.turn_count, 1);
1161 assert_eq!(session.history.len(), 2); } else {
1163 assert_eq!(session.turn_count, 0);
1165 assert_eq!(session.history.len(), 1); }
1167 }
1168
1169 #[tokio::test]
1170 async fn test_run_processes_messages() {
1171 use syncable_ag_ui_core::Event;
1172 use syncable_ag_ui_core::types::{Message as AgUiProtocolMessage, RunAgentInput};
1173 use tokio::sync::broadcast;
1174
1175 let (msg_tx, msg_rx) = mpsc::channel(100);
1176 let (event_tx, mut event_rx) = broadcast::channel(100);
1177
1178 let bridge = EventBridge::new(
1179 event_tx,
1180 Arc::new(RwLock::new(ThreadId::random())),
1181 Arc::new(RwLock::new(None)),
1182 );
1183
1184 let mut processor = AgentProcessor::with_defaults(msg_rx, bridge);
1185
1186 let handle = tokio::spawn(async move {
1188 processor.run().await;
1189 });
1190
1191 let thread_id = ThreadId::random();
1193 let run_id = RunId::random();
1194 let input = RunAgentInput::new(thread_id.clone(), run_id.clone())
1195 .with_messages(vec![AgUiProtocolMessage::new_user("Hello from test")]);
1196
1197 let agent_msg = super::super::AgentMessage::new(input);
1198 msg_tx.send(agent_msg).await.expect("Should send");
1199
1200 let event = tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv())
1202 .await
1203 .expect("Should receive event in time")
1204 .expect("Should have event");
1205
1206 assert!(matches!(event, Event::RunStarted(_)));
1207
1208 drop(msg_tx);
1210
1211 tokio::time::timeout(std::time::Duration::from_millis(100), handle)
1213 .await
1214 .expect("Processor should finish")
1215 .expect("Should not panic");
1216 }
1217
1218 #[tokio::test]
1219 async fn test_run_handles_empty_messages() {
1220 use syncable_ag_ui_core::Event;
1221 use syncable_ag_ui_core::types::RunAgentInput;
1222 use tokio::sync::broadcast;
1223
1224 let (msg_tx, msg_rx) = mpsc::channel(100);
1225 let (event_tx, mut event_rx) = broadcast::channel(100);
1226
1227 let bridge = EventBridge::new(
1228 event_tx,
1229 Arc::new(RwLock::new(ThreadId::random())),
1230 Arc::new(RwLock::new(None)),
1231 );
1232
1233 let mut processor = AgentProcessor::with_defaults(msg_rx, bridge);
1234
1235 let handle = tokio::spawn(async move {
1237 processor.run().await;
1238 });
1239
1240 let thread_id = ThreadId::random();
1242 let run_id = RunId::random();
1243 let input = RunAgentInput::new(thread_id.clone(), run_id.clone());
1244 let agent_msg = super::super::AgentMessage::new(input);
1247 msg_tx.send(agent_msg).await.expect("Should send");
1248
1249 let event = tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv())
1251 .await
1252 .expect("Should receive event")
1253 .expect("Should have event");
1254
1255 assert!(matches!(event, Event::RunStarted(_)));
1256
1257 let event = tokio::time::timeout(std::time::Duration::from_millis(100), event_rx.recv())
1258 .await
1259 .expect("Should receive event")
1260 .expect("Should have event");
1261
1262 assert!(matches!(event, Event::RunError(_)));
1263
1264 drop(msg_tx);
1265 let _ = tokio::time::timeout(std::time::Duration::from_millis(100), handle).await;
1266 }
1267}