syncable_cli/bedrock/
image.rs

1use super::client::Client;
2use super::types::errors::AwsSdkInvokeModelError;
3use super::types::text_to_image::{TextToImageGeneration, TextToImageResponse};
4use aws_smithy_types::Blob;
5use rig::image_generation::{
6    self, ImageGenerationError, ImageGenerationRequest, ImageGenerationResponse,
7};
8
9/// `amazon.titan-image-generator-v1`
10pub const AMAZON_TITAN_IMAGE_GENERATOR_V1: &str = "amazon.titan-image-generator-v1";
11/// `amazon.titan-image-generator-v2:0`
12pub const AMAZON_TITAN_IMAGE_GENERATOR_V2_0: &str = "amazon.titan-image-generator-v2:0";
13/// `amazon.nova-canvas-v1:0`
14pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
15
16#[derive(Clone)]
17pub struct ImageGenerationModel {
18    pub(crate) client: Client,
19    pub model: String,
20}
21
22impl ImageGenerationModel {
23    pub fn new(client: Client, model: impl Into<String>) -> Self {
24        Self {
25            client,
26            model: model.into(),
27        }
28    }
29}
30
31impl image_generation::ImageGenerationModel for ImageGenerationModel {
32    type Response = TextToImageResponse;
33
34    type Client = Client;
35
36    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
37        Self::new(client.clone(), model)
38    }
39
40    async fn image_generation(
41        &self,
42        generation_request: ImageGenerationRequest,
43    ) -> Result<ImageGenerationResponse<Self::Response>, ImageGenerationError> {
44        let mut request = TextToImageGeneration::new(generation_request.prompt);
45        request.width(generation_request.width);
46        request.height(generation_request.height);
47
48        let body = serde_json::to_string(&request)?;
49        let model_response = self
50            .client
51            .get_inner()
52            .await
53            .invoke_model()
54            .model_id(self.model.as_str())
55            .content_type("application/json")
56            .accept("application/json")
57            .body(Blob::new(body))
58            .send()
59            .await
60            .map_err(|sdk_error| {
61                Into::<ImageGenerationError>::into(AwsSdkInvokeModelError(sdk_error))
62            })?;
63
64        let response_str = String::from_utf8(model_response.body.into_inner())
65            .map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
66
67        let result: TextToImageResponse = serde_json::from_str(&response_str)
68            .map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
69
70        result.try_into()
71    }
72}