1use anyhow::{Context, Result};
4use std::collections::BTreeSet;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use crate::cli::{AddArgs, AddFeature};
9use crate::templates::{BackendTemplateContext, BackendTemplateEngine};
10use crate::{
11 ensure_dir, print_info, print_success, print_warning, write_file, TIDEWAY_VERSION,
12};
13
14pub fn run(args: AddArgs) -> Result<()> {
15 let project_dir = PathBuf::from(&args.path);
16 let cargo_path = project_dir.join("Cargo.toml");
17
18 if !cargo_path.exists() {
19 return Err(anyhow::anyhow!(
20 "Cargo.toml not found in {}",
21 project_dir.display()
22 ));
23 }
24
25 let cargo_contents = fs::read_to_string(&cargo_path)
26 .with_context(|| format!("Failed to read {}", cargo_path.display()))?;
27
28 let project_name = project_name_from_cargo(&cargo_contents, &project_dir);
29 let project_name_pascal = to_pascal_case(&project_name);
30
31 update_cargo_toml(&cargo_path, &cargo_contents, args.feature)?;
32 update_env_example(&project_dir, args.feature, &project_name)?;
33
34 if args.feature == AddFeature::Auth {
35 scaffold_auth(&project_dir, &project_name, &project_name_pascal, args.force)?;
36 print_info("Auth scaffold created in src/auth/");
37 if args.wire {
38 wire_auth_in_main(&project_dir, &project_name)?;
39 } else {
40 print_info("Next steps: wire AuthModule + SimpleAuthProvider in main.rs");
41 }
42 }
43
44 if args.feature == AddFeature::Database && args.wire {
45 wire_database_in_main(&project_dir)?;
46 }
47
48 if args.feature == AddFeature::Openapi {
49 ensure_openapi_docs_file(&project_dir)?;
50 if args.wire {
51 wire_openapi_in_main(&project_dir)?;
52 } else {
53 print_info("Next steps: wire OpenAPI in main.rs");
54 }
55 }
56
57 print_success(&format!("Added {}", args.feature));
58 Ok(())
59}
60
61fn update_cargo_toml(path: &Path, contents: &str, feature: AddFeature) -> Result<()> {
62 let mut doc = contents.parse::<toml_edit::DocumentMut>()?;
63
64 let deps = doc["dependencies"].or_insert(toml_edit::Item::Table(toml_edit::Table::new()));
65
66 let tideway_item = deps
67 .as_table_mut()
68 .expect("dependencies should be a table")
69 .entry("tideway");
70
71 let feature_name = feature.to_string();
72
73 match tideway_item {
74 toml_edit::Entry::Vacant(entry) => {
75 let mut table = toml_edit::InlineTable::new();
76 table.get_or_insert("version", TIDEWAY_VERSION);
77 table.get_or_insert("features", array_value(&[feature_name.as_str()]));
78 entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
79 }
80 toml_edit::Entry::Occupied(mut entry) => {
81 if entry.get().is_str() {
82 let version = entry
83 .get()
84 .as_str()
85 .unwrap_or(TIDEWAY_VERSION)
86 .to_string();
87 let mut table = toml_edit::InlineTable::new();
88 table.get_or_insert("version", version);
89 table.get_or_insert("features", array_value(&[feature_name.as_str()]));
90 entry.insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(table)));
91 } else {
92 let item = entry.get_mut();
93 let features = item["features"]
94 .or_insert(toml_edit::Item::Value(toml_edit::Value::Array(toml_edit::Array::new())))
95 .as_array_mut()
96 .expect("features should be an array");
97
98 if !features.iter().any(|v| v.as_str() == Some(&feature_name)) {
99 features.push(feature_name);
100 }
101 }
102 }
103 }
104
105 if feature == AddFeature::Database {
106 let deps_table = deps.as_table_mut().expect("dependencies should be a table");
107 deps_table
108 .entry("sea-orm")
109 .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
110 {
111 let mut table = toml_edit::InlineTable::new();
112 table.get_or_insert("version", "1.1");
113 table.get_or_insert(
114 "features",
115 array_value(&["sqlx-postgres", "runtime-tokio-rustls"]),
116 );
117 table
118 },
119 )));
120 }
121
122 if feature == AddFeature::Auth {
123 let deps_table = deps.as_table_mut().expect("dependencies should be a table");
124 deps_table
125 .entry("async-trait")
126 .or_insert(toml_edit::value("0.1"));
127 deps_table
128 .entry("serde")
129 .or_insert(toml_edit::Item::Value(toml_edit::Value::InlineTable(
130 {
131 let mut table = toml_edit::InlineTable::new();
132 table.get_or_insert("version", "1.0");
133 table.get_or_insert("features", array_value(&["derive"]));
134 table
135 },
136 )));
137 deps_table
138 .entry("serde_json")
139 .or_insert(toml_edit::value("1.0"));
140 }
141
142 write_file(path, &doc.to_string())
143 .with_context(|| format!("Failed to write {}", path.display()))?;
144 Ok(())
145}
146
147fn update_env_example(project_dir: &Path, feature: AddFeature, project_name: &str) -> Result<()> {
148 let env_path = project_dir.join(".env.example");
149 let mut lines = if env_path.exists() {
150 fs::read_to_string(&env_path)
151 .with_context(|| format!("Failed to read {}", env_path.display()))?
152 .lines()
153 .map(|line| line.to_string())
154 .collect::<Vec<_>>()
155 } else {
156 vec![
157 "# Server".to_string(),
158 "TIDEWAY_HOST=0.0.0.0".to_string(),
159 "TIDEWAY_PORT=8000".to_string(),
160 String::new(),
161 ]
162 };
163
164 let mut existing = BTreeSet::new();
165 for line in &lines {
166 if let Some((key, _)) = line.split_once('=') {
167 existing.insert(key.trim().to_string());
168 }
169 }
170
171 match feature {
172 AddFeature::Database => {
173 if !existing.contains("DATABASE_URL") {
174 lines.push("# Database".to_string());
175 lines.push(format!(
176 "DATABASE_URL=postgres://postgres:postgres@localhost:5432/{}",
177 project_name
178 ));
179 lines.push(String::new());
180 }
181 }
182 AddFeature::Auth => {
183 if !existing.contains("JWT_SECRET") {
184 lines.push("# Auth".to_string());
185 lines.push("JWT_SECRET=your-super-secret-jwt-key-change-in-production".to_string());
186 lines.push(String::new());
187 }
188 }
189 _ => {}
190 }
191
192 write_file(&env_path, &lines.join("\n"))
193 .with_context(|| format!("Failed to write {}", env_path.display()))?;
194 Ok(())
195}
196
197fn scaffold_auth(
198 project_dir: &Path,
199 project_name: &str,
200 project_name_pascal: &str,
201 force: bool,
202) -> Result<()> {
203 let context = BackendTemplateContext {
204 project_name: project_name.to_string(),
205 project_name_pascal: project_name_pascal.to_string(),
206 has_organizations: false,
207 database: "postgres".to_string(),
208 tideway_version: TIDEWAY_VERSION.to_string(),
209 tideway_features: vec!["auth".to_string()],
210 has_tideway_features: true,
211 has_auth_feature: true,
212 has_database_feature: false,
213 needs_arc: true,
214 has_config: false,
215 };
216
217 let engine = BackendTemplateEngine::new(context)?;
218 let auth_dir = project_dir.join("src").join("auth");
219
220 write_file_with_force(
221 &auth_dir.join("mod.rs"),
222 &engine.render("starter/src/auth/mod.rs")?,
223 force,
224 )?;
225 write_file_with_force(
226 &auth_dir.join("provider.rs"),
227 &engine.render("starter/src/auth/provider.rs")?,
228 force,
229 )?;
230 write_file_with_force(
231 &auth_dir.join("routes.rs"),
232 &engine.render("starter/src/auth/routes.rs")?,
233 force,
234 )?;
235
236 Ok(())
237}
238
239fn wire_auth_in_main(project_dir: &Path, project_name: &str) -> Result<()> {
240 let main_path = project_dir.join("src").join("main.rs");
241 if !main_path.exists() {
242 print_warning("src/main.rs not found; skipping auto-wiring");
243 return Ok(());
244 }
245
246 let mut contents = fs::read_to_string(&main_path)
247 .with_context(|| format!("Failed to read {}", main_path.display()))?;
248
249 if !contents.contains("mod auth;") {
250 if contents.contains("mod routes;") {
251 contents = contents.replace("mod routes;\n", "mod routes;\nmod auth;\n");
252 } else {
253 contents = format!("mod auth;\n{}", contents);
254 }
255 }
256
257 contents = ensure_use_line(
258 contents,
259 "use axum::Extension;",
260 "use tideway::auth",
261 );
262 contents = ensure_use_line(
263 contents,
264 "use crate::auth::{AuthModule, SimpleAuthProvider};",
265 "use tideway::auth",
266 );
267 contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
268 contents = ensure_use_line(
269 contents,
270 "use tideway::auth::{JwtIssuer, JwtIssuerConfig};",
271 "use tideway::auth",
272 );
273
274 let has_jwt_secret = contents.contains("let jwt_secret");
275 let has_jwt_issuer = contents.contains("let jwt_issuer");
276 let has_auth_provider = contents.contains("auth_provider");
277 let has_auth_module = contents.contains("auth_module");
278
279 if has_jwt_secret && has_jwt_issuer {
280 if !has_auth_provider || !has_auth_module {
281 if let Some(insert_at) = contents.find("let jwt_issuer") {
282 let after = contents[insert_at..]
283 .find(";\n")
284 .map(|idx| insert_at + idx + 2)
285 .unwrap_or(insert_at);
286 let insert = format!(
287 " let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n let auth_module = AuthModule::new(jwt_issuer.clone());\n"
288 );
289 contents.insert_str(after, &insert);
290 }
291 }
292 } else {
293 let block = format!(
294 " let jwt_secret = std::env::var(\"JWT_SECRET\").expect(\"JWT_SECRET is not set\");\n let jwt_issuer = Arc::new(JwtIssuer::new(JwtIssuerConfig::with_secret(\n &jwt_secret,\n \"{}\",\n )).expect(\"Failed to create JWT issuer\"));\n let auth_provider = SimpleAuthProvider::from_secret(&jwt_secret);\n let auth_module = AuthModule::new(jwt_issuer.clone());\n\n",
295 project_name
296 );
297 contents = insert_before_app_builder(contents, &block)?;
298 }
299
300 contents = insert_auth_into_app_builder(contents)?;
301
302 write_file(&main_path, &contents)
303 .with_context(|| format!("Failed to write {}", main_path.display()))?;
304 print_success("Wired auth into src/main.rs");
305 Ok(())
306}
307
308pub fn wire_database_in_main(project_dir: &Path) -> Result<()> {
309 let main_path = project_dir.join("src").join("main.rs");
310 if !main_path.exists() {
311 print_warning("src/main.rs not found; skipping auto-wiring");
312 return Ok(());
313 }
314
315 let mut contents = fs::read_to_string(&main_path)
316 .with_context(|| format!("Failed to read {}", main_path.display()))?;
317
318 if !contents.contains("async fn main") {
319 print_warning("main.rs is not async; skipping database wiring");
320 return Ok(());
321 }
322
323 contents = ensure_use_line(
324 contents,
325 "use tideway::{AppContext, SeaOrmPool};",
326 "use tideway::",
327 );
328 contents = ensure_use_line(contents, "use std::sync::Arc;", "use tideway::");
329
330 let has_database_block = contents.contains("DATABASE_URL")
331 || contents.contains("sea_orm::Database::connect")
332 || contents.contains("with_database");
333
334 if !has_database_block {
335 let block = " let database_url = std::env::var(\"DATABASE_URL\").expect(\"DATABASE_URL is not set\");\n let db = sea_orm::Database::connect(&database_url)\n .await\n .expect(\"Failed to connect to database\");\n\n";
336 contents = insert_before_app_builder(contents, block)?;
337 }
338
339 if !contents.contains(".with_database(") {
340 contents = insert_database_into_app_builder(contents)?;
341 }
342
343 write_file(&main_path, &contents)
344 .with_context(|| format!("Failed to write {}", main_path.display()))?;
345 print_success("Wired database into src/main.rs");
346 Ok(())
347}
348
349fn ensure_use_line(mut contents: String, line: &str, anchor: &str) -> String {
350 if contents.contains(line) {
351 return contents;
352 }
353
354 if let Some(pos) = contents.find(anchor) {
355 if let Some(line_end) = contents[pos..].find('\n') {
356 let insert_at = pos + line_end + 1;
357 contents.insert_str(insert_at, &format!("{}\n", line));
358 return contents;
359 }
360 }
361
362 contents = format!("{}\n{}", line, contents);
363 contents
364}
365
366fn insert_before_app_builder(mut contents: String, block: &str) -> Result<String> {
367 if let Some(pos) = contents.find("let app = App::") {
368 contents.insert_str(pos, block);
369 Ok(contents)
370 } else {
371 print_warning("Could not find app builder; skipping auth wiring");
372 Ok(contents)
373 }
374}
375
376fn insert_auth_into_app_builder(mut contents: String) -> Result<String> {
377 if contents.contains("register_module(auth_module)") {
378 return Ok(contents);
379 }
380
381 if let Some(pos) = contents.find("let app = App::") {
382 let line_end = contents[pos..]
383 .find('\n')
384 .map(|idx| pos + idx)
385 .unwrap_or(contents.len());
386 let indent = contents[pos..]
387 .chars()
388 .take_while(|c| c.is_whitespace())
389 .collect::<String>();
390 let insert = format!(
391 "{} .with_global_layer(Extension(auth_provider))\n{} .register_module(auth_module)\n",
392 indent, indent
393 );
394 contents.insert_str(line_end + 1, &insert);
395 Ok(contents)
396 } else {
397 print_warning("Could not find app builder; skipping auth module registration");
398 Ok(contents)
399 }
400}
401
402fn insert_database_into_app_builder(mut contents: String) -> Result<String> {
403 if let Some(pos) = contents.find("let app = App::") {
404 let line_end = contents[pos..]
405 .find('\n')
406 .map(|idx| pos + idx)
407 .unwrap_or(contents.len());
408 let indent = contents[pos..]
409 .chars()
410 .take_while(|c| c.is_whitespace())
411 .collect::<String>();
412 let insert = format!(
413 "{} .with_context(\n{} AppContext::builder()\n{} .with_database(Arc::new(SeaOrmPool::new(db, database_url)))\n{} .build()\n{} )\n",
414 indent, indent, indent, indent, indent
415 );
416 contents.insert_str(line_end + 1, &insert);
417 Ok(contents)
418 } else {
419 print_warning("Could not find app builder; skipping database wiring");
420 Ok(contents)
421 }
422}
423
424fn wire_openapi_in_main(project_dir: &Path) -> Result<()> {
425 let main_path = project_dir.join("src").join("main.rs");
426 if !main_path.exists() {
427 print_warning("src/main.rs not found; skipping auto-wiring");
428 return Ok(());
429 }
430
431 let mut contents = fs::read_to_string(&main_path)
432 .with_context(|| format!("Failed to read {}", main_path.display()))?;
433
434 if contents.contains("openapi::create_openapi_router") || contents.contains("openapi_merge_module") {
435 print_info("OpenAPI already appears wired in main.rs");
436 return Ok(());
437 }
438
439 contents = ensure_use_line(contents, "use tideway::ConfigBuilder;", "use tideway::");
440 if contents.contains("mod config;") {
441 contents = ensure_use_line(contents, "use crate::config::AppConfig;", "use tideway::");
442 }
443 contents = ensure_use_line(contents, "use tideway::openapi;", "use tideway::");
444
445 if !contents.contains("mod openapi_docs;") {
446 if contents.contains("mod routes;") {
447 contents = contents.replace("mod routes;\n", "mod routes;\nmod openapi_docs;\n");
448 } else {
449 contents = format!("mod openapi_docs;\n{}", contents);
450 }
451 }
452
453 let has_config_var = contents.contains("let config = ConfigBuilder::new()")
454 || contents.contains("let config = AppConfig::from_env()");
455 let config_available = contents.contains("ConfigBuilder::new()")
456 || contents.contains("AppConfig::from_env()");
457
458 if !has_config_var && config_available {
459 let config_block = " let config = ConfigBuilder::new()\n .from_env()\n .build()\n .expect(\"Invalid TIDEWAY_* config\");\n\n";
460 contents = insert_before_app_builder(contents, config_block)?;
461 }
462
463 if contents.contains("let config = AppConfig::from_env()") {
464 contents = insert_openapi_into_app_builder(contents, "config.tideway")?;
465 } else {
466 contents = insert_openapi_into_app_builder(contents, "config")?;
467 }
468
469 write_file(&main_path, &contents)
470 .with_context(|| format!("Failed to write {}", main_path.display()))?;
471 print_success("Wired OpenAPI into src/main.rs");
472 Ok(())
473}
474
475fn insert_openapi_into_app_builder(mut contents: String, config_ref: &str) -> Result<String> {
476 if contents.contains("create_openapi_router") {
477 return Ok(contents);
478 }
479
480 if let Some(pos) = contents.find("let app = App::") {
481 if let Some(end_pos) = contents[pos..].find(";\n\n") {
483 let insert_at = pos + end_pos + 3;
484 let block = format!(
485 "\n #[cfg(feature = \"openapi\")]\n if {config_ref}.openapi.enabled {{\n let openapi = tideway::openapi_merge_module!(openapi_docs, ApiDoc);\n let openapi_router = tideway::openapi::create_openapi_router(openapi, &{config_ref}.openapi);\n app = app.merge_router(openapi_router);\n }}\n"
486 );
487 contents.insert_str(insert_at, &block);
488 } else {
489 print_warning("Could not find app builder termination; skipping OpenAPI wiring");
490 }
491 Ok(contents)
492 } else {
493 print_warning("Could not find app builder; skipping OpenAPI wiring");
494 Ok(contents)
495 }
496}
497
498fn ensure_openapi_docs_file(project_dir: &Path) -> Result<()> {
499 let docs_path = project_dir.join("src").join("openapi_docs.rs");
500 if docs_path.exists() {
501 return Ok(());
502 }
503
504 let contents = r#"#[cfg(feature = "openapi")]
505tideway::openapi_doc!(pub(crate) ApiDoc, paths());
506"#;
507
508 if let Some(parent) = docs_path.parent() {
509 ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
510 }
511
512 write_file(&docs_path, &contents)
513 .with_context(|| format!("Failed to write {}", docs_path.display()))?;
514 print_success("Created src/openapi_docs.rs");
515 Ok(())
516}
517
518
519fn write_file_with_force(path: &Path, contents: &str, force: bool) -> Result<()> {
520 if path.exists() && !force {
521 print_warning(&format!(
522 "Skipping {} (use --force to overwrite)",
523 path.display()
524 ));
525 return Ok(());
526 }
527
528 if let Some(parent) = path.parent() {
529 ensure_dir(parent).with_context(|| format!("Failed to create {}", parent.display()))?;
530 }
531
532 write_file(path, contents)
533 .with_context(|| format!("Failed to write {}", path.display()))?;
534 Ok(())
535}
536
537fn project_name_from_cargo(contents: &str, project_dir: &Path) -> String {
538 let doc = contents
539 .parse::<toml_edit::DocumentMut>()
540 .ok()
541 .and_then(|doc| doc["package"]["name"].as_str().map(|s| s.to_string()));
542
543 doc.unwrap_or_else(|| {
544 project_dir
545 .file_name()
546 .and_then(|n| n.to_str())
547 .unwrap_or("my_app")
548 .to_string()
549 })
550 .replace('-', "_")
551}
552
553fn to_pascal_case(s: &str) -> String {
554 s.split('_')
555 .filter(|part| !part.is_empty())
556 .map(|word| {
557 let mut chars = word.chars();
558 match chars.next() {
559 None => String::new(),
560 Some(first) => first.to_uppercase().chain(chars).collect(),
561 }
562 })
563 .collect()
564}
565
566pub fn array_value(values: &[&str]) -> toml_edit::Value {
567 let mut array = toml_edit::Array::new();
568 for value in values {
569 array.push(*value);
570 }
571 toml_edit::Value::Array(array)
572}