vtcode_core/gemini/client/
mod.rs

1pub mod config;
2pub mod retry;
3
4pub use config::ClientConfig;
5pub use retry::RetryConfig;
6
7use crate::gemini::models::{GenerateContentRequest, GenerateContentResponse};
8use crate::gemini::streaming::{
9    StreamingError, StreamingMetrics, StreamingProcessor, StreamingResponse,
10};
11use anyhow::{Context, Result};
12use reqwest::Client as ReqwestClient;
13use std::time::Instant;
14
15#[derive(Clone)]
16pub struct Client {
17    api_key: String,
18    model: String,
19    http: ReqwestClient,
20    config: ClientConfig,
21    retry_config: RetryConfig,
22    metrics: StreamingMetrics,
23}
24
25impl Client {
26    pub fn new(api_key: String, model: String) -> Self {
27        Self::with_config(api_key, model, ClientConfig::default())
28    }
29
30    /// Create a client with custom configuration
31    pub fn with_config(api_key: String, model: String, config: ClientConfig) -> Self {
32        let http_client = ReqwestClient::builder()
33            .pool_max_idle_per_host(config.pool_max_idle_per_host)
34            .pool_idle_timeout(config.pool_idle_timeout)
35            .tcp_keepalive(config.tcp_keepalive)
36            .timeout(config.request_timeout)
37            .connect_timeout(config.connect_timeout)
38            .user_agent(&config.user_agent)
39            .build()
40            .expect("Failed to build HTTP client");
41
42        Self {
43            api_key,
44            model,
45            http: http_client,
46            config,
47            retry_config: RetryConfig::default(),
48            metrics: StreamingMetrics::default(),
49        }
50    }
51
52    /// Get current client configuration
53    pub fn config(&self) -> &ClientConfig {
54        &self.config
55    }
56
57    /// Set retry configuration
58    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
59        self.retry_config = retry_config;
60        self
61    }
62
63    /// Get current retry configuration
64    pub fn retry_config(&self) -> &RetryConfig {
65        &self.retry_config
66    }
67
68    /// Get streaming metrics
69    pub fn metrics(&self) -> &StreamingMetrics {
70        &self.metrics
71    }
72
73    /// Reset streaming metrics
74    pub fn reset_metrics(&mut self) {
75        self.metrics = StreamingMetrics::default();
76    }
77
78    /// Classify error to determine if it's retryable
79    fn classify_error(&self, error: &anyhow::Error) -> StreamingError {
80        let error_str = error.to_string().to_lowercase();
81
82        if error_str.contains("timeout")
83            || error_str.contains("connection")
84            || error_str.contains("network")
85        {
86            StreamingError::NetworkError {
87                message: error.to_string(),
88                is_retryable: true,
89            }
90        } else if error_str.contains("rate limit") || error_str.contains("429") {
91            StreamingError::ApiError {
92                status_code: 429,
93                message: "Rate limit exceeded".to_string(),
94                is_retryable: true,
95            }
96        } else {
97            StreamingError::NetworkError {
98                message: error.to_string(),
99                is_retryable: false,
100            }
101        }
102    }
103
104    /// Generate content with the Gemini API
105    pub async fn generate(
106        &mut self,
107        request: &GenerateContentRequest,
108    ) -> Result<GenerateContentResponse> {
109        let start_time = Instant::now();
110
111        let url = format!(
112            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
113            self.model, self.api_key
114        );
115
116        let response = self
117            .http
118            .post(&url)
119            .json(request)
120            .send()
121            .await
122            .context("Failed to send request")?;
123
124        if !response.status().is_success() {
125            let status = response.status();
126            let error_text = response.text().await.unwrap_or_default();
127            return Err(anyhow::anyhow!("API error {}: {}", status, error_text));
128        }
129
130        let response_data: GenerateContentResponse =
131            response.json().await.context("Failed to parse response")?;
132
133        self.metrics.total_requests += 1;
134        self.metrics.total_response_time += start_time.elapsed();
135
136        Ok(response_data)
137    }
138
139    /// Generate content with the Gemini API using streaming
140    pub async fn generate_stream<F>(
141        &mut self,
142        request: &GenerateContentRequest,
143        on_chunk: F,
144    ) -> Result<StreamingResponse, StreamingError>
145    where
146        F: FnMut(&str) -> Result<(), StreamingError>,
147    {
148        let start_time = Instant::now();
149
150        let url = format!(
151            "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?key={}",
152            self.model, self.api_key
153        );
154
155        let response = self
156            .http
157            .post(&url)
158            .json(request)
159            .send()
160            .await
161            .map_err(|e| {
162                let error = anyhow::Error::new(e);
163                self.classify_error(&error)
164            })?;
165
166        if !response.status().is_success() {
167            let status = response.status();
168            let error_text = response.text().await.unwrap_or_default();
169
170            let is_retryable = match status.as_u16() {
171                429 | 500 | 502 | 503 | 504 => true,
172                _ => false,
173            };
174
175            return Err(StreamingError::ApiError {
176                status_code: status.as_u16(),
177                message: error_text,
178                is_retryable,
179            });
180        }
181
182        // Process the streaming response
183        let mut processor = StreamingProcessor::new();
184        let result = processor.process_stream(response, on_chunk).await;
185
186        self.metrics.total_requests += 1;
187        self.metrics.total_response_time += start_time.elapsed();
188
189        result
190    }
191}