stakpak_api/local/
mod.rs

1// use crate::local::hooks::file_scratchpad_context::{
2//     FileScratchpadContextHook, FileScratchpadContextHookOptions,
3// };
4use crate::local::hooks::inline_scratchpad_context::{
5    InlineScratchpadContextHook, InlineScratchpadContextHookOptions,
6};
7use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
8use crate::{ListRuleBook, models::*};
9use async_trait::async_trait;
10use futures_util::Stream;
11use libsql::{Builder, Connection};
12use reqwest::Error as ReqwestError;
13use reqwest::header::HeaderMap;
14use rmcp::model::Content;
15use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
16use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
17use stakpak_shared::models::integrations::gemini::{GeminiConfig, GeminiModel};
18use stakpak_shared::models::integrations::openai::{
19    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
20    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
21    OpenAIModel, Role, Tool,
22};
23use stakpak_shared::models::integrations::search_service::*;
24use stakpak_shared::models::llm::{
25    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
26    LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::StakAIClient;
29use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
30use std::pin::Pin;
31use std::sync::Arc;
32use tokio::sync::mpsc;
33use uuid::Uuid;
34
35mod context_managers;
36mod db;
37mod hooks;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Clone, Debug)]
43pub struct LocalClient {
44    pub db: Connection,
45    pub stakpak_base_url: Option<String>,
46    pub anthropic_config: Option<AnthropicConfig>,
47    pub openai_config: Option<OpenAIConfig>,
48    pub gemini_config: Option<GeminiConfig>,
49    pub model_options: ModelOptions,
50    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
51    _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
52}
53
54#[derive(Clone, Debug)]
55pub struct ModelOptions {
56    pub smart_model: Option<LLMModel>,
57    pub eco_model: Option<LLMModel>,
58    pub recovery_model: Option<LLMModel>,
59}
60
61#[derive(Clone, Debug)]
62pub struct ModelSet {
63    pub smart_model: LLMModel,
64    pub eco_model: LLMModel,
65    pub recovery_model: LLMModel,
66    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
67    pub _search_services_orchestrator: Option<Arc<SearchServicesOrchestrator>>,
68}
69
70impl ModelSet {
71    fn get_model(&self, agent_model: &AgentModel) -> LLMModel {
72        match agent_model {
73            AgentModel::Smart => self.smart_model.clone(),
74            AgentModel::Eco => self.eco_model.clone(),
75            AgentModel::Recovery => self.recovery_model.clone(),
76        }
77    }
78}
79
80impl From<ModelOptions> for ModelSet {
81    fn from(value: ModelOptions) -> Self {
82        let smart_model = value
83            .smart_model
84            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
85        let eco_model = value
86            .eco_model
87            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
88        let recovery_model = value
89            .recovery_model
90            .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
91
92        Self {
93            smart_model,
94            eco_model,
95            recovery_model,
96            hook_registry: None,
97            _search_services_orchestrator: None,
98        }
99    }
100}
101
102pub struct LocalClientConfig {
103    pub stakpak_base_url: Option<String>,
104    pub store_path: Option<String>,
105    pub anthropic_config: Option<AnthropicConfig>,
106    pub openai_config: Option<OpenAIConfig>,
107    pub gemini_config: Option<GeminiConfig>,
108    pub smart_model: Option<String>,
109    pub eco_model: Option<String>,
110    pub recovery_model: Option<String>,
111    pub hook_registry: Option<HookRegistry<AgentState>>,
112}
113
114#[derive(Debug)]
115enum StreamMessage {
116    Delta(GenerationDelta),
117    Ctx(Box<HookContext<AgentState>>),
118}
119
120const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
121const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
122
123impl LocalClient {
124    pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
125        let store_path = config
126            .store_path
127            .map(std::path::PathBuf::from)
128            .unwrap_or_else(|| {
129                std::env::home_dir()
130                    .unwrap_or_default()
131                    .join(DEFAULT_STORE_PATH)
132            });
133
134        if let Some(parent) = store_path.parent() {
135            std::fs::create_dir_all(parent)
136                .map_err(|e| format!("Failed to create database directory: {}", e))?;
137        }
138
139        let db = Builder::new_local(store_path.display().to_string())
140            .build()
141            .await
142            .map_err(|e| e.to_string())?;
143
144        let conn = db.connect().map_err(|e| e.to_string())?;
145
146        // Initialize database schema
147        db::init_schema(&conn).await?;
148
149        let model_options = ModelOptions {
150            smart_model: config.smart_model.map(LLMModel::from),
151            eco_model: config.eco_model.map(LLMModel::from),
152            recovery_model: config.recovery_model.map(LLMModel::from),
153        };
154
155        // Add hooks
156        let mut hook_registry = config.hook_registry.unwrap_or_default();
157        hook_registry.register(
158            LifecycleEvent::BeforeInference,
159            Box::new(InlineScratchpadContextHook::new(
160                InlineScratchpadContextHookOptions {
161                    model_options: model_options.clone(),
162                    history_action_message_size_limit: Some(100),
163                    history_action_message_keep_last_n: Some(1),
164                    history_action_result_keep_last_n: Some(50),
165                },
166            )),
167        );
168        // hook_registry.register(
169        //     LifecycleEvent::BeforeInference,
170        //     Box::new(FileScratchpadContextHook::new(
171        //         FileScratchpadContextHookOptions {
172        //             history_action_message_size_limit: Some(100),
173        //             history_action_message_keep_last_n: Some(1),
174        //             history_action_result_keep_last_n: Some(50),
175        //             scratchpad_path: None,
176        //             todo_path: None,
177        //             model_options: model_options.clone(),
178        //             overwrite_if_different: Some(true),
179        //         },
180        //     )),
181        // );
182
183        Ok(Self {
184            db: conn,
185            stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
186            anthropic_config: config.anthropic_config,
187            gemini_config: config.gemini_config,
188            openai_config: config.openai_config,
189            model_options,
190            hook_registry: Some(Arc::new(hook_registry)),
191            _search_services_orchestrator: Some(Arc::new(SearchServicesOrchestrator)),
192        })
193    }
194}
195
196#[async_trait]
197impl AgentProvider for LocalClient {
198    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
199        Ok(GetMyAccountResponse {
200            username: "local".to_string(),
201            id: "local".to_string(),
202            first_name: "local".to_string(),
203            last_name: "local".to_string(),
204        })
205    }
206
207    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
208        if self.stakpak_base_url.is_none() {
209            return Ok(vec![]);
210        }
211
212        let stakpak_base_url = self
213            .stakpak_base_url
214            .as_ref()
215            .ok_or("Stakpak base URL not set")?;
216
217        let url = format!("{}/rules", stakpak_base_url);
218
219        let client = create_tls_client(
220            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
221        )?;
222
223        let response = client
224            .get(url)
225            .send()
226            .await
227            .map_err(|e: ReqwestError| e.to_string())?;
228
229        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
230
231        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
232            Ok(response) => Ok(response.results),
233            Err(e) => {
234                eprintln!("Failed to deserialize response: {}", e);
235                eprintln!("Raw response: {}", value);
236                Err("Failed to deserialize response:".into())
237            }
238        }
239    }
240
241    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
242        let stakpak_base_url = self
243            .stakpak_base_url
244            .as_ref()
245            .ok_or("Stakpak base URL not set")?;
246
247        let encoded_uri = urlencoding::encode(uri);
248
249        let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
250
251        let client = create_tls_client(
252            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
253        )?;
254
255        let response = client
256            .get(&url)
257            .send()
258            .await
259            .map_err(|e: ReqwestError| e.to_string())?;
260
261        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
262
263        match serde_json::from_value::<RuleBook>(value.clone()) {
264            Ok(response) => Ok(response),
265            Err(e) => {
266                eprintln!("Failed to deserialize response: {}", e);
267                eprintln!("Raw response: {}", value);
268                Err("Failed to deserialize response:".into())
269            }
270        }
271    }
272
273    async fn create_rulebook(
274        &self,
275        _uri: &str,
276        _description: &str,
277        _content: &str,
278        _tags: Vec<String>,
279        _visibility: Option<RuleBookVisibility>,
280    ) -> Result<CreateRuleBookResponse, String> {
281        // TODO: Implement create rulebook
282        Err("Local provider does not support rulebooks yet".to_string())
283    }
284
285    async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
286        // TODO: Implement delete rulebook
287        Err("Local provider does not support rulebooks yet".to_string())
288    }
289
290    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
291        db::list_sessions(&self.db).await
292    }
293
294    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
295        db::get_session(&self.db, session_id).await
296    }
297
298    async fn get_agent_session_stats(
299        &self,
300        _session_id: Uuid,
301    ) -> Result<AgentSessionStats, String> {
302        // TODO: Implement get agent session stats
303        Ok(AgentSessionStats::default())
304    }
305
306    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
307        db::get_checkpoint(&self.db, checkpoint_id).await
308    }
309
310    async fn get_agent_session_latest_checkpoint(
311        &self,
312        session_id: Uuid,
313    ) -> Result<RunAgentOutput, String> {
314        db::get_latest_checkpoint(&self.db, session_id).await
315    }
316
317    async fn chat_completion(
318        &self,
319        model: AgentModel,
320        messages: Vec<ChatMessage>,
321        tools: Option<Vec<Tool>>,
322    ) -> Result<ChatCompletionResponse, String> {
323        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
324
325        if let Some(hook_registry) = &self.hook_registry {
326            hook_registry
327                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
328                .await
329                .map_err(|e| e.to_string())?
330                .ok()?;
331        }
332
333        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
334        ctx.set_session_id(current_checkpoint.session.id);
335
336        let new_message = self.run_agent_completion(&mut ctx, None).await?;
337        ctx.state.append_new_message(new_message.clone());
338
339        let result = self
340            .update_session(&current_checkpoint, ctx.state.messages.clone())
341            .await?;
342        let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
343        ctx.set_new_checkpoint_id(result.checkpoint.id);
344
345        if let Some(hook_registry) = &self.hook_registry {
346            hook_registry
347                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
348                .await
349                .map_err(|e| e.to_string())?
350                .ok()?;
351        }
352
353        Ok(ChatCompletionResponse {
354            id: ctx.new_checkpoint_id.unwrap().to_string(),
355            object: "chat.completion".to_string(),
356            created: checkpoint_created_at,
357            model: ctx
358                .state
359                .llm_input
360                .as_ref()
361                .map(|llm_input| llm_input.model.clone().to_string())
362                .unwrap_or_default(),
363            choices: vec![ChatCompletionChoice {
364                index: 0,
365                message: ctx.state.messages.last().cloned().unwrap(),
366                logprobs: None,
367                finish_reason: FinishReason::Stop,
368            }],
369            usage: ctx
370                .state
371                .llm_output
372                .as_ref()
373                .map(|u| u.usage.clone())
374                .unwrap_or_default(),
375            system_fingerprint: None,
376            metadata: None,
377        })
378    }
379
380    async fn chat_completion_stream(
381        &self,
382        model: AgentModel,
383        messages: Vec<ChatMessage>,
384        tools: Option<Vec<Tool>>,
385        _headers: Option<HeaderMap>,
386    ) -> Result<
387        (
388            Pin<
389                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
390            >,
391            Option<String>,
392        ),
393        String,
394    > {
395        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
396
397        if let Some(hook_registry) = &self.hook_registry {
398            hook_registry
399                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
400                .await
401                .map_err(|e| e.to_string())?
402                .ok()?;
403        }
404
405        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
406        ctx.set_session_id(current_checkpoint.session.id);
407
408        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
409
410        let _ = tx
411            .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
412                content: format!(
413                    "\n<checkpoint_id>{}</checkpoint_id>\n",
414                    current_checkpoint.checkpoint.id
415                ),
416            })))
417            .await;
418
419        let client = self.clone();
420        let self_clone = self.clone();
421        let mut ctx_clone = ctx.clone();
422        tokio::spawn(async move {
423            let result = client
424                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
425                .await;
426
427            match result {
428                Err(e) => {
429                    let _ = tx.send(Err(e)).await;
430                }
431                Ok(new_message) => {
432                    ctx_clone.state.append_new_message(new_message.clone());
433                    let _ = tx
434                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
435                        .await;
436
437                    let output = self_clone
438                        .update_session(&current_checkpoint, ctx_clone.state.messages.clone())
439                        .await;
440
441                    match output {
442                        Err(e) => {
443                            let _ = tx.send(Err(e)).await;
444                        }
445                        Ok(output) => {
446                            ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
447                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
448                            let _ = tx
449                                .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
450                                    content: format!(
451                                        "\n<checkpoint_id>{}</checkpoint_id>\n",
452                                        output.checkpoint.id
453                                    ),
454                                })))
455                                .await;
456                        }
457                    }
458                }
459            }
460        });
461
462        let hook_registry = self.hook_registry.clone();
463        let stream = async_stream::stream! {
464            while let Some(delta_result) = rx.recv().await {
465                match delta_result {
466                    Ok(delta) => match delta {
467                            StreamMessage::Ctx(updated_ctx) => {
468                                ctx = *updated_ctx;
469                            }
470                            StreamMessage::Delta(delta) => {
471                                yield Ok(ChatCompletionStreamResponse {
472                                    id: ctx.request_id.to_string(),
473                                    object: "chat.completion.chunk".to_string(),
474                                    created: chrono::Utc::now().timestamp() as u64,
475                                    model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
476                                    choices: vec![ChatCompletionStreamChoice {
477                                        index: 0,
478                                        delta: delta.into(),
479                                        finish_reason: None,
480                                    }],
481                                    usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
482                                    metadata: None,
483                                })
484                            }
485                        }
486                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
487                }
488            }
489
490            if let Some(hook_registry) = hook_registry {
491                hook_registry
492                    .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
493                    .await
494                    .map_err(|e| e.to_string())?
495                    .ok()?;
496            }
497        };
498
499        Ok((Box::pin(stream), None))
500    }
501
502    async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
503        Ok(())
504    }
505
506    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
507        let config = SearchServicesOrchestrator::start()
508            .await
509            .map_err(|e| e.to_string())?;
510
511        // SECURITY TODO:
512        // This uses plain-text, unauthenticated HTTP over localhost.
513        // While acceptable for local development, this is an injection
514        //
515        // Mitigations to consider:
516        // - Add mutual authentication (e.g., token)
517        // - Validate the expected service identity
518
519        let api_url = format!("http://localhost:{}", config.api_port);
520        let search_client = SearchClient::new(api_url);
521
522        let initial_query = if let Some(exclude) = &input.exclude_keywords {
523            format!("{} -{}", input.keywords, exclude)
524        } else {
525            input.keywords.clone()
526        };
527
528        let llm_config = self.get_llm_config();
529        let search_model = get_search_model(
530            &llm_config,
531            self.model_options.eco_model.clone(),
532            self.model_options.smart_model.clone(),
533        );
534
535        let analysis = analyze_search_query(&llm_config, &search_model, &initial_query).await?;
536        let required_documentation = analysis.required_documentation;
537        let mut current_query = analysis.reformulated_query;
538        let mut previous_queries = Vec::new();
539        let mut final_valid_docs = Vec::new();
540        let mut accumulated_needed_urls = Vec::new();
541
542        const MAX_ITERATIONS: usize = 3;
543
544        for _iteration in 0..MAX_ITERATIONS {
545            previous_queries.push(current_query.clone());
546
547            let search_results = search_client
548                .search_and_scrape(current_query.clone(), None)
549                .await
550                .map_err(|e| e.to_string())?;
551
552            if search_results.is_empty() {
553                break;
554            }
555
556            let validation_result = validate_search_docs(
557                &llm_config,
558                &search_model,
559                &search_results,
560                &current_query,
561                &required_documentation,
562                &previous_queries,
563                &accumulated_needed_urls,
564            )
565            .await?;
566
567            for url in &validation_result.needed_urls {
568                if !accumulated_needed_urls.contains(url) {
569                    accumulated_needed_urls.push(url.clone());
570                }
571            }
572
573            for doc in validation_result.valid_docs.into_iter() {
574                let is_duplicate = final_valid_docs
575                    .iter()
576                    .any(|existing_doc: &ScrapedContent| existing_doc.url == doc.url);
577
578                if !is_duplicate {
579                    final_valid_docs.push(doc);
580                }
581            }
582
583            if validation_result.is_satisfied {
584                break;
585            }
586
587            if let Some(new_query) = validation_result.new_query {
588                if new_query != current_query && !previous_queries.contains(&new_query) {
589                    current_query = new_query;
590                } else {
591                    break;
592                }
593            } else {
594                break;
595            }
596        }
597
598        if final_valid_docs.is_empty() {
599            return Ok(vec![Content::text("No results found".to_string())]);
600        }
601
602        let contents: Vec<Content> = final_valid_docs
603            .into_iter()
604            .map(|result| {
605                Content::text(format!(
606                    "Title: {}\nURL: {}\nContent: {}",
607                    result.title.unwrap_or_default(),
608                    result.url,
609                    result.content.unwrap_or_default(),
610                ))
611            })
612            .collect();
613
614        Ok(contents)
615    }
616
617    async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
618        // TODO: Implement search memory
619        Ok(Vec::new())
620    }
621
622    async fn slack_read_messages(
623        &self,
624        _input: &SlackReadMessagesRequest,
625    ) -> Result<Vec<Content>, String> {
626        // TODO: Implement slack read messages
627        Ok(Vec::new())
628    }
629
630    async fn slack_read_replies(
631        &self,
632        _input: &SlackReadRepliesRequest,
633    ) -> Result<Vec<Content>, String> {
634        // TODO: Implement slack read replies
635        Ok(Vec::new())
636    }
637
638    async fn slack_send_message(
639        &self,
640        _input: &SlackSendMessageRequest,
641    ) -> Result<Vec<Content>, String> {
642        // TODO: Implement slack send message
643        Ok(Vec::new())
644    }
645
646    async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
647        // TODO: Implement memorize session
648        Ok(())
649    }
650}
651
652impl LocalClient {
653    fn get_llm_config(&self) -> LLMProviderConfig {
654        LLMProviderConfig {
655            anthropic_config: self.anthropic_config.clone(),
656            openai_config: self.openai_config.clone(),
657            gemini_config: self.gemini_config.clone(),
658        }
659    }
660
661    async fn run_agent_completion(
662        &self,
663        ctx: &mut HookContext<AgentState>,
664        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
665    ) -> Result<ChatMessage, String> {
666        if let Some(hook_registry) = &self.hook_registry {
667            hook_registry
668                .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
669                .await
670                .map_err(|e| e.to_string())?
671                .ok()?;
672        }
673
674        let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
675            llm_input
676        } else {
677            return Err(
678                "Run agent completion: LLM input not found, make sure to register a context hook before inference"
679                    .to_string(),
680            );
681        };
682
683        let llm_config = self.get_llm_config();
684        let stakai_client = StakAIClient::new(&llm_config)
685            .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
686
687        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
688            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
689            let input = LLMStreamInput {
690                model: input.model,
691                messages: input.messages,
692                max_tokens: input.max_tokens,
693                tools: input.tools,
694                stream_channel_tx: internal_tx,
695                provider_options: input.provider_options,
696            };
697
698            let chat_future = async move {
699                stakai_client
700                    .chat_stream(input)
701                    .await
702                    .map_err(|e| e.to_string())
703            };
704
705            let receive_future = async move {
706                while let Some(delta) = internal_rx.recv().await {
707                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
708                        break;
709                    }
710                }
711            };
712
713            let (chat_result, _) = tokio::join!(chat_future, receive_future);
714            let response = chat_result?;
715            (response.choices[0].message.clone(), response.usage)
716        } else {
717            let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
718            (response.choices[0].message.clone(), response.usage)
719        };
720
721        ctx.state.set_llm_output(response_message, usage);
722
723        if let Some(hook_registry) = &self.hook_registry {
724            hook_registry
725                .execute_hooks(ctx, &LifecycleEvent::AfterInference)
726                .await
727                .map_err(|e| e.to_string())?
728                .ok()?;
729        }
730
731        let llm_output = ctx
732            .state
733            .llm_output
734            .as_ref()
735            .ok_or_else(|| "LLM output is missing from state".to_string())?;
736
737        Ok(ChatMessage::from(llm_output))
738    }
739
740    async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
741        // 1. Validate input
742        if messages.is_empty() {
743            return Err("At least one message is required".to_string());
744        }
745
746        // 2. Extract session/checkpoint ID or create new session
747        let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
748            message
749                .content
750                .as_ref()
751                .and_then(|content| content.extract_checkpoint_id())
752        });
753
754        let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
755            db::get_checkpoint(&self.db, checkpoint_id).await?
756        } else {
757            let title = self.generate_session_title(messages).await?;
758
759            // Create new session
760            let session_id = Uuid::new_v4();
761            let now = chrono::Utc::now();
762            let session = AgentSession {
763                id: session_id,
764                title,
765                agent_id: AgentID::PabloV1,
766                visibility: AgentSessionVisibility::Private,
767                created_at: now,
768                updated_at: now,
769                checkpoints: vec![],
770            };
771            db::create_session(&self.db, &session).await?;
772
773            // Create initial checkpoint (root)
774            let checkpoint_id = Uuid::new_v4();
775            let checkpoint = AgentCheckpointListItem {
776                id: checkpoint_id,
777                status: AgentStatus::Complete,
778                execution_depth: 0,
779                parent: None,
780                created_at: now,
781                updated_at: now,
782            };
783            let initial_state = AgentOutput::PabloV1 {
784                messages: messages.to_vec(),
785                node_states: serde_json::json!({}),
786            };
787            db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
788
789            db::get_checkpoint(&self.db, checkpoint_id).await?
790        };
791
792        Ok(current_checkpoint)
793    }
794
795    async fn update_session(
796        &self,
797        checkpoint_info: &RunAgentOutput,
798        new_messages: Vec<ChatMessage>,
799    ) -> Result<RunAgentOutput, String> {
800        let now = chrono::Utc::now();
801        let complete_checkpoint = AgentCheckpointListItem {
802            id: Uuid::new_v4(),
803            status: AgentStatus::Complete,
804            execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
805            parent: Some(AgentParentCheckpoint {
806                id: checkpoint_info.checkpoint.id,
807            }),
808            created_at: now,
809            updated_at: now,
810        };
811
812        let mut new_state = checkpoint_info.output.clone();
813        new_state.set_messages(new_messages);
814
815        db::create_checkpoint(
816            &self.db,
817            checkpoint_info.session.id,
818            &complete_checkpoint,
819            &new_state,
820        )
821        .await?;
822
823        Ok(RunAgentOutput {
824            checkpoint: complete_checkpoint,
825            session: checkpoint_info.session.clone(),
826            output: new_state,
827        })
828    }
829
830    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
831        let llm_config = self.get_llm_config();
832
833        let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
834            eco_model.clone()
835        } else if llm_config.openai_config.is_some() {
836            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
837        } else if llm_config.anthropic_config.is_some() {
838            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
839        } else if llm_config.gemini_config.is_some() {
840            LLMModel::Gemini(GeminiModel::Gemini25Flash)
841        } else {
842            return Err("No LLM config found".to_string());
843        };
844
845        let messages = vec![
846            LLMMessage {
847                role: "system".to_string(),
848                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
849            },
850            LLMMessage {
851                role: "user".to_string(),
852                content: LLMMessageContent::String(
853                    messages
854                        .iter()
855                        .map(|msg| {
856                            msg.content
857                                .as_ref()
858                                .unwrap_or(&MessageContent::String("".to_string()))
859                                .to_string()
860                        })
861                        .collect(),
862                ),
863            },
864        ];
865
866        let input = LLMInput {
867            model: llm_model,
868            messages,
869            max_tokens: 100,
870            tools: None,
871            provider_options: None,
872        };
873
874        let stakai_client = StakAIClient::new(&llm_config)
875            .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
876        let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
877
878        Ok(response.choices[0].message.content.to_string())
879    }
880}
881
882async fn analyze_search_query(
883    llm_config: &LLMProviderConfig,
884    model: &LLMModel,
885    query: &str,
886) -> Result<AnalysisResult, String> {
887    let system_prompt = r#"You are an expert search query analyzer specializing in technical documentation retrieval.
888
889## Your Task
890
891Analyze the user's search query to:
8921. Identify the specific types of documentation needed
8932. Reformulate the query for optimal search engine results
894
895## Guidelines for Required Documentation
896
897Identify specific documentation types such as:
898- API references and specifications
899- Installation/setup guides
900- Configuration documentation
901- Tutorials and getting started guides
902- Troubleshooting guides
903- Architecture/design documents
904- CLI/command references
905- SDK/library documentation
906
907## Guidelines for Query Reformulation
908
909Create an optimized search query that:
910- Uses specific technical terminology
911- Includes relevant keywords (e.g., "documentation", "guide", "API")
912- Removes ambiguous or filler words
913- Targets authoritative sources when possible
914- Is concise but comprehensive (5-10 words ideal)
915
916## Response Format
917
918Respond ONLY with valid XML in this exact structure:
919
920<analysis>
921  <required_documentation>
922    <item>specific documentation type needed</item>
923  </required_documentation>
924  <reformulated_query>optimized search query string</reformulated_query>
925</analysis>"#;
926
927    let user_prompt = format!(
928        r#"<user_query>{}</user_query>
929
930Analyze this query and provide the required documentation types and an optimized search query."#,
931        query
932    );
933
934    let input = LLMInput {
935        model: model.clone(),
936        messages: vec![
937            LLMMessage {
938                role: Role::System.to_string(),
939                content: LLMMessageContent::String(system_prompt.to_string()),
940            },
941            LLMMessage {
942                role: Role::User.to_string(),
943                content: LLMMessageContent::String(user_prompt.to_string()),
944            },
945        ],
946        max_tokens: 2000,
947        tools: None,
948        provider_options: None,
949    };
950
951    let stakai_client = StakAIClient::new(llm_config)
952        .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
953    let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
954
955    let content = response.choices[0].message.content.to_string();
956
957    parse_analysis_xml(&content)
958}
959
960fn parse_analysis_xml(xml: &str) -> Result<AnalysisResult, String> {
961    let extract_tag = |tag: &str| -> Option<String> {
962        let start_tag = format!("<{}>", tag);
963        let end_tag = format!("</{}>", tag);
964        xml.find(&start_tag).and_then(|start| {
965            let content_start = start + start_tag.len();
966            xml[content_start..]
967                .find(&end_tag)
968                .map(|end| xml[content_start..content_start + end].trim().to_string())
969        })
970    };
971
972    let extract_all_tags = |tag: &str| -> Vec<String> {
973        let start_tag = format!("<{}>", tag);
974        let end_tag = format!("</{}>", tag);
975        let mut results = Vec::new();
976        let mut search_start = 0;
977
978        while let Some(start) = xml[search_start..].find(&start_tag) {
979            let abs_start = search_start + start + start_tag.len();
980            if let Some(end) = xml[abs_start..].find(&end_tag) {
981                results.push(xml[abs_start..abs_start + end].trim().to_string());
982                search_start = abs_start + end + end_tag.len();
983            } else {
984                break;
985            }
986        }
987        results
988    };
989
990    let required_documentation = extract_all_tags("item");
991    let reformulated_query =
992        extract_tag("reformulated_query").ok_or("Failed to extract reformulated_query from XML")?;
993
994    Ok(AnalysisResult {
995        required_documentation,
996        reformulated_query,
997    })
998}
999
1000async fn validate_search_docs(
1001    llm_config: &LLMProviderConfig,
1002    model: &LLMModel,
1003    docs: &[ScrapedContent],
1004    query: &str,
1005    required_documentation: &[String],
1006    previous_queries: &[String],
1007    accumulated_needed_urls: &[String],
1008) -> Result<ValidationResult, String> {
1009    let docs_preview = docs
1010        .iter()
1011        .enumerate()
1012        .take(10)
1013        .map(|(i, r)| {
1014            format!(
1015                "<doc index=\"{}\">\n  <title>{}</title>\n  <url>{}</url>\n</doc>",
1016                i + 1,
1017                r.title.clone().unwrap_or_else(|| "Untitled".to_string()),
1018                r.url
1019            )
1020        })
1021        .collect::<Vec<_>>()
1022        .join("\n");
1023
1024    let required_docs_formatted = required_documentation
1025        .iter()
1026        .map(|d| format!("  <item>{}</item>", d))
1027        .collect::<Vec<_>>()
1028        .join("\n");
1029
1030    let previous_queries_formatted = previous_queries
1031        .iter()
1032        .map(|q| format!("  <query>{}</query>", q))
1033        .collect::<Vec<_>>()
1034        .join("\n");
1035
1036    let accumulated_urls_formatted = accumulated_needed_urls
1037        .iter()
1038        .map(|u| format!("  <url>{}</url>", u))
1039        .collect::<Vec<_>>()
1040        .join("\n");
1041
1042    let system_prompt = r#"You are an expert search result validator. Your task is to evaluate whether search results adequately satisfy a documentation query.
1043
1044## Evaluation Criteria
1045
1046For each search result, assess:
10471. **Relevance**: Does the document directly address the required documentation topics?
10482. **Authority**: Is this an official source, documentation site, or authoritative reference?
10493. **Completeness**: Does it provide comprehensive information, not just passing mentions?
10504. **Freshness**: For technical docs, prefer current/maintained sources over outdated ones.
1051
1052## Decision Guidelines
1053
1054Mark results as SATISFIED when:
1055- All required documentation topics have at least one authoritative source
1056- The sources provide actionable, detailed information
1057- No critical gaps remain in coverage
1058
1059Suggest a NEW QUERY when:
1060- Key topics are missing from results
1061- Results are too general or tangential
1062- A more specific query would yield better results
1063- Previous queries haven't addressed certain requirements
1064
1065## Response Format
1066
1067Respond ONLY with valid XML in this exact structure:
1068
1069<validation>
1070  <is_satisfied>true or false</is_satisfied>
1071  <valid_docs>
1072    <doc><url>exact URL from results</url></doc>
1073  </valid_docs>
1074  <needed_urls>
1075    <url>specific URL pattern or domain still needed</url>
1076  </needed_urls>
1077  <new_query>refined search query if not satisfied, omit if satisfied</new_query>
1078  <reasoning>brief explanation of your assessment</reasoning>
1079</validation>"#;
1080
1081    let user_prompt = format!(
1082        r#"<search_context>
1083  <original_query>{}</original_query>
1084  <required_documentation>
1085{}
1086  </required_documentation>
1087  <previous_queries>
1088{}
1089  </previous_queries>
1090  <accumulated_needed_urls>
1091{}
1092  </accumulated_needed_urls>
1093</search_context>
1094
1095<current_results>
1096{}
1097</current_results>
1098
1099Evaluate these search results against the requirements. Which documents are valid and relevant? Is the documentation requirement satisfied? If not, what specific query would help find missing information?"#,
1100        query,
1101        if required_docs_formatted.is_empty() {
1102            "    <item>None specified</item>".to_string()
1103        } else {
1104            required_docs_formatted
1105        },
1106        if previous_queries_formatted.is_empty() {
1107            "    <query>None</query>".to_string()
1108        } else {
1109            previous_queries_formatted
1110        },
1111        if accumulated_urls_formatted.is_empty() {
1112            "    <url>None</url>".to_string()
1113        } else {
1114            accumulated_urls_formatted
1115        },
1116        docs_preview
1117    );
1118
1119    let input = LLMInput {
1120        model: model.clone(),
1121        messages: vec![
1122            LLMMessage {
1123                role: Role::System.to_string(),
1124                content: LLMMessageContent::String(system_prompt.to_string()),
1125            },
1126            LLMMessage {
1127                role: Role::User.to_string(),
1128                content: LLMMessageContent::String(user_prompt.to_string()),
1129            },
1130        ],
1131        max_tokens: 4000,
1132        tools: None,
1133        provider_options: None,
1134    };
1135
1136    let stakai_client = StakAIClient::new(llm_config)
1137        .map_err(|e| format!("Failed to create StakAI client: {}", e))?;
1138    let response = stakai_client.chat(input).await.map_err(|e| e.to_string())?;
1139
1140    let content = response.choices[0].message.content.to_string();
1141
1142    let validation = parse_validation_xml(&content, docs)?;
1143
1144    Ok(validation)
1145}
1146
1147fn parse_validation_xml(xml: &str, docs: &[ScrapedContent]) -> Result<ValidationResult, String> {
1148    let extract_tag = |tag: &str| -> Option<String> {
1149        let start_tag = format!("<{}>", tag);
1150        let end_tag = format!("</{}>", tag);
1151        xml.find(&start_tag).and_then(|start| {
1152            let content_start = start + start_tag.len();
1153            xml[content_start..]
1154                .find(&end_tag)
1155                .map(|end| xml[content_start..content_start + end].trim().to_string())
1156        })
1157    };
1158
1159    let extract_all_tags = |tag: &str| -> Vec<String> {
1160        let start_tag = format!("<{}>", tag);
1161        let end_tag = format!("</{}>", tag);
1162        let mut results = Vec::new();
1163        let mut search_start = 0;
1164
1165        while let Some(start) = xml[search_start..].find(&start_tag) {
1166            let abs_start = search_start + start + start_tag.len();
1167            if let Some(end) = xml[abs_start..].find(&end_tag) {
1168                results.push(xml[abs_start..abs_start + end].trim().to_string());
1169                search_start = abs_start + end + end_tag.len();
1170            } else {
1171                break;
1172            }
1173        }
1174        results
1175    };
1176
1177    let is_satisfied = extract_tag("is_satisfied")
1178        .map(|s| s.to_lowercase() == "true")
1179        .unwrap_or(false);
1180
1181    let valid_urls: Vec<String> = extract_all_tags("url")
1182        .into_iter()
1183        .filter(|url| docs.iter().any(|d| d.url == *url))
1184        .collect();
1185
1186    let valid_docs: Vec<ScrapedContent> = valid_urls
1187        .iter()
1188        .filter_map(|url| docs.iter().find(|d| d.url == *url).cloned())
1189        .collect();
1190
1191    let needed_urls: Vec<String> = extract_all_tags("url")
1192        .into_iter()
1193        .filter(|url| !docs.iter().any(|d| d.url == *url))
1194        .collect();
1195
1196    let new_query = extract_tag("new_query").filter(|q| !q.is_empty() && q != "omit if satisfied");
1197
1198    Ok(ValidationResult {
1199        is_satisfied,
1200        valid_docs,
1201        needed_urls,
1202        new_query,
1203    })
1204}
1205
1206fn get_search_model(
1207    llm_config: &LLMProviderConfig,
1208    eco_model: Option<LLMModel>,
1209    smart_model: Option<LLMModel>,
1210) -> LLMModel {
1211    let base_model = eco_model.or(smart_model);
1212
1213    match base_model {
1214        Some(LLMModel::OpenAI(_)) => LLMModel::OpenAI(OpenAIModel::O4Mini),
1215        Some(LLMModel::Anthropic(_)) => LLMModel::Anthropic(AnthropicModel::Claude45Haiku),
1216        Some(LLMModel::Gemini(_)) => LLMModel::Gemini(GeminiModel::Gemini3Flash),
1217        Some(LLMModel::Custom(model)) => LLMModel::Custom(model),
1218        None => {
1219            if llm_config.openai_config.is_some() {
1220                LLMModel::OpenAI(OpenAIModel::O4Mini)
1221            } else if llm_config.anthropic_config.is_some() {
1222                LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
1223            } else if llm_config.gemini_config.is_some() {
1224                LLMModel::Gemini(GeminiModel::Gemini3Flash)
1225            } else {
1226                LLMModel::OpenAI(OpenAIModel::O4Mini)
1227            }
1228        }
1229    }
1230}