1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
//!
//! Given a input text, outputs if the model classifies it as violating
//! OpenAI's content policy.
//!
//! Related guide: [Moderations](https://platform.openai.com/docs/guides/moderation)
//!
//! Source: OpenAI documentation
////////////////////////////////////////////////////////////////////////////////
use crate::openai::{
endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
types::{common::Error, model::Model, moderation::ModerationResponse},
};
use log::{debug, warn};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
/// Given a input text, outputs if the model classifies it as violating
/// OpenAI's content policy.
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct Moderation {
pub model: Model,
pub input: Vec<String>,
}
impl Default for Moderation {
fn default() -> Self {
Self {
model: Model::TEXT_MODERATION_LATEST,
input: vec![],
}
}
}
impl Moderation {
/// The input text to classify, can provide multiple.
pub fn input(self, content: &str) -> Self {
let mut input = vec![];
if self.input.len() > 0 {
input.extend(self.input);
}
input.push(String::from(content));
Self { input, ..self }
}
/// Two content moderations models are available: `text-moderation-stable`
/// and `text-moderation-latest`.
///
/// The default is `text-moderation-latest` which will be automatically
/// upgraded over time. This ensures you are always using our most accurate
/// model. If you use `text-moderation-stable`, we will provide advanced
/// notice before updating the model. Accuracy of `text-moderation-stable`
/// may be slightly lower than for `text-moderation-latest`.
pub fn model(self, model: Model) -> Self {
Self { model, ..self }
}
/// Send moderation request to OpenAI.
pub async fn moderate(&self) -> Result<ModerationResponse, Box<dyn std::error::Error>> {
if !endpoint_filter(&self.model, &Endpoint::Moderation_v1) {
return Err("Model not compatible with this endpoint".into());
}
let mut moderation_response: Option<ModerationResponse> = None;
request_endpoint(&self, &Endpoint::Moderation_v1, EndpointVariant::None, |res| {
if let Ok(text) = res {
if let Ok(response_data) = serde_json::from_str::<ModerationResponse>(&text) {
debug!(target: "openai", "Response parsed, moderation response deserialized.");
moderation_response = Some(response_data);
} else {
if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
warn!(target: "openai",
"OpenAI error code {}: `{:?}`",
response_error.error.code.unwrap_or(0),
text
);
} else {
warn!(target: "openai", "Edit response not deserializable.");
}
}
}
})
.await?;
if let Some(response_data) = moderation_response {
Ok(response_data)
} else {
Err("No response or error parsing response".into())
}
}
}