1use std::collections::HashMap;
4use std::hash::BuildHasher;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use futures_util::{Stream, StreamExt};
9use rust_genai_types::content::{Content, FunctionCall, Role};
10use rust_genai_types::converters;
11use rust_genai_types::models::{
12 ComputeTokensConfig, ComputeTokensRequest, ComputeTokensResponse, CountTokensConfig,
13 CountTokensRequest, CountTokensResponse, DeleteModelConfig, DeleteModelResponse,
14 EditImageConfig, EditImageResponse, EmbedContentConfig, EmbedContentResponse,
15 GenerateContentConfig, GenerateContentRequest, GenerateImagesConfig, GenerateImagesResponse,
16 GenerateVideosConfig, GenerateVideosSource, Image, ListModelsConfig, ListModelsResponse, Model,
17 RecontextImageConfig, RecontextImageResponse, RecontextImageSource, ReferenceImage,
18 SegmentImageConfig, SegmentImageResponse, SegmentImageSource, UpdateModelConfig,
19};
20use rust_genai_types::response::GenerateContentResponse;
21
22use crate::afc::{
23 call_callable_tools, max_remote_calls, resolve_callable_tools, should_append_history,
24 should_disable_afc, validate_afc_config, validate_afc_tools, CallableTool,
25};
26use crate::client::{Backend, ClientInner};
27use crate::error::{Error, Result};
28use crate::model_capabilities::{
29 validate_code_execution_image_inputs, validate_function_response_media,
30};
31use crate::sse::parse_sse_stream;
32use crate::thinking::{validate_temperature, ThoughtSignatureValidator};
33use crate::tokenizer::TokenEstimator;
34use serde_json::Value;
35
36mod builders;
37mod http;
38mod media;
39mod parsers;
40
41use builders::{
42 build_edit_image_body, build_embed_body_gemini, build_embed_body_vertex,
43 build_function_call_content, build_generate_images_body, build_generate_videos_body,
44 build_recontext_image_body, build_segment_image_body, build_upscale_image_body,
45};
46use http::{
47 apply_http_options, build_model_get_url, build_model_get_url_with_options,
48 build_model_method_url, build_models_list_url, merge_extra_body,
49};
50use parsers::{
51 convert_vertex_embed_response, parse_edit_image_response, parse_generate_images_response,
52 parse_generate_videos_operation, parse_recontext_image_response, parse_segment_image_response,
53 parse_upscale_image_response,
54};
55
56#[derive(Clone)]
57pub struct Models {
58 pub(crate) inner: Arc<ClientInner>,
59}
60
61struct CallableStreamContext<S> {
62 models: Models,
63 model: String,
64 contents: Vec<Content>,
65 request_config: GenerateContentConfig,
66 callable_tools: Vec<Box<dyn CallableTool>>,
67 function_map: HashMap<String, usize, S>,
68 max_calls: usize,
69 append_history: bool,
70}
71
72fn build_synthetic_afc_response(
73 response_content: Content,
74 history: &[Content],
75) -> GenerateContentResponse {
76 let mut response = GenerateContentResponse {
77 candidates: vec![rust_genai_types::response::Candidate {
78 content: Some(response_content),
79 citation_metadata: None,
80 finish_message: None,
81 token_count: None,
82 finish_reason: None,
83 avg_logprobs: None,
84 grounding_metadata: None,
85 index: None,
86 logprobs_result: None,
87 safety_ratings: Vec::new(),
88 url_context_metadata: None,
89 }],
90 create_time: None,
91 automatic_function_calling_history: None,
92 prompt_feedback: None,
93 usage_metadata: None,
94 model_version: None,
95 response_id: None,
96 };
97
98 if !history.is_empty() {
99 response.automatic_function_calling_history = Some(history.to_vec());
100 }
101
102 response
103}
104
105async fn forward_stream_items(
106 mut stream: Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>,
107 tx: &tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
108) -> Option<(Vec<FunctionCall>, Vec<Content>)> {
109 let mut function_calls: Vec<FunctionCall> = Vec::new();
110 let mut response_contents: Vec<Content> = Vec::new();
111
112 while let Some(item) = stream.next().await {
113 if let Ok(response) = &item {
114 if let Some(content) = response.candidates.first().and_then(|c| c.content.clone()) {
115 for part in &content.parts {
116 if let Some(call) = part.function_call_ref() {
117 function_calls.push(call.clone());
118 }
119 }
120 response_contents.push(content);
121 }
122 }
123
124 if tx.send(item).await.is_err() {
125 return None;
126 }
127 }
128
129 Some((function_calls, response_contents))
130}
131
132fn spawn_callable_stream_loop<S>(
133 ctx: CallableStreamContext<S>,
134 tx: tokio::sync::mpsc::Sender<Result<GenerateContentResponse>>,
135) where
136 S: BuildHasher + Sync + Send + 'static,
137{
138 let CallableStreamContext {
139 models,
140 model,
141 contents,
142 request_config,
143 mut callable_tools,
144 function_map,
145 max_calls,
146 append_history,
147 } = ctx;
148 tokio::spawn(async move {
149 let mut conversation = contents;
150 let mut history: Vec<Content> = Vec::new();
151 let mut remaining_calls = max_calls;
152
153 loop {
154 if remaining_calls == 0 {
155 break;
156 }
157
158 let stream = match models
159 .generate_content_stream(&model, conversation.clone(), request_config.clone())
160 .await
161 {
162 Ok(stream) => stream,
163 Err(err) => {
164 let _ = tx.send(Err(err)).await;
165 break;
166 }
167 };
168
169 let Some((function_calls, response_contents)) = forward_stream_items(stream, &tx).await
170 else {
171 return;
172 };
173
174 if function_calls.is_empty() {
175 break;
176 }
177
178 let response_parts = match call_callable_tools(
179 &mut callable_tools,
180 &function_map,
181 &function_calls,
182 )
183 .await
184 {
185 Ok(parts) => parts,
186 Err(err) => {
187 let _ = tx.send(Err(err)).await;
188 break;
189 }
190 };
191
192 if response_parts.is_empty() {
193 break;
194 }
195
196 let call_content = build_function_call_content(&function_calls);
197 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
198
199 if append_history {
200 if history.is_empty() {
201 history.extend(conversation.clone());
202 }
203 history.push(call_content.clone());
204 history.push(response_content.clone());
205 }
206
207 conversation.extend(response_contents);
208 conversation.push(call_content);
209 conversation.push(response_content.clone());
210 remaining_calls = remaining_calls.saturating_sub(1);
211
212 let synthetic = build_synthetic_afc_response(response_content, &history);
213 if tx.send(Ok(synthetic)).await.is_err() {
214 return;
215 }
216 }
217 });
218}
219
220impl Models {
221 pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
222 Self { inner }
223 }
224
225 pub async fn generate_content(
231 &self,
232 model: impl Into<String>,
233 contents: Vec<Content>,
234 ) -> Result<GenerateContentResponse> {
235 self.generate_content_with_config(model, contents, GenerateContentConfig::default())
236 .await
237 }
238
239 pub async fn generate_content_with_config(
245 &self,
246 model: impl Into<String>,
247 contents: Vec<Content>,
248 config: GenerateContentConfig,
249 ) -> Result<GenerateContentResponse> {
250 let model = model.into();
251 validate_temperature(&model, &config)?;
252 ThoughtSignatureValidator::new(&model).validate(&contents)?;
253 validate_function_response_media(&model, &contents)?;
254 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
255
256 let request = GenerateContentRequest {
257 contents,
258 system_instruction: config.system_instruction,
259 generation_config: config.generation_config,
260 safety_settings: config.safety_settings,
261 tools: config.tools,
262 tool_config: config.tool_config,
263 cached_content: config.cached_content,
264 labels: config.labels,
265 };
266
267 let backend = self.inner.config.backend;
268 let url = build_model_method_url(&self.inner, &model, "generateContent")?;
269 let body = match backend {
270 Backend::GeminiApi => converters::generate_content_request_to_mldev(&request)?,
271 Backend::VertexAi => converters::generate_content_request_to_vertex(&request)?,
272 };
273
274 let request = self.inner.http.post(url).json(&body);
275 let response = self.inner.send(request).await?;
276 if !response.status().is_success() {
277 return Err(Error::ApiError {
278 status: response.status().as_u16(),
279 message: response.text().await.unwrap_or_default(),
280 });
281 }
282 let value = response.json::<Value>().await?;
283 let result = match backend {
284 Backend::GeminiApi => converters::generate_content_response_from_mldev(value)?,
285 Backend::VertexAi => converters::generate_content_response_from_vertex(value)?,
286 };
287 Ok(result)
288 }
289
290 pub async fn generate_content_with_callable_tools(
296 &self,
297 model: impl Into<String>,
298 contents: Vec<Content>,
299 config: GenerateContentConfig,
300 mut callable_tools: Vec<Box<dyn CallableTool>>,
301 ) -> Result<GenerateContentResponse> {
302 let model = model.into();
303 if callable_tools.is_empty() {
304 return self
305 .generate_content_with_config(model, contents, config)
306 .await;
307 }
308
309 validate_afc_config(&config)?;
310
311 let mut callable_info = resolve_callable_tools(&mut callable_tools).await?;
312 let has_callable = !callable_info.function_map.is_empty();
313 let mut merged_tools = config.tools.clone().unwrap_or_default();
314 merged_tools.append(&mut callable_info.tools);
315
316 let mut request_config = config.clone();
317 request_config.tools = Some(merged_tools);
318
319 if should_disable_afc(&config, has_callable) {
320 return self
321 .generate_content_with_config(model, contents, request_config)
322 .await;
323 }
324
325 validate_afc_tools(&callable_info.function_map, config.tools.as_deref())?;
326
327 let max_calls = max_remote_calls(&config);
328 let append_history = should_append_history(&config);
329 let mut history: Vec<Content> = Vec::new();
330 let mut conversation = contents.clone();
331 let mut remaining_calls = max_calls;
332 let mut response = self
333 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
334 .await?;
335
336 loop {
337 let function_calls: Vec<FunctionCall> =
338 response.function_calls().into_iter().cloned().collect();
339
340 if function_calls.is_empty() {
341 if append_history && !history.is_empty() {
342 response.automatic_function_calling_history = Some(history);
343 }
344 return Ok(response);
345 }
346
347 if remaining_calls == 0 {
348 break;
349 }
350
351 let response_parts = call_callable_tools(
352 &mut callable_tools,
353 &callable_info.function_map,
354 &function_calls,
355 )
356 .await?;
357 if response_parts.is_empty() {
358 break;
359 }
360
361 let call_content = build_function_call_content(&function_calls);
362 let response_content = Content::from_parts(response_parts.clone(), Role::Function);
363
364 if append_history {
365 if history.is_empty() {
366 history.extend(conversation.clone());
367 }
368 history.push(call_content.clone());
369 history.push(response_content.clone());
370 }
371
372 conversation.push(call_content);
373 conversation.push(response_content);
374 remaining_calls = remaining_calls.saturating_sub(1);
375
376 response = self
377 .generate_content_with_config(&model, conversation.clone(), request_config.clone())
378 .await?;
379 }
380
381 if append_history && !history.is_empty() {
382 response.automatic_function_calling_history = Some(history);
383 }
384 Ok(response)
385 }
386
387 pub async fn generate_content_stream_with_callable_tools(
393 &self,
394 model: impl Into<String>,
395 contents: Vec<Content>,
396 config: GenerateContentConfig,
397 mut callable_tools: Vec<Box<dyn CallableTool>>,
398 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
399 let model = model.into();
400 if callable_tools.is_empty() {
401 return self.generate_content_stream(model, contents, config).await;
402 }
403
404 validate_afc_config(&config)?;
405
406 let callable_info = resolve_callable_tools(&mut callable_tools).await?;
407 let function_map = callable_info.function_map;
408 let has_callable = !function_map.is_empty();
409 let mut merged_tools = config.tools.clone().unwrap_or_default();
410 merged_tools.extend(callable_info.tools);
411
412 let mut request_config = config.clone();
413 request_config.tools = Some(merged_tools);
414
415 if should_disable_afc(&config, has_callable) {
416 return self
417 .generate_content_stream(model, contents, request_config)
418 .await;
419 }
420
421 validate_afc_tools(&function_map, config.tools.as_deref())?;
422
423 let max_calls = max_remote_calls(&config);
424 let append_history = should_append_history(&config);
425 let (tx, rx) = tokio::sync::mpsc::channel(8);
426 let models = self.clone();
427 let ctx = CallableStreamContext {
428 models,
429 model,
430 contents,
431 request_config,
432 callable_tools,
433 function_map,
434 max_calls,
435 append_history,
436 };
437 spawn_callable_stream_loop(ctx, tx);
438
439 let output = futures_util::stream::unfold(rx, |mut rx| async {
440 rx.recv().await.map(|item| (item, rx))
441 });
442
443 Ok(Box::pin(output))
444 }
445
446 pub async fn generate_content_stream(
452 &self,
453 model: impl Into<String>,
454 contents: Vec<Content>,
455 config: GenerateContentConfig,
456 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerateContentResponse>> + Send>>> {
457 let model = model.into();
458 validate_temperature(&model, &config)?;
459 ThoughtSignatureValidator::new(&model).validate(&contents)?;
460 validate_function_response_media(&model, &contents)?;
461 validate_code_execution_image_inputs(&model, &contents, config.tools.as_deref())?;
462
463 let request = GenerateContentRequest {
464 contents,
465 system_instruction: config.system_instruction,
466 generation_config: config.generation_config,
467 safety_settings: config.safety_settings,
468 tools: config.tools,
469 tool_config: config.tool_config,
470 cached_content: config.cached_content,
471 labels: config.labels,
472 };
473
474 let mut url = build_model_method_url(&self.inner, &model, "streamGenerateContent")?;
475 url.push_str("?alt=sse");
476
477 let request = self.inner.http.post(url).json(&request);
478 let response = self.inner.send(request).await?;
479 if !response.status().is_success() {
480 return Err(Error::ApiError {
481 status: response.status().as_u16(),
482 message: response.text().await.unwrap_or_default(),
483 });
484 }
485
486 Ok(Box::pin(parse_sse_stream(response)))
487 }
488
489 pub async fn embed_content(
495 &self,
496 model: impl Into<String>,
497 contents: Vec<Content>,
498 ) -> Result<EmbedContentResponse> {
499 self.embed_content_with_config(model, contents, EmbedContentConfig::default())
500 .await
501 }
502
503 pub async fn embed_content_with_config(
509 &self,
510 model: impl Into<String>,
511 contents: Vec<Content>,
512 config: EmbedContentConfig,
513 ) -> Result<EmbedContentResponse> {
514 let model = model.into();
515 let url = match self.inner.config.backend {
516 Backend::GeminiApi => {
517 build_model_method_url(&self.inner, &model, "batchEmbedContents")?
518 }
519 Backend::VertexAi => build_model_method_url(&self.inner, &model, "predict")?,
520 };
521
522 let body = match self.inner.config.backend {
523 Backend::GeminiApi => build_embed_body_gemini(&model, &contents, &config)?,
524 Backend::VertexAi => build_embed_body_vertex(&contents, &config)?,
525 };
526
527 let request = self.inner.http.post(url).json(&body);
528 let response = self.inner.send(request).await?;
529 if !response.status().is_success() {
530 return Err(Error::ApiError {
531 status: response.status().as_u16(),
532 message: response.text().await.unwrap_or_default(),
533 });
534 }
535
536 match self.inner.config.backend {
537 Backend::GeminiApi => Ok(response.json::<EmbedContentResponse>().await?),
538 Backend::VertexAi => {
539 let value = response.json::<Value>().await?;
540 Ok(convert_vertex_embed_response(&value)?)
541 }
542 }
543 }
544
545 pub async fn count_tokens(
551 &self,
552 model: impl Into<String>,
553 contents: Vec<Content>,
554 ) -> Result<CountTokensResponse> {
555 self.count_tokens_with_config(model, contents, CountTokensConfig::default())
556 .await
557 }
558
559 pub async fn count_tokens_with_config(
565 &self,
566 model: impl Into<String>,
567 contents: Vec<Content>,
568 config: CountTokensConfig,
569 ) -> Result<CountTokensResponse> {
570 let request = CountTokensRequest {
571 contents,
572 system_instruction: config.system_instruction,
573 tools: config.tools,
574 generation_config: config.generation_config,
575 };
576
577 let backend = self.inner.config.backend;
578 let url = build_model_method_url(&self.inner, &model.into(), "countTokens")?;
579 let body = match backend {
580 Backend::GeminiApi => converters::count_tokens_request_to_mldev(&request)?,
581 Backend::VertexAi => converters::count_tokens_request_to_vertex(&request)?,
582 };
583 let request = self.inner.http.post(url).json(&body);
584 let response = self.inner.send(request).await?;
585 if !response.status().is_success() {
586 return Err(Error::ApiError {
587 status: response.status().as_u16(),
588 message: response.text().await.unwrap_or_default(),
589 });
590 }
591 let value = response.json::<Value>().await?;
592 let result = match backend {
593 Backend::GeminiApi => converters::count_tokens_response_from_mldev(value)?,
594 Backend::VertexAi => converters::count_tokens_response_from_vertex(value)?,
595 };
596 Ok(result)
597 }
598
599 pub async fn compute_tokens(
605 &self,
606 model: impl Into<String>,
607 contents: Vec<Content>,
608 ) -> Result<ComputeTokensResponse> {
609 self.compute_tokens_with_config(model, contents, ComputeTokensConfig::default())
610 .await
611 }
612
613 pub async fn compute_tokens_with_config(
619 &self,
620 model: impl Into<String>,
621 contents: Vec<Content>,
622 config: ComputeTokensConfig,
623 ) -> Result<ComputeTokensResponse> {
624 if self.inner.config.backend != Backend::VertexAi {
625 return Err(Error::InvalidConfig {
626 message: "Compute tokens is only supported in Vertex AI backend".into(),
627 });
628 }
629
630 let request = ComputeTokensRequest { contents };
631 let url = build_model_method_url(&self.inner, &model.into(), "computeTokens")?;
632 let mut body = converters::compute_tokens_request_to_vertex(&request)?;
633 if let Some(options) = config.http_options.as_ref() {
634 merge_extra_body(&mut body, options)?;
635 }
636
637 let mut request = self.inner.http.post(url).json(&body);
638 request = apply_http_options(request, config.http_options.as_ref())?;
639
640 let response = self.inner.send(request).await?;
641 if !response.status().is_success() {
642 return Err(Error::ApiError {
643 status: response.status().as_u16(),
644 message: response.text().await.unwrap_or_default(),
645 });
646 }
647 let value = response.json::<Value>().await?;
648 let result = converters::compute_tokens_response_from_vertex(value)?;
649 Ok(result)
650 }
651
652 pub fn estimate_tokens_local(
654 &self,
655 contents: &[Content],
656 estimator: &dyn TokenEstimator,
657 ) -> CountTokensResponse {
658 let total = i32::try_from(estimator.estimate_tokens(contents)).unwrap_or(i32::MAX);
659 CountTokensResponse {
660 total_tokens: Some(total),
661 cached_content_token_count: None,
662 }
663 }
664
665 pub fn estimate_tokens_local_with_config(
667 &self,
668 contents: &[Content],
669 config: &CountTokensConfig,
670 estimator: &dyn TokenEstimator,
671 ) -> CountTokensResponse {
672 let estimation_contents = crate::tokenizer::build_estimation_contents(contents, config);
673 let total =
674 i32::try_from(estimator.estimate_tokens(&estimation_contents)).unwrap_or(i32::MAX);
675 CountTokensResponse {
676 total_tokens: Some(total),
677 cached_content_token_count: None,
678 }
679 }
680
681 pub async fn count_tokens_or_estimate(
687 &self,
688 model: impl Into<String> + Send,
689 contents: Vec<Content>,
690 config: CountTokensConfig,
691 estimator: Option<&(dyn TokenEstimator + Sync)>,
692 ) -> Result<CountTokensResponse> {
693 if let Some(estimator) = estimator {
694 return Ok(self.estimate_tokens_local_with_config(&contents, &config, estimator));
695 }
696 self.count_tokens_with_config(model, contents, config).await
697 }
698
699 pub async fn generate_images(
705 &self,
706 model: impl Into<String>,
707 prompt: impl Into<String>,
708 mut config: GenerateImagesConfig,
709 ) -> Result<GenerateImagesResponse> {
710 let http_options = config.http_options.take();
711 let model = model.into();
712 let prompt = prompt.into();
713 let mut body = build_generate_images_body(self.inner.config.backend, &prompt, &config)?;
714 if let Some(options) = http_options.as_ref() {
715 merge_extra_body(&mut body, options)?;
716 }
717 let url = build_model_method_url(&self.inner, &model, "predict")?;
718
719 let mut request = self.inner.http.post(url).json(&body);
720 request = apply_http_options(request, http_options.as_ref())?;
721
722 let response = self.inner.send(request).await?;
723 if !response.status().is_success() {
724 return Err(Error::ApiError {
725 status: response.status().as_u16(),
726 message: response.text().await.unwrap_or_default(),
727 });
728 }
729
730 let value = response.json::<Value>().await?;
731 Ok(parse_generate_images_response(&value))
732 }
733
734 pub async fn edit_image(
740 &self,
741 model: impl Into<String>,
742 prompt: impl Into<String>,
743 reference_images: Vec<ReferenceImage>,
744 mut config: EditImageConfig,
745 ) -> Result<EditImageResponse> {
746 if self.inner.config.backend != Backend::VertexAi {
747 return Err(Error::InvalidConfig {
748 message: "Edit image is only supported in Vertex AI backend".into(),
749 });
750 }
751
752 let http_options = config.http_options.take();
753 let model = model.into();
754 let prompt = prompt.into();
755 let mut body = build_edit_image_body(&prompt, &reference_images, &config)?;
756 if let Some(options) = http_options.as_ref() {
757 merge_extra_body(&mut body, options)?;
758 }
759 let url = build_model_method_url(&self.inner, &model, "predict")?;
760
761 let mut request = self.inner.http.post(url).json(&body);
762 request = apply_http_options(request, http_options.as_ref())?;
763
764 let response = self.inner.send(request).await?;
765 if !response.status().is_success() {
766 return Err(Error::ApiError {
767 status: response.status().as_u16(),
768 message: response.text().await.unwrap_or_default(),
769 });
770 }
771
772 let value = response.json::<Value>().await?;
773 Ok(parse_edit_image_response(&value))
774 }
775
776 pub async fn upscale_image(
782 &self,
783 model: impl Into<String>,
784 image: Image,
785 upscale_factor: impl Into<String>,
786 mut config: rust_genai_types::models::UpscaleImageConfig,
787 ) -> Result<rust_genai_types::models::UpscaleImageResponse> {
788 if self.inner.config.backend != Backend::VertexAi {
789 return Err(Error::InvalidConfig {
790 message: "Upscale image is only supported in Vertex AI backend".into(),
791 });
792 }
793
794 let http_options = config.http_options.take();
795 let model = model.into();
796 let upscale_factor = upscale_factor.into();
797 let mut body = build_upscale_image_body(&image, &upscale_factor, &config)?;
798 if let Some(options) = http_options.as_ref() {
799 merge_extra_body(&mut body, options)?;
800 }
801 let url = build_model_method_url(&self.inner, &model, "predict")?;
802
803 let mut request = self.inner.http.post(url).json(&body);
804 request = apply_http_options(request, http_options.as_ref())?;
805
806 let response = self.inner.send(request).await?;
807 if !response.status().is_success() {
808 return Err(Error::ApiError {
809 status: response.status().as_u16(),
810 message: response.text().await.unwrap_or_default(),
811 });
812 }
813
814 let value = response.json::<Value>().await?;
815 Ok(parse_upscale_image_response(&value))
816 }
817
818 pub async fn recontext_image(
824 &self,
825 model: impl Into<String>,
826 source: RecontextImageSource,
827 mut config: RecontextImageConfig,
828 ) -> Result<RecontextImageResponse> {
829 if self.inner.config.backend != Backend::VertexAi {
830 return Err(Error::InvalidConfig {
831 message: "Recontext image is only supported in Vertex AI backend".into(),
832 });
833 }
834
835 let http_options = config.http_options.take();
836 let model = model.into();
837 let mut body = build_recontext_image_body(&source, &config)?;
838 if let Some(options) = http_options.as_ref() {
839 merge_extra_body(&mut body, options)?;
840 }
841 let url = build_model_method_url(&self.inner, &model, "predict")?;
842
843 let mut request = self.inner.http.post(url).json(&body);
844 request = apply_http_options(request, http_options.as_ref())?;
845
846 let response = self.inner.send(request).await?;
847 if !response.status().is_success() {
848 return Err(Error::ApiError {
849 status: response.status().as_u16(),
850 message: response.text().await.unwrap_or_default(),
851 });
852 }
853
854 let value = response.json::<Value>().await?;
855 Ok(parse_recontext_image_response(&value))
856 }
857
858 pub async fn segment_image(
864 &self,
865 model: impl Into<String>,
866 source: SegmentImageSource,
867 mut config: SegmentImageConfig,
868 ) -> Result<SegmentImageResponse> {
869 if self.inner.config.backend != Backend::VertexAi {
870 return Err(Error::InvalidConfig {
871 message: "Segment image is only supported in Vertex AI backend".into(),
872 });
873 }
874
875 let http_options = config.http_options.take();
876 let model = model.into();
877 let mut body = build_segment_image_body(&source, &config)?;
878 if let Some(options) = http_options.as_ref() {
879 merge_extra_body(&mut body, options)?;
880 }
881 let url = build_model_method_url(&self.inner, &model, "predict")?;
882
883 let mut request = self.inner.http.post(url).json(&body);
884 request = apply_http_options(request, http_options.as_ref())?;
885
886 let response = self.inner.send(request).await?;
887 if !response.status().is_success() {
888 return Err(Error::ApiError {
889 status: response.status().as_u16(),
890 message: response.text().await.unwrap_or_default(),
891 });
892 }
893
894 let value = response.json::<Value>().await?;
895 Ok(parse_segment_image_response(&value))
896 }
897
898 pub async fn generate_videos(
904 &self,
905 model: impl Into<String>,
906 source: GenerateVideosSource,
907 mut config: GenerateVideosConfig,
908 ) -> Result<rust_genai_types::operations::Operation> {
909 let http_options = config.http_options.take();
910 let model = model.into();
911 let mut body = build_generate_videos_body(self.inner.config.backend, &source, &config)?;
912 if let Some(options) = http_options.as_ref() {
913 merge_extra_body(&mut body, options)?;
914 }
915 let url = build_model_method_url(&self.inner, &model, "predictLongRunning")?;
916
917 let mut request = self.inner.http.post(url).json(&body);
918 request = apply_http_options(request, http_options.as_ref())?;
919
920 let response = self.inner.send(request).await?;
921 if !response.status().is_success() {
922 return Err(Error::ApiError {
923 status: response.status().as_u16(),
924 message: response.text().await.unwrap_or_default(),
925 });
926 }
927
928 let value = response.json::<Value>().await?;
929 parse_generate_videos_operation(value, self.inner.config.backend)
930 }
931
932 pub async fn generate_videos_with_prompt(
938 &self,
939 model: impl Into<String>,
940 prompt: impl Into<String>,
941 config: GenerateVideosConfig,
942 ) -> Result<rust_genai_types::operations::Operation> {
943 let source = GenerateVideosSource {
944 prompt: Some(prompt.into()),
945 ..GenerateVideosSource::default()
946 };
947 self.generate_videos(model, source, config).await
948 }
949
950 pub async fn list(&self) -> Result<ListModelsResponse> {
956 self.list_with_config(ListModelsConfig::default()).await
957 }
958
959 pub async fn list_with_config(&self, config: ListModelsConfig) -> Result<ListModelsResponse> {
965 let url = build_models_list_url(&self.inner, &config)?;
966 let request = self.inner.http.get(url);
967 let response = self.inner.send(request).await?;
968 if !response.status().is_success() {
969 return Err(Error::ApiError {
970 status: response.status().as_u16(),
971 message: response.text().await.unwrap_or_default(),
972 });
973 }
974 let result = response.json::<ListModelsResponse>().await?;
975 Ok(result)
976 }
977
978 pub async fn all(&self) -> Result<Vec<Model>> {
984 self.all_with_config(ListModelsConfig::default()).await
985 }
986
987 pub async fn all_with_config(&self, mut config: ListModelsConfig) -> Result<Vec<Model>> {
993 let mut models = Vec::new();
994 loop {
995 let response = self.list_with_config(config.clone()).await?;
996 if let Some(items) = response.models {
997 models.extend(items);
998 }
999 match response.next_page_token {
1000 Some(token) if !token.is_empty() => {
1001 config.page_token = Some(token);
1002 }
1003 _ => break,
1004 }
1005 }
1006 Ok(models)
1007 }
1008
1009 pub async fn get(&self, model: impl Into<String>) -> Result<Model> {
1015 let url = build_model_get_url(&self.inner, &model.into())?;
1016 let request = self.inner.http.get(url);
1017 let response = self.inner.send(request).await?;
1018 if !response.status().is_success() {
1019 return Err(Error::ApiError {
1020 status: response.status().as_u16(),
1021 message: response.text().await.unwrap_or_default(),
1022 });
1023 }
1024 let result = response.json::<Model>().await?;
1025 Ok(result)
1026 }
1027
1028 pub async fn update(
1034 &self,
1035 model: impl Into<String>,
1036 mut config: UpdateModelConfig,
1037 ) -> Result<Model> {
1038 let http_options = config.http_options.take();
1039 let url =
1040 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1041
1042 let mut body = serde_json::to_value(&config)?;
1043 if let Some(options) = http_options.as_ref() {
1044 merge_extra_body(&mut body, options)?;
1045 }
1046 let mut request = self.inner.http.patch(url).json(&body);
1047 request = apply_http_options(request, http_options.as_ref())?;
1048
1049 let response = self.inner.send(request).await?;
1050 if !response.status().is_success() {
1051 return Err(Error::ApiError {
1052 status: response.status().as_u16(),
1053 message: response.text().await.unwrap_or_default(),
1054 });
1055 }
1056 Ok(response.json::<Model>().await?)
1057 }
1058
1059 pub async fn delete(
1065 &self,
1066 model: impl Into<String>,
1067 mut config: DeleteModelConfig,
1068 ) -> Result<DeleteModelResponse> {
1069 let http_options = config.http_options.take();
1070 let url =
1071 build_model_get_url_with_options(&self.inner, &model.into(), http_options.as_ref())?;
1072
1073 let mut request = self.inner.http.delete(url);
1074 request = apply_http_options(request, http_options.as_ref())?;
1075
1076 let response = self.inner.send(request).await?;
1077 if !response.status().is_success() {
1078 return Err(Error::ApiError {
1079 status: response.status().as_u16(),
1080 message: response.text().await.unwrap_or_default(),
1081 });
1082 }
1083 if response.content_length().unwrap_or(0) == 0 {
1084 return Ok(DeleteModelResponse::default());
1085 }
1086 Ok(response
1087 .json::<DeleteModelResponse>()
1088 .await
1089 .unwrap_or_default())
1090 }
1091}
1092
1093#[cfg(test)]
1094mod tests;