tutti_core/
runner.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    time::Duration,
4};
5
6use colored::Colorize;
7use futures::StreamExt;
8use tokio::{
9    sync::mpsc::{self, Receiver, Sender},
10    task::JoinHandle,
11};
12use tutti_config::Project;
13
14use crate::{
15    process::{BoxStream, ProcId},
16    CommandSpec, ProcessManager,
17};
18
19#[derive(Debug)]
20pub enum LogEvent {
21    Log { service_name: String, line: Vec<u8> },
22    Stop { service_name: String },
23}
24
25async fn follow_output(
26    is_stdout: bool,
27    mut output: BoxStream<Vec<u8>>,
28    service_name: String,
29    rx: Sender<LogEvent>,
30) {
31    while let Some(line) = output.next().await {
32        if rx
33            .send(LogEvent::Log {
34                service_name: service_name.clone(),
35                line,
36            })
37            .await
38            .is_err()
39        {
40            break;
41        }
42    }
43    if is_stdout && rx.send(LogEvent::Stop { service_name }).await.is_err() {
44        eprintln!("Failed to send stop event");
45    }
46}
47
48#[derive(Debug)]
49pub struct Runner<M: ProcessManager> {
50    project: Project,
51    pm: M,
52
53    tasks: Vec<JoinHandle<()>>,
54    processes: Vec<ProcId>,
55}
56
57impl<M: ProcessManager> Runner<M> {
58    pub fn new(project: Project, pm: M) -> Self {
59        let tasks = Vec::with_capacity(project.services.len() * 2);
60        let processes = Vec::with_capacity(project.services.len());
61
62        Self {
63            project,
64            pm,
65            tasks,
66            processes,
67        }
68    }
69
70    /// Performs topological sort on services considering their dependencies.
71    /// Returns services in order they should be started.
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if a service is not found or if there is a cycle in the dependencies.
76    fn topological_sort(&self, service_names: &HashSet<String>) -> anyhow::Result<Vec<String>> {
77        let mut graph: HashMap<String, Vec<String>> = HashMap::new();
78        let mut in_degree: HashMap<String, usize> = HashMap::new();
79        let mut all_services = HashSet::new();
80
81        let mut to_process = VecDeque::new();
82        for name in service_names {
83            to_process.push_back(name.clone());
84        }
85
86        while let Some(service_name) = to_process.pop_front() {
87            if all_services.contains(&service_name) {
88                continue;
89            }
90
91            let service = self
92                .project
93                .services
94                .get(&service_name)
95                .ok_or_else(|| anyhow::anyhow!("Service '{service_name}' not found"))?;
96
97            all_services.insert(service_name.clone());
98            graph.entry(service_name.clone()).or_default();
99            in_degree.entry(service_name.clone()).or_insert(0);
100
101            for dep in &service.deps {
102                if !self.project.services.contains_key(dep) {
103                    return Err(anyhow::anyhow!(
104                        "Dependency '{dep}' of service '{service_name}' not found",
105                    ));
106                }
107
108                graph
109                    .entry(dep.clone())
110                    .or_default()
111                    .push(service_name.clone());
112                *in_degree.entry(service_name.clone()).or_insert(0) += 1;
113
114                to_process.push_back(dep.clone());
115            }
116        }
117
118        let mut queue = VecDeque::new();
119        for (service, &degree) in &in_degree {
120            if degree == 0 {
121                queue.push_back(service.clone());
122            }
123        }
124
125        let mut result = Vec::new();
126        while let Some(service) = queue.pop_front() {
127            result.push(service.clone());
128
129            if let Some(dependents) = graph.get(&service) {
130                for dependent in dependents {
131                    if let Some(degree) = in_degree.get_mut(dependent) {
132                        *degree -= 1;
133                        if *degree == 0 {
134                            queue.push_back(dependent.clone());
135                        }
136                    }
137                }
138            }
139        }
140
141        if result.len() != all_services.len() {
142            return Err(anyhow::anyhow!("Circular dependency detected"));
143        }
144
145        Ok(result)
146    }
147
148    /// Starts the services defined in the project configuration.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if any of the services fail to start.
153    pub async fn up(&mut self, services: Vec<String>) -> anyhow::Result<Receiver<LogEvent>> {
154        let (tx, rx) = mpsc::channel(10);
155
156        let service_names = if services.is_empty() {
157            self.project
158                .services
159                .keys()
160                .cloned()
161                .collect::<HashSet<_>>()
162        } else {
163            services
164                .into_iter()
165                .filter(|name| self.project.services.contains_key(name))
166                .collect::<HashSet<_>>()
167        };
168
169        let sorted_services = self.topological_sort(&service_names)?;
170
171        let to_run = sorted_services
172            .into_iter()
173            .map(|name| {
174                let service = self
175                    .project
176                    .services
177                    .get(&name)
178                    .ok_or_else(|| anyhow::anyhow!("unknown service: {name}"))?;
179                Ok((name, service))
180            })
181            .collect::<anyhow::Result<Vec<_>>>()?;
182
183        for (name, service) in &to_run {
184            let service = self
185                .pm
186                .spawn(CommandSpec {
187                    name: name.to_owned(),
188                    cmd: service.cmd.clone(),
189                    cwd: service.cwd.clone(),
190                    env: service
191                        .env
192                        .clone()
193                        .map(|h| h.into_iter().collect())
194                        .unwrap_or_default(),
195                })
196                .await?;
197            let stdout = service.stdout;
198            let stderr = service.stderr;
199
200            self.tasks.push(tokio::spawn(follow_output(
201                true,
202                stdout,
203                name.clone(),
204                tx.clone(),
205            )));
206            self.tasks.push(tokio::spawn(follow_output(
207                false,
208                stderr,
209                name.clone(),
210                tx.clone(),
211            )));
212            self.processes.push(service.id);
213        }
214
215        Ok(rx)
216    }
217
218    /// Stops all services.
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if any of the services fail to stop.
223    pub async fn down(&mut self) -> anyhow::Result<()> {
224        let duration = Duration::from_millis(100);
225
226        for id in self.processes.drain(..) {
227            println!("Stopping process {id:?}");
228            if self.pm.wait(id, duration).await?.is_some() {
229                println!("process {id:?} already stopped");
230                continue;
231            }
232
233            self.pm.shutdown(id).await?;
234            if let Some(exit_code) = self.pm.wait(id, duration).await? {
235                let line = format!("process {id:?} stopped with {exit_code} code")
236                    .black()
237                    .on_yellow();
238                println!("{line}");
239            } else {
240                self.pm.kill(id).await?;
241                let line = format!("process {id:?} killed").black().on_red();
242                println!("{line}");
243            }
244        }
245
246        Ok(())
247    }
248
249    /// Waits for all services to exit.
250    ///
251    /// # Errors
252    ///
253    /// Returns an error if any of the services fail to exit.
254    pub async fn wait(&mut self) -> anyhow::Result<()> {
255        for task in self.tasks.drain(..) {
256            task.await?;
257        }
258
259        Ok(())
260    }
261}