stakpak_api/local/
mod.rs

1// use crate::local::hooks::file_scratchpad_context::{
2//     FileScratchpadContextHook, FileScratchpadContextHookOptions,
3// };
4use crate::local::hooks::inline_scratchpad_context::{
5    InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
17use stakpak_shared::models::integrations::gemini::{GeminiConfig, GeminiModel};
18use stakpak_shared::models::integrations::openai::{
19    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
21    OpenAIModel, Role, Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26    LLMStreamInput, chat, chat_stream,
27};
28use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
29use std::pin::Pin;
30use std::sync::Arc;
31use tokio::sync::mpsc;
32use uuid::Uuid;
33
34mod context_managers;
35mod db;
36mod hooks;
37
38#[cfg(test)]
39mod tests;
40
41#[derive(Clone, Debug)]
42pub struct LocalClient {
43    pub db: Connection,
44    pub stakpak_base_url: Option<String>,
45    pub anthropic_config: Option<AnthropicConfig>,
46    pub openai_config: Option<OpenAIConfig>,
47    pub gemini_config: Option<GeminiConfig>,
48    pub model_options: ModelOptions,
49    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
50    _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
51}
52
53#[derive(Clone, Debug)]
54pub struct ModelOptions {
55    pub smart_model: Option<LLMModel>,
56    pub eco_model: Option<LLMModel>,
57    pub recovery_model: Option<LLMModel>,
58}
59
60#[derive(Clone, Debug)]
61pub struct ModelSet {
62    pub smart_model: LLMModel,
63    pub eco_model: LLMModel,
64    pub recovery_model: LLMModel,
65    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
66    pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
67}
68
69impl ModelSet {
70    fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
71        match agent_model {
72            AgentModel::Smart => self.smart_model.clone(),
73            AgentModel::Eco => self.eco_model.clone(),
74            AgentModel::Recovery => self.recovery_model.clone(),
75        }
76    }
77}
78
79impl From<ModelOptions> for ModelSet {
80    fn from(value: ModelOptions) -> Self {
81        let smart_model = value
82            .smart_model
83            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
84        let eco_model = value
85            .eco_model
86            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
87        let recovery_model = value
88            .recovery_model
89            .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
90
91        Self {
92            smart_model,
93            eco_model,
94            recovery_model,
95            hook_registry: None,
96            _search_services_orchestrator: None,
97        }
98    }
99}
100
101pub struct LocalClientConfig {
102    pub stakpak_base_url: Option<String>,
103    pub store_path: Option<String>,
104    pub anthropic_config: Option<AnthropicConfig>,
105    pub openai_config: Option<OpenAIConfig>,
106    pub gemini_config: Option<GeminiConfig>,
107    pub smart_model: Option<String>,
108    pub eco_model: Option<String>,
109    pub recovery_model: Option<String>,
110    pub hook_registry: Option<HookRegistry<AgentState>>,
111}
112
113#[derive(Debug)]
114enum StreamMessage {
115    Delta(GenerationDelta),
116    Ctx(Box<HookContext<AgentState>>),
117}
118
119const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
120const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
121
122impl LocalClient {
123    pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
124        let store_path = config
125            .store_path
126            .map(std::path::PathBuf::from)
127            .unwrap_or_else(|| {
128                std::env::home_dir()
129                    .unwrap_or_default()
130                    .join(DEFAULT_STORE_PATH)
131            });
132
133        if let Some(parent) = store_path.parent() {
134            std::fs::create_dir_all(parent)
135                .map_err(|e| format!("Failed to create database directory: {}", e))?;
136        }
137
138        let db = Builder::new_local(store_path.display().to_string())
139            .build()
140            .await
141            .map_err(|e| e.to_string())?;
142
143        let conn = db.connect().map_err(|e| e.to_string())?;
144
145        // Initialize database schema
146        db::init_schema(&conn).await?;
147
148        let model_options = ModelOptions {
149            smart_model: config.smart_model.map(LLMModel::from),
150            eco_model: config.eco_model.map(LLMModel::from),
151            recovery_model: config.recovery_model.map(LLMModel::from),
152        };
153
154        // Add hooks
155        let mut hook_registry = config.hook_registry.unwrap_or_default();
156        hook_registry.register(
157            LifecycleEvent::BeforeInference,
158            Box::new(InlineScratchpadContextHook::new(
159                InlineScratchpadContextHookOptions {
160                    model_options: model_options.clone(),
161                    history_action_message_size_limit: Some(100),
162                    history_action_message_keep_last_n: Some(1),
163                    history_action_result_keep_last_n: Some(50),
164                },
165            )),
166        );
167        // hook_registry.register(
168        //     LifecycleEvent::BeforeInference,
169        //     Box::new(FileScratchpadContextHook::new(
170        //         FileScratchpadContextHookOptions {
171        //             history_action_message_size_limit: Some(100),
172        //             history_action_message_keep_last_n: Some(1),
173        //             history_action_result_keep_last_n: Some(50),
174        //             scratchpad_path: None,
175        //             todo_path: None,
176        //             model_options: model_options.clone(),
177        //             overwrite_if_different: Some(true),
178        //         },
179        //     )),
180        // );
181
182        Ok(Self {
183            db: conn,
184            stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
185            anthropic_config: config.anthropic_config,
186            gemini_config: config.gemini_config,
187            openai_config: config.openai_config,
188            model_options,
189            hook_registry: Some(Arc::new(hook_registry)),
190            _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
191        })
192    }
193}
194
195#[async_trait]
196impl AgentProvider for LocalClient {
197    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
198        Ok(GetMyAccountResponse {
199            username: "local".to_string(),
200            id: "local".to_string(),
201            first_name: "local".to_string(),
202            last_name: "local".to_string(),
203        })
204    }
205
206    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
207        if self.stakpak_base_url.is_none() {
208            return Ok(vec![]);
209        }
210
211        let stakpak_base_url = self
212            .stakpak_base_url
213            .as_ref()
214            .ok_or("Stakpak base URL not set")?;
215
216        let url = format!("{}/rules", stakpak_base_url);
217
218        let client = create_tls_client(
219            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
220        )?;
221
222        let response = client
223            .get(url)
224            .send()
225            .await
226            .map_err(|e: ReqwestError| e.to_string())?;
227
228        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
229
230        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
231            Ok(response) => Ok(response.results),
232            Err(e) => {
233                eprintln!("Failed to deserialize response: {}", e);
234                eprintln!("Raw response: {}", value);
235                Err("Failed to deserialize response:".into())
236            }
237        }
238    }
239
240    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
241        let stakpak_base_url = self
242            .stakpak_base_url
243            .as_ref()
244            .ok_or("Stakpak base URL not set")?;
245
246        let encoded_uri = urlencoding::encode(uri);
247
248        let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
249
250        let client = create_tls_client(
251            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
252        )?;
253
254        let response = client
255            .get(&url)
256            .send()
257            .await
258            .map_err(|e: ReqwestError| e.to_string())?;
259
260        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
261
262        match serde_json::from_value::<RuleBook>(value.clone()) {
263            Ok(response) => Ok(response),
264            Err(e) => {
265                eprintln!("Failed to deserialize response: {}", e);
266                eprintln!("Raw response: {}", value);
267                Err("Failed to deserialize response:".into())
268            }
269        }
270    }
271
272    async fn create_rulebook(
273        &self,
274        _uri: &str,
275        _description: &str,
276        _content: &str,
277        _tags: Vec<String>,
278        _visibility: Option<RuleBookVisibility>,
279    ) -> Result<CreateRuleBookResponse, String> {
280        // TODO: Implement create rulebook
281        Err("Local provider does not support rulebooks yet".to_string())
282    }
283
284    async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
285        // TODO: Implement delete rulebook
286        Err("Local provider does not support rulebooks yet".to_string())
287    }
288
289    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
290        db::list_sessions(&self.db).await
291    }
292
293    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
294        db::get_session(&self.db, session_id).await
295    }
296
297    async fn get_agent_session_stats(
298        &self,
299        _session_id: Uuid,
300    ) -> Result<AgentSessionStats, String> {
301        // TODO: Implement get agent session stats
302        Ok(AgentSessionStats::default())
303    }
304
305    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
306        db::get_checkpoint(&self.db, checkpoint_id).await
307    }
308
309    async fn get_agent_session_latest_checkpoint(
310        &self,
311        session_id: Uuid,
312    ) -> Result<RunAgentOutput, String> {
313        db::get_latest_checkpoint(&self.db, session_id).await
314    }
315
316    async fn chat_completion(
317        &self,
318        model: AgentModel,
319        messages: Vec<ChatMessage>,
320        tools: Option<Vec<Tool>>,
321    ) -> Result<ChatCompletionResponse, String> {
322        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
323
324        if let Some(hook_registry) = &self.hook_registry {
325            hook_registry
326                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
327                .await
328                .map_err(|e| e.to_string())?
329                .ok()?;
330        }
331
332        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
333        ctx.set_session_id(current_checkpoint.session.id);
334
335        let new_message = self.run_agent_completion(&mut ctx, None).await?;
336        ctx.state.append_new_message(new_message.clone());
337
338        let result = self
339            .update_session(&current_checkpoint, ctx.state.messages.clone())
340            .await?;
341        let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
342        ctx.set_new_checkpoint_id(result.checkpoint.id);
343
344        if let Some(hook_registry) = &self.hook_registry {
345            hook_registry
346                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
347                .await
348                .map_err(|e| e.to_string())?
349                .ok()?;
350        }
351
352        Ok(ChatCompletionResponse {
353            id: ctx.new_checkpoint_id.unwrap().to_string(),
354            object: "chat.completion".to_string(),
355            created: checkpoint_created_at,
356            model: ctx
357                .state
358                .llm_input
359                .as_ref()
360                .map(|llm_input| llm_input.model.clone().to_string())
361                .unwrap_or_default(),
362            choices: vec![ChatCompletionChoice {
363                index: 0,
364                message: ctx.state.messages.last().cloned().unwrap(),
365                logprobs: None,
366                finish_reason: FinishReason::Stop,
367            }],
368            usage: ctx
369                .state
370                .llm_output
371                .as_ref()
372                .map(|u| u.usage.clone())
373                .unwrap_or_default(),
374            system_fingerprint: None,
375            metadata: None,
376        })
377    }
378
379    async fn chat_completion_stream(
380        &self,
381        model: AgentModel,
382        messages: Vec<ChatMessage>,
383        tools: Option<Vec<Tool>>,
384        _headers: Option<HeaderMap>,
385    ) -> Result<
386        (
387            Pin<
388                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
389            >,
390            Option<String>,
391        ),
392        String,
393    > {
394        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
395
396        if let Some(hook_registry) = &self.hook_registry {
397            hook_registry
398                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
399                .await
400                .map_err(|e| e.to_string())?
401                .ok()?;
402        }
403
404        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
405        ctx.set_session_id(current_checkpoint.session.id);
406
407        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
408
409        let _ = tx
410            .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
411                content: format!(
412                    "\n<checkpoint_id>{}</checkpoint_id>\n",
413                    current_checkpoint.checkpoint.id
414                ),
415            })))
416            .await;
417
418        let client = self.clone();
419        let self_clone = self.clone();
420        let mut ctx_clone = ctx.clone();
421        tokio::spawn(async move {
422            let result = client
423                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
424                .await;
425
426            match result {
427                Err(e) => {
428                    let _ = tx.send(Err(e)).await;
429                }
430                Ok(new_message) => {
431                    ctx_clone.state.append_new_message(new_message.clone());
432                    let _ = tx
433                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
434                        .await;
435
436                    let output = self_clone
437                        .update_session(&current_checkpoint, ctx_clone.state.messages.clone())
438                        .await;
439
440                    match output {
441                        Err(e) => {
442                            let _ = tx.send(Err(e)).await;
443                        }
444                        Ok(output) => {
445                            ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
446                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
447                            let _ = tx
448                                .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
449                                    content: format!(
450                                        "\n<checkpoint_id>{}</checkpoint_id>\n",
451                                        output.checkpoint.id
452                                    ),
453                                })))
454                                .await;
455                        }
456                    }
457                }
458            }
459        });
460
461        let hook_registry = self.hook_registry.clone();
462        let stream = async_stream::stream! {
463            while let Some(delta_result) = rx.recv().await {
464                match delta_result {
465                    Ok(delta) => match delta {
466                            StreamMessage::Ctx(updated_ctx) => {
467                                ctx = *updated_ctx;
468                            }
469                            StreamMessage::Delta(delta) => {
470                                yield Ok(ChatCompletionStreamResponse {
471                                    id: ctx.request_id.to_string(),
472                                    object: "chat.completion.chunk".to_string(),
473                                    created: chrono::Utc::now().timestamp() as u64,
474                                    model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
475                                    choices: vec![ChatCompletionStreamChoice {
476                                        index: 0,
477                                        delta: delta.into(),
478                                        finish_reason: None,
479                                    }],
480                                    usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
481                                    metadata: None,
482                                })
483                            }
484                        }
485                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
486                }
487            }
488
489            if let Some(hook_registry) = hook_registry {
490                hook_registry
491                    .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
492                    .await
493                    .map_err(|e| e.to_string())?
494                    .ok()?;
495            }
496        };
497
498        Ok((Box::pin(stream), None))
499    }
500
501    async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
502        Ok(())
503    }
504
505    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
506        let config = SearchServicesOrchestrator::start()
507            .await
508            .map_err(|e| e.to_string())?;
509
510        // SECURITY TODO:
511        // This uses plain-text, unauthenticated HTTP over localhost.
512        // While acceptable for local development, this is an injection
513        //
514        // Mitigations to consider:
515        // - Add mutual authentication (e.g., token)
516        // - Validate the expected service identity
517
518        let api_url = format!("http://localhost:{}", config.api_port);
519        let search_client = SearchClient::new(api_url);
520
521        let initial_query = if let Some(exclude) = &input.exclude_keywords {
522            format!("{} -{}", input.keywords, exclude)
523        } else {
524            input.keywords.clone()
525        };
526
527        let llm_config = self.get_llm_config();
528        let search_model = get_search_model(
529            &llm_config,
530            self.model_options.eco_model.clone(),
531            self.model_options.smart_model.clone(),
532        );
533
534        let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
535        let required_documentation = analysis.required_documentation;
536        let mut current_query = analysis.reformulated_query;
537        let mut previous_queries = Vec::new();
538        let mut final_valid_docs = Vec::new();
539        let mut accumulated_needed_urls = Vec::new();
540
541        const MAX_ITERATIONS: usize = 3;
542
543        for _iteration in 0..MAX_ITERATIONS {
544            previous_queries.push(current_query.clone());
545
546            let search_results = search_client
547                .search_and_scrape(current_query.clone(), None)
548                .await
549                .map_err(|e| e.to_string())?;
550
551            if search_results.is_empty() {
552                break;
553            }
554
555            let validation_result = validate_search_docs(
556                &llm_config,
557                &search_model,
558                &search_results,
559                &current_query,
560                &required_documentation,
561                &previous_queries,
562                &accumulated_needed_urls,
563            )
564            .await?;
565
566            for url in &validation_result.needed_urls {
567                if !accumulated_needed_urls.contains(url) {
568                    accumulated_needed_urls.push(url.clone());
569                }
570            }
571
572            for doc in validation_result.valid_docs.into_iter() {
573                let is_duplicate = final_valid_docs
574                    .iter()
575                    .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
576
577                if !is_duplicate {
578                    final_valid_docs.push(doc);
579                }
580            }
581
582            if validation_result.is_satisfied {
583                break;
584            }
585
586            if let Some(new_query) = validation_result.new_query {
587                if new_query != current_query && !previous_queries.contains(&new_query) {
588                    current_query = new_query;
589                } else {
590                    break;
591                }
592            } else {
593                break;
594            }
595        }
596
597        if final_valid_docs.is_empty() {
598            return Ok(vec![Content::text("No results found".to_string())]);
599        }
600
601        let contents: Vec<Content> = final_valid_docs
602            .into_iter()
603            .map(|result| {
604                Content::text(format!(
605                    "Title: {}\nURL: {}\nContent: {}",
606                    result.title.unwrap_or_default(),
607                    result.url,
608                    result.content.unwrap_or_default(),
609                ))
610            })
611            .collect();
612
613        Ok(contents)
614    }
615
616    async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
617        // TODO: Implement search memory
618        Ok(Vec::new())
619    }
620
621    async fn slack_read_messages(
622        &self,
623        _input: &SlackReadMessagesRequest,
624    ) -> Result<Vec<Content>, String> {
625        // TODO: Implement slack read messages
626        Ok(Vec::new())
627    }
628
629    async fn slack_read_replies(
630        &self,
631        _input: &SlackReadRepliesRequest,
632    ) -> Result<Vec<Content>, String> {
633        // TODO: Implement slack read replies
634        Ok(Vec::new())
635    }
636
637    async fn slack_send_message(
638        &self,
639        _input: &SlackSendMessageRequest,
640    ) -> Result<Vec<Content>, String> {
641        // TODO: Implement slack send message
642        Ok(Vec::new())
643    }
644
645    async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
646        // TODO: Implement memorize session
647        Ok(())
648    }
649}
650
651impl LocalClient {
652    fn get_llm_config(&self) -> LLMProviderConfig {
653        LLMProviderConfig {
654            anthropic_config: self.anthropic_config.clone(),
655            openai_config: self.openai_config.clone(),
656            gemini_config: self.gemini_config.clone(),
657        }
658    }
659
660    async fn run_agent_completion(
661        &self,
662        ctx: &mut HookContext<AgentState>,
663        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
664    ) -> Result<ChatMessage, String> {
665        if let Some(hook_registry) = &self.hook_registry {
666            hook_registry
667                .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
668                .await
669                .map_err(|e| e.to_string())?
670                .ok()?;
671        }
672
673        let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
674            llm_input
675        } else {
676            return Err(
677                "Run agent completion: LLM input not found, make sure to register a context hook before inference"
678                    .to_string(),
679            );
680        };
681
682        let llm_config = self.get_llm_config();
683
684        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
685            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
686            let input = LLMStreamInput {
687                model: input.model,
688                messages: input.messages,
689                max_tokens: input.max_tokens,
690                tools: input.tools,
691                stream_channel_tx: internal_tx,
692            };
693
694            let chat_future = async move {
695                chat_stream(&llm_config, input)
696                    .await
697                    .map_err(|e| e.to_string())
698            };
699
700            let receive_future = async move {
701                while let Some(delta) = internal_rx.recv().await {
702                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
703                        break;
704                    }
705                }
706            };
707
708            let (chat_result, _) = tokio::join!(chat_future, receive_future);
709            let response = chat_result?;
710            (response.choices[0].message.clone(), response.usage)
711        } else {
712            let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
713            (response.choices[0].message.clone(), response.usage)
714        };
715
716        ctx.state.set_llm_output(response_message, usage);
717
718        if let Some(hook_registry) = &self.hook_registry {
719            hook_registry
720                .execute_hooks(ctx, &LifecycleEvent::AfterInference)
721                .await
722                .map_err(|e| e.to_string())?
723                .ok()?;
724        }
725
726        let llm_output = ctx
727            .state
728            .llm_output
729            .as_ref()
730            .ok_or_else(|| "LLM output is missing from state".to_string())?;
731
732        Ok(ChatMessage::from(llm_output))
733    }
734
735    async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
736        // 1. Validate input
737        if messages.is_empty() {
738            return Err("At least one message is required".to_string());
739        }
740
741        // 2. Extract session/checkpoint ID or create new session
742        let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
743            message
744                .content
745                .as_ref()
746                .and_then(|content| content.extract_checkpoint_id())
747        });
748
749        let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
750            db::get_checkpoint(&self.db, checkpoint_id).await?
751        } else {
752            let title = self.generate_session_title(messages).await?;
753
754            // Create new session
755            let session_id = Uuid::new_v4();
756            let now = chrono::Utc::now();
757            let session = AgentSession {
758                id: session_id,
759                title,
760                agent_id: AgentID::PabloV1,
761                visibility: AgentSessionVisibility::Private,
762                created_at: now,
763                updated_at: now,
764                checkpoints: vec![],
765            };
766            db::create_session(&self.db, &session).await?;
767
768            // Create initial checkpoint (root)
769            let checkpoint_id = Uuid::new_v4();
770            let checkpoint = AgentCheckpointListItem {
771                id: checkpoint_id,
772                status: AgentStatus::Complete,
773                execution_depth: 0,
774                parent: None,
775                created_at: now,
776                updated_at: now,
777            };
778            let initial_state = AgentOutput::PabloV1 {
779                messages: messages.to_vec(),
780                node_states: serde_json::json!({}),
781            };
782            db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
783
784            db::get_checkpoint(&self.db, checkpoint_id).await?
785        };
786
787        Ok(current_checkpoint)
788    }
789
790    async fn update_session(
791        &self,
792        checkpoint_info: &RunAgentOutput,
793        new_messages: Vec<ChatMessage>,
794    ) -> Result<RunAgentOutput, String> {
795        let now = chrono::Utc::now();
796        let complete_checkpoint = AgentCheckpointListItem {
797            id: Uuid::new_v4(),
798            status: AgentStatus::Complete,
799            execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
800            parent: Some(AgentParentCheckpoint {
801                id: checkpoint_info.checkpoint.id,
802            }),
803            created_at: now,
804            updated_at: now,
805        };
806
807        let mut new_state = checkpoint_info.output.clone();
808        new_state.set_messages(new_messages);
809
810        db::create_checkpoint(
811            &self.db,
812            checkpoint_info.session.id,
813            &complete_checkpoint,
814            &new_state,
815        )
816        .await?;
817
818        Ok(RunAgentOutput {
819            checkpoint: complete_checkpoint,
820            session: checkpoint_info.session.clone(),
821            output: new_state,
822        })
823    }
824
825    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
826        let llm_config = self.get_llm_config();
827
828        let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
829            eco_model.clone()
830        } else if llm_config.openai_config.is_some() {
831            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
832        } else if llm_config.anthropic_config.is_some() {
833            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
834        } else if llm_config.gemini_config.is_some() {
835            LLMModel::Gemini(GeminiModel::Gemini25Flash)
836        } else {
837            return Err("No LLM config found".to_string());
838        };
839
840        let messages = vec![
841            LLMMessage {
842                role: "system".to_string(),
843                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
844            },
845            LLMMessage {
846                role: "user".to_string(),
847                content: LLMMessageContent::String(
848                    messages
849                        .iter()
850                        .map(|msg| {
851                            msg.content
852                                .as_ref()
853                                .unwrap_or(&MessageContent::String("".to_string()))
854                                .to_string()
855                        })
856                        .collect(),
857                ),
858            },
859        ];
860
861        let input = LLMInput {
862            model: llm_model,
863            messages,
864            max_tokens: 100,
865            tools: None,
866        };
867
868        let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
869
870        Ok(response.choices[0].message.content.to_string())
871    }
872}
873
874async fn analyze_search_query(
875    llm_config: &LLMProviderConfig,
876    model: &LLMModel,
877    query: &str,
878) -> Result<AnalysisResult, String> {
879    let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
880
881## Your Task
882
883Analyze the user's search query to:
8841. Identify the specific types of documentation needed
8852. Reformulate the query for optimal search engine results
886
887## Guidelines for Required Documentation
888
889Identify specific documentation types such as:
890- API references and specifications
891- Installation/setup guides
892- Configuration documentation
893- Tutorials and getting started guides
894- Troubleshooting guides
895- Architecture/design documents
896- CLI/command references
897- SDK/library documentation
898
899## Guidelines for Query Reformulation
900
901Create an optimized search query that:
902- Uses specific technical terminology
903- Includes relevant keywords (e.g., "documentation", "guide", "API")
904- Removes ambiguous or filler words
905- Targets authoritative sources when possible
906- Is concise but comprehensive (5-10 words ideal)
907
908## Response Format
909
910Respond ONLY with valid XML in this exact structure:
911
912<analysis>
913  <required_documentation>
914    <item>specific documentation type needed</item>
915  </required_documentation>
916  <reformulated_query>optimized search query string</reformulated_query>
917</analysis>"#;
918
919    let user_prompt = format!(
920        r#"<user_query>{}</user_query>
921
922Analyze this query and provide the required documentation types and an optimized search query."#,
923        query
924    );
925
926    let response = chat(
927        llm_config,
928        LLMInput {
929            model: model.clone(),
930            messages: vec![
931                LLMMessage {
932                    role: Role::System.to_string(),
933                    content: LLMMessageContent::String(system_prompt.to_string()),
934                },
935                LLMMessage {
936                    role: Role::User.to_string(),
937                    content: LLMMessageContent::String(user_prompt.to_string()),
938                },
939            ],
940            max_tokens: 2000,
941            tools: None,
942        },
943    )
944    .await
945    .map_err(|e| e.to_string())?;
946
947    let content = response.choices[0].message.content.to_string();
948
949    parse_analysis_xml(&content)
950}
951
952fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
953    let extract_tag = |tag: &str| -> Option<String> {
954        let start_tag = format!("<{}>", tag);
955        let end_tag = format!("</{}>", tag);
956        xml.find(&start_tag).and_then(|start| {
957            let content_start = start + start_tag.len();
958            xml[content_start..]
959                .find(&end_tag)
960                .map(|end| xml[content_start..content_start + end].trim().to_string())
961        })
962    };
963
964    let extract_all_tags = |tag: &str| -> Vec<String> {
965        let start_tag = format!("<{}>", tag);
966        let end_tag = format!("</{}>", tag);
967        let mut results = Vec::new();
968        let mut search_start = 0;
969
970        while let Some(start) = xml[search_start..].find(&start_tag) {
971            let abs_start = search_start + start + start_tag.len();
972            if let Some(end) = xml[abs_start..].find(&end_tag) {
973                results.push(xml[abs_start..abs_start + end].trim().to_string());
974                search_start = abs_start + end + end_tag.len();
975            } else {
976                break;
977            }
978        }
979        results
980    };
981
982    let required_documentation = extract_all_tags("item");
983    let reformulated_query =
984        extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
985
986    Ok(AnalysisResult {
987        required_documentation,
988        reformulated_query,
989    })
990}
991
992async fn validate_search_docs(
993    llm_config: &LLMProviderConfig,
994    model: &LLMModel,
995    docs: &[ScrapedContent],
996    query: &str,
997    required_documentation: &[String],
998    previous_queries: &[String],
999    accumulated_needed_urls: &[String],
1000) -> Result<ValidationResult, String> {
1001    let docs_preview = docs
1002        .iter()
1003        .enumerate()
1004        .take(10)
1005        .map(|(i, r)| {
1006            format!(
1007                "<doc index=\"{}\">\n  <title>{}</title>\n  <url>{}</url>\n</doc>",
1008                i + 1,
1009                r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1010                r.url
1011            )
1012        })
1013        .collect::<Vec<_>>()
1014        .join("\n");
1015
1016    let required_docs_formatted = required_documentation
1017        .iter()
1018        .map(|d| format!("  <item>{}</item>", d))
1019        .collect::<Vec<_>>()
1020        .join("\n");
1021
1022    let previous_queries_formatted = previous_queries
1023        .iter()
1024        .map(|q| format!("  <query>{}</query>", q))
1025        .collect::<Vec<_>>()
1026        .join("\n");
1027
1028    let accumulated_urls_formatted = accumulated_needed_urls
1029        .iter()
1030        .map(|u| format!("  <url>{}</url>", u))
1031        .collect::<Vec<_>>()
1032        .join("\n");
1033
1034    let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1035
1036## Evaluation Criteria
1037
1038For each search result, assess:
10391. **Relevance**: Does the document directly address the required documentation topics?
10402. **Authority**: Is this an official source, documentation site, or authoritative reference?
10413. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10424. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1043
1044## Decision Guidelines
1045
1046Mark results as SATISFIED when:
1047- All required documentation topics have at least one authoritative source
1048- The sources provide actionable, detailed information
1049- No critical gaps remain in coverage
1050
1051Suggest a NEW QUERY when:
1052- Key topics are missing from results
1053- Results are too general or tangential
1054- A more specific query would yield better results
1055- Previous queries haven't addressed certain requirements
1056
1057## Response Format
1058
1059Respond ONLY with valid XML in this exact structure:
1060
1061<validation>
1062  <is_satisfied>true or false</is_satisfied>
1063  <valid_docs>
1064    <doc><url>exact URL from results</url></doc>
1065  </valid_docs>
1066  <needed_urls>
1067    <url>specific URL pattern or domain still needed</url>
1068  </needed_urls>
1069  <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1070  <reasoning>brief explanation of your assessment</reasoning>
1071</validation>"#;
1072
1073    let user_prompt = format!(
1074        r#"<search_context>
1075  <original_query>{}</original_query>
1076  <required_documentation>
1077{}
1078  </required_documentation>
1079  <previous_queries>
1080{}
1081  </previous_queries>
1082  <accumulated_needed_urls>
1083{}
1084  </accumulated_needed_urls>
1085</search_context>
1086
1087<current_results>
1088{}
1089</current_results>
1090
1091Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1092        query,
1093        if required_docs_formatted.is_empty() {
1094            "    <item>None specified</item>".to_string()
1095        } else {
1096            required_docs_formatted
1097        },
1098        if previous_queries_formatted.is_empty() {
1099            "    <query>None</query>".to_string()
1100        } else {
1101            previous_queries_formatted
1102        },
1103        if accumulated_urls_formatted.is_empty() {
1104            "    <url>None</url>".to_string()
1105        } else {
1106            accumulated_urls_formatted
1107        },
1108        docs_preview
1109    );
1110
1111    let response = chat(
1112        llm_config,
1113        LLMInput {
1114            model: model.clone(),
1115            messages: vec![
1116                LLMMessage {
1117                    role: Role::System.to_string(),
1118                    content: LLMMessageContent::String(system_prompt.to_string()),
1119                },
1120                LLMMessage {
1121                    role: Role::User.to_string(),
1122                    content: LLMMessageContent::String(user_prompt.to_string()),
1123                },
1124            ],
1125            max_tokens: 4000,
1126            tools: None,
1127        },
1128    )
1129    .await
1130    .map_err(|e| e.to_string())?;
1131
1132    let content = response.choices[0].message.content.to_string();
1133
1134    let validation = parse_validation_xml(&content, docs)?;
1135
1136    Ok(validation)
1137}
1138
1139fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1140    let extract_tag = |tag: &str| -> Option<String> {
1141        let start_tag = format!("<{}>", tag);
1142        let end_tag = format!("</{}>", tag);
1143        xml.find(&start_tag).and_then(|start| {
1144            let content_start = start + start_tag.len();
1145            xml[content_start..]
1146                .find(&end_tag)
1147                .map(|end| xml[content_start..content_start + end].trim().to_string())
1148        })
1149    };
1150
1151    let extract_all_tags = |tag: &str| -> Vec<String> {
1152        let start_tag = format!("<{}>", tag);
1153        let end_tag = format!("</{}>", tag);
1154        let mut results = Vec::new();
1155        let mut search_start = 0;
1156
1157        while let Some(start) = xml[search_start..].find(&start_tag) {
1158            let abs_start = search_start + start + start_tag.len();
1159            if let Some(end) = xml[abs_start..].find(&end_tag) {
1160                results.push(xml[abs_start..abs_start + end].trim().to_string());
1161                search_start = abs_start + end + end_tag.len();
1162            } else {
1163                break;
1164            }
1165        }
1166        results
1167    };
1168
1169    let is_satisfied = extract_tag("is_satisfied")
1170        .map(|s| s.to_lowercase() == "true")
1171        .unwrap_or(false);
1172
1173    let valid_urls: Vec<String> = extract_all_tags("url")
1174        .into_iter()
1175        .filter(|url| docs.iter().any(|d| d.url == *url))
1176        .collect();
1177
1178    let valid_docs: Vec<ScrapedContent> = valid_urls
1179        .iter()
1180        .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1181        .collect();
1182
1183    let needed_urls: Vec<String> = extract_all_tags("url")
1184        .into_iter()
1185        .filter(|url| !docs.iter().any(|d| d.url == *url))
1186        .collect();
1187
1188    let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1189
1190    Ok(ValidationResult {
1191        is_satisfied,
1192        valid_docs,
1193        needed_urls,
1194        new_query,
1195    })
1196}
1197
1198fn get_search_model(
1199    llm_config: &LLMProviderConfig,
1200    eco_model: Option<LLMModel>,
1201    smart_model: Option<LLMModel>,
1202) -> LLMModel {
1203    let base_model = eco_model.or(smart_model);
1204
1205    match base_model {
1206        Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1207        Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1208        Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1209        Some(LLMModel::Custom(model)) => LLMModel::Custom(model),
1210        None => {
1211            if llm_config.openai_config.is_some() {
1212                LLMModel::OpenAI(OpenAIModel::O4Mini)
1213            } else if llm_config.anthropic_config.is_some() {
1214                LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1215            } else if llm_config.gemini_config.is_some() {
1216                LLMModel::Gemini(GeminiModel::Gemini3Flash)
1217            } else {
1218                LLMModel::OpenAI(OpenAIModel::O4Mini)
1219            }
1220        }
1221    }
1222}