tame_webpurify/
client.rs

1use http::header::CONTENT_TYPE;
2use http::{Request, Response, Uri};
3use url::form_urlencoded;
4
5#[derive(thiserror::Error, Debug)]
6#[allow(clippy::upper_case_acronyms)]
7pub enum RequestError {
8    #[error("The provided Uri was invalid")]
9    InvalidUri,
10
11    #[error(transparent)]
12    HTTP(#[from] http::Error),
13}
14
15#[derive(thiserror::Error, Debug)]
16#[allow(clippy::upper_case_acronyms)]
17pub enum ResponseError {
18    #[error("The response status was invalid: {0}")]
19    HttpStatus(http::StatusCode),
20    #[error(transparent)]
21    Deserialize(#[from] serde_json::Error),
22    #[error("Missing field {0} in response")]
23    MissingField(String),
24    #[error("Invalid field {0} in response")]
25    InvalidField(String),
26    #[error("The API key passed was not valid: {0}")]
27    InvalidApiKey(String),
28    #[error("The API key passed is inactive or has been revoked: {0}")]
29    InactiveApiKey(String),
30    #[error("API Key was not included in request: {0}")]
31    MissingApiKey(String),
32    #[error("The requested service is temporarily unavailable: {0}")]
33    ServiceUnavailable(String),
34    #[error("Unknown error code returned: {0} {1}")]
35    UnknownErr(String, String),
36    #[error("Non ok stat returned: {0}")]
37    NonOkStat(String),
38    #[error("Got mis matched method in response, got: {0} expected: {1}")]
39    MisMatchedMethod(String, String),
40}
41
42#[derive(Clone, Copy)]
43pub enum Region {
44    Europe,
45    Us,
46    Asia,
47    Es,
48}
49
50#[derive(Eq, PartialEq)]
51pub enum Method {
52    /// webpurify.live.check
53    Check,
54    /// webpurify.live.replace
55    Replace(String),
56}
57
58impl Method {
59    fn method_str(&self) -> &'static str {
60        match self {
61            Method::Check => "webpurify.live.check",
62            Method::Replace(_) => "webpurify.live.replace",
63        }
64    }
65}
66
67fn api_url_by_region(region: Region) -> String {
68    match region {
69        Region::Us => "https://api1.webpurify.com/services/rest/",
70        Region::Europe => "https://api1-eu.webpurify.com/services/rest/",
71        Region::Asia => "https://api1-ap.webpurify.com/services/rest/",
72        Region::Es => "https://es-api.webpurify.net/services/rest/",
73    }
74    .to_string()
75}
76
77/// method: Which method should we use on matched strings?
78///     Check - returns 1 if profanity is found, otherwise 0
79///     Replace - returns 1 if profanity if found and replaces
80pub fn query_string(api_key: &str, text: &str, method: Method) -> String {
81    let method_str = method.method_str();
82
83    let mut serializer = form_urlencoded::Serializer::new(String::new());
84    let qs = serializer
85        .append_pair("format", "json")
86        .append_pair("api_key", api_key)
87        .append_pair("text", text)
88        .append_pair("method", &method_str)
89        .append_pair("semail", "1")
90        .append_pair("slink", "1")
91        .append_pair("rsp", "1")
92        .append_pair("sphone", "1");
93
94    if let Method::Replace(replace_with) = method {
95        qs.append_pair("replacesymbol", &replace_with);
96    }
97
98    qs.finish()
99}
100
101pub(crate) fn into_uri<U: TryInto<Uri>>(uri: U) -> Result<Uri, RequestError> {
102    uri.try_into().map_err(|_err| RequestError::InvalidUri)
103}
104
105fn request_builder(api_uri: String) -> Result<Request<Vec<u8>>, RequestError> {
106    let request_builder = Request::builder()
107        .method("POST")
108        .uri(into_uri(api_uri)?)
109        .header(CONTENT_TYPE, "application/json");
110
111    let req = request_builder.body(vec![])?;
112    Ok(req)
113}
114
115/// `WebPurify` returns the number of matched profanities, PII etc.
116/// This function only returns a request object, you need to do the actual HTTP request yourself.
117///
118/// Extend the function when more languages are required.
119/// Documentation: <https://www.webpurify.com/documentation/additional/language/>
120///
121/// # Arguments
122///
123/// * `api_key` - a string slice that holds your `WebPurify` API Key
124///
125/// * `region` - the regional `WebPurify` API to use
126///
127/// * `text` - a string slice to be checked by `WebPurify`
128///
129/// # Examples
130/// ```
131/// use tame_webpurify::client;
132/// let res = client::profanity_check_request("some-api-key", client::Region::Europe, "my filthy user-input string");
133/// ```
134pub fn profanity_check_request(
135    api_key: &str,
136    region: Region,
137    text: &str,
138) -> Result<Request<Vec<u8>>, RequestError> {
139    let qs = query_string(api_key, text, Method::Check);
140    let api_uri = format!("{}?{}", api_url_by_region(region), qs);
141
142    let req = request_builder(api_uri)?;
143    Ok(req)
144}
145
146/// `WebPurify` replaces matched profanities, PII etc with a given symbol.
147/// This function only returns a request object, you need to do the actual HTTP request yourself.
148///
149/// Extend the function when more languages are required.
150/// Documentation: <https://www.webpurify.com/documentation/additional/language/>
151///
152/// # Arguments
153///
154/// * `api_key` - a string slice that holds your `WebPurify` API Key
155///
156/// * `region` - the regional `WebPurify` API to use
157///
158/// * `text` - a string slice you want to be moderated by `WebPurify`
159///
160/// * `replace_text` - a string slice to replace profanities in `text` with
161///
162/// # Examples
163/// ```
164/// use tame_webpurify::client;
165/// let res = client::profanity_replace_request("some-api-key", client::Region::Europe, "my filthy user-input string", "*");
166/// ```
167pub fn profanity_replace_request(
168    api_key: &str,
169    region: Region,
170    text: &str,
171    replace_text: &str,
172) -> Result<Request<Vec<u8>>, RequestError> {
173    let qs = query_string(api_key, text, Method::Replace(replace_text.to_string()));
174    let api_uri = format!("{}?{}", api_url_by_region(region), qs);
175
176    let req = request_builder(api_uri)?;
177    Ok(req)
178}
179
180#[derive(serde::Deserialize)]
181struct ApiResponse {
182    rsp: ApiResponseRsp,
183}
184
185#[derive(serde::Deserialize)]
186struct ApiResponseRsp {
187    #[serde(rename = "@attributes")]
188    attributes: ApiResponseRspAttributes,
189    err: Option<ApiResponseErr>,
190    method: Option<String>,
191    found: Option<String>,
192    text: Option<String>,
193}
194
195#[derive(serde::Deserialize)]
196struct ApiResponseRspAttributes {
197    stat: String,
198}
199
200#[derive(serde::Deserialize)]
201struct ApiResponseErr {
202    #[serde(rename = "@attributes")]
203    attributes: ApiResponseErrAttributes,
204}
205
206#[derive(serde::Deserialize)]
207struct ApiResponseErrAttributes {
208    code: String,
209    msg: String,
210}
211
212fn parse_response<T>(response: Response<T>, method: Method) -> Result<ApiResponse, ResponseError>
213where
214    T: AsRef<[u8]>,
215{
216    if !response.status().is_success() {
217        return Err(ResponseError::HttpStatus(response.status()));
218    }
219
220    let body = response.body();
221    let api_response: ApiResponse = serde_json::from_slice(body.as_ref())?;
222
223    if let Some(ApiResponseErr {
224        attributes: ApiResponseErrAttributes { code, msg },
225    }) = api_response.rsp.err
226    {
227        let err = match code.as_str() {
228            "100" => ResponseError::InvalidApiKey(msg),
229            "101" => ResponseError::InactiveApiKey(msg),
230            "102" => ResponseError::MissingApiKey(msg),
231            "103" => ResponseError::ServiceUnavailable(msg),
232            _ => ResponseError::UnknownErr(code, msg),
233        };
234        return Err(err);
235    }
236
237    if !api_response.rsp.attributes.stat.eq("ok") {
238        return Err(ResponseError::NonOkStat(api_response.rsp.attributes.stat));
239    }
240
241    if !api_response
242        .rsp
243        .method
244        .as_ref()
245        .map(|s| s.as_str())
246        .eq(&Some(method.method_str()))
247    {
248        return Err(ResponseError::MisMatchedMethod(
249            api_response.rsp.method.unwrap_or_default(),
250            method.method_str().to_owned(),
251        ));
252    }
253
254    Ok(api_response)
255}
256
257/// Returns true if `WebPurify` flagged a request to contain profanities, PII, etc
258///
259/// # Arguments
260///
261/// * `response` - a response object from the `WebPurify` `check` API call
262///
263pub fn profanity_check_result<T>(response: Response<T>) -> Result<bool, ResponseError>
264where
265    T: AsRef<[u8]>,
266{
267    let response = parse_response(response, Method::Check)?;
268
269    let check: u32 = response
270        .rsp
271        .found
272        .ok_or_else(|| ResponseError::MissingField("found".to_owned()))
273        .and_then(|found| {
274            found
275                .parse()
276                .map_err(|_err| ResponseError::InvalidField("found".to_owned()))
277        })?;
278
279    Ok(check > 0)
280}
281
282/// Returns the sanitized string from a response object.
283///
284/// # Arguments
285///
286/// * `response` - a response object from the `WebPurify` `replace` API call
287///
288pub fn profanity_replace_result<T>(response: Response<T>) -> Result<String, ResponseError>
289where
290    T: AsRef<[u8]>,
291{
292    let response = parse_response(response, Method::Replace("".to_owned()))?; // TODO It is inconvenient to pass in the replace char here
293
294    match response.rsp.text {
295        Some(text) => Ok(text),
296        None => Err(ResponseError::MissingField("text".to_owned())),
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use std::error::Error;
303
304    use crate::client;
305    use http::Request;
306    use http::Response;
307    use http::StatusCode;
308
309    fn uri_contains(req: &Request<Vec<u8>>, needle: &str) -> bool {
310        req.uri().to_string().contains(needle)
311    }
312
313    #[test]
314    fn qs_encoding() {
315        assert_eq!(
316            client::query_string("abcd", "hi there", client::Method::Check),
317            "format=json&api_key=abcd&text=hi+there&method=webpurify.live.check&semail=1&slink=1&rsp=1&sphone=1"
318        );
319    }
320
321    #[test]
322    fn check_request() {
323        let region = client::Region::Europe;
324        let req = client::profanity_check_request("abcd", region, "hi there");
325        assert_eq!(
326            req.unwrap().uri(),
327            "https://api1-eu.webpurify.com/services/rest/?format=json&api_key=abcd&text=hi+there&method=webpurify.live.check&semail=1&slink=1&rsp=1&sphone=1"
328        );
329    }
330
331    #[test]
332    fn check_result() -> Result<(), Box<dyn Error>> {
333        let response_found = |found: u32| {
334            let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"found\":\"{found}\",\"api_key\":\"123\"}}}}");
335            Response::builder()
336                .status(StatusCode::OK)
337                .body(body.as_bytes().to_vec())
338        };
339        let result = client::profanity_check_result(response_found(3)?)?;
340        assert!(result);
341        let result = client::profanity_check_result(response_found(0)?)?;
342        assert!(!result);
343        Ok(())
344    }
345
346    #[test]
347    fn check_result_missing_found() -> Result<(), Box<dyn Error>> {
348        let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"api_key\":\"123\"}}}}");
349        let response = Response::builder()
350            .status(StatusCode::OK)
351            .body(body.as_bytes().to_vec());
352        let result = client::profanity_check_result(response?);
353        assert!(result.is_err());
354        Ok(())
355    }
356
357    #[test]
358    fn replace_request() {
359        let region = client::Region::Europe;
360        let res_req = client::profanity_replace_request("abcd", region, "hi there", "*");
361        let req = res_req.unwrap();
362        assert!(uri_contains(&req, "method=webpurify.live.replace"));
363        assert!(uri_contains(&req, "replacesymbol=*"));
364        assert!(uri_contains(&req, "text=hi+there"));
365    }
366
367    #[test]
368    fn replace_result() -> Result<(), Box<dyn Error>> {
369        let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"found\":\"3\",\"text\":\"foo\",\"api_key\":\"123\"}}";
370        let response = Response::builder()
371            .status(StatusCode::OK)
372            .body((*body).into_iter().collect::<Vec<_>>())?;
373        let result = client::profanity_replace_result(response)?;
374
375        assert_eq!(result, "foo".to_owned());
376        Ok(())
377    }
378
379    #[test]
380    fn replace_result_missing_found() -> Result<(), Box<dyn Error>> {
381        let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"text\":\"foo\",\"api_key\":\"123\"}}";
382        let response = Response::builder()
383            .status(StatusCode::OK)
384            .body((*body).into_iter().collect::<Vec<_>>())?;
385        let result = client::profanity_replace_result(response)?;
386
387        assert_eq!(result, "foo".to_owned());
388        Ok(())
389    }
390
391    #[test]
392    fn response_errors() -> Result<(), Box<dyn Error>> {
393        let response = |code: u32| {
394            let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"fail\"}},\"err\":{{\"@attributes\":{{\"code\":\"{code}\",\"msg\":\"Msg\"}}}},\"text\":\"text\"}}}}");
395            Response::builder()
396                .status(StatusCode::OK)
397                .body(body.as_bytes().to_vec())
398        };
399
400        for (code, err) in [
401            (100, client::ResponseError::InvalidApiKey("Msg".to_owned())),
402            (101, client::ResponseError::InactiveApiKey("Msg".to_owned())),
403            (102, client::ResponseError::MissingApiKey("Msg".to_owned())),
404            (
405                103,
406                client::ResponseError::ServiceUnavailable("Msg".to_owned()),
407            ),
408            (
409                999,
410                client::ResponseError::UnknownErr("999".to_owned(), "Msg".to_owned()),
411            ),
412        ] {
413            let result = client::profanity_replace_result(response(code)?);
414            let result_err = result.err().expect("Expected error");
415            assert!(
416                std::mem::discriminant(&result_err) == std::mem::discriminant(&err),
417                "Expected error: {:?} but got: {:?}",
418                err,
419                result_err
420            );
421        }
422
423        Ok(())
424    }
425
426    #[test]
427    fn mismatched_response_methods() -> Result<(), Box<dyn Error>> {
428        // Check treated as replace result
429        let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"found\":\"1\",\"api_key\":\"123\"}}}}");
430        let response = Response::builder()
431            .status(StatusCode::OK)
432            .body(body.as_bytes().to_vec());
433        let result = client::profanity_replace_result(response?);
434        assert!(result.is_err());
435
436        // Replace treated as check result
437        let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"found\":\"3\",\"text\":\"foo\",\"api_key\":\"123\"}}";
438        let response = Response::builder()
439            .status(StatusCode::OK)
440            .body((*body).into_iter().collect::<Vec<_>>())?;
441        let result = client::profanity_check_result(response);
442        assert!(result.is_err());
443
444        Ok(())
445    }
446}