smbcloud_cli/account/
lib.rs

1use anyhow::{anyhow, Result};
2use console::style;
3use log::debug;
4use regex::Regex;
5use reqwest::{Client, Response, StatusCode};
6use smbcloud_model::account::SmbAuthorization;
7use smbcloud_networking::{
8    constants::{
9        GH_OAUTH_CLIENT_ID, GH_OAUTH_REDIRECT_HOST, GH_OAUTH_REDIRECT_PORT, PATH_AUTHORIZE,
10    },
11    smb_base_url_builder,
12};
13use spinners::Spinner;
14use std::{
15    fs::{create_dir_all, OpenOptions},
16    io::{BufRead, BufReader, Write},
17    net::{TcpListener, TcpStream},
18    sync::mpsc::{self, Receiver, Sender},
19};
20use url_builder::URLBuilder;
21
22pub async fn authorize_github() -> Result<SmbAuthorization> {
23    // Spin up a simple localhost server to listen for the GitHub OAuth callback
24    // setup_oauth_callback_server();
25    // Open the GitHub OAuth URL in the user's browser
26    let mut spinner = Spinner::new(
27        spinners::Spinners::BouncingBall,
28        style("🚀 Getting your GitHub information...")
29            .green()
30            .bold()
31            .to_string(),
32    );
33
34    let rx = match open::that(build_github_oauth_url()) {
35        Ok(_) => {
36            let (tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel();
37            debug!(
38                "Setting up OAuth callback server... (tx: {:#?}, rx: {:#?})",
39                &tx, &rx
40            );
41            tokio::spawn(async move {
42                setup_oauth_callback_server(tx);
43            });
44            rx
45        }
46        Err(_) => {
47            let error = anyhow!("Failed to open a browser.");
48            return Err(error);
49        }
50    };
51
52    spinner.stop_and_persist("⌛", "Waiting for the authorization.".into());
53
54    debug!("Waiting for code from channel...");
55
56    match rx.recv() {
57        Ok(code) => {
58            debug!("Got code from channel: {:#?}", &code);
59            //Err(anyhow!("Failed to get code from channel."))
60            process_connect_github(code).await
61        }
62        Err(e) => {
63            let error = anyhow!("Failed to get code from channel: {e}");
64            Err(error)
65        }
66    }
67}
68
69fn setup_oauth_callback_server(tx: Sender<String>) {
70    let listener = TcpListener::bind(format!("127.0.0.1:{}", GH_OAUTH_REDIRECT_PORT)).unwrap();
71    for stream in listener.incoming() {
72        let stream = stream.unwrap();
73        handle_connection(stream, tx.clone());
74    }
75}
76
77fn handle_connection(mut stream: TcpStream, tx: Sender<String>) {
78    let buf_reader = BufReader::new(&stream);
79    let request_line = &buf_reader.lines().next().unwrap().unwrap();
80
81    debug!("Request: {:#?}", request_line);
82
83    let code_regex = Regex::new(r"code=([^&]*)").unwrap();
84
85    let (status_line, contents) = match code_regex.captures(request_line) {
86        Some(group) => {
87            let code = group.get(1).unwrap().as_str();
88            debug!("Code: {:#?}", code);
89            debug!("Sending code to channel...");
90            debug!("Channel: {:#?}", &tx);
91            match tx.send(code.to_string()) {
92                Ok(_) => {
93                    debug!("Code sent to channel.");
94                }
95                Err(e) => {
96                    debug!("Failed to send code to channel: {e}");
97                }
98            }
99            (
100                "HTTP/1.1 200 OK",
101                "<!DOCTYPE html>
102
103                <head>
104                    <meta charset='utf-8'>
105                    <title>Hello!</title>
106                </head>
107                
108                <body>
109                    <h1>Authenticated!</h1>
110                    <p>Back to the terminal console to finish your registration.</p>
111                </body>",
112            )
113        }
114        None => {
115            debug!("Code not found.");
116            (
117                "HTTP/1.1 404 NOT FOUND",
118                "<!DOCTYPE html>
119                <html lang='en'>
120                
121                <head>
122                    <meta charset='utf-8'>
123                    <title>404 Not found</title>
124                </head>
125                
126                <body>
127                    <h1>Oops!</h1>
128                    <p>Sorry, I don't know what you're asking for.</p>
129                </body>
130                
131                </html>",
132            )
133        }
134    };
135
136    debug!("Contents: {:#?}", &contents);
137    let response = format!("{status_line}\r\n\r\n{contents}");
138    stream.write_all(response.as_bytes()).unwrap();
139    stream.flush().unwrap();
140}
141
142// Get access token
143pub async fn process_connect_github(code: String) -> Result<SmbAuthorization> {
144    let response = Client::new()
145        .post(build_authorize_smb_url())
146        .body(format!("gh_code={}", code))
147        .header("Accept", "application/json")
148        .header("Content-Type", "application/x-www-form-urlencoded")
149        .send()
150        .await?;
151    let mut spinner = Spinner::new(
152        spinners::Spinners::BouncingBall,
153        style("🚀 Authorizing your account...")
154            .green()
155            .bold()
156            .to_string(),
157    );
158    // println!("Response: {:#?}", &response);
159    match response.status() {
160        StatusCode::OK => {
161            // Account authorized and token received
162            spinner.stop_and_persist("✅", "You are logged in with your GitHub account!".into());
163            save_token(&response).await?;
164            let result = response.json().await?;
165            // println!("Result: {:#?}", &result);
166            Ok(result)
167        }
168        StatusCode::NOT_FOUND => {
169            // Account not found and we show signup option
170            spinner.stop_and_persist("🥲", "Account not found. Please signup!".into());
171            let result = response.json().await?;
172            // println!("Result: {:#?}", &result);
173            Ok(result)
174        }
175        StatusCode::UNPROCESSABLE_ENTITY => {
176            // Account found but email not verified
177            spinner.stop_and_persist("🥹", "Unverified email!".into());
178            let result = response.json().await?;
179            // println!("Result: {:#?}", &result);
180            Ok(result)
181        }
182        _ => {
183            // Other errors
184            let error = anyhow!("Error while authorizing with GitHub.");
185            Err(error)
186        }
187    }
188}
189
190fn build_authorize_smb_url() -> String {
191    let mut url_builder = smb_base_url_builder();
192    url_builder.add_route(PATH_AUTHORIZE);
193    url_builder.build()
194}
195
196fn build_github_oauth_url() -> String {
197    let mut url_builder = github_base_url_builder();
198    url_builder
199        .add_route("login/oauth/authorize")
200        .add_param("scope", "user")
201        .add_param("state", "smbcloud");
202    url_builder.build()
203}
204
205fn github_base_url_builder() -> URLBuilder {
206    let redirect_url = format!("{}:{}", GH_OAUTH_REDIRECT_HOST, GH_OAUTH_REDIRECT_PORT);
207
208    let mut url_builder = URLBuilder::new();
209    url_builder
210        .set_protocol("https")
211        .set_host("github.com")
212        .add_param("client_id", GH_OAUTH_CLIENT_ID)
213        .add_param("redirect_uri", &redirect_url);
214    url_builder
215}
216
217pub async fn save_token(response: &Response) -> Result<()> {
218    let headers = response.headers();
219    // println!("Headers: {:#?}", &headers);
220    match headers.get("Authorization") {
221        Some(token) => {
222            debug!("{}", token.to_str()?);
223            match home::home_dir() {
224                Some(path) => {
225                    debug!("{}", path.to_str().unwrap());
226                    create_dir_all(path.join(".smb"))?;
227                    let mut file = OpenOptions::new()
228                        .create(true)
229                        .write(true)
230                        .open([path.to_str().unwrap(), "/.smb/token"].join(""))?;
231                    file.write_all(token.to_str()?.as_bytes())?;
232                    Ok(())
233                }
234                None => Err(anyhow!("Failed to get home directory.")),
235            }
236        }
237        None => Err(anyhow!("Failed to get token. Probably a backend issue.")),
238    }
239}