Skip to main content

stakpak_api/client/
provider.rs

1//! AgentProvider trait implementation for AgentClient
2//!
3//! Implements the unified provider interface with:
4//! - Stakpak-first routing when API key is present
5//! - Local fallback when Stakpak is unavailable
6//! - Hook registry integration for lifecycle events
7
8use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11    CreateCheckpointRequest as StorageCreateCheckpointRequest,
12    CreateSessionRequest as StorageCreateSessionRequest,
13    UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakpak_shared::hooks::{HookContext, LifecycleEvent};
20use stakpak_shared::models::integrations::anthropic::AnthropicModel;
21use stakpak_shared::models::integrations::openai::{
22    AgentModel, ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23    ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26    GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMModel, LLMStreamInput,
27};
28use stakpak_shared::models::stakai_adapter::get_stakai_model_string;
29use std::pin::Pin;
30use tokio::sync::mpsc;
31use uuid::Uuid;
32
33/// Lightweight session info returned by initialize_session / save_checkpoint
34#[derive(Debug, Clone)]
35pub(crate) struct SessionInfo {
36    session_id: Uuid,
37    checkpoint_id: Uuid,
38    checkpoint_created_at: chrono::DateTime<chrono::Utc>,
39}
40
41use super::AgentClient;
42
43// =============================================================================
44// Internal Message Types
45// =============================================================================
46
47#[derive(Debug)]
48pub(crate) enum StreamMessage {
49    Delta(GenerationDelta),
50    Ctx(Box<HookContext<AgentState>>),
51}
52
53// =============================================================================
54// AgentProvider Implementation
55// =============================================================================
56
57#[async_trait]
58impl AgentProvider for AgentClient {
59    // =========================================================================
60    // Account
61    // =========================================================================
62
63    async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
64        if let Some(api) = &self.stakpak_api {
65            api.get_account().await
66        } else {
67            // Local stub
68            Ok(GetMyAccountResponse {
69                username: "local".to_string(),
70                id: "local".to_string(),
71                first_name: "local".to_string(),
72                last_name: "local".to_string(),
73                email: "local@stakpak.dev".to_string(),
74                scope: None,
75            })
76        }
77    }
78
79    async fn get_billing_info(
80        &self,
81        account_username: &str,
82    ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
83        if let Some(api) = &self.stakpak_api {
84            api.get_billing(account_username).await
85        } else {
86            Err("Billing info not available without Stakpak API key".to_string())
87        }
88    }
89
90    // =========================================================================
91    // Rulebooks
92    // =========================================================================
93
94    async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
95        if let Some(api) = &self.stakpak_api {
96            api.list_rulebooks().await
97        } else {
98            // Try to fetch public rulebooks via unauthenticated request
99            let client = stakpak_shared::tls_client::create_tls_client(
100                stakpak_shared::tls_client::TlsClientConfig::default()
101                    .with_timeout(std::time::Duration::from_secs(30)),
102            )?;
103
104            let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
105            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
106
107            if response.status().is_success() {
108                let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
109                match serde_json::from_value::<ListRulebooksResponse>(value) {
110                    Ok(resp) => Ok(resp.results),
111                    Err(_) => Ok(vec![]),
112                }
113            } else {
114                Ok(vec![])
115            }
116        }
117    }
118
119    async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
120        if let Some(api) = &self.stakpak_api {
121            api.get_rulebook_by_uri(uri).await
122        } else {
123            // Try to fetch public rulebook via unauthenticated request
124            let client = stakpak_shared::tls_client::create_tls_client(
125                stakpak_shared::tls_client::TlsClientConfig::default()
126                    .with_timeout(std::time::Duration::from_secs(30)),
127            )?;
128
129            let encoded_uri = urlencoding::encode(uri);
130            let url = format!(
131                "{}/v1/rules/{}",
132                self.get_stakpak_api_endpoint(),
133                encoded_uri
134            );
135            let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
136
137            if response.status().is_success() {
138                response.json().await.map_err(|e| e.to_string())
139            } else {
140                Err("Rulebook not found".to_string())
141            }
142        }
143    }
144
145    async fn create_rulebook(
146        &self,
147        uri: &str,
148        description: &str,
149        content: &str,
150        tags: Vec<String>,
151        visibility: Option<RuleBookVisibility>,
152    ) -> Result<CreateRuleBookResponse, String> {
153        if let Some(api) = &self.stakpak_api {
154            api.create_rulebook(&CreateRuleBookInput {
155                uri: uri.to_string(),
156                description: description.to_string(),
157                content: content.to_string(),
158                tags,
159                visibility,
160            })
161            .await
162        } else {
163            Err("Creating rulebooks requires Stakpak API key".to_string())
164        }
165    }
166
167    async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
168        if let Some(api) = &self.stakpak_api {
169            api.delete_rulebook(uri).await
170        } else {
171            Err("Deleting rulebooks requires Stakpak API key".to_string())
172        }
173    }
174
175    // =========================================================================
176    // Chat Completion
177    // =========================================================================
178
179    async fn chat_completion(
180        &self,
181        model: AgentModel,
182        messages: Vec<ChatMessage>,
183        tools: Option<Vec<Tool>>,
184        session_id: Option<Uuid>,
185    ) -> Result<ChatCompletionResponse, String> {
186        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
187
188        // Execute before request hooks
189        self.hook_registry
190            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
191            .await
192            .map_err(|e| e.to_string())?
193            .ok()?;
194
195        // Initialize or resume session
196        let current_session = self.initialize_session(&ctx).await?;
197        ctx.set_session_id(current_session.session_id);
198
199        // Run completion
200        let new_message = self.run_agent_completion(&mut ctx, None).await?;
201        ctx.state.append_new_message(new_message.clone());
202
203        // Save checkpoint
204        let result = self
205            .save_checkpoint(&current_session, ctx.state.messages.clone())
206            .await?;
207        let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
208        ctx.set_new_checkpoint_id(result.checkpoint_id);
209
210        // Execute after request hooks
211        self.hook_registry
212            .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
213            .await
214            .map_err(|e| e.to_string())?
215            .ok()?;
216
217        let mut meta = serde_json::Map::new();
218        if let Some(session_id) = ctx.session_id {
219            meta.insert(
220                "session_id".to_string(),
221                serde_json::Value::String(session_id.to_string()),
222            );
223        }
224        if let Some(checkpoint_id) = ctx.new_checkpoint_id {
225            meta.insert(
226                "checkpoint_id".to_string(),
227                serde_json::Value::String(checkpoint_id.to_string()),
228            );
229        }
230
231        Ok(ChatCompletionResponse {
232            id: ctx.new_checkpoint_id.unwrap().to_string(),
233            object: "chat.completion".to_string(),
234            created: checkpoint_created_at,
235            model: ctx
236                .state
237                .llm_input
238                .as_ref()
239                .map(|llm_input| llm_input.model.clone().to_string())
240                .unwrap_or_default(),
241            choices: vec![ChatCompletionChoice {
242                index: 0,
243                message: ctx.state.messages.last().cloned().unwrap(),
244                logprobs: None,
245                finish_reason: FinishReason::Stop,
246            }],
247            usage: ctx
248                .state
249                .llm_output
250                .as_ref()
251                .map(|u| u.usage.clone())
252                .unwrap_or_default(),
253            system_fingerprint: None,
254            metadata: if meta.is_empty() {
255                None
256            } else {
257                Some(serde_json::Value::Object(meta))
258            },
259        })
260    }
261
262    async fn chat_completion_stream(
263        &self,
264        model: AgentModel,
265        messages: Vec<ChatMessage>,
266        tools: Option<Vec<Tool>>,
267        _headers: Option<HeaderMap>,
268        session_id: Option<Uuid>,
269    ) -> Result<
270        (
271            Pin<
272                Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
273            >,
274            Option<String>,
275        ),
276        String,
277    > {
278        let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
279
280        // Execute before request hooks
281        self.hook_registry
282            .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
283            .await
284            .map_err(|e| e.to_string())?
285            .ok()?;
286
287        // Initialize session
288        let current_session = self.initialize_session(&ctx).await?;
289        ctx.set_session_id(current_session.session_id);
290
291        let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
292
293        // Clone what we need for the spawned task
294        let client = self.clone();
295        let mut ctx_clone = ctx.clone();
296
297        // Spawn the completion task with proper shutdown handling
298        // The task checks if the channel is closed before each expensive operation
299        // to support graceful shutdown when the stream consumer is dropped
300        tokio::spawn(async move {
301            // Check if consumer is still listening before starting
302            if tx.is_closed() {
303                return;
304            }
305
306            let result = client
307                .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
308                .await;
309
310            match result {
311                Err(e) => {
312                    let _ = tx.send(Err(e)).await;
313                }
314                Ok(new_message) => {
315                    // Check if consumer is still listening before continuing
316                    if tx.is_closed() {
317                        return;
318                    }
319
320                    ctx_clone.state.append_new_message(new_message.clone());
321                    if tx
322                        .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
323                        .await
324                        .is_err()
325                    {
326                        // Consumer dropped, exit gracefully
327                        return;
328                    }
329
330                    // Check again before expensive session update
331                    if tx.is_closed() {
332                        return;
333                    }
334
335                    let result = client
336                        .save_checkpoint(&current_session, ctx_clone.state.messages.clone())
337                        .await;
338
339                    match result {
340                        Err(e) => {
341                            let _ = tx.send(Err(e)).await;
342                        }
343                        Ok(updated) => {
344                            ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
345                            let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
346                        }
347                    }
348                }
349            }
350        });
351
352        let hook_registry = self.hook_registry.clone();
353        let stream = async_stream::stream! {
354            while let Some(delta_result) = rx.recv().await {
355                match delta_result {
356                    Ok(delta) => match delta {
357                        StreamMessage::Ctx(updated_ctx) => {
358                            ctx = *updated_ctx;
359                            // Emit session metadata so callers can track session_id
360                            if let Some(session_id) = ctx.session_id {
361                                let mut meta = serde_json::Map::new();
362                                meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
363                                if let Some(checkpoint_id) = ctx.new_checkpoint_id {
364                                    meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
365                                }
366                                yield Ok(ChatCompletionStreamResponse {
367                                    id: ctx.request_id.to_string(),
368                                    object: "chat.completion.chunk".to_string(),
369                                    created: chrono::Utc::now().timestamp() as u64,
370                                    model: String::new(),
371                                    choices: vec![],
372                                    usage: None,
373                                    metadata: Some(serde_json::Value::Object(meta)),
374                                });
375                            }
376                        }
377                        StreamMessage::Delta(delta) => {
378                            // Extract usage from Usage delta variant
379                            let usage = if let GenerationDelta::Usage { usage } = &delta {
380                                Some(usage.clone())
381                            } else {
382                                None
383                            };
384
385                            yield Ok(ChatCompletionStreamResponse {
386                                id: ctx.request_id.to_string(),
387                                object: "chat.completion.chunk".to_string(),
388                                created: chrono::Utc::now().timestamp() as u64,
389                                model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
390                                choices: vec![ChatCompletionStreamChoice {
391                                    index: 0,
392                                    delta: delta.into(),
393                                    finish_reason: None,
394                                }],
395                                usage,
396                                metadata: None,
397                            })
398                        }
399                    }
400                    Err(e) => yield Err(ApiStreamError::Unknown(e)),
401                }
402            }
403
404            // Execute after request hooks
405            hook_registry
406                .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
407                .await
408                .map_err(|e| e.to_string())?
409                .ok()?;
410        };
411
412        Ok((Box::pin(stream), None))
413    }
414
415    async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
416        if let Some(api) = &self.stakpak_api {
417            api.cancel_request(&request_id).await
418        } else {
419            // Local mode doesn't support cancellation yet
420            Ok(())
421        }
422    }
423
424    // =========================================================================
425    // Search Docs
426    // =========================================================================
427
428    async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
429        if let Some(api) = &self.stakpak_api {
430            api.search_docs(&crate::stakpak::SearchDocsRequest {
431                keywords: input.keywords.clone(),
432                exclude_keywords: input.exclude_keywords.clone(),
433                limit: input.limit,
434            })
435            .await
436        } else {
437            // Fallback to local search service
438            use stakpak_shared::models::integrations::search_service::*;
439
440            let config = SearchServicesOrchestrator::start()
441                .await
442                .map_err(|e| e.to_string())?;
443
444            let api_url = format!("http://localhost:{}", config.api_port);
445            let search_client = SearchClient::new(api_url);
446
447            let search_results = search_client
448                .search_and_scrape(input.keywords.clone(), None)
449                .await
450                .map_err(|e| e.to_string())?;
451
452            if search_results.is_empty() {
453                return Ok(vec![Content::text("No results found".to_string())]);
454            }
455
456            Ok(search_results
457                .into_iter()
458                .map(|result| {
459                    let content = result.content.unwrap_or_default();
460                    Content::text(format!("URL: {}\nContent: {}", result.url, content))
461                })
462                .collect())
463        }
464    }
465
466    // =========================================================================
467    // Memory
468    // =========================================================================
469
470    async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
471        if let Some(api) = &self.stakpak_api {
472            api.memorize_session(checkpoint_id).await
473        } else {
474            // No-op in local mode
475            Ok(())
476        }
477    }
478
479    async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
480        if let Some(api) = &self.stakpak_api {
481            api.search_memory(&crate::stakpak::SearchMemoryRequest {
482                keywords: input.keywords.clone(),
483                start_time: input.start_time,
484                end_time: input.end_time,
485            })
486            .await
487        } else {
488            // Empty results in local mode
489            Ok(vec![])
490        }
491    }
492
493    // =========================================================================
494    // Slack
495    // =========================================================================
496
497    async fn slack_read_messages(
498        &self,
499        input: &SlackReadMessagesRequest,
500    ) -> Result<Vec<Content>, String> {
501        if let Some(api) = &self.stakpak_api {
502            api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
503                channel: input.channel.clone(),
504                limit: input.limit,
505            })
506            .await
507        } else {
508            Err("Slack integration requires Stakpak API key".to_string())
509        }
510    }
511
512    async fn slack_read_replies(
513        &self,
514        input: &SlackReadRepliesRequest,
515    ) -> Result<Vec<Content>, String> {
516        if let Some(api) = &self.stakpak_api {
517            api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
518                channel: input.channel.clone(),
519                ts: input.ts.clone(),
520            })
521            .await
522        } else {
523            Err("Slack integration requires Stakpak API key".to_string())
524        }
525    }
526
527    async fn slack_send_message(
528        &self,
529        input: &SlackSendMessageRequest,
530    ) -> Result<Vec<Content>, String> {
531        if let Some(api) = &self.stakpak_api {
532            api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
533                channel: input.channel.clone(),
534                mrkdwn_text: input.mrkdwn_text.clone(),
535                thread_ts: input.thread_ts.clone(),
536            })
537            .await
538        } else {
539            Err("Slack integration requires Stakpak API key".to_string())
540        }
541    }
542}
543
544// =============================================================================
545// SessionStorage implementation (delegates to inner session_storage)
546// =============================================================================
547
548#[async_trait]
549impl crate::storage::SessionStorage for super::AgentClient {
550    async fn list_sessions(
551        &self,
552        query: &crate::storage::ListSessionsQuery,
553    ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
554        self.session_storage.list_sessions(query).await
555    }
556
557    async fn get_session(
558        &self,
559        session_id: Uuid,
560    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
561        self.session_storage.get_session(session_id).await
562    }
563
564    async fn create_session(
565        &self,
566        request: &crate::storage::CreateSessionRequest,
567    ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
568        self.session_storage.create_session(request).await
569    }
570
571    async fn update_session(
572        &self,
573        session_id: Uuid,
574        request: &crate::storage::UpdateSessionRequest,
575    ) -> Result<crate::storage::Session, crate::storage::StorageError> {
576        self.session_storage
577            .update_session(session_id, request)
578            .await
579    }
580
581    async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
582        self.session_storage.delete_session(session_id).await
583    }
584
585    async fn list_checkpoints(
586        &self,
587        session_id: Uuid,
588        query: &crate::storage::ListCheckpointsQuery,
589    ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
590        self.session_storage
591            .list_checkpoints(session_id, query)
592            .await
593    }
594
595    async fn get_checkpoint(
596        &self,
597        checkpoint_id: Uuid,
598    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
599        self.session_storage.get_checkpoint(checkpoint_id).await
600    }
601
602    async fn create_checkpoint(
603        &self,
604        session_id: Uuid,
605        request: &crate::storage::CreateCheckpointRequest,
606    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
607        self.session_storage
608            .create_checkpoint(session_id, request)
609            .await
610    }
611
612    async fn get_active_checkpoint(
613        &self,
614        session_id: Uuid,
615    ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
616        self.session_storage.get_active_checkpoint(session_id).await
617    }
618
619    async fn get_session_stats(
620        &self,
621        session_id: Uuid,
622    ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
623        self.session_storage.get_session_stats(session_id).await
624    }
625}
626
627// =============================================================================
628// Helper Methods
629// =============================================================================
630
631const TITLE_GENERATOR_PROMPT: &str =
632    include_str!("../local/prompts/session_title_generator.v1.txt");
633
634impl AgentClient {
635    /// Initialize or resume a session based on context
636    ///
637    /// If `ctx.session_id` is set, we resume that session directly.
638    /// Otherwise, we create a new session.
639    pub(crate) async fn initialize_session(
640        &self,
641        ctx: &HookContext<AgentState>,
642    ) -> Result<SessionInfo, String> {
643        let messages = &ctx.state.messages;
644
645        if messages.is_empty() {
646            return Err("At least one message is required".to_string());
647        }
648
649        // If session_id is set in context, resume that session directly
650        if let Some(session_id) = ctx.session_id {
651            let session = self
652                .session_storage
653                .get_session(session_id)
654                .await
655                .map_err(|e| e.to_string())?;
656
657            let checkpoint = session
658                .active_checkpoint
659                .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
660
661            return Ok(SessionInfo {
662                session_id: session.id,
663                checkpoint_id: checkpoint.id,
664                checkpoint_created_at: checkpoint.created_at,
665            });
666        }
667
668        // Create new session with a fast local title.
669        let fallback_title = Self::fallback_session_title(messages);
670
671        // Get current working directory
672        let cwd = std::env::current_dir()
673            .ok()
674            .map(|p| p.to_string_lossy().to_string());
675
676        // Create session via storage trait
677        let mut session_request =
678            StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
679        if let Some(cwd) = cwd {
680            session_request = session_request.with_cwd(cwd);
681        }
682
683        let result = self
684            .session_storage
685            .create_session(&session_request)
686            .await
687            .map_err(|e| e.to_string())?;
688
689        // Generate a better title asynchronously and update the session when ready.
690        let client = self.clone();
691        let messages_for_title = messages.to_vec();
692        let session_id = result.session_id;
693        tokio::spawn(async move {
694            if let Ok(title) = client.generate_session_title(&messages_for_title).await {
695                let trimmed = title.trim();
696                if !trimmed.is_empty() && trimmed != fallback_title {
697                    let request =
698                        StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
699                    let _ = client
700                        .session_storage
701                        .update_session(session_id, &request)
702                        .await;
703                }
704            }
705        });
706
707        Ok(SessionInfo {
708            session_id: result.session_id,
709            checkpoint_id: result.checkpoint.id,
710            checkpoint_created_at: result.checkpoint.created_at,
711        })
712    }
713
714    fn fallback_session_title(messages: &[ChatMessage]) -> String {
715        messages
716            .iter()
717            .find(|m| m.role == Role::User)
718            .and_then(|m| m.content.as_ref())
719            .map(|c| {
720                let text = c.to_string();
721                text.split_whitespace()
722                    .take(5)
723                    .collect::<Vec<_>>()
724                    .join(" ")
725            })
726            .unwrap_or_else(|| "New Session".to_string())
727    }
728
729    /// Save a new checkpoint for the current session
730    pub(crate) async fn save_checkpoint(
731        &self,
732        current: &SessionInfo,
733        messages: Vec<ChatMessage>,
734    ) -> Result<SessionInfo, String> {
735        let checkpoint_request =
736            StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
737
738        let checkpoint = self
739            .session_storage
740            .create_checkpoint(current.session_id, &checkpoint_request)
741            .await
742            .map_err(|e| e.to_string())?;
743
744        Ok(SessionInfo {
745            session_id: current.session_id,
746            checkpoint_id: checkpoint.id,
747            checkpoint_created_at: checkpoint.created_at,
748        })
749    }
750
751    /// Run agent completion (inference)
752    pub(crate) async fn run_agent_completion(
753        &self,
754        ctx: &mut HookContext<AgentState>,
755        stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
756    ) -> Result<ChatMessage, String> {
757        // Execute before inference hooks
758        self.hook_registry
759            .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
760            .await
761            .map_err(|e| e.to_string())?
762            .ok()?;
763
764        let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
765            llm_input
766        } else {
767            return Err(
768                "LLM input not found, make sure to register a context hook before inference"
769                    .to_string(),
770            );
771        };
772
773        // Inject session_id header if available
774        if let Some(session_id) = ctx.session_id {
775            let headers = input
776                .headers
777                .get_or_insert_with(std::collections::HashMap::new);
778            headers.insert("X-Session-Id".to_string(), session_id.to_string());
779        }
780
781        let (response_message, usage) = if let Some(tx) = stream_channel_tx {
782            // Streaming mode
783            let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
784            let stream_input = LLMStreamInput {
785                model: input.model,
786                messages: input.messages,
787                max_tokens: input.max_tokens,
788                tools: input.tools,
789                stream_channel_tx: internal_tx,
790                provider_options: input.provider_options,
791                headers: input.headers,
792            };
793
794            let stakai = self.stakai.clone();
795            let chat_future = async move {
796                stakai
797                    .chat_stream(stream_input)
798                    .await
799                    .map_err(|e| e.to_string())
800            };
801
802            let receive_future = async move {
803                while let Some(delta) = internal_rx.recv().await {
804                    if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
805                        break;
806                    }
807                }
808            };
809
810            let (chat_result, _) = tokio::join!(chat_future, receive_future);
811            let response = chat_result?;
812            (response.choices[0].message.clone(), response.usage)
813        } else {
814            // Non-streaming mode
815            let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
816            (response.choices[0].message.clone(), response.usage)
817        };
818
819        ctx.state.set_llm_output(response_message, usage);
820
821        // Execute after inference hooks
822        self.hook_registry
823            .execute_hooks(ctx, &LifecycleEvent::AfterInference)
824            .await
825            .map_err(|e| e.to_string())?
826            .ok()?;
827
828        let llm_output = ctx
829            .state
830            .llm_output
831            .as_ref()
832            .ok_or_else(|| "LLM output is missing from state".to_string())?;
833
834        Ok(ChatMessage::from(llm_output))
835    }
836
837    /// Generate a title for a new session
838    async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
839        let llm_model = if let Some(eco_model) = &self.model_options.eco_model {
840            eco_model.clone()
841        } else {
842            // Try to find a suitable model
843            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
844        };
845
846        // If Stakpak is available, route through it
847        let model = if self.has_stakpak() {
848            // Get properly formatted model string with provider prefix (e.g., "anthropic/claude-haiku-4-5")
849            let model_str = get_stakai_model_string(&llm_model);
850            // Extract display name from the last segment for UI
851            let display_name = model_str
852                .rsplit('/')
853                .next()
854                .unwrap_or(&model_str)
855                .to_string();
856            LLMModel::Custom {
857                provider: "stakpak".to_string(),
858                model: model_str,
859                name: Some(display_name),
860            }
861        } else {
862            llm_model
863        };
864
865        let llm_messages = vec![
866            LLMMessage {
867                role: Role::System.to_string(),
868                content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
869            },
870            LLMMessage {
871                role: Role::User.to_string(),
872                content: LLMMessageContent::String(
873                    messages
874                        .iter()
875                        .map(|msg| {
876                            msg.content
877                                .as_ref()
878                                .unwrap_or(&MessageContent::String("".to_string()))
879                                .to_string()
880                        })
881                        .collect(),
882                ),
883            },
884        ];
885
886        let input = LLMInput {
887            model,
888            messages: llm_messages,
889            max_tokens: 100,
890            tools: None,
891            provider_options: None,
892            headers: None,
893        };
894
895        let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
896
897        Ok(response.choices[0].message.content.to_string())
898    }
899}