shapemaker/synchronization/
midi.rs1use super::audio::{self, Stem};
2use super::sync::{SyncData, Syncable};
3use crate::ui::{Log, MaybeProgressBar};
4use indicatif::ProgressBar;
5use itertools::Itertools;
6use midly::{MetaMessage, MidiMessage, TrackEvent, TrackEventKind};
7use std::{collections::HashMap, fmt::Debug, path::PathBuf};
8
9pub struct MidiSynchronizer {
10 pub midi_path: PathBuf,
11}
12
13trait Averageable {
14 fn average(&self) -> f32;
15}
16
17impl Averageable for Vec<f32> {
18 fn average(&self) -> f32 {
19 self.iter().sum::<f32>() / self.len() as f32
20 }
21}
22
23impl Syncable for MidiSynchronizer {
24 fn new(path: &str) -> Self {
25 Self {
26 midi_path: PathBuf::from(path),
27 }
28 }
29
30 fn load(&self, progressbar: Option<&ProgressBar>) -> SyncData {
31 let (now, notes_per_instrument) = load_notes(&self.midi_path, progressbar);
32
33 SyncData {
34 bpm: tempo_to_bpm(now.tempo),
35 stems: HashMap::from_iter(notes_per_instrument.iter().map(|(name, notes)| {
36 let mut notes_per_ms = HashMap::<usize, Vec<audio::Note>>::new();
37
38 if let Some(pb) = progressbar {
39 pb.set_length(notes.len() as u64);
40 pb.set_position(0);
41 }
42 progressbar.set_message(format!("Adding loaded notes for {name}"));
43
44 for note in notes.iter() {
45 notes_per_ms
46 .entry(note.ms as usize)
47 .or_default()
48 .push(audio::Note {
49 pitch: note.key,
50 tick: note.tick,
51 velocity: note.vel,
52 });
53 progressbar.inc(1);
54 }
55
56 let duration_ms = *notes_per_ms.keys().max().unwrap_or(&0);
57
58 if let Some(pb) = progressbar {
59 pb.set_length(duration_ms as u64 - 1);
60 pb.set_position(0);
61 }
62 progressbar.set_message(format!("Infering amplitudes for {name}"));
63
64 let mut amplitudes = Vec::<f32>::new();
65 let mut last_amplitude = 0.0;
66 for i in 0..duration_ms {
67 if let Some(notes) = notes_per_ms.get(&i) {
68 last_amplitude = notes
69 .iter()
70 .map(|n| n.velocity as f32)
71 .collect::<Vec<f32>>()
72 .average();
73 }
74 amplitudes.push(last_amplitude);
75 progressbar.inc(1);
76 }
77
78 (
79 name.clone(),
80 Stem {
81 amplitude_max: notes.iter().map(|n| n.vel).max().unwrap_or(0) as f32,
82 amplitude_db: amplitudes,
83 duration_ms,
84 notes: notes_per_ms,
85 name: name.clone(),
86 },
87 )
88 })),
89 markers: HashMap::new(),
90 }
91 }
92}
93
94#[derive(Clone)]
95struct Note {
96 tick: u32,
97 ms: u32,
98 key: u8,
99 vel: u8,
100}
101
102struct Now {
103 ms: usize,
104 tempo: usize,
105 ticks_per_beat: u16,
106}
107
108type Timeline<'a> = HashMap<u32, HashMap<String, TrackEvent<'a>>>;
109
110type StemNotes = HashMap<u32, HashMap<String, Note>>;
111
112impl Note {
113 fn is_off(&self) -> bool {
114 self.vel == 0
115 }
116}
117
118fn tempo_to_bpm(µs_per_beat: usize) -> usize {
119 (60_000_000.0 / µs_per_beat as f32).round() as usize
120}
121
122impl Debug for Note {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 write!(
129 f,
130 "{}{}",
131 self.key,
132 if self.is_off() {
133 "↓".to_string()
134 } else if self.vel == 100 {
135 "".to_string()
136 } else {
137 format!("@{}", self.vel)
138 }
139 )
140 }
141}
142
143fn load_notes(
144 source: &PathBuf,
145 progressbar: Option<&ProgressBar>,
146) -> (Now, HashMap<String, Vec<Note>>) {
147 if let Some(pb) = progressbar {
149 pb.set_length(1);
150 pb.set_prefix("Loading");
151 pb.set_message("reading MIDI file");
152 pb.set_position(0);
153 }
154
155 let raw = std::fs::read(source)
156 .unwrap_or_else(|_| panic!("Failed to read MIDI file {}", source.to_str().unwrap()));
157 let midifile = midly::Smf::parse(&raw).unwrap();
158
159 let mut timeline = Timeline::new();
160 progressbar.set_message(format!("MIDI file has {} tracks", midifile.tracks.len()));
161
162 let mut now = Now {
163 ms: 0,
164 tempo: 0,
165 ticks_per_beat: match midifile.header.timing {
166 midly::Timing::Metrical(ticks_per_beat) => ticks_per_beat.as_int(),
167 midly::Timing::Timecode(fps, subframe) => (1.0 / fps.as_f32() / subframe as f32) as u16,
168 },
169 };
170
171 let mut track_no = 0;
173 let mut track_names = HashMap::<usize, String>::new();
174 for track in midifile.tracks.iter() {
175 track_no += 1;
176 let mut track_name = String::new();
177 for event in track {
178 match event.kind {
179 TrackEventKind::Meta(MetaMessage::TrackName(name_bytes)) => {
180 track_name = String::from_utf8(name_bytes.to_vec()).unwrap_or_default();
181 }
182 TrackEventKind::Meta(MetaMessage::Tempo(tempo)) => {
183 if now.tempo == 0 {
184 now.tempo = tempo.as_int() as usize;
185 }
186 }
187 _ => {}
188 }
189 }
190 track_names.insert(
191 track_no,
192 if !track_name.is_empty() {
193 track_name
194 } else {
195 format!("Track #{}", track_no)
196 },
197 );
198 }
199
200 progressbar.log(
201 "Detected",
202 &format!(
203 "MIDI file {} with {} stems and initial tempo of {} BPM",
204 source.to_str().unwrap(),
205 track_names.len(),
206 tempo_to_bpm(now.tempo)
207 ),
208 );
209
210 let mut track_no = 0;
212 for track in midifile.tracks.iter() {
213 track_no += 1;
214 let mut absolute_tick = 0;
215 for event in track {
216 absolute_tick += event.delta.as_int();
217 timeline
218 .entry(absolute_tick)
219 .or_default()
220 .insert(track_names[&track_no].clone(), *event);
221 }
222 }
223
224 let mut absolute_tick_to_ms = HashMap::<u32, usize>::new();
226 let mut last_tick = 0;
227 for (tick, tracks) in timeline.iter().sorted_by_key(|(tick, _)| *tick) {
228 for event in tracks.values() {
229 if let TrackEventKind::Meta(MetaMessage::Tempo(tempo)) = event.kind {
230 now.tempo = tempo.as_int() as usize;
231 }
232 }
233 let delta = tick - last_tick;
234 last_tick = *tick;
235 now.ms += midi_tick_to_ms(delta, now.tempo, now.ticks_per_beat as usize);
236 absolute_tick_to_ms.insert(*tick, now.ms);
237 }
238
239 if let Some(pb) = progressbar {
240 pb.set_length(midifile.tracks.iter().map(|t| t.len() as u64).sum::<u64>());
241 pb.set_prefix("Loading");
242 pb.set_message("parsing MIDI events");
243 pb.set_position(0);
244 }
245
246 let mut stem_notes = StemNotes::new();
248 for (tick, tracks) in timeline.iter().sorted_by_key(|(tick, _)| *tick) {
249 for (track_name, event) in tracks {
250 if let TrackEventKind::Midi {
251 channel: _,
252 message,
253 } = event.kind
254 {
255 match message {
256 MidiMessage::NoteOn { key, vel } | MidiMessage::NoteOff { key, vel } => {
257 stem_notes
258 .entry(absolute_tick_to_ms[tick] as u32)
259 .or_default()
260 .insert(
261 track_name.clone(),
262 Note {
263 tick: *tick,
264 ms: absolute_tick_to_ms[tick] as u32,
265 key: key.as_int(),
266 vel: if matches!(message, MidiMessage::NoteOff { .. }) {
267 0
268 } else {
269 vel.as_int()
270 },
271 },
272 );
273 }
274 _ => {}
275 }
276 }
277 progressbar.inc(1)
278 }
279 }
280
281 let mut result = HashMap::<String, Vec<Note>>::new();
282
283 for (_ms, notes) in stem_notes.iter().sorted_by_key(|(ms, _)| *ms) {
284 for (track_name, note) in notes {
285 result
286 .entry(track_name.clone())
287 .or_default()
288 .push(note.clone());
289 }
290 }
291
292 (now, result)
293}
294
295fn midi_tick_to_ms(tick: u32, tempo: usize, ppq: usize) -> usize {
296 let with_floats = (tempo as f32 / 1e3) / ppq as f32 * tick as f32;
297 with_floats.round() as usize
298}