zai_rs/model/gen_image/
data.rs1use super::super::traits::*;
2use super::image_request::{ImageGenBody, ImageQuality, ImageSize};
3use crate::client::http::HttpClient;
4use serde::Serialize;
5use validator::Validate;
6
7pub struct ImageGenRequest<N>
10where
11 N: ModelName + ImageGen + Serialize,
12{
13 pub key: String,
15 body: ImageGenBody<N>,
17}
18
19impl<N> ImageGenRequest<N>
20where
21 N: ModelName + ImageGen + Serialize,
22{
23 pub fn new(model: N, key: String) -> Self {
25 let body = ImageGenBody {
26 model,
27 prompt: None,
28 quality: None,
29 size: None,
30 watermark_enabled: None,
31 user_id: None,
32 };
33 Self { key, body }
34 }
35
36 pub fn body_mut(&mut self) -> &mut ImageGenBody<N> {
38 &mut self.body
39 }
40
41 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
43 self.body.prompt = Some(prompt.into());
44 self
45 }
46
47 pub fn with_quality(mut self, quality: ImageQuality) -> Self {
49 self.body.quality = Some(quality);
50 self
51 }
52
53 pub fn with_size(mut self, size: ImageSize) -> Self {
55 self.body.size = Some(size);
56 self
57 }
58
59 pub fn with_watermark_enabled(mut self, watermark_enabled: bool) -> Self {
61 self.body.watermark_enabled = Some(watermark_enabled);
62 self
63 }
64
65 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
67 self.body.user_id = Some(user_id.into());
68 self
69 }
70
71 pub fn validate(&self) -> anyhow::Result<()> {
72 self.body.validate().map_err(|e| anyhow::anyhow!(e))?;
74 if self
76 .body
77 .prompt
78 .as_deref()
79 .map(|s| s.trim().is_empty())
80 .unwrap_or(true)
81 {
82 return Err(anyhow::anyhow!("prompt is required"));
83 }
84 if let Some(size) = &self.body.size {
86 if let super::image_request::ImageSize::Custom { .. } = size {
87 if !size.is_valid() {
88 return Err(anyhow::anyhow!(
89 "invalid custom image size: must be 512..=2048, divisible by 16, and <= 2^21 pixels"
90 ));
91 }
92 }
93 }
94 Ok(())
95 }
96
97 pub async fn send(&self) -> anyhow::Result<super::image_response::ImageResponse> {
98 self.validate()?;
99 let resp = self.post().await?;
100 let parsed = resp.json::<super::image_response::ImageResponse>().await?;
101 Ok(parsed)
102 }
103}
104
105impl<N> HttpClient for ImageGenRequest<N>
106where
107 N: ModelName + ImageGen + Serialize,
108{
109 type Body = ImageGenBody<N>;
110 type ApiUrl = &'static str;
111 type ApiKey = String;
112
113 fn api_url(&self) -> &Self::ApiUrl {
114 &"https://open.bigmodel.cn/api/paas/v4/images/generations"
115 }
116
117 fn api_key(&self) -> &Self::ApiKey {
118 &self.key
119 }
120
121 fn body(&self) -> &Self::Body {
122 &self.body
123 }
124}