1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//!
//! Given a input text, outputs if the model classifies it as violating
//! OpenAI's content policy.
//!
//! Related guide: [Moderations](https://platform.openai.com/docs/guides/moderation)
//!
//! Source: OpenAI documentation

////////////////////////////////////////////////////////////////////////////////

use crate::openai::{
    endpoint::{endpoint_filter, request_endpoint, Endpoint, EndpointVariant},
    types::{common::Error, model::Model, moderation::ModerationResponse},
};
use log::{debug, warn};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;

/// Given a input text, outputs if the model classifies it as violating
/// OpenAI's content policy.
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct Moderation {
    pub model: Model,

    pub input: Vec<String>,
}

impl Default for Moderation {
    fn default() -> Self {
        Self {
            model: Model::TEXT_MODERATION_LATEST,
            input: vec![],
        }
    }
}

impl Moderation {
    /// The input text to classify, can provide multiple.
    pub fn input(self, content: &str) -> Self {
        let mut input = vec![];
        if self.input.len() > 0 {
            input.extend(self.input);
        }

        input.push(String::from(content));
        Self { input, ..self }
    }

    /// Two content moderations models are available: `text-moderation-stable`
    /// and `text-moderation-latest`.
    ///
    /// The default is `text-moderation-latest` which will be automatically
    /// upgraded over time. This ensures you are always using our most accurate
    /// model. If you use `text-moderation-stable`, we will provide advanced
    /// notice before updating the model. Accuracy of `text-moderation-stable`
    /// may be slightly lower than for `text-moderation-latest`.
    pub fn model(self, model: Model) -> Self {
        Self { model, ..self }
    }

    /// Send moderation request to OpenAI.
    pub async fn moderate(&self) -> Result<ModerationResponse, Box<dyn std::error::Error>> {
        if !endpoint_filter(&self.model, &Endpoint::Moderation_v1) {
            return Err("Model not compatible with this endpoint".into());
        }

        let mut moderation_response: Option<ModerationResponse> = None;

        request_endpoint(&self, &Endpoint::Moderation_v1, EndpointVariant::None, |res| {
            if let Ok(text) = res {
                if let Ok(response_data) = serde_json::from_str::<ModerationResponse>(&text) {
                    debug!(target: "openai", "Response parsed, moderation response deserialized.");
                    moderation_response = Some(response_data);
                } else {
                    if let Ok(response_error) = serde_json::from_str::<Error>(&text) {
                        warn!(target: "openai",
                            "OpenAI error code {}: `{:?}`",
                            response_error.error.code.unwrap_or(0),
                            text
                        );
                    } else {
                        warn!(target: "openai", "Edit response not deserializable.");
                    }
                }
            }
        })
        .await?;

        if let Some(response_data) = moderation_response {
            Ok(response_data)
        } else {
            Err("No response or error parsing response".into())
        }
    }
}