zai_rs/model/gen_image/
data.rs

1use super::super::traits::*;
2use super::image_request::{ImageGenBody, ImageQuality, ImageSize};
3use crate::client::http::HttpClient;
4use serde::Serialize;
5use validator::Validate;
6
7/// Image generation request structure
8/// Provides a typed builder around the image generation API body
9pub struct ImageGenRequest<N>
10where
11    N: ModelName + ImageGen + Serialize,
12{
13    /// API key for authentication
14    pub key: String,
15    /// Request body
16    body: ImageGenBody<N>,
17}
18
19impl<N> ImageGenRequest<N>
20where
21    N: ModelName + ImageGen + Serialize,
22{
23    /// Create a new image generation request for the given model and API key
24    pub fn new(model: N, key: String) -> Self {
25        let body = ImageGenBody {
26            model,
27            prompt: None,
28            quality: None,
29            size: None,
30            watermark_enabled: None,
31            user_id: None,
32        };
33        Self { key, body }
34    }
35
36    /// Mutable access to inner body (for advanced customizations)
37    pub fn body_mut(&mut self) -> &mut ImageGenBody<N> {
38        &mut self.body
39    }
40
41    /// Set prompt text
42    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
43        self.body.prompt = Some(prompt.into());
44        self
45    }
46
47    /// Set image quality
48    pub fn with_quality(mut self, quality: ImageQuality) -> Self {
49        self.body.quality = Some(quality);
50        self
51    }
52
53    /// Set image size
54    pub fn with_size(mut self, size: ImageSize) -> Self {
55        self.body.size = Some(size);
56        self
57    }
58
59    /// Enable/disable watermark
60    pub fn with_watermark_enabled(mut self, watermark_enabled: bool) -> Self {
61        self.body.watermark_enabled = Some(watermark_enabled);
62        self
63    }
64
65    /// Set user id
66    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
67        self.body.user_id = Some(user_id.into());
68        self
69    }
70
71    pub fn validate(&self) -> anyhow::Result<()> {
72        // Body-level field validations
73        self.body.validate().map_err(|e| anyhow::anyhow!(e))?;
74        // Require prompt
75        if self
76            .body
77            .prompt
78            .as_deref()
79            .map(|s| s.trim().is_empty())
80            .unwrap_or(true)
81        {
82            return Err(anyhow::anyhow!("prompt is required"));
83        }
84        // Validate custom size when present
85        if let Some(size) = &self.body.size {
86            if let super::image_request::ImageSize::Custom { .. } = size {
87                if !size.is_valid() {
88                    return Err(anyhow::anyhow!(
89                        "invalid custom image size: must be 512..=2048, divisible by 16, and <= 2^21 pixels"
90                    ));
91                }
92            }
93        }
94        Ok(())
95    }
96
97    pub async fn send(&self) -> anyhow::Result<super::image_response::ImageResponse> {
98        self.validate()?;
99        let resp = self.post().await?;
100        let parsed = resp.json::<super::image_response::ImageResponse>().await?;
101        Ok(parsed)
102    }
103}
104
105impl<N> HttpClient for ImageGenRequest<N>
106where
107    N: ModelName + ImageGen + Serialize,
108{
109    type Body = ImageGenBody<N>;
110    type ApiUrl = &'static str;
111    type ApiKey = String;
112
113    fn api_url(&self) -> &Self::ApiUrl {
114        &"https://open.bigmodel.cn/api/paas/v4/images/generations"
115    }
116
117    fn api_key(&self) -> &Self::ApiKey {
118        &self.key
119    }
120
121    fn body(&self) -> &Self::Body {
122        &self.body
123    }
124}