stakpak_api/local/
mod.rs

1// use crate::local::hooks::file_scratchpad_context::{
2//     FileScratchpadContextHook, FileScratchpadContextHookOptions,
3// };
4use crate::local::hooks::inline_scratchpad_context::{
5    InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::AnthropicModel;
17use stakpak_shared::models::integrations::gemini::GeminiModel;
18use stakpak_shared::models::integrations::openai::{
19    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIModel, Role,
21    Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26    LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::StakAIClient;
29use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
30use std::pin::Pin;
31use std::sync::Arc;
32use tokio::sync::mpsc;
33use uuid::Uuid;
34
35mod context_managers;
36mod db;
37mod hooks;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Clone, Debug)]
43pub struct LocalClient {
44    pub db: Connection,
45    pub stakpak_base_url: Option<String>,
46    pub providers: LLMProviderConfig,
47    pub model_options: ModelOptions,
48    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
49    _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
50}
51
52#[derive(Clone, Debug)]
53pub struct ModelOptions {
54    pub smart_model: Option<LLMModel>,
55    pub eco_model: Option<LLMModel>,
56    pub recovery_model: Option<LLMModel>,
57}
58
59#[derive(Clone, Debug)]
60pub struct ModelSet {
61    pub smart_model: LLMModel,
62    pub eco_model: LLMModel,
63    pub recovery_model: LLMModel,
64    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
65    pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
66}
67
68impl ModelSet {
69    fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
70        match agent_model {
71            AgentModel::Smart => self.smart_model.clone(),
72            AgentModel::Eco => self.eco_model.clone(),
73            AgentModel::Recovery => self.recovery_model.clone(),
74        }
75    }
76}
77
78impl From<ModelOptions> for ModelSet {
79    fn from(value: ModelOptions) -> Self {
80        let smart_model = value
81            .smart_model
82            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
83        let eco_model = value
84            .eco_model
85            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
86        let recovery_model = value
87            .recovery_model
88            .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
89
90        Self {
91            smart_model,
92            eco_model,
93            recovery_model,
94            hook_registry: None,
95            _search_services_orchestrator: None,
96        }
97    }
98}
99
100pub struct LocalClientConfig {
101    pub stakpak_base_url: Option<String>,
102    pub store_path: Option<String>,
103    pub providers: LLMProviderConfig,
104    pub smart_model: Option<String>,
105    pub eco_model: Option<String>,
106    pub recovery_model: Option<String>,
107    pub hook_registry: Option<HookRegistry<AgentState>>,
108}
109
110#[derive(Debug)]
111enum StreamMessage {
112    Delta(GenerationDelta),
113    Ctx(Box<HookContext<AgentState>>),
114}
115
116const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
117const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
118
119impl LocalClient {
120    pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
121        let store_path = config
122            .store_path
123            .map(std::path::PathBuf::from)
124            .unwrap_or_else(|| {
125                std::env::home_dir()
126                    .unwrap_or_default()
127                    .join(DEFAULT_STORE_PATH)
128            });
129
130        if let Some(parent) = store_path.parent() {
131            std::fs::create_dir_all(parent)
132                .map_err(|e| format!("Failed to create database directory: {}", e))?;
133        }
134
135        let db = Builder::new_local(store_path.display().to_string())
136            .build()
137            .await
138            .map_err(|e| e.to_string())?;
139
140        let conn = db.connect().map_err(|e| e.to_string())?;
141
142        // Initialize database schema
143        db::init_schema(&conn).await?;
144
145        let model_options = ModelOptions {
146            smart_model: config.smart_model.map(LLMModel::from),
147            eco_model: config.eco_model.map(LLMModel::from),
148            recovery_model: config.recovery_model.map(LLMModel::from),
149        };
150
151        // Add hooks
152        let mut hook_registry = config.hook_registry.unwrap_or_default();
153        hook_registry.register(
154            LifecycleEvent::BeforeInference,
155            Box::new(InlineScratchpadContextHook::new(
156                InlineScratchpadContextHookOptions {
157                    model_options: model_options.clone(),
158                    history_action_message_size_limit: Some(100),
159                    history_action_message_keep_last_n: Some(1),
160                    history_action_result_keep_last_n: Some(50),
161                },
162            )),
163        );
164        // hook_registry.register(
165        //     LifecycleEvent::BeforeInference,
166        //     Box::new(FileScratchpadContextHook::new(
167        //         FileScratchpadContextHookOptions {
168        //             history_action_message_size_limit: Some(100),
169        //             history_action_message_keep_last_n: Some(1),
170        //             history_action_result_keep_last_n: Some(50),
171        //             scratchpad_path: None,
172        //             todo_path: None,
173        //             model_options: model_options.clone(),
174        //             overwrite_if_different: Some(true),
175        //         },
176        //     )),
177        // );
178
179        Ok(Self {
180            db: conn,
181            stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
182            providers: config.providers,
183            model_options,
184            hook_registry: Some(Arc::new(hook_registry)),
185            _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
186        })
187    }
188}
189
190#[async_trait]
191impl AgentProvider for LocalClient {
192    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
193        Ok(GetMyAccountResponse {
194            username: "local".to_string(),
195            id: "local".to_string(),
196            first_name: "local".to_string(),
197            last_name: "local".to_string(),
198            email: "local@stakpak.dev".to_string(),
199            scope: None,
200        })
201    }
202
203    async fn get_billing_info(
204        &self,
205        _account_username: &str,
206    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
207        Err("Billing info not supported in local mode".to_string())
208    }
209
210    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
211        if self.stakpak_base_url.is_none() {
212            return Ok(vec![]);
213        }
214
215        let stakpak_base_url = self
216            .stakpak_base_url
217            .as_ref()
218            .ok_or("Stakpak base URL not set")?;
219
220        let url = format!("{}/rules", stakpak_base_url);
221
222        let client = create_tls_client(
223            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
224        )?;
225
226        let response = client
227            .get(url)
228            .send()
229            .await
230            .map_err(|e: ReqwestError| e.to_string())?;
231
232        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
233
234        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
235            Ok(response) => Ok(response.results),
236            Err(e) => {
237                eprintln!("Failed to deserialize response: {}", e);
238                eprintln!("Raw response: {}", value);
239                Err("Failed to deserialize response:".into())
240            }
241        }
242    }
243
244    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
245        let stakpak_base_url = self
246            .stakpak_base_url
247            .as_ref()
248            .ok_or("Stakpak base URL not set")?;
249
250        let encoded_uri = urlencoding::encode(uri);
251
252        let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
253
254        let client = create_tls_client(
255            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
256        )?;
257
258        let response = client
259            .get(&url)
260            .send()
261            .await
262            .map_err(|e: ReqwestError| e.to_string())?;
263
264        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
265
266        match serde_json::from_value::<RuleBook>(value.clone()) {
267            Ok(response) => Ok(response),
268            Err(_) => {
269                // eprintln!("Failed to deserialize response: {}", e);
270                // eprintln!("Raw response: {}", value);
271                Err("Failed to deserialize response:".into())
272            }
273        }
274    }
275
276    async fn create_rulebook(
277        &self,
278        _uri: &str,
279        _description: &str,
280        _content: &str,
281        _tags: Vec<String>,
282        _visibility: Option<RuleBookVisibility>,
283    ) -> Result<CreateRuleBookResponse, String> {
284        // TODO: Implement create rulebook
285        Err("Local provider does not support rulebooks yet".to_string())
286    }
287
288    async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
289        // TODO: Implement delete rulebook
290        Err("Local provider does not support rulebooks yet".to_string())
291    }
292
293    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
294        db::list_sessions(&self.db).await
295    }
296
297    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
298        db::get_session(&self.db, session_id).await
299    }
300
301    async fn get_agent_session_stats(
302        &self,
303        _session_id: Uuid,
304    ) -> Result<AgentSessionStats, String> {
305        // TODO: Implement get agent session stats
306        Ok(AgentSessionStats::default())
307    }
308
309    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
310        db::get_checkpoint(&self.db, checkpoint_id).await
311    }
312
313    async fn get_agent_session_latest_checkpoint(
314        &self,
315        session_id: Uuid,
316    ) -> Result<RunAgentOutput, String> {
317        db::get_latest_checkpoint(&self.db, session_id).await
318    }
319
320    async fn chat_completion(
321        &self,
322        model: AgentModel,
323        messages: Vec<ChatMessage>,
324        tools: Option<Vec<Tool>>,
325    ) -> Result<ChatCompletionResponse, String> {
326        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
327
328        if let Some(hook_registry) = &self.hook_registry {
329            hook_registry
330                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
331                .await
332                .map_err(|e| e.to_string())?
333                .ok()?;
334        }
335
336        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
337        ctx.set_session_id(current_checkpoint.session.id);
338
339        let new_message = self.run_agent_completion(&mut ctx, None).await?;
340        ctx.state.append_new_message(new_message.clone());
341
342        let result = self
343            .update_session(&current_checkpoint, ctx.state.messages.clone())
344            .await?;
345        let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
346        ctx.set_new_checkpoint_id(result.checkpoint.id);
347
348        if let Some(hook_registry) = &self.hook_registry {
349            hook_registry
350                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
351                .await
352                .map_err(|e| e.to_string())?
353                .ok()?;
354        }
355
356        Ok(ChatCompletionResponse {
357            id: ctx.new_checkpoint_id.unwrap().to_string(),
358            object: "chat.completion".to_string(),
359            created: checkpoint_created_at,
360            model: ctx
361                .state
362                .llm_input
363                .as_ref()
364                .map(|llm_input| llm_input.model.clone().to_string())
365                .unwrap_or_default(),
366            choices: vec![ChatCompletionChoice {
367                index: 0,
368                message: ctx.state.messages.last().cloned().unwrap(),
369                logprobs: None,
370                finish_reason: FinishReason::Stop,
371            }],
372            usage: ctx
373                .state
374                .llm_output
375                .as_ref()
376                .map(|u| u.usage.clone())
377                .unwrap_or_default(),
378            system_fingerprint: None,
379            metadata: None,
380        })
381    }
382
383    async fn chat_completion_stream(
384        &self,
385        model: AgentModel,
386        messages: Vec<ChatMessage>,
387        tools: Option<Vec<Tool>>,
388        _headers: Option<HeaderMap>,
389    ) -> Result<
390        (
391            Pin<
392                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
393            >,
394            Option<String>,
395        ),
396        String,
397    > {
398        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
399
400        if let Some(hook_registry) = &self.hook_registry {
401            hook_registry
402                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
403                .await
404                .map_err(|e| e.to_string())?
405                .ok()?;
406        }
407
408        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
409        ctx.set_session_id(current_checkpoint.session.id);
410
411        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
412
413        let _ = tx
414            .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
415                content: format!(
416                    "\n<checkpoint_id>{}</checkpoint_id>\n",
417                    current_checkpoint.checkpoint.id
418                ),
419            })))
420            .await;
421
422        let client = self.clone();
423        let self_clone = self.clone();
424        let mut ctx_clone = ctx.clone();
425        tokio::spawn(async move {
426            let result = client
427                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
428                .await;
429
430            match result {
431                Err(e) => {
432                    let _ = tx.send(Err(e)).await;
433                }
434                Ok(new_message) => {
435                    ctx_clone.state.append_new_message(new_message.clone());
436                    let _ = tx
437                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
438                        .await;
439
440                    let output = self_clone
441                        .update_session(&current_checkpoint, ctx_clone.state.messages.clone())
442                        .await;
443
444                    match output {
445                        Err(e) => {
446                            let _ = tx.send(Err(e)).await;
447                        }
448                        Ok(output) => {
449                            ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
450                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
451                            let _ = tx
452                                .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
453                                    content: format!(
454                                        "\n<checkpoint_id>{}</checkpoint_id>\n",
455                                        output.checkpoint.id
456                                    ),
457                                })))
458                                .await;
459                        }
460                    }
461                }
462            }
463        });
464
465        let hook_registry = self.hook_registry.clone();
466        let stream = async_stream::stream! {
467            while let Some(delta_result) = rx.recv().await {
468                match delta_result {
469                    Ok(delta) => match delta {
470                            StreamMessage::Ctx(updated_ctx) => {
471                                ctx = *updated_ctx;
472                            }
473                            StreamMessage::Delta(delta) => {
474                                yield Ok(ChatCompletionStreamResponse {
475                                    id: ctx.request_id.to_string(),
476                                    object: "chat.completion.chunk".to_string(),
477                                    created: chrono::Utc::now().timestamp() as u64,
478                                    model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
479                                    choices: vec![ChatCompletionStreamChoice {
480                                        index: 0,
481                                        delta: delta.into(),
482                                        finish_reason: None,
483                                    }],
484                                    usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
485                                    metadata: None,
486                                })
487                            }
488                        }
489                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
490                }
491            }
492
493            if let Some(hook_registry) = hook_registry {
494                hook_registry
495                    .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
496                    .await
497                    .map_err(|e| e.to_string())?
498                    .ok()?;
499            }
500        };
501
502        Ok((Box::pin(stream), None))
503    }
504
505    async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
506        Ok(())
507    }
508
509    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
510        let config = SearchServicesOrchestrator::start()
511            .await
512            .map_err(|e| e.to_string())?;
513
514        // SECURITY TODO:
515        // This uses plain-text, unauthenticated HTTP over localhost.
516        // While acceptable for local development, this is an injection
517        //
518        // Mitigations to consider:
519        // - Add mutual authentication (e.g., token)
520        // - Validate the expected service identity
521
522        let api_url = format!("http://localhost:{}", config.api_port);
523        let search_client = SearchClient::new(api_url);
524
525        let initial_query = if let Some(exclude) = &input.exclude_keywords {
526            format!("{} -{}", input.keywords, exclude)
527        } else {
528            input.keywords.clone()
529        };
530
531        let llm_config = self.get_llm_config();
532        let search_model = get_search_model(
533            &llm_config,
534            self.model_options.eco_model.clone(),
535            self.model_options.smart_model.clone(),
536        );
537
538        let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
539        let required_documentation = analysis.required_documentation;
540        let mut current_query = analysis.reformulated_query;
541        let mut previous_queries = Vec::new();
542        let mut final_valid_docs = Vec::new();
543        let mut accumulated_needed_urls = Vec::new();
544
545        const MAX_ITERATIONS: usize = 3;
546
547        for _iteration in 0..MAX_ITERATIONS {
548            previous_queries.push(current_query.clone());
549
550            let search_results = search_client
551                .search_and_scrape(current_query.clone(), None)
552                .await
553                .map_err(|e| e.to_string())?;
554
555            if search_results.is_empty() {
556                break;
557            }
558
559            let validation_result = validate_search_docs(
560                &llm_config,
561                &search_model,
562                &search_results,
563                &current_query,
564                &required_documentation,
565                &previous_queries,
566                &accumulated_needed_urls,
567            )
568            .await?;
569
570            for url in &validation_result.needed_urls {
571                if !accumulated_needed_urls.contains(url) {
572                    accumulated_needed_urls.push(url.clone());
573                }
574            }
575
576            for doc in validation_result.valid_docs.into_iter() {
577                let is_duplicate = final_valid_docs
578                    .iter()
579                    .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
580
581                if !is_duplicate {
582                    final_valid_docs.push(doc);
583                }
584            }
585
586            if validation_result.is_satisfied {
587                break;
588            }
589
590            if let Some(new_query) = validation_result.new_query {
591                if new_query != current_query && !previous_queries.contains(&new_query) {
592                    current_query = new_query;
593                } else {
594                    break;
595                }
596            } else {
597                break;
598            }
599        }
600
601        if final_valid_docs.is_empty() {
602            return Ok(vec![Content::text("No results found".to_string())]);
603        }
604
605        let contents: Vec<Content> = final_valid_docs
606            .into_iter()
607            .map(|result| {
608                let content = result.content.unwrap_or_default();
609                Content::text(format!("URL: {}\nContent: {}", result.url, content))
610            })
611            .collect();
612
613        Ok(contents)
614    }
615
616    async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
617        // TODO: Implement search memory
618        Ok(Vec::new())
619    }
620
621    async fn slack_read_messages(
622        &self,
623        _input: &SlackReadMessagesRequest,
624    ) -> Result<Vec<Content>, String> {
625        // TODO: Implement slack read messages
626        Ok(Vec::new())
627    }
628
629    async fn slack_read_replies(
630        &self,
631        _input: &SlackReadRepliesRequest,
632    ) -> Result<Vec<Content>, String> {
633        // TODO: Implement slack read replies
634        Ok(Vec::new())
635    }
636
637    async fn slack_send_message(
638        &self,
639        _input: &SlackSendMessageRequest,
640    ) -> Result<Vec<Content>, String> {
641        // TODO: Implement slack send message
642        Ok(Vec::new())
643    }
644
645    async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
646        // TODO: Implement memorize session
647        Ok(())
648    }
649}
650
651impl LocalClient {
652    fn get_llm_config(&self) -> LLMProviderConfig {
653        self.providers.clone()
654    }
655
656    async fn run_agent_completion(
657        &self,
658        ctx: &mut HookContext<AgentState>,
659        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
660    ) -> Result<ChatMessage, String> {
661        if let Some(hook_registry) = &self.hook_registry {
662            hook_registry
663                .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
664                .await
665                .map_err(|e| e.to_string())?
666                .ok()?;
667        }
668
669        let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
670            llm_input
671        } else {
672            return Err(
673                "Run agent completion: LLM input not found, make sure to register a context hook before inference"
674                    .to_string(),
675            );
676        };
677
678        let llm_config = self.get_llm_config();
679        let stakai_client = StakAIClient::new(&llm_config)
680            .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
681
682        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
683            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
684            let input = LLMStreamInput {
685                model: input.model,
686                messages: input.messages,
687                max_tokens: input.max_tokens,
688                tools: input.tools,
689                stream_channel_tx: internal_tx,
690                provider_options: input.provider_options,
691            };
692
693            let chat_future = async move {
694                stakai_client
695                    .chat_stream(input)
696                    .await
697                    .map_err(|e| e.to_string())
698            };
699
700            let receive_future = async move {
701                while let Some(delta) = internal_rx.recv().await {
702                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
703                        break;
704                    }
705                }
706            };
707
708            let (chat_result, _) = tokio::join!(chat_future, receive_future);
709            let response = chat_result?;
710            (response.choices[0].message.clone(), response.usage)
711        } else {
712            let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
713            (response.choices[0].message.clone(), response.usage)
714        };
715
716        ctx.state.set_llm_output(response_message, usage);
717
718        if let Some(hook_registry) = &self.hook_registry {
719            hook_registry
720                .execute_hooks(ctx, &LifecycleEvent::AfterInference)
721                .await
722                .map_err(|e| e.to_string())?
723                .ok()?;
724        }
725
726        let llm_output = ctx
727            .state
728            .llm_output
729            .as_ref()
730            .ok_or_else(|| "LLM output is missing from state".to_string())?;
731
732        Ok(ChatMessage::from(llm_output))
733    }
734
735    async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
736        // 1. Validate input
737        if messages.is_empty() {
738            return Err("At least one message is required".to_string());
739        }
740
741        // 2. Extract session/checkpoint ID or create new session
742        let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
743            message
744                .content
745                .as_ref()
746                .and_then(|content| content.extract_checkpoint_id())
747        });
748
749        let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
750            db::get_checkpoint(&self.db, checkpoint_id).await?
751        } else {
752            let title = self.generate_session_title(messages).await?;
753
754            // Create new session
755            let session_id = Uuid::new_v4();
756            let now = chrono::Utc::now();
757            let session = AgentSession {
758                id: session_id,
759                title,
760                agent_id: AgentID::PabloV1,
761                visibility: AgentSessionVisibility::Private,
762                created_at: now,
763                updated_at: now,
764                checkpoints: vec![],
765            };
766            db::create_session(&self.db, &session).await?;
767
768            // Create initial checkpoint (root)
769            let checkpoint_id = Uuid::new_v4();
770            let checkpoint = AgentCheckpointListItem {
771                id: checkpoint_id,
772                status: AgentStatus::Complete,
773                execution_depth: 0,
774                parent: None,
775                created_at: now,
776                updated_at: now,
777            };
778            let initial_state = AgentOutput::PabloV1 {
779                messages: messages.to_vec(),
780                node_states: serde_json::json!({}),
781            };
782            db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
783
784            db::get_checkpoint(&self.db, checkpoint_id).await?
785        };
786
787        Ok(current_checkpoint)
788    }
789
790    async fn update_session(
791        &self,
792        checkpoint_info: &RunAgentOutput,
793        new_messages: Vec<ChatMessage>,
794    ) -> Result<RunAgentOutput, String> {
795        let now = chrono::Utc::now();
796        let complete_checkpoint = AgentCheckpointListItem {
797            id: Uuid::new_v4(),
798            status: AgentStatus::Complete,
799            execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
800            parent: Some(AgentParentCheckpoint {
801                id: checkpoint_info.checkpoint.id,
802            }),
803            created_at: now,
804            updated_at: now,
805        };
806
807        let mut new_state = checkpoint_info.output.clone();
808        new_state.set_messages(new_messages);
809
810        db::create_checkpoint(
811            &self.db,
812            checkpoint_info.session.id,
813            &complete_checkpoint,
814            &new_state,
815        )
816        .await?;
817
818        Ok(RunAgentOutput {
819            checkpoint: complete_checkpoint,
820            session: checkpoint_info.session.clone(),
821            output: new_state,
822        })
823    }
824
825    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
826        let llm_config = self.get_llm_config();
827
828        let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
829            eco_model.clone()
830        } else if llm_config.get_provider("openai").is_some() {
831            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
832        } else if llm_config.get_provider("anthropic").is_some() {
833            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
834        } else if llm_config.get_provider("gemini").is_some() {
835            LLMModel::Gemini(GeminiModel::Gemini25Flash)
836        } else {
837            return Err("No LLM config found".to_string());
838        };
839
840        let messages = vec![
841            LLMMessage {
842                role: "system".to_string(),
843                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
844            },
845            LLMMessage {
846                role: "user".to_string(),
847                content: LLMMessageContent::String(
848                    messages
849                        .iter()
850                        .map(|msg| {
851                            msg.content
852                                .as_ref()
853                                .unwrap_or(&MessageContent::String("".to_string()))
854                                .to_string()
855                        })
856                        .collect(),
857                ),
858            },
859        ];
860
861        let input = LLMInput {
862            model: llm_model,
863            messages,
864            max_tokens: 100,
865            tools: None,
866            provider_options: None,
867        };
868
869        let stakai_client = StakAIClient::new(&llm_config)
870            .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
871        let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
872
873        Ok(response.choices[0].message.content.to_string())
874    }
875}
876
877async fn analyze_search_query(
878    llm_config: &LLMProviderConfig,
879    model: &LLMModel,
880    query: &str,
881) -> Result<AnalysisResult, String> {
882    let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
883
884## Your Task
885
886Analyze the user's search query to:
8871. Identify the specific types of documentation needed
8882. Reformulate the query for optimal search engine results
889
890## Guidelines for Required Documentation
891
892Identify specific documentation types such as:
893- API references and specifications
894- Installation/setup guides
895- Configuration documentation
896- Tutorials and getting started guides
897- Troubleshooting guides
898- Architecture/design documents
899- CLI/command references
900- SDK/library documentation
901
902## Guidelines for Query Reformulation
903
904Create an optimized search query that:
905- Uses specific technical terminology
906- Includes relevant keywords (e.g., "documentation", "guide", "API")
907- Removes ambiguous or filler words
908- Targets authoritative sources when possible
909- Is concise but comprehensive (5-10 words ideal)
910
911## Response Format
912
913Respond ONLY with valid XML in this exact structure:
914
915<analysis>
916  <required_documentation>
917    <item>specific documentation type needed</item>
918  </required_documentation>
919  <reformulated_query>optimized search query string</reformulated_query>
920</analysis>"#;
921
922    let user_prompt = format!(
923        r#"<user_query>{}</user_query>
924
925Analyze this query and provide the required documentation types and an optimized search query."#,
926        query
927    );
928
929    let input = LLMInput {
930        model: model.clone(),
931        messages: vec![
932            LLMMessage {
933                role: Role::System.to_string(),
934                content: LLMMessageContent::String(system_prompt.to_string()),
935            },
936            LLMMessage {
937                role: Role::User.to_string(),
938                content: LLMMessageContent::String(user_prompt.to_string()),
939            },
940        ],
941        max_tokens: 2000,
942        tools: None,
943        provider_options: None,
944    };
945
946    let stakai_client = StakAIClient::new(llm_config)
947        .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
948    let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
949
950    let content = response.choices[0].message.content.to_string();
951
952    parse_analysis_xml(&content)
953}
954
955fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
956    let extract_tag = |tag: &str| -> Option<String> {
957        let start_tag = format!("<{}>", tag);
958        let end_tag = format!("</{}>", tag);
959        xml.find(&start_tag).and_then(|start| {
960            let content_start = start + start_tag.len();
961            xml[content_start..]
962                .find(&end_tag)
963                .map(|end| xml[content_start..content_start + end].trim().to_string())
964        })
965    };
966
967    let extract_all_tags = |tag: &str| -> Vec<String> {
968        let start_tag = format!("<{}>", tag);
969        let end_tag = format!("</{}>", tag);
970        let mut results = Vec::new();
971        let mut search_start = 0;
972
973        while let Some(start) = xml[search_start..].find(&start_tag) {
974            let abs_start = search_start + start + start_tag.len();
975            if let Some(end) = xml[abs_start..].find(&end_tag) {
976                results.push(xml[abs_start..abs_start + end].trim().to_string());
977                search_start = abs_start + end + end_tag.len();
978            } else {
979                break;
980            }
981        }
982        results
983    };
984
985    let required_documentation = extract_all_tags("item");
986    let reformulated_query =
987        extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
988
989    Ok(AnalysisResult {
990        required_documentation,
991        reformulated_query,
992    })
993}
994
995async fn validate_search_docs(
996    llm_config: &LLMProviderConfig,
997    model: &LLMModel,
998    docs: &[ScrapedContent],
999    query: &str,
1000    required_documentation: &[String],
1001    previous_queries: &[String],
1002    accumulated_needed_urls: &[String],
1003) -> Result<ValidationResult, String> {
1004    let docs_preview = docs
1005        .iter()
1006        .enumerate()
1007        .take(10)
1008        .map(|(i, r)| {
1009            format!(
1010                "<doc index=\"{}\">\n  <title>{}</title>\n  <url>{}</url>\n</doc>",
1011                i + 1,
1012                r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1013                r.url
1014            )
1015        })
1016        .collect::<Vec<_>>()
1017        .join("\n");
1018
1019    let required_docs_formatted = required_documentation
1020        .iter()
1021        .map(|d| format!("  <item>{}</item>", d))
1022        .collect::<Vec<_>>()
1023        .join("\n");
1024
1025    let previous_queries_formatted = previous_queries
1026        .iter()
1027        .map(|q| format!("  <query>{}</query>", q))
1028        .collect::<Vec<_>>()
1029        .join("\n");
1030
1031    let accumulated_urls_formatted = accumulated_needed_urls
1032        .iter()
1033        .map(|u| format!("  <url>{}</url>", u))
1034        .collect::<Vec<_>>()
1035        .join("\n");
1036
1037    let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1038
1039## Evaluation Criteria
1040
1041For each search result, assess:
10421. **Relevance**: Does the document directly address the required documentation topics?
10432. **Authority**: Is this an official source, documentation site, or authoritative reference?
10443. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10454. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1046
1047## Decision Guidelines
1048
1049Mark results as SATISFIED when:
1050- All required documentation topics have at least one authoritative source
1051- The sources provide actionable, detailed information
1052- No critical gaps remain in coverage
1053
1054Suggest a NEW QUERY when:
1055- Key topics are missing from results
1056- Results are too general or tangential
1057- A more specific query would yield better results
1058- Previous queries haven't addressed certain requirements
1059
1060## Response Format
1061
1062Respond ONLY with valid XML in this exact structure:
1063
1064<validation>
1065  <is_satisfied>true or false</is_satisfied>
1066  <valid_docs>
1067    <doc><url>exact URL from results</url></doc>
1068  </valid_docs>
1069  <needed_urls>
1070    <url>specific URL pattern or domain still needed</url>
1071  </needed_urls>
1072  <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1073  <reasoning>brief explanation of your assessment</reasoning>
1074</validation>"#;
1075
1076    let user_prompt = format!(
1077        r#"<search_context>
1078  <original_query>{}</original_query>
1079  <required_documentation>
1080{}
1081  </required_documentation>
1082  <previous_queries>
1083{}
1084  </previous_queries>
1085  <accumulated_needed_urls>
1086{}
1087  </accumulated_needed_urls>
1088</search_context>
1089
1090<current_results>
1091{}
1092</current_results>
1093
1094Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1095        query,
1096        if required_docs_formatted.is_empty() {
1097            "    <item>None specified</item>".to_string()
1098        } else {
1099            required_docs_formatted
1100        },
1101        if previous_queries_formatted.is_empty() {
1102            "    <query>None</query>".to_string()
1103        } else {
1104            previous_queries_formatted
1105        },
1106        if accumulated_urls_formatted.is_empty() {
1107            "    <url>None</url>".to_string()
1108        } else {
1109            accumulated_urls_formatted
1110        },
1111        docs_preview
1112    );
1113
1114    let input = LLMInput {
1115        model: model.clone(),
1116        messages: vec![
1117            LLMMessage {
1118                role: Role::System.to_string(),
1119                content: LLMMessageContent::String(system_prompt.to_string()),
1120            },
1121            LLMMessage {
1122                role: Role::User.to_string(),
1123                content: LLMMessageContent::String(user_prompt.to_string()),
1124            },
1125        ],
1126        max_tokens: 4000,
1127        tools: None,
1128        provider_options: None,
1129    };
1130
1131    let stakai_client = StakAIClient::new(llm_config)
1132        .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
1133    let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
1134
1135    let content = response.choices[0].message.content.to_string();
1136
1137    let validation = parse_validation_xml(&content, docs)?;
1138
1139    Ok(validation)
1140}
1141
1142fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1143    let extract_tag = |tag: &str| -> Option<String> {
1144        let start_tag = format!("<{}>", tag);
1145        let end_tag = format!("</{}>", tag);
1146        xml.find(&start_tag).and_then(|start| {
1147            let content_start = start + start_tag.len();
1148            xml[content_start..]
1149                .find(&end_tag)
1150                .map(|end| xml[content_start..content_start + end].trim().to_string())
1151        })
1152    };
1153
1154    let extract_all_tags = |tag: &str| -> Vec<String> {
1155        let start_tag = format!("<{}>", tag);
1156        let end_tag = format!("</{}>", tag);
1157        let mut results = Vec::new();
1158        let mut search_start = 0;
1159
1160        while let Some(start) = xml[search_start..].find(&start_tag) {
1161            let abs_start = search_start + start + start_tag.len();
1162            if let Some(end) = xml[abs_start..].find(&end_tag) {
1163                results.push(xml[abs_start..abs_start + end].trim().to_string());
1164                search_start = abs_start + end + end_tag.len();
1165            } else {
1166                break;
1167            }
1168        }
1169        results
1170    };
1171
1172    let is_satisfied = extract_tag("is_satisfied")
1173        .map(|s| s.to_lowercase() == "true")
1174        .unwrap_or(false);
1175
1176    let valid_urls: Vec<String> = extract_all_tags("url")
1177        .into_iter()
1178        .filter(|url| docs.iter().any(|d| d.url == *url))
1179        .collect();
1180
1181    let valid_docs: Vec<ScrapedContent> = valid_urls
1182        .iter()
1183        .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1184        .collect();
1185
1186    let needed_urls: Vec<String> = extract_all_tags("url")
1187        .into_iter()
1188        .filter(|url| !docs.iter().any(|d| d.url == *url))
1189        .collect();
1190
1191    let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1192
1193    Ok(ValidationResult {
1194        is_satisfied,
1195        valid_docs,
1196        needed_urls,
1197        new_query,
1198    })
1199}
1200
1201fn get_search_model(
1202    llm_config: &LLMProviderConfig,
1203    eco_model: Option<LLMModel>,
1204    smart_model: Option<LLMModel>,
1205) -> LLMModel {
1206    let base_model = eco_model.or(smart_model);
1207
1208    match base_model {
1209        Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1210        Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1211        Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1212        Some(LLMModel::Custom { provider, model }) => LLMModel::Custom { provider, model },
1213        None => {
1214            if llm_config.get_provider("openai").is_some() {
1215                LLMModel::OpenAI(OpenAIModel::O4Mini)
1216            } else if llm_config.get_provider("anthropic").is_some() {
1217                LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1218            } else if llm_config.get_provider("gemini").is_some() {
1219                LLMModel::Gemini(GeminiModel::Gemini3Flash)
1220            } else {
1221                LLMModel::OpenAI(OpenAIModel::O4Mini)
1222            }
1223        }
1224    }
1225}