1use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10
11pub struct GoogleProvider {
12 client: Client,
13 api_key: String,
14 base_url: String,
15}
16
17impl GoogleProvider {
18 pub fn new(api_key: String) -> Self {
19 Self {
20 client: Client::new(),
21 api_key,
22 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
23 }
24 }
25}
26
27impl Default for GoogleProvider {
28 fn default() -> Self {
29 let api_key = std::env::var("GOOGLE_API_KEY").unwrap_or_default();
30 Self::new(api_key)
31 }
32}
33
34#[async_trait]
35impl LlmProvider for GoogleProvider {
36 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
37 let url = format!(
38 "{}/models/{}:generateContent?key={}",
39 self.base_url, request.model, self.api_key
40 );
41
42 let google_request = GoogleChatRequest {
43 contents: request.messages.into_iter().map(Into::into).collect(),
44 generation_config: Some(GoogleGenerationConfig {
45 temperature: request.temperature,
46 max_output_tokens: request.max_tokens,
47 }),
48 };
49
50 let response = self
51 .client
52 .post(&url)
53 .json(&google_request)
54 .send()
55 .await
56 .context("Failed to send request to Google Gemini")?;
57
58 if !response.status().is_success() {
59 let error_text = response.text().await?;
60 return Err(anyhow::anyhow!("Google Gemini API error: {}", error_text));
61 }
62
63 let google_response: GoogleChatResponse = response.json().await?;
64
65 let content = google_response
66 .candidates
67 .first()
68 .and_then(|c| c.content.parts.first())
69 .map(|p| p.text.clone())
70 .unwrap_or_default();
71
72 Ok(LlmResponse {
73 content,
74 model: request.model,
75 usage: google_response.usage_metadata.map(Into::into),
76 })
77 }
78
79 fn name(&self) -> &'static str {
80 "Google"
81 }
82}
83
84#[derive(Debug, Serialize)]
85struct GoogleChatRequest {
86 contents: Vec<GoogleContent>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 generation_config: Option<GoogleGenerationConfig>,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct GoogleContent {
93 role: String,
94 parts: Vec<GooglePart>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct GooglePart {
99 text: String,
100}
101
102impl From<LlmMessage> for GoogleContent {
103 fn from(msg: LlmMessage) -> Self {
104 Self {
105 role: match msg.role {
106 LlmRole::System => "user".to_string(), LlmRole::User => "user".to_string(),
108 LlmRole::Assistant => "model".to_string(),
109 },
110 parts: vec![GooglePart { text: msg.content }],
111 }
112 }
113}
114
115#[derive(Debug, Serialize)]
116struct GoogleGenerationConfig {
117 #[serde(skip_serializing_if = "Option::is_none")]
118 temperature: Option<f32>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 max_output_tokens: Option<usize>,
121}
122
123#[derive(Debug, Deserialize)]
124struct GoogleChatResponse {
125 candidates: Vec<GoogleCandidate>,
126 #[serde(rename = "usageMetadata")]
127 usage_metadata: Option<GoogleUsageMetadata>,
128}
129
130#[derive(Debug, Deserialize)]
131struct GoogleCandidate {
132 content: GoogleContent,
133}
134
135#[derive(Debug, Deserialize)]
136struct GoogleUsageMetadata {
137 #[serde(rename = "promptTokenCount")]
138 prompt_token_count: usize,
139 #[serde(rename = "candidatesTokenCount")]
140 candidates_token_count: usize,
141 #[serde(rename = "totalTokenCount")]
142 total_token_count: usize,
143}
144
145impl From<GoogleUsageMetadata> for LlmUsage {
146 fn from(usage: GoogleUsageMetadata) -> Self {
147 Self {
148 prompt_tokens: usage.prompt_token_count,
149 completion_tokens: usage.candidates_token_count,
150 total_tokens: usage.total_token_count,
151 }
152 }
153}