Skip to main content

tideway_cli/commands/
add.rs

1//! Add command - enable Tideway features and scaffold modules.
2
3use anyhow::{Context, Result};
4use std::collections::BTreeSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::cli::{AddArgs, AddFeature};
9use crate::templates::{BackendTemplateContext, BackendTemplateEngine};
10use crate::{
11    ensure_dir, print_info, print_success, print_warning, write_file, TIDEWAY_VERSION,
12};
13
14pub fn run(args: AddArgs) -> Result<()> {
15    let project_dir = PathBuf::from(&args.path);
16    let cargo_path = project_dir.join("Cargo.toml");
17
18    if !cargo_path.exists() {
19        return Err(anyhow::anyhow!(
20            "Cargo.toml not found in {}",
21            project_dir.display()
22        ));
23    }
24
25    let cargo_contents = fs::read_to_string(&cargo_path)
26        .with_context(|| format!("Failed to read {}", cargo_path.display()))?;
27
28    let project_name = project_name_from_cargo(&cargo_contents, &project_dir);
29    let project_name_pascal = to_pascal_case(&project_name);
30
31    update_cargo_toml(&cargo_path, &cargo_contents, args.feature)?;
32    update_env_example(&project_dir, args.feature, &project_name)?;
33
34    if args.feature == AddFeature::Auth {
35        scaffold_auth(&project_dir, &project_name, &project_name_pascal, args.force)?;
36        print_info("Auth scaffold created in src/auth/");
37        if args.wire {
38            wire_auth_in_main(&project_dir, &project_name)?;
39        } else {
40            print_info("Next steps: wire AuthModule + SimpleAuthProvider in main.rs");
41        }
42    }
43
44    if args.feature == AddFeature::Database && args.wire {
45        wire_database_in_main(&project_dir)?;
46    }
47
48    if args.feature == AddFeature::Openapi {
49        ensure_openapi_docs_file(&project_dir)?;
50        if args.wire {
51            wire_openapi_in_main(&project_dir)?;
52        } else {
53            print_info("Next steps: wire OpenAPI in main.rs");
54        }
55    }
56
57    print_success(&format!("Added {}", args.feature));
58    Ok(())
59}
60
61fn update_cargo_toml(path: &Path, contents: &str, feature: AddFeature) -> Result<()> {
62    let mut doc = contents.parse::<toml_edit::DocumentMut>()?;
63
64    let deps = doc["dependencies"].or_insert(toml_edit::Item::Table(toml_edit::Table::new()));
65
66    let tideway_item = deps
67        .as_table_mut()
68        .expect("dependencies should be a table")
69        .entry("tideway");
70
71    let feature_name = feature.to_string();
72
73    match tideway_item {
74        toml_edit::Entry::Vacant(entry) => {
75            let mut table = toml_edit::InlineTable::new();
76            table.get_or_insert("version", TIDEWAY_VERSION);
77            table.get_or_insert("features", array_value(&[feature_name.as_str()]));
78            entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
79        }
80        toml_edit::Entry::Occupied(mut entry) => {
81            if entry.get().is_str() {
82                let version = entry
83                    .get()
84                    .as_str()
85                    .unwrap_or(TIDEWAY_VERSION)
86                    .to_string();
87                let mut table = toml_edit::InlineTable::new();
88                table.get_or_insert("version", version);
89                table.get_or_insert("features", array_value(&[feature_name.as_str()]));
90                entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
91            } else {
92                let item = entry.get_mut();
93                let features = item["features"]
94                    .or_insert(toml_edit::Item::Value(toml_edit::Value::Array(toml_edit::Array::new())))
95                    .as_array_mut()
96                    .expect("features should be an array");
97
98                if !features.iter().any(|v| v.as_str() == Some(&feature_name)) {
99                    features.push(feature_name);
100                }
101            }
102        }
103    }
104
105    if feature == AddFeature::Database {
106        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
107        deps_table
108            .entry("sea-orm")
109            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
110                {
111                    let mut table = toml_edit::InlineTable::new();
112                    table.get_or_insert("version", "1.1");
113                    table.get_or_insert(
114                        "features",
115                        array_value(&["sqlx-postgres", "runtime-tokio-rustls"]),
116                    );
117                    table
118                },
119            )));
120    }
121
122    if feature == AddFeature::Auth {
123        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
124        deps_table
125            .entry("async-trait")
126            .or_insert(toml_edit::value("0.1"));
127        deps_table
128            .entry("serde")
129            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
130                {
131                    let mut table = toml_edit::InlineTable::new();
132                    table.get_or_insert("version", "1.0");
133                    table.get_or_insert("features", array_value(&["derive"]));
134                    table
135                },
136            )));
137        deps_table
138            .entry("serde_json")
139            .or_insert(toml_edit::value("1.0"));
140    }
141
142    write_file(path, &doc.to_string())
143        .with_context(|| format!("Failed to write {}", path.display()))?;
144    Ok(())
145}
146
147fn update_env_example(project_dir: &Path, feature: AddFeature, project_name: &str) -> Result<()> {
148    let env_path = project_dir.join(".env.example");
149    let mut lines = if env_path.exists() {
150        fs::read_to_string(&env_path)
151            .with_context(|| format!("Failed to read {}", env_path.display()))?
152            .lines()
153            .map(|line| line.to_string())
154            .collect::<Vec<_>>()
155    } else {
156        vec![
157            "# Server".to_string(),
158            "TIDEWAY_HOST=0.0.0.0".to_string(),
159            "TIDEWAY_PORT=8000".to_string(),
160            String::new(),
161        ]
162    };
163
164    let mut existing = BTreeSet::new();
165    for line in &lines {
166        if let Some((key, _)) = line.split_once('=') {
167            existing.insert(key.trim().to_string());
168        }
169    }
170
171    match feature {
172        AddFeature::Database => {
173            if !existing.contains("DATABASE_URL") {
174                lines.push("# Database".to_string());
175                lines.push(format!(
176                    "DATABASE_URL=postgres://postgres:postgres@localhost:5432/{}",
177                    project_name
178                ));
179                lines.push(String::new());
180            }
181        }
182        AddFeature::Auth => {
183            if !existing.contains("JWT_SECRET") {
184                lines.push("# Auth".to_string());
185                lines.push("JWT_SECRET=your-super-secret-jwt-key-change-in-production".to_string());
186                lines.push(String::new());
187            }
188        }
189        _ => {}
190    }
191
192    write_file(&env_path, &lines.join("\n"))
193        .with_context(|| format!("Failed to write {}", env_path.display()))?;
194    Ok(())
195}
196
197fn scaffold_auth(
198    project_dir: &Path,
199    project_name: &str,
200    project_name_pascal: &str,
201    force: bool,
202) -> Result<()> {
203    let context = BackendTemplateContext {
204        project_name: project_name.to_string(),
205        project_name_pascal: project_name_pascal.to_string(),
206        has_organizations: false,
207        database: "postgres".to_string(),
208        tideway_version: TIDEWAY_VERSION.to_string(),
209        tideway_features: vec!["auth".to_string()],
210        has_tideway_features: true,
211        has_auth_feature: true,
212        has_database_feature: false,
213        needs_arc: true,
214        has_config: false,
215    };
216
217    let engine = BackendTemplateEngine::new(context)?;
218    let auth_dir = project_dir.join("src").join("auth");
219
220    write_file_with_force(
221        &auth_dir.join("mod.rs"),
222        &engine.render("starter/src/auth/mod.rs")?,
223        force,
224    )?;
225    write_file_with_force(
226        &auth_dir.join("provider.rs"),
227        &engine.render("starter/src/auth/provider.rs")?,
228        force,
229    )?;
230    write_file_with_force(
231        &auth_dir.join("routes.rs"),
232        &engine.render("starter/src/auth/routes.rs")?,
233        force,
234    )?;
235
236    Ok(())
237}
238
239fn wire_auth_in_main(project_dir: &Path, project_name: &str) -> Result<()> {
240    let main_path = project_dir.join("src").join("main.rs");
241    if !main_path.exists() {
242        print_warning("src/main.rs not found; skipping auto-wiring");
243        return Ok(());
244    }
245
246    let mut contents = fs::read_to_string(&main_path)
247        .with_context(|| format!("Failed to read {}", main_path.display()))?;
248
249    if !contents.contains("mod auth;") {
250        if contents.contains("mod routes;") {
251            contents = contents.replace("mod routes;\n", "mod routes;\nmod auth;\n");
252        } else {
253            contents = format!("mod auth;\n{}", contents);
254        }
255    }
256
257    contents = ensure_use_line(
258        contents,
259        "use axum::Extension;",
260        "use tideway::auth",
261    );
262    contents = ensure_use_line(
263        contents,
264        "use crate::auth::{AuthModule, SimpleAuthProvider};",
265        "use tideway::auth",
266    );
267    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
268    contents = ensure_use_line(
269        contents,
270        "use tideway::auth::{JwtIssuer, JwtIssuerConfig};",
271        "use tideway::auth",
272    );
273
274    let has_jwt_secret = contents.contains("let jwt_secret");
275    let has_jwt_issuer = contents.contains("let jwt_issuer");
276    let has_auth_provider = contents.contains("auth_provider");
277    let has_auth_module = contents.contains("auth_module");
278
279    if has_jwt_secret && has_jwt_issuer {
280        if !has_auth_provider || !has_auth_module {
281            if let Some(insert_at) = contents.find("let jwt_issuer") {
282                let after = contents[insert_at..]
283                    .find(";\n")
284                    .map(|idx| insert_at + idx + 2)
285                    .unwrap_or(insert_at);
286                let insert = format!(
287                    "    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n"
288                );
289                contents.insert_str(after, &insert);
290            }
291        }
292    } else {
293        let block = format!(
294            "    let jwt_secret = std::env::var(\"JWT_SECRET\").expect(\"JWT_SECRET is not set\");\n    let jwt_issuer = Arc::new(JwtIssuer::new(JwtIssuerConfig::with_secret(\n        &jwt_secret,\n        \"{}\",\n    )).expect(\"Failed to create JWT issuer\"));\n    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n\n",
295            project_name
296        );
297        contents = insert_before_app_builder(contents, &block)?;
298    }
299
300    contents = insert_auth_into_app_builder(contents)?;
301
302    write_file(&main_path, &contents)
303        .with_context(|| format!("Failed to write {}", main_path.display()))?;
304    print_success("Wired auth into src/main.rs");
305    Ok(())
306}
307
308pub fn wire_database_in_main(project_dir: &Path) -> Result<()> {
309    let main_path = project_dir.join("src").join("main.rs");
310    if !main_path.exists() {
311        print_warning("src/main.rs not found; skipping auto-wiring");
312        return Ok(());
313    }
314
315    let mut contents = fs::read_to_string(&main_path)
316        .with_context(|| format!("Failed to read {}", main_path.display()))?;
317
318    if !contents.contains("async fn main") {
319        print_warning("main.rs is not async; skipping database wiring");
320        return Ok(());
321    }
322
323    contents = ensure_use_line(
324        contents,
325        "use tideway::{AppContext, SeaOrmPool};",
326        "use tideway::",
327    );
328    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
329
330    let has_database_block = contents.contains("DATABASE_URL")
331        || contents.contains("sea_orm::Database::connect")
332        || contents.contains("with_database");
333
334    if !has_database_block {
335        let block = "    let database_url = std::env::var(\"DATABASE_URL\").expect(\"DATABASE_URL is not set\");\n    let db = sea_orm::Database::connect(&database_url)\n        .await\n        .expect(\"Failed to connect to database\");\n\n";
336        contents = insert_before_app_builder(contents, block)?;
337    }
338
339    if !contents.contains(".with_database(") {
340        contents = insert_database_into_app_builder(contents)?;
341    }
342
343    write_file(&main_path, &contents)
344        .with_context(|| format!("Failed to write {}", main_path.display()))?;
345    print_success("Wired database into src/main.rs");
346    Ok(())
347}
348
349fn ensure_use_line(mut contents: String, line: &str, anchor: &str) -> String {
350    if contents.contains(line) {
351        return contents;
352    }
353
354    if let Some(pos) = contents.find(anchor) {
355        if let Some(line_end) = contents[pos..].find('\n') {
356            let insert_at = pos + line_end + 1;
357            contents.insert_str(insert_at, &format!("{}\n", line));
358            return contents;
359        }
360    }
361
362    contents = format!("{}\n{}", line, contents);
363    contents
364}
365
366fn insert_before_app_builder(mut contents: String, block: &str) -> Result<String> {
367    if let Some(pos) = contents.find("let app = App::") {
368        contents.insert_str(pos, block);
369        Ok(contents)
370    } else {
371        print_warning("Could not find app builder; skipping auth wiring");
372        Ok(contents)
373    }
374}
375
376fn insert_auth_into_app_builder(mut contents: String) -> Result<String> {
377    if contents.contains("register_module(auth_module)") {
378        return Ok(contents);
379    }
380
381    if let Some(pos) = contents.find("let app = App::") {
382        let line_end = contents[pos..]
383            .find('\n')
384            .map(|idx| pos + idx)
385            .unwrap_or(contents.len());
386        let indent = contents[pos..]
387            .chars()
388            .take_while(|c| c.is_whitespace())
389            .collect::<String>();
390        let insert = format!(
391            "{}    .with_global_layer(Extension(auth_provider))\n{}    .register_module(auth_module)\n",
392            indent, indent
393        );
394        contents.insert_str(line_end + 1, &insert);
395        Ok(contents)
396    } else {
397        print_warning("Could not find app builder; skipping auth module registration");
398        Ok(contents)
399    }
400}
401
402fn insert_database_into_app_builder(mut contents: String) -> Result<String> {
403    if let Some(pos) = contents.find("let app = App::") {
404        let line_end = contents[pos..]
405            .find('\n')
406            .map(|idx| pos + idx)
407            .unwrap_or(contents.len());
408        let indent = contents[pos..]
409            .chars()
410            .take_while(|c| c.is_whitespace())
411            .collect::<String>();
412        let insert = format!(
413            "{}    .with_context(\n{}        AppContext::builder()\n{}            .with_database(Arc::new(SeaOrmPool::new(db, database_url)))\n{}            .build()\n{}    )\n",
414            indent, indent, indent, indent, indent
415        );
416        contents.insert_str(line_end + 1, &insert);
417        Ok(contents)
418    } else {
419        print_warning("Could not find app builder; skipping database wiring");
420        Ok(contents)
421    }
422}
423
424fn wire_openapi_in_main(project_dir: &Path) -> Result<()> {
425    let main_path = project_dir.join("src").join("main.rs");
426    if !main_path.exists() {
427        print_warning("src/main.rs not found; skipping auto-wiring");
428        return Ok(());
429    }
430
431    let mut contents = fs::read_to_string(&main_path)
432        .with_context(|| format!("Failed to read {}", main_path.display()))?;
433
434    if contents.contains("openapi::create_openapi_router") || contents.contains("openapi_merge_module") {
435        print_info("OpenAPI already appears wired in main.rs");
436        return Ok(());
437    }
438
439    contents = ensure_use_line(contents, "use tideway::ConfigBuilder;", "use tideway::");
440    if contents.contains("mod config;") {
441        contents = ensure_use_line(contents, "use crate::config::AppConfig;", "use tideway::");
442    }
443    contents = ensure_use_line(contents, "use tideway::openapi;", "use tideway::");
444
445    if !contents.contains("mod openapi_docs;") {
446        if contents.contains("mod routes;") {
447            contents = contents.replace("mod routes;\n", "mod routes;\nmod openapi_docs;\n");
448        } else {
449            contents = format!("mod openapi_docs;\n{}", contents);
450        }
451    }
452
453    let has_config_var = contents.contains("let config = ConfigBuilder::new()")
454        || contents.contains("let config = AppConfig::from_env()");
455    let config_available = contents.contains("ConfigBuilder::new()")
456        || contents.contains("AppConfig::from_env()");
457
458    if !has_config_var && config_available {
459        let config_block = "    let config = ConfigBuilder::new()\n        .from_env()\n        .build()\n        .expect(\"Invalid TIDEWAY_* config\");\n\n";
460        contents = insert_before_app_builder(contents, config_block)?;
461    }
462
463    if contents.contains("let config = AppConfig::from_env()") {
464        contents = insert_openapi_into_app_builder(contents, "config.tideway")?;
465    } else {
466        contents = insert_openapi_into_app_builder(contents, "config")?;
467    }
468
469    write_file(&main_path, &contents)
470        .with_context(|| format!("Failed to write {}", main_path.display()))?;
471    print_success("Wired OpenAPI into src/main.rs");
472    Ok(())
473}
474
475fn insert_openapi_into_app_builder(mut contents: String, config_ref: &str) -> Result<String> {
476    if contents.contains("create_openapi_router") {
477        return Ok(contents);
478    }
479
480    if let Some(pos) = contents.find("let app = App::") {
481        // Insert after app builder block to keep code readable.
482        if let Some(end_pos) = contents[pos..].find(";\n\n") {
483            let insert_at = pos + end_pos + 3;
484            let block = format!(
485                "\n    #[cfg(feature = \"openapi\")]\n    if {config_ref}.openapi.enabled {{\n        let openapi = tideway::openapi_merge_module!(openapi_docs, ApiDoc);\n        let openapi_router = tideway::openapi::create_openapi_router(openapi, &{config_ref}.openapi);\n        app = app.merge_router(openapi_router);\n    }}\n"
486            );
487            contents.insert_str(insert_at, &block);
488        } else {
489            print_warning("Could not find app builder termination; skipping OpenAPI wiring");
490        }
491        Ok(contents)
492    } else {
493        print_warning("Could not find app builder; skipping OpenAPI wiring");
494        Ok(contents)
495    }
496}
497
498fn ensure_openapi_docs_file(project_dir: &Path) -> Result<()> {
499    let docs_path = project_dir.join("src").join("openapi_docs.rs");
500    if docs_path.exists() {
501        return Ok(());
502    }
503
504    let contents = r#"#[cfg(feature = "openapi")]
505tideway::openapi_doc!(pub(crate) ApiDoc, paths());
506"#;
507
508    if let Some(parent) = docs_path.parent() {
509        ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
510    }
511
512    write_file(&docs_path, &contents)
513        .with_context(|| format!("Failed to write {}", docs_path.display()))?;
514    print_success("Created src/openapi_docs.rs");
515    Ok(())
516}
517
518
519fn write_file_with_force(path: &Path, contents: &str, force: bool) -> Result<()> {
520    if path.exists() && !force {
521        print_warning(&format!(
522            "Skipping {} (use --force to overwrite)",
523            path.display()
524        ));
525        return Ok(());
526    }
527
528    if let Some(parent) = path.parent() {
529        ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
530    }
531
532    write_file(path, contents)
533        .with_context(|| format!("Failed to write {}", path.display()))?;
534    Ok(())
535}
536
537fn project_name_from_cargo(contents: &str, project_dir: &Path) -> String {
538    let doc = contents
539        .parse::<toml_edit::DocumentMut>()
540        .ok()
541        .and_then(|doc| doc["package"]["name"].as_str().map(|s| s.to_string()));
542
543    doc.unwrap_or_else(|| {
544        project_dir
545            .file_name()
546            .and_then(|n| n.to_str())
547            .unwrap_or("my_app")
548            .to_string()
549    })
550    .replace('-', "_")
551}
552
553fn to_pascal_case(s: &str) -> String {
554    s.split('_')
555        .filter(|part| !part.is_empty())
556        .map(|word| {
557            let mut chars = word.chars();
558            match chars.next() {
559                None => String::new(),
560                Some(first) => first.to_uppercase().chain(chars).collect(),
561            }
562        })
563        .collect()
564}
565
566pub fn array_value(values: &[&str]) -> toml_edit::Value {
567    let mut array = toml_edit::Array::new();
568    for value in values {
569        array.push(*value);
570    }
571    toml_edit::Value::Array(array)
572}