1use crate::local::hooks::inline_scratchpad_context::{
5 InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::AnthropicModel;
17use stakpak_shared::models::integrations::gemini::GeminiModel;
18use stakpak_shared::models::integrations::openai::{
19 AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIModel, Role,
21 Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26 LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::StakAIClient;
29use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
30use std::pin::Pin;
31use std::sync::Arc;
32use tokio::sync::mpsc;
33use uuid::Uuid;
34
35mod context_managers;
36mod db;
37mod hooks;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Clone, Debug)]
43pub struct LocalClient {
44 pub db: Connection,
45 pub stakpak_base_url: Option<String>,
46 pub providers: LLMProviderConfig,
47 pub model_options: ModelOptions,
48 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
49 _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
50}
51
52#[derive(Clone, Debug)]
53pub struct ModelOptions {
54 pub smart_model: Option<LLMModel>,
55 pub eco_model: Option<LLMModel>,
56 pub recovery_model: Option<LLMModel>,
57}
58
59#[derive(Clone, Debug)]
60pub struct ModelSet {
61 pub smart_model: LLMModel,
62 pub eco_model: LLMModel,
63 pub recovery_model: LLMModel,
64 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
65 pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
66}
67
68impl ModelSet {
69 fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
70 match agent_model {
71 AgentModel::Smart => self.smart_model.clone(),
72 AgentModel::Eco => self.eco_model.clone(),
73 AgentModel::Recovery => self.recovery_model.clone(),
74 }
75 }
76}
77
78impl From<ModelOptions> for ModelSet {
79 fn from(value: ModelOptions) -> Self {
80 let smart_model = value
81 .smart_model
82 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
83 let eco_model = value
84 .eco_model
85 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
86 let recovery_model = value
87 .recovery_model
88 .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
89
90 Self {
91 smart_model,
92 eco_model,
93 recovery_model,
94 hook_registry: None,
95 _search_services_orchestrator: None,
96 }
97 }
98}
99
100pub struct LocalClientConfig {
101 pub stakpak_base_url: Option<String>,
102 pub store_path: Option<String>,
103 pub providers: LLMProviderConfig,
104 pub smart_model: Option<String>,
105 pub eco_model: Option<String>,
106 pub recovery_model: Option<String>,
107 pub hook_registry: Option<HookRegistry<AgentState>>,
108}
109
110#[derive(Debug)]
111enum StreamMessage {
112 Delta(GenerationDelta),
113 Ctx(Box<HookContext<AgentState>>),
114}
115
116const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
117const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
118
119impl LocalClient {
120 pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
121 let store_path = config
122 .store_path
123 .map(std::path::PathBuf::from)
124 .unwrap_or_else(|| {
125 std::env::home_dir()
126 .unwrap_or_default()
127 .join(DEFAULT_STORE_PATH)
128 });
129
130 if let Some(parent) = store_path.parent() {
131 std::fs::create_dir_all(parent)
132 .map_err(|e| format!("Failed to create database directory: {}", e))?;
133 }
134
135 let db = Builder::new_local(store_path.display().to_string())
136 .build()
137 .await
138 .map_err(|e| e.to_string())?;
139
140 let conn = db.connect().map_err(|e| e.to_string())?;
141
142 db::init_schema(&conn).await?;
144
145 let model_options = ModelOptions {
146 smart_model: config.smart_model.map(LLMModel::from),
147 eco_model: config.eco_model.map(LLMModel::from),
148 recovery_model: config.recovery_model.map(LLMModel::from),
149 };
150
151 let mut hook_registry = config.hook_registry.unwrap_or_default();
153 hook_registry.register(
154 LifecycleEvent::BeforeInference,
155 Box::new(InlineScratchpadContextHook::new(
156 InlineScratchpadContextHookOptions {
157 model_options: model_options.clone(),
158 history_action_message_size_limit: Some(100),
159 history_action_message_keep_last_n: Some(1),
160 history_action_result_keep_last_n: Some(50),
161 },
162 )),
163 );
164 Ok(Self {
180 db: conn,
181 stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
182 providers: config.providers,
183 model_options,
184 hook_registry: Some(Arc::new(hook_registry)),
185 _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
186 })
187 }
188}
189
190#[async_trait]
191impl AgentProvider for LocalClient {
192 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
193 Ok(GetMyAccountResponse {
194 username: "local".to_string(),
195 id: "local".to_string(),
196 first_name: "local".to_string(),
197 last_name: "local".to_string(),
198 email: "local@stakpak.dev".to_string(),
199 scope: None,
200 })
201 }
202
203 async fn get_billing_info(
204 &self,
205 _account_username: &str,
206 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
207 Err("Billing info not supported in local mode".to_string())
208 }
209
210 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
211 if self.stakpak_base_url.is_none() {
212 return Ok(vec![]);
213 }
214
215 let stakpak_base_url = self
216 .stakpak_base_url
217 .as_ref()
218 .ok_or("Stakpak base URL not set")?;
219
220 let url = format!("{}/rules", stakpak_base_url);
221
222 let client = create_tls_client(
223 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
224 )?;
225
226 let response = client
227 .get(url)
228 .send()
229 .await
230 .map_err(|e: ReqwestError| e.to_string())?;
231
232 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
233
234 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
235 Ok(response) => Ok(response.results),
236 Err(e) => {
237 eprintln!("Failed to deserialize response: {}", e);
238 eprintln!("Raw response: {}", value);
239 Err("Failed to deserialize response:".into())
240 }
241 }
242 }
243
244 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
245 let stakpak_base_url = self
246 .stakpak_base_url
247 .as_ref()
248 .ok_or("Stakpak base URL not set")?;
249
250 let encoded_uri = urlencoding::encode(uri);
251
252 let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
253
254 let client = create_tls_client(
255 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
256 )?;
257
258 let response = client
259 .get(&url)
260 .send()
261 .await
262 .map_err(|e: ReqwestError| e.to_string())?;
263
264 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
265
266 match serde_json::from_value::<RuleBook>(value.clone()) {
267 Ok(response) => Ok(response),
268 Err(_) => {
269 Err("Failed to deserialize response:".into())
272 }
273 }
274 }
275
276 async fn create_rulebook(
277 &self,
278 _uri: &str,
279 _description: &str,
280 _content: &str,
281 _tags: Vec<String>,
282 _visibility: Option<RuleBookVisibility>,
283 ) -> Result<CreateRuleBookResponse, String> {
284 Err("Local provider does not support rulebooks yet".to_string())
286 }
287
288 async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
289 Err("Local provider does not support rulebooks yet".to_string())
291 }
292
293 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
294 db::list_sessions(&self.db).await
295 }
296
297 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
298 db::get_session(&self.db, session_id).await
299 }
300
301 async fn get_agent_session_stats(
302 &self,
303 _session_id: Uuid,
304 ) -> Result<AgentSessionStats, String> {
305 Ok(AgentSessionStats::default())
307 }
308
309 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
310 db::get_checkpoint(&self.db, checkpoint_id).await
311 }
312
313 async fn get_agent_session_latest_checkpoint(
314 &self,
315 session_id: Uuid,
316 ) -> Result<RunAgentOutput, String> {
317 db::get_latest_checkpoint(&self.db, session_id).await
318 }
319
320 async fn chat_completion(
321 &self,
322 model: AgentModel,
323 messages: Vec<ChatMessage>,
324 tools: Option<Vec<Tool>>,
325 ) -> Result<ChatCompletionResponse, String> {
326 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
327
328 if let Some(hook_registry) = &self.hook_registry {
329 hook_registry
330 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
331 .await
332 .map_err(|e| e.to_string())?
333 .ok()?;
334 }
335
336 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
337 ctx.set_session_id(current_checkpoint.session.id);
338
339 let new_message = self.run_agent_completion(&mut ctx, None).await?;
340 ctx.state.append_new_message(new_message.clone());
341
342 let result = self
343 .update_session(¤t_checkpoint, ctx.state.messages.clone())
344 .await?;
345 let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
346 ctx.set_new_checkpoint_id(result.checkpoint.id);
347
348 if let Some(hook_registry) = &self.hook_registry {
349 hook_registry
350 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
351 .await
352 .map_err(|e| e.to_string())?
353 .ok()?;
354 }
355
356 Ok(ChatCompletionResponse {
357 id: ctx.new_checkpoint_id.unwrap().to_string(),
358 object: "chat.completion".to_string(),
359 created: checkpoint_created_at,
360 model: ctx
361 .state
362 .llm_input
363 .as_ref()
364 .map(|llm_input| llm_input.model.clone().to_string())
365 .unwrap_or_default(),
366 choices: vec![ChatCompletionChoice {
367 index: 0,
368 message: ctx.state.messages.last().cloned().unwrap(),
369 logprobs: None,
370 finish_reason: FinishReason::Stop,
371 }],
372 usage: ctx
373 .state
374 .llm_output
375 .as_ref()
376 .map(|u| u.usage.clone())
377 .unwrap_or_default(),
378 system_fingerprint: None,
379 metadata: None,
380 })
381 }
382
383 async fn chat_completion_stream(
384 &self,
385 model: AgentModel,
386 messages: Vec<ChatMessage>,
387 tools: Option<Vec<Tool>>,
388 _headers: Option<HeaderMap>,
389 ) -> Result<
390 (
391 Pin<
392 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
393 >,
394 Option<String>,
395 ),
396 String,
397 > {
398 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
399
400 if let Some(hook_registry) = &self.hook_registry {
401 hook_registry
402 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
403 .await
404 .map_err(|e| e.to_string())?
405 .ok()?;
406 }
407
408 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
409 ctx.set_session_id(current_checkpoint.session.id);
410
411 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
412
413 let _ = tx
414 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
415 content: format!(
416 "\n<checkpoint_id>{}</checkpoint_id>\n",
417 current_checkpoint.checkpoint.id
418 ),
419 })))
420 .await;
421
422 let client = self.clone();
423 let self_clone = self.clone();
424 let mut ctx_clone = ctx.clone();
425 tokio::spawn(async move {
426 let result = client
427 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
428 .await;
429
430 match result {
431 Err(e) => {
432 let _ = tx.send(Err(e)).await;
433 }
434 Ok(new_message) => {
435 ctx_clone.state.append_new_message(new_message.clone());
436 let _ = tx
437 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
438 .await;
439
440 let output = self_clone
441 .update_session(¤t_checkpoint, ctx_clone.state.messages.clone())
442 .await;
443
444 match output {
445 Err(e) => {
446 let _ = tx.send(Err(e)).await;
447 }
448 Ok(output) => {
449 ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
450 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
451 let _ = tx
452 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
453 content: format!(
454 "\n<checkpoint_id>{}</checkpoint_id>\n",
455 output.checkpoint.id
456 ),
457 })))
458 .await;
459 }
460 }
461 }
462 }
463 });
464
465 let hook_registry = self.hook_registry.clone();
466 let stream = async_stream::stream! {
467 while let Some(delta_result) = rx.recv().await {
468 match delta_result {
469 Ok(delta) => match delta {
470 StreamMessage::Ctx(updated_ctx) => {
471 ctx = *updated_ctx;
472 }
473 StreamMessage::Delta(delta) => {
474 yield Ok(ChatCompletionStreamResponse {
475 id: ctx.request_id.to_string(),
476 object: "chat.completion.chunk".to_string(),
477 created: chrono::Utc::now().timestamp() as u64,
478 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
479 choices: vec![ChatCompletionStreamChoice {
480 index: 0,
481 delta: delta.into(),
482 finish_reason: None,
483 }],
484 usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
485 metadata: None,
486 })
487 }
488 }
489 Err(e) => yield Err(ApiStreamError::Unknown(e)),
490 }
491 }
492
493 if let Some(hook_registry) = hook_registry {
494 hook_registry
495 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
496 .await
497 .map_err(|e| e.to_string())?
498 .ok()?;
499 }
500 };
501
502 Ok((Box::pin(stream), None))
503 }
504
505 async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
506 Ok(())
507 }
508
509 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
510 let config = SearchServicesOrchestrator::start()
511 .await
512 .map_err(|e| e.to_string())?;
513
514 let api_url = format!("http://localhost:{}", config.api_port);
523 let search_client = SearchClient::new(api_url);
524
525 let initial_query = if let Some(exclude) = &input.exclude_keywords {
526 format!("{} -{}", input.keywords, exclude)
527 } else {
528 input.keywords.clone()
529 };
530
531 let llm_config = self.get_llm_config();
532 let search_model = get_search_model(
533 &llm_config,
534 self.model_options.eco_model.clone(),
535 self.model_options.smart_model.clone(),
536 );
537
538 let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
539 let required_documentation = analysis.required_documentation;
540 let mut current_query = analysis.reformulated_query;
541 let mut previous_queries = Vec::new();
542 let mut final_valid_docs = Vec::new();
543 let mut accumulated_needed_urls = Vec::new();
544
545 const MAX_ITERATIONS: usize = 3;
546
547 for _iteration in 0..MAX_ITERATIONS {
548 previous_queries.push(current_query.clone());
549
550 let search_results = search_client
551 .search_and_scrape(current_query.clone(), None)
552 .await
553 .map_err(|e| e.to_string())?;
554
555 if search_results.is_empty() {
556 break;
557 }
558
559 let validation_result = validate_search_docs(
560 &llm_config,
561 &search_model,
562 &search_results,
563 ¤t_query,
564 &required_documentation,
565 &previous_queries,
566 &accumulated_needed_urls,
567 )
568 .await?;
569
570 for url in &validation_result.needed_urls {
571 if !accumulated_needed_urls.contains(url) {
572 accumulated_needed_urls.push(url.clone());
573 }
574 }
575
576 for doc in validation_result.valid_docs.into_iter() {
577 let is_duplicate = final_valid_docs
578 .iter()
579 .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
580
581 if !is_duplicate {
582 final_valid_docs.push(doc);
583 }
584 }
585
586 if validation_result.is_satisfied {
587 break;
588 }
589
590 if let Some(new_query) = validation_result.new_query {
591 if new_query != current_query && !previous_queries.contains(&new_query) {
592 current_query = new_query;
593 } else {
594 break;
595 }
596 } else {
597 break;
598 }
599 }
600
601 if final_valid_docs.is_empty() {
602 return Ok(vec![Content::text("No results found".to_string())]);
603 }
604
605 let contents: Vec<Content> = final_valid_docs
606 .into_iter()
607 .map(|result| {
608 let content = result.content.unwrap_or_default();
609 Content::text(format!("URL: {}\nContent: {}", result.url, content))
610 })
611 .collect();
612
613 Ok(contents)
614 }
615
616 async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
617 Ok(Vec::new())
619 }
620
621 async fn slack_read_messages(
622 &self,
623 _input: &SlackReadMessagesRequest,
624 ) -> Result<Vec<Content>, String> {
625 Ok(Vec::new())
627 }
628
629 async fn slack_read_replies(
630 &self,
631 _input: &SlackReadRepliesRequest,
632 ) -> Result<Vec<Content>, String> {
633 Ok(Vec::new())
635 }
636
637 async fn slack_send_message(
638 &self,
639 _input: &SlackSendMessageRequest,
640 ) -> Result<Vec<Content>, String> {
641 Ok(Vec::new())
643 }
644
645 async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
646 Ok(())
648 }
649}
650
651impl LocalClient {
652 fn get_llm_config(&self) -> LLMProviderConfig {
653 self.providers.clone()
654 }
655
656 async fn run_agent_completion(
657 &self,
658 ctx: &mut HookContext<AgentState>,
659 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
660 ) -> Result<ChatMessage, String> {
661 if let Some(hook_registry) = &self.hook_registry {
662 hook_registry
663 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
664 .await
665 .map_err(|e| e.to_string())?
666 .ok()?;
667 }
668
669 let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
670 llm_input
671 } else {
672 return Err(
673 "Run agent completion: LLM input not found, make sure to register a context hook before inference"
674 .to_string(),
675 );
676 };
677
678 let llm_config = self.get_llm_config();
679 let stakai_client = StakAIClient::new(&llm_config)
680 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
681
682 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
683 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
684 let input = LLMStreamInput {
685 model: input.model,
686 messages: input.messages,
687 max_tokens: input.max_tokens,
688 tools: input.tools,
689 stream_channel_tx: internal_tx,
690 provider_options: input.provider_options,
691 };
692
693 let chat_future = async move {
694 stakai_client
695 .chat_stream(input)
696 .await
697 .map_err(|e| e.to_string())
698 };
699
700 let receive_future = async move {
701 while let Some(delta) = internal_rx.recv().await {
702 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
703 break;
704 }
705 }
706 };
707
708 let (chat_result, _) = tokio::join!(chat_future, receive_future);
709 let response = chat_result?;
710 (response.choices[0].message.clone(), response.usage)
711 } else {
712 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
713 (response.choices[0].message.clone(), response.usage)
714 };
715
716 ctx.state.set_llm_output(response_message, usage);
717
718 if let Some(hook_registry) = &self.hook_registry {
719 hook_registry
720 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
721 .await
722 .map_err(|e| e.to_string())?
723 .ok()?;
724 }
725
726 let llm_output = ctx
727 .state
728 .llm_output
729 .as_ref()
730 .ok_or_else(|| "LLM output is missing from state".to_string())?;
731
732 Ok(ChatMessage::from(llm_output))
733 }
734
735 async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
736 if messages.is_empty() {
738 return Err("At least one message is required".to_string());
739 }
740
741 let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
743 message
744 .content
745 .as_ref()
746 .and_then(|content| content.extract_checkpoint_id())
747 });
748
749 let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
750 db::get_checkpoint(&self.db, checkpoint_id).await?
751 } else {
752 let title = self.generate_session_title(messages).await?;
753
754 let session_id = Uuid::new_v4();
756 let now = chrono::Utc::now();
757 let session = AgentSession {
758 id: session_id,
759 title,
760 agent_id: AgentID::PabloV1,
761 visibility: AgentSessionVisibility::Private,
762 created_at: now,
763 updated_at: now,
764 checkpoints: vec![],
765 };
766 db::create_session(&self.db, &session).await?;
767
768 let checkpoint_id = Uuid::new_v4();
770 let checkpoint = AgentCheckpointListItem {
771 id: checkpoint_id,
772 status: AgentStatus::Complete,
773 execution_depth: 0,
774 parent: None,
775 created_at: now,
776 updated_at: now,
777 };
778 let initial_state = AgentOutput::PabloV1 {
779 messages: messages.to_vec(),
780 node_states: serde_json::json!({}),
781 };
782 db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
783
784 db::get_checkpoint(&self.db, checkpoint_id).await?
785 };
786
787 Ok(current_checkpoint)
788 }
789
790 async fn update_session(
791 &self,
792 checkpoint_info: &RunAgentOutput,
793 new_messages: Vec<ChatMessage>,
794 ) -> Result<RunAgentOutput, String> {
795 let now = chrono::Utc::now();
796 let complete_checkpoint = AgentCheckpointListItem {
797 id: Uuid::new_v4(),
798 status: AgentStatus::Complete,
799 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
800 parent: Some(AgentParentCheckpoint {
801 id: checkpoint_info.checkpoint.id,
802 }),
803 created_at: now,
804 updated_at: now,
805 };
806
807 let mut new_state = checkpoint_info.output.clone();
808 new_state.set_messages(new_messages);
809
810 db::create_checkpoint(
811 &self.db,
812 checkpoint_info.session.id,
813 &complete_checkpoint,
814 &new_state,
815 )
816 .await?;
817
818 Ok(RunAgentOutput {
819 checkpoint: complete_checkpoint,
820 session: checkpoint_info.session.clone(),
821 output: new_state,
822 })
823 }
824
825 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
826 let llm_config = self.get_llm_config();
827
828 let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
829 eco_model.clone()
830 } else if llm_config.get_provider("openai").is_some() {
831 LLMModel::OpenAI(OpenAIModel::GPT5Mini)
832 } else if llm_config.get_provider("anthropic").is_some() {
833 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
834 } else if llm_config.get_provider("gemini").is_some() {
835 LLMModel::Gemini(GeminiModel::Gemini25Flash)
836 } else {
837 return Err("No LLM config found".to_string());
838 };
839
840 let messages = vec![
841 LLMMessage {
842 role: "system".to_string(),
843 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
844 },
845 LLMMessage {
846 role: "user".to_string(),
847 content: LLMMessageContent::String(
848 messages
849 .iter()
850 .map(|msg| {
851 msg.content
852 .as_ref()
853 .unwrap_or(&MessageContent::String("".to_string()))
854 .to_string()
855 })
856 .collect(),
857 ),
858 },
859 ];
860
861 let input = LLMInput {
862 model: llm_model,
863 messages,
864 max_tokens: 100,
865 tools: None,
866 provider_options: None,
867 };
868
869 let stakai_client = StakAIClient::new(&llm_config)
870 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
871 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
872
873 Ok(response.choices[0].message.content.to_string())
874 }
875}
876
877async fn analyze_search_query(
878 llm_config: &LLMProviderConfig,
879 model: &LLMModel,
880 query: &str,
881) -> Result<AnalysisResult, String> {
882 let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
883
884## Your Task
885
886Analyze the user's search query to:
8871. Identify the specific types of documentation needed
8882. Reformulate the query for optimal search engine results
889
890## Guidelines for Required Documentation
891
892Identify specific documentation types such as:
893- API references and specifications
894- Installation/setup guides
895- Configuration documentation
896- Tutorials and getting started guides
897- Troubleshooting guides
898- Architecture/design documents
899- CLI/command references
900- SDK/library documentation
901
902## Guidelines for Query Reformulation
903
904Create an optimized search query that:
905- Uses specific technical terminology
906- Includes relevant keywords (e.g., "documentation", "guide", "API")
907- Removes ambiguous or filler words
908- Targets authoritative sources when possible
909- Is concise but comprehensive (5-10 words ideal)
910
911## Response Format
912
913Respond ONLY with valid XML in this exact structure:
914
915<analysis>
916 <required_documentation>
917 <item>specific documentation type needed</item>
918 </required_documentation>
919 <reformulated_query>optimized search query string</reformulated_query>
920</analysis>"#;
921
922 let user_prompt = format!(
923 r#"<user_query>{}</user_query>
924
925Analyze this query and provide the required documentation types and an optimized search query."#,
926 query
927 );
928
929 let input = LLMInput {
930 model: model.clone(),
931 messages: vec![
932 LLMMessage {
933 role: Role::System.to_string(),
934 content: LLMMessageContent::String(system_prompt.to_string()),
935 },
936 LLMMessage {
937 role: Role::User.to_string(),
938 content: LLMMessageContent::String(user_prompt.to_string()),
939 },
940 ],
941 max_tokens: 2000,
942 tools: None,
943 provider_options: None,
944 };
945
946 let stakai_client = StakAIClient::new(llm_config)
947 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
948 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
949
950 let content = response.choices[0].message.content.to_string();
951
952 parse_analysis_xml(&content)
953}
954
955fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
956 let extract_tag = |tag: &str| -> Option<String> {
957 let start_tag = format!("<{}>", tag);
958 let end_tag = format!("</{}>", tag);
959 xml.find(&start_tag).and_then(|start| {
960 let content_start = start + start_tag.len();
961 xml[content_start..]
962 .find(&end_tag)
963 .map(|end| xml[content_start..content_start + end].trim().to_string())
964 })
965 };
966
967 let extract_all_tags = |tag: &str| -> Vec<String> {
968 let start_tag = format!("<{}>", tag);
969 let end_tag = format!("</{}>", tag);
970 let mut results = Vec::new();
971 let mut search_start = 0;
972
973 while let Some(start) = xml[search_start..].find(&start_tag) {
974 let abs_start = search_start + start + start_tag.len();
975 if let Some(end) = xml[abs_start..].find(&end_tag) {
976 results.push(xml[abs_start..abs_start + end].trim().to_string());
977 search_start = abs_start + end + end_tag.len();
978 } else {
979 break;
980 }
981 }
982 results
983 };
984
985 let required_documentation = extract_all_tags("item");
986 let reformulated_query =
987 extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
988
989 Ok(AnalysisResult {
990 required_documentation,
991 reformulated_query,
992 })
993}
994
995async fn validate_search_docs(
996 llm_config: &LLMProviderConfig,
997 model: &LLMModel,
998 docs: &[ScrapedContent],
999 query: &str,
1000 required_documentation: &[String],
1001 previous_queries: &[String],
1002 accumulated_needed_urls: &[String],
1003) -> Result<ValidationResult, String> {
1004 let docs_preview = docs
1005 .iter()
1006 .enumerate()
1007 .take(10)
1008 .map(|(i, r)| {
1009 format!(
1010 "<doc index=\"{}\">\n <title>{}</title>\n <url>{}</url>\n</doc>",
1011 i + 1,
1012 r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1013 r.url
1014 )
1015 })
1016 .collect::<Vec<_>>()
1017 .join("\n");
1018
1019 let required_docs_formatted = required_documentation
1020 .iter()
1021 .map(|d| format!(" <item>{}</item>", d))
1022 .collect::<Vec<_>>()
1023 .join("\n");
1024
1025 let previous_queries_formatted = previous_queries
1026 .iter()
1027 .map(|q| format!(" <query>{}</query>", q))
1028 .collect::<Vec<_>>()
1029 .join("\n");
1030
1031 let accumulated_urls_formatted = accumulated_needed_urls
1032 .iter()
1033 .map(|u| format!(" <url>{}</url>", u))
1034 .collect::<Vec<_>>()
1035 .join("\n");
1036
1037 let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1038
1039## Evaluation Criteria
1040
1041For each search result, assess:
10421. **Relevance**: Does the document directly address the required documentation topics?
10432. **Authority**: Is this an official source, documentation site, or authoritative reference?
10443. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10454. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1046
1047## Decision Guidelines
1048
1049Mark results as SATISFIED when:
1050- All required documentation topics have at least one authoritative source
1051- The sources provide actionable, detailed information
1052- No critical gaps remain in coverage
1053
1054Suggest a NEW QUERY when:
1055- Key topics are missing from results
1056- Results are too general or tangential
1057- A more specific query would yield better results
1058- Previous queries haven't addressed certain requirements
1059
1060## Response Format
1061
1062Respond ONLY with valid XML in this exact structure:
1063
1064<validation>
1065 <is_satisfied>true or false</is_satisfied>
1066 <valid_docs>
1067 <doc><url>exact URL from results</url></doc>
1068 </valid_docs>
1069 <needed_urls>
1070 <url>specific URL pattern or domain still needed</url>
1071 </needed_urls>
1072 <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1073 <reasoning>brief explanation of your assessment</reasoning>
1074</validation>"#;
1075
1076 let user_prompt = format!(
1077 r#"<search_context>
1078 <original_query>{}</original_query>
1079 <required_documentation>
1080{}
1081 </required_documentation>
1082 <previous_queries>
1083{}
1084 </previous_queries>
1085 <accumulated_needed_urls>
1086{}
1087 </accumulated_needed_urls>
1088</search_context>
1089
1090<current_results>
1091{}
1092</current_results>
1093
1094Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1095 query,
1096 if required_docs_formatted.is_empty() {
1097 " <item>None specified</item>".to_string()
1098 } else {
1099 required_docs_formatted
1100 },
1101 if previous_queries_formatted.is_empty() {
1102 " <query>None</query>".to_string()
1103 } else {
1104 previous_queries_formatted
1105 },
1106 if accumulated_urls_formatted.is_empty() {
1107 " <url>None</url>".to_string()
1108 } else {
1109 accumulated_urls_formatted
1110 },
1111 docs_preview
1112 );
1113
1114 let input = LLMInput {
1115 model: model.clone(),
1116 messages: vec![
1117 LLMMessage {
1118 role: Role::System.to_string(),
1119 content: LLMMessageContent::String(system_prompt.to_string()),
1120 },
1121 LLMMessage {
1122 role: Role::User.to_string(),
1123 content: LLMMessageContent::String(user_prompt.to_string()),
1124 },
1125 ],
1126 max_tokens: 4000,
1127 tools: None,
1128 provider_options: None,
1129 };
1130
1131 let stakai_client = StakAIClient::new(llm_config)
1132 .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
1133 let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
1134
1135 let content = response.choices[0].message.content.to_string();
1136
1137 let validation = parse_validation_xml(&content, docs)?;
1138
1139 Ok(validation)
1140}
1141
1142fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1143 let extract_tag = |tag: &str| -> Option<String> {
1144 let start_tag = format!("<{}>", tag);
1145 let end_tag = format!("</{}>", tag);
1146 xml.find(&start_tag).and_then(|start| {
1147 let content_start = start + start_tag.len();
1148 xml[content_start..]
1149 .find(&end_tag)
1150 .map(|end| xml[content_start..content_start + end].trim().to_string())
1151 })
1152 };
1153
1154 let extract_all_tags = |tag: &str| -> Vec<String> {
1155 let start_tag = format!("<{}>", tag);
1156 let end_tag = format!("</{}>", tag);
1157 let mut results = Vec::new();
1158 let mut search_start = 0;
1159
1160 while let Some(start) = xml[search_start..].find(&start_tag) {
1161 let abs_start = search_start + start + start_tag.len();
1162 if let Some(end) = xml[abs_start..].find(&end_tag) {
1163 results.push(xml[abs_start..abs_start + end].trim().to_string());
1164 search_start = abs_start + end + end_tag.len();
1165 } else {
1166 break;
1167 }
1168 }
1169 results
1170 };
1171
1172 let is_satisfied = extract_tag("is_satisfied")
1173 .map(|s| s.to_lowercase() == "true")
1174 .unwrap_or(false);
1175
1176 let valid_urls: Vec<String> = extract_all_tags("url")
1177 .into_iter()
1178 .filter(|url| docs.iter().any(|d| d.url == *url))
1179 .collect();
1180
1181 let valid_docs: Vec<ScrapedContent> = valid_urls
1182 .iter()
1183 .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1184 .collect();
1185
1186 let needed_urls: Vec<String> = extract_all_tags("url")
1187 .into_iter()
1188 .filter(|url| !docs.iter().any(|d| d.url == *url))
1189 .collect();
1190
1191 let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1192
1193 Ok(ValidationResult {
1194 is_satisfied,
1195 valid_docs,
1196 needed_urls,
1197 new_query,
1198 })
1199}
1200
1201fn get_search_model(
1202 llm_config: &LLMProviderConfig,
1203 eco_model: Option<LLMModel>,
1204 smart_model: Option<LLMModel>,
1205) -> LLMModel {
1206 let base_model = eco_model.or(smart_model);
1207
1208 match base_model {
1209 Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1210 Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1211 Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1212 Some(LLMModel::Custom { provider, model }) => LLMModel::Custom { provider, model },
1213 None => {
1214 if llm_config.get_provider("openai").is_some() {
1215 LLMModel::OpenAI(OpenAIModel::O4Mini)
1216 } else if llm_config.get_provider("anthropic").is_some() {
1217 LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1218 } else if llm_config.get_provider("gemini").is_some() {
1219 LLMModel::Gemini(GeminiModel::Gemini3Flash)
1220 } else {
1221 LLMModel::OpenAI(OpenAIModel::O4Mini)
1222 }
1223 }
1224 }
1225}