use crate::openai::{
endpoint::{
request_endpoint, request_endpoint_form_data, Endpoint, EndpointVariant,
ImageEndpointVariant,
},
types::{
common::Error,
image::{Format, ImageResponse, Size},
},
};
use log::{debug, warn};
use reqwest::multipart::{Form, Part};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct Image {
#[serde(skip)]
pub image: Option<(String, Vec<u8>)>,
#[serde(skip)]
pub mask: Option<(String, Vec<u8>)>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub size: Option<Size>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<Format>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl Default for Image {
fn default() -> Self {
Self {
image: None,
mask: None,
prompt: None,
size: None,
response_format: None,
user: None,
n: None,
}
}
}
impl Image {
fn mime(&self, file_name: &str) -> Result<&'static str, Box<dyn std::error::Error>> {
Ok(
match *file_name.split('.').collect::<Vec<&str>>().last().unwrap() {
"png" => "image/png",
_ => return Err("Unsupported format!".into()),
},
)
}
pub fn image(self, filename: &str, bytes: Vec<u8>) -> Self {
Self {
image: Some((filename.into(), bytes.clone())),
..self
}
}
pub fn mask(self, filename: &str, bytes: Vec<u8>) -> Self {
Self {
mask: Some((filename.into(), bytes.clone())),
..self
}
}
pub fn prompt(self, content: &str) -> Self {
Self {
prompt: Some(content.into()),
..self
}
}
pub fn n(self, n: u32) -> Self {
Self { n: Some(n), ..self }
}
pub fn size(self, size: Size) -> Self {
Self {
size: Some(size),
..self
}
}
pub fn response_format(self, response_format: Format) -> Self {
Self {
response_format: Some(response_format),
..self
}
}
pub fn user(self, user: &str) -> Self {
Self {
user: Some(user.into()),
..self
}
}
pub async fn editing(&self) -> Result<ImageResponse, Box<dyn std::error::Error>> {
if let None = self.image {
return Err("`image` required, call `image()` first".into());
}
if let None = self.prompt {
return Err("`prompt` required, call `prompt()` first".into());
}
let mut image_response: Option<ImageResponse> = None;
let mut form = Form::new();
if let Some(image_tup) = self.image.as_ref() {
let image = Part::bytes(image_tup.1.clone())
.file_name(image_tup.0.clone())
.mime_str(self.mime(&image_tup.0).unwrap())
.unwrap();
form = form.part("image", image);
}
if let Some(mask_tup) = self.mask.as_ref() {
let mask = Part::bytes(mask_tup.1.clone())
.file_name(mask_tup.0.clone())
.mime_str(self.mime(&mask_tup.0).unwrap())
.unwrap();
form = form.part("mask", mask);
}
if let Some(prompt) = self.prompt.clone() {
form = form.part("prompt", Part::text(prompt));
}
if let Some(n) = self.n.clone() {
form = form.part("n", Part::text(n.to_string()));
}
if let Some(size) = self.size.clone() {
let size: &str = size.into();
form = form.part("size", Part::text(size));
}
if let Some(fmt) = self.response_format.clone() {
let fmt: &str = fmt.into();
form = form.part("response_format", Part::text(fmt));
}
if let Some(user) = self.user.clone() {
form = form.part("user", Part::text(user));
}
let variant: String = ImageEndpointVariant::Editing.into();
request_endpoint_form_data(form, &Endpoint::Image_v1, EndpointVariant::from(variant), |res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<ImageResponse>(&text) {
debug!(target: "openai", "Response parsed, image edit response deserialized.");
image_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Image response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = image_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
pub async fn generation(&mut self) -> Result<ImageResponse, Box<dyn std::error::Error>> {
self.image = None;
self.mask = None;
if let None = self.prompt {
return Err("`prompt` required, call `prompt()` first".into());
}
let mut image_response: Option<ImageResponse> = None;
let variant: String = ImageEndpointVariant::Generation.into();
request_endpoint(&self, &Endpoint::Image_v1, EndpointVariant::from(variant), |res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<ImageResponse>(&text) {
debug!(target: "openai", "Response parsed, image generation response deserialized.");
image_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Image response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = image_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
pub async fn variation(&mut self) -> Result<ImageResponse, Box<dyn std::error::Error>> {
if let None = self.image {
return Err("`image` required, call `image()` first".into());
}
self.prompt = None;
self.mask = None;
let mut image_response: Option<ImageResponse> = None;
let mut form = Form::new();
if let Some(image_tup) = self.image.as_ref() {
let image = Part::bytes(image_tup.1.clone())
.file_name(image_tup.0.clone())
.mime_str(self.mime(&image_tup.0).unwrap())
.unwrap();
form = form.part("image", image);
}
if let Some(n) = self.n.clone() {
form = form.part("n", Part::text(n.to_string()));
}
if let Some(size) = self.size.clone() {
let size: &str = size.into();
form = form.part("size", Part::text(size));
}
if let Some(fmt) = self.response_format.clone() {
let fmt: &str = fmt.into();
form = form.part("response_format", Part::text(fmt));
}
if let Some(user) = self.user.clone() {
form = form.part("user", Part::text(user));
}
let variant: String = ImageEndpointVariant::Variation.into();
request_endpoint_form_data(form, &Endpoint::Image_v1, EndpointVariant::from(variant), |res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<ImageResponse>(&text) {
debug!(target: "openai", "Response parsed, image variation response deserialized.");
image_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Image response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = image_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
}