tag2upload_service_manager/
global.rs
use crate::prelude::*;
const USER_AGENT: &str =
"tag2upload-service-manager https://salsa.debian.org/dgit-team";
#[derive(Debug)]
pub struct Globals {
pub cli_options: CliOptions,
pub config: Config,
pub state: watch::Sender<State>,
pub worker_tracker: Arc<WorkerTracker>,
pub db_trigger: watch::Sender<DbAssocState>,
pub http_client: reqwest::Client,
pub dns_resolver: hickory_resolver::TokioAsyncResolver,
pub running: watch::Receiver<Option<Running>>,
pub scratch_dir: String,
pub temp_dir_retain: Option<tempfile::TempDir>,
pub tera: Tera,
pub test_suppl: test::GlobalSupplement,
}
#[derive(Clone, Debug)]
pub struct Running {
pub port: u16,
}
pub struct Started {
pub rocket: RocketIgnite,
}
#[derive(Default, Debug)]
pub struct DbAssocState {
}
#[cfg(not(test))]
mod imp_globals {
use super::*;
static GLOBALS: OnceLock<Arc<Globals>> = OnceLock::new();
pub fn globals() -> Arc<Globals> {
GLOBALS.get()
.expect("Using globals() before set")
.clone()
}
pub(crate) fn set_globals(g: Arc<Globals>) {
GLOBALS.set(g)
.expect("set_globals called more than once?");
}
}
#[cfg(test)]
mod imp_globals {
use super::*;
thread_local! {
static GLOBALS: RefCell<Option<Arc<Globals>>> =
const { RefCell::new(None) };
}
pub fn globals() -> Arc<Globals> {
GLOBALS.with_borrow(|gl| gl.clone())
.expect("Using globals() before set")
}
pub(crate) fn set_globals(g: Arc<Globals>) {
if let Some(was) = GLOBALS.replace(Some(g)) {
panic!("set_globals called more than once, previously {was:?}");
}
}
}
pub use imp_globals::*;
#[derive(Debug)]
pub struct ShuttingDown;
#[derive(Debug)]
pub struct State {
pub shutdown_reason: Option<Result<ShuttingDown, InternalError>>,
pub test_suppl: test::StateSupplement,
}
impl State {
pub fn new(test_suppl: test::StateSupplement) -> Self {
State {
shutdown_reason: None,
test_suppl,
}
}
pub fn check_shutdown(&self) -> Result<(), ShuttingDown> {
match &self.shutdown_reason {
Some(Ok(ShuttingDown)) => Err(ShuttingDown),
Some(Err(_)) => Err(ShuttingDown),
None => Ok(()),
}
}
}
impl Globals {
pub async fn await_shutdown(&self) -> ShuttingDown {
self.state.subscribe()
.wait_for_then(|state| {
state.shutdown_reason
.as_ref()
.map(|_: &Result<ShuttingDown, IE>| ())
})
.await
.unwrap_or_else(|e| {
error!("shutdown handler task recv failed {e}");
});
ShuttingDown
}
}
pub async fn test_hook<S: Display>(
#[allow(unused_variables)]
point: impl FnOnce() -> S + Send + Sync,
) {
#[cfg(test)]
test::hook_point(point().to_string()).await;
}
impl Globals {
pub fn spawn_task_running(
self: &Arc<Self>,
what: impl Display,
fut: impl Future<Output = TaskResult> + Send + 'static,
) {
let gl = self.clone();
self.spawn_task_immediate(what, async move {
let Running { .. } = gl.await_running().await?;
fut.await
})
}
pub fn spawn_task_immediate(
self: &Arc<Self>,
what: impl Display,
fut: impl Future<Output = TaskResult> + Send + 'static,
) {
tokio::spawn({
let what = what.to_string();
async move {
let _: TaskResult = AssertUnwindSafe(fut)
.catch_unwind()
.await
.unwrap_or_else(|_: Box<dyn Any + Send>| {
Err(internal!("task {what} panicked!").into())
});
}
});
}
pub fn check_shutdown(&self) -> Result<(), ShuttingDown> {
self.state.borrow().check_shutdown()
}
pub async fn await_running(&self) -> Result<Running, ShuttingDown> {
match self.running.clone().wait_for_then(|p| p.clone()).await {
Ok(y) => Ok(y),
Err(e) => {
debug!("shutting down, no port: {e}");
Err(ShuttingDown)
}
}
}
pub async fn http_fetch_json<T: serde::de::DeserializeOwned>(
&self,
url: Url,
) -> Result<T, AE> {
if let Some(fake) = &self.config.testing.fake_https_dir {
let url = url.to_string();
let url = url.strip_prefix("https://")
.ok_or_else(|| anyhow!(
"failed to strip https:// prefix from url {url:?}"
))?;
let fake_file = format!("{fake}/{url}");
let data = fs::read_to_string(&fake_file)
.with_context(|| format!("{fake_file:?}"))
.context("failed to read fake file")?;
let data = serde_json::from_str(&data)
.with_context(|| format!("{fake_file:?}"))
.context("failed to deser fake file")?;
return Ok(data);
}
self.http_client.get(url)
.send().await.context("send")?
.error_for_status().context("status")?
.json().await.context("response")
}
}
macro_rules! test_hook_url { { $url:ident } => {
#[cfg(test)]
let $url = crate::test::UrlMappable::map(&$url);
} }
pub fn shutdown_start_tasks(
globals: &Arc<Globals>,
) -> Result<(), StartupError> {
use tokio::signal::unix::{signal, SignalKind as SK};
#[cfg(test)]
match globals.t_shutdown_handlers() {
Ok(()) => {},
Err(()) => return Ok(()),
};
let mut terminate = signal(SK::terminate())
.into_internal("failed to set up SIGTERM handler")?;
globals.spawn_task_immediate("shutdown SIGTEREM watch", {
let globals = globals.clone();
async move {
let () = terminate.recv().await
.ok_or_else(|| internal!("no more SIGTERM reception?!"))?;
globals.state.send_modify(|state| {
match state.shutdown_reason {
None => {
info!("received SIGTERM, shutting down...");
state.shutdown_reason = Some(Ok(ShuttingDown));
}
Some(Err(_)) => {
info!("SIGTERM, but already crashing!");
},
Some(Ok(ShuttingDown)) => {
info!("SIGTERM, but already shutting down");
}
}
});
Ok(TaskWorkComplete {})
}
});
globals.spawn_task_immediate("shutdown handler", {
let globals = globals.clone();
async move {
let _: ShuttingDown = globals.await_shutdown().await;
let mut subscription = globals.db_trigger.subscribe();
match loop {
let job = find_job_deferring_shutdown()?;
let Some::<JobRow>(job) = job
else { break Ok(()); };
info!(jid=%job.jid, "shutdown awaits completion of build");
match subscription.changed().await {
Ok(()) => {}, Err(e) => break Err(e),
}
} {
Ok(()) => info!("clean shutdown complete."),
Err(e) => error!("shutdown terminating early, watch {}", e),
};
unsafe {
libc::kill(0, libc::SIGHUP);
error!("SIGHUP didn't kill us!!");
std::process::abort();
}
}
});
Ok(())
}
pub fn find_job_deferring_shutdown() -> Result<Option<JobRow>, IE> {
db_transaction(TN::Readonly, |dbt| {
dbt.bsql_query_01(bsql!("
SELECT * FROM jobs
WHERE processing != ''
AND status = " (JobStatus::Building) "
ORDER BY last_update ASC
"))
})?
}
fn write_port_report_file(gl: &Arc<Globals>, port: u16) {
if let Some(file) = gl.config.files.port_report_file.clone() {
trace!(?file, "writing port");
(|| {
let f = fs::File::create(&file).context("open")?;
let mut f = io::BufWriter::new(f);
writeln!(f, "{port}").context("write")?;
f.flush().context("write (flush)")?;
Ok::<_, AE>(())
})()
.with_context(|| format!("{file:?}"))
.unwrap_or_else(|ae| IE::new_without_backtrace(ae).note_only());
}
}
pub async fn startup(
cli_options: CliOptions,
base_config: Figment,
test_global_suppl: test::GlobalSupplement,
test_state_suppl: test::StateSupplement,
rocket_hook: impl FnOnce(RocketBuild) -> RocketBuild,
) -> Result<Started, StartupError> {
use StartupError as SE;
let rocket_base_config = {
use rocket::config::*;
rocket::Config {
shutdown: Shutdown {
ctrlc: false,
signals: HashSet::new(),
..Default::default()
},
..rocket::Config::release_default()
}
};
let config = base_config
.merge(figment::providers::Toml::file(&cli_options.config))
.join(figment::providers::Serialized::default(
"rocket",
rocket_base_config,
))
.join(figment::providers::Serialized::default(
"rocket",
json! {{
}},
));
let config = {
let mut c = config;
for s in &cli_options.config_toml {
c = c.merge(figment::providers::Toml::string(s));
}
c
};
let rocket_config = config
.focus("rocket");
let config: Config = config
.extract()
.map_err(SE::ParseConfig)?;
config.check()?;
logging::setup(&config)?;
let scratch_dir;
let temp_dir_retain;
match &config.files.scratch_dir {
Some(s) => {
scratch_dir = s.clone();
temp_dir_retain = None;
}
None => {
let td = tempfile::TempDir::new()
.context("create temp dir")
.map_err(SE::TempDir)?;
scratch_dir = td.path().to_str()
.ok_or_else(|| anyhow!("not utf-8"))
.map_err(SE::TempDir)?
.to_owned();
temp_dir_retain = Some(td);
}
};
remove_dir_all::remove_dir_contents(&scratch_dir)
.context("clean out old contents of scratch directory")
.map_err(SE::TempDir)?;
let http_client = reqwest::Client::builder()
.user_agent(USER_AGENT)
.timeout(*config.timeouts.http_request)
.build()?;
let dns_resolver = dns::Resolver::tokio_from_system_conf()?;
let tera = routes_abstract::tera_templates(&config)?;
let (running_tx, running) = watch::channel(None);
let globals = Arc::new(Globals {
cli_options,
config,
db_trigger: watch::Sender::new(DbAssocState::default()),
state: watch::Sender::new(State::new(test_state_suppl)),
worker_tracker: Default::default(),
http_client,
dns_resolver,
scratch_dir,
running,
tera,
temp_dir_retain,
test_suppl: test_global_suppl,
});
set_globals(globals.clone());
db_support::initialise(&globals)?;
let listener = o2m_listener::Listener::new(&globals)?;
shutdown_start_tasks(&globals)?;
expire::start_task(&globals);
let rocket = rocket::custom(&rocket_config);
let rocket = rocket_hook(rocket);
let rocket = routes::mount_all(rocket);
let rocket = rocket.attach({
let globals = globals.clone();
rocket::fairing::AdHoc::on_liftoff(
"spawn workers",
|rocket: &rocket::Rocket<_>| Box::pin(
async move {
if globals.state.borrow().shutdown_reason.is_some() {
trace!(
"shutdown triggered during startup, not continuing"
);
return;
}
fetcher::start_tasks(&globals);
listener.start_task();
let port = rocket.config().port;
write_port_report_file(&globals, port);
let running = Running {
port
};
running_tx.send(Some(running.clone()))
.expect("no-one wanted our port");
info!(?running, "running");
}
)
)
});
let rocket = rocket.ignite().await?;
Ok(Started {
rocket,
})
}