1use std::collections::HashMap;
4use std::hash::BuildHasher;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use futures_util::{Stream, StreamExt};
9use rust_genai_types::content::{Content, FunctionCall, Role};
10use rust_genai_types::converters;
11use rust_genai_types::models::{
12 ComputeTokensConfig, ComputeTokensRequest, ComputeTokensResponse, CountTokensConfig,
13 CountTokensRequest, CountTokensResponse, DeleteModelConfig, DeleteModelResponse,
14 EditImageConfig, EditImageResponse, EmbedContentConfig, EmbedContentResponse,
15 GenerateContentConfig, GenerateContentRequest, GenerateImagesConfig, GenerateImagesResponse,
16 GenerateVideosConfig, GenerateVideosOperation, GenerateVideosSource, Image, ListModelsConfig,
17 ListModelsResponse, Model, RecontextImageConfig, RecontextImageResponse, RecontextImageSource,
18 ReferenceImage, SegmentImageConfig, SegmentImageResponse, SegmentImageSource,
19 UpdateModelConfig,
20};
21use rust_genai_types::response::GenerateContentResponse;
22
23use crate::afc::{
24 call_callable_tools, max_remote_calls, resolve_callable_tools, should_append_history,
25 should_disable_afc, validate_afc_config, validate_afc_tools, CallableTool,
26};
27use crate::client::{Backend, ClientInner};
28use crate::error::{Error, Result};
29use crate::http_response::{
30 sdk_http_response_from_headers, sdk_http_response_from_headers_and_body,
31};
32use crate::model_capabilities::{
33 validate_code_execution_image_inputs, validate_function_response_media,
34};
35use crate::sse::parse_sse_stream;
36use crate::thinking::{validate_temperature, ThoughtSignatureValidator};
37use crate::tokenizer::TokenEstimator;
38use serde_json::Value;
39
40mod builders;
41mod http;
42mod media;
43pub(crate) mod parsers;
44
45use builders::{
46 build_edit_image_body, build_embed_body_gemini, build_embed_body_vertex,
47 build_function_call_content, build_generate_images_body, build_generate_videos_body,
48 build_recontext_image_body, build_segment_image_body, build_upscale_image_body,
49};
50use http::{
51 apply_http_options, build_model_get_url, build_model_get_url_with_options,
52 build_model_method_url, build_models_list_url, merge_extra_body,
53};
54use parsers::{
55 convert_vertex_embed_response, parse_edit_image_response, parse_generate_images_response,
56 parse_generate_videos_operation, parse_recontext_image_response, parse_segment_image_response,
57 parse_upscale_image_response,
58};
59
60#[derive(Clone)]
61pub struct Models {
62 pub(crate) inner: Arc<ClientInner>,
63}
64
65struct CallableStreamContext<S> {
66 models: Models,
67 model: String,
68 contents: Vec<Content>,
69 request_config: GenerateContentConfig,
70 callable_tools: Vec<Box<dyn CallableTool>>,
71 function_map: HashMap<String, usize, S>,
72 max_calls: usize,
73 append_history: bool,
74}
75
76fn build_synthetic_afc_response(
77 response_content: Content,
78 history: &[Content],
79) -> GenerateContentResponse {
80 let mut response = GenerateContentResponse {
81 sdk_http_response: None,
82 candidates: vec![rust_genai_types::response::Candidate {
83 content: Some(response_content),
84 citation_metadata: None,
85 finish_message: None,
86 token_count: None,
87 finish_reason: None,
88 avg_logprobs: None,
89 grounding_metadata: None,
90 index: None,
91 logprobs_result: None,
92 safety_ratings: Vec::new(),
93 url_context_metadata: None,
94 }],
95 create_time: None,
96 automatic_function_calling_history: None,
97 prompt_feedback: None,
98 usage_metadata: None,
99 model_version: None,
100 response_id: None,
101 };
102
103 if !history.is_empty() {
104 response.automatic_function_calling_history = Some(history.to_vec());
105 }
106
107 response
108}
109
110async fn forward_stream_items(
111 mut stream: Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>,
112 tx: &tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
113) -> Option<(Vec<FunctionCall>, Vec<Content>)> {
114 let mut function_calls: Vec<FunctionCall> = Vec::new();
115 let mut response_contents: Vec<Content> = Vec::new();
116
117 while let Some(item) = stream.next().await {
118 if let Ok(response) = &item {
119 if let Some(content) = response.candidates.first().and_then(|c| c.content.clone()) {
120 for part in &content.parts {
121 if let Some(call) = part.function_call_ref() {
122 function_calls.push(call.clone());
123 }
124 }
125 response_contents.push(content);
126 }
127 }
128
129 if tx.send(item).await.is_err() {
130 return None;
131 }
132 }
133
134 Some((function_calls, response_contents))
135}
136
137fn spawn_callable_stream_loop<S>(
138 ctx: CallableStreamContext<S>,
139 tx: tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
140) where
141 S: BuildHasher + Sync + Send + 'static,
142{
143 let CallableStreamContext {
144 models,
145 model,
146 contents,
147 request_config,
148 mut callable_tools,
149 function_map,
150 max_calls,
151 append_history,
152 } = ctx;
153 tokio::spawn(async move {
154 let mut conversation = contents;
155 let mut history: Vec<Content> = Vec::new();
156 let mut remaining_calls = max_calls;
157
158 loop {
159 if remaining_calls == 0 {
160 break;
161 }
162
163 let stream = match models
164 .generate_content_stream(&model, conversation.clone(), request_config.clone())
165 .await
166 {
167 Ok(stream) => stream,
168 Err(err) => {
169 let _ = tx.send(Err(err)).await;
170 break;
171 }
172 };
173
174 let Some((function_calls, response_contents)) = forward_stream_items(stream, &tx).await
175 else {
176 return;
177 };
178
179 if function_calls.is_empty() {
180 break;
181 }
182
183 let response_parts = match call_callable_tools(
184 &mut callable_tools,
185 &function_map,
186 &function_calls,
187 )
188 .await
189 {
190 Ok(parts) => parts,
191 Err(err) => {
192 let _ = tx.send(Err(err)).await;
193 break;
194 }
195 };
196
197 if response_parts.is_empty() {
198 break;
199 }
200
201 let call_content = build_function_call_content(&function_calls);
202 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
203
204 if append_history {
205 if history.is_empty() {
206 history.extend(conversation.clone());
207 }
208 history.push(call_content.clone());
209 history.push(response_content.clone());
210 }
211
212 conversation.extend(response_contents);
213 conversation.push(call_content);
214 conversation.push(response_content.clone());
215 remaining_calls = remaining_calls.saturating_sub(1);
216
217 let synthetic = build_synthetic_afc_response(response_content, &history);
218 if tx.send(Ok(synthetic)).await.is_err() {
219 return;
220 }
221 }
222 });
223}
224
225impl Models {
226 pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
227 Self { inner }
228 }
229
230 pub async fn generate_content(
236 &self,
237 model: impl Into<String>,
238 contents: Vec<Content>,
239 ) -> Result<GenerateContentResponse> {
240 self.generate_content_with_config(model, contents, GenerateContentConfig::default())
241 .await
242 }
243
244 pub async fn generate_content_with_config(
250 &self,
251 model: impl Into<String>,
252 contents: Vec<Content>,
253 config: GenerateContentConfig,
254 ) -> Result<GenerateContentResponse> {
255 let should_return_http_response = config.should_return_http_response.unwrap_or(false);
256 let model = model.into();
257 validate_temperature(&model, &config)?;
258 ThoughtSignatureValidator::new(&model).validate(&contents)?;
259 validate_function_response_media(&model, &contents)?;
260 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
261
262 let backend = self.inner.config.backend;
263 if backend == Backend::GeminiApi && config.model_armor_config.is_some() {
264 return Err(Error::InvalidConfig {
265 message: "model_armor_config is not supported in Gemini API".into(),
266 });
267 }
268 if config.model_armor_config.is_some() && config.safety_settings.is_some() {
269 return Err(Error::InvalidConfig {
270 message: "model_armor_config cannot be combined with safety_settings".into(),
271 });
272 }
273
274 let request = GenerateContentRequest {
275 contents,
276 system_instruction: config.system_instruction,
277 generation_config: config.generation_config,
278 safety_settings: config.safety_settings,
279 model_armor_config: config.model_armor_config,
280 tools: config.tools,
281 tool_config: config.tool_config,
282 cached_content: config.cached_content,
283 labels: config.labels,
284 };
285
286 let url = build_model_method_url(&self.inner, &model, "generateContent")?;
287 let body = match backend {
288 Backend::GeminiApi => converters::generate_content_request_to_mldev(&request)?,
289 Backend::VertexAi => converters::generate_content_request_to_vertex(&request)?,
290 };
291
292 let request = self.inner.http.post(url).json(&body);
293 let response = self.inner.send(request).await?;
294 if !response.status().is_success() {
295 return Err(Error::ApiError {
296 status: response.status().as_u16(),
297 message: response.text().await.unwrap_or_default(),
298 });
299 }
300 let headers = response.headers().clone();
301 if should_return_http_response {
302 let body = response.text().await.unwrap_or_default();
303 return Ok(GenerateContentResponse {
304 sdk_http_response: Some(sdk_http_response_from_headers_and_body(&headers, body)),
305 candidates: Vec::new(),
306 create_time: None,
307 automatic_function_calling_history: None,
308 prompt_feedback: None,
309 usage_metadata: None,
310 model_version: None,
311 response_id: None,
312 });
313 }
314 let value = response.json::<Value>().await?;
315 let mut result = match backend {
316 Backend::GeminiApi => converters::generate_content_response_from_mldev(value)?,
317 Backend::VertexAi => converters::generate_content_response_from_vertex(value)?,
318 };
319 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
320 Ok(result)
321 }
322
323 pub async fn generate_content_with_callable_tools(
329 &self,
330 model: impl Into<String>,
331 contents: Vec<Content>,
332 config: GenerateContentConfig,
333 mut callable_tools: Vec<Box<dyn CallableTool>>,
334 ) -> Result<GenerateContentResponse> {
335 if config.should_return_http_response.unwrap_or(false) {
336 return Err(Error::InvalidConfig {
337 message: "should_return_http_response is not supported in callable tools methods"
338 .into(),
339 });
340 }
341 let model = model.into();
342 if callable_tools.is_empty() {
343 return self
344 .generate_content_with_config(model, contents, config)
345 .await;
346 }
347
348 validate_afc_config(&config)?;
349
350 let mut callable_info = resolve_callable_tools(&mut callable_tools).await?;
351 let has_callable = !callable_info.function_map.is_empty();
352 let mut merged_tools = config.tools.clone().unwrap_or_default();
353 merged_tools.append(&mut callable_info.tools);
354
355 let mut request_config = config.clone();
356 request_config.tools = Some(merged_tools);
357
358 if should_disable_afc(&config, has_callable) {
359 return self
360 .generate_content_with_config(model, contents, request_config)
361 .await;
362 }
363
364 validate_afc_tools(&callable_info.function_map, config.tools.as_deref())?;
365
366 let max_calls = max_remote_calls(&config);
367 let append_history = should_append_history(&config);
368 let mut history: Vec<Content> = Vec::new();
369 let mut conversation = contents.clone();
370 let mut remaining_calls = max_calls;
371 let mut response = self
372 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
373 .await?;
374
375 loop {
376 let function_calls: Vec<FunctionCall> =
377 response.function_calls().into_iter().cloned().collect();
378
379 if function_calls.is_empty() {
380 if append_history && !history.is_empty() {
381 response.automatic_function_calling_history = Some(history);
382 }
383 return Ok(response);
384 }
385
386 if remaining_calls == 0 {
387 break;
388 }
389
390 let response_parts = call_callable_tools(
391 &mut callable_tools,
392 &callable_info.function_map,
393 &function_calls,
394 )
395 .await?;
396 if response_parts.is_empty() {
397 break;
398 }
399
400 let call_content = build_function_call_content(&function_calls);
401 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
402
403 if append_history {
404 if history.is_empty() {
405 history.extend(conversation.clone());
406 }
407 history.push(call_content.clone());
408 history.push(response_content.clone());
409 }
410
411 conversation.push(call_content);
412 conversation.push(response_content);
413 remaining_calls = remaining_calls.saturating_sub(1);
414
415 response = self
416 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
417 .await?;
418 }
419
420 if append_history && !history.is_empty() {
421 response.automatic_function_calling_history = Some(history);
422 }
423 Ok(response)
424 }
425
426 pub async fn generate_content_stream_with_callable_tools(
432 &self,
433 model: impl Into<String>,
434 contents: Vec<Content>,
435 config: GenerateContentConfig,
436 mut callable_tools: Vec<Box<dyn CallableTool>>,
437 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
438 if config.should_return_http_response.unwrap_or(false) {
439 return Err(Error::InvalidConfig {
440 message: "should_return_http_response is not supported in callable tools methods"
441 .into(),
442 });
443 }
444 let model = model.into();
445 if callable_tools.is_empty() {
446 return self.generate_content_stream(model, contents, config).await;
447 }
448
449 validate_afc_config(&config)?;
450
451 let callable_info = resolve_callable_tools(&mut callable_tools).await?;
452 let function_map = callable_info.function_map;
453 let has_callable = !function_map.is_empty();
454 let mut merged_tools = config.tools.clone().unwrap_or_default();
455 merged_tools.extend(callable_info.tools);
456
457 let mut request_config = config.clone();
458 request_config.tools = Some(merged_tools);
459
460 if should_disable_afc(&config, has_callable) {
461 return self
462 .generate_content_stream(model, contents, request_config)
463 .await;
464 }
465
466 validate_afc_tools(&function_map, config.tools.as_deref())?;
467
468 let max_calls = max_remote_calls(&config);
469 let append_history = should_append_history(&config);
470 let (tx, rx) = tokio::sync::mpsc::channel(8);
471 let models = self.clone();
472 let ctx = CallableStreamContext {
473 models,
474 model,
475 contents,
476 request_config,
477 callable_tools,
478 function_map,
479 max_calls,
480 append_history,
481 };
482 spawn_callable_stream_loop(ctx, tx);
483
484 let output = futures_util::stream::unfold(rx, |mut rx| async {
485 rx.recv().await.map(|item| (item, rx))
486 });
487
488 Ok(Box::pin(output))
489 }
490
491 pub async fn generate_content_stream(
497 &self,
498 model: impl Into<String>,
499 contents: Vec<Content>,
500 config: GenerateContentConfig,
501 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
502 if config.should_return_http_response.unwrap_or(false) {
503 return Err(Error::InvalidConfig {
504 message: "should_return_http_response is not supported in streaming methods".into(),
505 });
506 }
507 let model = model.into();
508 validate_temperature(&model, &config)?;
509 ThoughtSignatureValidator::new(&model).validate(&contents)?;
510 validate_function_response_media(&model, &contents)?;
511 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
512
513 let backend = self.inner.config.backend;
514 if backend == Backend::GeminiApi && config.model_armor_config.is_some() {
515 return Err(Error::InvalidConfig {
516 message: "model_armor_config is not supported in Gemini API".into(),
517 });
518 }
519 if config.model_armor_config.is_some() && config.safety_settings.is_some() {
520 return Err(Error::InvalidConfig {
521 message: "model_armor_config cannot be combined with safety_settings".into(),
522 });
523 }
524
525 let request = GenerateContentRequest {
526 contents,
527 system_instruction: config.system_instruction,
528 generation_config: config.generation_config,
529 safety_settings: config.safety_settings,
530 model_armor_config: config.model_armor_config,
531 tools: config.tools,
532 tool_config: config.tool_config,
533 cached_content: config.cached_content,
534 labels: config.labels,
535 };
536
537 let mut url = build_model_method_url(&self.inner, &model, "streamGenerateContent")?;
538 url.push_str("?alt=sse");
539
540 let request = self.inner.http.post(url).json(&request);
541 let response = self.inner.send(request).await?;
542 if !response.status().is_success() {
543 return Err(Error::ApiError {
544 status: response.status().as_u16(),
545 message: response.text().await.unwrap_or_default(),
546 });
547 }
548
549 let headers = response.headers().clone();
550 let sdk_http_response = sdk_http_response_from_headers(&headers);
551 let stream = parse_sse_stream(response).map(move |item| {
552 item.map(|mut resp| {
553 resp.sdk_http_response = Some(sdk_http_response.clone());
554 resp
555 })
556 });
557 Ok(Box::pin(stream))
558 }
559
560 pub async fn embed_content(
566 &self,
567 model: impl Into<String>,
568 contents: Vec<Content>,
569 ) -> Result<EmbedContentResponse> {
570 self.embed_content_with_config(model, contents, EmbedContentConfig::default())
571 .await
572 }
573
574 pub async fn embed_content_with_config(
580 &self,
581 model: impl Into<String>,
582 contents: Vec<Content>,
583 config: EmbedContentConfig,
584 ) -> Result<EmbedContentResponse> {
585 let model = model.into();
586 let url = match self.inner.config.backend {
587 Backend::GeminiApi => {
588 build_model_method_url(&self.inner, &model, "batchEmbedContents")?
589 }
590 Backend::VertexAi => build_model_method_url(&self.inner, &model, "predict")?,
591 };
592
593 let body = match self.inner.config.backend {
594 Backend::GeminiApi => build_embed_body_gemini(&model, &contents, &config)?,
595 Backend::VertexAi => build_embed_body_vertex(&contents, &config)?,
596 };
597
598 let request = self.inner.http.post(url).json(&body);
599 let response = self.inner.send(request).await?;
600 if !response.status().is_success() {
601 return Err(Error::ApiError {
602 status: response.status().as_u16(),
603 message: response.text().await.unwrap_or_default(),
604 });
605 }
606
607 let headers = response.headers().clone();
608 match self.inner.config.backend {
609 Backend::GeminiApi => {
610 let mut result = response.json::<EmbedContentResponse>().await?;
611 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
612 Ok(result)
613 }
614 Backend::VertexAi => {
615 let value = response.json::<Value>().await?;
616 let mut result = convert_vertex_embed_response(&value)?;
617 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
618 Ok(result)
619 }
620 }
621 }
622
623 pub async fn count_tokens(
629 &self,
630 model: impl Into<String>,
631 contents: Vec<Content>,
632 ) -> Result<CountTokensResponse> {
633 self.count_tokens_with_config(model, contents, CountTokensConfig::default())
634 .await
635 }
636
637 pub async fn count_tokens_with_config(
643 &self,
644 model: impl Into<String>,
645 contents: Vec<Content>,
646 config: CountTokensConfig,
647 ) -> Result<CountTokensResponse> {
648 let request = CountTokensRequest {
649 contents,
650 system_instruction: config.system_instruction,
651 tools: config.tools,
652 generation_config: config.generation_config,
653 };
654
655 let backend = self.inner.config.backend;
656 let url = build_model_method_url(&self.inner, &model.into(), "countTokens")?;
657 let body = match backend {
658 Backend::GeminiApi => converters::count_tokens_request_to_mldev(&request)?,
659 Backend::VertexAi => converters::count_tokens_request_to_vertex(&request)?,
660 };
661 let request = self.inner.http.post(url).json(&body);
662 let response = self.inner.send(request).await?;
663 if !response.status().is_success() {
664 return Err(Error::ApiError {
665 status: response.status().as_u16(),
666 message: response.text().await.unwrap_or_default(),
667 });
668 }
669 let headers = response.headers().clone();
670 let value = response.json::<Value>().await?;
671 let mut result = match backend {
672 Backend::GeminiApi => converters::count_tokens_response_from_mldev(value)?,
673 Backend::VertexAi => converters::count_tokens_response_from_vertex(value)?,
674 };
675 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
676 Ok(result)
677 }
678
679 pub async fn compute_tokens(
685 &self,
686 model: impl Into<String>,
687 contents: Vec<Content>,
688 ) -> Result<ComputeTokensResponse> {
689 self.compute_tokens_with_config(model, contents, ComputeTokensConfig::default())
690 .await
691 }
692
693 pub async fn compute_tokens_with_config(
699 &self,
700 model: impl Into<String>,
701 contents: Vec<Content>,
702 config: ComputeTokensConfig,
703 ) -> Result<ComputeTokensResponse> {
704 if self.inner.config.backend != Backend::VertexAi {
705 return Err(Error::InvalidConfig {
706 message: "Compute tokens is only supported in Vertex AI backend".into(),
707 });
708 }
709
710 let request = ComputeTokensRequest { contents };
711 let url = build_model_method_url(&self.inner, &model.into(), "computeTokens")?;
712 let mut body = converters::compute_tokens_request_to_vertex(&request)?;
713 if let Some(options) = config.http_options.as_ref() {
714 merge_extra_body(&mut body, options)?;
715 }
716
717 let mut request = self.inner.http.post(url).json(&body);
718 request = apply_http_options(request, config.http_options.as_ref())?;
719
720 let response = self
721 .inner
722 .send_with_http_options(request, config.http_options.as_ref())
723 .await?;
724 if !response.status().is_success() {
725 return Err(Error::ApiError {
726 status: response.status().as_u16(),
727 message: response.text().await.unwrap_or_default(),
728 });
729 }
730 let headers = response.headers().clone();
731 let value = response.json::<Value>().await?;
732 let mut result = converters::compute_tokens_response_from_vertex(value)?;
733 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
734 Ok(result)
735 }
736
737 pub fn estimate_tokens_local(
739 &self,
740 contents: &[Content],
741 estimator: &dyn TokenEstimator,
742 ) -> CountTokensResponse {
743 let total = i32::try_from(estimator.estimate_tokens(contents)).unwrap_or(i32::MAX);
744 CountTokensResponse {
745 sdk_http_response: None,
746 total_tokens: Some(total),
747 cached_content_token_count: None,
748 }
749 }
750
751 pub fn estimate_tokens_local_with_config(
753 &self,
754 contents: &[Content],
755 config: &CountTokensConfig,
756 estimator: &dyn TokenEstimator,
757 ) -> CountTokensResponse {
758 let estimation_contents = crate::tokenizer::build_estimation_contents(contents, config);
759 let total =
760 i32::try_from(estimator.estimate_tokens(&estimation_contents)).unwrap_or(i32::MAX);
761 CountTokensResponse {
762 sdk_http_response: None,
763 total_tokens: Some(total),
764 cached_content_token_count: None,
765 }
766 }
767
768 pub async fn count_tokens_or_estimate(
774 &self,
775 model: impl Into<String> + Send,
776 contents: Vec<Content>,
777 config: CountTokensConfig,
778 estimator: Option<&(dyn TokenEstimator + Sync)>,
779 ) -> Result<CountTokensResponse> {
780 if let Some(estimator) = estimator {
781 return Ok(self.estimate_tokens_local_with_config(&contents, &config, estimator));
782 }
783 self.count_tokens_with_config(model, contents, config).await
784 }
785
786 pub async fn generate_images(
792 &self,
793 model: impl Into<String>,
794 prompt: impl Into<String>,
795 mut config: GenerateImagesConfig,
796 ) -> Result<GenerateImagesResponse> {
797 let http_options = config.http_options.take();
798 let model = model.into();
799 let prompt = prompt.into();
800 let mut body = build_generate_images_body(self.inner.config.backend, &prompt, &config)?;
801 if let Some(options) = http_options.as_ref() {
802 merge_extra_body(&mut body, options)?;
803 }
804 let url = build_model_method_url(&self.inner, &model, "predict")?;
805
806 let mut request = self.inner.http.post(url).json(&body);
807 request = apply_http_options(request, http_options.as_ref())?;
808
809 let response = self
810 .inner
811 .send_with_http_options(request, http_options.as_ref())
812 .await?;
813 if !response.status().is_success() {
814 return Err(Error::ApiError {
815 status: response.status().as_u16(),
816 message: response.text().await.unwrap_or_default(),
817 });
818 }
819
820 let headers = response.headers().clone();
821 let value = response.json::<Value>().await?;
822 let mut result = parse_generate_images_response(&value);
823 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
824 Ok(result)
825 }
826
827 pub async fn edit_image(
833 &self,
834 model: impl Into<String>,
835 prompt: impl Into<String>,
836 reference_images: Vec<ReferenceImage>,
837 mut config: EditImageConfig,
838 ) -> Result<EditImageResponse> {
839 if self.inner.config.backend != Backend::VertexAi {
840 return Err(Error::InvalidConfig {
841 message: "Edit image is only supported in Vertex AI backend".into(),
842 });
843 }
844
845 let http_options = config.http_options.take();
846 let model = model.into();
847 let prompt = prompt.into();
848 let mut body = build_edit_image_body(&prompt, &reference_images, &config)?;
849 if let Some(options) = http_options.as_ref() {
850 merge_extra_body(&mut body, options)?;
851 }
852 let url = build_model_method_url(&self.inner, &model, "predict")?;
853
854 let mut request = self.inner.http.post(url).json(&body);
855 request = apply_http_options(request, http_options.as_ref())?;
856
857 let response = self
858 .inner
859 .send_with_http_options(request, http_options.as_ref())
860 .await?;
861 if !response.status().is_success() {
862 return Err(Error::ApiError {
863 status: response.status().as_u16(),
864 message: response.text().await.unwrap_or_default(),
865 });
866 }
867
868 let headers = response.headers().clone();
869 let value = response.json::<Value>().await?;
870 let mut result = parse_edit_image_response(&value);
871 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
872 Ok(result)
873 }
874
875 pub async fn upscale_image(
881 &self,
882 model: impl Into<String>,
883 image: Image,
884 upscale_factor: impl Into<String>,
885 mut config: rust_genai_types::models::UpscaleImageConfig,
886 ) -> Result<rust_genai_types::models::UpscaleImageResponse> {
887 if self.inner.config.backend != Backend::VertexAi {
888 return Err(Error::InvalidConfig {
889 message: "Upscale image is only supported in Vertex AI backend".into(),
890 });
891 }
892
893 let http_options = config.http_options.take();
894 let model = model.into();
895 let upscale_factor = upscale_factor.into();
896 let mut body = build_upscale_image_body(&image, &upscale_factor, &config)?;
897 if let Some(options) = http_options.as_ref() {
898 merge_extra_body(&mut body, options)?;
899 }
900 let url = build_model_method_url(&self.inner, &model, "predict")?;
901
902 let mut request = self.inner.http.post(url).json(&body);
903 request = apply_http_options(request, http_options.as_ref())?;
904
905 let response = self
906 .inner
907 .send_with_http_options(request, http_options.as_ref())
908 .await?;
909 if !response.status().is_success() {
910 return Err(Error::ApiError {
911 status: response.status().as_u16(),
912 message: response.text().await.unwrap_or_default(),
913 });
914 }
915
916 let headers = response.headers().clone();
917 let value = response.json::<Value>().await?;
918 let mut result = parse_upscale_image_response(&value);
919 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
920 Ok(result)
921 }
922
923 pub async fn recontext_image(
929 &self,
930 model: impl Into<String>,
931 source: RecontextImageSource,
932 mut config: RecontextImageConfig,
933 ) -> Result<RecontextImageResponse> {
934 if self.inner.config.backend != Backend::VertexAi {
935 return Err(Error::InvalidConfig {
936 message: "Recontext image is only supported in Vertex AI backend".into(),
937 });
938 }
939
940 let http_options = config.http_options.take();
941 let model = model.into();
942 let mut body = build_recontext_image_body(&source, &config)?;
943 if let Some(options) = http_options.as_ref() {
944 merge_extra_body(&mut body, options)?;
945 }
946 let url = build_model_method_url(&self.inner, &model, "predict")?;
947
948 let mut request = self.inner.http.post(url).json(&body);
949 request = apply_http_options(request, http_options.as_ref())?;
950
951 let response = self
952 .inner
953 .send_with_http_options(request, http_options.as_ref())
954 .await?;
955 if !response.status().is_success() {
956 return Err(Error::ApiError {
957 status: response.status().as_u16(),
958 message: response.text().await.unwrap_or_default(),
959 });
960 }
961
962 let value = response.json::<Value>().await?;
963 Ok(parse_recontext_image_response(&value))
964 }
965
966 pub async fn segment_image(
972 &self,
973 model: impl Into<String>,
974 source: SegmentImageSource,
975 mut config: SegmentImageConfig,
976 ) -> Result<SegmentImageResponse> {
977 if self.inner.config.backend != Backend::VertexAi {
978 return Err(Error::InvalidConfig {
979 message: "Segment image is only supported in Vertex AI backend".into(),
980 });
981 }
982
983 let http_options = config.http_options.take();
984 let model = model.into();
985 let mut body = build_segment_image_body(&source, &config)?;
986 if let Some(options) = http_options.as_ref() {
987 merge_extra_body(&mut body, options)?;
988 }
989 let url = build_model_method_url(&self.inner, &model, "predict")?;
990
991 let mut request = self.inner.http.post(url).json(&body);
992 request = apply_http_options(request, http_options.as_ref())?;
993
994 let response = self
995 .inner
996 .send_with_http_options(request, http_options.as_ref())
997 .await?;
998 if !response.status().is_success() {
999 return Err(Error::ApiError {
1000 status: response.status().as_u16(),
1001 message: response.text().await.unwrap_or_default(),
1002 });
1003 }
1004
1005 let value = response.json::<Value>().await?;
1006 Ok(parse_segment_image_response(&value))
1007 }
1008
1009 pub async fn generate_videos(
1015 &self,
1016 model: impl Into<String>,
1017 source: GenerateVideosSource,
1018 mut config: GenerateVideosConfig,
1019 ) -> Result<GenerateVideosOperation> {
1020 let http_options = config.http_options.take();
1021 let model = model.into();
1022 let mut body = build_generate_videos_body(self.inner.config.backend, &source, &config)?;
1023 if let Some(options) = http_options.as_ref() {
1024 merge_extra_body(&mut body, options)?;
1025 }
1026 let url = build_model_method_url(&self.inner, &model, "predictLongRunning")?;
1027
1028 let mut request = self.inner.http.post(url).json(&body);
1029 request = apply_http_options(request, http_options.as_ref())?;
1030
1031 let response = self
1032 .inner
1033 .send_with_http_options(request, http_options.as_ref())
1034 .await?;
1035 if !response.status().is_success() {
1036 return Err(Error::ApiError {
1037 status: response.status().as_u16(),
1038 message: response.text().await.unwrap_or_default(),
1039 });
1040 }
1041
1042 let value = response.json::<Value>().await?;
1043 parse_generate_videos_operation(value, self.inner.config.backend)
1044 }
1045
1046 pub async fn generate_videos_with_prompt(
1052 &self,
1053 model: impl Into<String>,
1054 prompt: impl Into<String>,
1055 config: GenerateVideosConfig,
1056 ) -> Result<GenerateVideosOperation> {
1057 let source = GenerateVideosSource {
1058 prompt: Some(prompt.into()),
1059 ..GenerateVideosSource::default()
1060 };
1061 self.generate_videos(model, source, config).await
1062 }
1063
1064 pub async fn list(&self) -> Result<ListModelsResponse> {
1070 self.list_with_config(ListModelsConfig::default()).await
1071 }
1072
1073 pub async fn list_with_config(&self, config: ListModelsConfig) -> Result<ListModelsResponse> {
1079 let url = build_models_list_url(&self.inner, &config)?;
1080 let request = self.inner.http.get(url);
1081 let response = self.inner.send(request).await?;
1082 if !response.status().is_success() {
1083 return Err(Error::ApiError {
1084 status: response.status().as_u16(),
1085 message: response.text().await.unwrap_or_default(),
1086 });
1087 }
1088 let headers = response.headers().clone();
1089 let mut result = response.json::<ListModelsResponse>().await?;
1090 result.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
1091 Ok(result)
1092 }
1093
1094 pub async fn all(&self) -> Result<Vec<Model>> {
1100 self.all_with_config(ListModelsConfig::default()).await
1101 }
1102
1103 pub async fn all_with_config(&self, mut config: ListModelsConfig) -> Result<Vec<Model>> {
1109 let mut models = Vec::new();
1110 loop {
1111 let response = self.list_with_config(config.clone()).await?;
1112 if let Some(items) = response.models {
1113 models.extend(items);
1114 }
1115 match response.next_page_token {
1116 Some(token) if !token.is_empty() => {
1117 config.page_token = Some(token);
1118 }
1119 _ => break,
1120 }
1121 }
1122 Ok(models)
1123 }
1124
1125 pub async fn get(&self, model: impl Into<String>) -> Result<Model> {
1131 let url = build_model_get_url(&self.inner, &model.into())?;
1132 let request = self.inner.http.get(url);
1133 let response = self.inner.send(request).await?;
1134 if !response.status().is_success() {
1135 return Err(Error::ApiError {
1136 status: response.status().as_u16(),
1137 message: response.text().await.unwrap_or_default(),
1138 });
1139 }
1140 let result = response.json::<Model>().await?;
1141 Ok(result)
1142 }
1143
1144 pub async fn update(
1150 &self,
1151 model: impl Into<String>,
1152 mut config: UpdateModelConfig,
1153 ) -> Result<Model> {
1154 let http_options = config.http_options.take();
1155 let url =
1156 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1157
1158 let mut body = serde_json::to_value(&config)?;
1159 if let Some(options) = http_options.as_ref() {
1160 merge_extra_body(&mut body, options)?;
1161 }
1162 let mut request = self.inner.http.patch(url).json(&body);
1163 request = apply_http_options(request, http_options.as_ref())?;
1164
1165 let response = self
1166 .inner
1167 .send_with_http_options(request, http_options.as_ref())
1168 .await?;
1169 if !response.status().is_success() {
1170 return Err(Error::ApiError {
1171 status: response.status().as_u16(),
1172 message: response.text().await.unwrap_or_default(),
1173 });
1174 }
1175 Ok(response.json::<Model>().await?)
1176 }
1177
1178 pub async fn delete(
1184 &self,
1185 model: impl Into<String>,
1186 mut config: DeleteModelConfig,
1187 ) -> Result<DeleteModelResponse> {
1188 let http_options = config.http_options.take();
1189 let url =
1190 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1191
1192 let mut request = self.inner.http.delete(url);
1193 request = apply_http_options(request, http_options.as_ref())?;
1194
1195 let response = self
1196 .inner
1197 .send_with_http_options(request, http_options.as_ref())
1198 .await?;
1199 if !response.status().is_success() {
1200 return Err(Error::ApiError {
1201 status: response.status().as_u16(),
1202 message: response.text().await.unwrap_or_default(),
1203 });
1204 }
1205 let headers = response.headers().clone();
1206 if response.content_length().unwrap_or(0) == 0 {
1207 let resp = DeleteModelResponse {
1208 sdk_http_response: Some(sdk_http_response_from_headers(&headers)),
1209 };
1210 return Ok(resp);
1211 }
1212 let mut resp = response
1213 .json::<DeleteModelResponse>()
1214 .await
1215 .unwrap_or_default();
1216 resp.sdk_http_response = Some(sdk_http_response_from_headers(&headers));
1217 Ok(resp)
1218 }
1219}
1220
1221#[cfg(test)]
1222mod tests;