1use crate::local::context_managers::scratchpad_context_manager::{
2 ScratchpadContextManager, ScratchpadContextManagerOptions,
3};
4use crate::local::hooks::scratchpad_context_hook::{ContextHook, ContextHookOptions};
5use crate::{AgentProvider, ApiStreamError, GetMyAccountResponse};
6use crate::{ListRuleBook, models::*};
7use async_trait::async_trait;
8use futures_util::Stream;
9use libsql::{Builder, Connection};
10use reqwest::Error as ReqwestError;
11use reqwest::header::HeaderMap;
12use rmcp::model::Content;
13use stakpak_shared::hooks::{HookContext, HookRegistry, LifecycleEvent};
14use stakpak_shared::models::integrations::anthropic::{AnthropicConfig, AnthropicModel};
15use stakpak_shared::models::integrations::gemini::GeminiConfig;
16use stakpak_shared::models::integrations::openai::{
17 AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
18 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, OpenAIConfig,
19 OpenAIModel, Tool,
20};
21use stakpak_shared::models::llm::{
22 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMProviderConfig,
23 LLMStreamInput, chat, chat_stream,
24};
25use stakpak_shared::tls_client::{TlsClientConfig, create_tls_client};
26use std::pin::Pin;
27use std::sync::Arc;
28use tokio::sync::mpsc;
29use uuid::Uuid;
30
31mod context_managers;
32mod db;
33mod hooks;
34
35#[cfg(test)]
36mod tests;
37
38#[derive(Clone, Debug)]
39pub struct LocalClient {
40 pub db: Connection,
41 pub stakpak_base_url: Option<String>,
42 pub anthropic_config: Option<AnthropicConfig>,
43 pub openai_config: Option<OpenAIConfig>,
44 pub gemini_config: Option<GeminiConfig>,
45 pub smart_model: LLMModel,
46 pub eco_model: LLMModel,
47 pub recovery_model: LLMModel,
48 pub hook_registry: Option<Arc<HookRegistry<AgentState>>>,
49}
50
51pub struct LocalClientConfig {
52 pub stakpak_base_url: Option<String>,
53 pub store_path: Option<String>,
54 pub anthropic_config: Option<AnthropicConfig>,
55 pub openai_config: Option<OpenAIConfig>,
56 pub gemini_config: Option<GeminiConfig>,
57 pub smart_model: Option<String>,
58 pub eco_model: Option<String>,
59 pub recovery_model: Option<String>,
60 pub hook_registry: Option<HookRegistry<AgentState>>,
61}
62
63#[derive(Debug)]
64enum StreamMessage {
65 Delta(GenerationDelta),
66 Ctx(Box<HookContext<AgentState>>),
67}
68
69const DEFAULT_STORE_PATH: &str = ".stakpak/data/local.db";
70const SYSTEM_PROMPT: &str = include_str!("./prompts/agent.v1.txt");
71const TITLE_GENERATOR_PROMPT: &str = include_str!("./prompts/session_title_generator.v1.txt");
72
73impl LocalClient {
74 pub async fn new(config: LocalClientConfig) -> Result<Self, String> {
75 let default_store_path = std::env::home_dir()
76 .unwrap_or_default()
77 .join(DEFAULT_STORE_PATH);
78
79 if let Some(parent) = default_store_path.parent() {
80 std::fs::create_dir_all(parent)
81 .map_err(|e| format!("Failed to create database directory: {}", e))?;
82 }
83
84 let db = Builder::new_local(default_store_path.display().to_string())
85 .build()
86 .await
87 .map_err(|e| e.to_string())?;
88
89 let conn = db.connect().map_err(|e| e.to_string())?;
90
91 db::init_schema(&conn).await?;
93
94 let smart_model = config
95 .smart_model
96 .map(LLMModel::from)
97 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Sonnet));
98 let eco_model = config
99 .eco_model
100 .map(LLMModel::from)
101 .unwrap_or(LLMModel::Anthropic(AnthropicModel::Claude45Haiku));
102 let recovery_model = config
103 .recovery_model
104 .map(LLMModel::from)
105 .unwrap_or(LLMModel::OpenAI(OpenAIModel::GPT5));
106
107 let mut hook_registry = config.hook_registry.unwrap_or_default();
109 hook_registry.register(
110 LifecycleEvent::BeforeInference,
111 Box::new(ContextHook::new(ContextHookOptions {
112 context_manager: Box::new(ScratchpadContextManager::new(
113 ScratchpadContextManagerOptions {
114 history_action_message_size_limit: 100,
115 history_action_message_keep_last_n: 1,
116 history_action_result_keep_last_n: 50,
117 },
118 )),
119 smart_model: (smart_model.clone(), SYSTEM_PROMPT.to_string()),
120 eco_model: (eco_model.clone(), SYSTEM_PROMPT.to_string()),
121 recovery_model: (recovery_model.clone(), SYSTEM_PROMPT.to_string()),
122 })),
123 );
124
125 Ok(Self {
126 db: conn,
127 stakpak_base_url: config.stakpak_base_url.map(|url| url + "/v1"),
128 anthropic_config: config.anthropic_config,
129 gemini_config: config.gemini_config,
130 openai_config: config.openai_config,
131 smart_model,
132 eco_model,
133 recovery_model,
134 hook_registry: Some(Arc::new(hook_registry)),
135 })
136 }
137}
138
139#[async_trait]
140impl AgentProvider for LocalClient {
141 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
142 Ok(GetMyAccountResponse {
143 username: "local".to_string(),
144 id: "local".to_string(),
145 first_name: "local".to_string(),
146 last_name: "local".to_string(),
147 })
148 }
149
150 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
151 if self.stakpak_base_url.is_none() {
152 return Ok(vec![]);
153 }
154
155 let stakpak_base_url = self
156 .stakpak_base_url
157 .as_ref()
158 .ok_or("Stakpak base URL not set")?;
159
160 let url = format!("{}/rules", stakpak_base_url);
161
162 let client = create_tls_client(
163 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
164 )?;
165
166 let response = client
167 .get(url)
168 .send()
169 .await
170 .map_err(|e: ReqwestError| e.to_string())?;
171
172 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
173
174 match serde_json::from_value::<ListRulebooksResponse>(value.clone()) {
175 Ok(response) => Ok(response.results),
176 Err(e) => {
177 eprintln!("Failed to deserialize response: {}", e);
178 eprintln!("Raw response: {}", value);
179 Err("Failed to deserialize response:".into())
180 }
181 }
182 }
183
184 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
185 let stakpak_base_url = self
186 .stakpak_base_url
187 .as_ref()
188 .ok_or("Stakpak base URL not set")?;
189
190 let encoded_uri = urlencoding::encode(uri);
191
192 let url = format!("{}/rules/{}", stakpak_base_url, encoded_uri);
193
194 let client = create_tls_client(
195 TlsClientConfig::default().with_timeout(std::time::Duration::from_secs(300)),
196 )?;
197
198 let response = client
199 .get(&url)
200 .send()
201 .await
202 .map_err(|e: ReqwestError| e.to_string())?;
203
204 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
205
206 match serde_json::from_value::<RuleBook>(value.clone()) {
207 Ok(response) => Ok(response),
208 Err(e) => {
209 eprintln!("Failed to deserialize response: {}", e);
210 eprintln!("Raw response: {}", value);
211 Err("Failed to deserialize response:".into())
212 }
213 }
214 }
215
216 async fn create_rulebook(
217 &self,
218 _uri: &str,
219 _description: &str,
220 _content: &str,
221 _tags: Vec<String>,
222 _visibility: Option<RuleBookVisibility>,
223 ) -> Result<CreateRuleBookResponse, String> {
224 Err("Local provider does not support rulebooks yet".to_string())
226 }
227
228 async fn delete_rulebook(&self, _uri: &str) -> Result<(), String> {
229 Err("Local provider does not support rulebooks yet".to_string())
231 }
232
233 async fn list_agent_sessions(&self) -> Result<Vec<AgentSession>, String> {
234 db::list_sessions(&self.db).await
235 }
236
237 async fn get_agent_session(&self, session_id: Uuid) -> Result<AgentSession, String> {
238 db::get_session(&self.db, session_id).await
239 }
240
241 async fn get_agent_session_stats(
242 &self,
243 _session_id: Uuid,
244 ) -> Result<AgentSessionStats, String> {
245 Ok(AgentSessionStats::default())
247 }
248
249 async fn get_agent_checkpoint(&self, checkpoint_id: Uuid) -> Result<RunAgentOutput, String> {
250 db::get_checkpoint(&self.db, checkpoint_id).await
251 }
252
253 async fn get_agent_session_latest_checkpoint(
254 &self,
255 session_id: Uuid,
256 ) -> Result<RunAgentOutput, String> {
257 db::get_latest_checkpoint(&self.db, session_id).await
258 }
259
260 async fn chat_completion(
261 &self,
262 model: AgentModel,
263 messages: Vec<ChatMessage>,
264 tools: Option<Vec<Tool>>,
265 ) -> Result<ChatCompletionResponse, String> {
266 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
267
268 if let Some(hook_registry) = &self.hook_registry {
269 hook_registry
270 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
271 .await
272 .map_err(|e| e.to_string())?
273 .ok()?;
274 }
275
276 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
277 ctx.set_session_id(current_checkpoint.session.id);
278
279 let new_message = self.run_agent_completion(&mut ctx, None).await?;
280 ctx.state.append_new_message(new_message.clone());
281
282 let result = self
283 .update_session(¤t_checkpoint, ctx.state.messages.clone())
284 .await?;
285 let checkpoint_created_at = result.checkpoint.created_at.timestamp() as u64;
286 ctx.set_new_checkpoint_id(result.checkpoint.id);
287
288 if let Some(hook_registry) = &self.hook_registry {
289 hook_registry
290 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
291 .await
292 .map_err(|e| e.to_string())?
293 .ok()?;
294 }
295
296 Ok(ChatCompletionResponse {
297 id: ctx.new_checkpoint_id.unwrap().to_string(),
298 object: "chat.completion".to_string(),
299 created: checkpoint_created_at,
300 model: ctx
301 .state
302 .llm_input
303 .as_ref()
304 .map(|llm_input| llm_input.model.clone().to_string())
305 .unwrap_or_default(),
306 choices: vec![ChatCompletionChoice {
307 index: 0,
308 message: ctx.state.messages.last().cloned().unwrap(),
309 logprobs: None,
310 finish_reason: FinishReason::Stop,
311 }],
312 usage: ctx
313 .state
314 .llm_output
315 .as_ref()
316 .map(|u| u.usage.clone())
317 .unwrap_or_default(),
318 system_fingerprint: None,
319 metadata: None,
320 })
321 }
322
323 async fn chat_completion_stream(
324 &self,
325 model: AgentModel,
326 messages: Vec<ChatMessage>,
327 tools: Option<Vec<Tool>>,
328 _headers: Option<HeaderMap>,
329 ) -> Result<
330 (
331 Pin<
332 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
333 >,
334 Option<String>,
335 ),
336 String,
337 > {
338 let mut ctx = HookContext::new(None, AgentState::new(model, messages, tools));
339
340 if let Some(hook_registry) = &self.hook_registry {
341 hook_registry
342 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
343 .await
344 .map_err(|e| e.to_string())?
345 .ok()?;
346 }
347
348 let current_checkpoint = self.initialize_session(&ctx.state.messages).await?;
349 ctx.set_session_id(current_checkpoint.session.id);
350
351 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
352
353 let _ = tx
354 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
355 content: format!(
356 "\n<checkpoint_id>{}</checkpoint_id>\n",
357 current_checkpoint.checkpoint.id
358 ),
359 })))
360 .await;
361
362 let client = self.clone();
363 let self_clone = self.clone();
364 let mut ctx_clone = ctx.clone();
365 tokio::spawn(async move {
366 let result = client
367 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
368 .await;
369
370 match result {
371 Err(e) => {
372 let _ = tx.send(Err(e)).await;
373 }
374 Ok(new_message) => {
375 ctx_clone.state.append_new_message(new_message.clone());
376 let _ = tx
377 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
378 .await;
379
380 let output = self_clone
381 .update_session(¤t_checkpoint, ctx_clone.state.messages.clone())
382 .await;
383
384 match output {
385 Err(e) => {
386 let _ = tx.send(Err(e)).await;
387 }
388 Ok(output) => {
389 ctx_clone.set_new_checkpoint_id(output.checkpoint.id);
390 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
391 let _ = tx
392 .send(Ok(StreamMessage::Delta(GenerationDelta::Content {
393 content: format!(
394 "\n<checkpoint_id>{}</checkpoint_id>\n",
395 output.checkpoint.id
396 ),
397 })))
398 .await;
399 }
400 }
401 }
402 }
403 });
404
405 let hook_registry = self.hook_registry.clone();
406 let stream = async_stream::stream! {
407 while let Some(delta_result) = rx.recv().await {
408 match delta_result {
409 Ok(delta) => match delta {
410 StreamMessage::Ctx(updated_ctx) => {
411 ctx = *updated_ctx;
412 }
413 StreamMessage::Delta(delta) => {
414 yield Ok(ChatCompletionStreamResponse {
415 id: ctx.request_id.to_string(),
416 object: "chat.completion.chunk".to_string(),
417 created: chrono::Utc::now().timestamp() as u64,
418 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
419 choices: vec![ChatCompletionStreamChoice {
420 index: 0,
421 delta: delta.into(),
422 finish_reason: None,
423 }],
424 usage: ctx.state.llm_output.as_ref().map(|u| u.usage.clone()),
425 metadata: None,
426 })
427 }
428 }
429 Err(e) => yield Err(ApiStreamError::Unknown(e)),
430 }
431 }
432
433 if let Some(hook_registry) = hook_registry {
434 hook_registry
435 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
436 .await
437 .map_err(|e| e.to_string())?
438 .ok()?;
439 }
440 };
441
442 Ok((Box::pin(stream), None))
443 }
444
445 async fn cancel_stream(&self, _request_id: String) -> Result<(), String> {
446 Ok(())
447 }
448
449 async fn search_docs(&self, _input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
450 Ok(Vec::new())
452 }
453
454 async fn search_memory(&self, _input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
455 Ok(Vec::new())
457 }
458
459 async fn slack_read_messages(
460 &self,
461 _input: &SlackReadMessagesRequest,
462 ) -> Result<Vec<Content>, String> {
463 Ok(Vec::new())
465 }
466
467 async fn slack_read_replies(
468 &self,
469 _input: &SlackReadRepliesRequest,
470 ) -> Result<Vec<Content>, String> {
471 Ok(Vec::new())
473 }
474
475 async fn slack_send_message(
476 &self,
477 _input: &SlackSendMessageRequest,
478 ) -> Result<Vec<Content>, String> {
479 Ok(Vec::new())
481 }
482
483 async fn memorize_session(&self, _checkpoint_id: Uuid) -> Result<(), String> {
484 Ok(())
486 }
487}
488
489impl LocalClient {
490 fn get_llm_config(&self) -> LLMProviderConfig {
491 LLMProviderConfig {
492 anthropic_config: self.anthropic_config.clone(),
493 openai_config: self.openai_config.clone(),
494 gemini_config: self.gemini_config.clone(),
495 }
496 }
497
498 async fn run_agent_completion(
499 &self,
500 ctx: &mut HookContext<AgentState>,
501 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
502 ) -> Result<ChatMessage, String> {
503 if let Some(hook_registry) = &self.hook_registry {
504 hook_registry
505 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
506 .await
507 .map_err(|e| e.to_string())?
508 .ok()?;
509 }
510
511 let input = if let Some(llm_input) = ctx.state.llm_input.clone() {
512 llm_input
513 } else {
514 return Err(
515 "Run agent completion: LLM input not found, make sure to register a context hook before inference"
516 .to_string(),
517 );
518 };
519
520 let llm_config = self.get_llm_config();
521
522 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
523 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
524 let input = LLMStreamInput {
525 model: input.model,
526 messages: input.messages,
527 max_tokens: input.max_tokens,
528 tools: input.tools,
529 stream_channel_tx: internal_tx,
530 };
531
532 let chat_future = async move {
533 chat_stream(&llm_config, input)
534 .await
535 .map_err(|e| e.to_string())
536 };
537
538 let receive_future = async move {
539 while let Some(delta) = internal_rx.recv().await {
540 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
541 break;
542 }
543 }
544 };
545
546 let (chat_result, _) = tokio::join!(chat_future, receive_future);
547 let response = chat_result?;
548 (response.choices[0].message.clone(), response.usage)
549 } else {
550 let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
551 (response.choices[0].message.clone(), response.usage)
552 };
553
554 ctx.state.set_llm_output(response_message, usage);
555
556 if let Some(hook_registry) = &self.hook_registry {
557 hook_registry
558 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
559 .await
560 .map_err(|e| e.to_string())?
561 .ok()?;
562 }
563
564 let llm_output = ctx
565 .state
566 .llm_output
567 .as_ref()
568 .ok_or_else(|| "LLM output is missing from state".to_string())?;
569
570 Ok(ChatMessage::from(llm_output))
571 }
572
573 async fn initialize_session(&self, messages: &[ChatMessage]) -> Result<RunAgentOutput, String> {
574 if messages.is_empty() {
576 return Err("At least one message is required".to_string());
577 }
578
579 let checkpoint_id = ChatMessage::last_server_message(messages).and_then(|message| {
581 message
582 .content
583 .as_ref()
584 .and_then(|content| content.extract_checkpoint_id())
585 });
586
587 let current_checkpoint = if let Some(checkpoint_id) = checkpoint_id {
588 db::get_checkpoint(&self.db, checkpoint_id).await?
589 } else {
590 let title = self.generate_session_title(messages).await?;
591
592 let session_id = Uuid::new_v4();
594 let now = chrono::Utc::now();
595 let session = AgentSession {
596 id: session_id,
597 title,
598 agent_id: AgentID::PabloV1,
599 visibility: AgentSessionVisibility::Private,
600 created_at: now,
601 updated_at: now,
602 checkpoints: vec![],
603 };
604 db::create_session(&self.db, &session).await?;
605
606 let checkpoint_id = Uuid::new_v4();
608 let checkpoint = AgentCheckpointListItem {
609 id: checkpoint_id,
610 status: AgentStatus::Complete,
611 execution_depth: 0,
612 parent: None,
613 created_at: now,
614 updated_at: now,
615 };
616 let initial_state = AgentOutput::PabloV1 {
617 messages: messages.to_vec(),
618 node_states: serde_json::json!({}),
619 };
620 db::create_checkpoint(&self.db, session_id, &checkpoint, &initial_state).await?;
621
622 db::get_checkpoint(&self.db, checkpoint_id).await?
623 };
624
625 Ok(current_checkpoint)
626 }
627
628 async fn update_session(
629 &self,
630 checkpoint_info: &RunAgentOutput,
631 new_messages: Vec<ChatMessage>,
632 ) -> Result<RunAgentOutput, String> {
633 let now = chrono::Utc::now();
634 let complete_checkpoint = AgentCheckpointListItem {
635 id: Uuid::new_v4(),
636 status: AgentStatus::Complete,
637 execution_depth: checkpoint_info.checkpoint.execution_depth + 1,
638 parent: Some(AgentParentCheckpoint {
639 id: checkpoint_info.checkpoint.id,
640 }),
641 created_at: now,
642 updated_at: now,
643 };
644
645 let mut new_state = checkpoint_info.output.clone();
646 new_state.set_messages(new_messages);
647
648 db::create_checkpoint(
649 &self.db,
650 checkpoint_info.session.id,
651 &complete_checkpoint,
652 &new_state,
653 )
654 .await?;
655
656 Ok(RunAgentOutput {
657 checkpoint: complete_checkpoint,
658 session: checkpoint_info.session.clone(),
659 output: new_state,
660 })
661 }
662
663 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
664 let llm_config = self.get_llm_config();
665 let llm_model = self.eco_model.clone();
666
667 let messages = vec![
668 LLMMessage {
669 role: "system".to_string(),
670 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.into()),
671 },
672 LLMMessage {
673 role: "user".to_string(),
674 content: LLMMessageContent::String(
675 messages
676 .iter()
677 .map(|msg| {
678 msg.content
679 .as_ref()
680 .unwrap_or(&MessageContent::String("".to_string()))
681 .to_string()
682 })
683 .collect(),
684 ),
685 },
686 ];
687
688 let input = LLMInput {
689 model: llm_model,
690 messages,
691 max_tokens: 100,
692 tools: None,
693 };
694
695 let response = chat(&llm_config, input).await.map_err(|e| e.to_string())?;
696
697 Ok(response.choices[0].message.content.to_string())
698 }
699}