1use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11 CreateCheckpointRequest as StorageCreateCheckpointRequest,
12 CreateSessionRequest as StorageCreateSessionRequest,
13 UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22 ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35 session_id: Uuid,
36 checkpoint_id: Uuid,
37 checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42#[derive(Debug)]
47pub(crate) enum StreamMessage {
48 Delta(GenerationDelta),
49 Ctx(Box<HookContext<AgentState>>),
50}
51
52#[async_trait]
57impl AgentProvider for AgentClient {
58 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63 if let Some(api) = &self.stakpak_api {
64 api.get_account().await
65 } else {
66 Ok(GetMyAccountResponse {
68 username: "local".to_string(),
69 id: "local".to_string(),
70 first_name: "local".to_string(),
71 last_name: "local".to_string(),
72 email: "local@stakpak.dev".to_string(),
73 scope: None,
74 })
75 }
76 }
77
78 async fn get_billing_info(
79 &self,
80 account_username: &str,
81 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82 if let Some(api) = &self.stakpak_api {
83 api.get_billing(account_username).await
84 } else {
85 Err("Billing info not available without Stakpak API key".to_string())
86 }
87 }
88
89 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94 if let Some(api) = &self.stakpak_api {
95 api.list_rulebooks().await
96 } else {
97 let client = stakpak_shared::tls_client::create_tls_client(
99 stakpak_shared::tls_client::TlsClientConfig::default()
100 .with_timeout(std::time::Duration::from_secs(30)),
101 )?;
102
103 let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106 if response.status().is_success() {
107 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108 match serde_json::from_value::<ListRulebooksResponse>(value) {
109 Ok(resp) => Ok(resp.results),
110 Err(_) => Ok(vec![]),
111 }
112 } else {
113 Ok(vec![])
114 }
115 }
116 }
117
118 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119 if let Some(api) = &self.stakpak_api {
120 api.get_rulebook_by_uri(uri).await
121 } else {
122 let client = stakpak_shared::tls_client::create_tls_client(
124 stakpak_shared::tls_client::TlsClientConfig::default()
125 .with_timeout(std::time::Duration::from_secs(30)),
126 )?;
127
128 let encoded_uri = urlencoding::encode(uri);
129 let url = format!(
130 "{}/v1/rules/{}",
131 self.get_stakpak_api_endpoint(),
132 encoded_uri
133 );
134 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136 if response.status().is_success() {
137 response.json().await.map_err(|e| e.to_string())
138 } else {
139 Err("Rulebook not found".to_string())
140 }
141 }
142 }
143
144 async fn create_rulebook(
145 &self,
146 uri: &str,
147 description: &str,
148 content: &str,
149 tags: Vec<String>,
150 visibility: Option<RuleBookVisibility>,
151 ) -> Result<CreateRuleBookResponse, String> {
152 if let Some(api) = &self.stakpak_api {
153 api.create_rulebook(&CreateRuleBookInput {
154 uri: uri.to_string(),
155 description: description.to_string(),
156 content: content.to_string(),
157 tags,
158 visibility,
159 })
160 .await
161 } else {
162 Err("Creating rulebooks requires Stakpak API key".to_string())
163 }
164 }
165
166 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167 if let Some(api) = &self.stakpak_api {
168 api.delete_rulebook(uri).await
169 } else {
170 Err("Deleting rulebooks requires Stakpak API key".to_string())
171 }
172 }
173
174 async fn chat_completion(
179 &self,
180 model: Model,
181 messages: Vec<ChatMessage>,
182 tools: Option<Vec<Tool>>,
183 session_id: Option<Uuid>,
184 metadata: Option<serde_json::Value>,
185 ) -> Result<ChatCompletionResponse, String> {
186 let mut ctx = HookContext::new(
187 session_id,
188 AgentState::new(model, messages, tools, metadata),
189 );
190
191 self.hook_registry
193 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
194 .await
195 .map_err(|e| e.to_string())?
196 .ok()?;
197
198 let current_session = self.initialize_session(&ctx).await?;
200 ctx.set_session_id(current_session.session_id);
201
202 let new_message = self.run_agent_completion(&mut ctx, None).await?;
204 ctx.state.append_new_message(new_message.clone());
205
206 let result = self
208 .save_checkpoint(
209 ¤t_session,
210 ctx.state.messages.clone(),
211 ctx.state.metadata.clone(),
212 )
213 .await?;
214 let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
215 ctx.set_new_checkpoint_id(result.checkpoint_id);
216
217 self.hook_registry
219 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
220 .await
221 .map_err(|e| e.to_string())?
222 .ok()?;
223
224 let mut meta = serde_json::Map::new();
225 if let Some(session_id) = ctx.session_id {
226 meta.insert(
227 "session_id".to_string(),
228 serde_json::Value::String(session_id.to_string()),
229 );
230 }
231 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
232 meta.insert(
233 "checkpoint_id".to_string(),
234 serde_json::Value::String(checkpoint_id.to_string()),
235 );
236 }
237 if let Some(state_metadata) = &ctx.state.metadata {
238 meta.insert("state_metadata".to_string(), state_metadata.clone());
239 }
240
241 Ok(ChatCompletionResponse {
242 id: ctx.new_checkpoint_id.unwrap().to_string(),
243 object: "chat.completion".to_string(),
244 created: checkpoint_created_at,
245 model: ctx
246 .state
247 .llm_input
248 .as_ref()
249 .map(|llm_input| llm_input.model.id.clone())
250 .unwrap_or_default(),
251 choices: vec![ChatCompletionChoice {
252 index: 0,
253 message: ctx.state.messages.last().cloned().unwrap(),
254 logprobs: None,
255 finish_reason: FinishReason::Stop,
256 }],
257 usage: ctx
258 .state
259 .llm_output
260 .as_ref()
261 .map(|u| u.usage.clone())
262 .unwrap_or_default(),
263 system_fingerprint: None,
264 metadata: if meta.is_empty() {
265 None
266 } else {
267 Some(serde_json::Value::Object(meta))
268 },
269 })
270 }
271
272 async fn chat_completion_stream(
273 &self,
274 model: Model,
275 messages: Vec<ChatMessage>,
276 tools: Option<Vec<Tool>>,
277 _headers: Option<HeaderMap>,
278 session_id: Option<Uuid>,
279 metadata: Option<serde_json::Value>,
280 ) -> Result<
281 (
282 Pin<
283 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
284 >,
285 Option<String>,
286 ),
287 String,
288 > {
289 let mut ctx = HookContext::new(
290 session_id,
291 AgentState::new(model, messages, tools, metadata),
292 );
293
294 self.hook_registry
296 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
297 .await
298 .map_err(|e| e.to_string())?
299 .ok()?;
300
301 let current_session = self.initialize_session(&ctx).await?;
303 ctx.set_session_id(current_session.session_id);
304
305 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
306
307 let client = self.clone();
309 let mut ctx_clone = ctx.clone();
310
311 tokio::spawn(async move {
315 if tx.is_closed() {
317 return;
318 }
319
320 let result = client
321 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
322 .await;
323
324 match result {
325 Err(e) => {
326 let _ = tx.send(Err(e)).await;
327 }
328 Ok(new_message) => {
329 if tx.is_closed() {
331 return;
332 }
333
334 ctx_clone.state.append_new_message(new_message.clone());
335 if tx
336 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
337 .await
338 .is_err()
339 {
340 return;
342 }
343
344 if tx.is_closed() {
346 return;
347 }
348
349 let result = client
350 .save_checkpoint(
351 ¤t_session,
352 ctx_clone.state.messages.clone(),
353 ctx_clone.state.metadata.clone(),
354 )
355 .await;
356
357 match result {
358 Err(e) => {
359 let _ = tx.send(Err(e)).await;
360 }
361 Ok(updated) => {
362 ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
363 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
364 }
365 }
366 }
367 }
368 });
369
370 let hook_registry = self.hook_registry.clone();
371 let stream = async_stream::stream! {
372 while let Some(delta_result) = rx.recv().await {
373 match delta_result {
374 Ok(delta) => match delta {
375 StreamMessage::Ctx(updated_ctx) => {
376 ctx = *updated_ctx;
377 if let Some(session_id) = ctx.session_id {
379 let mut meta = serde_json::Map::new();
380 meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
381 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
382 meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
383 }
384 if let Some(state_metadata) = &ctx.state.metadata {
385 meta.insert("state_metadata".to_string(), state_metadata.clone());
386 }
387 yield Ok(ChatCompletionStreamResponse {
388 id: ctx.request_id.to_string(),
389 object: "chat.completion.chunk".to_string(),
390 created: chrono::Utc::now().timestamp() as u64,
391 model: String::new(),
392 choices: vec![],
393 usage: None,
394 metadata: Some(serde_json::Value::Object(meta)),
395 });
396 }
397 }
398 StreamMessage::Delta(delta) => {
399 let usage = if let GenerationDelta::Usage { usage } = &delta {
401 Some(usage.clone())
402 } else {
403 None
404 };
405
406 yield Ok(ChatCompletionStreamResponse {
407 id: ctx.request_id.to_string(),
408 object: "chat.completion.chunk".to_string(),
409 created: chrono::Utc::now().timestamp() as u64,
410 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
411 choices: vec![ChatCompletionStreamChoice {
412 index: 0,
413 delta: delta.into(),
414 finish_reason: None,
415 }],
416 usage,
417 metadata: None,
418 })
419 }
420 }
421 Err(e) => yield Err(ApiStreamError::Unknown(e)),
422 }
423 }
424
425 hook_registry
427 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
428 .await
429 .map_err(|e| e.to_string())?
430 .ok()?;
431 };
432
433 Ok((Box::pin(stream), None))
434 }
435
436 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
437 if let Some(api) = &self.stakpak_api {
438 api.cancel_request(&request_id).await
439 } else {
440 Ok(())
442 }
443 }
444
445 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
450 if let Some(api) = &self.stakpak_api {
451 api.search_docs(&crate::stakpak::SearchDocsRequest {
452 keywords: input.keywords.clone(),
453 exclude_keywords: input.exclude_keywords.clone(),
454 limit: input.limit,
455 })
456 .await
457 } else {
458 use stakpak_shared::models::integrations::search_service::*;
460
461 let config = SearchServicesOrchestrator::start()
462 .await
463 .map_err(|e| e.to_string())?;
464
465 let api_url = format!("http://localhost:{}", config.api_port);
466 let search_client = SearchClient::new(api_url);
467
468 let search_results = search_client
469 .search_and_scrape(input.keywords.clone(), None)
470 .await
471 .map_err(|e| e.to_string())?;
472
473 if search_results.is_empty() {
474 return Ok(vec![Content::text("No results found".to_string())]);
475 }
476
477 Ok(search_results
478 .into_iter()
479 .map(|result| {
480 let content = result.content.unwrap_or_default();
481 Content::text(format!("URL: {}\nContent: {}", result.url, content))
482 })
483 .collect())
484 }
485 }
486
487 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
492 if let Some(api) = &self.stakpak_api {
493 api.memorize_session(checkpoint_id).await
494 } else {
495 Ok(())
497 }
498 }
499
500 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
501 if let Some(api) = &self.stakpak_api {
502 api.search_memory(&crate::stakpak::SearchMemoryRequest {
503 keywords: input.keywords.clone(),
504 start_time: input.start_time,
505 end_time: input.end_time,
506 })
507 .await
508 } else {
509 Ok(vec![])
511 }
512 }
513
514 async fn slack_read_messages(
519 &self,
520 input: &SlackReadMessagesRequest,
521 ) -> Result<Vec<Content>, String> {
522 if let Some(api) = &self.stakpak_api {
523 api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
524 channel: input.channel.clone(),
525 limit: input.limit,
526 })
527 .await
528 } else {
529 Err("Slack integration requires Stakpak API key".to_string())
530 }
531 }
532
533 async fn slack_read_replies(
534 &self,
535 input: &SlackReadRepliesRequest,
536 ) -> Result<Vec<Content>, String> {
537 if let Some(api) = &self.stakpak_api {
538 api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
539 channel: input.channel.clone(),
540 ts: input.ts.clone(),
541 })
542 .await
543 } else {
544 Err("Slack integration requires Stakpak API key".to_string())
545 }
546 }
547
548 async fn slack_send_message(
549 &self,
550 input: &SlackSendMessageRequest,
551 ) -> Result<Vec<Content>, String> {
552 if let Some(api) = &self.stakpak_api {
553 api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
554 channel: input.channel.clone(),
555 markdown_text: input.markdown_text.clone(),
556 thread_ts: input.thread_ts.clone(),
557 })
558 .await
559 } else {
560 Err("Slack integration requires Stakpak API key".to_string())
561 }
562 }
563
564 async fn list_models(&self) -> Vec<stakai::Model> {
569 let registry = self.stakai.registry();
573 let mut all_models = Vec::new();
574
575 for provider_id in registry.list_providers() {
576 if let Ok(mut models) = registry.models_for_provider(&provider_id).await {
577 all_models.append(&mut models);
578 }
579 }
580
581 sort_models_by_recency(&mut all_models);
582 all_models
583 }
584}
585
586fn sort_models_by_recency(models: &mut [stakai::Model]) {
588 models.sort_by(|a, b| {
589 match (&b.release_date, &a.release_date) {
590 (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
591 (Some(_), None) => std::cmp::Ordering::Less,
592 (None, Some(_)) => std::cmp::Ordering::Greater,
593 (None, None) => b.id.cmp(&a.id), }
595 });
596}
597
598#[async_trait]
603impl crate::storage::SessionStorage for super::AgentClient {
604 fn backend_info(&self) -> crate::storage::BackendInfo {
605 self.session_storage.backend_info()
606 }
607
608 async fn list_sessions(
609 &self,
610 query: &crate::storage::ListSessionsQuery,
611 ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
612 self.session_storage.list_sessions(query).await
613 }
614
615 async fn get_session(
616 &self,
617 session_id: Uuid,
618 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
619 self.session_storage.get_session(session_id).await
620 }
621
622 async fn create_session(
623 &self,
624 request: &crate::storage::CreateSessionRequest,
625 ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
626 self.session_storage.create_session(request).await
627 }
628
629 async fn update_session(
630 &self,
631 session_id: Uuid,
632 request: &crate::storage::UpdateSessionRequest,
633 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
634 self.session_storage
635 .update_session(session_id, request)
636 .await
637 }
638
639 async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
640 self.session_storage.delete_session(session_id).await
641 }
642
643 async fn list_checkpoints(
644 &self,
645 session_id: Uuid,
646 query: &crate::storage::ListCheckpointsQuery,
647 ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
648 self.session_storage
649 .list_checkpoints(session_id, query)
650 .await
651 }
652
653 async fn get_checkpoint(
654 &self,
655 checkpoint_id: Uuid,
656 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
657 self.session_storage.get_checkpoint(checkpoint_id).await
658 }
659
660 async fn create_checkpoint(
661 &self,
662 session_id: Uuid,
663 request: &crate::storage::CreateCheckpointRequest,
664 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
665 self.session_storage
666 .create_checkpoint(session_id, request)
667 .await
668 }
669
670 async fn get_active_checkpoint(
671 &self,
672 session_id: Uuid,
673 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
674 self.session_storage.get_active_checkpoint(session_id).await
675 }
676
677 async fn get_session_stats(
678 &self,
679 session_id: Uuid,
680 ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
681 self.session_storage.get_session_stats(session_id).await
682 }
683}
684
685const TITLE_GENERATOR_PROMPT: &str = include_str!("../prompts/session_title_generator.v1.txt");
690
691impl AgentClient {
692 pub(crate) async fn initialize_session(
697 &self,
698 ctx: &HookContext<AgentState>,
699 ) -> Result<SessionInfo, String> {
700 let messages = &ctx.state.messages;
701
702 if messages.is_empty() {
703 return Err("At least one message is required".to_string());
704 }
705
706 if let Some(session_id) = ctx.session_id {
708 let session = self
709 .session_storage
710 .get_session(session_id)
711 .await
712 .map_err(|e| e.to_string())?;
713
714 let checkpoint = session
715 .active_checkpoint
716 .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
717
718 if session.title.trim().is_empty() || session.title == "New Session" {
720 let client = self.clone();
721 let messages_for_title = messages.to_vec();
722 let session_id = session.id;
723 let existing_title = session.title.clone();
724 tokio::spawn(async move {
725 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
726 let trimmed = title.trim();
727 if !trimmed.is_empty() && trimmed != existing_title {
728 let request =
729 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
730 let _ = client
731 .session_storage
732 .update_session(session_id, &request)
733 .await;
734 }
735 }
736 });
737 }
738
739 return Ok(SessionInfo {
740 session_id: session.id,
741 checkpoint_id: checkpoint.id,
742 checkpoint_created_at: checkpoint.created_at,
743 });
744 }
745
746 let fallback_title = Self::fallback_session_title(messages);
748
749 let cwd = std::env::current_dir()
751 .ok()
752 .map(|p| p.to_string_lossy().to_string());
753
754 let mut session_request =
756 StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
757 if let Some(cwd) = cwd {
758 session_request = session_request.with_cwd(cwd);
759 }
760
761 let result = self
762 .session_storage
763 .create_session(&session_request)
764 .await
765 .map_err(|e| e.to_string())?;
766
767 let client = self.clone();
769 let messages_for_title = messages.to_vec();
770 let session_id = result.session_id;
771 tokio::spawn(async move {
772 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
773 let trimmed = title.trim();
774 if !trimmed.is_empty() && trimmed != fallback_title {
775 let request =
776 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
777 let _ = client
778 .session_storage
779 .update_session(session_id, &request)
780 .await;
781 }
782 }
783 });
784
785 Ok(SessionInfo {
786 session_id: result.session_id,
787 checkpoint_id: result.checkpoint.id,
788 checkpoint_created_at: result.checkpoint.created_at,
789 })
790 }
791
792 fn fallback_session_title(messages: &[ChatMessage]) -> String {
793 messages
794 .iter()
795 .find(|m| m.role == Role::User)
796 .and_then(|m| m.content.as_ref())
797 .map(|c| {
798 let text = c.to_string();
799 text.split_whitespace()
800 .take(5)
801 .collect::<Vec<_>>()
802 .join(" ")
803 })
804 .unwrap_or_else(|| "New Session".to_string())
805 }
806
807 pub(crate) async fn save_checkpoint(
809 &self,
810 current: &SessionInfo,
811 messages: Vec<ChatMessage>,
812 metadata: Option<serde_json::Value>,
813 ) -> Result<SessionInfo, String> {
814 let mut checkpoint_request =
815 StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
816
817 if let Some(meta) = metadata {
818 checkpoint_request = checkpoint_request.with_metadata(meta);
819 }
820
821 let checkpoint = self
822 .session_storage
823 .create_checkpoint(current.session_id, &checkpoint_request)
824 .await
825 .map_err(|e| e.to_string())?;
826
827 Ok(SessionInfo {
828 session_id: current.session_id,
829 checkpoint_id: checkpoint.id,
830 checkpoint_created_at: checkpoint.created_at,
831 })
832 }
833
834 pub(crate) async fn run_agent_completion(
836 &self,
837 ctx: &mut HookContext<AgentState>,
838 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
839 ) -> Result<ChatMessage, String> {
840 self.hook_registry
842 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
843 .await
844 .map_err(|e| e.to_string())?
845 .ok()?;
846
847 let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
848 llm_input
849 } else {
850 return Err(
851 "LLM input not found, make sure to register a context hook before inference"
852 .to_string(),
853 );
854 };
855
856 if let Some(session_id) = ctx.session_id {
858 let headers = input
859 .headers
860 .get_or_insert_with(std::collections::HashMap::new);
861 headers.insert("X-Session-Id".to_string(), session_id.to_string());
862 }
863
864 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
865 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
867 let stream_input = LLMStreamInput {
868 model: input.model,
869 messages: input.messages,
870 max_tokens: input.max_tokens,
871 tools: input.tools,
872 stream_channel_tx: internal_tx,
873 provider_options: input.provider_options,
874 headers: input.headers,
875 };
876
877 let stakai = self.stakai.clone();
878 let chat_future = async move {
879 stakai
880 .chat_stream(stream_input)
881 .await
882 .map_err(|e| e.to_string())
883 };
884
885 let receive_future = async move {
886 while let Some(delta) = internal_rx.recv().await {
887 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
888 break;
889 }
890 }
891 };
892
893 let (chat_result, _) = tokio::join!(chat_future, receive_future);
894 let response = chat_result?;
895 (response.choices[0].message.clone(), response.usage)
896 } else {
897 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
899 (response.choices[0].message.clone(), response.usage)
900 };
901
902 ctx.state.set_llm_output(response_message, usage);
903
904 self.hook_registry
906 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
907 .await
908 .map_err(|e| e.to_string())?
909 .ok()?;
910
911 let llm_output = ctx
912 .state
913 .llm_output
914 .as_ref()
915 .ok_or_else(|| "LLM output is missing from state".to_string())?;
916
917 Ok(ChatMessage::from(llm_output))
918 }
919
920 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
922 let use_stakpak = self.stakpak.is_some();
924 let providers = self.stakai.registry().list_providers();
925 let cheap_models: &[(&str, &str)] = &[
926 ("stakpak", "claude-haiku-4-5"),
927 ("anthropic", "claude-haiku-4-5"),
928 ("amazon-bedrock", "claude-haiku-4-5"),
929 ("openai", "gpt-4.1-mini"),
930 ("google", "gemini-2.5-flash"),
931 ];
932 let model = cheap_models
933 .iter()
934 .find_map(|(provider, model_id)| {
935 if providers.contains(&provider.to_string()) {
936 crate::find_model(model_id, use_stakpak)
937 } else {
938 None
939 }
940 })
941 .ok_or_else(|| "No model available for title generation".to_string())?;
942
943 let llm_messages = vec![
944 LLMMessage {
945 role: Role::System.to_string(),
946 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
947 },
948 LLMMessage {
949 role: Role::User.to_string(),
950 content: LLMMessageContent::String(
951 messages
952 .iter()
953 .map(|msg| {
954 msg.content
955 .as_ref()
956 .unwrap_or(&MessageContent::String("".to_string()))
957 .to_string()
958 })
959 .collect(),
960 ),
961 },
962 ];
963
964 let input = LLMInput {
965 model,
966 messages: llm_messages,
967 max_tokens: 100,
968 tools: None,
969 provider_options: None,
970 headers: None,
971 };
972
973 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
974
975 Ok(response.choices[0].message.content.to_string())
976 }
977}