diff --git a/djmarkov.py b/djmarkov.py index 65a6f61..e0f90af 100644 --- a/djmarkov.py +++ b/djmarkov.py @@ -1,7 +1,7 @@ #!/usr/bin/env python from PyLyrics import * -from markov import Markov +from markov import TextMarkov import musicbrainzngs import sys import argparse @@ -120,6 +120,10 @@ def main(): print(str(len(songs)) + " songs available.") print("Done.") + if len(songs) == 0: + print("Artist does not have songs on file. Exiting.") + exit() + print("Parsing available songs...") parsed_songs = [] for song in songs: @@ -130,10 +134,10 @@ def main(): .replace(":", "")) print("Done.") print("Building Markov Chain...") - markov_chain = Markov(" ".join(parsed_songs)) + markov_chain = TextMarkov(" ".join(parsed_songs)) print("Done.") result = markov_chain.traverse(150) - print(result) + print(" ".join(result)) if __name__ == "__main__": main() diff --git a/markov.py b/markov.py index 40d13e7..1962c33 100644 --- a/markov.py +++ b/markov.py @@ -3,9 +3,10 @@ from collections import defaultdict import numpy as np from random import randint - +import json + class Markov(object): - def __init__(self, text=None): + def __init__(self, obj_list=None, load=None): self.freq_table = None self.dict_index = 0 @@ -13,109 +14,158 @@ def __init__(self, text=None): self.inner_freq_index = 1 self.total_index = 1 - self.root_word = None + self.root = None self.chain = None self.reset() - if text: - self.add_text(text) + if load: + self.freq_table = json.load(open(load, 'r')) + if obj_list: + self.add(obj_list) #Function for consuming new text to add to existing Markov chain. - def add_text(self, text): - tokens = text.split(" ") - for i in range(len(tokens) - 1): - orig = tokens[i] - after = tokens[i + 1] - self.add_word(orig, after) - self.calc_frequency() + def add(self, obj_list): + changed_keys = set() + for i in range(len(obj_list) - 1): + orig = obj_list[i] + changed_keys.add(orig) + after = obj_list[i + 1] + self._add_pair(orig, after) + self._calc_frequency(changed_keys=changed_keys) #Helper function for add_text. #Adds a word to the frequency table WITHOUT recalculating frequency value. - def add_word(self, orig, next_word): - self.freq_table[orig][self.dict_index][next_word][self.inner_count_index] += 1 + def _add_pair(self, orig, after): + self.freq_table[orig][self.dict_index][after][self.inner_count_index] += 1 self.freq_table[orig][self.total_index] += 1 #Calculates frequencies of potential next words for every original word. - def calc_frequency(self): - for k in list(self.freq_table.keys()): + def _calc_frequency(self, changed_keys=None): + if changed_keys is None: + changed_keys = list(self.freq_table.keys()) + for k in changed_keys: for ik in list(self.freq_table[k][self.dict_index].keys()): self.freq_table[k][self.dict_index][ik][self.inner_freq_index] = self.freq_table[k][self.dict_index][ik][self.inner_count_index] / self.freq_table[k][self.total_index] def reset(self): self.freq_table = defaultdict(lambda: [defaultdict(lambda: [0, 0.0]), 0]) - self.root_word = None + self.root = None self.chain = None - def random_word(self): + def random(self): if not self.freq_table: - raise ValueError("No text available") - all_words = list(self.freq_table.keys()) - root = all_words[randint(0, len(all_words) - 1)] + raise ValueError("No objects available in frequency table") + all_objs = list(self.freq_table.keys()) + root = all_objs[randint(0, len(all_objs) - 1)] return root - def traverse(self, dist, min_dist=None, root_word=None, restart_on_error=True): + def traverse(self, dist, min_dist=None, root=None, restart_on_error=True): if not self.freq_table: raise ValueError("No text available") - if root_word is None: - self.root_word = self.random_word() + if root is None: + self.root = self.random() else: - self.root_word = root_word + self.root = root self.chain = self.gen_chain() - result = "" + result = [] chain_index = 0 while chain_index < dist: try: - result += next(self.chain) + " " + result.append(next(self.chain)) chain_index += 1 except StopIteration: if restart_on_error: if not min_dist is None: if chain_index >= (min_dist - 1): break - self.root_word = self.random_word() + self.root == self.random() self.chain = self.gen_chain() continue else: break return result + def best(dist, root=None): + if not self.freq_table: + raise ValueError("No text available") + if root is None: + self.root = self.random() + else: + self.root = root + result = [] + chain_index = 0 + while chain_index < dist: + item = self.retrieve(root) + + chain_index += 1 + def gen_chain(self): - if not self.root_word: - raise ValueError("Root word has not been set") - yield self.root_word - frequencies = self.retrieve(self.root_word) - while frequencies and not ((len(frequencies) == 1) and frequencies[0][0] == None): - frequencies = self.retrieve(self.root_word) + if not self.root: + raise ValueError("Root has not been set") + yield self.root + frequencies = self.retrieve(self.root) + while frequencies and not ((len(frequencies) == 1) and + frequencies[0][0] == None): + frequencies = self.retrieve(self.root) choices = [c for c, w in frequencies] weights = [w for c, w in frequencies] try: - new_word = np.random.choice(choices, size=1, p=weights)[0] + new_obj = np.random.choice(choices, size=1, p=weights)[0] except ValueError: #Weights don't sum to 1. Add a dummy and try again. choices.append(None) weights.append(1 - sum(weights)) - new_word = np.random.choice(choices, size=1, p=weights)[0] - if new_word is None: + new_obj = np.random.choice(choices, size=1, p=weights)[0] + if new_obj is None: continue else: - self.root_word = new_word - yield new_word + self.root = new_obj + yield new_obj - def next_words(self, orig): - next_words = list(self.freq_table[orig][self.dict_index].keys()) - return next_words + def next_objs(self, orig): + next_objs = list(self.freq_table[orig][self.dict_index].keys()) + return next_objs #Return a list of tuples in the format of [("next_word", frequency)] def retrieve(self, orig): - freq_list = [(i, self.freq_table[orig][self.dict_index][i][self.inner_freq_index]) + freq_list = [(i, self.freq_table[orig][self.dict_index][i] + [self.inner_freq_index]) for i in self.freq_table[orig][self.dict_index].keys()] return freq_list #Print entire frequency table, prettily. def print_table(self): for i in list(self.freq_table.keys()): - self.pretty_print_freq(i) + self.print_freq(i) #Print single entry in the frequency table. def print_freq(self, orig): freq_list = self.retrieve(orig) print(orig, ":", freq_list) + + def dumps(self): + for i in self.freq_table: + print(i, self.freq_table[i]) + + def consume(target): + #Iterate over rows of frequency table, borrowing rows with new keys, + #adjusting and recalculating frequencies of rows with a familiar key. + pass + + def save(self, name): + if self.freq_table: + f = open(name, 'w') + json.dump(self.freq_table, f) + else: + raise ValueError("Cannot save an empty frequency table.") + +class TextMarkov(Markov): + def add(self, obj_list): + if not isinstance(obj_list, str): + raise ValueError("Object must be str type") + data = obj_list.split(" ") + for i in range(len(data) - 1): + orig = data[i] + after = data[i + 1] + self._add_pair(orig, after) + self._calc_frequency() +