Skip to main content

rust_genai/
models.rs

1//! Models API surface.
2
3use std::collections::HashMap;
4use std::hash::BuildHasher;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use futures_util::{Stream, StreamExt};
9use rust_genai_types::content::{Content, FunctionCall, Role};
10use rust_genai_types::converters;
11use rust_genai_types::models::{
12    ComputeTokensConfig, ComputeTokensRequest, ComputeTokensResponse, CountTokensConfig,
13    CountTokensRequest, CountTokensResponse, DeleteModelConfig, DeleteModelResponse,
14    EditImageConfig, EditImageResponse, EmbedContentConfig, EmbedContentResponse,
15    GenerateContentConfig, GenerateContentRequest, GenerateImagesConfig, GenerateImagesResponse,
16    GenerateVideosConfig, GenerateVideosOperation, GenerateVideosSource, Image, ListModelsConfig,
17    ListModelsResponse, Model, RecontextImageConfig, RecontextImageResponse, RecontextImageSource,
18    ReferenceImage, SegmentImageConfig, SegmentImageResponse, SegmentImageSource,
19    UpdateModelConfig,
20};
21use rust_genai_types::response::GenerateContentResponse;
22
23use crate::afc::{
24    call_callable_tools, max_remote_calls, resolve_callable_tools, should_append_history,
25    should_disable_afc, validate_afc_config, validate_afc_tools, CallableTool,
26};
27use crate::client::{Backend, ClientInner};
28use crate::error::{Error, Result};
29use crate::http_response::{
30    sdk_http_response_from_headers, sdk_http_response_from_headers_and_body,
31};
32use crate::model_capabilities::{
33    validate_code_execution_image_inputs, validate_function_response_media,
34};
35use crate::sse::parse_sse_stream;
36use crate::thinking::{validate_temperature, ThoughtSignatureValidator};
37use crate::tokenizer::TokenEstimator;
38use serde_json::Value;
39
40mod builders;
41mod http;
42mod media;
43pub(crate) mod parsers;
44
45use builders::{
46    build_edit_image_body, build_embed_body_gemini, build_embed_body_vertex,
47    build_function_call_content, build_generate_images_body, build_generate_videos_body,
48    build_recontext_image_body, build_segment_image_body, build_upscale_image_body,
49};
50use http::{
51    apply_http_options, build_model_get_url, build_model_get_url_with_options,
52    build_model_method_url, build_models_list_url, merge_extra_body,
53};
54use parsers::{
55    convert_vertex_embed_response, parse_edit_image_response, parse_generate_images_response,
56    parse_generate_videos_operation, parse_recontext_image_response, parse_segment_image_response,
57    parse_upscale_image_response,
58};
59
60#[derive(Clone)]
61pub struct Models {
62    pub(crate) inner: Arc<ClientInner>,
63}
64
65struct CallableStreamContext<S> {
66    models: Models,
67    model: String,
68    contents: Vec<Content>,
69    request_config: GenerateContentConfig,
70    callable_tools: Vec<Box<dyn CallableTool>>,
71    function_map: HashMap<String, usize, S>,
72    max_calls: usize,
73    append_history: bool,
74}
75
76fn build_synthetic_afc_response(
77    response_content: Content,
78    history: &[Content],
79) -> GenerateContentResponse {
80    let mut response = GenerateContentResponse {
81        sdk_http_response: None,
82        candidates: vec![rust_genai_types::response::Candidate {
83            content: Some(response_content),
84            citation_metadata: None,
85            finish_message: None,
86            token_count: None,
87            finish_reason: None,
88            avg_logprobs: None,
89            grounding_metadata: None,
90            index: None,
91            logprobs_result: None,
92            safety_ratings: Vec::new(),
93            url_context_metadata: None,
94        }],
95        create_time: None,
96        automatic_function_calling_history: None,
97        prompt_feedback: None,
98        usage_metadata: None,
99        model_version: None,
100        response_id: None,
101    };
102
103    if !history.is_empty() {
104        response.automatic_function_calling_history = Some(history.to_vec());
105    }
106
107    response
108}
109
110async fn forward_stream_items(
111    mut stream: Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>,
112    tx: &tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
113) -> Option<(Vec<FunctionCall>, Vec<Content>)> {
114    let mut function_calls: Vec<FunctionCall> = Vec::new();
115    let mut response_contents: Vec<Content> = Vec::new();
116
117    while let Some(item) = stream.next().await {
118        if let Ok(response) = &item {
119            if let Some(content) = response.candidates.first().and_then(|c| c.content.clone()) {
120                for part in &content.parts {
121                    if let Some(call) = part.function_call_ref() {
122                        function_calls.push(call.clone());
123                    }
124                }
125                response_contents.push(content);
126            }
127        }
128
129        if tx.send(item).await.is_err() {
130            return None;
131        }
132    }
133
134    Some((function_calls, response_contents))
135}
136
137fn spawn_callable_stream_loop<S>(
138    ctx: CallableStreamContext<S>,
139    tx: tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
140) where
141    S: BuildHasher + Sync + Send + 'static,
142{
143    let CallableStreamContext {
144        models,
145        model,
146        contents,
147        request_config,
148        mut callable_tools,
149        function_map,
150        max_calls,
151        append_history,
152    } = ctx;
153    tokio::spawn(async move {
154        let mut conversation = contents;
155        let mut history: Vec<Content> = Vec::new();
156        let mut remaining_calls = max_calls;
157
158        loop {
159            if remaining_calls == 0 {
160                break;
161            }
162
163            let stream = match models
164                .generate_content_stream(&model, conversation.clone(), request_config.clone())
165                .await
166            {
167                Ok(stream) => stream,
168                Err(err) => {
169                    let _ = tx.send(Err(err)).await;
170                    break;
171                }
172            };
173
174            let Some((function_calls, response_contents)) = forward_stream_items(stream, &tx).await
175            else {
176                return;
177            };
178
179            if function_calls.is_empty() {
180                break;
181            }
182
183            let response_parts = match call_callable_tools(
184                &mut callable_tools,
185                &function_map,
186                &function_calls,
187            )
188            .await
189            {
190                Ok(parts) => parts,
191                Err(err) => {
192                    let _ = tx.send(Err(err)).await;
193                    break;
194                }
195            };
196
197            if response_parts.is_empty() {
198                break;
199            }
200
201            let call_content = build_function_call_content(&function_calls);
202            let response_content = Content::from_parts(response_parts.clone(), Role::Function);
203
204            if append_history {
205                if history.is_empty() {
206                    history.extend(conversation.clone());
207                }
208                history.push(call_content.clone());
209                history.push(response_content.clone());
210            }
211
212            conversation.extend(response_contents);
213            conversation.push(call_content);
214            conversation.push(response_content.clone());
215            remaining_calls = remaining_calls.saturating_sub(1);
216
217            let synthetic = build_synthetic_afc_response(response_content, &history);
218            if tx.send(Ok(synthetic)).await.is_err() {
219                return;
220            }
221        }
222    });
223}
224
225impl Models {
226    pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
227        Self { inner }
228    }
229
230    /// 生成内容(默认配置)。
231    ///
232    /// # Errors
233    ///
234    /// 当请求失败、配置校验失败或响应解析失败时返回错误。
235    pub async fn generate_content(
236        &self,
237        model: impl Into<String>,
238        contents: Vec<Content>,
239    ) -> Result<GenerateContentResponse> {
240        self.generate_content_with_config(model, contents, GenerateContentConfig::default())
241            .await
242    }
243
244    /// 生成内容(自定义配置)。
245    ///
246    /// # Errors
247    ///
248    /// 当请求失败、配置校验失败或响应解析失败时返回错误。
249    pub async fn generate_content_with_config(
250        &self,
251        model: impl Into<String>,
252        contents: Vec<Content>,
253        config: GenerateContentConfig,
254    ) -> Result<GenerateContentResponse> {
255        let should_return_http_response = config.should_return_http_response.unwrap_or(false);
256        let model = model.into();
257        validate_temperature(&model, &config)?;
258        ThoughtSignatureValidator::new(&model).validate(&contents)?;
259        validate_function_response_media(&model, &contents)?;
260        validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
261
262        let backend = self.inner.config.backend;
263        if backend == Backend::GeminiApi && config.model_armor_config.is_some() {
264            return Err(Error::InvalidConfig {
265                message: "model_armor_config is not supported in Gemini API".into(),
266            });
267        }
268        if config.model_armor_config.is_some() && config.safety_settings.is_some() {
269            return Err(Error::InvalidConfig {
270                message: "model_armor_config cannot be combined with safety_settings".into(),
271            });
272        }
273
274        let request = GenerateContentRequest {
275            contents,
276            system_instruction: config.system_instruction,
277            generation_config: config.generation_config,
278            safety_settings: config.safety_settings,
279            model_armor_config: config.model_armor_config,
280            tools: config.tools,
281            tool_config: config.tool_config,
282            cached_content: config.cached_content,
283            labels: config.labels,
284        };
285
286        let url = build_model_method_url(&self.inner, &model, "generateContent")?;
287        let body = match backend {
288            Backend::GeminiApi => converters::generate_content_request_to_mldev(&request)?,
289            Backend::VertexAi => converters::generate_content_request_to_vertex(&request)?,
290        };
291
292        let request = self.inner.http.post(url).json(&body);
293        let response = self.inner.send(request).await?;
294        if !response.status().is_success() {
295            return Err(Error::ApiError {
296                status: response.status().as_u16(),
297                message: response.text().await.unwrap_or_default(),
298            });
299        }
300        let headers = response.headers().clone();
301        if should_return_http_response {
302            let body = response.text().await.unwrap_or_default();
303            return Ok(GenerateContentResponse {
304                sdk_http_response: Some(sdk_http_response_from_headers_and_body(&headers, body)),
305                candidates: Vec::new(),
306                create_time: None,
307                automatic_function_calling_history: None,
308                prompt_feedback: None,
309                usage_metadata: None,
310                model_version: None,
311                response_id: None,
312            });
313        }
314        let value = response.json::<Value>().await?;
315        let mut result = match backend {
316            Backend::GeminiApi => converters::generate_content_response_from_mldev(value)?,
317            Backend::VertexAi => converters::generate_content_response_from_vertex(value)?,
318        };
319        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
320        Ok(result)
321    }
322
323    /// 生成内容(自动函数调用 + callable tools)。
324    ///
325    /// # Errors
326    ///
327    /// 当配置校验失败、自动函数调用执行失败或请求失败时返回错误。
328    pub async fn generate_content_with_callable_tools(
329        &self,
330        model: impl Into<String>,
331        contents: Vec<Content>,
332        config: GenerateContentConfig,
333        mut callable_tools: Vec<Box<dyn CallableTool>>,
334    ) -> Result<GenerateContentResponse> {
335        if config.should_return_http_response.unwrap_or(false) {
336            return Err(Error::InvalidConfig {
337                message: "should_return_http_response is not supported in callable tools methods"
338                    .into(),
339            });
340        }
341        let model = model.into();
342        if callable_tools.is_empty() {
343            return self
344                .generate_content_with_config(model, contents, config)
345                .await;
346        }
347
348        validate_afc_config(&config)?;
349
350        let mut callable_info = resolve_callable_tools(&mut callable_tools).await?;
351        let has_callable = !callable_info.function_map.is_empty();
352        let mut merged_tools = config.tools.clone().unwrap_or_default();
353        merged_tools.append(&mut callable_info.tools);
354
355        let mut request_config = config.clone();
356        request_config.tools = Some(merged_tools);
357
358        if should_disable_afc(&config, has_callable) {
359            return self
360                .generate_content_with_config(model, contents, request_config)
361                .await;
362        }
363
364        validate_afc_tools(&callable_info.function_map, config.tools.as_deref())?;
365
366        let max_calls = max_remote_calls(&config);
367        let append_history = should_append_history(&config);
368        let mut history: Vec<Content> = Vec::new();
369        let mut conversation = contents.clone();
370        let mut remaining_calls = max_calls;
371        let mut response = self
372            .generate_content_with_config(&model, conversation.clone(), request_config.clone())
373            .await?;
374
375        loop {
376            let function_calls: Vec<FunctionCall> =
377                response.function_calls().into_iter().cloned().collect();
378
379            if function_calls.is_empty() {
380                if append_history && !history.is_empty() {
381                    response.automatic_function_calling_history = Some(history);
382                }
383                return Ok(response);
384            }
385
386            if remaining_calls == 0 {
387                break;
388            }
389
390            let response_parts = call_callable_tools(
391                &mut callable_tools,
392                &callable_info.function_map,
393                &function_calls,
394            )
395            .await?;
396            if response_parts.is_empty() {
397                break;
398            }
399
400            let call_content = build_function_call_content(&function_calls);
401            let response_content = Content::from_parts(response_parts.clone(), Role::Function);
402
403            if append_history {
404                if history.is_empty() {
405                    history.extend(conversation.clone());
406                }
407                history.push(call_content.clone());
408                history.push(response_content.clone());
409            }
410
411            conversation.push(call_content);
412            conversation.push(response_content);
413            remaining_calls = remaining_calls.saturating_sub(1);
414
415            response = self
416                .generate_content_with_config(&model, conversation.clone(), request_config.clone())
417                .await?;
418        }
419
420        if append_history && !history.is_empty() {
421            response.automatic_function_calling_history = Some(history);
422        }
423        Ok(response)
424    }
425
426    /// 生成内容(流式 + 自动函数调用)。
427    ///
428    /// # Errors
429    ///
430    /// 当配置校验失败、自动函数调用执行失败或请求失败时返回错误。
431    pub async fn generate_content_stream_with_callable_tools(
432        &self,
433        model: impl Into<String>,
434        contents: Vec<Content>,
435        config: GenerateContentConfig,
436        mut callable_tools: Vec<Box<dyn CallableTool>>,
437    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
438        if config.should_return_http_response.unwrap_or(false) {
439            return Err(Error::InvalidConfig {
440                message: "should_return_http_response is not supported in callable tools methods"
441                    .into(),
442            });
443        }
444        let model = model.into();
445        if callable_tools.is_empty() {
446            return self.generate_content_stream(model, contents, config).await;
447        }
448
449        validate_afc_config(&config)?;
450
451        let callable_info = resolve_callable_tools(&mut callable_tools).await?;
452        let function_map = callable_info.function_map;
453        let has_callable = !function_map.is_empty();
454        let mut merged_tools = config.tools.clone().unwrap_or_default();
455        merged_tools.extend(callable_info.tools);
456
457        let mut request_config = config.clone();
458        request_config.tools = Some(merged_tools);
459
460        if should_disable_afc(&config, has_callable) {
461            return self
462                .generate_content_stream(model, contents, request_config)
463                .await;
464        }
465
466        validate_afc_tools(&function_map, config.tools.as_deref())?;
467
468        let max_calls = max_remote_calls(&config);
469        let append_history = should_append_history(&config);
470        let (tx, rx) = tokio::sync::mpsc::channel(8);
471        let models = self.clone();
472        let ctx = CallableStreamContext {
473            models,
474            model,
475            contents,
476            request_config,
477            callable_tools,
478            function_map,
479            max_calls,
480            append_history,
481        };
482        spawn_callable_stream_loop(ctx, tx);
483
484        let output = futures_util::stream::unfold(rx, |mut rx| async {
485            rx.recv().await.map(|item| (item, rx))
486        });
487
488        Ok(Box::pin(output))
489    }
490
491    /// 生成内容(流式)。
492    ///
493    /// # Errors
494    ///
495    /// 当请求失败、配置校验失败或响应解析失败时返回错误。
496    pub async fn generate_content_stream(
497        &self,
498        model: impl Into<String>,
499        contents: Vec<Content>,
500        config: GenerateContentConfig,
501    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
502        if config.should_return_http_response.unwrap_or(false) {
503            return Err(Error::InvalidConfig {
504                message: "should_return_http_response is not supported in streaming methods".into(),
505            });
506        }
507        let model = model.into();
508        validate_temperature(&model, &config)?;
509        ThoughtSignatureValidator::new(&model).validate(&contents)?;
510        validate_function_response_media(&model, &contents)?;
511        validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
512
513        let backend = self.inner.config.backend;
514        if backend == Backend::GeminiApi && config.model_armor_config.is_some() {
515            return Err(Error::InvalidConfig {
516                message: "model_armor_config is not supported in Gemini API".into(),
517            });
518        }
519        if config.model_armor_config.is_some() && config.safety_settings.is_some() {
520            return Err(Error::InvalidConfig {
521                message: "model_armor_config cannot be combined with safety_settings".into(),
522            });
523        }
524
525        let request = GenerateContentRequest {
526            contents,
527            system_instruction: config.system_instruction,
528            generation_config: config.generation_config,
529            safety_settings: config.safety_settings,
530            model_armor_config: config.model_armor_config,
531            tools: config.tools,
532            tool_config: config.tool_config,
533            cached_content: config.cached_content,
534            labels: config.labels,
535        };
536
537        let mut url = build_model_method_url(&self.inner, &model, "streamGenerateContent")?;
538        url.push_str("?alt=sse");
539
540        let request = self.inner.http.post(url).json(&request);
541        let response = self.inner.send(request).await?;
542        if !response.status().is_success() {
543            return Err(Error::ApiError {
544                status: response.status().as_u16(),
545                message: response.text().await.unwrap_or_default(),
546            });
547        }
548
549        let headers = response.headers().clone();
550        let sdk_http_response = sdk_http_response_from_headers(&headers);
551        let stream = parse_sse_stream(response).map(move |item| {
552            item.map(|mut resp| {
553                resp.sdk_http_response = Some(sdk_http_response.clone());
554                resp
555            })
556        });
557        Ok(Box::pin(stream))
558    }
559
560    /// 生成嵌入向量(默认配置)。
561    ///
562    /// # Errors
563    ///
564    /// 当请求失败或响应解析失败时返回错误。
565    pub async fn embed_content(
566        &self,
567        model: impl Into<String>,
568        contents: Vec<Content>,
569    ) -> Result<EmbedContentResponse> {
570        self.embed_content_with_config(model, contents, EmbedContentConfig::default())
571            .await
572    }
573
574    /// 生成嵌入向量(自定义配置)。
575    ///
576    /// # Errors
577    ///
578    /// 当请求失败、配置不合法或响应解析失败时返回错误。
579    pub async fn embed_content_with_config(
580        &self,
581        model: impl Into<String>,
582        contents: Vec<Content>,
583        config: EmbedContentConfig,
584    ) -> Result<EmbedContentResponse> {
585        let model = model.into();
586        let url = match self.inner.config.backend {
587            Backend::GeminiApi => {
588                build_model_method_url(&self.inner, &model, "batchEmbedContents")?
589            }
590            Backend::VertexAi => build_model_method_url(&self.inner, &model, "predict")?,
591        };
592
593        let body = match self.inner.config.backend {
594            Backend::GeminiApi => build_embed_body_gemini(&model, &contents, &config)?,
595            Backend::VertexAi => build_embed_body_vertex(&contents, &config)?,
596        };
597
598        let request = self.inner.http.post(url).json(&body);
599        let response = self.inner.send(request).await?;
600        if !response.status().is_success() {
601            return Err(Error::ApiError {
602                status: response.status().as_u16(),
603                message: response.text().await.unwrap_or_default(),
604            });
605        }
606
607        let headers = response.headers().clone();
608        match self.inner.config.backend {
609            Backend::GeminiApi => {
610                let mut result = response.json::<EmbedContentResponse>().await?;
611                result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
612                Ok(result)
613            }
614            Backend::VertexAi => {
615                let value = response.json::<Value>().await?;
616                let mut result = convert_vertex_embed_response(&value)?;
617                result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
618                Ok(result)
619            }
620        }
621    }
622
623    /// 计数 tokens(默认配置)。
624    ///
625    /// # Errors
626    ///
627    /// 当请求失败或响应解析失败时返回错误。
628    pub async fn count_tokens(
629        &self,
630        model: impl Into<String>,
631        contents: Vec<Content>,
632    ) -> Result<CountTokensResponse> {
633        self.count_tokens_with_config(model, contents, CountTokensConfig::default())
634            .await
635    }
636
637    /// 计数 tokens(自定义配置)。
638    ///
639    /// # Errors
640    ///
641    /// 当请求失败、配置不合法或响应解析失败时返回错误。
642    pub async fn count_tokens_with_config(
643        &self,
644        model: impl Into<String>,
645        contents: Vec<Content>,
646        config: CountTokensConfig,
647    ) -> Result<CountTokensResponse> {
648        let request = CountTokensRequest {
649            contents,
650            system_instruction: config.system_instruction,
651            tools: config.tools,
652            generation_config: config.generation_config,
653        };
654
655        let backend = self.inner.config.backend;
656        let url = build_model_method_url(&self.inner, &model.into(), "countTokens")?;
657        let body = match backend {
658            Backend::GeminiApi => converters::count_tokens_request_to_mldev(&request)?,
659            Backend::VertexAi => converters::count_tokens_request_to_vertex(&request)?,
660        };
661        let request = self.inner.http.post(url).json(&body);
662        let response = self.inner.send(request).await?;
663        if !response.status().is_success() {
664            return Err(Error::ApiError {
665                status: response.status().as_u16(),
666                message: response.text().await.unwrap_or_default(),
667            });
668        }
669        let headers = response.headers().clone();
670        let value = response.json::<Value>().await?;
671        let mut result = match backend {
672            Backend::GeminiApi => converters::count_tokens_response_from_mldev(value)?,
673            Backend::VertexAi => converters::count_tokens_response_from_vertex(value)?,
674        };
675        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
676        Ok(result)
677    }
678
679    /// 计算 tokens(默认配置,仅 Vertex AI)。
680    ///
681    /// # Errors
682    ///
683    /// 当后端不支持或请求失败时返回错误。
684    pub async fn compute_tokens(
685        &self,
686        model: impl Into<String>,
687        contents: Vec<Content>,
688    ) -> Result<ComputeTokensResponse> {
689        self.compute_tokens_with_config(model, contents, ComputeTokensConfig::default())
690            .await
691    }
692
693    /// 计算 tokens(自定义配置,仅 Vertex AI)。
694    ///
695    /// # Errors
696    ///
697    /// 当后端不支持、配置不合法或请求失败时返回错误。
698    pub async fn compute_tokens_with_config(
699        &self,
700        model: impl Into<String>,
701        contents: Vec<Content>,
702        config: ComputeTokensConfig,
703    ) -> Result<ComputeTokensResponse> {
704        if self.inner.config.backend != Backend::VertexAi {
705            return Err(Error::InvalidConfig {
706                message: "Compute tokens is only supported in Vertex AI backend".into(),
707            });
708        }
709
710        let request = ComputeTokensRequest { contents };
711        let url = build_model_method_url(&self.inner, &model.into(), "computeTokens")?;
712        let mut body = converters::compute_tokens_request_to_vertex(&request)?;
713        if let Some(options) = config.http_options.as_ref() {
714            merge_extra_body(&mut body, options)?;
715        }
716
717        let mut request = self.inner.http.post(url).json(&body);
718        request = apply_http_options(request, config.http_options.as_ref())?;
719
720        let response = self
721            .inner
722            .send_with_http_options(request, config.http_options.as_ref())
723            .await?;
724        if !response.status().is_success() {
725            return Err(Error::ApiError {
726                status: response.status().as_u16(),
727                message: response.text().await.unwrap_or_default(),
728            });
729        }
730        let headers = response.headers().clone();
731        let value = response.json::<Value>().await?;
732        let mut result = converters::compute_tokens_response_from_vertex(value)?;
733        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
734        Ok(result)
735    }
736
737    /// 本地估算 tokens(离线估算器)。
738    pub fn estimate_tokens_local(
739        &self,
740        contents: &[Content],
741        estimator: &dyn TokenEstimator,
742    ) -> CountTokensResponse {
743        let total = i32::try_from(estimator.estimate_tokens(contents)).unwrap_or(i32::MAX);
744        CountTokensResponse {
745            sdk_http_response: None,
746            total_tokens: Some(total),
747            cached_content_token_count: None,
748        }
749    }
750
751    /// 本地估算 tokens(包含 tools / system instruction / response schema)。
752    pub fn estimate_tokens_local_with_config(
753        &self,
754        contents: &[Content],
755        config: &CountTokensConfig,
756        estimator: &dyn TokenEstimator,
757    ) -> CountTokensResponse {
758        let estimation_contents = crate::tokenizer::build_estimation_contents(contents, config);
759        let total =
760            i32::try_from(estimator.estimate_tokens(&estimation_contents)).unwrap_or(i32::MAX);
761        CountTokensResponse {
762            sdk_http_response: None,
763            total_tokens: Some(total),
764            cached_content_token_count: None,
765        }
766    }
767
768    /// 计数 tokens(优先使用本地估算器)。
769    ///
770    /// # Errors
771    ///
772    /// 当请求失败或响应解析失败时返回错误。
773    pub async fn count_tokens_or_estimate(
774        &self,
775        model: impl Into<String> + Send,
776        contents: Vec<Content>,
777        config: CountTokensConfig,
778        estimator: Option<&(dyn TokenEstimator + Sync)>,
779    ) -> Result<CountTokensResponse> {
780        if let Some(estimator) = estimator {
781            return Ok(self.estimate_tokens_local_with_config(&contents, &config, estimator));
782        }
783        self.count_tokens_with_config(model, contents, config).await
784    }
785
786    /// 生成图像(Imagen)。
787    ///
788    /// # Errors
789    ///
790    /// 当请求失败、配置不合法或响应解析失败时返回错误。
791    pub async fn generate_images(
792        &self,
793        model: impl Into<String>,
794        prompt: impl Into<String>,
795        mut config: GenerateImagesConfig,
796    ) -> Result<GenerateImagesResponse> {
797        let http_options = config.http_options.take();
798        let model = model.into();
799        let prompt = prompt.into();
800        let mut body = build_generate_images_body(self.inner.config.backend, &prompt, &config)?;
801        if let Some(options) = http_options.as_ref() {
802            merge_extra_body(&mut body, options)?;
803        }
804        let url = build_model_method_url(&self.inner, &model, "predict")?;
805
806        let mut request = self.inner.http.post(url).json(&body);
807        request = apply_http_options(request, http_options.as_ref())?;
808
809        let response = self
810            .inner
811            .send_with_http_options(request, http_options.as_ref())
812            .await?;
813        if !response.status().is_success() {
814            return Err(Error::ApiError {
815                status: response.status().as_u16(),
816                message: response.text().await.unwrap_or_default(),
817            });
818        }
819
820        let headers = response.headers().clone();
821        let value = response.json::<Value>().await?;
822        let mut result = parse_generate_images_response(&value);
823        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
824        Ok(result)
825    }
826
827    /// 编辑图像(仅 Vertex AI)。
828    ///
829    /// # Errors
830    ///
831    /// 当后端不支持、请求失败或响应解析失败时返回错误。
832    pub async fn edit_image(
833        &self,
834        model: impl Into<String>,
835        prompt: impl Into<String>,
836        reference_images: Vec<ReferenceImage>,
837        mut config: EditImageConfig,
838    ) -> Result<EditImageResponse> {
839        if self.inner.config.backend != Backend::VertexAi {
840            return Err(Error::InvalidConfig {
841                message: "Edit image is only supported in Vertex AI backend".into(),
842            });
843        }
844
845        let http_options = config.http_options.take();
846        let model = model.into();
847        let prompt = prompt.into();
848        let mut body = build_edit_image_body(&prompt, &reference_images, &config)?;
849        if let Some(options) = http_options.as_ref() {
850            merge_extra_body(&mut body, options)?;
851        }
852        let url = build_model_method_url(&self.inner, &model, "predict")?;
853
854        let mut request = self.inner.http.post(url).json(&body);
855        request = apply_http_options(request, http_options.as_ref())?;
856
857        let response = self
858            .inner
859            .send_with_http_options(request, http_options.as_ref())
860            .await?;
861        if !response.status().is_success() {
862            return Err(Error::ApiError {
863                status: response.status().as_u16(),
864                message: response.text().await.unwrap_or_default(),
865            });
866        }
867
868        let headers = response.headers().clone();
869        let value = response.json::<Value>().await?;
870        let mut result = parse_edit_image_response(&value);
871        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
872        Ok(result)
873    }
874
875    /// 放大图像(仅 Vertex AI)。
876    ///
877    /// # Errors
878    ///
879    /// 当后端不支持、请求失败或响应解析失败时返回错误。
880    pub async fn upscale_image(
881        &self,
882        model: impl Into<String>,
883        image: Image,
884        upscale_factor: impl Into<String>,
885        mut config: rust_genai_types::models::UpscaleImageConfig,
886    ) -> Result<rust_genai_types::models::UpscaleImageResponse> {
887        if self.inner.config.backend != Backend::VertexAi {
888            return Err(Error::InvalidConfig {
889                message: "Upscale image is only supported in Vertex AI backend".into(),
890            });
891        }
892
893        let http_options = config.http_options.take();
894        let model = model.into();
895        let upscale_factor = upscale_factor.into();
896        let mut body = build_upscale_image_body(&image, &upscale_factor, &config)?;
897        if let Some(options) = http_options.as_ref() {
898            merge_extra_body(&mut body, options)?;
899        }
900        let url = build_model_method_url(&self.inner, &model, "predict")?;
901
902        let mut request = self.inner.http.post(url).json(&body);
903        request = apply_http_options(request, http_options.as_ref())?;
904
905        let response = self
906            .inner
907            .send_with_http_options(request, http_options.as_ref())
908            .await?;
909        if !response.status().is_success() {
910            return Err(Error::ApiError {
911                status: response.status().as_u16(),
912                message: response.text().await.unwrap_or_default(),
913            });
914        }
915
916        let headers = response.headers().clone();
917        let value = response.json::<Value>().await?;
918        let mut result = parse_upscale_image_response(&value);
919        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
920        Ok(result)
921    }
922
923    /// Recontext 图像(Vertex AI)。
924    ///
925    /// # Errors
926    ///
927    /// 当后端不支持、请求失败或响应解析失败时返回错误。
928    pub async fn recontext_image(
929        &self,
930        model: impl Into<String>,
931        source: RecontextImageSource,
932        mut config: RecontextImageConfig,
933    ) -> Result<RecontextImageResponse> {
934        if self.inner.config.backend != Backend::VertexAi {
935            return Err(Error::InvalidConfig {
936                message: "Recontext image is only supported in Vertex AI backend".into(),
937            });
938        }
939
940        let http_options = config.http_options.take();
941        let model = model.into();
942        let mut body = build_recontext_image_body(&source, &config)?;
943        if let Some(options) = http_options.as_ref() {
944            merge_extra_body(&mut body, options)?;
945        }
946        let url = build_model_method_url(&self.inner, &model, "predict")?;
947
948        let mut request = self.inner.http.post(url).json(&body);
949        request = apply_http_options(request, http_options.as_ref())?;
950
951        let response = self
952            .inner
953            .send_with_http_options(request, http_options.as_ref())
954            .await?;
955        if !response.status().is_success() {
956            return Err(Error::ApiError {
957                status: response.status().as_u16(),
958                message: response.text().await.unwrap_or_default(),
959            });
960        }
961
962        let value = response.json::<Value>().await?;
963        Ok(parse_recontext_image_response(&value))
964    }
965
966    /// Segment 图像(Vertex AI)。
967    ///
968    /// # Errors
969    ///
970    /// 当后端不支持、请求失败或响应解析失败时返回错误。
971    pub async fn segment_image(
972        &self,
973        model: impl Into<String>,
974        source: SegmentImageSource,
975        mut config: SegmentImageConfig,
976    ) -> Result<SegmentImageResponse> {
977        if self.inner.config.backend != Backend::VertexAi {
978            return Err(Error::InvalidConfig {
979                message: "Segment image is only supported in Vertex AI backend".into(),
980            });
981        }
982
983        let http_options = config.http_options.take();
984        let model = model.into();
985        let mut body = build_segment_image_body(&source, &config)?;
986        if let Some(options) = http_options.as_ref() {
987            merge_extra_body(&mut body, options)?;
988        }
989        let url = build_model_method_url(&self.inner, &model, "predict")?;
990
991        let mut request = self.inner.http.post(url).json(&body);
992        request = apply_http_options(request, http_options.as_ref())?;
993
994        let response = self
995            .inner
996            .send_with_http_options(request, http_options.as_ref())
997            .await?;
998        if !response.status().is_success() {
999            return Err(Error::ApiError {
1000                status: response.status().as_u16(),
1001                message: response.text().await.unwrap_or_default(),
1002            });
1003        }
1004
1005        let value = response.json::<Value>().await?;
1006        Ok(parse_segment_image_response(&value))
1007    }
1008
1009    /// 生成视频(Veo)。
1010    ///
1011    /// # Errors
1012    ///
1013    /// 当请求失败、配置不合法或响应解析失败时返回错误。
1014    pub async fn generate_videos(
1015        &self,
1016        model: impl Into<String>,
1017        source: GenerateVideosSource,
1018        mut config: GenerateVideosConfig,
1019    ) -> Result<GenerateVideosOperation> {
1020        let http_options = config.http_options.take();
1021        let model = model.into();
1022        let mut body = build_generate_videos_body(self.inner.config.backend, &source, &config)?;
1023        if let Some(options) = http_options.as_ref() {
1024            merge_extra_body(&mut body, options)?;
1025        }
1026        let url = build_model_method_url(&self.inner, &model, "predictLongRunning")?;
1027
1028        let mut request = self.inner.http.post(url).json(&body);
1029        request = apply_http_options(request, http_options.as_ref())?;
1030
1031        let response = self
1032            .inner
1033            .send_with_http_options(request, http_options.as_ref())
1034            .await?;
1035        if !response.status().is_success() {
1036            return Err(Error::ApiError {
1037                status: response.status().as_u16(),
1038                message: response.text().await.unwrap_or_default(),
1039            });
1040        }
1041
1042        let value = response.json::<Value>().await?;
1043        parse_generate_videos_operation(value, self.inner.config.backend)
1044    }
1045
1046    /// 生成视频(仅文本提示)。
1047    ///
1048    /// # Errors
1049    ///
1050    /// 当请求失败或配置不合法时返回错误。
1051    pub async fn generate_videos_with_prompt(
1052        &self,
1053        model: impl Into<String>,
1054        prompt: impl Into<String>,
1055        config: GenerateVideosConfig,
1056    ) -> Result<GenerateVideosOperation> {
1057        let source = GenerateVideosSource {
1058            prompt: Some(prompt.into()),
1059            ..GenerateVideosSource::default()
1060        };
1061        self.generate_videos(model, source, config).await
1062    }
1063
1064    /// 列出模型(基础列表)。
1065    ///
1066    /// # Errors
1067    ///
1068    /// 当请求失败或响应解析失败时返回错误。
1069    pub async fn list(&self) -> Result<ListModelsResponse> {
1070        self.list_with_config(ListModelsConfig::default()).await
1071    }
1072
1073    /// 列出模型(带配置)。
1074    ///
1075    /// # Errors
1076    ///
1077    /// 当请求失败、配置不合法或响应解析失败时返回错误。
1078    pub async fn list_with_config(&self, config: ListModelsConfig) -> Result<ListModelsResponse> {
1079        let url = build_models_list_url(&self.inner, &config)?;
1080        let request = self.inner.http.get(url);
1081        let response = self.inner.send(request).await?;
1082        if !response.status().is_success() {
1083            return Err(Error::ApiError {
1084                status: response.status().as_u16(),
1085                message: response.text().await.unwrap_or_default(),
1086            });
1087        }
1088        let headers = response.headers().clone();
1089        let mut result = response.json::<ListModelsResponse>().await?;
1090        result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
1091        Ok(result)
1092    }
1093
1094    /// 列出所有模型(自动翻页)。
1095    ///
1096    /// # Errors
1097    ///
1098    /// 当请求失败或响应解析失败时返回错误。
1099    pub async fn all(&self) -> Result<Vec<Model>> {
1100        self.all_with_config(ListModelsConfig::default()).await
1101    }
1102
1103    /// 列出所有模型(带配置,自动翻页)。
1104    ///
1105    /// # Errors
1106    ///
1107    /// 当请求失败、配置不合法或响应解析失败时返回错误。
1108    pub async fn all_with_config(&self, mut config: ListModelsConfig) -> Result<Vec<Model>> {
1109        let mut models = Vec::new();
1110        loop {
1111            let response = self.list_with_config(config.clone()).await?;
1112            if let Some(items) = response.models {
1113                models.extend(items);
1114            }
1115            match response.next_page_token {
1116                Some(token) if !token.is_empty() => {
1117                    config.page_token = Some(token);
1118                }
1119                _ => break,
1120            }
1121        }
1122        Ok(models)
1123    }
1124
1125    /// 获取单个模型信息。
1126    ///
1127    /// # Errors
1128    ///
1129    /// 当请求失败或响应解析失败时返回错误。
1130    pub async fn get(&self, model: impl Into<String>) -> Result<Model> {
1131        let url = build_model_get_url(&self.inner, &model.into())?;
1132        let request = self.inner.http.get(url);
1133        let response = self.inner.send(request).await?;
1134        if !response.status().is_success() {
1135            return Err(Error::ApiError {
1136                status: response.status().as_u16(),
1137                message: response.text().await.unwrap_or_default(),
1138            });
1139        }
1140        let result = response.json::<Model>().await?;
1141        Ok(result)
1142    }
1143
1144    /// 更新模型信息。
1145    ///
1146    /// # Errors
1147    ///
1148    /// 当请求失败、配置不合法或响应解析失败时返回错误。
1149    pub async fn update(
1150        &self,
1151        model: impl Into<String>,
1152        mut config: UpdateModelConfig,
1153    ) -> Result<Model> {
1154        let http_options = config.http_options.take();
1155        let url =
1156            build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1157
1158        let mut body = serde_json::to_value(&config)?;
1159        if let Some(options) = http_options.as_ref() {
1160            merge_extra_body(&mut body, options)?;
1161        }
1162        let mut request = self.inner.http.patch(url).json(&body);
1163        request = apply_http_options(request, http_options.as_ref())?;
1164
1165        let response = self
1166            .inner
1167            .send_with_http_options(request, http_options.as_ref())
1168            .await?;
1169        if !response.status().is_success() {
1170            return Err(Error::ApiError {
1171                status: response.status().as_u16(),
1172                message: response.text().await.unwrap_or_default(),
1173            });
1174        }
1175        Ok(response.json::<Model>().await?)
1176    }
1177
1178    /// 删除模型。
1179    ///
1180    /// # Errors
1181    ///
1182    /// 当请求失败或响应解析失败时返回错误。
1183    pub async fn delete(
1184        &self,
1185        model: impl Into<String>,
1186        mut config: DeleteModelConfig,
1187    ) -> Result<DeleteModelResponse> {
1188        let http_options = config.http_options.take();
1189        let url =
1190            build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1191
1192        let mut request = self.inner.http.delete(url);
1193        request = apply_http_options(request, http_options.as_ref())?;
1194
1195        let response = self
1196            .inner
1197            .send_with_http_options(request, http_options.as_ref())
1198            .await?;
1199        if !response.status().is_success() {
1200            return Err(Error::ApiError {
1201                status: response.status().as_u16(),
1202                message: response.text().await.unwrap_or_default(),
1203            });
1204        }
1205        let headers = response.headers().clone();
1206        if response.content_length().unwrap_or(0) == 0 {
1207            let resp = DeleteModelResponse {
1208                sdk_http_response: Some(sdk_http_response_from_headers(&headers)),
1209            };
1210            return Ok(resp);
1211        }
1212        let mut resp = response
1213            .json::<DeleteModelResponse>()
1214            .await
1215            .unwrap_or_default();
1216        resp.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
1217        Ok(resp)
1218    }
1219}
1220
1221#[cfg(test)]
1222mod tests;