1use crate::AgentProvider;
9use crate::models::*;
10use crate::storage::{
11 CreateCheckpointRequest as StorageCreateCheckpointRequest,
12 CreateSessionRequest as StorageCreateSessionRequest,
13 UpdateSessionRequest as StorageUpdateSessionRequest,
14};
15use async_trait::async_trait;
16use futures_util::Stream;
17use reqwest::header::HeaderMap;
18use rmcp::model::Content;
19use stakai::Model;
20use stakpak_shared::hooks::{HookContext, LifecycleEvent};
21use stakpak_shared::models::integrations::openai::{
22 ChatCompletionChoice, ChatCompletionResponse, ChatCompletionStreamChoice,
23 ChatCompletionStreamResponse, ChatMessage, FinishReason, MessageContent, Role, Tool,
24};
25use stakpak_shared::models::llm::{
26 GenerationDelta, LLMInput, LLMMessage, LLMMessageContent, LLMStreamInput,
27};
28use std::pin::Pin;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32#[derive(Debug, Clone)]
34pub(crate) struct SessionInfo {
35 session_id: Uuid,
36 checkpoint_id: Uuid,
37 checkpoint_created_at: chrono::DateTime<chrono::Utc>,
38}
39
40use super::AgentClient;
41
42#[derive(Debug)]
47pub(crate) enum StreamMessage {
48 Delta(GenerationDelta),
49 Ctx(Box<HookContext<AgentState>>),
50}
51
52#[async_trait]
57impl AgentProvider for AgentClient {
58 async fn get_my_account(&self) -> Result<GetMyAccountResponse, String> {
63 if let Some(api) = &self.stakpak_api {
64 api.get_account().await
65 } else {
66 Ok(GetMyAccountResponse {
68 username: "local".to_string(),
69 id: "local".to_string(),
70 first_name: "local".to_string(),
71 last_name: "local".to_string(),
72 email: "local@stakpak.dev".to_string(),
73 scope: None,
74 })
75 }
76 }
77
78 async fn get_billing_info(
79 &self,
80 account_username: &str,
81 ) -> Result<stakpak_shared::models::billing::BillingResponse, String> {
82 if let Some(api) = &self.stakpak_api {
83 api.get_billing(account_username).await
84 } else {
85 Err("Billing info not available without Stakpak API key".to_string())
86 }
87 }
88
89 async fn list_rulebooks(&self) -> Result<Vec<ListRuleBook>, String> {
94 if let Some(api) = &self.stakpak_api {
95 api.list_rulebooks().await
96 } else {
97 let client = stakpak_shared::tls_client::create_tls_client(
99 stakpak_shared::tls_client::TlsClientConfig::default()
100 .with_timeout(std::time::Duration::from_secs(30)),
101 )?;
102
103 let url = format!("{}/v1/rules", self.get_stakpak_api_endpoint());
104 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
105
106 if response.status().is_success() {
107 let value: serde_json::Value = response.json().await.map_err(|e| e.to_string())?;
108 match serde_json::from_value::<ListRulebooksResponse>(value) {
109 Ok(resp) => Ok(resp.results),
110 Err(_) => Ok(vec![]),
111 }
112 } else {
113 Ok(vec![])
114 }
115 }
116 }
117
118 async fn get_rulebook_by_uri(&self, uri: &str) -> Result<RuleBook, String> {
119 if let Some(api) = &self.stakpak_api {
120 api.get_rulebook_by_uri(uri).await
121 } else {
122 let client = stakpak_shared::tls_client::create_tls_client(
124 stakpak_shared::tls_client::TlsClientConfig::default()
125 .with_timeout(std::time::Duration::from_secs(30)),
126 )?;
127
128 let encoded_uri = urlencoding::encode(uri);
129 let url = format!(
130 "{}/v1/rules/{}",
131 self.get_stakpak_api_endpoint(),
132 encoded_uri
133 );
134 let response = client.get(&url).send().await.map_err(|e| e.to_string())?;
135
136 if response.status().is_success() {
137 response.json().await.map_err(|e| e.to_string())
138 } else {
139 Err("Rulebook not found".to_string())
140 }
141 }
142 }
143
144 async fn create_rulebook(
145 &self,
146 uri: &str,
147 description: &str,
148 content: &str,
149 tags: Vec<String>,
150 visibility: Option<RuleBookVisibility>,
151 ) -> Result<CreateRuleBookResponse, String> {
152 if let Some(api) = &self.stakpak_api {
153 api.create_rulebook(&CreateRuleBookInput {
154 uri: uri.to_string(),
155 description: description.to_string(),
156 content: content.to_string(),
157 tags,
158 visibility,
159 })
160 .await
161 } else {
162 Err("Creating rulebooks requires Stakpak API key".to_string())
163 }
164 }
165
166 async fn delete_rulebook(&self, uri: &str) -> Result<(), String> {
167 if let Some(api) = &self.stakpak_api {
168 api.delete_rulebook(uri).await
169 } else {
170 Err("Deleting rulebooks requires Stakpak API key".to_string())
171 }
172 }
173
174 async fn chat_completion(
179 &self,
180 model: Model,
181 messages: Vec<ChatMessage>,
182 tools: Option<Vec<Tool>>,
183 session_id: Option<Uuid>,
184 ) -> Result<ChatCompletionResponse, String> {
185 let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
186
187 self.hook_registry
189 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
190 .await
191 .map_err(|e| e.to_string())?
192 .ok()?;
193
194 let current_session = self.initialize_session(&ctx).await?;
196 ctx.set_session_id(current_session.session_id);
197
198 let new_message = self.run_agent_completion(&mut ctx, None).await?;
200 ctx.state.append_new_message(new_message.clone());
201
202 let result = self
204 .save_checkpoint(¤t_session, ctx.state.messages.clone())
205 .await?;
206 let checkpoint_created_at = result.checkpoint_created_at.timestamp() as u64;
207 ctx.set_new_checkpoint_id(result.checkpoint_id);
208
209 self.hook_registry
211 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
212 .await
213 .map_err(|e| e.to_string())?
214 .ok()?;
215
216 let mut meta = serde_json::Map::new();
217 if let Some(session_id) = ctx.session_id {
218 meta.insert(
219 "session_id".to_string(),
220 serde_json::Value::String(session_id.to_string()),
221 );
222 }
223 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
224 meta.insert(
225 "checkpoint_id".to_string(),
226 serde_json::Value::String(checkpoint_id.to_string()),
227 );
228 }
229
230 Ok(ChatCompletionResponse {
231 id: ctx.new_checkpoint_id.unwrap().to_string(),
232 object: "chat.completion".to_string(),
233 created: checkpoint_created_at,
234 model: ctx
235 .state
236 .llm_input
237 .as_ref()
238 .map(|llm_input| llm_input.model.id.clone())
239 .unwrap_or_default(),
240 choices: vec![ChatCompletionChoice {
241 index: 0,
242 message: ctx.state.messages.last().cloned().unwrap(),
243 logprobs: None,
244 finish_reason: FinishReason::Stop,
245 }],
246 usage: ctx
247 .state
248 .llm_output
249 .as_ref()
250 .map(|u| u.usage.clone())
251 .unwrap_or_default(),
252 system_fingerprint: None,
253 metadata: if meta.is_empty() {
254 None
255 } else {
256 Some(serde_json::Value::Object(meta))
257 },
258 })
259 }
260
261 async fn chat_completion_stream(
262 &self,
263 model: Model,
264 messages: Vec<ChatMessage>,
265 tools: Option<Vec<Tool>>,
266 _headers: Option<HeaderMap>,
267 session_id: Option<Uuid>,
268 ) -> Result<
269 (
270 Pin<
271 Box<dyn Stream<Item = Result<ChatCompletionStreamResponse, ApiStreamError>> + Send>,
272 >,
273 Option<String>,
274 ),
275 String,
276 > {
277 let mut ctx = HookContext::new(session_id, AgentState::new(model, messages, tools));
278
279 self.hook_registry
281 .execute_hooks(&mut ctx, &LifecycleEvent::BeforeRequest)
282 .await
283 .map_err(|e| e.to_string())?
284 .ok()?;
285
286 let current_session = self.initialize_session(&ctx).await?;
288 ctx.set_session_id(current_session.session_id);
289
290 let (tx, mut rx) = mpsc::channel::<Result<StreamMessage, String>>(100);
291
292 let client = self.clone();
294 let mut ctx_clone = ctx.clone();
295
296 tokio::spawn(async move {
300 if tx.is_closed() {
302 return;
303 }
304
305 let result = client
306 .run_agent_completion(&mut ctx_clone, Some(tx.clone()))
307 .await;
308
309 match result {
310 Err(e) => {
311 let _ = tx.send(Err(e)).await;
312 }
313 Ok(new_message) => {
314 if tx.is_closed() {
316 return;
317 }
318
319 ctx_clone.state.append_new_message(new_message.clone());
320 if tx
321 .send(Ok(StreamMessage::Ctx(Box::new(ctx_clone.clone()))))
322 .await
323 .is_err()
324 {
325 return;
327 }
328
329 if tx.is_closed() {
331 return;
332 }
333
334 let result = client
335 .save_checkpoint(¤t_session, ctx_clone.state.messages.clone())
336 .await;
337
338 match result {
339 Err(e) => {
340 let _ = tx.send(Err(e)).await;
341 }
342 Ok(updated) => {
343 ctx_clone.set_new_checkpoint_id(updated.checkpoint_id);
344 let _ = tx.send(Ok(StreamMessage::Ctx(Box::new(ctx_clone)))).await;
345 }
346 }
347 }
348 }
349 });
350
351 let hook_registry = self.hook_registry.clone();
352 let stream = async_stream::stream! {
353 while let Some(delta_result) = rx.recv().await {
354 match delta_result {
355 Ok(delta) => match delta {
356 StreamMessage::Ctx(updated_ctx) => {
357 ctx = *updated_ctx;
358 if let Some(session_id) = ctx.session_id {
360 let mut meta = serde_json::Map::new();
361 meta.insert("session_id".to_string(), serde_json::Value::String(session_id.to_string()));
362 if let Some(checkpoint_id) = ctx.new_checkpoint_id {
363 meta.insert("checkpoint_id".to_string(), serde_json::Value::String(checkpoint_id.to_string()));
364 }
365 yield Ok(ChatCompletionStreamResponse {
366 id: ctx.request_id.to_string(),
367 object: "chat.completion.chunk".to_string(),
368 created: chrono::Utc::now().timestamp() as u64,
369 model: String::new(),
370 choices: vec![],
371 usage: None,
372 metadata: Some(serde_json::Value::Object(meta)),
373 });
374 }
375 }
376 StreamMessage::Delta(delta) => {
377 let usage = if let GenerationDelta::Usage { usage } = &delta {
379 Some(usage.clone())
380 } else {
381 None
382 };
383
384 yield Ok(ChatCompletionStreamResponse {
385 id: ctx.request_id.to_string(),
386 object: "chat.completion.chunk".to_string(),
387 created: chrono::Utc::now().timestamp() as u64,
388 model: ctx.state.llm_input.as_ref().map(|llm_input| llm_input.model.clone().to_string()).unwrap_or_default(),
389 choices: vec![ChatCompletionStreamChoice {
390 index: 0,
391 delta: delta.into(),
392 finish_reason: None,
393 }],
394 usage,
395 metadata: None,
396 })
397 }
398 }
399 Err(e) => yield Err(ApiStreamError::Unknown(e)),
400 }
401 }
402
403 hook_registry
405 .execute_hooks(&mut ctx, &LifecycleEvent::AfterRequest)
406 .await
407 .map_err(|e| e.to_string())?
408 .ok()?;
409 };
410
411 Ok((Box::pin(stream), None))
412 }
413
414 async fn cancel_stream(&self, request_id: String) -> Result<(), String> {
415 if let Some(api) = &self.stakpak_api {
416 api.cancel_request(&request_id).await
417 } else {
418 Ok(())
420 }
421 }
422
423 async fn search_docs(&self, input: &SearchDocsRequest) -> Result<Vec<Content>, String> {
428 if let Some(api) = &self.stakpak_api {
429 api.search_docs(&crate::stakpak::SearchDocsRequest {
430 keywords: input.keywords.clone(),
431 exclude_keywords: input.exclude_keywords.clone(),
432 limit: input.limit,
433 })
434 .await
435 } else {
436 use stakpak_shared::models::integrations::search_service::*;
438
439 let config = SearchServicesOrchestrator::start()
440 .await
441 .map_err(|e| e.to_string())?;
442
443 let api_url = format!("http://localhost:{}", config.api_port);
444 let search_client = SearchClient::new(api_url);
445
446 let search_results = search_client
447 .search_and_scrape(input.keywords.clone(), None)
448 .await
449 .map_err(|e| e.to_string())?;
450
451 if search_results.is_empty() {
452 return Ok(vec![Content::text("No results found".to_string())]);
453 }
454
455 Ok(search_results
456 .into_iter()
457 .map(|result| {
458 let content = result.content.unwrap_or_default();
459 Content::text(format!("URL: {}\nContent: {}", result.url, content))
460 })
461 .collect())
462 }
463 }
464
465 async fn memorize_session(&self, checkpoint_id: Uuid) -> Result<(), String> {
470 if let Some(api) = &self.stakpak_api {
471 api.memorize_session(checkpoint_id).await
472 } else {
473 Ok(())
475 }
476 }
477
478 async fn search_memory(&self, input: &SearchMemoryRequest) -> Result<Vec<Content>, String> {
479 if let Some(api) = &self.stakpak_api {
480 api.search_memory(&crate::stakpak::SearchMemoryRequest {
481 keywords: input.keywords.clone(),
482 start_time: input.start_time,
483 end_time: input.end_time,
484 })
485 .await
486 } else {
487 Ok(vec![])
489 }
490 }
491
492 async fn slack_read_messages(
497 &self,
498 input: &SlackReadMessagesRequest,
499 ) -> Result<Vec<Content>, String> {
500 if let Some(api) = &self.stakpak_api {
501 api.slack_read_messages(&crate::stakpak::SlackReadMessagesRequest {
502 channel: input.channel.clone(),
503 limit: input.limit,
504 })
505 .await
506 } else {
507 Err("Slack integration requires Stakpak API key".to_string())
508 }
509 }
510
511 async fn slack_read_replies(
512 &self,
513 input: &SlackReadRepliesRequest,
514 ) -> Result<Vec<Content>, String> {
515 if let Some(api) = &self.stakpak_api {
516 api.slack_read_replies(&crate::stakpak::SlackReadRepliesRequest {
517 channel: input.channel.clone(),
518 ts: input.ts.clone(),
519 })
520 .await
521 } else {
522 Err("Slack integration requires Stakpak API key".to_string())
523 }
524 }
525
526 async fn slack_send_message(
527 &self,
528 input: &SlackSendMessageRequest,
529 ) -> Result<Vec<Content>, String> {
530 if let Some(api) = &self.stakpak_api {
531 api.slack_send_message(&crate::stakpak::SlackSendMessageRequest {
532 channel: input.channel.clone(),
533 markdown_text: input.markdown_text.clone(),
534 thread_ts: input.thread_ts.clone(),
535 })
536 .await
537 } else {
538 Err("Slack integration requires Stakpak API key".to_string())
539 }
540 }
541
542 async fn list_models(&self) -> Vec<stakai::Model> {
547 const PROVIDERS: &[&str] = &["anthropic", "openai", "google"];
548
549 let use_stakpak = self.has_stakpak();
550 let mut all_models = Vec::new();
551
552 for &provider_id in PROVIDERS {
553 let mut models = load_and_transform_models(provider_id, use_stakpak);
554 sort_models_by_recency(&mut models);
555 all_models.extend(models);
556 }
557
558 all_models
559 }
560}
561
562fn load_and_transform_models(provider_id: &str, use_stakpak: bool) -> Vec<stakai::Model> {
564 let models = stakai::load_models_for_provider(provider_id).unwrap_or_default();
565
566 if use_stakpak {
567 models
568 .into_iter()
569 .map(|m| stakai::Model {
570 id: format!("{}/{}", provider_id, m.id),
571 provider: "stakpak".into(),
572 name: m.name,
573 reasoning: m.reasoning,
574 cost: m.cost,
575 limit: m.limit,
576 release_date: m.release_date,
577 })
578 .collect()
579 } else {
580 models
581 }
582}
583
584fn sort_models_by_recency(models: &mut [stakai::Model]) {
586 models.sort_by(|a, b| {
587 match (&b.release_date, &a.release_date) {
588 (Some(b_date), Some(a_date)) => b_date.cmp(a_date),
589 (Some(_), None) => std::cmp::Ordering::Less,
590 (None, Some(_)) => std::cmp::Ordering::Greater,
591 (None, None) => b.id.cmp(&a.id), }
593 });
594}
595
596#[async_trait]
601impl crate::storage::SessionStorage for super::AgentClient {
602 async fn list_sessions(
603 &self,
604 query: &crate::storage::ListSessionsQuery,
605 ) -> Result<crate::storage::ListSessionsResult, crate::storage::StorageError> {
606 self.session_storage.list_sessions(query).await
607 }
608
609 async fn get_session(
610 &self,
611 session_id: Uuid,
612 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
613 self.session_storage.get_session(session_id).await
614 }
615
616 async fn create_session(
617 &self,
618 request: &crate::storage::CreateSessionRequest,
619 ) -> Result<crate::storage::CreateSessionResult, crate::storage::StorageError> {
620 self.session_storage.create_session(request).await
621 }
622
623 async fn update_session(
624 &self,
625 session_id: Uuid,
626 request: &crate::storage::UpdateSessionRequest,
627 ) -> Result<crate::storage::Session, crate::storage::StorageError> {
628 self.session_storage
629 .update_session(session_id, request)
630 .await
631 }
632
633 async fn delete_session(&self, session_id: Uuid) -> Result<(), crate::storage::StorageError> {
634 self.session_storage.delete_session(session_id).await
635 }
636
637 async fn list_checkpoints(
638 &self,
639 session_id: Uuid,
640 query: &crate::storage::ListCheckpointsQuery,
641 ) -> Result<crate::storage::ListCheckpointsResult, crate::storage::StorageError> {
642 self.session_storage
643 .list_checkpoints(session_id, query)
644 .await
645 }
646
647 async fn get_checkpoint(
648 &self,
649 checkpoint_id: Uuid,
650 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
651 self.session_storage.get_checkpoint(checkpoint_id).await
652 }
653
654 async fn create_checkpoint(
655 &self,
656 session_id: Uuid,
657 request: &crate::storage::CreateCheckpointRequest,
658 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
659 self.session_storage
660 .create_checkpoint(session_id, request)
661 .await
662 }
663
664 async fn get_active_checkpoint(
665 &self,
666 session_id: Uuid,
667 ) -> Result<crate::storage::Checkpoint, crate::storage::StorageError> {
668 self.session_storage.get_active_checkpoint(session_id).await
669 }
670
671 async fn get_session_stats(
672 &self,
673 session_id: Uuid,
674 ) -> Result<crate::storage::SessionStats, crate::storage::StorageError> {
675 self.session_storage.get_session_stats(session_id).await
676 }
677}
678
679const TITLE_GENERATOR_PROMPT: &str =
684 include_str!("../local/prompts/session_title_generator.v1.txt");
685
686impl AgentClient {
687 pub(crate) async fn initialize_session(
692 &self,
693 ctx: &HookContext<AgentState>,
694 ) -> Result<SessionInfo, String> {
695 let messages = &ctx.state.messages;
696
697 if messages.is_empty() {
698 return Err("At least one message is required".to_string());
699 }
700
701 if let Some(session_id) = ctx.session_id {
703 let session = self
704 .session_storage
705 .get_session(session_id)
706 .await
707 .map_err(|e| e.to_string())?;
708
709 let checkpoint = session
710 .active_checkpoint
711 .ok_or_else(|| format!("Session {} has no active checkpoint", session_id))?;
712
713 if session.title.trim().is_empty() || session.title == "New Session" {
715 let client = self.clone();
716 let messages_for_title = messages.to_vec();
717 let session_id = session.id;
718 let existing_title = session.title.clone();
719 tokio::spawn(async move {
720 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
721 let trimmed = title.trim();
722 if !trimmed.is_empty() && trimmed != existing_title {
723 let request =
724 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
725 let _ = client
726 .session_storage
727 .update_session(session_id, &request)
728 .await;
729 }
730 }
731 });
732 }
733
734 return Ok(SessionInfo {
735 session_id: session.id,
736 checkpoint_id: checkpoint.id,
737 checkpoint_created_at: checkpoint.created_at,
738 });
739 }
740
741 let fallback_title = Self::fallback_session_title(messages);
743
744 let cwd = std::env::current_dir()
746 .ok()
747 .map(|p| p.to_string_lossy().to_string());
748
749 let mut session_request =
751 StorageCreateSessionRequest::new(fallback_title.clone(), messages.to_vec());
752 if let Some(cwd) = cwd {
753 session_request = session_request.with_cwd(cwd);
754 }
755
756 let result = self
757 .session_storage
758 .create_session(&session_request)
759 .await
760 .map_err(|e| e.to_string())?;
761
762 let client = self.clone();
764 let messages_for_title = messages.to_vec();
765 let session_id = result.session_id;
766 tokio::spawn(async move {
767 if let Ok(title) = client.generate_session_title(&messages_for_title).await {
768 let trimmed = title.trim();
769 if !trimmed.is_empty() && trimmed != fallback_title {
770 let request =
771 StorageUpdateSessionRequest::new().with_title(trimmed.to_string());
772 let _ = client
773 .session_storage
774 .update_session(session_id, &request)
775 .await;
776 }
777 }
778 });
779
780 Ok(SessionInfo {
781 session_id: result.session_id,
782 checkpoint_id: result.checkpoint.id,
783 checkpoint_created_at: result.checkpoint.created_at,
784 })
785 }
786
787 fn fallback_session_title(messages: &[ChatMessage]) -> String {
788 messages
789 .iter()
790 .find(|m| m.role == Role::User)
791 .and_then(|m| m.content.as_ref())
792 .map(|c| {
793 let text = c.to_string();
794 text.split_whitespace()
795 .take(5)
796 .collect::<Vec<_>>()
797 .join(" ")
798 })
799 .unwrap_or_else(|| "New Session".to_string())
800 }
801
802 pub(crate) async fn save_checkpoint(
804 &self,
805 current: &SessionInfo,
806 messages: Vec<ChatMessage>,
807 ) -> Result<SessionInfo, String> {
808 let checkpoint_request =
809 StorageCreateCheckpointRequest::new(messages).with_parent(current.checkpoint_id);
810
811 let checkpoint = self
812 .session_storage
813 .create_checkpoint(current.session_id, &checkpoint_request)
814 .await
815 .map_err(|e| e.to_string())?;
816
817 Ok(SessionInfo {
818 session_id: current.session_id,
819 checkpoint_id: checkpoint.id,
820 checkpoint_created_at: checkpoint.created_at,
821 })
822 }
823
824 pub(crate) async fn run_agent_completion(
826 &self,
827 ctx: &mut HookContext<AgentState>,
828 stream_channel_tx: Option<mpsc::Sender<Result<StreamMessage, String>>>,
829 ) -> Result<ChatMessage, String> {
830 self.hook_registry
832 .execute_hooks(ctx, &LifecycleEvent::BeforeInference)
833 .await
834 .map_err(|e| e.to_string())?
835 .ok()?;
836
837 let mut input = if let Some(llm_input) = ctx.state.llm_input.clone() {
838 llm_input
839 } else {
840 return Err(
841 "LLM input not found, make sure to register a context hook before inference"
842 .to_string(),
843 );
844 };
845
846 if let Some(session_id) = ctx.session_id {
848 let headers = input
849 .headers
850 .get_or_insert_with(std::collections::HashMap::new);
851 headers.insert("X-Session-Id".to_string(), session_id.to_string());
852 }
853
854 let (response_message, usage) = if let Some(tx) = stream_channel_tx {
855 let (internal_tx, mut internal_rx) = mpsc::channel::<GenerationDelta>(100);
857 let stream_input = LLMStreamInput {
858 model: input.model,
859 messages: input.messages,
860 max_tokens: input.max_tokens,
861 tools: input.tools,
862 stream_channel_tx: internal_tx,
863 provider_options: input.provider_options,
864 headers: input.headers,
865 };
866
867 let stakai = self.stakai.clone();
868 let chat_future = async move {
869 stakai
870 .chat_stream(stream_input)
871 .await
872 .map_err(|e| e.to_string())
873 };
874
875 let receive_future = async move {
876 while let Some(delta) = internal_rx.recv().await {
877 if tx.send(Ok(StreamMessage::Delta(delta))).await.is_err() {
878 break;
879 }
880 }
881 };
882
883 let (chat_result, _) = tokio::join!(chat_future, receive_future);
884 let response = chat_result?;
885 (response.choices[0].message.clone(), response.usage)
886 } else {
887 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
889 (response.choices[0].message.clone(), response.usage)
890 };
891
892 ctx.state.set_llm_output(response_message, usage);
893
894 self.hook_registry
896 .execute_hooks(ctx, &LifecycleEvent::AfterInference)
897 .await
898 .map_err(|e| e.to_string())?
899 .ok()?;
900
901 let llm_output = ctx
902 .state
903 .llm_output
904 .as_ref()
905 .ok_or_else(|| "LLM output is missing from state".to_string())?;
906
907 Ok(ChatMessage::from(llm_output))
908 }
909
910 async fn generate_session_title(&self, messages: &[ChatMessage]) -> Result<String, String> {
912 let model = Model::new(
914 "claude-haiku-4-5-20250929",
915 "Claude Haiku 4.5",
916 "anthropic",
917 false,
918 None,
919 stakai::ModelLimit::default(),
920 );
921
922 let llm_messages = vec![
923 LLMMessage {
924 role: Role::System.to_string(),
925 content: LLMMessageContent::String(TITLE_GENERATOR_PROMPT.to_string()),
926 },
927 LLMMessage {
928 role: Role::User.to_string(),
929 content: LLMMessageContent::String(
930 messages
931 .iter()
932 .map(|msg| {
933 msg.content
934 .as_ref()
935 .unwrap_or(&MessageContent::String("".to_string()))
936 .to_string()
937 })
938 .collect(),
939 ),
940 },
941 ];
942
943 let input = LLMInput {
944 model,
945 messages: llm_messages,
946 max_tokens: 100,
947 tools: None,
948 provider_options: None,
949 headers: None,
950 };
951
952 let response = self.stakai.chat(input).await.map_err(|e| e.to_string())?;
953
954 Ok(response.choices[0].message.content.to_string())
955 }
956}