Skip to main content

tideway_cli/commands/
add.rs

1//! Add command - enable Tideway features and scaffold modules.
2
3use anyhow::{Context, Result};
4use std::collections::BTreeSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::cli::{AddArgs, AddFeature};
9use crate::commands::app_builder::{
10    find_app_builder_end_insert_at, find_app_builder_marker_range, find_app_builder_start,
11    find_app_builder_var_name, find_unmarked_app_builder_statement_range,
12    insert_snippet_into_builder_block,
13};
14use crate::commands::file_ops::{ensure_module_decl, to_pascal_case, write_file_with_force};
15use crate::commands::messaging::GREENFIELD_NEW_APP_FIRST;
16use crate::templates::{BackendTemplateContext, BackendTemplateEngine};
17use crate::{
18    TIDEWAY_VERSION, ensure_dir, error_contract, print_info, print_success, print_warning,
19    write_file,
20};
21
22pub fn run(args: AddArgs) -> Result<()> {
23    let project_dir = PathBuf::from(&args.path);
24    let cargo_path = project_dir.join("Cargo.toml");
25
26    if !cargo_path.exists() {
27        return Err(anyhow::anyhow!(error_contract(
28            &format!("Cargo.toml not found in {}", project_dir.display()),
29            "Run this command inside a Rust project root.",
30            GREENFIELD_NEW_APP_FIRST,
31        )));
32    }
33
34    let cargo_contents = fs::read_to_string(&cargo_path)
35        .with_context(|| format!("Failed to read {}", cargo_path.display()))?;
36
37    let project_name = project_name_from_cargo(&cargo_contents, &project_dir);
38    let project_name_pascal = to_pascal_case(&project_name);
39
40    update_cargo_toml(&cargo_path, &cargo_contents, args.feature)?;
41    update_env_example(&project_dir, args.feature, &project_name)?;
42
43    if args.feature == AddFeature::Auth {
44        scaffold_auth(
45            &project_dir,
46            &project_name,
47            &project_name_pascal,
48            args.force,
49        )?;
50        print_info("Auth scaffold created in src/auth/");
51        if args.wire {
52            wire_auth_in_main(&project_dir, &project_name)?;
53        } else {
54            print_info("Next steps: wire AuthModule + SimpleAuthProvider in main.rs");
55        }
56    }
57
58    if args.feature == AddFeature::Database && args.wire {
59        wire_database_in_main(&project_dir)?;
60    }
61
62    if args.feature == AddFeature::Openapi {
63        ensure_openapi_docs_file(&project_dir)?;
64        if args.wire {
65            wire_openapi_in_main(&project_dir)?;
66        } else {
67            print_info("Next steps: wire OpenAPI in main.rs");
68        }
69    }
70
71    print_success(&format!("Added {}", args.feature));
72    Ok(())
73}
74
75fn update_cargo_toml(path: &Path, contents: &str, feature: AddFeature) -> Result<()> {
76    let mut doc = contents.parse::<toml_edit::DocumentMut>()?;
77
78    let deps = doc["dependencies"].or_insert(toml_edit::Item::Table(toml_edit::Table::new()));
79
80    let tideway_item = deps
81        .as_table_mut()
82        .expect("dependencies should be a table")
83        .entry("tideway");
84
85    let feature_name = feature.to_string();
86
87    match tideway_item {
88        toml_edit::Entry::Vacant(entry) => {
89            let mut table = toml_edit::InlineTable::new();
90            table.get_or_insert("version", TIDEWAY_VERSION);
91            table.get_or_insert("features", array_value(&[feature_name.as_str()]));
92            entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
93        }
94        toml_edit::Entry::Occupied(mut entry) => {
95            if entry.get().is_str() {
96                let version = entry.get().as_str().unwrap_or(TIDEWAY_VERSION).to_string();
97                let mut table = toml_edit::InlineTable::new();
98                table.get_or_insert("version", version);
99                table.get_or_insert("features", array_value(&[feature_name.as_str()]));
100                entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
101            } else {
102                let item = entry.get_mut();
103                let features = item["features"]
104                    .or_insert(toml_edit::Item::Value(toml_edit::Value::Array(
105                        toml_edit::Array::new(),
106                    )))
107                    .as_array_mut()
108                    .expect("features should be an array");
109
110                if !features.iter().any(|v| v.as_str() == Some(&feature_name)) {
111                    features.push(feature_name);
112                }
113            }
114        }
115    }
116
117    if feature == AddFeature::Database {
118        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
119        deps_table
120            .entry("sea-orm")
121            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable({
122                let mut table = toml_edit::InlineTable::new();
123                table.get_or_insert("version", "1.1");
124                table.get_or_insert(
125                    "features",
126                    array_value(&["sqlx-postgres", "runtime-tokio-rustls"]),
127                );
128                table
129            })));
130    }
131
132    if feature == AddFeature::Auth {
133        let deps_table = deps.as_table_mut().expect("dependencies should be a table");
134        deps_table
135            .entry("async-trait")
136            .or_insert(toml_edit::value("0.1"));
137        deps_table
138            .entry("serde")
139            .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable({
140                let mut table = toml_edit::InlineTable::new();
141                table.get_or_insert("version", "1.0");
142                table.get_or_insert("features", array_value(&["derive"]));
143                table
144            })));
145        deps_table
146            .entry("serde_json")
147            .or_insert(toml_edit::value("1.0"));
148    }
149
150    write_file(path, &doc.to_string())
151        .with_context(|| format!("Failed to write {}", path.display()))?;
152    Ok(())
153}
154
155fn update_env_example(project_dir: &Path, feature: AddFeature, project_name: &str) -> Result<()> {
156    let env_path = project_dir.join(".env.example");
157    let mut lines = if env_path.exists() {
158        fs::read_to_string(&env_path)
159            .with_context(|| format!("Failed to read {}", env_path.display()))?
160            .lines()
161            .map(|line| line.to_string())
162            .collect::<Vec<_>>()
163    } else {
164        vec![
165            "# Server".to_string(),
166            "TIDEWAY_HOST=0.0.0.0".to_string(),
167            "TIDEWAY_PORT=8000".to_string(),
168            String::new(),
169        ]
170    };
171
172    let mut existing = BTreeSet::new();
173    for line in &lines {
174        if let Some((key, _)) = line.split_once('=') {
175            existing.insert(key.trim().to_string());
176        }
177    }
178
179    match feature {
180        AddFeature::Database => {
181            if !existing.contains("DATABASE_URL") {
182                lines.push("# Database".to_string());
183                lines.push(format!(
184                    "DATABASE_URL=postgres://postgres:postgres@localhost:5432/{}",
185                    project_name
186                ));
187                lines.push(String::new());
188            }
189        }
190        AddFeature::Auth => {
191            if !existing.contains("JWT_SECRET") {
192                lines.push("# Auth".to_string());
193                lines.push("JWT_SECRET=your-super-secret-jwt-key-change-in-production".to_string());
194                lines.push(String::new());
195            }
196        }
197        _ => {}
198    }
199
200    write_file(&env_path, &lines.join("\n"))
201        .with_context(|| format!("Failed to write {}", env_path.display()))?;
202    Ok(())
203}
204
205fn scaffold_auth(
206    project_dir: &Path,
207    project_name: &str,
208    project_name_pascal: &str,
209    force: bool,
210) -> Result<()> {
211    let context = BackendTemplateContext {
212        project_name: project_name.to_string(),
213        project_name_pascal: project_name_pascal.to_string(),
214        has_organizations: false,
215        database: "postgres".to_string(),
216        database_url: format!(
217            "postgres://postgres:postgres@localhost:5432/{}",
218            project_name
219        ),
220        is_sqlite_database: false,
221        tideway_version: TIDEWAY_VERSION.to_string(),
222        tideway_features: vec!["auth".to_string()],
223        has_tideway_features: true,
224        has_auth_feature: true,
225        has_database_feature: false,
226        has_openapi_feature: false,
227        needs_arc: true,
228        has_config: false,
229    };
230
231    let engine = BackendTemplateEngine::new(context)?;
232    let auth_dir = project_dir.join("src").join("auth");
233
234    write_file_with_force(
235        &auth_dir.join("mod.rs"),
236        &engine.render("starter/src/auth/mod.rs")?,
237        force,
238    )?;
239    write_file_with_force(
240        &auth_dir.join("provider.rs"),
241        &engine.render("starter/src/auth/provider.rs")?,
242        force,
243    )?;
244    write_file_with_force(
245        &auth_dir.join("routes.rs"),
246        &engine.render("starter/src/auth/routes.rs")?,
247        force,
248    )?;
249
250    Ok(())
251}
252
253fn wire_auth_in_main(project_dir: &Path, project_name: &str) -> Result<()> {
254    let main_path = project_dir.join("src").join("main.rs");
255    if !main_path.exists() {
256        print_warning("src/main.rs not found; skipping auto-wiring");
257        return Ok(());
258    }
259
260    let mut contents = fs::read_to_string(&main_path)
261        .with_context(|| format!("Failed to read {}", main_path.display()))?;
262
263    contents = ensure_module_decl(&contents, "auth");
264
265    contents = ensure_use_line(contents, "use axum::Extension;", "use tideway::auth");
266    contents = ensure_use_line(
267        contents,
268        "use crate::auth::{AuthModule, SimpleAuthProvider};",
269        "use tideway::auth",
270    );
271    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
272    contents = ensure_use_line(
273        contents,
274        "use tideway::auth::{JwtIssuer, JwtIssuerConfig};",
275        "use tideway::auth",
276    );
277
278    let has_jwt_secret = contents.contains("let jwt_secret");
279    let has_jwt_issuer = contents.contains("let jwt_issuer");
280    let has_auth_provider = contents.contains("auth_provider");
281    let has_auth_module = contents.contains("auth_module");
282
283    if has_jwt_secret && has_jwt_issuer {
284        if !has_auth_provider || !has_auth_module {
285            if let Some(insert_at) = contents.find("let jwt_issuer") {
286                let after = contents[insert_at..]
287                    .find(";\n")
288                    .map(|idx| insert_at + idx + 2)
289                    .unwrap_or(insert_at);
290                let insert = "    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n".to_string();
291                contents.insert_str(after, &insert);
292            }
293        }
294    } else {
295        let block = format!(
296            "    let jwt_secret = std::env::var(\"JWT_SECRET\").expect(\"JWT_SECRET is not set\");\n    let jwt_issuer = Arc::new(JwtIssuer::new(JwtIssuerConfig::with_secret(\n        &jwt_secret,\n        \"{}\",\n    )).expect(\"Failed to create JWT issuer\"));\n    let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n    let auth_module = AuthModule::new(jwt_issuer.clone());\n\n",
297            project_name
298        );
299        contents = insert_before_app_builder(contents, &block)?;
300    }
301
302    contents = insert_auth_into_app_builder(contents)?;
303
304    write_file(&main_path, &contents)
305        .with_context(|| format!("Failed to write {}", main_path.display()))?;
306    print_success("Wired auth into src/main.rs");
307    Ok(())
308}
309
310pub fn wire_database_in_main(project_dir: &Path) -> Result<()> {
311    let main_path = project_dir.join("src").join("main.rs");
312    if !main_path.exists() {
313        print_warning("src/main.rs not found; skipping auto-wiring");
314        return Ok(());
315    }
316
317    let mut contents = fs::read_to_string(&main_path)
318        .with_context(|| format!("Failed to read {}", main_path.display()))?;
319
320    if !contents.contains("async fn main") {
321        print_warning("main.rs is not async; skipping database wiring");
322        return Ok(());
323    }
324
325    contents = ensure_use_line(
326        contents,
327        "use tideway::{AppContext, SeaOrmPool};",
328        "use tideway::",
329    );
330    contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
331
332    let has_database_block = contents.contains("DATABASE_URL")
333        || contents.contains("sea_orm::Database::connect")
334        || contents.contains("with_database");
335
336    if !has_database_block {
337        let block = "    let database_url = std::env::var(\"DATABASE_URL\").expect(\"DATABASE_URL is not set\");\n    let db = sea_orm::Database::connect(&database_url)\n        .await\n        .expect(\"Failed to connect to database\");\n\n";
338        contents = insert_before_app_builder(contents, block)?;
339    }
340
341    if !contents.contains(".with_database(") {
342        contents = insert_database_into_app_builder(contents)?;
343    }
344
345    write_file(&main_path, &contents)
346        .with_context(|| format!("Failed to write {}", main_path.display()))?;
347    print_success("Wired database into src/main.rs");
348    Ok(())
349}
350
351fn ensure_use_line(mut contents: String, line: &str, anchor: &str) -> String {
352    if contents.contains(line) {
353        return contents;
354    }
355
356    if let Some(pos) = contents.find(anchor) {
357        if let Some(line_end) = contents[pos..].find('\n') {
358            let insert_at = pos + line_end + 1;
359            contents.insert_str(insert_at, &format!("{}\n", line));
360            return contents;
361        }
362    }
363
364    contents = format!("{}\n{}", line, contents);
365    contents
366}
367
368fn insert_before_app_builder(mut contents: String, block: &str) -> Result<String> {
369    if let Some((start, _)) = find_app_builder_marker_range(&contents) {
370        contents.insert_str(start, block);
371        return Ok(contents);
372    }
373
374    if let Some((start, _)) = find_unmarked_app_builder_statement_range(&contents) {
375        contents.insert_str(start, block);
376        return Ok(contents);
377    }
378
379    print_warning("Could not find app builder; skipping auth wiring");
380    Ok(contents)
381}
382
383fn insert_auth_into_app_builder(mut contents: String) -> Result<String> {
384    if contents.contains("register_module(auth_module)") {
385        return Ok(contents);
386    }
387
388    let insert = ".with_global_layer(Extension(auth_provider))\n.register_module(auth_module)";
389    if let Some((start, end)) = find_app_builder_marker_range(&contents) {
390        let statement = &contents[start..=end];
391        if let Some(updated) = insert_snippet_into_builder_block(statement, insert) {
392            contents.replace_range(start..=end, &updated);
393            return Ok(contents);
394        }
395        print_warning("Could not update app builder; skipping auth module registration");
396        return Ok(contents);
397    }
398
399    if let Some((start, end)) = find_unmarked_app_builder_statement_range(&contents) {
400        let statement = &contents[start..=end];
401        if let Some(updated) = insert_snippet_into_builder_block(statement, insert) {
402            contents.replace_range(start..=end, &updated);
403            return Ok(contents);
404        }
405        print_warning("Could not update app builder; skipping auth module registration");
406        return Ok(contents);
407    }
408
409    print_warning("Could not find app builder; skipping auth module registration");
410    Ok(contents)
411}
412
413fn insert_database_into_app_builder(mut contents: String) -> Result<String> {
414    if contents.contains(".with_database(") {
415        return Ok(contents);
416    }
417
418    let insert = ".with_context(\n    AppContext::builder()\n        .with_database(Arc::new(SeaOrmPool::new(db, database_url)))\n        .build()\n)";
419
420    if let Some((start, end)) = find_app_builder_marker_range(&contents) {
421        let statement = &contents[start..=end];
422        if let Some(updated) = insert_snippet_into_builder_block(statement, insert) {
423            contents.replace_range(start..=end, &updated);
424            return Ok(contents);
425        }
426        print_warning("Could not update app builder; skipping database wiring");
427        return Ok(contents);
428    }
429
430    if let Some((start, end)) = find_unmarked_app_builder_statement_range(&contents) {
431        let statement = &contents[start..=end];
432        if let Some(updated) = insert_snippet_into_builder_block(statement, insert) {
433            contents.replace_range(start..=end, &updated);
434            return Ok(contents);
435        }
436        print_warning("Could not update app builder; skipping database wiring");
437        return Ok(contents);
438    }
439
440    print_warning("Could not find app builder; skipping database wiring");
441    Ok(contents)
442}
443
444fn wire_openapi_in_main(project_dir: &Path) -> Result<()> {
445    let main_path = project_dir.join("src").join("main.rs");
446    if !main_path.exists() {
447        print_warning("src/main.rs not found; skipping auto-wiring");
448        return Ok(());
449    }
450
451    let mut contents = fs::read_to_string(&main_path)
452        .with_context(|| format!("Failed to read {}", main_path.display()))?;
453
454    if contents.contains("openapi::create_openapi_router")
455        || contents.contains("openapi_merge_module")
456    {
457        print_info("OpenAPI already appears wired in main.rs");
458        return Ok(());
459    }
460
461    contents = ensure_use_line(contents, "use tideway::ConfigBuilder;", "use tideway::");
462    if contents.contains("mod config;") {
463        contents = ensure_use_line(contents, "use crate::config::AppConfig;", "use tideway::");
464    }
465    contents = ensure_use_line(contents, "use tideway::openapi;", "use tideway::");
466
467    if !contents.contains("mod openapi_docs;") {
468        contents = ensure_module_decl(&contents, "openapi_docs");
469    }
470
471    let has_config_var = contents.contains("let config = ConfigBuilder::new()")
472        || contents.contains("let config = AppConfig::from_env()");
473    let config_available =
474        contents.contains("ConfigBuilder::new()") || contents.contains("AppConfig::from_env()");
475
476    if !has_config_var && config_available {
477        let config_block = "    let config = ConfigBuilder::new()\n        .from_env()\n        .build()\n        .expect(\"Invalid TIDEWAY_* config\");\n\n";
478        contents = insert_before_app_builder(contents, config_block)?;
479    }
480
481    if contents.contains("let config = AppConfig::from_env()") {
482        contents = insert_openapi_into_app_builder(contents, "config.tideway")?;
483    } else {
484        contents = insert_openapi_into_app_builder(contents, "config")?;
485    }
486
487    write_file(&main_path, &contents)
488        .with_context(|| format!("Failed to write {}", main_path.display()))?;
489    print_success("Wired OpenAPI into src/main.rs");
490    Ok(())
491}
492
493fn insert_openapi_into_app_builder(mut contents: String, config_ref: &str) -> Result<String> {
494    if contents.contains("create_openapi_router") {
495        return Ok(contents);
496    }
497
498    if let Some(pos) = find_app_builder_start(&contents) {
499        let app_var =
500            find_app_builder_var_name(&contents, pos).unwrap_or_else(|| "app".to_string());
501        // Insert after app builder block to keep code readable.
502        if let Some(insert_at) = find_app_builder_end_insert_at(&contents, pos) {
503            let block = format!(
504                "\n    if {config_ref}.openapi.enabled {{\n        let openapi = tideway::openapi_merge_module!(openapi_docs, ApiDoc);\n        let openapi_router = tideway::openapi::create_openapi_router(openapi, &{config_ref}.openapi);\n        {app_var} = {app_var}.merge_router(openapi_router);\n    }}\n"
505            );
506            contents.insert_str(insert_at, &block);
507        } else {
508            print_warning("Could not find app builder termination; skipping OpenAPI wiring");
509        }
510        Ok(contents)
511    } else {
512        print_warning("Could not find app builder; skipping OpenAPI wiring");
513        Ok(contents)
514    }
515}
516
517fn ensure_openapi_docs_file(project_dir: &Path) -> Result<()> {
518    let docs_path = project_dir.join("src").join("openapi_docs.rs");
519    if docs_path.exists() {
520        return Ok(());
521    }
522
523    let contents = r#"tideway::openapi_doc!(pub(crate) ApiDoc, paths());
524"#;
525
526    if let Some(parent) = docs_path.parent() {
527        ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
528    }
529
530    write_file(&docs_path, contents)
531        .with_context(|| format!("Failed to write {}", docs_path.display()))?;
532    print_success("Created src/openapi_docs.rs");
533    Ok(())
534}
535
536fn project_name_from_cargo(contents: &str, project_dir: &Path) -> String {
537    let doc = contents
538        .parse::<toml_edit::DocumentMut>()
539        .ok()
540        .and_then(|doc| doc["package"]["name"].as_str().map(|s| s.to_string()));
541
542    doc.unwrap_or_else(|| {
543        project_dir
544            .file_name()
545            .and_then(|n| n.to_str())
546            .unwrap_or("my_app")
547            .to_string()
548    })
549    .replace('-', "_")
550}
551
552pub fn array_value(values: &[&str]) -> toml_edit::Value {
553    let mut array = toml_edit::Array::new();
554    for value in values {
555        array.push(*value);
556    }
557    toml_edit::Value::Array(array)
558}