1use core::str;
2use std::{borrow::Cow, fmt::Display};
3
4use futures_util::{Stream, StreamExt};
5
6use crate::{
7 api::{GeminiGenericErrorResponse, GenerationConfig, SafetySetting, Tool},
8 chat::ChatSession,
9 content::Content,
10 error::{GeminiError, GeminiErrorKind},
11 EmbedContentConfig, EmbedContentRequest, EmbedContentResponse, GeminiRequest, GeminiResponse,
12};
13
14pub static BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
16
17#[derive(Debug, Clone)]
19pub struct GenerativeModel {
20 pub api_key: String,
22 pub model: GeminiModel,
24 pub generation_config: Option<GenerationConfig>,
26 pub system_instruction: Option<Content>,
28 pub safety_settings: Option<Vec<SafetySetting>>,
30 pub tools: Option<Vec<Tool>>,
32}
33
34#[derive(Debug, Clone)]
36pub struct GenerativeModelBuilder {
37 pub api_key: Option<String>,
38 pub model: Option<GeminiModel>,
39 pub system_instruction: Option<Content>,
40 pub safety_settings: Option<Vec<SafetySetting>>,
41 pub generation_config: Option<GenerationConfig>,
42 pub tools: Option<Vec<Tool>>,
43}
44
45impl GenerativeModelBuilder {
46 pub fn new() -> Self {
48 Self {
49 api_key: None,
50 model: None,
51 system_instruction: None,
52 safety_settings: None,
53 generation_config: None,
54 tools: None,
55 }
56 }
57
58 pub fn api_key(&mut self, api_key: &str) -> &mut Self {
60 self.api_key = Some(api_key.to_string());
61 self
62 }
63
64 pub fn model(&mut self, model: GeminiModel) -> &mut Self {
66 self.model = Some(model);
67 self
68 }
69
70 pub fn system_instruction(&mut self, system_instruction: impl Into<Content>) -> &mut Self {
72 self.system_instruction = Some(system_instruction.into());
73 self
74 }
75
76 pub fn generation_config(&mut self, config: GenerationConfig) -> &mut Self {
78 self.generation_config = Some(config);
79 self
80 }
81
82 pub fn safety_setting(&mut self, setting: SafetySetting) -> &mut Self {
84 if let Some(ref mut x) = self.safety_settings {
85 x.push(setting);
86 } else {
87 self.safety_settings = Some(vec![setting]);
88 }
89 self
90 }
91
92 pub fn tool(&mut self, tool: Tool) -> &mut Self {
94 if let Some(ref mut x) = self.tools {
95 x.push(tool);
96 } else {
97 self.tools = Some(vec![tool]);
98 }
99 self
100 }
101
102 pub fn build(&mut self) -> GenerativeModel {
108 GenerativeModel {
109 api_key: self.api_key.take().expect("API key must be set"),
110 model: self.model.take().unwrap_or_default(),
111 generation_config: self.generation_config.take(),
112 system_instruction: self.system_instruction.take(),
113 safety_settings: self.safety_settings.take(),
114 tools: self.tools.take(),
115 }
116 }
117}
118
119impl GenerativeModel {
120 pub fn start_chat(&self, history: Vec<Content>) -> ChatSession {
122 ChatSession {
123 model: self.clone(),
124 history,
125 }
126 }
127
128 pub async fn generate_content(
130 &self,
131 prompt: Vec<Content>,
132 ) -> Result<GeminiResponse, GeminiError> {
133 self.generate_content_with(prompt, GenerativeModelBuilder::new())
134 .await
135 }
136
137 pub async fn generate_content_streamed(
139 &self,
140 prompt: Vec<Content>,
141 ) -> Result<impl Stream<Item = Result<GeminiResponse, GeminiError>>, GeminiError> {
142 self.generate_content_streamed_with(prompt, GenerativeModelBuilder::new())
143 .await
144 }
145
146 pub async fn generate_content_with(
148 &self,
149 prompt: Vec<Content>,
150 config: GenerativeModelBuilder,
151 ) -> Result<GeminiResponse, GeminiError> {
152 let response = self.send_request(prompt, config, false).await?;
153
154 let text = response.text().await.map_err(|err| GeminiError {
155 kind: GeminiErrorKind::Other,
156 message: err.to_string(),
157 })?;
158
159 if let Ok(response) = serde_json::from_str::<GeminiResponse>(&text) {
160 Ok(response)
161 } else {
162 Err(serde_json::from_str::<GeminiGenericErrorResponse>(&text)
163 .map(|x| GeminiError::from(x.error))
164 .unwrap_or_else(|x| GeminiError::message(&x.to_string())))
165 }
166 }
167
168 pub async fn generate_content_streamed_with(
170 &self,
171 prompt: Vec<Content>,
172 config: GenerativeModelBuilder,
173 ) -> Result<impl Stream<Item = Result<GeminiResponse, GeminiError>>, GeminiError> {
174 let response = self.send_request(prompt, config, true).await?;
175
176 let stream = response.bytes_stream().filter_map(|chunk| async move {
177 match chunk {
178 Ok(chunk) => {
179 let str = &str::from_utf8(&chunk)
181 .expect("Unexpected: this should not happen. Please report this bug to rusty-gemini repo.")[1..];
182
183 if str.is_empty() {
185 None
186 } else if let Ok(response) = serde_json::from_str::<GeminiResponse>(&str) {
187 Some(Ok(response))
188 } else {
189 Some(Err(serde_json::from_str::<GeminiGenericErrorResponse>(
190 &str,
191 )
192 .map(|x| GeminiError::from(x.error))
193 .unwrap_or_else(|err| GeminiError::message(&err.to_string()))))
194 }
195 }
196 Err(err) => Some(Err(GeminiError::message(&err.to_string()))),
197 }
198 });
199 Ok(stream)
200 }
201
202 pub async fn embed_content(
204 &self,
205 content: impl Into<Content>,
206 config: EmbedContentConfig,
207 ) -> Result<EmbedContentResponse, GeminiError> {
208 let content = content.into();
209 let request = EmbedContentRequest { content, config };
210
211 let client = reqwest::Client::new();
212 let response = client
213 .post(format!(
214 "{BASE_URL}/models/{}:embedContent?key={}",
215 self.model, self.api_key
216 ))
217 .json(&request)
218 .send()
219 .await
220 .map_err(|err| GeminiError::message(&err.to_string()))?;
221
222 let text = response
223 .text()
224 .await
225 .map_err(|err| GeminiError::message(&err.to_string()))?;
226 if let Ok(response) = serde_json::from_str::<EmbedContentResponse>(&text) {
227 Ok(response)
228 } else {
229 Err(serde_json::from_str::<GeminiGenericErrorResponse>(&text)
230 .map(|x| GeminiError::from(x.error))
231 .unwrap_or_else(|x| GeminiError::message(&x.to_string())))
232 }
233 }
234
235 async fn send_request(
236 &self,
237 prompt: Vec<Content>,
238 config: GenerativeModelBuilder,
239 stream: bool,
240 ) -> Result<reqwest::Response, GeminiError> {
241 let request = GeminiRequest {
242 contents: prompt,
243 tools: config.tools.or_else(|| self.tools.clone()),
244 safety_settings: config
245 .safety_settings
246 .or_else(|| self.safety_settings.clone()),
247 system_instruction: config
248 .system_instruction
249 .or_else(|| self.system_instruction.clone()),
250 generation_config: config
251 .generation_config
252 .or_else(|| self.generation_config.clone()),
253 };
254 let client = reqwest::Client::new();
255 let suffix = if stream {
256 "streamGenerateContent"
257 } else {
258 "generateContent"
259 };
260 let response = client
261 .post(format!(
262 "{BASE_URL}/models/{}:{}?key={}",
263 config.model.as_ref().unwrap_or(&self.model),
264 suffix,
265 self.api_key
266 ))
267 .json(&request)
268 .send()
269 .await
270 .map_err(|err| GeminiError {
271 kind: GeminiErrorKind::Other,
272 message: err.to_string(),
273 })?;
274 Ok(response)
275 }
276}
277
278#[derive(Debug, Default, Clone)]
280#[allow(non_camel_case_types)]
281pub enum GeminiModel {
282 #[default]
283 Flash_2_5,
284 Pro_2_5,
285 Flash_2_5_Lite,
286 Pro_1_5,
288 Flash_1_5,
290 Flash_1_5_8B,
292 TextEmbedding004,
294 Custom(Cow<'static, str>),
296}
297
298impl Display for GeminiModel {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 write!(
301 f,
302 "{}",
303 match self {
304 GeminiModel::Flash_2_5 => "gemini-2.5-flash",
305 GeminiModel::Pro_2_5 => "gemini-2.5-pro",
306 GeminiModel::Flash_2_5_Lite => "gemini-2.5-flash-lite",
307 GeminiModel::Pro_1_5 => "gemini-1.5-pro",
308 GeminiModel::Flash_1_5 => "gemini-1.5-flash",
309 GeminiModel::Flash_1_5_8B => "gemini-1.5-flash-8b",
310 GeminiModel::TextEmbedding004 => "text-embedding-004",
311 GeminiModel::Custom(custom) => custom,
312
313 }
314 )
315 }
316}