shapemaker/synchronization/
midi.rs

1use super::audio::{self, Stem};
2use super::sync::{SyncData, Syncable};
3use crate::ui::{Log, MaybeProgressBar};
4use indicatif::ProgressBar;
5use itertools::Itertools;
6use midly::{MetaMessage, MidiMessage, TrackEvent, TrackEventKind};
7use std::{collections::HashMap, fmt::Debug, path::PathBuf};
8
9pub struct MidiSynchronizer {
10    pub midi_path: PathBuf,
11}
12
13trait Averageable {
14    fn average(&self) -> f32;
15}
16
17impl Averageable for Vec<f32> {
18    fn average(&self) -> f32 {
19        self.iter().sum::<f32>() / self.len() as f32
20    }
21}
22
23impl Syncable for MidiSynchronizer {
24    fn new(path: &str) -> Self {
25        Self {
26            midi_path: PathBuf::from(path),
27        }
28    }
29
30    fn load(&self, progressbar: Option<&ProgressBar>) -> SyncData {
31        let (now, notes_per_instrument) = load_notes(&self.midi_path, progressbar);
32
33        SyncData {
34            bpm: tempo_to_bpm(now.tempo),
35            stems: HashMap::from_iter(notes_per_instrument.iter().map(|(name, notes)| {
36                let mut notes_per_ms = HashMap::<usize, Vec<audio::Note>>::new();
37
38                if let Some(pb) = progressbar {
39                    pb.set_length(notes.len() as u64);
40                    pb.set_position(0);
41                }
42                progressbar.set_message(format!("Adding loaded notes for {name}"));
43
44                for note in notes.iter() {
45                    notes_per_ms
46                        .entry(note.ms as usize)
47                        .or_default()
48                        .push(audio::Note {
49                            pitch: note.key,
50                            tick: note.tick,
51                            velocity: note.vel,
52                        });
53                    progressbar.inc(1);
54                }
55
56                let duration_ms = *notes_per_ms.keys().max().unwrap_or(&0);
57
58                if let Some(pb) = progressbar {
59                    pb.set_length(duration_ms as u64 - 1);
60                    pb.set_position(0);
61                }
62                progressbar.set_message(format!("Infering amplitudes for {name}"));
63
64                let mut amplitudes = Vec::<f32>::new();
65                let mut last_amplitude = 0.0;
66                for i in 0..duration_ms {
67                    if let Some(notes) = notes_per_ms.get(&i) {
68                        last_amplitude = notes
69                            .iter()
70                            .map(|n| n.velocity as f32)
71                            .collect::<Vec<f32>>()
72                            .average();
73                    }
74                    amplitudes.push(last_amplitude);
75                    progressbar.inc(1);
76                }
77
78                (
79                    name.clone(),
80                    Stem {
81                        amplitude_max: notes.iter().map(|n| n.vel).max().unwrap_or(0) as f32,
82                        amplitude_db: amplitudes,
83                        duration_ms,
84                        notes: notes_per_ms,
85                        name: name.clone(),
86                    },
87                )
88            })),
89            markers: HashMap::new(),
90        }
91    }
92}
93
94#[derive(Clone)]
95struct Note {
96    tick: u32,
97    ms: u32,
98    key: u8,
99    vel: u8,
100}
101
102struct Now {
103    ms: usize,
104    tempo: usize,
105    ticks_per_beat: u16,
106}
107
108type Timeline<'a> = HashMap<u32, HashMap<String, TrackEvent<'a>>>;
109
110type StemNotes = HashMap<u32, HashMap<String, Note>>;
111
112impl Note {
113    fn is_off(&self) -> bool {
114        self.vel == 0
115    }
116}
117
118fn tempo_to_bpm(µs_per_beat: usize) -> usize {
119    (60_000_000.0 / µs_per_beat as f32).round() as usize
120}
121
122// fn to_ms(delta: u32, bpm: f32) -> f32 {
123//     (delta as f32) * (60.0 / bpm) * 1000.0
124// }
125
126impl Debug for Note {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        write!(
129            f,
130            "{}{}",
131            self.key,
132            if self.is_off() {
133                "↓".to_string()
134            } else if self.vel == 100 {
135                "".to_string()
136            } else {
137                format!("@{}", self.vel)
138            }
139        )
140    }
141}
142
143fn load_notes(
144    source: &PathBuf,
145    progressbar: Option<&ProgressBar>,
146) -> (Now, HashMap<String, Vec<Note>>) {
147    // Read midi file using midly
148    if let Some(pb) = progressbar {
149        pb.set_length(1);
150        pb.set_prefix("Loading");
151        pb.set_message("reading MIDI file");
152        pb.set_position(0);
153    }
154
155    let raw = std::fs::read(source)
156        .unwrap_or_else(|_| panic!("Failed to read MIDI file {}", source.to_str().unwrap()));
157    let midifile = midly::Smf::parse(&raw).unwrap();
158
159    let mut timeline = Timeline::new();
160    progressbar.set_message(format!("MIDI file has {} tracks", midifile.tracks.len()));
161
162    let mut now = Now {
163        ms: 0,
164        tempo: 0,
165        ticks_per_beat: match midifile.header.timing {
166            midly::Timing::Metrical(ticks_per_beat) => ticks_per_beat.as_int(),
167            midly::Timing::Timecode(fps, subframe) => (1.0 / fps.as_f32() / subframe as f32) as u16,
168        },
169    };
170
171    // Get track names and (initial) BPM
172    let mut track_no = 0;
173    let mut track_names = HashMap::<usize, String>::new();
174    for track in midifile.tracks.iter() {
175        track_no += 1;
176        let mut track_name = String::new();
177        for event in track {
178            match event.kind {
179                TrackEventKind::Meta(MetaMessage::TrackName(name_bytes)) => {
180                    track_name = String::from_utf8(name_bytes.to_vec()).unwrap_or_default();
181                }
182                TrackEventKind::Meta(MetaMessage::Tempo(tempo)) => {
183                    if now.tempo == 0 {
184                        now.tempo = tempo.as_int() as usize;
185                    }
186                }
187                _ => {}
188            }
189        }
190        track_names.insert(
191            track_no,
192            if !track_name.is_empty() {
193                track_name
194            } else {
195                format!("Track #{}", track_no)
196            },
197        );
198    }
199
200    progressbar.log(
201        "Detected",
202        &format!(
203            "MIDI file {} with {} stems and initial tempo of {} BPM",
204            source.to_str().unwrap(),
205            track_names.len(),
206            tempo_to_bpm(now.tempo)
207        ),
208    );
209
210    // Convert ticks to absolute
211    let mut track_no = 0;
212    for track in midifile.tracks.iter() {
213        track_no += 1;
214        let mut absolute_tick = 0;
215        for event in track {
216            absolute_tick += event.delta.as_int();
217            timeline
218                .entry(absolute_tick)
219                .or_default()
220                .insert(track_names[&track_no].clone(), *event);
221        }
222    }
223
224    // Convert ticks to ms
225    let mut absolute_tick_to_ms = HashMap::<u32, usize>::new();
226    let mut last_tick = 0;
227    for (tick, tracks) in timeline.iter().sorted_by_key(|(tick, _)| *tick) {
228        for event in tracks.values() {
229            if let TrackEventKind::Meta(MetaMessage::Tempo(tempo)) = event.kind {
230                now.tempo = tempo.as_int() as usize;
231            }
232        }
233        let delta = tick - last_tick;
234        last_tick = *tick;
235        now.ms += midi_tick_to_ms(delta, now.tempo, now.ticks_per_beat as usize);
236        absolute_tick_to_ms.insert(*tick, now.ms);
237    }
238
239    if let Some(pb) = progressbar {
240        pb.set_length(midifile.tracks.iter().map(|t| t.len() as u64).sum::<u64>());
241        pb.set_prefix("Loading");
242        pb.set_message("parsing MIDI events");
243        pb.set_position(0);
244    }
245
246    // Add notes
247    let mut stem_notes = StemNotes::new();
248    for (tick, tracks) in timeline.iter().sorted_by_key(|(tick, _)| *tick) {
249        for (track_name, event) in tracks {
250            if let TrackEventKind::Midi {
251                channel: _,
252                message,
253            } = event.kind
254            {
255                match message {
256                    MidiMessage::NoteOn { key, vel } | MidiMessage::NoteOff { key, vel } => {
257                        stem_notes
258                            .entry(absolute_tick_to_ms[tick] as u32)
259                            .or_default()
260                            .insert(
261                                track_name.clone(),
262                                Note {
263                                    tick: *tick,
264                                    ms: absolute_tick_to_ms[tick] as u32,
265                                    key: key.as_int(),
266                                    vel: if matches!(message, MidiMessage::NoteOff { .. }) {
267                                        0
268                                    } else {
269                                        vel.as_int()
270                                    },
271                                },
272                            );
273                    }
274                    _ => {}
275                }
276            }
277            progressbar.inc(1)
278        }
279    }
280
281    let mut result = HashMap::<String, Vec<Note>>::new();
282
283    for (_ms, notes) in stem_notes.iter().sorted_by_key(|(ms, _)| *ms) {
284        for (track_name, note) in notes {
285            result
286                .entry(track_name.clone())
287                .or_default()
288                .push(note.clone());
289        }
290    }
291
292    (now, result)
293}
294
295fn midi_tick_to_ms(tick: u32, tempo: usize, ppq: usize) -> usize {
296    let with_floats = (tempo as f32 / 1e3) / ppq as f32 * tick as f32;
297    with_floats.round() as usize
298}