1use http::header::CONTENT_TYPE;
2use http::{Request, Response, Uri};
3use url::form_urlencoded;
4
5#[derive(thiserror::Error, Debug)]
6#[allow(clippy::upper_case_acronyms)]
7pub enum RequestError {
8 #[error("The provided Uri was invalid")]
9 InvalidUri,
10
11 #[error(transparent)]
12 HTTP(#[from] http::Error),
13}
14
15#[derive(thiserror::Error, Debug)]
16#[allow(clippy::upper_case_acronyms)]
17pub enum ResponseError {
18 #[error("The response status was invalid: {0}")]
19 HttpStatus(http::StatusCode),
20 #[error(transparent)]
21 Deserialize(#[from] serde_json::Error),
22 #[error("Missing field {0} in response")]
23 MissingField(String),
24 #[error("Invalid field {0} in response")]
25 InvalidField(String),
26 #[error("The API key passed was not valid: {0}")]
27 InvalidApiKey(String),
28 #[error("The API key passed is inactive or has been revoked: {0}")]
29 InactiveApiKey(String),
30 #[error("API Key was not included in request: {0}")]
31 MissingApiKey(String),
32 #[error("The requested service is temporarily unavailable: {0}")]
33 ServiceUnavailable(String),
34 #[error("Unknown error code returned: {0} {1}")]
35 UnknownErr(String, String),
36 #[error("Non ok stat returned: {0}")]
37 NonOkStat(String),
38 #[error("Got mis matched method in response, got: {0} expected: {1}")]
39 MisMatchedMethod(String, String),
40}
41
42#[derive(Clone, Copy)]
43pub enum Region {
44 Europe,
45 Us,
46 Asia,
47 Es,
48}
49
50#[derive(Eq, PartialEq)]
51pub enum Method {
52 Check,
54 Replace(String),
56}
57
58impl Method {
59 fn method_str(&self) -> &'static str {
60 match self {
61 Method::Check => "webpurify.live.check",
62 Method::Replace(_) => "webpurify.live.replace",
63 }
64 }
65}
66
67fn api_url_by_region(region: Region) -> String {
68 match region {
69 Region::Us => "https://api1.webpurify.com/services/rest/",
70 Region::Europe => "https://api1-eu.webpurify.com/services/rest/",
71 Region::Asia => "https://api1-ap.webpurify.com/services/rest/",
72 Region::Es => "https://es-api.webpurify.net/services/rest/",
73 }
74 .to_string()
75}
76
77pub fn query_string(api_key: &str, text: &str, method: Method) -> String {
81 let method_str = method.method_str();
82
83 let mut serializer = form_urlencoded::Serializer::new(String::new());
84 let qs = serializer
85 .append_pair("format", "json")
86 .append_pair("api_key", api_key)
87 .append_pair("text", text)
88 .append_pair("method", &method_str)
89 .append_pair("semail", "1")
90 .append_pair("slink", "1")
91 .append_pair("rsp", "1")
92 .append_pair("sphone", "1");
93
94 if let Method::Replace(replace_with) = method {
95 qs.append_pair("replacesymbol", &replace_with);
96 }
97
98 qs.finish()
99}
100
101pub(crate) fn into_uri<U: TryInto<Uri>>(uri: U) -> Result<Uri, RequestError> {
102 uri.try_into().map_err(|_err| RequestError::InvalidUri)
103}
104
105fn request_builder(api_uri: String) -> Result<Request<Vec<u8>>, RequestError> {
106 let request_builder = Request::builder()
107 .method("POST")
108 .uri(into_uri(api_uri)?)
109 .header(CONTENT_TYPE, "application/json");
110
111 let req = request_builder.body(vec![])?;
112 Ok(req)
113}
114
115pub fn profanity_check_request(
135 api_key: &str,
136 region: Region,
137 text: &str,
138) -> Result<Request<Vec<u8>>, RequestError> {
139 let qs = query_string(api_key, text, Method::Check);
140 let api_uri = format!("{}?{}", api_url_by_region(region), qs);
141
142 let req = request_builder(api_uri)?;
143 Ok(req)
144}
145
146pub fn profanity_replace_request(
168 api_key: &str,
169 region: Region,
170 text: &str,
171 replace_text: &str,
172) -> Result<Request<Vec<u8>>, RequestError> {
173 let qs = query_string(api_key, text, Method::Replace(replace_text.to_string()));
174 let api_uri = format!("{}?{}", api_url_by_region(region), qs);
175
176 let req = request_builder(api_uri)?;
177 Ok(req)
178}
179
180#[derive(serde::Deserialize)]
181struct ApiResponse {
182 rsp: ApiResponseRsp,
183}
184
185#[derive(serde::Deserialize)]
186struct ApiResponseRsp {
187 #[serde(rename = "@attributes")]
188 attributes: ApiResponseRspAttributes,
189 err: Option<ApiResponseErr>,
190 method: Option<String>,
191 found: Option<String>,
192 text: Option<String>,
193}
194
195#[derive(serde::Deserialize)]
196struct ApiResponseRspAttributes {
197 stat: String,
198}
199
200#[derive(serde::Deserialize)]
201struct ApiResponseErr {
202 #[serde(rename = "@attributes")]
203 attributes: ApiResponseErrAttributes,
204}
205
206#[derive(serde::Deserialize)]
207struct ApiResponseErrAttributes {
208 code: String,
209 msg: String,
210}
211
212fn parse_response<T>(response: Response<T>, method: Method) -> Result<ApiResponse, ResponseError>
213where
214 T: AsRef<[u8]>,
215{
216 if !response.status().is_success() {
217 return Err(ResponseError::HttpStatus(response.status()));
218 }
219
220 let body = response.body();
221 let api_response: ApiResponse = serde_json::from_slice(body.as_ref())?;
222
223 if let Some(ApiResponseErr {
224 attributes: ApiResponseErrAttributes { code, msg },
225 }) = api_response.rsp.err
226 {
227 let err = match code.as_str() {
228 "100" => ResponseError::InvalidApiKey(msg),
229 "101" => ResponseError::InactiveApiKey(msg),
230 "102" => ResponseError::MissingApiKey(msg),
231 "103" => ResponseError::ServiceUnavailable(msg),
232 _ => ResponseError::UnknownErr(code, msg),
233 };
234 return Err(err);
235 }
236
237 if !api_response.rsp.attributes.stat.eq("ok") {
238 return Err(ResponseError::NonOkStat(api_response.rsp.attributes.stat));
239 }
240
241 if !api_response
242 .rsp
243 .method
244 .as_ref()
245 .map(|s| s.as_str())
246 .eq(&Some(method.method_str()))
247 {
248 return Err(ResponseError::MisMatchedMethod(
249 api_response.rsp.method.unwrap_or_default(),
250 method.method_str().to_owned(),
251 ));
252 }
253
254 Ok(api_response)
255}
256
257pub fn profanity_check_result<T>(response: Response<T>) -> Result<bool, ResponseError>
264where
265 T: AsRef<[u8]>,
266{
267 let response = parse_response(response, Method::Check)?;
268
269 let check: u32 = response
270 .rsp
271 .found
272 .ok_or_else(|| ResponseError::MissingField("found".to_owned()))
273 .and_then(|found| {
274 found
275 .parse()
276 .map_err(|_err| ResponseError::InvalidField("found".to_owned()))
277 })?;
278
279 Ok(check > 0)
280}
281
282pub fn profanity_replace_result<T>(response: Response<T>) -> Result<String, ResponseError>
289where
290 T: AsRef<[u8]>,
291{
292 let response = parse_response(response, Method::Replace("".to_owned()))?; match response.rsp.text {
295 Some(text) => Ok(text),
296 None => Err(ResponseError::MissingField("text".to_owned())),
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use std::error::Error;
303
304 use crate::client;
305 use http::Request;
306 use http::Response;
307 use http::StatusCode;
308
309 fn uri_contains(req: &Request<Vec<u8>>, needle: &str) -> bool {
310 req.uri().to_string().contains(needle)
311 }
312
313 #[test]
314 fn qs_encoding() {
315 assert_eq!(
316 client::query_string("abcd", "hi there", client::Method::Check),
317 "format=json&api_key=abcd&text=hi+there&method=webpurify.live.check&semail=1&slink=1&rsp=1&sphone=1"
318 );
319 }
320
321 #[test]
322 fn check_request() {
323 let region = client::Region::Europe;
324 let req = client::profanity_check_request("abcd", region, "hi there");
325 assert_eq!(
326 req.unwrap().uri(),
327 "https://api1-eu.webpurify.com/services/rest/?format=json&api_key=abcd&text=hi+there&method=webpurify.live.check&semail=1&slink=1&rsp=1&sphone=1"
328 );
329 }
330
331 #[test]
332 fn check_result() -> Result<(), Box<dyn Error>> {
333 let response_found = |found: u32| {
334 let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"found\":\"{found}\",\"api_key\":\"123\"}}}}");
335 Response::builder()
336 .status(StatusCode::OK)
337 .body(body.as_bytes().to_vec())
338 };
339 let result = client::profanity_check_result(response_found(3)?)?;
340 assert!(result);
341 let result = client::profanity_check_result(response_found(0)?)?;
342 assert!(!result);
343 Ok(())
344 }
345
346 #[test]
347 fn check_result_missing_found() -> Result<(), Box<dyn Error>> {
348 let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"api_key\":\"123\"}}}}");
349 let response = Response::builder()
350 .status(StatusCode::OK)
351 .body(body.as_bytes().to_vec());
352 let result = client::profanity_check_result(response?);
353 assert!(result.is_err());
354 Ok(())
355 }
356
357 #[test]
358 fn replace_request() {
359 let region = client::Region::Europe;
360 let res_req = client::profanity_replace_request("abcd", region, "hi there", "*");
361 let req = res_req.unwrap();
362 assert!(uri_contains(&req, "method=webpurify.live.replace"));
363 assert!(uri_contains(&req, "replacesymbol=*"));
364 assert!(uri_contains(&req, "text=hi+there"));
365 }
366
367 #[test]
368 fn replace_result() -> Result<(), Box<dyn Error>> {
369 let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"found\":\"3\",\"text\":\"foo\",\"api_key\":\"123\"}}";
370 let response = Response::builder()
371 .status(StatusCode::OK)
372 .body((*body).into_iter().collect::<Vec<_>>())?;
373 let result = client::profanity_replace_result(response)?;
374
375 assert_eq!(result, "foo".to_owned());
376 Ok(())
377 }
378
379 #[test]
380 fn replace_result_missing_found() -> Result<(), Box<dyn Error>> {
381 let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"text\":\"foo\",\"api_key\":\"123\"}}";
382 let response = Response::builder()
383 .status(StatusCode::OK)
384 .body((*body).into_iter().collect::<Vec<_>>())?;
385 let result = client::profanity_replace_result(response)?;
386
387 assert_eq!(result, "foo".to_owned());
388 Ok(())
389 }
390
391 #[test]
392 fn response_errors() -> Result<(), Box<dyn Error>> {
393 let response = |code: u32| {
394 let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"fail\"}},\"err\":{{\"@attributes\":{{\"code\":\"{code}\",\"msg\":\"Msg\"}}}},\"text\":\"text\"}}}}");
395 Response::builder()
396 .status(StatusCode::OK)
397 .body(body.as_bytes().to_vec())
398 };
399
400 for (code, err) in [
401 (100, client::ResponseError::InvalidApiKey("Msg".to_owned())),
402 (101, client::ResponseError::InactiveApiKey("Msg".to_owned())),
403 (102, client::ResponseError::MissingApiKey("Msg".to_owned())),
404 (
405 103,
406 client::ResponseError::ServiceUnavailable("Msg".to_owned()),
407 ),
408 (
409 999,
410 client::ResponseError::UnknownErr("999".to_owned(), "Msg".to_owned()),
411 ),
412 ] {
413 let result = client::profanity_replace_result(response(code)?);
414 let result_err = result.err().expect("Expected error");
415 assert!(
416 std::mem::discriminant(&result_err) == std::mem::discriminant(&err),
417 "Expected error: {:?} but got: {:?}",
418 err,
419 result_err
420 );
421 }
422
423 Ok(())
424 }
425
426 #[test]
427 fn mismatched_response_methods() -> Result<(), Box<dyn Error>> {
428 let body = format!("{{\"rsp\":{{\"@attributes\":{{\"stat\":\"ok\",\"rsp\":\"0.0072040557861328\"}},\"method\":\"webpurify.live.check\",\"format\":\"rest\",\"found\":\"1\",\"api_key\":\"123\"}}}}");
430 let response = Response::builder()
431 .status(StatusCode::OK)
432 .body(body.as_bytes().to_vec());
433 let result = client::profanity_replace_result(response?);
434 assert!(result.is_err());
435
436 let body = b"{\"rsp\":{\"@attributes\":{\"stat\":\"ok\",\"rsp\":\"0.018898963928223\"},\"method\":\"webpurify.live.replace\",\"format\":\"rest\",\"found\":\"3\",\"text\":\"foo\",\"api_key\":\"123\"}}";
438 let response = Response::builder()
439 .status(StatusCode::OK)
440 .body((*body).into_iter().collect::<Vec<_>>())?;
441 let result = client::profanity_check_result(response);
442 assert!(result.is_err());
443
444 Ok(())
445 }
446}