turnstile_actix_web/
lib.rs

1use std::future::{ready, Ready};
2
3use actix_web::{
4    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
5    Error,
6};
7
8use error::TurnstileError;
9use futures_util::future::LocalBoxFuture;
10use turnstile::verify_cloudflare_turnstile;
11
12pub mod error;
13pub mod reqwest_client;
14pub mod turnstile;
15
16#[derive(Clone)]
17pub struct TurnstileConfig {
18    pub secret_key: String,
19    pub timeout_secs: Option<u64>,
20}
21
22impl TurnstileConfig {
23    pub fn new(secret_key: impl Into<String>) -> Self {
24        Self {
25            secret_key: secret_key.into(),
26            timeout_secs: Some(5),
27        }
28    }
29}
30
31pub struct Turnstile {
32    config: TurnstileConfig,
33}
34impl Turnstile {
35    pub fn new(config: TurnstileConfig) -> Self {
36        Self { config }
37    }
38}
39
40impl<S, B> Transform<S, ServiceRequest> for Turnstile
41where
42    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
43    S::Future: 'static,
44    B: 'static,
45{
46    type Response = ServiceResponse<B>;
47    type Error = Error;
48    type InitError = ();
49    type Transform = TurnstileMiddleware<S>;
50    type Future = Ready<Result<Self::Transform, Self::InitError>>;
51
52    fn new_transform(&self, service: S) -> Self::Future {
53        let config = self.config.clone();
54        ready(Ok(TurnstileMiddleware { service, config }))
55    }
56}
57
58pub struct TurnstileMiddleware<S> {
59    service: S,
60    config: TurnstileConfig,
61}
62
63impl<S, B> Service<ServiceRequest> for TurnstileMiddleware<S>
64where
65    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
66    S::Future: 'static,
67    B: 'static,
68{
69    type Response = ServiceResponse<B>;
70    type Error = Error;
71    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
72
73    forward_ready!(service);
74
75    fn call(&self, req: ServiceRequest) -> Self::Future {
76        let connection_info = req.connection_info().to_owned();
77        let client_ip = match connection_info.realip_remote_addr() {
78            Some(ip) => ip.to_owned(),
79            None => {
80                return Box::pin(async { Err(Error::from(TurnstileError::ClientIPNotFound)) });
81            }
82        };
83
84        let headers = req.headers();
85        let cf_turnstile_response = match headers.get("cf-turnstile-response") {
86            Some(res) => match res.to_str() {
87                Ok(res) => res.to_owned(),
88                Err(_) => {
89                    return Box::pin(async {
90                        Err(Error::from(TurnstileError::InvalidTokenFormat))
91                    });
92                }
93            },
94            None => {
95                return Box::pin(async { Err(Error::from(TurnstileError::TokenNotFound)) });
96            }
97        };
98        // println!("{}: {}", client_ip, cf_turnstile_response);
99
100        let fut = self.service.call(req);
101
102        let config = self.config.clone();
103
104        Box::pin(async move {
105            match verify_cloudflare_turnstile(&cf_turnstile_response, &client_ip, &config).await {
106                Ok(true) => {
107                    // success
108                    let res = fut.await?;
109                    Ok(res)
110                }
111                Ok(false) => {
112                    // cloudflare returned failure
113                    Err(Error::from(TurnstileError::VerificationFailed(
114                        "Cloudflare rejected the token".to_string(),
115                    )))
116                }
117                Err(err) => {
118                    // network error
119                    Err(Error::from(TurnstileError::NetworkError(err)))
120                }
121            }
122        })
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use actix_web::{http::header, test, web, App, HttpResponse};
129
130    use super::*;
131
132    #[actix_web::test]
133    async fn test_turnstile_success() {
134        // Setting for test
135        let turnstile_config = TurnstileConfig::new("1x0000000000000000000000000000000AA");
136
137        // Create mock server
138        let app =
139            test::init_service(App::new().wrap(Turnstile::new(turnstile_config)).service(
140                web::resource("/").to(|| async { HttpResponse::Ok().body("hello world") }),
141            ))
142            .await;
143
144        // Mock sample verification token
145        // this is supposed to be returned by a client
146        let token = "valid_turnstile_token";
147
148        // Build test request
149        let req = test::TestRequest::get()
150            .uri("/")
151            .insert_header((
152                header::HeaderName::from_static("cf-turnstile-response"),
153                token,
154            ))
155            .peer_addr("192.168.1.1:12345".parse().unwrap())
156            .to_request();
157
158        // check request
159        let resp = test::call_service(&app, req).await;
160
161        // check if success
162        assert!(resp.status().is_success());
163    }
164
165    #[actix_web::test]
166    async fn test_turnstile_failure() {
167        // Setting for test
168        let turnstile_config = TurnstileConfig::new("2x0000000000000000000000000000000AA");
169
170        // Create mock server
171        let app =
172            test::init_service(App::new().wrap(Turnstile::new(turnstile_config)).service(
173                web::resource("/").to(|| async { HttpResponse::Ok().body("hello world") }),
174            ))
175            .await;
176
177        // Mock sample verification token
178        // this is supposed to be returned by a client
179        let token = "valid_turnstile_token";
180
181        // Build test request
182        let req = test::TestRequest::get()
183            .uri("/")
184            .insert_header((
185                header::HeaderName::from_static("cf-turnstile-response"),
186                token,
187            ))
188            .peer_addr("192.168.1.1:12345".parse().unwrap())
189            .to_request();
190
191        // リクエストの実行と検証
192        let resp = test::try_call_service(&app, req).await;
193        match resp {
194            Ok(response) => {
195                println!("{:?}", response);
196                assert!(response.status().is_client_error());
197            }
198            Err(e) => {
199                if let Some(turnstile_error) = e.as_error::<TurnstileError>() {
200                    match turnstile_error {
201                        TurnstileError::VerificationFailed(_) => {
202                            println!("{}", e.to_string());
203                        }
204                        err => {
205                            panic!("Unexpected error type: {}", err)
206                        }
207                    }
208                } else {
209                    panic!("Unexpected error type: {:?}", e)
210                }
211            }
212        }
213    }
214}