screeps_rust_api/
http_client.rs

1use std::{
2    sync::Mutex,
3    thread,
4    time::{Duration, SystemTime, UNIX_EPOCH},
5};
6
7use reqwest::{Client, header::HeaderMap};
8use serde::{Serialize, de::DeserializeOwned};
9use serde_json::json;
10
11use crate::{
12    config::ScreepsConfig,
13    error::{ScreepsError, ScreepsResult},
14    model::TokenData,
15    rate_limit::RateLimits,
16};
17
18/// 请求方法
19pub enum Method {
20    Get,
21    Post,
22}
23pub use Method::*;
24
25pub type AnyPayload = Option<&'static [(&'static str, &'static str)]>;
26
27/// Screeps http 客户端
28pub struct ScreepsHttpClient {
29    /// 请求客户端
30    pub client: Client,
31    /// 配置
32    pub config: ScreepsConfig,
33    /// 限速信息
34    pub rate_limits: Mutex<RateLimits>,
35    /// 最新的 token
36    pub token: Mutex<Option<String>>,
37}
38
39impl ScreepsHttpClient {
40    pub fn new(config: ScreepsConfig) -> Self {
41        let client = Client::builder()
42            .timeout(Duration::from_secs(config.timeout))
43            .build()
44            .unwrap();
45
46        Self {
47            client,
48            token: Mutex::new(config.token.clone()),
49            config,
50            rate_limits: Mutex::new(RateLimits::default()),
51        }
52    }
53
54    /// 封装 get 请求和 post 请求
55    pub async fn request<T: Serialize, U: DeserializeOwned>(
56        &self,
57        method: Method,
58        path: &str,
59        body: Option<T>,
60    ) -> ScreepsResult<U> {
61        let url = self.build_url(path);
62        let request_builder = match method {
63            Method::Get => self.client.get(url).query(&body),
64            Method::Post => self.client.post(url).json(&body),
65        }
66        .headers(self.build_headers());
67
68        // 先检查速率限制
69        let rate_limit = self.rate_limits.lock().unwrap().get_limit(&method, path);
70        if rate_limit.remaining <= 0 {
71            let wait_time = rate_limit.reset * 1000
72                - SystemTime::now()
73                    .duration_since(UNIX_EPOCH)
74                    .unwrap()
75                    .as_millis() as u128;
76            if wait_time > 0 {
77                thread::sleep(Duration::from_millis(wait_time as u64));
78            }
79        }
80        let response = request_builder.send().await?;
81        if let Some(token) = response.headers().get("x-token") {
82            *self.token.lock().unwrap() = Some(token.to_str().unwrap().to_string());
83        }
84        self.rate_limits
85            .lock()
86            .unwrap()
87            .update_from_headers(&method, path, response.headers());
88        let result = response.json::<U>().await?;
89        Ok(result)
90    }
91
92    /// 构造请求头,添加 token
93    fn build_headers(&self) -> HeaderMap {
94        let mut headers = HeaderMap::new();
95        let token = self.token.lock().unwrap().as_ref().cloned();
96        if let Some(token) = token {
97            headers.insert("X-Token", token.parse().unwrap());
98            headers.insert("X-Username", token.parse().unwrap());
99        }
100        headers
101    }
102
103    /// 根据路径构造完整的 api
104    pub fn build_url(&self, path: &str) -> String {
105        format!("{}{}", self.config.build_base_url(), path)
106    }
107}
108
109impl ScreepsHttpClient {
110    /// 登录以获取 token
111    pub async fn auth(&self) -> ScreepsResult<TokenData> {
112        if self.config.email.is_none() || self.config.password.is_none() {
113            return Err(ScreepsError::Config(
114                "email or password is none".to_string(),
115            ));
116        }
117
118        let result: Result<TokenData, ScreepsError> = self
119            .request(
120                Method::Post,
121                "/auth/signin",
122                Some(json!({
123                    "email": self.config.email.clone().unwrap(),
124                    "password": self.config.password.clone().unwrap(),
125                })),
126            )
127            .await;
128        result
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use std::env;
136
137    #[test]
138    fn test_build_url() {
139        let config = ScreepsConfig::new(
140            Some("token".to_string()),
141            Some("email".to_string()),
142            Some("password".to_string()),
143            "screeps.com".to_string(),
144            true,
145            10000,
146        );
147        let client = ScreepsHttpClient::new(config);
148        assert_eq!(
149            client.build_url("/auth/signin"),
150            "https://screeps.com/api/auth/signin"
151        )
152    }
153
154    #[test]
155    fn test_build_headers() {
156        let config = ScreepsConfig::new(
157            Some("token".to_string()),
158            Some("email".to_string()),
159            Some("password".to_string()),
160            "screeps.com".to_string(),
161            true,
162            10000,
163        );
164        let client = ScreepsHttpClient::new(config);
165        assert_eq!(
166            client
167                .build_headers()
168                .get("X-Token")
169                .unwrap()
170                .to_str()
171                .unwrap(),
172            "token"
173        );
174    }
175
176    #[tokio::test]
177    async fn test_auth() {
178        let _ = dotenvy::dotenv();
179
180        // 从环境变量获取凭据,如果不存在则跳过测试
181        let email = env::var("SCREEPS_EMAIL").unwrap_or_else(|_| {
182            println!("SCREEPS_EMAIL not set, skipping test");
183            "test@example.com".to_string() // 占位符值
184        });
185
186        let password = env::var("SCREEPS_PASSWORD").unwrap_or_else(|_| {
187            println!("SCREEPS_PASSWORD not set, skipping test");
188            "password".to_string() // 占位符值
189        });
190
191        // 如果没有设置环境变量,则跳过测试
192        if env::var("SCREEPS_EMAIL").is_err() || env::var("SCREEPS_PASSWORD").is_err() {
193            println!(
194                "Skipping authentication test because SCREEPS_EMAIL or SCREEPS_PASSWORD is not set"
195            );
196            return;
197        }
198
199        let config = ScreepsConfig::new(
200            None,
201            Some(email),
202            Some(password),
203            "screeps.com".to_string(),
204            true,
205            10000,
206        );
207        let client = ScreepsHttpClient::new(config);
208        let result = client.auth().await;
209
210        // 只有当设置了环境变量时才检查结果
211        if env::var("SCREEPS_EMAIL").is_ok() && env::var("SCREEPS_PASSWORD").is_ok() {
212            match result {
213                Ok(data) => {
214                    println!("Authentication successful: {:?}", data);
215                    assert_eq!(data.base_data.ok.unwrap(), 1);
216                }
217                Err(e) => {
218                    println!("Authentication failed: {:?}", e);
219                    panic!("Authentication failed");
220                }
221            }
222        }
223    }
224}