voice_echo/pipeline/
conversation.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use pulse_system_types::llm::{LmProvider, Message, MessageContent, Role};
6use tokio::sync::Mutex;
7
8pub struct ConversationManager {
14 provider: Arc<dyn LmProvider>,
15 sessions: Arc<Mutex<HashMap<String, Session>>>,
16 session_timeout: Duration,
17 system_prompt: String,
18 max_response_tokens: u32,
19}
20
21struct Session {
22 messages: Vec<Message>,
23 last_used: Instant,
24}
25
26impl ConversationManager {
27 pub fn new(
28 provider: Arc<dyn LmProvider>,
29 system_prompt: String,
30 session_timeout_secs: u64,
31 max_response_tokens: u32,
32 ) -> Self {
33 Self {
34 provider,
35 sessions: Arc::new(Mutex::new(HashMap::new())),
36 session_timeout: Duration::from_secs(session_timeout_secs),
37 system_prompt,
38 max_response_tokens,
39 }
40 }
41
42 pub async fn send(&self, call_sid: &str, prompt: &str) -> Result<String, ConversationError> {
47 let mut sessions = self.sessions.lock().await;
48
49 sessions.retain(|_, s| s.last_used.elapsed() < self.session_timeout);
51
52 let session = sessions
53 .entry(call_sid.to_string())
54 .or_insert_with(|| Session {
55 messages: Vec::new(),
56 last_used: Instant::now(),
57 });
58
59 session.messages.push(Message {
61 role: Role::User,
62 content: MessageContent::Text(prompt.to_string()),
63 });
64 session.last_used = Instant::now();
65
66 let messages = session.messages.clone();
68 drop(sessions);
69
70 tracing::info!(call_sid, provider = self.provider.name(), "Invoking LLM");
71
72 let response = self
73 .provider
74 .invoke(
75 &self.system_prompt,
76 &messages,
77 self.max_response_tokens,
78 None, )
80 .await
81 .map_err(|e| ConversationError::Provider(e.to_string()))?;
82
83 let text = response.text();
84
85 let mut sessions = self.sessions.lock().await;
87 if let Some(session) = sessions.get_mut(call_sid) {
88 session.messages.push(Message {
89 role: Role::Assistant,
90 content: MessageContent::Text(text.clone()),
91 });
92 session.last_used = Instant::now();
93 }
94
95 tracing::info!(call_sid, response_len = text.len(), "LLM responded");
96
97 Ok(text)
98 }
99
100 pub async fn end_session(&self, call_sid: &str) {
102 self.sessions.lock().await.remove(call_sid);
103 }
104}
105
106#[derive(Debug, thiserror::Error)]
107pub enum ConversationError {
108 #[error("LLM provider error: {0}")]
109 Provider(String),
110}