Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tandav committed Jan 20, 2024
1 parent 1031f9b commit b157428
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 64 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ ignore = [
"FIX002",
"FIX004",
"B028",
"PTH123",
]

[tool.ruff.per-file-ignores]
Expand Down
108 changes: 52 additions & 56 deletions src/musiclib/midi/notation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import argparse
import pathlib
import json
import abc
import bisect
import collections
import json
import operator
import pathlib
from typing import NamedTuple

import mido
from mido.midifiles.tracks import _to_reltime

import musiclib
from musiclib.midi.parse import Midi
from musiclib.midi.parse import MidiNote
from mido.midifiles.tracks import _to_reltime
from musiclib.midi.parse import abs_messages
from musiclib.note import SpecificNote
from musiclib.midi.player import Player

import mido
from musiclib.note import SpecificNote


class IntervalEvent(NamedTuple):
Expand All @@ -24,44 +23,36 @@ class IntervalEvent(NamedTuple):
off: int


class Event(abc.ABC):
def __init__(self, code: str):
self.type, *kw = code.splitlines()
kw = dict(kv.split(maxsplit=1) for kv in kw)
self.post_init(**kw)

def post_init(self, **kw):
pass
class Event:
def __init__(self, code: str) -> None:
self.type, *_kw = code.splitlines()
self.kw = dict(kv.split(maxsplit=1) for kv in _kw)


class Header(Event):
def post_init(
self,
version: str,
root: str,
channels: str,
ticks_per_beat: str = '96',
):
if musiclib.__version__ != version:
raise ValueError(f'musiclib must be exact version {version} to parse notation')
self.version = version
self.root = SpecificNote.from_str(root)
self.ticks_per_beat = int(ticks_per_beat)
self.channel_map = json.loads(channels)
def __init__(self, code: str) -> None:
super().__init__(code)
self.version = self.kw['version']
if musiclib.__version__ != self.version:
raise ValueError(f'musiclib must be exact version {self.version} to parse notation')
self.root = SpecificNote.from_str(self.kw['root'])
self.ticks_per_beat = int(self.kw['ticks_per_beat'])
self.channel_map = json.loads(self.kw['channels'])


class Modulation(Event):
def post_init(self, root: str):
self.root = SpecificNote.from_str(root)
def __init__(self, code: str) -> None:
super().__init__(code)
self.root = SpecificNote.from_str(self.kw['root'])


class Voice:
def __init__(self, code: str):
def __init__(self, code: str) -> None:
self.channel, intervals_str = code.split(maxsplit=1)
self.interval_events = self.parse_interval_events(intervals_str)

def parse_interval_events(self, intervals_str: str, ticks_per_beat: int = 96):
interval = None
def parse_interval_events(self, intervals_str: str, ticks_per_beat: int = 96) -> list[IntervalEvent]:
interval: int | None = None
on = 0
off = 0
interval_events = []
Expand All @@ -81,14 +72,17 @@ def parse_interval_events(self, intervals_str: str, ticks_per_beat: int = 96):
on = off
off += ticks_per_beat
interval = int(interval_str, base=12)
if interval is None:
raise ValueError('Cannot have empty voice')
interval_events.append(IntervalEvent(interval, on, off))
return interval_events


class Bar:
def __init__(self, code: str):
def __init__(self, code: str) -> None:
self.voices = [Voice(voice_code) for voice_code in code.splitlines()]

def to_midi(self, root: SpecificNote, figured_bass: bool = True):
def to_midi(self, root: SpecificNote, *, figured_bass: bool = True) -> dict[str, list[MidiNote]]:
if not isinstance(root, SpecificNote):
raise TypeError(f'root must be SpecificNote, got {root}')
channels = collections.defaultdict(list)
Expand Down Expand Up @@ -126,25 +120,25 @@ def to_midi(self, root: SpecificNote, figured_bass: bool = True):
)
return dict(channels)


class Notation:
def __init__(self, code: str) -> None:
self.parse(code)
self.ticks_per_beat = self.header.ticks_per_beat
self.channel_map = self.header.channel_map

def parse(self, code: str):
events = code.strip().split('\n\n')
self.header, *self.events = [self.parse_event(event) for event in events]

def parse_event(self, code: str):
if code.startswith('header'):
return Header(code)
if code.startswith('modulation'):
return Modulation(code)
return Bar(code)
def parse(self, code: str) -> None:
self.events: list[Event | Bar] = []
for event_code in code.strip().split('\n\n'):
if event_code.startswith('header'):
self.header = Header(event_code)
elif event_code.startswith('modulation'):
self.events.append(Modulation(event_code))
else:
self.events.append(Bar(event_code))

