1use std::sync::Arc;
4
5use futures_util::Stream;
6use futures_util::StreamExt;
7use tokio::sync::RwLock;
8
9use rust_genai_types::content::Content;
10use rust_genai_types::models::GenerateContentConfig;
11use rust_genai_types::response::GenerateContentResponse;
12
13use crate::afc::CallableTool;
14use crate::client::ClientInner;
15use crate::error::Result;
16use crate::models::Models;
17
18#[derive(Clone)]
19pub struct Chats {
20 pub(crate) inner: Arc<ClientInner>,
21}
22
23impl Chats {
24 pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
25 Self { inner }
26 }
27
28 pub fn create(&self, model: impl Into<String>) -> ChatSession {
30 ChatSession::new(self.inner.clone(), model.into())
31 }
32
33 pub fn create_with_config(
35 &self,
36 model: impl Into<String>,
37 config: GenerateContentConfig,
38 ) -> ChatSession {
39 ChatSession::with_config(self.inner.clone(), model.into(), config)
40 }
41}
42
43#[derive(Clone)]
45pub struct ChatSession {
46 client: Arc<ClientInner>,
47 model: String,
48 history: Arc<RwLock<Vec<Content>>>,
49 config: GenerateContentConfig,
50}
51
52impl ChatSession {
53 fn new(client: Arc<ClientInner>, model: String) -> Self {
54 Self {
55 client,
56 model,
57 history: Arc::new(RwLock::new(Vec::new())),
58 config: GenerateContentConfig::default(),
59 }
60 }
61
62 fn with_config(client: Arc<ClientInner>, model: String, config: GenerateContentConfig) -> Self {
63 Self {
64 client,
65 model,
66 history: Arc::new(RwLock::new(Vec::new())),
67 config,
68 }
69 }
70
71 pub async fn send_message(
76 &self,
77 message: impl Into<String>,
78 ) -> Result<GenerateContentResponse> {
79 let user_content = Content::text(message);
80
81 {
82 let mut history = self.history.write().await;
83 history.push(user_content.clone());
84 }
85
86 let models = Models::new(self.client.clone());
87 let history = self.history.read().await.clone();
88
89 let response = models
90 .generate_content_with_config(&self.model, history, self.config.clone())
91 .await?;
92
93 if let Some(candidate) = response.candidates.first() {
94 if let Some(content) = &candidate.content {
95 let mut history = self.history.write().await;
96 history.push(content.clone());
97 }
98 }
99
100 Ok(response)
101 }
102
103 pub async fn send(&self, message: impl Into<String>) -> Result<GenerateContentResponse> {
108 self.send_message(message).await
109 }
110
111 pub async fn send_message_stream(
116 &self,
117 message: impl Into<String>,
118 ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
119 let user_content = Content::text(message);
120
121 {
122 let mut history = self.history.write().await;
123 history.push(user_content.clone());
124 }
125
126 let models = Models::new(self.client.clone());
127 let history = self.history.read().await.clone();
128
129 let stream = models
130 .generate_content_stream(&self.model, history, self.config.clone())
131 .await?;
132
133 let history_ref = self.history.clone();
134 let (tx, rx) = tokio::sync::mpsc::channel(8);
135
136 tokio::spawn(async move {
137 let mut stream = stream;
138 let mut last_content: Option<Content> = None;
139
140 while let Some(item) = stream.next().await {
141 if let Ok(response) = &item {
142 if let Some(candidate) = response.candidates.first() {
143 if let Some(content) = &candidate.content {
144 last_content = Some(content.clone());
145 }
146 }
147 }
148
149 if tx.send(item).await.is_err() {
150 break;
151 }
152 }
153
154 if let Some(content) = last_content {
155 let mut history = history_ref.write().await;
156 history.push(content);
157 }
158 });
159
160 let output = futures_util::stream::unfold(rx, |mut rx| async {
161 rx.recv().await.map(|item| (item, rx))
162 });
163
164 Ok(output)
165 }
166
167 pub async fn send_stream(
172 &self,
173 message: impl Into<String>,
174 ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
175 self.send_message_stream(message).await
176 }
177
178 pub async fn send_message_with_callable_tools(
183 &self,
184 message: impl Into<String>,
185 callable_tools: Vec<Box<dyn CallableTool>>,
186 ) -> Result<GenerateContentResponse> {
187 let user_content = Content::text(message);
188
189 {
190 let mut history = self.history.write().await;
191 history.push(user_content.clone());
192 }
193
194 let models = Models::new(self.client.clone());
195 let history = self.history.read().await.clone();
196
197 let response = models
198 .generate_content_with_callable_tools(
199 &self.model,
200 history,
201 self.config.clone(),
202 callable_tools,
203 )
204 .await?;
205
206 if let Some(afc_history) = response.automatic_function_calling_history.clone() {
207 let mut history = self.history.write().await;
208 *history = afc_history;
209 }
210
211 if let Some(candidate) = response.candidates.first() {
212 if let Some(content) = &candidate.content {
213 let mut history = self.history.write().await;
214 history.push(content.clone());
215 }
216 }
217
218 Ok(response)
219 }
220
221 pub async fn send_message_stream_with_callable_tools(
226 &self,
227 message: impl Into<String>,
228 callable_tools: Vec<Box<dyn CallableTool>>,
229 ) -> Result<impl Stream<Item = Result<GenerateContentResponse>>> {
230 let user_content = Content::text(message);
231
232 {
233 let mut history = self.history.write().await;
234 history.push(user_content.clone());
235 }
236
237 let models = Models::new(self.client.clone());
238 let history = self.history.read().await.clone();
239
240 let stream = models
241 .generate_content_stream_with_callable_tools(
242 &self.model,
243 history,
244 self.config.clone(),
245 callable_tools,
246 )
247 .await?;
248
249 let history_ref = self.history.clone();
250 let (tx, rx) = tokio::sync::mpsc::channel(8);
251
252 tokio::spawn(async move {
253 let mut stream = stream;
254 let mut last_content: Option<Content> = None;
255 let mut last_afc_history: Option<Vec<Content>> = None;
256
257 while let Some(item) = stream.next().await {
258 if let Ok(response) = &item {
259 if let Some(content) = response
260 .candidates
261 .first()
262 .and_then(|candidate| candidate.content.clone())
263 {
264 last_content = Some(content);
265 }
266
267 if let Some(history) = response.automatic_function_calling_history.clone() {
268 last_afc_history = Some(history);
269 }
270 }
271
272 if tx.send(item).await.is_err() {
273 break;
274 }
275 }
276
277 if let Some(history) = last_afc_history {
278 let mut history_ref = history_ref.write().await;
279 *history_ref = history;
280 }
281
282 if let Some(content) = last_content {
283 let mut history = history_ref.write().await;
284 history.push(content);
285 }
286 });
287
288 let output = futures_util::stream::unfold(rx, |mut rx| async {
289 rx.recv().await.map(|item| (item, rx))
290 });
291
292 Ok(output)
293 }
294
295 pub async fn history(&self) -> Vec<Content> {
297 self.history.read().await.clone()
298 }
299
300 pub async fn clear_history(&self) {
302 self.history.write().await.clear();
303 }
304}