1use crate::local::hooks::inline_scratchpad_context::{
5 InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
17use stakpak_shared::models::integrations::gemini::{GeminiConfig, GeminiModel};
18use stakpak_shared::models::integrations::openai::{
19 AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
21 OpenAIModel, Role, Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26 LLMStreamInput, chat, chat_stream,
27};
28use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
29use std::pin::Pin;
30use std::sync::Arc;
31use tokio::sync::mpsc;
32use uuid::Uuid;
33
34mod context_managers;
35mod db;
36mod hooks;
37
38#[cfg(test)]
39mod tests;
40
41#[derive(Clone, Debug)]
42pub struct LocalClient {
43 pub db: Connection,
44 pub stakpak_base_url: Option<String>,
45 pub anthropic_config: Option<AnthropicConfig>,
46 pub openai_config: Option<OpenAIConfig>,
47 pub gemini_config: Option<GeminiConfig>,
48 pub model_options: ModelOptions,
49 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
50 _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
51}
52
53#[derive(Clone, Debug)]
54pub struct ModelOptions {
55 pub smart_model: Option<LLMModel>,
56 pub eco_model: Option<LLMModel>,
57 pub recovery_model: Option<LLMModel>,
58}
59
60#[derive(Clone, Debug)]
61pub struct ModelSet {
62 pub smart_model: LLMModel,
63 pub eco_model: LLMModel,
64 pub recovery_model: LLMModel,
65 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
66 pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
67}
68
69impl ModelSet {
70 fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
71 match agent_model {
72 AgentModel::Smart => self.smart_model.clone(),
73 AgentModel::Eco => self.eco_model.clone(),
74 AgentModel::Recovery => self.recovery_model.clone(),
75 }
76 }
77}
78
79impl From<ModelOptions> for ModelSet {
80 fn from(value: ModelOptions) -> Self {
81 let smart_model = value
82 .smart_model
83 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
84 let eco_model = value
85 .eco_model
86 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
87 let recovery_model = value
88 .recovery_model
89 .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
90
91 Self {
92 smart_model,
93 eco_model,
94 recovery_model,
95 hook_registry: None,
96 _search_services_orchestrator: None,
97 }
98 }
99}
100
101pub struct LocalClientConfig {
102 pub stakpak_base_url: Option<String>,
103 pub store_path: Option<String>,
104 pub anthropic_config: Option<AnthropicConfig>,
105 pub openai_config: Option<OpenAIConfig>,
106 pub gemini_config: Option<GeminiConfig>,
107 pub smart_model: Option<String>,
108 pub eco_model: Option<String>,
109 pub recovery_model: Option<String>,
110 pub hook_registry: Option<HookRegistry<AgentState>>,
111}
112
113#[derive(Debug)]
114enum StreamMessage {
115 Delta(GenerationDelta),
116 Ctx(Box<HookContext<AgentState>>),
117}
118
119const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
120const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
121
122impl LocalClient {
123 pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
124 let store_path = config
125 .store_path
126 .map(std::path::PathBuf::from)
127 .unwrap_or_else(|| {
128 std::env::home_dir()
129 .unwrap_or_default()
130 .join(DEFAULT_STORE_PATH)
131 });
132
133 if let Some(parent) = store_path.parent() {
134 std::fs::create_dir_all(parent)
135 .map_err(|e| format!("Failed to create database directory: {}", e))?;
136 }
137
138 let db = Builder::new_local(store_path.display().to_string())
139 .build()
140 .await
141 .map_err(|e| e.to_string())?;
142
143 let conn = db.connect().map_err(|e| e.to_string())?;
144
145 db::init_schema(&conn).await?;
147
148 let model_options = ModelOptions {
149 smart_model: config.smart_model.map(LLMModel::from),
150 eco_model: config.eco_model.map(LLMModel::from),
151 recovery_model: config.recovery_model.map(LLMModel::from),
152 };
153
154 let mut hook_registry = config.hook_registry.unwrap_or_default();
156 hook_registry.register(
157 LifecycleEvent::BeforeInference,
158 Box::new(InlineScratchpadContextHook::new(
159 InlineScratchpadContextHookOptions {
160 model_options: model_options.clone(),
161 history_action_message_size_limit: Some(100),
162 history_action_message_keep_last_n: Some(1),
163 history_action_result_keep_last_n: Some(50),
164 },
165 )),
166 );
167 Ok(Self {
183 db: conn,
184 stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
185 anthropic_config: config.anthropic_config,
186 gemini_config: config.gemini_config,
187 openai_config: config.openai_config,
188 model_options,
189 hook_registry: Some(Arc::new(hook_registry)),
190 _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
191 })
192 }
193}
194
195#[async_trait]
196impl AgentProvider for LocalClient {
197 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
198 Ok(GetMyAccountResponse {
199 username: "local".to_string(),
200 id: "local".to_string(),
201 first_name: "local".to_string(),
202 last_name: "local".to_string(),
203 })
204 }
205
206 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
207 if self.stakpak_base_url.is_none() {
208 return Ok(vec![]);
209 }
210
211 let stakpak_base_url = self
212 .stakpak_base_url
213 .as_ref()
214 .ok_or("Stakpak base URL not set")?;
215
216 let url = format!("{}/rules", stakpak_base_url);
217
218 let client = create_tls_client(
219 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
220 )?;
221
222 let response = client
223 .get(url)
224 .send()
225 .await
226 .map_err(|e: ReqwestError| e.to_string())?;
227
228 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
229
230 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
231 Ok(response) => Ok(response.results),
232 Err(e) => {
233 eprintln!("Failed to deserialize response: {}", e);
234 eprintln!("Raw response: {}", value);
235 Err("Failed to deserialize response:".into())
236 }
237 }
238 }
239
240 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
241 let stakpak_base_url = self
242 .stakpak_base_url
243 .as_ref()
244 .ok_or("Stakpak base URL not set")?;
245
246 let encoded_uri = urlencoding::encode(uri);
247
248 let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
249
250 let client = create_tls_client(
251 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
252 )?;
253
254 let response = client
255 .get(&url)
256 .send()
257 .await
258 .map_err(|e: ReqwestError| e.to_string())?;
259
260 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
261
262 match serde_json::from_value::<RuleBook>(value.clone()) {
263 Ok(response) => Ok(response),
264 Err(e) => {
265 eprintln!("Failed to deserialize response: {}", e);
266 eprintln!("Raw response: {}", value);
267 Err("Failed to deserialize response:".into())
268 }
269 }
270 }
271
272 async fn create_rulebook(
273 &self,
274 _uri: &str,
275 _description: &str,
276 _content: &str,
277 _tags: Vec<String>,
278 _visibility: Option<RuleBookVisibility>,
279 ) -> Result<CreateRuleBookResponse, String> {
280 Err("Local provider does not support rulebooks yet".to_string())
282 }
283
284 async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
285 Err("Local provider does not support rulebooks yet".to_string())
287 }
288
289 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
290 db::list_sessions(&self.db).await
291 }
292
293 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
294 db::get_session(&self.db, session_id).await
295 }
296
297 async fn get_agent_session_stats(
298 &self,
299 _session_id: Uuid,
300 ) -> Result<AgentSessionStats, String> {
301 Ok(AgentSessionStats::default())
303 }
304
305 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
306 db::get_checkpoint(&self.db, checkpoint_id).await
307 }
308
309 async fn get_agent_session_latest_checkpoint(
310 &self,
311 session_id: Uuid,
312 ) -> Result<RunAgentOutput, String> {
313 db::get_latest_checkpoint(&self.db, session_id).await
314 }
315
316 async fn chat_completion(
317 &self,
318 model: AgentModel,
319 messages: Vec<ChatMessage>,
320 tools: Option<Vec<Tool>>,
321 ) -> Result<ChatCompletionResponse, String> {
322 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
323
324 if let Some(hook_registry) = &self.hook_registry {
325 hook_registry
326 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
327 .await
328 .map_err(|e| e.to_string())?
329 .ok()?;
330 }
331
332 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
333 ctx.set_session_id(current_checkpoint.session.id);
334
335 let new_message = self.run_agent_completion(&mut ctx, None).await?;
336 ctx.state.append_new_message(new_message.clone());
337
338 let result = self
339 .update_session(¤t_checkpoint, ctx.state.messages.clone())
340 .await?;
341 let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
342 ctx.set_new_checkpoint_id(result.checkpoint.id);
343
344 if let Some(hook_registry) = &self.hook_registry {
345 hook_registry
346 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
347 .await
348 .map_err(|e| e.to_string())?
349 .ok()?;
350 }
351
352 Ok(ChatCompletionResponse {
353 id: ctx.new_checkpoint_id.unwrap().to_string(),
354 object: "chat.completion".to_string(),
355 created: checkpoint_created_at,
356 model: ctx
357 .state
358 .llm_input
359 .as_ref()
360 .map(|llm_input| llm_input.model.clone().to_string())
361 .unwrap_or_default(),
362 choices: vec![ChatCompletionChoice {
363 index: 0,
364 message: ctx.state.messages.last().cloned().unwrap(),
365 logprobs: None,
366 finish_reason: FinishReason::Stop,
367 }],
368 usage: ctx
369 .state
370 .llm_output
371 .as_ref()
372 .map(|u| u.usage.clone())
373 .unwrap_or_default(),
374 system_fingerprint: None,
375 metadata: None,
376 })
377 }
378
379 async fn chat_completion_stream(
380 &self,
381 model: AgentModel,
382 messages: Vec<ChatMessage>,
383 tools: Option<Vec<Tool>>,
384 _headers: Option<HeaderMap>,
385 ) -> Result<
386 (
387 Pin<
388 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
389 >,
390 Option<String>,
391 ),
392 String,
393 > {
394 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
395
396 if let Some(hook_registry) = &self.hook_registry {
397 hook_registry
398 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
399 .await
400 .map_err(|e| e.to_string())?
401 .ok()?;
402 }
403
404 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
405 ctx.set_session_id(current_checkpoint.session.id);
406
407 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
408
409 let _ = tx
410 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
411 content: format!(
412 "\n<checkpoint_id>{}</checkpoint_id>\n",
413 current_checkpoint.checkpoint.id
414 ),
415 })))
416 .await;
417
418 let client = self.clone();
419 let self_clone = self.clone();
420 let mut ctx_clone = ctx.clone();
421 tokio::spawn(async move {
422 let result = client
423 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
424 .await;
425
426 match result {
427 Err(e) => {
428 let _ = tx.send(Err(e)).await;
429 }
430 Ok(new_message) => {
431 ctx_clone.state.append_new_message(new_message.clone());
432 let _ = tx
433 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
434 .await;
435
436 let output = self_clone
437 .update_session(¤t_checkpoint, ctx_clone.state.messages.clone())
438 .await;
439
440 match output {
441 Err(e) => {
442 let _ = tx.send(Err(e)).await;
443 }
444 Ok(output) => {
445 ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
446 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
447 let _ = tx
448 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
449 content: format!(
450 "\n<checkpoint_id>{}</checkpoint_id>\n",
451 output.checkpoint.id
452 ),
453 })))
454 .await;
455 }
456 }
457 }
458 }
459 });
460
461 let hook_registry = self.hook_registry.clone();
462 let stream = async_stream::stream! {
463 while let Some(delta_result) = rx.recv().await {
464 match delta_result {
465 Ok(delta) => match delta {
466 StreamMessage::Ctx(updated_ctx) => {
467 ctx = *updated_ctx;
468 }
469 StreamMessage::Delta(delta) => {
470 yield Ok(ChatCompletionStreamResponse {
471 id: ctx.request_id.to_string(),
472 object: "chat.completion.chunk".to_string(),
473 created: chrono::Utc::now().timestamp() as u64,
474 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
475 choices: vec![ChatCompletionStreamChoice {
476 index: 0,
477 delta: delta.into(),
478 finish_reason: None,
479 }],
480 usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
481 metadata: None,
482 })
483 }
484 }
485 Err(e) => yield Err(ApiStreamError::Unknown(e)),
486 }
487 }
488
489 if let Some(hook_registry) = hook_registry {
490 hook_registry
491 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
492 .await
493 .map_err(|e| e.to_string())?
494 .ok()?;
495 }
496 };
497
498 Ok((Box::pin(stream), None))
499 }
500
501 async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
502 Ok(())
503 }
504
505 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
506 let config = SearchServicesOrchestrator::start()
507 .await
508 .map_err(|e| e.to_string())?;
509
510 let api_url = format!("http://localhost:{}", config.api_port);
519 let search_client = SearchClient::new(api_url);
520
521 let initial_query = if let Some(exclude) = &input.exclude_keywords {
522 format!("{} -{}", input.keywords, exclude)
523 } else {
524 input.keywords.clone()
525 };
526
527 let llm_config = self.get_llm_config();
528 let search_model = get_search_model(
529 &llm_config,
530 self.model_options.eco_model.clone(),
531 self.model_options.smart_model.clone(),
532 );
533
534 let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
535 let required_documentation = analysis.required_documentation;
536 let mut current_query = analysis.reformulated_query;
537 let mut previous_queries = Vec::new();
538 let mut final_valid_docs = Vec::new();
539 let mut accumulated_needed_urls = Vec::new();
540
541 const MAX_ITERATIONS: usize = 3;
542
543 for _iteration in 0..MAX_ITERATIONS {
544 previous_queries.push(current_query.clone());
545
546 let search_results = search_client
547 .search_and_scrape(current_query.clone(), None)
548 .await
549 .map_err(|e| e.to_string())?;
550
551 if search_results.is_empty() {
552 break;
553 }
554
555 let validation_result = validate_search_docs(
556 &llm_config,
557 &search_model,
558 &search_results,
559 ¤t_query,
560 &required_documentation,
561 &previous_queries,
562 &accumulated_needed_urls,
563 )
564 .await?;
565
566 for url in &validation_result.needed_urls {
567 if !accumulated_needed_urls.contains(url) {
568 accumulated_needed_urls.push(url.clone());
569 }
570 }
571
572 for doc in validation_result.valid_docs.into_iter() {
573 let is_duplicate = final_valid_docs
574 .iter()
575 .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
576
577 if !is_duplicate {
578 final_valid_docs.push(doc);
579 }
580 }
581
582 if validation_result.is_satisfied {
583 break;
584 }
585
586 if let Some(new_query) = validation_result.new_query {
587 if new_query != current_query && !previous_queries.contains(&new_query) {
588 current_query = new_query;
589 } else {
590 break;
591 }
592 } else {
593 break;
594 }
595 }
596
597 if final_valid_docs.is_empty() {
598 return Ok(vec![Content::text("No results found".to_string())]);
599 }
600
601 let contents: Vec<Content> = final_valid_docs
602 .into_iter()
603 .map(|result| {
604 Content::text(format!(
605 "Title: {}\nURL: {}\nContent: {}",
606 result.title.unwrap_or_default(),
607 result.url,
608 result.content.unwrap_or_default(),
609 ))
610 })
611 .collect();
612
613 Ok(contents)
614 }
615
616 async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
617 Ok(Vec::new())
619 }
620
621 async fn slack_read_messages(
622 &self,
623 _input: &SlackReadMessagesRequest,
624 ) -> Result<Vec<Content>, String> {
625 Ok(Vec::new())
627 }
628
629 async fn slack_read_replies(
630 &self,
631 _input: &SlackReadRepliesRequest,
632 ) -> Result<Vec<Content>, String> {
633 Ok(Vec::new())
635 }
636
637 async fn slack_send_message(
638 &self,
639 _input: &SlackSendMessageRequest,
640 ) -> Result<Vec<Content>, String> {
641 Ok(Vec::new())
643 }
644
645 async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
646 Ok(())
648 }
649}
650
651impl LocalClient {
652 fn get_llm_config(&self) -> LLMProviderConfig {
653 LLMProviderConfig {
654 anthropic_config: self.anthropic_config.clone(),
655 openai_config: self.openai_config.clone(),
656 gemini_config: self.gemini_config.clone(),
657 }
658 }
659
660 async fn run_agent_completion(
661 &self,
662 ctx: &mut HookContext<AgentState>,
663 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
664 ) -> Result<ChatMessage, String> {
665 if let Some(hook_registry) = &self.hook_registry {
666 hook_registry
667 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
668 .await
669 .map_err(|e| e.to_string())?
670 .ok()?;
671 }
672
673 let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
674 llm_input
675 } else {
676 return Err(
677 "Run agent completion: LLM input not found, make sure to register a context hook before inference"
678 .to_string(),
679 );
680 };
681
682 let llm_config = self.get_llm_config();
683
684 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
685 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
686 let input = LLMStreamInput {
687 model: input.model,
688 messages: input.messages,
689 max_tokens: input.max_tokens,
690 tools: input.tools,
691 stream_channel_tx: internal_tx,
692 };
693
694 let chat_future = async move {
695 chat_stream(&llm_config, input)
696 .await
697 .map_err(|e| e.to_string())
698 };
699
700 let receive_future = async move {
701 while let Some(delta) = internal_rx.recv().await {
702 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
703 break;
704 }
705 }
706 };
707
708 let (chat_result, _) = tokio::join!(chat_future, receive_future);
709 let response = chat_result?;
710 (response.choices[0].message.clone(), response.usage)
711 } else {
712 let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
713 (response.choices[0].message.clone(), response.usage)
714 };
715
716 ctx.state.set_llm_output(response_message, usage);
717
718 if let Some(hook_registry) = &self.hook_registry {
719 hook_registry
720 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
721 .await
722 .map_err(|e| e.to_string())?
723 .ok()?;
724 }
725
726 let llm_output = ctx
727 .state
728 .llm_output
729 .as_ref()
730 .ok_or_else(|| "LLM output is missing from state".to_string())?;
731
732 Ok(ChatMessage::from(llm_output))
733 }
734
735 async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
736 if messages.is_empty() {
738 return Err("At least one message is required".to_string());
739 }
740
741 let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
743 message
744 .content
745 .as_ref()
746 .and_then(|content| content.extract_checkpoint_id())
747 });
748
749 let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
750 db::get_checkpoint(&self.db, checkpoint_id).await?
751 } else {
752 let title = self.generate_session_title(messages).await?;
753
754 let session_id = Uuid::new_v4();
756 let now = chrono::Utc::now();
757 let session = AgentSession {
758 id: session_id,
759 title,
760 agent_id: AgentID::PabloV1,
761 visibility: AgentSessionVisibility::Private,
762 created_at: now,
763 updated_at: now,
764 checkpoints: vec![],
765 };
766 db::create_session(&self.db, &session).await?;
767
768 let checkpoint_id = Uuid::new_v4();
770 let checkpoint = AgentCheckpointListItem {
771 id: checkpoint_id,
772 status: AgentStatus::Complete,
773 execution_depth: 0,
774 parent: None,
775 created_at: now,
776 updated_at: now,
777 };
778 let initial_state = AgentOutput::PabloV1 {
779 messages: messages.to_vec(),
780 node_states: serde_json::json!({}),
781 };
782 db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
783
784 db::get_checkpoint(&self.db, checkpoint_id).await?
785 };
786
787 Ok(current_checkpoint)
788 }
789
790 async fn update_session(
791 &self,
792 checkpoint_info: &RunAgentOutput,
793 new_messages: Vec<ChatMessage>,
794 ) -> Result<RunAgentOutput, String> {
795 let now = chrono::Utc::now();
796 let complete_checkpoint = AgentCheckpointListItem {
797 id: Uuid::new_v4(),
798 status: AgentStatus::Complete,
799 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
800 parent: Some(AgentParentCheckpoint {
801 id: checkpoint_info.checkpoint.id,
802 }),
803 created_at: now,
804 updated_at: now,
805 };
806
807 let mut new_state = checkpoint_info.output.clone();
808 new_state.set_messages(new_messages);
809
810 db::create_checkpoint(
811 &self.db,
812 checkpoint_info.session.id,
813 &complete_checkpoint,
814 &new_state,
815 )
816 .await?;
817
818 Ok(RunAgentOutput {
819 checkpoint: complete_checkpoint,
820 session: checkpoint_info.session.clone(),
821 output: new_state,
822 })
823 }
824
825 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
826 let llm_config = self.get_llm_config();
827
828 let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
829 eco_model.clone()
830 } else if llm_config.openai_config.is_some() {
831 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
832 } else if llm_config.anthropic_config.is_some() {
833 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
834 } else if llm_config.gemini_config.is_some() {
835 LLMModel::Gemini(GeminiModel::Gemini25Flash)
836 } else {
837 return Err("No LLM config found".to_string());
838 };
839
840 let messages = vec![
841 LLMMessage {
842 role: "system".to_string(),
843 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
844 },
845 LLMMessage {
846 role: "user".to_string(),
847 content: LLMMessageContent::String(
848 messages
849 .iter()
850 .map(|msg| {
851 msg.content
852 .as_ref()
853 .unwrap_or(&MessageContent::String("".to_string()))
854 .to_string()
855 })
856 .collect(),
857 ),
858 },
859 ];
860
861 let input = LLMInput {
862 model: llm_model,
863 messages,
864 max_tokens: 100,
865 tools: None,
866 };
867
868 let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
869
870 Ok(response.choices[0].message.content.to_string())
871 }
872}
873
874async fn analyze_search_query(
875 llm_config: &LLMProviderConfig,
876 model: &LLMModel,
877 query: &str,
878) -> Result<AnalysisResult, String> {
879 let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
880
881## Your Task
882
883Analyze the user's search query to:
8841. Identify the specific types of documentation needed
8852. Reformulate the query for optimal search engine results
886
887## Guidelines for Required Documentation
888
889Identify specific documentation types such as:
890- API references and specifications
891- Installation/setup guides
892- Configuration documentation
893- Tutorials and getting started guides
894- Troubleshooting guides
895- Architecture/design documents
896- CLI/command references
897- SDK/library documentation
898
899## Guidelines for Query Reformulation
900
901Create an optimized search query that:
902- Uses specific technical terminology
903- Includes relevant keywords (e.g., "documentation", "guide", "API")
904- Removes ambiguous or filler words
905- Targets authoritative sources when possible
906- Is concise but comprehensive (5-10 words ideal)
907
908## Response Format
909
910Respond ONLY with valid XML in this exact structure:
911
912<analysis>
913 <required_documentation>
914 <item>specific documentation type needed</item>
915 </required_documentation>
916 <reformulated_query>optimized search query string</reformulated_query>
917</analysis>"#;
918
919 let user_prompt = format!(
920 r#"<user_query>{}</user_query>
921
922Analyze this query and provide the required documentation types and an optimized search query."#,
923 query
924 );
925
926 let response = chat(
927 llm_config,
928 LLMInput {
929 model: model.clone(),
930 messages: vec![
931 LLMMessage {
932 role: Role::System.to_string(),
933 content: LLMMessageContent::String(system_prompt.to_string()),
934 },
935 LLMMessage {
936 role: Role::User.to_string(),
937 content: LLMMessageContent::String(user_prompt.to_string()),
938 },
939 ],
940 max_tokens: 2000,
941 tools: None,
942 },
943 )
944 .await
945 .map_err(|e| e.to_string())?;
946
947 let content = response.choices[0].message.content.to_string();
948
949 parse_analysis_xml(&content)
950}
951
952fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
953 let extract_tag = |tag: &str| -> Option<String> {
954 let start_tag = format!("<{}>", tag);
955 let end_tag = format!("</{}>", tag);
956 xml.find(&start_tag).and_then(|start| {
957 let content_start = start + start_tag.len();
958 xml[content_start..]
959 .find(&end_tag)
960 .map(|end| xml[content_start..content_start + end].trim().to_string())
961 })
962 };
963
964 let extract_all_tags = |tag: &str| -> Vec<String> {
965 let start_tag = format!("<{}>", tag);
966 let end_tag = format!("</{}>", tag);
967 let mut results = Vec::new();
968 let mut search_start = 0;
969
970 while let Some(start) = xml[search_start..].find(&start_tag) {
971 let abs_start = search_start + start + start_tag.len();
972 if let Some(end) = xml[abs_start..].find(&end_tag) {
973 results.push(xml[abs_start..abs_start + end].trim().to_string());
974 search_start = abs_start + end + end_tag.len();
975 } else {
976 break;
977 }
978 }
979 results
980 };
981
982 let required_documentation = extract_all_tags("item");
983 let reformulated_query =
984 extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
985
986 Ok(AnalysisResult {
987 required_documentation,
988 reformulated_query,
989 })
990}
991
992async fn validate_search_docs(
993 llm_config: &LLMProviderConfig,
994 model: &LLMModel,
995 docs: &[ScrapedContent],
996 query: &str,
997 required_documentation: &[String],
998 previous_queries: &[String],
999 accumulated_needed_urls: &[String],
1000) -> Result<ValidationResult, String> {
1001 let docs_preview = docs
1002 .iter()
1003 .enumerate()
1004 .take(10)
1005 .map(|(i, r)| {
1006 format!(
1007 "<doc index=\"{}\">\n <title>{}</title>\n <url>{}</url>\n</doc>",
1008 i + 1,
1009 r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1010 r.url
1011 )
1012 })
1013 .collect::<Vec<_>>()
1014 .join("\n");
1015
1016 let required_docs_formatted = required_documentation
1017 .iter()
1018 .map(|d| format!(" <item>{}</item>", d))
1019 .collect::<Vec<_>>()
1020 .join("\n");
1021
1022 let previous_queries_formatted = previous_queries
1023 .iter()
1024 .map(|q| format!(" <query>{}</query>", q))
1025 .collect::<Vec<_>>()
1026 .join("\n");
1027
1028 let accumulated_urls_formatted = accumulated_needed_urls
1029 .iter()
1030 .map(|u| format!(" <url>{}</url>", u))
1031 .collect::<Vec<_>>()
1032 .join("\n");
1033
1034 let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1035
1036## Evaluation Criteria
1037
1038For each search result, assess:
10391. **Relevance**: Does the document directly address the required documentation topics?
10402. **Authority**: Is this an official source, documentation site, or authoritative reference?
10413. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10424. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1043
1044## Decision Guidelines
1045
1046Mark results as SATISFIED when:
1047- All required documentation topics have at least one authoritative source
1048- The sources provide actionable, detailed information
1049- No critical gaps remain in coverage
1050
1051Suggest a NEW QUERY when:
1052- Key topics are missing from results
1053- Results are too general or tangential
1054- A more specific query would yield better results
1055- Previous queries haven't addressed certain requirements
1056
1057## Response Format
1058
1059Respond ONLY with valid XML in this exact structure:
1060
1061<validation>
1062 <is_satisfied>true or false</is_satisfied>
1063 <valid_docs>
1064 <doc><url>exact URL from results</url></doc>
1065 </valid_docs>
1066 <needed_urls>
1067 <url>specific URL pattern or domain still needed</url>
1068 </needed_urls>
1069 <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1070 <reasoning>brief explanation of your assessment</reasoning>
1071</validation>"#;
1072
1073 let user_prompt = format!(
1074 r#"<search_context>
1075 <original_query>{}</original_query>
1076 <required_documentation>
1077{}
1078 </required_documentation>
1079 <previous_queries>
1080{}
1081 </previous_queries>
1082 <accumulated_needed_urls>
1083{}
1084 </accumulated_needed_urls>
1085</search_context>
1086
1087<current_results>
1088{}
1089</current_results>
1090
1091Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1092 query,
1093 if required_docs_formatted.is_empty() {
1094 " <item>None specified</item>".to_string()
1095 } else {
1096 required_docs_formatted
1097 },
1098 if previous_queries_formatted.is_empty() {
1099 " <query>None</query>".to_string()
1100 } else {
1101 previous_queries_formatted
1102 },
1103 if accumulated_urls_formatted.is_empty() {
1104 " <url>None</url>".to_string()
1105 } else {
1106 accumulated_urls_formatted
1107 },
1108 docs_preview
1109 );
1110
1111 let response = chat(
1112 llm_config,
1113 LLMInput {
1114 model: model.clone(),
1115 messages: vec![
1116 LLMMessage {
1117 role: Role::System.to_string(),
1118 content: LLMMessageContent::String(system_prompt.to_string()),
1119 },
1120 LLMMessage {
1121 role: Role::User.to_string(),
1122 content: LLMMessageContent::String(user_prompt.to_string()),
1123 },
1124 ],
1125 max_tokens: 4000,
1126 tools: None,
1127 },
1128 )
1129 .await
1130 .map_err(|e| e.to_string())?;
1131
1132 let content = response.choices[0].message.content.to_string();
1133
1134 let validation = parse_validation_xml(&content, docs)?;
1135
1136 Ok(validation)
1137}
1138
1139fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1140 let extract_tag = |tag: &str| -> Option<String> {
1141 let start_tag = format!("<{}>", tag);
1142 let end_tag = format!("</{}>", tag);
1143 xml.find(&start_tag).and_then(|start| {
1144 let content_start = start + start_tag.len();
1145 xml[content_start..]
1146 .find(&end_tag)
1147 .map(|end| xml[content_start..content_start + end].trim().to_string())
1148 })
1149 };
1150
1151 let extract_all_tags = |tag: &str| -> Vec<String> {
1152 let start_tag = format!("<{}>", tag);
1153 let end_tag = format!("</{}>", tag);
1154 let mut results = Vec::new();
1155 let mut search_start = 0;
1156
1157 while let Some(start) = xml[search_start..].find(&start_tag) {
1158 let abs_start = search_start + start + start_tag.len();
1159 if let Some(end) = xml[abs_start..].find(&end_tag) {
1160 results.push(xml[abs_start..abs_start + end].trim().to_string());
1161 search_start = abs_start + end + end_tag.len();
1162 } else {
1163 break;
1164 }
1165 }
1166 results
1167 };
1168
1169 let is_satisfied = extract_tag("is_satisfied")
1170 .map(|s| s.to_lowercase() == "true")
1171 .unwrap_or(false);
1172
1173 let valid_urls: Vec<String> = extract_all_tags("url")
1174 .into_iter()
1175 .filter(|url| docs.iter().any(|d| d.url == *url))
1176 .collect();
1177
1178 let valid_docs: Vec<ScrapedContent> = valid_urls
1179 .iter()
1180 .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1181 .collect();
1182
1183 let needed_urls: Vec<String> = extract_all_tags("url")
1184 .into_iter()
1185 .filter(|url| !docs.iter().any(|d| d.url == *url))
1186 .collect();
1187
1188 let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1189
1190 Ok(ValidationResult {
1191 is_satisfied,
1192 valid_docs,
1193 needed_urls,
1194 new_query,
1195 })
1196}
1197
1198fn get_search_model(
1199 llm_config: &LLMProviderConfig,
1200 eco_model: Option<LLMModel>,
1201 smart_model: Option<LLMModel>,
1202) -> LLMModel {
1203 let base_model = eco_model.or(smart_model);
1204
1205 match base_model {
1206 Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1207 Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1208 Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1209 Some(LLMModel::Custom(model)) => LLMModel::Custom(model),
1210 None => {
1211 if llm_config.openai_config.is_some() {
1212 LLMModel::OpenAI(OpenAIModel::O4Mini)
1213 } else if llm_config.anthropic_config.is_some() {
1214 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1215 } else if llm_config.gemini_config.is_some() {
1216 LLMModel::Gemini(GeminiModel::Gemini3Flash)
1217 } else {
1218 LLMModel::OpenAI(OpenAIModel::O4Mini)
1219 }
1220 }
1221 }
1222}