rusty_gemini/
model.rs

1use core::str;
2use std::{borrow::Cow, fmt::Display};
3
4use futures_util::{Stream, StreamExt};
5
6use crate::{
7    api::{GeminiGenericErrorResponse, GenerationConfig, SafetySetting, Tool},
8    chat::ChatSession,
9    content::Content,
10    error::{GeminiError, GeminiErrorKind},
11    EmbedContentConfig, EmbedContentRequest, EmbedContentResponse, GeminiRequest, GeminiResponse,
12};
13
14/// The base URL for the Gemini API.
15pub static BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
16
17/// Represents a Generative Model instance.
18#[derive(Debug, Clone)]
19pub struct GenerativeModel {
20    /// The API key used to authenticate requests.
21    pub api_key: String,
22    /// The specific Gemini model to use (e.g., Pro_1_5, Flash_1_5).
23    pub model: GeminiModel,
24    /// Optional configuration for content generation.
25    pub generation_config: Option<GenerationConfig>,
26    /// Optional instructions given to the model before the prompt.
27    pub system_instruction: Option<Content>,
28    /// Optional safety settings to control the content generated by the model.
29    pub safety_settings: Option<Vec<SafetySetting>>,
30    /// Optional tools that the model can use.
31    pub tools: Option<Vec<Tool>>,
32}
33
34/// A builder for creating a `GenerativeModel`.
35#[derive(Debug, Clone)]
36pub struct GenerativeModelBuilder {
37    pub api_key: Option<String>,
38    pub model: Option<GeminiModel>,
39    pub system_instruction: Option<Content>,
40    pub safety_settings: Option<Vec<SafetySetting>>,
41    pub generation_config: Option<GenerationConfig>,
42    pub tools: Option<Vec<Tool>>,
43}
44
45impl GenerativeModelBuilder {
46    /// Creates a new `GenerativeModelBuilder` with default values.
47    pub fn new() -> Self {
48        Self {
49            api_key: None,
50            model: None,
51            system_instruction: None,
52            safety_settings: None,
53            generation_config: None,
54            tools: None,
55        }
56    }
57
58    /// Sets the API key for the `GenerativeModel`.
59    pub fn api_key(&mut self, api_key: &str) -> &mut Self {
60        self.api_key = Some(api_key.to_string());
61        self
62    }
63
64    /// Sets the specific `GeminiModel` to be used.
65    pub fn model(&mut self, model: GeminiModel) -> &mut Self {
66        self.model = Some(model);
67        self
68    }
69
70    /// Sets the system instruction for the `GenerativeModel`.
71    pub fn system_instruction(&mut self, system_instruction: impl Into<Content>) -> &mut Self {
72        self.system_instruction = Some(system_instruction.into());
73        self
74    }
75
76    /// Sets the generation configuration for the `GenerativeModel`.
77    pub fn generation_config(&mut self, config: GenerationConfig) -> &mut Self {
78        self.generation_config = Some(config);
79        self
80    }
81
82    /// Adds a safety setting to the `GenerativeModel`.
83    pub fn safety_setting(&mut self, setting: SafetySetting) -> &mut Self {
84        if let Some(ref mut x) = self.safety_settings {
85            x.push(setting);
86        } else {
87            self.safety_settings = Some(vec![setting]);
88        }
89        self
90    }
91
92    /// Adds a tool to the `GenerativeModel`.
93    pub fn tool(&mut self, tool: Tool) -> &mut Self {
94        if let Some(ref mut x) = self.tools {
95            x.push(tool);
96        } else {
97            self.tools = Some(vec![tool]);
98        }
99        self
100    }
101
102    /// Builds the `GenerativeModel` with the configured values.
103    ///
104    /// # Panics
105    ///
106    /// Panics if the `api_key` is not set.
107    pub fn build(&mut self) -> GenerativeModel {
108        GenerativeModel {
109            api_key: self.api_key.take().expect("API key must be set"),
110            model: self.model.take().unwrap_or_default(),
111            generation_config: self.generation_config.take(),
112            system_instruction: self.system_instruction.take(),
113            safety_settings: self.safety_settings.take(),
114            tools: self.tools.take(),
115        }
116    }
117}
118
119impl GenerativeModel {
120    /// Starts a new chat session with the given history.
121    pub fn start_chat(&self, history: Vec<Content>) -> ChatSession {
122        ChatSession {
123            model: self.clone(),
124            history,
125        }
126    }
127
128    /// Generates content based on the provided prompt.
129    pub async fn generate_content(
130        &self,
131        prompt: Vec<Content>,
132    ) -> Result<GeminiResponse, GeminiError> {
133        self.generate_content_with(prompt, GenerativeModelBuilder::new())
134            .await
135    }
136
137    /// Generates a stream of content responses based on the provided prompt.
138    pub async fn generate_content_streamed(
139        &self,
140        prompt: Vec<Content>,
141    ) -> Result<impl Stream<Item = Result<GeminiResponse, GeminiError>>, GeminiError> {
142        self.generate_content_streamed_with(prompt, GenerativeModelBuilder::new())
143            .await
144    }
145
146    /// Generates content based on the provided prompt, overriding some of the model's configurations using the provided builder.
147    pub async fn generate_content_with(
148        &self,
149        prompt: Vec<Content>,
150        config: GenerativeModelBuilder,
151    ) -> Result<GeminiResponse, GeminiError> {
152        let response = self.send_request(prompt, config, false).await?;
153
154        let text = response.text().await.map_err(|err| GeminiError {
155            kind: GeminiErrorKind::Other,
156            message: err.to_string(),
157        })?;
158
159        if let Ok(response) = serde_json::from_str::<GeminiResponse>(&text) {
160            Ok(response)
161        } else {
162            Err(serde_json::from_str::<GeminiGenericErrorResponse>(&text)
163                .map(|x| GeminiError::from(x.error))
164                .unwrap_or_else(|x| GeminiError::message(&x.to_string())))
165        }
166    }
167
168    /// Generates a stream of content responses based on the provided prompt, overriding some of the model's configurations using the provided builder.
169    pub async fn generate_content_streamed_with(
170        &self,
171        prompt: Vec<Content>,
172        config: GenerativeModelBuilder,
173    ) -> Result<impl Stream<Item = Result<GeminiResponse, GeminiError>>, GeminiError> {
174        let response = self.send_request(prompt, config, true).await?;
175
176        let stream = response.bytes_stream().filter_map(|chunk| async move {
177            match chunk {
178                Ok(chunk) => {
179                    // we skip either '[' (which happens in the first chunk) or ',' in the subsequent chunks
180                    let str = &str::from_utf8(&chunk)
181                        .expect("Unexpected: this should not happen. Please report this bug to rusty-gemini repo.")[1..];
182
183                    // in the last chunk, str should be empty
184                    if str.is_empty() {
185                        None
186                    } else if let Ok(response) = serde_json::from_str::<GeminiResponse>(&str) {
187                        Some(Ok(response))
188                    } else {
189                        Some(Err(serde_json::from_str::<GeminiGenericErrorResponse>(
190                            &str,
191                        )
192                        .map(|x| GeminiError::from(x.error))
193                        .unwrap_or_else(|err| GeminiError::message(&err.to_string()))))
194                    }
195                }
196                Err(err) => Some(Err(GeminiError::message(&err.to_string()))),
197            }
198        });
199        Ok(stream)
200    }
201
202    /// Embeds the content using the model's embedding capabilities.
203    pub async fn embed_content(
204        &self,
205        content: impl Into<Content>,
206        config: EmbedContentConfig,
207    ) -> Result<EmbedContentResponse, GeminiError> {
208        let content = content.into();
209        let request = EmbedContentRequest { content, config };
210
211        let client = reqwest::Client::new();
212        let response = client
213            .post(format!(
214                "{BASE_URL}/models/{}:embedContent?key={}",
215                self.model, self.api_key
216            ))
217            .json(&request)
218            .send()
219            .await
220            .map_err(|err| GeminiError::message(&err.to_string()))?;
221
222        let text = response
223            .text()
224            .await
225            .map_err(|err| GeminiError::message(&err.to_string()))?;
226        if let Ok(response) = serde_json::from_str::<EmbedContentResponse>(&text) {
227            Ok(response)
228        } else {
229            Err(serde_json::from_str::<GeminiGenericErrorResponse>(&text)
230                .map(|x| GeminiError::from(x.error))
231                .unwrap_or_else(|x| GeminiError::message(&x.to_string())))
232        }
233    }
234
235    async fn send_request(
236        &self,
237        prompt: Vec<Content>,
238        config: GenerativeModelBuilder,
239        stream: bool,
240    ) -> Result<reqwest::Response, GeminiError> {
241        let request = GeminiRequest {
242            contents: prompt,
243            tools: config.tools.or_else(|| self.tools.clone()),
244            safety_settings: config
245                .safety_settings
246                .or_else(|| self.safety_settings.clone()),
247            system_instruction: config
248                .system_instruction
249                .or_else(|| self.system_instruction.clone()),
250            generation_config: config
251                .generation_config
252                .or_else(|| self.generation_config.clone()),
253        };
254        let client = reqwest::Client::new();
255        let suffix = if stream {
256            "streamGenerateContent"
257        } else {
258            "generateContent"
259        };
260        let response = client
261            .post(format!(
262                "{BASE_URL}/models/{}:{}?key={}",
263                config.model.as_ref().unwrap_or(&self.model),
264                suffix,
265                self.api_key
266            ))
267            .json(&request)
268            .send()
269            .await
270            .map_err(|err| GeminiError {
271                kind: GeminiErrorKind::Other,
272                message: err.to_string(),
273            })?;
274        Ok(response)
275    }
276}
277
278/// Represents the different Gemini models available.
279#[derive(Debug, Default, Clone)]
280#[allow(non_camel_case_types)]
281pub enum GeminiModel {
282    #[default]
283    Flash_2_5,
284    Pro_2_5,
285    Flash_2_5_Lite,
286    /// The Gemini 1.5 Pro model.
287    Pro_1_5,
288    /// The Gemini 1.5 Flash model.
289    Flash_1_5,
290    /// The Gemini 1.5 Flash 8B model.
291    Flash_1_5_8B,
292    /// The Text Embedding 004 model.
293    TextEmbedding004,
294    /// A custom Gemini model specified by its name.
295    Custom(Cow<'static, str>),
296}
297
298impl Display for GeminiModel {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        write!(
301            f,
302            "{}",
303            match self {
304                GeminiModel::Flash_2_5 => "gemini-2.5-flash",
305                GeminiModel::Pro_2_5 => "gemini-2.5-pro",
306                GeminiModel::Flash_2_5_Lite => "gemini-2.5-flash-lite",
307                GeminiModel::Pro_1_5 => "gemini-1.5-pro",
308                GeminiModel::Flash_1_5 => "gemini-1.5-flash",
309                GeminiModel::Flash_1_5_8B => "gemini-1.5-flash-8b",
310                GeminiModel::TextEmbedding004 => "text-embedding-004",
311                GeminiModel::Custom(custom) => custom,
312                
313            }
314        )
315    }
316}