Skip to main content

stakpak_api/client/
provider.rs

1//! AgentProvider trait implementation for AgentClient
2//!
3//! Implements the unified provider interface with:
4//! - Stakpak-first routing when API key is present
5//! - Local fallback when Stakpak is unavailable
6//! - Hook registry integration for lifecycle events
7
8use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11    CreateCheckpointRequest as StorageCreateCheckpointRequest,
12    CreateSessionRequest as StorageCreateSessionRequest,
13    UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22    ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32/// Lightweight session info returned by initialize_session / save_checkpoint
33#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35    session_id: Uuid,
36    checkpoint_id: Uuid,
37    checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42// =============================================================================
43// Internal Message Types
44// =============================================================================
45
46#[derive(Debug)]
47pub(crate) enum StreamMessage {
48    Delta(GenerationDelta),
49    Ctx(Box<HookContext<AgentState>>),
50}
51
52// =============================================================================
53// AgentProvider Implementation
54// =============================================================================
55
56#[async_trait]
57impl AgentProvider for AgentClient {
58    // =========================================================================
59    // Account
60    // =========================================================================
61
62    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63        if let Some(api) = &self.stakpak_api {
64            api.get_account().await
65        } else {
66            // Local stub
67            Ok(GetMyAccountResponse {
68                username: "local".to_string(),
69                id: "local".to_string(),
70                first_name: "local".to_string(),
71                last_name: "local".to_string(),
72                email: "local@stakpak.dev".to_string(),
73                scope: None,
74            })
75        }
76    }
77
78    async fn get_billing_info(
79        &self,
80        account_username: &str,
81    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82        if let Some(api) = &self.stakpak_api {
83            api.get_billing(account_username).await
84        } else {
85            Err("Billing info not available without Stakpak API key".to_string())
86        }
87    }
88
89    // =========================================================================
90    // Rulebooks
91    // =========================================================================
92
93    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94        if let Some(api) = &self.stakpak_api {
95            api.list_rulebooks().await
96        } else {
97            // Try to fetch public rulebooks via unauthenticated request
98            let client = stakpak_shared::tls_client::create_tls_client(
99                stakpak_shared::tls_client::TlsClientConfig::default()
100                    .with_timeout(std::time::Duration::from_secs(30)),
101            )?;
102
103            let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106            if response.status().is_success() {
107                let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108                match serde_json::from_value::<ListRulebooksResponse>(value) {
109                    Ok(resp) => Ok(resp.results),
110                    Err(_) => Ok(vec![]),
111                }
112            } else {
113                Ok(vec![])
114            }
115        }
116    }
117
118    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119        if let Some(api) = &self.stakpak_api {
120            api.get_rulebook_by_uri(uri).await
121        } else {
122            // Try to fetch public rulebook via unauthenticated request
123            let client = stakpak_shared::tls_client::create_tls_client(
124                stakpak_shared::tls_client::TlsClientConfig::default()
125                    .with_timeout(std::time::Duration::from_secs(30)),
126            )?;
127
128            let encoded_uri = urlencoding::encode(uri);
129            let url = format!(
130                "{}/v1/rules/{}",
131                self.get_stakpak_api_endpoint(),
132                encoded_uri
133            );
134            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136            if response.status().is_success() {
137                response.json().await.map_err(|e| e.to_string())
138            } else {
139                Err("Rulebook not found".to_string())
140            }
141        }
142    }
143
144    async fn create_rulebook(
145        &self,
146        uri: &str,
147        description: &str,
148        content: &str,
149        tags: Vec<String>,
150        visibility: Option<RuleBookVisibility>,
151    ) -> Result<CreateRuleBookResponse, String> {
152        if let Some(api) = &self.stakpak_api {
153            api.create_rulebook(&CreateRuleBookInput {
154                uri: uri.to_string(),
155                description: description.to_string(),
156                content: content.to_string(),
157                tags,
158                visibility,
159            })
160            .await
161        } else {
162            Err("Creating rulebooks requires Stakpak API key".to_string())
163        }
164    }
165
166    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167        if let Some(api) = &self.stakpak_api {
168            api.delete_rulebook(uri).await
169        } else {
170            Err("Deleting rulebooks requires Stakpak API key".to_string())
171        }
172    }
173
174    // =========================================================================
175    // Chat Completion
176    // =========================================================================
177
178    async fn chat_completion(
179        &self,
180        model: Model,
181        messages: Vec<ChatMessage>,
182        tools: Option<Vec<Tool>>,
183        session_id: Option<Uuid>,
184        metadata: Option<serde_json::Value>,
185    ) -> Result<ChatCompletionResponse, String> {
186        let mut ctx = HookContext::new(
187            session_id,
188            AgentState::new(model, messages, tools, metadata),
189        );
190
191        // Execute before request hooks
192        self.hook_registry
193            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
194            .await
195            .map_err(|e| e.to_string())?
196            .ok()?;
197
198        // Initialize or resume session
199        let current_session = self.initialize_session(&ctx).await?;
200        ctx.set_session_id(current_session.session_id);
201
202        // Run completion
203        let new_message = self.run_agent_completion(&mut ctx, None).await?;
204        ctx.state.append_new_message(new_message.clone());
205
206        // Save checkpoint
207        let result = self
208            .save_checkpoint(
209                &current_session,
210                ctx.state.messages.clone(),
211                ctx.state.metadata.clone(),
212            )
213            .await?;
214        let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
215        ctx.set_new_checkpoint_id(result.checkpoint_id);
216
217        // Execute after request hooks
218        self.hook_registry
219            .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
220            .await
221            .map_err(|e| e.to_string())?
222            .ok()?;
223
224        let mut meta = serde_json::Map::new();
225        if let Some(session_id) = ctx.session_id {
226            meta.insert(
227                "session_id".to_string(),
228                serde_json::Value::String(session_id.to_string()),
229            );
230        }
231        if let Some(checkpoint_id) = ctx.new_checkpoint_id {
232            meta.insert(
233                "checkpoint_id".to_string(),
234                serde_json::Value::String(checkpoint_id.to_string()),
235            );
236        }
237        if let Some(state_metadata) = &ctx.state.metadata {
238            meta.insert("state_metadata".to_string(), state_metadata.clone());
239        }
240
241        Ok(ChatCompletionResponse {
242            id: ctx.new_checkpoint_id.unwrap().to_string(),
243            object: "chat.completion".to_string(),
244            created: checkpoint_created_at,
245            model: ctx
246                .state
247                .llm_input
248                .as_ref()
249                .map(|llm_input| llm_input.model.id.clone())
250                .unwrap_or_default(),
251            choices: vec![ChatCompletionChoice {
252                index: 0,
253                message: ctx.state.messages.last().cloned().unwrap(),
254                logprobs: None,
255                finish_reason: FinishReason::Stop,
256            }],
257            usage: ctx
258                .state
259                .llm_output
260                .as_ref()
261                .map(|u| u.usage.clone())
262                .unwrap_or_default(),
263            system_fingerprint: None,
264            metadata: if meta.is_empty() {
265                None
266            } else {
267                Some(serde_json::Value::Object(meta))
268            },
269        })
270    }
271
272    async fn chat_completion_stream(
273        &self,
274        model: Model,
275        messages: Vec<ChatMessage>,
276        tools: Option<Vec<Tool>>,
277        _headers: Option<HeaderMap>,
278        session_id: Option<Uuid>,
279        metadata: Option<serde_json::Value>,
280    ) -> Result<
281        (
282            Pin<
283                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
284            >,
285            Option<String>,
286        ),
287        String,
288    > {
289        let mut ctx = HookContext::new(
290            session_id,
291            AgentState::new(model, messages, tools, metadata),
292        );
293
294        // Execute before request hooks
295        self.hook_registry
296            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
297            .await
298            .map_err(|e| e.to_string())?
299            .ok()?;
300
301        // Initialize session
302        let current_session = self.initialize_session(&ctx).await?;
303        ctx.set_session_id(current_session.session_id);
304
305        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
306
307        // Clone what we need for the spawned task
308        let client = self.clone();
309        let mut ctx_clone = ctx.clone();
310
311        // Spawn the completion task with proper shutdown handling
312        // The task checks if the channel is closed before each expensive operation
313        // to support graceful shutdown when the stream consumer is dropped
314        tokio::spawn(async move {
315            // Check if consumer is still listening before starting
316            if tx.is_closed() {
317                return;
318            }
319
320            let result = client
321                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
322                .await;
323
324            match result {
325                Err(e) => {
326                    let _ = tx.send(Err(e)).await;
327                }
328                Ok(new_message) => {
329                    // Check if consumer is still listening before continuing
330                    if tx.is_closed() {
331                        return;
332                    }
333
334                    ctx_clone.state.append_new_message(new_message.clone());
335                    if tx
336                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
337                        .await
338                        .is_err()
339                    {
340                        // Consumer dropped, exit gracefully
341                        return;
342                    }
343
344                    // Check again before expensive session update
345                    if tx.is_closed() {
346                        return;
347                    }
348
349                    let result = client
350                        .save_checkpoint(
351                            &current_session,
352                            ctx_clone.state.messages.clone(),
353                            ctx_clone.state.metadata.clone(),
354                        )
355                        .await;
356
357                    match result {
358                        Err(e) => {
359                            let _ = tx.send(Err(e)).await;
360                        }
361                        Ok(updated) => {
362                            ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
363                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
364                        }
365                    }
366                }
367            }
368        });
369
370        let hook_registry = self.hook_registry.clone();
371        let stream = async_stream::stream! {
372            while let Some(delta_result) = rx.recv().await {
373                match delta_result {
374                    Ok(delta) => match delta {
375                        StreamMessage::Ctx(updated_ctx) => {
376                            ctx = *updated_ctx;
377                            // Emit session metadata so callers can track session_id
378                            if let Some(session_id) = ctx.session_id {
379                                let mut meta = serde_json::Map::new();
380                                meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
381                                if let Some(checkpoint_id) = ctx.new_checkpoint_id {
382                                    meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
383                                }
384                                if let Some(state_metadata) = &ctx.state.metadata {
385                                    meta.insert("state_metadata".to_string(), state_metadata.clone());
386                                }
387                                yield Ok(ChatCompletionStreamResponse {
388                                    id: ctx.request_id.to_string(),
389                                    object: "chat.completion.chunk".to_string(),
390                                    created: chrono::Utc::now().timestamp() as u64,
391                                    model: String::new(),
392                                    choices: vec![],
393                                    usage: None,
394                                    metadata: Some(serde_json::Value::Object(meta)),
395                                });
396                            }
397                        }
398                        StreamMessage::Delta(delta) => {
399                            // Extract usage from Usage delta variant
400                            let usage = if let GenerationDelta::Usage { usage } = &delta {
401                                Some(usage.clone())
402                            } else {
403                                None
404                            };
405
406                            yield Ok(ChatCompletionStreamResponse {
407                                id: ctx.request_id.to_string(),
408                                object: "chat.completion.chunk".to_string(),
409                                created: chrono::Utc::now().timestamp() as u64,
410                                model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
411                                choices: vec![ChatCompletionStreamChoice {
412                                    index: 0,
413                                    delta: delta.into(),
414                                    finish_reason: None,
415                                }],
416                                usage,
417                                metadata: None,
418                            })
419                        }
420                    }
421                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
422                }
423            }
424
425            // Execute after request hooks
426            hook_registry
427                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
428                .await
429                .map_err(|e| e.to_string())?
430                .ok()?;
431        };
432
433        Ok((Box::pin(stream), None))
434    }
435
436    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
437        if let Some(api) = &self.stakpak_api {
438            api.cancel_request(&request_id).await
439        } else {
440            // Local mode doesn't support cancellation yet
441            Ok(())
442        }
443    }
444
445    // =========================================================================
446    // Search Docs
447    // =========================================================================
448
449    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
450        if let Some(api) = &self.stakpak_api {
451            api.search_docs(&crate::stakpak::SearchDocsRequest {
452                keywords: input.keywords.clone(),
453                exclude_keywords: input.exclude_keywords.clone(),
454                limit: input.limit,
455            })
456            .await
457        } else {
458            // Fallback to local search service
459            use stakpak_shared::models::integrations::search_service::*;
460
461            let config = SearchServicesOrchestrator::start()
462                .await
463                .map_err(|e| e.to_string())?;
464
465            let api_url = format!("http://localhost:{}", config.api_port);
466            let search_client = SearchClient::new(api_url);
467
468            let search_results = search_client
469                .search_and_scrape(input.keywords.clone(), None)
470                .await
471                .map_err(|e| e.to_string())?;
472
473            if search_results.is_empty() {
474                return Ok(vec![Content::text("No results found".to_string())]);
475            }
476
477            Ok(search_results
478                .into_iter()
479                .map(|result| {
480                    let content = result.content.unwrap_or_default();
481                    Content::text(format!("URL: {}\nContent: {}", result.url, content))
482                })
483                .collect())
484        }
485    }
486
487    // =========================================================================
488    // Memory
489    // =========================================================================
490
491    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
492        if let Some(api) = &self.stakpak_api {
493            api.memorize_session(checkpoint_id).await
494        } else {
495            // No-op in local mode
496            Ok(())
497        }
498    }
499
500    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
501        if let Some(api) = &self.stakpak_api {
502            api.search_memory(&crate::stakpak::SearchMemoryRequest {
503                keywords: input.keywords.clone(),
504                start_time: input.start_time,
505                end_time: input.end_time,
506            })
507            .await
508        } else {
509            // Empty results in local mode
510            Ok(vec![])
511        }
512    }
513
514    // =========================================================================
515    // Slack
516    // =========================================================================
517
518    async fn slack_read_messages(
519        &self,
520        input: &SlackReadMessagesRequest,
521    ) -> Result<Vec<Content>, String> {
522        if let Some(api) = &self.stakpak_api {
523            api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
524                channel: input.channel.clone(),
525                limit: input.limit,
526            })
527            .await
528        } else {
529            Err("Slack integration requires Stakpak API key".to_string())
530        }
531    }
532
533    async fn slack_read_replies(
534        &self,
535        input: &SlackReadRepliesRequest,
536    ) -> Result<Vec<Content>, String> {
537        if let Some(api) = &self.stakpak_api {
538            api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
539                channel: input.channel.clone(),
540                ts: input.ts.clone(),
541            })
542            .await
543        } else {
544            Err("Slack integration requires Stakpak API key".to_string())
545        }
546    }
547
548    async fn slack_send_message(
549        &self,
550        input: &SlackSendMessageRequest,
551    ) -> Result<Vec<Content>, String> {
552        if let Some(api) = &self.stakpak_api {
553            api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
554                channel: input.channel.clone(),
555                markdown_text: input.markdown_text.clone(),
556                thread_ts: input.thread_ts.clone(),
557            })
558            .await
559        } else {
560            Err("Slack integration requires Stakpak API key".to_string())
561        }
562    }
563
564    // =========================================================================
565    // Models
566    // =========================================================================
567
568    async fn list_models(&self) -> Vec<stakai::Model> {
569        // Use the provider registry which only contains providers with configured API keys.
570        // This ensures we only list models for providers the user actually has access to.
571        // Aggregate per provider so one failing provider does not hide all others.
572        let registry = self.stakai.registry();
573        let mut all_models = Vec::new();
574
575        for provider_id in registry.list_providers() {
576            if let Ok(mut models) = registry.models_for_provider(&provider_id).await {
577                all_models.append(&mut models);
578            }
579        }
580
581        sort_models_by_recency(&mut all_models);
582        all_models
583    }
584}
585
586/// Sort models by release_date descending (newest first)
587fn sort_models_by_recency(models: &mut [stakai::Model]) {
588    models.sort_by(|a, b| {
589        match (&b.release_date, &a.release_date) {
590            (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
591            (Some(_), None) => std::cmp::Ordering::Less,
592            (None, Some(_)) => std::cmp::Ordering::Greater,
593            (None, None) => b.id.cmp(&a.id), // Fallback to ID descending
594        }
595    });
596}
597
598// =============================================================================
599// SessionStorage implementation (delegates to inner session_storage)
600// =============================================================================
601
602#[async_trait]
603impl crate::storage::SessionStorage for super::AgentClient {
604    fn backend_info(&self) -> crate::storage::BackendInfo {
605        self.session_storage.backend_info()
606    }
607
608    async fn list_sessions(
609        &self,
610        query: &crate::storage::ListSessionsQuery,
611    ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
612        self.session_storage.list_sessions(query).await
613    }
614
615    async fn get_session(
616        &self,
617        session_id: Uuid,
618    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
619        self.session_storage.get_session(session_id).await
620    }
621
622    async fn create_session(
623        &self,
624        request: &crate::storage::CreateSessionRequest,
625    ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
626        self.session_storage.create_session(request).await
627    }
628
629    async fn update_session(
630        &self,
631        session_id: Uuid,
632        request: &crate::storage::UpdateSessionRequest,
633    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
634        self.session_storage
635            .update_session(session_id, request)
636            .await
637    }
638
639    async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
640        self.session_storage.delete_session(session_id).await
641    }
642
643    async fn list_checkpoints(
644        &self,
645        session_id: Uuid,
646        query: &crate::storage::ListCheckpointsQuery,
647    ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
648        self.session_storage
649            .list_checkpoints(session_id, query)
650            .await
651    }
652
653    async fn get_checkpoint(
654        &self,
655        checkpoint_id: Uuid,
656    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
657        self.session_storage.get_checkpoint(checkpoint_id).await
658    }
659
660    async fn create_checkpoint(
661        &self,
662        session_id: Uuid,
663        request: &crate::storage::CreateCheckpointRequest,
664    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
665        self.session_storage
666            .create_checkpoint(session_id, request)
667            .await
668    }
669
670    async fn get_active_checkpoint(
671        &self,
672        session_id: Uuid,
673    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
674        self.session_storage.get_active_checkpoint(session_id).await
675    }
676
677    async fn get_session_stats(
678        &self,
679        session_id: Uuid,
680    ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
681        self.session_storage.get_session_stats(session_id).await
682    }
683}
684
685// =============================================================================
686// Helper Methods
687// =============================================================================
688
689const TITLE_GENERATOR_PROMPT: &str = include_str!("../prompts/session_title_generator.v1.txt");
690
691impl AgentClient {
692    /// Initialize or resume a session based on context
693    ///
694    /// If `ctx.session_id` is set, we resume that session directly.
695    /// Otherwise, we create a new session.
696    pub(crate) async fn initialize_session(
697        &self,
698        ctx: &HookContext<AgentState>,
699    ) -> Result<SessionInfo, String> {
700        let messages = &ctx.state.messages;
701
702        if messages.is_empty() {
703            return Err("At least one message is required".to_string());
704        }
705
706        // If session_id is set in context, resume that session directly
707        if let Some(session_id) = ctx.session_id {
708            let session = self
709                .session_storage
710                .get_session(session_id)
711                .await
712                .map_err(|e| e.to_string())?;
713
714            let checkpoint = session
715                .active_checkpoint
716                .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
717
718            // If the session still has the default title, generate a better one in the background.
719            if session.title.trim().is_empty() || session.title == "New Session" {
720                let client = self.clone();
721                let messages_for_title = messages.to_vec();
722                let session_id = session.id;
723                let existing_title = session.title.clone();
724                tokio::spawn(async move {
725                    if let Ok(title) = client.generate_session_title(&messages_for_title).await {
726                        let trimmed = title.trim();
727                        if !trimmed.is_empty() && trimmed != existing_title {
728                            let request =
729                                StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
730                            let _ = client
731                                .session_storage
732                                .update_session(session_id, &request)
733                                .await;
734                        }
735                    }
736                });
737            }
738
739            return Ok(SessionInfo {
740                session_id: session.id,
741                checkpoint_id: checkpoint.id,
742                checkpoint_created_at: checkpoint.created_at,
743            });
744        }
745
746        // Create new session with a fast local title.
747        let fallback_title = Self::fallback_session_title(messages);
748
749        // Get current working directory
750        let cwd = std::env::current_dir()
751            .ok()
752            .map(|p| p.to_string_lossy().to_string());
753
754        // Create session via storage trait
755        let mut session_request =
756            StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
757        if let Some(cwd) = cwd {
758            session_request = session_request.with_cwd(cwd);
759        }
760
761        let result = self
762            .session_storage
763            .create_session(&session_request)
764            .await
765            .map_err(|e| e.to_string())?;
766
767        // Generate a better title asynchronously and update the session when ready.
768        let client = self.clone();
769        let messages_for_title = messages.to_vec();
770        let session_id = result.session_id;
771        tokio::spawn(async move {
772            if let Ok(title) = client.generate_session_title(&messages_for_title).await {
773                let trimmed = title.trim();
774                if !trimmed.is_empty() && trimmed != fallback_title {
775                    let request =
776                        StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
777                    let _ = client
778                        .session_storage
779                        .update_session(session_id, &request)
780                        .await;
781                }
782            }
783        });
784
785        Ok(SessionInfo {
786            session_id: result.session_id,
787            checkpoint_id: result.checkpoint.id,
788            checkpoint_created_at: result.checkpoint.created_at,
789        })
790    }
791
792    fn fallback_session_title(messages: &[ChatMessage]) -> String {
793        messages
794            .iter()
795            .find(|m| m.role == Role::User)
796            .and_then(|m| m.content.as_ref())
797            .map(|c| {
798                let text = c.to_string();
799                text.split_whitespace()
800                    .take(5)
801                    .collect::<Vec<_>>()
802                    .join(" ")
803            })
804            .unwrap_or_else(|| "New Session".to_string())
805    }
806
807    /// Save a new checkpoint for the current session
808    pub(crate) async fn save_checkpoint(
809        &self,
810        current: &SessionInfo,
811        messages: Vec<ChatMessage>,
812        metadata: Option<serde_json::Value>,
813    ) -> Result<SessionInfo, String> {
814        let mut checkpoint_request =
815            StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
816
817        if let Some(meta) = metadata {
818            checkpoint_request = checkpoint_request.with_metadata(meta);
819        }
820
821        let checkpoint = self
822            .session_storage
823            .create_checkpoint(current.session_id, &checkpoint_request)
824            .await
825            .map_err(|e| e.to_string())?;
826
827        Ok(SessionInfo {
828            session_id: current.session_id,
829            checkpoint_id: checkpoint.id,
830            checkpoint_created_at: checkpoint.created_at,
831        })
832    }
833
834    /// Run agent completion (inference)
835    pub(crate) async fn run_agent_completion(
836        &self,
837        ctx: &mut HookContext<AgentState>,
838        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
839    ) -> Result<ChatMessage, String> {
840        // Execute before inference hooks
841        self.hook_registry
842            .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
843            .await
844            .map_err(|e| e.to_string())?
845            .ok()?;
846
847        let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
848            llm_input
849        } else {
850            return Err(
851                "LLM input not found, make sure to register a context hook before inference"
852                    .to_string(),
853            );
854        };
855
856        // Inject session_id header if available
857        if let Some(session_id) = ctx.session_id {
858            let headers = input
859                .headers
860                .get_or_insert_with(std::collections::HashMap::new);
861            headers.insert("X-Session-Id".to_string(), session_id.to_string());
862        }
863
864        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
865            // Streaming mode
866            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
867            let stream_input = LLMStreamInput {
868                model: input.model,
869                messages: input.messages,
870                max_tokens: input.max_tokens,
871                tools: input.tools,
872                stream_channel_tx: internal_tx,
873                provider_options: input.provider_options,
874                headers: input.headers,
875            };
876
877            let stakai = self.stakai.clone();
878            let chat_future = async move {
879                stakai
880                    .chat_stream(stream_input)
881                    .await
882                    .map_err(|e| e.to_string())
883            };
884
885            let receive_future = async move {
886                while let Some(delta) = internal_rx.recv().await {
887                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
888                        break;
889                    }
890                }
891            };
892
893            let (chat_result, _) = tokio::join!(chat_future, receive_future);
894            let response = chat_result?;
895            (response.choices[0].message.clone(), response.usage)
896        } else {
897            // Non-streaming mode
898            let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
899            (response.choices[0].message.clone(), response.usage)
900        };
901
902        ctx.state.set_llm_output(response_message, usage);
903
904        // Execute after inference hooks
905        self.hook_registry
906            .execute_hooks(ctx, &LifecycleEvent::AfterInference)
907            .await
908            .map_err(|e| e.to_string())?
909            .ok()?;
910
911        let llm_output = ctx
912            .state
913            .llm_output
914            .as_ref()
915            .ok_or_else(|| "LLM output is missing from state".to_string())?;
916
917        Ok(ChatMessage::from(llm_output))
918    }
919
920    /// Generate a title for a new session
921    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
922        // Pick a cheap model from the user's configured providers
923        let use_stakpak = self.stakpak.is_some();
924        let providers = self.stakai.registry().list_providers();
925        let cheap_models: &[(&str, &str)] = &[
926            ("stakpak", "claude-haiku-4-5"),
927            ("anthropic", "claude-haiku-4-5"),
928            ("amazon-bedrock", "claude-haiku-4-5"),
929            ("openai", "gpt-4.1-mini"),
930            ("google", "gemini-2.5-flash"),
931        ];
932        let model = cheap_models
933            .iter()
934            .find_map(|(provider, model_id)| {
935                if providers.contains(&provider.to_string()) {
936                    crate::find_model(model_id, use_stakpak)
937                } else {
938                    None
939                }
940            })
941            .ok_or_else(|| "No model available for title generation".to_string())?;
942
943        let llm_messages = vec![
944            LLMMessage {
945                role: Role::System.to_string(),
946                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
947            },
948            LLMMessage {
949                role: Role::User.to_string(),
950                content: LLMMessageContent::String(
951                    messages
952                        .iter()
953                        .map(|msg| {
954                            msg.content
955                                .as_ref()
956                                .unwrap_or(&MessageContent::String("".to_string()))
957                                .to_string()
958                        })
959                        .collect(),
960                ),
961            },
962        ];
963
964        let input = LLMInput {
965            model,
966            messages: llm_messages,
967            max_tokens: 100,
968            tools: None,
969            provider_options: None,
970            headers: None,
971        };
972
973        let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
974
975        Ok(response.choices[0].message.content.to_string())
976    }
977}