1use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11 CreateCheckpointRequest as StorageCreateCheckpointRequest,
12 CreateSessionRequest as StorageCreateSessionRequest,
13 UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22 ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35 session_id: Uuid,
36 checkpoint_id: Uuid,
37 checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42#[derive(Debug)]
47pub(crate) enum StreamMessage {
48 Delta(GenerationDelta),
49 Ctx(Box<HookContext<AgentState>>),
50}
51
52#[async_trait]
57impl AgentProvider for AgentClient {
58 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63 if let Some(api) = &self.stakpak_api {
64 api.get_account().await
65 } else {
66 Ok(GetMyAccountResponse {
68 username: "local".to_string(),
69 id: "local".to_string(),
70 first_name: "local".to_string(),
71 last_name: "local".to_string(),
72 email: "local@stakpak.dev".to_string(),
73 scope: None,
74 })
75 }
76 }
77
78 async fn get_billing_info(
79 &self,
80 account_username: &str,
81 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82 if let Some(api) = &self.stakpak_api {
83 api.get_billing(account_username).await
84 } else {
85 Err("Billing info not available without Stakpak API key".to_string())
86 }
87 }
88
89 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94 if let Some(api) = &self.stakpak_api {
95 api.list_rulebooks().await
96 } else {
97 let client = stakpak_shared::tls_client::create_tls_client(
99 stakpak_shared::tls_client::TlsClientConfig::default()
100 .with_timeout(std::time::Duration::from_secs(30)),
101 )?;
102
103 let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106 if response.status().is_success() {
107 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108 match serde_json::from_value::<ListRulebooksResponse>(value) {
109 Ok(resp) => Ok(resp.results),
110 Err(_) => Ok(vec![]),
111 }
112 } else {
113 Ok(vec![])
114 }
115 }
116 }
117
118 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119 if let Some(api) = &self.stakpak_api {
120 api.get_rulebook_by_uri(uri).await
121 } else {
122 let client = stakpak_shared::tls_client::create_tls_client(
124 stakpak_shared::tls_client::TlsClientConfig::default()
125 .with_timeout(std::time::Duration::from_secs(30)),
126 )?;
127
128 let encoded_uri = urlencoding::encode(uri);
129 let url = format!(
130 "{}/v1/rules/{}",
131 self.get_stakpak_api_endpoint(),
132 encoded_uri
133 );
134 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136 if response.status().is_success() {
137 response.json().await.map_err(|e| e.to_string())
138 } else {
139 Err("Rulebook not found".to_string())
140 }
141 }
142 }
143
144 async fn create_rulebook(
145 &self,
146 uri: &str,
147 description: &str,
148 content: &str,
149 tags: Vec<String>,
150 visibility: Option<RuleBookVisibility>,
151 ) -> Result<CreateRuleBookResponse, String> {
152 if let Some(api) = &self.stakpak_api {
153 api.create_rulebook(&CreateRuleBookInput {
154 uri: uri.to_string(),
155 description: description.to_string(),
156 content: content.to_string(),
157 tags,
158 visibility,
159 })
160 .await
161 } else {
162 Err("Creating rulebooks requires Stakpak API key".to_string())
163 }
164 }
165
166 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167 if let Some(api) = &self.stakpak_api {
168 api.delete_rulebook(uri).await
169 } else {
170 Err("Deleting rulebooks requires Stakpak API key".to_string())
171 }
172 }
173
174 async fn chat_completion(
179 &self,
180 model: Model,
181 messages: Vec<ChatMessage>,
182 tools: Option<Vec<Tool>>,
183 session_id: Option<Uuid>,
184 metadata: Option<serde_json::Value>,
185 ) -> Result<ChatCompletionResponse, String> {
186 let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
187 ctx.state.metadata = metadata;
188
189 self.hook_registry
191 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
192 .await
193 .map_err(|e| e.to_string())?
194 .ok()?;
195
196 let current_session = self.initialize_session(&ctx).await?;
198 ctx.set_session_id(current_session.session_id);
199
200 let new_message = self.run_agent_completion(&mut ctx, None).await?;
202 ctx.state.append_new_message(new_message.clone());
203
204 let result = self
206 .save_checkpoint(
207 ¤t_session,
208 ctx.state.messages.clone(),
209 ctx.state.metadata.clone(),
210 )
211 .await?;
212 let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
213 ctx.set_new_checkpoint_id(result.checkpoint_id);
214
215 self.hook_registry
217 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
218 .await
219 .map_err(|e| e.to_string())?
220 .ok()?;
221
222 let mut meta = serde_json::Map::new();
223 if let Some(session_id) = ctx.session_id {
224 meta.insert(
225 "session_id".to_string(),
226 serde_json::Value::String(session_id.to_string()),
227 );
228 }
229 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
230 meta.insert(
231 "checkpoint_id".to_string(),
232 serde_json::Value::String(checkpoint_id.to_string()),
233 );
234 }
235
236 Ok(ChatCompletionResponse {
237 id: ctx.new_checkpoint_id.unwrap().to_string(),
238 object: "chat.completion".to_string(),
239 created: checkpoint_created_at,
240 model: ctx
241 .state
242 .llm_input
243 .as_ref()
244 .map(|llm_input| llm_input.model.id.clone())
245 .unwrap_or_default(),
246 choices: vec![ChatCompletionChoice {
247 index: 0,
248 message: ctx.state.messages.last().cloned().unwrap(),
249 logprobs: None,
250 finish_reason: FinishReason::Stop,
251 }],
252 usage: ctx
253 .state
254 .llm_output
255 .as_ref()
256 .map(|u| u.usage.clone())
257 .unwrap_or_default(),
258 system_fingerprint: None,
259 metadata: if meta.is_empty() {
260 None
261 } else {
262 Some(serde_json::Value::Object(meta))
263 },
264 })
265 }
266
267 async fn chat_completion_stream(
268 &self,
269 model: Model,
270 messages: Vec<ChatMessage>,
271 tools: Option<Vec<Tool>>,
272 _headers: Option<HeaderMap>,
273 session_id: Option<Uuid>,
274 metadata: Option<serde_json::Value>,
275 ) -> Result<
276 (
277 Pin<
278 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
279 >,
280 Option<String>,
281 ),
282 String,
283 > {
284 let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
285 ctx.state.metadata = metadata;
286
287 self.hook_registry
289 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
290 .await
291 .map_err(|e| e.to_string())?
292 .ok()?;
293
294 let current_session = self.initialize_session(&ctx).await?;
296 ctx.set_session_id(current_session.session_id);
297
298 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
299
300 let client = self.clone();
302 let mut ctx_clone = ctx.clone();
303
304 tokio::spawn(async move {
308 if tx.is_closed() {
310 return;
311 }
312
313 let result = client
314 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
315 .await;
316
317 match result {
318 Err(e) => {
319 let _ = tx.send(Err(e)).await;
320 }
321 Ok(new_message) => {
322 if tx.is_closed() {
324 return;
325 }
326
327 ctx_clone.state.append_new_message(new_message.clone());
328 if tx
329 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
330 .await
331 .is_err()
332 {
333 return;
335 }
336
337 if tx.is_closed() {
339 return;
340 }
341
342 let result = client
343 .save_checkpoint(
344 ¤t_session,
345 ctx_clone.state.messages.clone(),
346 ctx_clone.state.metadata.clone(),
347 )
348 .await;
349
350 match result {
351 Err(e) => {
352 let _ = tx.send(Err(e)).await;
353 }
354 Ok(updated) => {
355 ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
356 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
357 }
358 }
359 }
360 }
361 });
362
363 let hook_registry = self.hook_registry.clone();
364 let stream = async_stream::stream! {
365 while let Some(delta_result) = rx.recv().await {
366 match delta_result {
367 Ok(delta) => match delta {
368 StreamMessage::Ctx(updated_ctx) => {
369 ctx = *updated_ctx;
370 if let Some(session_id) = ctx.session_id {
372 let mut meta = serde_json::Map::new();
373 meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
374 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
375 meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
376 }
377 yield Ok(ChatCompletionStreamResponse {
378 id: ctx.request_id.to_string(),
379 object: "chat.completion.chunk".to_string(),
380 created: chrono::Utc::now().timestamp() as u64,
381 model: String::new(),
382 choices: vec![],
383 usage: None,
384 metadata: Some(serde_json::Value::Object(meta)),
385 });
386 }
387 }
388 StreamMessage::Delta(delta) => {
389 let usage = if let GenerationDelta::Usage { usage } = &delta {
391 Some(usage.clone())
392 } else {
393 None
394 };
395
396 yield Ok(ChatCompletionStreamResponse {
397 id: ctx.request_id.to_string(),
398 object: "chat.completion.chunk".to_string(),
399 created: chrono::Utc::now().timestamp() as u64,
400 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
401 choices: vec![ChatCompletionStreamChoice {
402 index: 0,
403 delta: delta.into(),
404 finish_reason: None,
405 }],
406 usage,
407 metadata: None,
408 })
409 }
410 }
411 Err(e) => yield Err(ApiStreamError::Unknown(e)),
412 }
413 }
414
415 hook_registry
417 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
418 .await
419 .map_err(|e| e.to_string())?
420 .ok()?;
421 };
422
423 Ok((Box::pin(stream), None))
424 }
425
426 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
427 if let Some(api) = &self.stakpak_api {
428 api.cancel_request(&request_id).await
429 } else {
430 Ok(())
432 }
433 }
434
435 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
440 if let Some(api) = &self.stakpak_api {
441 api.search_docs(&crate::stakpak::SearchDocsRequest {
442 keywords: input.keywords.clone(),
443 exclude_keywords: input.exclude_keywords.clone(),
444 limit: input.limit,
445 })
446 .await
447 } else {
448 use stakpak_shared::models::integrations::search_service::*;
450
451 let config = SearchServicesOrchestrator::start()
452 .await
453 .map_err(|e| e.to_string())?;
454
455 let api_url = format!("http://localhost:{}", config.api_port);
456 let search_client = SearchClient::new(api_url);
457
458 let search_results = search_client
459 .search_and_scrape(input.keywords.clone(), None)
460 .await
461 .map_err(|e| e.to_string())?;
462
463 if search_results.is_empty() {
464 return Ok(vec![Content::text("No results found".to_string())]);
465 }
466
467 Ok(search_results
468 .into_iter()
469 .map(|result| {
470 let content = result.content.unwrap_or_default();
471 Content::text(format!("URL: {}\nContent: {}", result.url, content))
472 })
473 .collect())
474 }
475 }
476
477 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
482 if let Some(api) = &self.stakpak_api {
483 api.memorize_session(checkpoint_id).await
484 } else {
485 Ok(())
487 }
488 }
489
490 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
491 if let Some(api) = &self.stakpak_api {
492 api.search_memory(&crate::stakpak::SearchMemoryRequest {
493 keywords: input.keywords.clone(),
494 start_time: input.start_time,
495 end_time: input.end_time,
496 })
497 .await
498 } else {
499 Ok(vec![])
501 }
502 }
503
504 async fn slack_read_messages(
509 &self,
510 input: &SlackReadMessagesRequest,
511 ) -> Result<Vec<Content>, String> {
512 if let Some(api) = &self.stakpak_api {
513 api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
514 channel: input.channel.clone(),
515 limit: input.limit,
516 })
517 .await
518 } else {
519 Err("Slack integration requires Stakpak API key".to_string())
520 }
521 }
522
523 async fn slack_read_replies(
524 &self,
525 input: &SlackReadRepliesRequest,
526 ) -> Result<Vec<Content>, String> {
527 if let Some(api) = &self.stakpak_api {
528 api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
529 channel: input.channel.clone(),
530 ts: input.ts.clone(),
531 })
532 .await
533 } else {
534 Err("Slack integration requires Stakpak API key".to_string())
535 }
536 }
537
538 async fn slack_send_message(
539 &self,
540 input: &SlackSendMessageRequest,
541 ) -> Result<Vec<Content>, String> {
542 if let Some(api) = &self.stakpak_api {
543 api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
544 channel: input.channel.clone(),
545 markdown_text: input.markdown_text.clone(),
546 thread_ts: input.thread_ts.clone(),
547 })
548 .await
549 } else {
550 Err("Slack integration requires Stakpak API key".to_string())
551 }
552 }
553
554 async fn list_models(&self) -> Vec<stakai::Model> {
559 let registry = self.stakai.registry();
563 let mut all_models = Vec::new();
564
565 for provider_id in registry.list_providers() {
566 if let Ok(mut models) = registry.models_for_provider(&provider_id).await {
567 all_models.append(&mut models);
568 }
569 }
570
571 sort_models_by_recency(&mut all_models);
572 all_models
573 }
574}
575
576fn sort_models_by_recency(models: &mut [stakai::Model]) {
578 models.sort_by(|a, b| {
579 match (&b.release_date, &a.release_date) {
580 (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
581 (Some(_), None) => std::cmp::Ordering::Less,
582 (None, Some(_)) => std::cmp::Ordering::Greater,
583 (None, None) => b.id.cmp(&a.id), }
585 });
586}
587
588#[async_trait]
593impl crate::storage::SessionStorage for super::AgentClient {
594 async fn list_sessions(
595 &self,
596 query: &crate::storage::ListSessionsQuery,
597 ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
598 self.session_storage.list_sessions(query).await
599 }
600
601 async fn get_session(
602 &self,
603 session_id: Uuid,
604 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
605 self.session_storage.get_session(session_id).await
606 }
607
608 async fn create_session(
609 &self,
610 request: &crate::storage::CreateSessionRequest,
611 ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
612 self.session_storage.create_session(request).await
613 }
614
615 async fn update_session(
616 &self,
617 session_id: Uuid,
618 request: &crate::storage::UpdateSessionRequest,
619 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
620 self.session_storage
621 .update_session(session_id, request)
622 .await
623 }
624
625 async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
626 self.session_storage.delete_session(session_id).await
627 }
628
629 async fn list_checkpoints(
630 &self,
631 session_id: Uuid,
632 query: &crate::storage::ListCheckpointsQuery,
633 ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
634 self.session_storage
635 .list_checkpoints(session_id, query)
636 .await
637 }
638
639 async fn get_checkpoint(
640 &self,
641 checkpoint_id: Uuid,
642 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
643 self.session_storage.get_checkpoint(checkpoint_id).await
644 }
645
646 async fn create_checkpoint(
647 &self,
648 session_id: Uuid,
649 request: &crate::storage::CreateCheckpointRequest,
650 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
651 self.session_storage
652 .create_checkpoint(session_id, request)
653 .await
654 }
655
656 async fn get_active_checkpoint(
657 &self,
658 session_id: Uuid,
659 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
660 self.session_storage.get_active_checkpoint(session_id).await
661 }
662
663 async fn get_session_stats(
664 &self,
665 session_id: Uuid,
666 ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
667 self.session_storage.get_session_stats(session_id).await
668 }
669}
670
671const TITLE_GENERATOR_PROMPT: &str =
676 include_str!("../local/prompts/session_title_generator.v1.txt");
677
678impl AgentClient {
679 pub(crate) async fn initialize_session(
684 &self,
685 ctx: &HookContext<AgentState>,
686 ) -> Result<SessionInfo, String> {
687 let messages = &ctx.state.messages;
688
689 if messages.is_empty() {
690 return Err("At least one message is required".to_string());
691 }
692
693 if let Some(session_id) = ctx.session_id {
695 let session = self
696 .session_storage
697 .get_session(session_id)
698 .await
699 .map_err(|e| e.to_string())?;
700
701 let checkpoint = session
702 .active_checkpoint
703 .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
704
705 if session.title.trim().is_empty() || session.title == "New Session" {
707 let client = self.clone();
708 let messages_for_title = messages.to_vec();
709 let session_id = session.id;
710 let existing_title = session.title.clone();
711 tokio::spawn(async move {
712 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
713 let trimmed = title.trim();
714 if !trimmed.is_empty() && trimmed != existing_title {
715 let request =
716 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
717 let _ = client
718 .session_storage
719 .update_session(session_id, &request)
720 .await;
721 }
722 }
723 });
724 }
725
726 return Ok(SessionInfo {
727 session_id: session.id,
728 checkpoint_id: checkpoint.id,
729 checkpoint_created_at: checkpoint.created_at,
730 });
731 }
732
733 let fallback_title = Self::fallback_session_title(messages);
735
736 let cwd = std::env::current_dir()
738 .ok()
739 .map(|p| p.to_string_lossy().to_string());
740
741 let mut session_request =
743 StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
744 if let Some(cwd) = cwd {
745 session_request = session_request.with_cwd(cwd);
746 }
747
748 let result = self
749 .session_storage
750 .create_session(&session_request)
751 .await
752 .map_err(|e| e.to_string())?;
753
754 let client = self.clone();
756 let messages_for_title = messages.to_vec();
757 let session_id = result.session_id;
758 tokio::spawn(async move {
759 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
760 let trimmed = title.trim();
761 if !trimmed.is_empty() && trimmed != fallback_title {
762 let request =
763 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
764 let _ = client
765 .session_storage
766 .update_session(session_id, &request)
767 .await;
768 }
769 }
770 });
771
772 Ok(SessionInfo {
773 session_id: result.session_id,
774 checkpoint_id: result.checkpoint.id,
775 checkpoint_created_at: result.checkpoint.created_at,
776 })
777 }
778
779 fn fallback_session_title(messages: &[ChatMessage]) -> String {
780 messages
781 .iter()
782 .find(|m| m.role == Role::User)
783 .and_then(|m| m.content.as_ref())
784 .map(|c| {
785 let text = c.to_string();
786 text.split_whitespace()
787 .take(5)
788 .collect::<Vec<_>>()
789 .join(" ")
790 })
791 .unwrap_or_else(|| "New Session".to_string())
792 }
793
794 pub(crate) async fn save_checkpoint(
796 &self,
797 current: &SessionInfo,
798 messages: Vec<ChatMessage>,
799 metadata: Option<serde_json::Value>,
800 ) -> Result<SessionInfo, String> {
801 let mut checkpoint_request =
802 StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
803
804 if let Some(meta) = metadata {
805 checkpoint_request = checkpoint_request.with_metadata(meta);
806 }
807
808 let checkpoint = self
809 .session_storage
810 .create_checkpoint(current.session_id, &checkpoint_request)
811 .await
812 .map_err(|e| e.to_string())?;
813
814 Ok(SessionInfo {
815 session_id: current.session_id,
816 checkpoint_id: checkpoint.id,
817 checkpoint_created_at: checkpoint.created_at,
818 })
819 }
820
821 pub(crate) async fn run_agent_completion(
823 &self,
824 ctx: &mut HookContext<AgentState>,
825 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
826 ) -> Result<ChatMessage, String> {
827 self.hook_registry
829 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
830 .await
831 .map_err(|e| e.to_string())?
832 .ok()?;
833
834 let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
835 llm_input
836 } else {
837 return Err(
838 "LLM input not found, make sure to register a context hook before inference"
839 .to_string(),
840 );
841 };
842
843 if let Some(session_id) = ctx.session_id {
845 let headers = input
846 .headers
847 .get_or_insert_with(std::collections::HashMap::new);
848 headers.insert("X-Session-Id".to_string(), session_id.to_string());
849 }
850
851 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
852 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
854 let stream_input = LLMStreamInput {
855 model: input.model,
856 messages: input.messages,
857 max_tokens: input.max_tokens,
858 tools: input.tools,
859 stream_channel_tx: internal_tx,
860 provider_options: input.provider_options,
861 headers: input.headers,
862 };
863
864 let stakai = self.stakai.clone();
865 let chat_future = async move {
866 stakai
867 .chat_stream(stream_input)
868 .await
869 .map_err(|e| e.to_string())
870 };
871
872 let receive_future = async move {
873 while let Some(delta) = internal_rx.recv().await {
874 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
875 break;
876 }
877 }
878 };
879
880 let (chat_result, _) = tokio::join!(chat_future, receive_future);
881 let response = chat_result?;
882 (response.choices[0].message.clone(), response.usage)
883 } else {
884 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
886 (response.choices[0].message.clone(), response.usage)
887 };
888
889 ctx.state.set_llm_output(response_message, usage);
890
891 self.hook_registry
893 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
894 .await
895 .map_err(|e| e.to_string())?
896 .ok()?;
897
898 let llm_output = ctx
899 .state
900 .llm_output
901 .as_ref()
902 .ok_or_else(|| "LLM output is missing from state".to_string())?;
903
904 Ok(ChatMessage::from(llm_output))
905 }
906
907 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
909 let use_stakpak = self.stakpak.is_some();
911 let providers = self.stakai.registry().list_providers();
912 let cheap_models: &[(&str, &str)] = &[
913 ("stakpak", "claude-haiku-4-5"),
914 ("anthropic", "claude-haiku-4-5"),
915 ("openai", "gpt-4.1-mini"),
916 ("google", "gemini-2.5-flash"),
917 ];
918 let model = cheap_models
919 .iter()
920 .find_map(|(provider, model_id)| {
921 if providers.contains(&provider.to_string()) {
922 crate::find_model(model_id, use_stakpak)
923 } else {
924 None
925 }
926 })
927 .ok_or_else(|| "No model available for title generation".to_string())?;
928
929 let llm_messages = vec![
930 LLMMessage {
931 role: Role::System.to_string(),
932 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
933 },
934 LLMMessage {
935 role: Role::User.to_string(),
936 content: LLMMessageContent::String(
937 messages
938 .iter()
939 .map(|msg| {
940 msg.content
941 .as_ref()
942 .unwrap_or(&MessageContent::String("".to_string()))
943 .to_string()
944 })
945 .collect(),
946 ),
947 },
948 ];
949
950 let input = LLMInput {
951 model,
952 messages: llm_messages,
953 max_tokens: 100,
954 tools: None,
955 provider_options: None,
956 headers: None,
957 };
958
959 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
960
961 Ok(response.choices[0].message.content.to_string())
962 }
963}