twitter_tool/twitter_client/
mod.rs

1pub mod api;
2
3use anyhow::{anyhow, Result};
4use hyper::body::Bytes;
5use hyper::client::HttpConnector;
6use hyper::server::conn::Http;
7use hyper::{Body, Client, Method, Request, Uri};
8use hyper_tls::HttpsConnector;
9use oauth2::basic::{BasicClient, BasicTokenResponse};
10use oauth2::reqwest::async_http_client;
11use oauth2::{
12    AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
13    RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
14};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use std::{fs, process};
19use tokio::net::TcpListener;
20use tokio::sync::Mutex;
21use url::Url;
22
23pub type PagedResult<T> = Result<(T, Option<String>)>;
24
25#[derive(Debug, Clone)]
26pub struct TwitterClient {
27    https_client: Client<HttpsConnector<HttpConnector>>,
28    twitter_client_id: String,
29    twitter_client_secret: String,
30    twitter_auth: TwitterAuth,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34struct TwitterAuth {
35    access_token: Option<AccessToken>,
36    refresh_token: Option<RefreshToken>,
37}
38
39impl TwitterClient {
40    pub fn new(twitter_client_id: &str, twitter_client_secret: &str) -> Self {
41        let https = HttpsConnector::new();
42        let https_client = Client::builder().build::<_, hyper::Body>(https);
43        Self {
44            https_client,
45            twitter_client_id: twitter_client_id.to_string(),
46            twitter_client_secret: twitter_client_secret.to_string(),
47            twitter_auth: TwitterAuth {
48                access_token: None,
49                refresh_token: None,
50            },
51        }
52    }
53
54    pub fn save_auth(&self) -> Result<()> {
55        let str = serde_json::to_string(&self.twitter_auth)?;
56        fs::write("./var/.oauth", str)?;
57        Ok(())
58    }
59
60    pub fn load_auth(&mut self) -> Result<()> {
61        let str = fs::read_to_string("./var/.oauth")?;
62        self.twitter_auth = serde_json::from_str(&str)?;
63        Ok(())
64    }
65
66    pub async fn authorize(&mut self, use_refresh_token: bool) -> Result<()> {
67        let oauth_client = BasicClient::new(
68            ClientId::new(self.twitter_client_id.clone()),
69            Some(ClientSecret::new(self.twitter_client_secret.clone())),
70            AuthUrl::new("https://twitter.com/i/oauth2/authorize".to_string())?,
71            Some(TokenUrl::new(
72                "https://api.twitter.com/2/oauth2/token".to_string(),
73            )?),
74        )
75        .set_redirect_uri(RedirectUrl::new("https://localhost:8080".to_string())?);
76        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
77        let (auth_url, _csrf_token) = oauth_client
78            .authorize_url(CsrfToken::new_random)
79            .add_scope(Scope::new("tweet.read".to_string()))
80            .add_scope(Scope::new("users.read".to_string()))
81            .add_scope(Scope::new("offline.access".to_string()))
82            .set_pkce_challenge(pkce_challenge)
83            .url();
84
85        match &self.twitter_auth.refresh_token {
86            Some(refresh_token) if use_refresh_token => {
87                let token = oauth_client
88                    .exchange_refresh_token(refresh_token)
89                    .request_async(async_http_client)
90                    .await?;
91                self.twitter_auth.access_token = Some(token.access_token().clone());
92                self.twitter_auth.refresh_token = token.refresh_token().cloned();
93                self.save_auth()?;
94            }
95            _ => {
96                // User browses here to complete OAuth flow
97                process::Command::new("open")
98                    .arg(auth_url.to_string())
99                    .output()
100                    .expect(&format!("Failed to open url in browser: {auth_url}"));
101
102                let mut callback_url = String::new();
103                println!("Enter callback url:");
104                std::io::stdin().read_line(&mut callback_url)?;
105                let callback_url = Url::parse(&callback_url)?;
106
107                // let (set_authorization_code, mut authorization_code) =
108                //     tokio::sync::mpsc::channel::<String>(1);
109                // let callback_addr = SocketAddr::from(([127, 0, 0, 1], 8080));
110                // let callback_listener = TcpListener::bind(callback_addr).await?;
111
112                fn parse_authorization_code(url: &Url) -> Result<String> {
113                    let mut expected_csrf_state = None;
114                    let mut authorization_code = None;
115                    for (key, value) in url.query_pairs() {
116                        if key == "state" {
117                            expected_csrf_state = Some(String::from(value));
118                        } else if key == "code" {
119                            authorization_code = Some(String::from(value));
120                        }
121                    }
122                    let _expected_csrf_state = expected_csrf_state
123                        .ok_or(anyhow!("Missing `state` param from callback"))?;
124                    let authorization_code =
125                        authorization_code.ok_or(anyhow!("Missing `code` param from callback"))?;
126
127                    // Once the user has been redirected to the redirect URL, you'll have access to the
128                    // authorization code. For security reasons, your code should verify that the `state`
129                    // parameter returned by the server matches `csrf_state`.
130                    Ok(authorization_code)
131                }
132
133                // tokio::task::spawn(async move {
134                //     loop {
135                //         let (stream, _) = callback_listener.accept().await.unwrap();
136                //         if let Err(err) = Http::new()
137                //             .serve_connection(
138                //                 stream,
139                //                 hyper::service::service_fn(|req| {
140                //                     let set_authorization_code = set_authorization_code.clone();
141                //                     async move {
142                //                         let authorization_code = parse_authorization_code(req.uri())?;
143                //                         set_authorization_code.send(authorization_code).await?;
144                //                         Ok::<_, anyhow::Error>(hyper::Response::new(hyper::Body::from(
145                //                             "You can close this window now",
146                //                         )))
147                //                     }
148                //                 }),
149                //             )
150                //             .await
151                //         {
152                //             eprintln!("Error serving callback: {}", err);
153                //         }
154                //     }
155                // });
156
157                let authorization_code = parse_authorization_code(&callback_url)?;
158                let token_result = oauth_client
159                    .exchange_code(AuthorizationCode::new(authorization_code))
160                    .set_pkce_verifier(pkce_verifier)
161                    .request_async(async_http_client)
162                    .await?;
163
164                self.twitter_auth.access_token = Some(token_result.access_token().clone());
165                self.twitter_auth.refresh_token = token_result.refresh_token().cloned();
166            }
167        }
168        Ok(())
169    }
170
171    async fn authenticated_get(&self, uri: &Url) -> Result<Bytes> {
172        let access_token = self
173            .twitter_auth
174            .access_token
175            .as_ref()
176            .ok_or(anyhow!("Unauthorized"))?;
177        let req = Request::builder()
178            .method(Method::GET)
179            .uri(uri.to_string())
180            .header("Authorization", format!("Bearer {}", access_token.secret()))
181            .body(Body::empty())?;
182        let resp = self.https_client.request(req).await?;
183        let resp = hyper::body::to_bytes(resp.into_body()).await?;
184        Ok(resp)
185    }
186
187    pub async fn me(&self) -> Result<api::User> {
188        let uri = Url::parse("https://api.twitter.com/2/users/me")?;
189        let bytes = self.authenticated_get(&uri).await?;
190        let resp: api::Response<api::User, ()> = serde_json::from_slice(&bytes)?;
191        Ok(resp.data)
192    }
193
194    pub async fn user_by_username(&self, username: &str) -> Result<api::User> {
195        let mut uri = Url::parse(&format!(
196            "https://api.twitter.com/2/users/by/username/{username}"
197        ))?;
198        uri.query_pairs_mut().append_pair("user.fields", "username");
199        let bytes = self.authenticated_get(&uri).await?;
200        let resp: api::Response<api::User, ()> = serde_json::from_slice(&bytes)?;
201        Ok(resp.data)
202    }
203
204    async fn get_tweets_with_users(
205        &self,
206        uri: &mut Url,
207        pagination_token: Option<String>,
208    ) -> PagedResult<Vec<api::Tweet>> {
209        uri.query_pairs_mut()
210            .append_pair(
211                "tweet.fields",
212                "created_at,attachments,referenced_tweets,public_metrics,conversation_id",
213            )
214            .append_pair("user.fields", "username")
215            .append_pair("expansions", "author_id")
216            .append_pair("max_results", "100");
217        if let Some(pagination_token) = pagination_token {
218            uri.query_pairs_mut()
219                .append_pair("pagination_token", &pagination_token);
220        }
221        let bytes = self.authenticated_get(&uri).await?;
222
223        #[derive(Debug, Serialize, Deserialize)]
224        struct Includes {
225            users: Vec<api::User>,
226        }
227
228        let resp: api::Response<Vec<api::Tweet>, Includes> = serde_json::from_slice(&bytes)?;
229        let next_pagination_token = resp.meta.and_then(|meta| meta.next_token);
230        let includes = resp.includes.ok_or(anyhow!("Expected `includes`"))?;
231        let users: HashMap<String, &api::User> = includes
232            .users
233            .iter()
234            .map(|user| (user.id.clone(), user))
235            .collect();
236        let tweets: Vec<api::Tweet> = resp
237            .data
238            .iter()
239            .map(|tweet| api::Tweet {
240                author_username: users
241                    .get(&tweet.author_id)
242                    .map(|user| user.username.clone()),
243                author_name: users.get(&tweet.author_id).map(|user| user.name.clone()),
244                ..tweet.clone()
245            })
246            .collect();
247        Ok((tweets, next_pagination_token))
248    }
249
250    pub async fn user_tweets(
251        &self,
252        user_id: &str,
253        pagination_token: Option<String>,
254    ) -> PagedResult<Vec<api::Tweet>> {
255        let mut uri = Url::parse(&format!("https://api.twitter.com/2/users/{user_id}/tweets"))?;
256        self.get_tweets_with_users(&mut uri, pagination_token).await
257    }
258
259    pub async fn timeline_reverse_chronological(
260        &self,
261        user_id: &str,
262        pagination_token: Option<String>,
263    ) -> PagedResult<Vec<api::Tweet>> {
264        let mut uri = Url::parse(&format!(
265            "https://api.twitter.com/2/users/{user_id}/timelines/reverse_chronological"
266        ))?;
267        self.get_tweets_with_users(&mut uri, pagination_token).await
268    }
269
270    pub async fn search_tweets(&self, query: &str) -> PagedResult<Vec<api::Tweet>> {
271        let mut uri = Url::parse("https://api.twitter.com/2/tweets/search/recent")?;
272        uri.query_pairs_mut().append_pair("query", query);
273        self.get_tweets_with_users(&mut uri, None).await
274    }
275}