Skip to main content

tideway_cli/commands/
add.rs

1//! Add command - enable Tideway features and scaffold modules.
2
3use anyhow::{Context, Result};
4use std::collections::BTreeSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::cli::{AddArgs, AddFeature};
9use crate::templates::{BackendTemplateContext, BackendTemplateEngine};
10use crate::{
11    ensure_dir, print_info, print_success, print_warning, write_file, TIDEWAY_VERSION,
12};
13
14pub fn run(args: AddArgs) -> Result<()> {
15    let project_dir = PathBuf::from(&args.path);
16    let cargo_path = project_dir.join("Cargo.toml");
17
18    if !cargo_path.exists() {
19        return Err(anyhow::anyhow!(
20            "Cargo.toml not found in {}",
21            project_dir.display()
22        ));
23    }
24
25    let cargo_contents = fs::read_to_string(&cargo_path)
26        .with_context(|| format!("Failed to read {}", cargo_path.display()))?;
27
28    let project_name = project_name_from_cargo(&cargo_contents, &project_dir);
29    let project_name_pascal = to_pascal_case(&project_name);
30
31    update_cargo_toml(&cargo_path, &cargo_contents, args.feature)?;
32    update_env_example(&project_dir, args.feature, &project_name)?;
33
34    if args.feature == AddFeature::Auth {
35        scaffold_auth(&project_dir, &project_name, &project_name_pascal, args.force)?;
36        print_info("Auth scaffold created in src/auth/");
37        if args.wire {
38            wire_auth_in_main(&project_dir, &project_name)?;
39        } else {
40            print_info("Next steps: wire AuthModule + SimpleAuthProvider in main.rs");
41        }
42    }
43
44    if args.feature == AddFeature::Database && args.wire {
45        wire_database_in_main(&project_dir)?;
46    }
47
48    if args.feature == AddFeature::Openapi {
49        ensure_openapi_docs_file(&project_dir)?;
50        if args.wire {
51            wire_openapi_in_main(&project_dir)?;
52        } else {
53            print_info("Next steps: wire OpenAPI in main.rs");
54        }
55    }
56
57    print_success(&format!("Added {}", args.feature));
58    Ok(())
59}
60
61fn update_cargo_toml(path: &Path, contents: &str, feature: AddFeature) -> Result<()> {
62    let mut doc = contents.parse::<toml_edit::DocumentMut>()?;
63
64    let deps = doc["dependencies"].or_insert(toml_edit::Item::Table(toml_edit::Table::new()));
65
66    let tideway_item = deps
67        .as_table_mut()
68        .expect("dependencies should be a table")
69        .entry("tideway");
70
71    let feature_name = feature.to_string();
72
73    match tideway_item {
74        toml_edit::Entry::Vacant(entry) => {
75            let mut table = toml_edit::InlineTable::new();
76            table.get_or_insert("version", TIDEWAY_VERSION);
77            table.get_or_insert("features", array_value(&[feature_name.as_str()]));
78            entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
79        }
80        toml_edit::Entry::Occupied(mut entry) => {
81            if entry.get().is_str() {
82                let version = entry
83                    .get()
84                    .as_str()
85                    .unwrap_or(TIDEWAY_VERSION)
86                    .to_string();
87                let mut table = toml_edit::InlineTable::new();
88                table.get_or_insert("version", version);
89                table.get_or_insert("features", array_value(&[feature_name.as_str()]));
90                entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
91            } else {
92                let item = entry.get_mut();
93                let features = item["features"]
94                    .or_insert(toml_edit::Item::Value(toml_edit::Value::Array(toml_edit::Array::new())))
95                    .as_array_mut()
96                    .expect("features should be an array");
97
98                if !features.iter().any(|v| v.as_str() == Some(&feature_name)) {
99                    features.push(feature_name);
100                }
101            }
102        }
103    }
104
105    if feature == AddFeature::Database {
106        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
107        deps_table
108            .entry("sea-orm")
109            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
110                {
111                    let mut table = toml_edit::InlineTable::new();
112                    table.get_or_insert("version", "1.1");
113                    table.get_or_insert(
114                        "features",
115                        array_value(&["sqlx-postgres", "runtime-tokio-rustls"]),
116                    );
117                    table
118                },
119            )));
120    }
121
122    if feature == AddFeature::Auth {
123        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
124        deps_table
125            .entry("async-trait")
126            .or_insert(toml_edit::value("0.1"));
127        deps_table
128            .entry("serde")
129            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
130                {
131                    let mut table = toml_edit::InlineTable::new();
132                    table.get_or_insert("version", "1.0");
133                    table.get_or_insert("features", array_value(&["derive"]));
134                    table
135                },
136            )));
137        deps_table
138            .entry("serde_json")
139            .or_insert(toml_edit::value("1.0"));
140    }
141
142    write_file(path, &doc.to_string())
143        .with_context(|| format!("Failed to write {}", path.display()))?;
144    Ok(())
145}
146
147fn update_env_example(project_dir: &Path, feature: AddFeature, project_name: &str) -> Result<()> {
148    let env_path = project_dir.join(".env.example");
149    let mut lines = if env_path.exists() {
150        fs::read_to_string(&env_path)
151            .with_context(|| format!("Failed to read {}", env_path.display()))?
152            .lines()
153            .map(|line| line.to_string())
154            .collect::<Vec<_>>()
155    } else {
156        vec![
157            "# Server".to_string(),
158            "TIDEWAY_HOST=0.0.0.0".to_string(),
159            "TIDEWAY_PORT=8000".to_string(),
160            String::new(),
161        ]
162    };
163
164    let mut existing = BTreeSet::new();
165    for line in &lines {
166        if let Some((key, _)) = line.split_once('=') {
167            existing.insert(key.trim().to_string());
168        }
169    }
170
171    match feature {
172        AddFeature::Database => {
173            if !existing.contains("DATABASE_URL") {
174                lines.push("# Database".to_string());
175                lines.push(format!(
176                    "DATABASE_URL=postgres://postgres:postgres@localhost:5432/{}",
177                    project_name
178                ));
179                lines.push(String::new());
180            }
181        }
182        AddFeature::Auth => {
183            if !existing.contains("JWT_SECRET") {
184                lines.push("# Auth".to_string());
185                lines.push("JWT_SECRET=your-super-secret-jwt-key-change-in-production".to_string());
186                lines.push(String::new());
187            }
188        }
189        _ => {}
190    }
191
192    write_file(&env_path, &lines.join("\n"))
193        .with_context(|| format!("Failed to write {}", env_path.display()))?;
194    Ok(())
195}
196
197fn scaffold_auth(
198    project_dir: &Path,
199    project_name: &str,
200    project_name_pascal: &str,
201    force: bool,
202) -> Result<()> {
203    let context = BackendTemplateContext {
204        project_name: project_name.to_string(),
205        project_name_pascal: project_name_pascal.to_string(),
206        has_organizations: false,
207        database: "postgres".to_string(),
208        tideway_version: TIDEWAY_VERSION.to_string(),
209        tideway_features: vec!["auth".to_string()],
210        has_tideway_features: true,
211        has_auth_feature: true,
212        has_database_feature: false,
213        has_openapi_feature: false,
214        needs_arc: true,
215        has_config: false,
216    };
217
218    let engine = BackendTemplateEngine::new(context)?;
219    let auth_dir = project_dir.join("src").join("auth");
220
221    write_file_with_force(
222        &auth_dir.join("mod.rs"),
223        &engine.render("starter/src/auth/mod.rs")?,
224        force,
225    )?;
226    write_file_with_force(
227        &auth_dir.join("provider.rs"),
228        &engine.render("starter/src/auth/provider.rs")?,
229        force,
230    )?;
231    write_file_with_force(
232        &auth_dir.join("routes.rs"),
233        &engine.render("starter/src/auth/routes.rs")?,
234        force,
235    )?;
236
237    Ok(())
238}
239
240fn wire_auth_in_main(project_dir: &Path, project_name: &str) -> Result<()> {
241    let main_path = project_dir.join("src").join("main.rs");
242    if !main_path.exists() {
243        print_warning("src/main.rs not found; skipping auto-wiring");
244        return Ok(());
245    }
246
247    let mut contents = fs::read_to_string(&main_path)
248        .with_context(|| format!("Failed to read {}", main_path.display()))?;
249
250    if !contents.contains("mod auth;") {
251        if contents.contains("mod routes;") {
252            contents = contents.replace("mod routes;\n", "mod routes;\nmod auth;\n");
253        } else {
254            contents = format!("mod auth;\n{}", contents);
255        }
256    }
257
258    contents = ensure_use_line(
259        contents,
260        "use axum::Extension;",
261        "use tideway::auth",
262    );
263    contents = ensure_use_line(
264        contents,
265        "use crate::auth::{AuthModule, SimpleAuthProvider};",
266        "use tideway::auth",
267    );
268    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
269    contents = ensure_use_line(
270        contents,
271        "use tideway::auth::{JwtIssuer, JwtIssuerConfig};",
272        "use tideway::auth",
273    );
274
275    let has_jwt_secret = contents.contains("let jwt_secret");
276    let has_jwt_issuer = contents.contains("let jwt_issuer");
277    let has_auth_provider = contents.contains("auth_provider");
278    let has_auth_module = contents.contains("auth_module");
279
280    if has_jwt_secret && has_jwt_issuer {
281        if !has_auth_provider || !has_auth_module {
282            if let Some(insert_at) = contents.find("let jwt_issuer") {
283                let after = contents[insert_at..]
284                    .find(";\n")
285                    .map(|idx| insert_at + idx + 2)
286                    .unwrap_or(insert_at);
287                let insert = format!(
288                    "    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n"
289                );
290                contents.insert_str(after, &insert);
291            }
292        }
293    } else {
294        let block = format!(
295            "    let jwt_secret = std::env::var(\"JWT_SECRET\").expect(\"JWT_SECRET is not set\");\n    let jwt_issuer = Arc::new(JwtIssuer::new(JwtIssuerConfig::with_secret(\n        &jwt_secret,\n        \"{}\",\n    )).expect(\"Failed to create JWT issuer\"));\n    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n\n",
296            project_name
297        );
298        contents = insert_before_app_builder(contents, &block)?;
299    }
300
301    contents = insert_auth_into_app_builder(contents)?;
302
303    write_file(&main_path, &contents)
304        .with_context(|| format!("Failed to write {}", main_path.display()))?;
305    print_success("Wired auth into src/main.rs");
306    Ok(())
307}
308
309pub fn wire_database_in_main(project_dir: &Path) -> Result<()> {
310    let main_path = project_dir.join("src").join("main.rs");
311    if !main_path.exists() {
312        print_warning("src/main.rs not found; skipping auto-wiring");
313        return Ok(());
314    }
315
316    let mut contents = fs::read_to_string(&main_path)
317        .with_context(|| format!("Failed to read {}", main_path.display()))?;
318
319    if !contents.contains("async fn main") {
320        print_warning("main.rs is not async; skipping database wiring");
321        return Ok(());
322    }
323
324    contents = ensure_use_line(
325        contents,
326        "use tideway::{AppContext, SeaOrmPool};",
327        "use tideway::",
328    );
329    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
330
331    let has_database_block = contents.contains("DATABASE_URL")
332        || contents.contains("sea_orm::Database::connect")
333        || contents.contains("with_database");
334
335    if !has_database_block {
336        let block = "    let database_url = std::env::var(\"DATABASE_URL\").expect(\"DATABASE_URL is not set\");\n    let db = sea_orm::Database::connect(&database_url)\n        .await\n        .expect(\"Failed to connect to database\");\n\n";
337        contents = insert_before_app_builder(contents, block)?;
338    }
339
340    if !contents.contains(".with_database(") {
341        contents = insert_database_into_app_builder(contents)?;
342    }
343
344    write_file(&main_path, &contents)
345        .with_context(|| format!("Failed to write {}", main_path.display()))?;
346    print_success("Wired database into src/main.rs");
347    Ok(())
348}
349
350fn ensure_use_line(mut contents: String, line: &str, anchor: &str) -> String {
351    if contents.contains(line) {
352        return contents;
353    }
354
355    if let Some(pos) = contents.find(anchor) {
356        if let Some(line_end) = contents[pos..].find('\n') {
357            let insert_at = pos + line_end + 1;
358            contents.insert_str(insert_at, &format!("{}\n", line));
359            return contents;
360        }
361    }
362
363    contents = format!("{}\n{}", line, contents);
364    contents
365}
366
367fn insert_before_app_builder(mut contents: String, block: &str) -> Result<String> {
368    if let Some(pos) = contents.find("let app = App::") {
369        contents.insert_str(pos, block);
370        Ok(contents)
371    } else {
372        print_warning("Could not find app builder; skipping auth wiring");
373        Ok(contents)
374    }
375}
376
377fn insert_auth_into_app_builder(mut contents: String) -> Result<String> {
378    if contents.contains("register_module(auth_module)") {
379        return Ok(contents);
380    }
381
382    if let Some(pos) = contents.find("let app = App::") {
383        let line_end = contents[pos..]
384            .find('\n')
385            .map(|idx| pos + idx)
386            .unwrap_or(contents.len());
387        let indent = contents[pos..]
388            .chars()
389            .take_while(|c| c.is_whitespace())
390            .collect::<String>();
391        let insert = format!(
392            "{}    .with_global_layer(Extension(auth_provider))\n{}    .register_module(auth_module)\n",
393            indent, indent
394        );
395        contents.insert_str(line_end + 1, &insert);
396        Ok(contents)
397    } else {
398        print_warning("Could not find app builder; skipping auth module registration");
399        Ok(contents)
400    }
401}
402
403fn insert_database_into_app_builder(mut contents: String) -> Result<String> {
404    if let Some(pos) = contents.find("let app = App::") {
405        let line_end = contents[pos..]
406            .find('\n')
407            .map(|idx| pos + idx)
408            .unwrap_or(contents.len());
409        let indent = contents[pos..]
410            .chars()
411            .take_while(|c| c.is_whitespace())
412            .collect::<String>();
413        let insert = format!(
414            "{}    .with_context(\n{}        AppContext::builder()\n{}            .with_database(Arc::new(SeaOrmPool::new(db, database_url)))\n{}            .build()\n{}    )\n",
415            indent, indent, indent, indent, indent
416        );
417        contents.insert_str(line_end + 1, &insert);
418        Ok(contents)
419    } else {
420        print_warning("Could not find app builder; skipping database wiring");
421        Ok(contents)
422    }
423}
424
425fn wire_openapi_in_main(project_dir: &Path) -> Result<()> {
426    let main_path = project_dir.join("src").join("main.rs");
427    if !main_path.exists() {
428        print_warning("src/main.rs not found; skipping auto-wiring");
429        return Ok(());
430    }
431
432    let mut contents = fs::read_to_string(&main_path)
433        .with_context(|| format!("Failed to read {}", main_path.display()))?;
434
435    if contents.contains("openapi::create_openapi_router") || contents.contains("openapi_merge_module") {
436        print_info("OpenAPI already appears wired in main.rs");
437        return Ok(());
438    }
439
440    contents = ensure_use_line(contents, "use tideway::ConfigBuilder;", "use tideway::");
441    if contents.contains("mod config;") {
442        contents = ensure_use_line(contents, "use crate::config::AppConfig;", "use tideway::");
443    }
444    contents = ensure_use_line(contents, "use tideway::openapi;", "use tideway::");
445
446    if !contents.contains("mod openapi_docs;") {
447        if contents.contains("mod routes;") {
448            contents = contents.replace("mod routes;\n", "mod routes;\nmod openapi_docs;\n");
449        } else {
450            contents = format!("mod openapi_docs;\n{}", contents);
451        }
452    }
453
454    let has_config_var = contents.contains("let config = ConfigBuilder::new()")
455        || contents.contains("let config = AppConfig::from_env()");
456    let config_available = contents.contains("ConfigBuilder::new()")
457        || contents.contains("AppConfig::from_env()");
458
459    if !has_config_var && config_available {
460        let config_block = "    let config = ConfigBuilder::new()\n        .from_env()\n        .build()\n        .expect(\"Invalid TIDEWAY_* config\");\n\n";
461        contents = insert_before_app_builder(contents, config_block)?;
462    }
463
464    if contents.contains("let config = AppConfig::from_env()") {
465        contents = insert_openapi_into_app_builder(contents, "config.tideway")?;
466    } else {
467        contents = insert_openapi_into_app_builder(contents, "config")?;
468    }
469
470    write_file(&main_path, &contents)
471        .with_context(|| format!("Failed to write {}", main_path.display()))?;
472    print_success("Wired OpenAPI into src/main.rs");
473    Ok(())
474}
475
476fn insert_openapi_into_app_builder(mut contents: String, config_ref: &str) -> Result<String> {
477    if contents.contains("create_openapi_router") {
478        return Ok(contents);
479    }
480
481    if let Some(pos) = contents.find("let app = App::") {
482        // Insert after app builder block to keep code readable.
483        if let Some(end_pos) = contents[pos..].find(";\n\n") {
484            let insert_at = pos + end_pos + 3;
485            let block = format!(
486                "\n    #[cfg(feature = \"openapi\")]\n    if {config_ref}.openapi.enabled {{\n        let openapi = tideway::openapi_merge_module!(openapi_docs, ApiDoc);\n        let openapi_router = tideway::openapi::create_openapi_router(openapi, &{config_ref}.openapi);\n        app = app.merge_router(openapi_router);\n    }}\n"
487            );
488            contents.insert_str(insert_at, &block);
489        } else {
490            print_warning("Could not find app builder termination; skipping OpenAPI wiring");
491        }
492        Ok(contents)
493    } else {
494        print_warning("Could not find app builder; skipping OpenAPI wiring");
495        Ok(contents)
496    }
497}
498
499fn ensure_openapi_docs_file(project_dir: &Path) -> Result<()> {
500    let docs_path = project_dir.join("src").join("openapi_docs.rs");
501    if docs_path.exists() {
502        return Ok(());
503    }
504
505    let contents = r#"#[cfg(feature = "openapi")]
506tideway::openapi_doc!(pub(crate) ApiDoc, paths());
507"#;
508
509    if let Some(parent) = docs_path.parent() {
510        ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
511    }
512
513    write_file(&docs_path, &contents)
514        .with_context(|| format!("Failed to write {}", docs_path.display()))?;
515    print_success("Created src/openapi_docs.rs");
516    Ok(())
517}
518
519
520fn write_file_with_force(path: &Path, contents: &str, force: bool) -> Result<()> {
521    if path.exists() && !force {
522        print_warning(&format!(
523            "Skipping {} (use --force to overwrite)",
524            path.display()
525        ));
526        return Ok(());
527    }
528
529    if let Some(parent) = path.parent() {
530        ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
531    }
532
533    write_file(path, contents)
534        .with_context(|| format!("Failed to write {}", path.display()))?;
535    Ok(())
536}
537
538fn project_name_from_cargo(contents: &str, project_dir: &Path) -> String {
539    let doc = contents
540        .parse::<toml_edit::DocumentMut>()
541        .ok()
542        .and_then(|doc| doc["package"]["name"].as_str().map(|s| s.to_string()));
543
544    doc.unwrap_or_else(|| {
545        project_dir
546            .file_name()
547            .and_then(|n| n.to_str())
548            .unwrap_or("my_app")
549            .to_string()
550    })
551    .replace('-', "_")
552}
553
554fn to_pascal_case(s: &str) -> String {
555    s.split('_')
556        .filter(|part| !part.is_empty())
557        .map(|word| {
558            let mut chars = word.chars();
559            match chars.next() {
560                None => String::new(),
561                Some(first) => first.to_uppercase().chain(chars).collect(),
562            }
563        })
564        .collect()
565}
566
567pub fn array_value(values: &[&str]) -> toml_edit::Value {
568    let mut array = toml_edit::Array::new();
569    for value in values {
570        array.push(*value);
571    }
572    toml_edit::Value::Array(array)
573}