stakpak_api/local/
mod.rs

1use crate::local::context_managers::scratchpad_context_manager::{
2    ScratchpadContextManager, ScratchpadContextManagerOptions,
3};
4use crate::local::hooks::scratchpad_context_hook::{ContextHook, ContextHookOptions};
5use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
6use crate::{ListRuleBook, models::*};
7use async_trait::async_trait;
8use futures_util::Stream;
9use libsql::{Builder, Connection};
10use reqwest::Error as ReqwestError;
11use reqwest::header::HeaderMap;
12use rmcp::model::Content;
13use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
14use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
15use stakpak_shared::models::integrations::gemini::GeminiConfig;
16use stakpak_shared::models::integrations::openai::{
17    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
18    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
19    OpenAIModel, Tool,
20};
21use stakpak_shared::models::llm::{
22    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
23    LLMStreamInput, chat, chat_stream,
24};
25use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
26use std::pin::Pin;
27use std::sync::Arc;
28use tokio::sync::mpsc;
29use uuid::Uuid;
30
31mod context_managers;
32mod db;
33mod hooks;
34
35#[cfg(test)]
36mod tests;
37
38#[derive(Clone, Debug)]
39pub struct LocalClient {
40    pub db: Connection,
41    pub stakpak_base_url: Option<String>,
42    pub anthropic_config: Option<AnthropicConfig>,
43    pub openai_config: Option<OpenAIConfig>,
44    pub gemini_config: Option<GeminiConfig>,
45    pub smart_model: LLMModel,
46    pub eco_model: LLMModel,
47    pub recovery_model: LLMModel,
48    pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
49}
50
51pub struct LocalClientConfig {
52    pub stakpak_base_url: Option<String>,
53    pub store_path: Option<String>,
54    pub anthropic_config: Option<AnthropicConfig>,
55    pub openai_config: Option<OpenAIConfig>,
56    pub gemini_config: Option<GeminiConfig>,
57    pub smart_model: Option<String>,
58    pub eco_model: Option<String>,
59    pub recovery_model: Option<String>,
60    pub hook_registry: Option<HookRegistry<AgentState>>,
61}
62
63#[derive(Debug)]
64enum StreamMessage {
65    Delta(GenerationDelta),
66    Ctx(Box<HookContext<AgentState>>),
67}
68
69const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
70const SYSTEM_PROMPT: &str = include_str!("./prompts/agent.v1.txt");
71const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
72
73impl LocalClient {
74    pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
75        let default_store_path = std::env::home_dir()
76            .unwrap_or_default()
77            .join(DEFAULT_STORE_PATH);
78
79        if let Some(parent) = default_store_path.parent() {
80            std::fs::create_dir_all(parent)
81                .map_err(|e| format!("Failed to create database directory: {}", e))?;
82        }
83
84        let db = Builder::new_local(default_store_path.display().to_string())
85            .build()
86            .await
87            .map_err(|e| e.to_string())?;
88
89        let conn = db.connect().map_err(|e| e.to_string())?;
90
91        // Initialize database schema
92        db::init_schema(&conn).await?;
93
94        let smart_model = config
95            .smart_model
96            .map(LLMModel::from)
97            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
98        let eco_model = config
99            .eco_model
100            .map(LLMModel::from)
101            .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
102        let recovery_model = config
103            .recovery_model
104            .map(LLMModel::from)
105            .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
106
107        // Add hooks
108        let mut hook_registry = config.hook_registry.unwrap_or_default();
109        hook_registry.register(
110            LifecycleEvent::BeforeInference,
111            Box::new(ContextHook::new(ContextHookOptions {
112                context_manager: Box::new(ScratchpadContextManager::new(
113                    ScratchpadContextManagerOptions {
114                        history_action_message_size_limit: 100,
115                        history_action_message_keep_last_n: 1,
116                        history_action_result_keep_last_n: 50,
117                    },
118                )),
119                smart_model: (smart_model.clone(), SYSTEM_PROMPT.to_string()),
120                eco_model: (eco_model.clone(), SYSTEM_PROMPT.to_string()),
121                recovery_model: (recovery_model.clone(), SYSTEM_PROMPT.to_string()),
122            })),
123        );
124
125        Ok(Self {
126            db: conn,
127            stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
128            anthropic_config: config.anthropic_config,
129            gemini_config: config.gemini_config,
130            openai_config: config.openai_config,
131            smart_model,
132            eco_model,
133            recovery_model,
134            hook_registry: Some(Arc::new(hook_registry)),
135        })
136    }
137}
138
139#[async_trait]
140impl AgentProvider for LocalClient {
141    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
142        Ok(GetMyAccountResponse {
143            username: "local".to_string(),
144            id: "local".to_string(),
145            first_name: "local".to_string(),
146            last_name: "local".to_string(),
147        })
148    }
149
150    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
151        if self.stakpak_base_url.is_none() {
152            return Ok(vec![]);
153        }
154
155        let stakpak_base_url = self
156            .stakpak_base_url
157            .as_ref()
158            .ok_or("Stakpak base URL not set")?;
159
160        let url = format!("{}/rules", stakpak_base_url);
161
162        let client = create_tls_client(
163            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
164        )?;
165
166        let response = client
167            .get(url)
168            .send()
169            .await
170            .map_err(|e: ReqwestError| e.to_string())?;
171
172        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
173
174        match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
175            Ok(response) => Ok(response.results),
176            Err(e) => {
177                eprintln!("Failed to deserialize response: {}", e);
178                eprintln!("Raw response: {}", value);
179                Err("Failed to deserialize response:".into())
180            }
181        }
182    }
183
184    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
185        let stakpak_base_url = self
186            .stakpak_base_url
187            .as_ref()
188            .ok_or("Stakpak base URL not set")?;
189
190        let encoded_uri = urlencoding::encode(uri);
191
192        let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
193
194        let client = create_tls_client(
195            TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
196        )?;
197
198        let response = client
199            .get(&url)
200            .send()
201            .await
202            .map_err(|e: ReqwestError| e.to_string())?;
203
204        let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
205
206        match serde_json::from_value::<RuleBook>(value.clone()) {
207            Ok(response) => Ok(response),
208            Err(e) => {
209                eprintln!("Failed to deserialize response: {}", e);
210                eprintln!("Raw response: {}", value);
211                Err("Failed to deserialize response:".into())
212            }
213        }
214    }
215
216    async fn create_rulebook(
217        &self,
218        _uri: &str,
219        _description: &str,
220        _content: &str,
221        _tags: Vec<String>,
222        _visibility: Option<RuleBookVisibility>,
223    ) -> Result<CreateRuleBookResponse, String> {
224        // TODO: Implement create rulebook
225        Err("Local provider does not support rulebooks yet".to_string())
226    }
227
228    async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
229        // TODO: Implement delete rulebook
230        Err("Local provider does not support rulebooks yet".to_string())
231    }
232
233    async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
234        db::list_sessions(&self.db).await
235    }
236
237    async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
238        db::get_session(&self.db, session_id).await
239    }
240
241    async fn get_agent_session_stats(
242        &self,
243        _session_id: Uuid,
244    ) -> Result<AgentSessionStats, String> {
245        // TODO: Implement get agent session stats
246        Ok(AgentSessionStats::default())
247    }
248
249    async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
250        db::get_checkpoint(&self.db, checkpoint_id).await
251    }
252
253    async fn get_agent_session_latest_checkpoint(
254        &self,
255        session_id: Uuid,
256    ) -> Result<RunAgentOutput, String> {
257        db::get_latest_checkpoint(&self.db, session_id).await
258    }
259
260    async fn chat_completion(
261        &self,
262        model: AgentModel,
263        messages: Vec<ChatMessage>,
264        tools: Option<Vec<Tool>>,
265    ) -> Result<ChatCompletionResponse, String> {
266        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
267
268        if let Some(hook_registry) = &self.hook_registry {
269            hook_registry
270                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
271                .await
272                .map_err(|e| e.to_string())?
273                .ok()?;
274        }
275
276        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
277        ctx.set_session_id(current_checkpoint.session.id);
278
279        let new_message = self.run_agent_completion(&mut ctx, None).await?;
280        ctx.state.append_new_message(new_message.clone());
281
282        let result = self
283            .update_session(&current_checkpoint, ctx.state.messages.clone())
284            .await?;
285        let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
286        ctx.set_new_checkpoint_id(result.checkpoint.id);
287
288        if let Some(hook_registry) = &self.hook_registry {
289            hook_registry
290                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
291                .await
292                .map_err(|e| e.to_string())?
293                .ok()?;
294        }
295
296        Ok(ChatCompletionResponse {
297            id: ctx.new_checkpoint_id.unwrap().to_string(),
298            object: "chat.completion".to_string(),
299            created: checkpoint_created_at,
300            model: ctx
301                .state
302                .llm_input
303                .as_ref()
304                .map(|llm_input| llm_input.model.clone().to_string())
305                .unwrap_or_default(),
306            choices: vec![ChatCompletionChoice {
307                index: 0,
308                message: ctx.state.messages.last().cloned().unwrap(),
309                logprobs: None,
310                finish_reason: FinishReason::Stop,
311            }],
312            usage: ctx
313                .state
314                .llm_output
315                .as_ref()
316                .map(|u| u.usage.clone())
317                .unwrap_or_default(),
318            system_fingerprint: None,
319            metadata: None,
320        })
321    }
322
323    async fn chat_completion_stream(
324        &self,
325        model: AgentModel,
326        messages: Vec<ChatMessage>,
327        tools: Option<Vec<Tool>>,
328        _headers: Option<HeaderMap>,
329    ) -> Result<
330        (
331            Pin<
332                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
333            >,
334            Option<String>,
335        ),
336        String,
337    > {
338        let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
339
340        if let Some(hook_registry) = &self.hook_registry {
341            hook_registry
342                .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
343                .await
344                .map_err(|e| e.to_string())?
345                .ok()?;
346        }
347
348        let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
349        ctx.set_session_id(current_checkpoint.session.id);
350
351        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
352
353        let _ = tx
354            .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
355                content: format!(
356                    "\n<checkpoint_id>{}</checkpoint_id>\n",
357                    current_checkpoint.checkpoint.id
358                ),
359            })))
360            .await;
361
362        let client = self.clone();
363        let self_clone = self.clone();
364        let mut ctx_clone = ctx.clone();
365        tokio::spawn(async move {
366            let result = client
367                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
368                .await;
369
370            match result {
371                Err(e) => {
372                    let _ = tx.send(Err(e)).await;
373                }
374                Ok(new_message) => {
375                    ctx_clone.state.append_new_message(new_message.clone());
376                    let _ = tx
377                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
378                        .await;
379
380                    let output = self_clone
381                        .update_session(&current_checkpoint, ctx_clone.state.messages.clone())
382                        .await;
383
384                    match output {
385                        Err(e) => {
386                            let _ = tx.send(Err(e)).await;
387                        }
388                        Ok(output) => {
389                            ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
390                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
391                            let _ = tx
392                                .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
393                                    content: format!(
394                                        "\n<checkpoint_id>{}</checkpoint_id>\n",
395                                        output.checkpoint.id
396                                    ),
397                                })))
398                                .await;
399                        }
400                    }
401                }
402            }
403        });
404
405        let hook_registry = self.hook_registry.clone();
406        let stream = async_stream::stream! {
407            while let Some(delta_result) = rx.recv().await {
408                match delta_result {
409                    Ok(delta) => match delta {
410                            StreamMessage::Ctx(updated_ctx) => {
411                                ctx = *updated_ctx;
412                            }
413                            StreamMessage::Delta(delta) => {
414                                yield Ok(ChatCompletionStreamResponse {
415                                    id: ctx.request_id.to_string(),
416                                    object: "chat.completion.chunk".to_string(),
417                                    created: chrono::Utc::now().timestamp() as u64,
418                                    model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
419                                    choices: vec![ChatCompletionStreamChoice {
420                                        index: 0,
421                                        delta: delta.into(),
422                                        finish_reason: None,
423                                    }],
424                                    usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
425                                    metadata: None,
426                                })
427                            }
428                        }
429                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
430                }
431            }
432
433            if let Some(hook_registry) = hook_registry {
434                hook_registry
435                    .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
436                    .await
437                    .map_err(|e| e.to_string())?
438                    .ok()?;
439            }
440        };
441
442        Ok((Box::pin(stream), None))
443    }
444
445    async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
446        Ok(())
447    }
448
449    async fn search_docs(&self, _input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
450        // TODO: Implement search docs
451        Ok(Vec::new())
452    }
453
454    async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
455        // TODO: Implement search memory
456        Ok(Vec::new())
457    }
458
459    async fn slack_read_messages(
460        &self,
461        _input: &SlackReadMessagesRequest,
462    ) -> Result<Vec<Content>, String> {
463        // TODO: Implement slack read messages
464        Ok(Vec::new())
465    }
466
467    async fn slack_read_replies(
468        &self,
469        _input: &SlackReadRepliesRequest,
470    ) -> Result<Vec<Content>, String> {
471        // TODO: Implement slack read replies
472        Ok(Vec::new())
473    }
474
475    async fn slack_send_message(
476        &self,
477        _input: &SlackSendMessageRequest,
478    ) -> Result<Vec<Content>, String> {
479        // TODO: Implement slack send message
480        Ok(Vec::new())
481    }
482
483    async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
484        // TODO: Implement memorize session
485        Ok(())
486    }
487}
488
489impl LocalClient {
490    fn get_llm_config(&self) -> LLMProviderConfig {
491        LLMProviderConfig {
492            anthropic_config: self.anthropic_config.clone(),
493            openai_config: self.openai_config.clone(),
494            gemini_config: self.gemini_config.clone(),
495        }
496    }
497
498    async fn run_agent_completion(
499        &self,
500        ctx: &mut HookContext<AgentState>,
501        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
502    ) -> Result<ChatMessage, String> {
503        if let Some(hook_registry) = &self.hook_registry {
504            hook_registry
505                .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
506                .await
507                .map_err(|e| e.to_string())?
508                .ok()?;
509        }
510
511        let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
512            llm_input
513        } else {
514            return Err(
515                "Run agent completion: LLM input not found, make sure to register a context hook before inference"
516                    .to_string(),
517            );
518        };
519
520        let llm_config = self.get_llm_config();
521
522        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
523            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
524            let input = LLMStreamInput {
525                model: input.model,
526                messages: input.messages,
527                max_tokens: input.max_tokens,
528                tools: input.tools,
529                stream_channel_tx: internal_tx,
530            };
531
532            let chat_future = async move {
533                chat_stream(&llm_config, input)
534                    .await
535                    .map_err(|e| e.to_string())
536            };
537
538            let receive_future = async move {
539                while let Some(delta) = internal_rx.recv().await {
540                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
541                        break;
542                    }
543                }
544            };
545
546            let (chat_result, _) = tokio::join!(chat_future, receive_future);
547            let response = chat_result?;
548            (response.choices[0].message.clone(), response.usage)
549        } else {
550            let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
551            (response.choices[0].message.clone(), response.usage)
552        };
553
554        ctx.state.set_llm_output(response_message, usage);
555
556        if let Some(hook_registry) = &self.hook_registry {
557            hook_registry
558                .execute_hooks(ctx, &LifecycleEvent::AfterInference)
559                .await
560                .map_err(|e| e.to_string())?
561                .ok()?;
562        }
563
564        let llm_output = ctx
565            .state
566            .llm_output
567            .as_ref()
568            .ok_or_else(|| "LLM output is missing from state".to_string())?;
569
570        Ok(ChatMessage::from(llm_output))
571    }
572
573    async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
574        // 1. Validate input
575        if messages.is_empty() {
576            return Err("At least one message is required".to_string());
577        }
578
579        // 2. Extract session/checkpoint ID or create new session
580        let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
581            message
582                .content
583                .as_ref()
584                .and_then(|content| content.extract_checkpoint_id())
585        });
586
587        let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
588            db::get_checkpoint(&self.db, checkpoint_id).await?
589        } else {
590            let title = self.generate_session_title(messages).await?;
591
592            // Create new session
593            let session_id = Uuid::new_v4();
594            let now = chrono::Utc::now();
595            let session = AgentSession {
596                id: session_id,
597                title,
598                agent_id: AgentID::PabloV1,
599                visibility: AgentSessionVisibility::Private,
600                created_at: now,
601                updated_at: now,
602                checkpoints: vec![],
603            };
604            db::create_session(&self.db, &session).await?;
605
606            // Create initial checkpoint (root)
607            let checkpoint_id = Uuid::new_v4();
608            let checkpoint = AgentCheckpointListItem {
609                id: checkpoint_id,
610                status: AgentStatus::Complete,
611                execution_depth: 0,
612                parent: None,
613                created_at: now,
614                updated_at: now,
615            };
616            let initial_state = AgentOutput::PabloV1 {
617                messages: messages.to_vec(),
618                node_states: serde_json::json!({}),
619            };
620            db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
621
622            db::get_checkpoint(&self.db, checkpoint_id).await?
623        };
624
625        Ok(current_checkpoint)
626    }
627
628    async fn update_session(
629        &self,
630        checkpoint_info: &RunAgentOutput,
631        new_messages: Vec<ChatMessage>,
632    ) -> Result<RunAgentOutput, String> {
633        let now = chrono::Utc::now();
634        let complete_checkpoint = AgentCheckpointListItem {
635            id: Uuid::new_v4(),
636            status: AgentStatus::Complete,
637            execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
638            parent: Some(AgentParentCheckpoint {
639                id: checkpoint_info.checkpoint.id,
640            }),
641            created_at: now,
642            updated_at: now,
643        };
644
645        let mut new_state = checkpoint_info.output.clone();
646        new_state.set_messages(new_messages);
647
648        db::create_checkpoint(
649            &self.db,
650            checkpoint_info.session.id,
651            &complete_checkpoint,
652            &new_state,
653        )
654        .await?;
655
656        Ok(RunAgentOutput {
657            checkpoint: complete_checkpoint,
658            session: checkpoint_info.session.clone(),
659            output: new_state,
660        })
661    }
662
663    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
664        let llm_config = self.get_llm_config();
665        let llm_model = self.eco_model.clone();
666
667        let messages = vec![
668            LLMMessage {
669                role: "system".to_string(),
670                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
671            },
672            LLMMessage {
673                role: "user".to_string(),
674                content: LLMMessageContent::String(
675                    messages
676                        .iter()
677                        .map(|msg| {
678                            msg.content
679                                .as_ref()
680                                .unwrap_or(&MessageContent::String("".to_string()))
681                                .to_string()
682                        })
683                        .collect(),
684                ),
685            },
686        ];
687
688        let input = LLMInput {
689            model: llm_model,
690            messages,
691            max_tokens: 100,
692            tools: None,
693        };
694
695        let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
696
697        Ok(response.choices[0].message.content.to_string())
698    }
699}