Skip to main content

rust_genai/
chats.rs

1//! Chats API surface.
2
3use std::sync::Arc;
4
5use futures_util::Stream;
6use futures_util::StreamExt;
7use tokio::sync::RwLock;
8
9use rust_genai_types::content::Content;
10use rust_genai_types::models::GenerateContentConfig;
11use rust_genai_types::response::GenerateContentResponse;
12
13use crate::afc::CallableTool;
14use crate::client::ClientInner;
15use crate::error::Result;
16use crate::models::Models;
17
18#[derive(Clone)]
19pub struct Chats {
20    pub(crate) inner: Arc<ClientInner>,
21}
22
23impl Chats {
24    pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
25        Self { inner }
26    }
27
28    /// 创建新会话。
29    pub fn create(&self, model: impl Into<String>) -> ChatSession {
30        ChatSession::new(self.inner.clone(), model.into())
31    }
32
33    /// 带配置创建会话。
34    pub fn create_with_config(
35        &self,
36        model: impl Into<String>,
37        config: GenerateContentConfig,
38    ) -> ChatSession {
39        ChatSession::with_config(self.inner.clone(), model.into(), config)
40    }
41}
42
43/// Chat 会话。
44#[derive(Clone)]
45pub struct ChatSession {
46    client: Arc<ClientInner>,
47    model: String,
48    history: Arc<RwLock<Vec<Content>>>,
49    config: GenerateContentConfig,
50}
51
52impl ChatSession {
53    fn new(client: Arc<ClientInner>, model: String) -> Self {
54        Self {
55            client,
56            model,
57            history: Arc::new(RwLock::new(Vec::new())),
58            config: GenerateContentConfig::default(),
59        }
60    }
61
62    fn with_config(client: Arc<ClientInner>, model: String, config: GenerateContentConfig) -> Self {
63        Self {
64            client,
65            model,
66            history: Arc::new(RwLock::new(Vec::new())),
67            config,
68        }
69    }
70
71    /// 发送消息。
72    ///
73    /// # Errors
74    /// 当请求失败或响应解析失败时返回错误。
75    pub async fn send_message(
76        &self,
77        message: impl Into<String>,
78    ) -> Result<GenerateContentResponse> {
79        let user_content = Content::text(message);
80
81        {
82            let mut history = self.history.write().await;
83            history.push(user_content.clone());
84        }
85
86        let models = Models::new(self.client.clone());
87        let history = self.history.read().await.clone();
88
89        let response = models
90            .generate_content_with_config(&self.model, history, self.config.clone())
91            .await?;
92
93        if let Some(candidate) = response.candidates.first() {
94            if let Some(content) = &candidate.content {
95                let mut history = self.history.write().await;
96                history.push(content.clone());
97            }
98        }
99
100        Ok(response)
101    }
102
103    /// 发送消息(兼容别名)。
104    ///
105    /// # Errors
106    /// 当请求失败或响应解析失败时返回错误。
107    pub async fn send(&self, message: impl Into<String>) -> Result<GenerateContentResponse> {
108        self.send_message(message).await
109    }
110
111    /// 流式发送消息。
112    ///
113    /// # Errors
114    /// 当请求失败或响应解析失败时返回错误。
115    pub async fn send_message_stream(
116        &self,
117        message: impl Into<String>,
118    ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
119        let user_content = Content::text(message);
120
121        {
122            let mut history = self.history.write().await;
123            history.push(user_content.clone());
124        }
125
126        let models = Models::new(self.client.clone());
127        let history = self.history.read().await.clone();
128
129        let stream = models
130            .generate_content_stream(&self.model, history, self.config.clone())
131            .await?;
132
133        let history_ref = self.history.clone();
134        let (tx, rx) = tokio::sync::mpsc::channel(8);
135
136        tokio::spawn(async move {
137            let mut stream = stream;
138            let mut last_content: Option<Content> = None;
139
140            while let Some(item) = stream.next().await {
141                if let Ok(response) = &item {
142                    if let Some(candidate) = response.candidates.first() {
143                        if let Some(content) = &candidate.content {
144                            last_content = Some(content.clone());
145                        }
146                    }
147                }
148
149                if tx.send(item).await.is_err() {
150                    break;
151                }
152            }
153
154            if let Some(content) = last_content {
155                let mut history = history_ref.write().await;
156                history.push(content);
157            }
158        });
159
160        let output = futures_util::stream::unfold(rx, |mut rx| async {
161            rx.recv().await.map(|item| (item, rx))
162        });
163
164        Ok(output)
165    }
166
167    /// 流式发送消息(兼容别名)。
168    ///
169    /// # Errors
170    /// 当请求失败或响应解析失败时返回错误。
171    pub async fn send_stream(
172        &self,
173        message: impl Into<String>,
174    ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
175        self.send_message_stream(message).await
176    }
177
178    /// 发送消息(自动函数调用 + callable tools)。
179    ///
180    /// # Errors
181    /// 当请求失败或响应解析失败时返回错误。
182    pub async fn send_message_with_callable_tools(
183        &self,
184        message: impl Into<String>,
185        callable_tools: Vec<Box<dyn CallableTool>>,
186    ) -> Result<GenerateContentResponse> {
187        let user_content = Content::text(message);
188
189        {
190            let mut history = self.history.write().await;
191            history.push(user_content.clone());
192        }
193
194        let models = Models::new(self.client.clone());
195        let history = self.history.read().await.clone();
196
197        let response = models
198            .generate_content_with_callable_tools(
199                &self.model,
200                history,
201                self.config.clone(),
202                callable_tools,
203            )
204            .await?;
205
206        if let Some(afc_history) = response.automatic_function_calling_history.clone() {
207            let mut history = self.history.write().await;
208            *history = afc_history;
209        }
210
211        if let Some(candidate) = response.candidates.first() {
212            if let Some(content) = &candidate.content {
213                let mut history = self.history.write().await;
214                history.push(content.clone());
215            }
216        }
217
218        Ok(response)
219    }
220
221    /// 流式发送消息(自动函数调用 + callable tools)。
222    ///
223    /// # Errors
224    /// 当请求失败或响应解析失败时返回错误。
225    pub async fn send_message_stream_with_callable_tools(
226        &self,
227        message: impl Into<String>,
228        callable_tools: Vec<Box<dyn CallableTool>>,
229    ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
230        let user_content = Content::text(message);
231
232        {
233            let mut history = self.history.write().await;
234            history.push(user_content.clone());
235        }
236
237        let models = Models::new(self.client.clone());
238        let history = self.history.read().await.clone();
239
240        let stream = models
241            .generate_content_stream_with_callable_tools(
242                &self.model,
243                history,
244                self.config.clone(),
245                callable_tools,
246            )
247            .await?;
248
249        let history_ref = self.history.clone();
250        let (tx, rx) = tokio::sync::mpsc::channel(8);
251
252        tokio::spawn(async move {
253            let mut stream = stream;
254            let mut last_content: Option<Content> = None;
255            let mut last_afc_history: Option<Vec<Content>> = None;
256
257            while let Some(item) = stream.next().await {
258                if let Ok(response) = &item {
259                    if let Some(content) = response
260                        .candidates
261                        .first()
262                        .and_then(|candidate| candidate.content.clone())
263                    {
264                        last_content = Some(content);
265                    }
266
267                    if let Some(history) = response.automatic_function_calling_history.clone() {
268                        last_afc_history = Some(history);
269                    }
270                }
271
272                if tx.send(item).await.is_err() {
273                    break;
274                }
275            }
276
277            if let Some(history) = last_afc_history {
278                let mut history_ref = history_ref.write().await;
279                *history_ref = history;
280            }
281
282            if let Some(content) = last_content {
283                let mut history = history_ref.write().await;
284                history.push(content);
285            }
286        });
287
288        let output = futures_util::stream::unfold(rx, |mut rx| async {
289            rx.recv().await.map(|item| (item, rx))
290        });
291
292        Ok(output)
293    }
294
295    /// 获取历史。
296    pub async fn history(&self) -> Vec<Content> {
297        self.history.read().await.clone()
298    }
299
300    /// 清空历史。
301    pub async fn clear_history(&self) {
302        self.history.write().await.clear();
303    }
304}