def _to_midi(self):
channels = [[] for _ in range(len(self.channel_map))]
def _to_midi(self) -> list[Midi]:
channels: list[list[MidiNote]] = [[] for _ in range(len(self.channel_map))]
root = self.header.root
t = 0
for event in self.events:
Expand All @@ -161,25 +155,27 @@ def _to_midi(self):
channels[channel_id] += [
MidiNote(
note=note.note,
on=t+note.on,
off=t+note.off,
on=t + note.on,
off=t + note.off,
channel=channel_id,
)
for note in notes
]
t += bar_off
else:
raise ValueError(f'unknown event type: {event}')
raise TypeError(f'unknown event type: {event}')

channels = [Midi(notes=v, ticks_per_beat=self.ticks_per_beat) for v in channels]
return channels
return [Midi(notes=v, ticks_per_beat=self.ticks_per_beat) for v in channels]

def to_midi(self):
tracks = [mido.MidiTrack(_to_reltime(abs_messages(midi))) for midi in self._to_midi()]
return mido.MidiFile(tracks=tracks, type=1, ticks_per_beat=self.ticks_per_beat)
def to_midi(self) -> mido.MidiFile:
return mido.MidiFile(
type=1,
ticks_per_beat=self.ticks_per_beat,
tracks=[mido.MidiTrack(_to_reltime(abs_messages(midi))) for midi in self._to_midi()],
)


def play_file():
def play_file() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('filepath', type=pathlib.Path)
parser.add_argument('--bpm', type=float, default=120)
Expand Down
4 changes: 2 additions & 2 deletions src/musiclib/midi/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def parse_midi(midi: mido.MidiFile) -> Midi:
elif is_note('off', message):
note = playing_notes[message.note]
note['off'] = t
print(message, playing_notes, note)
notes.append(MidiNote(**note))
del playing_notes[message.note]
elif message.type == 'pitchwheel':
pitchbend.append(MidiPitch(time=t, pitch=message.pitch))
return Midi(notes=notes, pitchbend=pitchbend, ticks_per_beat=midi.ticks_per_beat)


def midiobj_to_midifile(midi: Midi) -> mido.MidiFile:
track = mido.MidiTrack(_to_reltime(abs_messages(midi)))
return mido.MidiFile(type=0, tracks=[track], ticks_per_beat=midi.ticks_per_beat)
Expand All @@ -107,7 +107,7 @@ def abs_messages(midi: Midi) -> list[mido.Message]:
out.append(mido.Message(type='note_on', time=note.on, note=note.note.i, velocity=note.velocity, channel=note.channel))
out.append(mido.Message(type='note_off', time=note.off, note=note.note.i, velocity=note.velocity, channel=note.channel))
for pitch in midi.pitchbend:
out.append(mido.Message(type='pitchwheel', time=pitch.time, pitch=pitch.pitch))
out.append(mido.Message(type='pitchwheel', time=pitch.time, pitch=pitch.pitch)) # noqa: PERF401
out.sort(key=lambda m: (m.time, {'note_off': 0, 'pitchwheel': 1, 'note_on': 2}[m.type]))
return out

Expand Down
2 changes: 1 addition & 1 deletion src/musiclib/midi/pitchbend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def insert_pitch_pattern(

def make_notes_pitchbends(midi: Midi) -> dict[MidiNote, list[MidiPitch]]:
T, P = zip(*[(e.time, e.pitch) for e in midi.pitchbend], strict=True) # noqa: N806
T_set = set(T)
T_set = set(T) # noqa: N806
interp_t = []
for note in midi.notes:
for t in (note.on, note.off):
Expand Down
11 changes: 6 additions & 5 deletions tests/midi/notation_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import pytest
from pathlib import Path

import pytest
from musiclib.midi import notation
from musiclib.midi.notation import IntervalEvent
from musiclib.midi import parse
from musiclib.midi.notation import IntervalEvent


@pytest.mark.parametrize(
'code, channel', [
('code', 'channel'), [
('flute 1 2 3 4', 'flute'),
('bass 9 8 -7 17 4 -12 0', 'bass'),
],
Expand All @@ -17,7 +18,7 @@ def test_voice_channel(code, channel):


@pytest.mark.parametrize(
'code, interval_events', [
('code', 'interval_events'), [
(
'flute 1 2 3 4',
[
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_voice(code, interval_events):
assert notation.Voice(code).interval_events == interval_events


@pytest.mark.parametrize('example_dir', (Path(__file__).parent / f'data/notation').iterdir())
@pytest.mark.parametrize('example_dir', (Path(__file__).parent / 'data/notation').iterdir())
def test_to_midi(example_dir):
code = (example_dir / 'code.txt').read_text()
with open(example_dir / 'midi.json') as f:
Expand Down

0 comments on commit b157428

Please sign in to comment.