Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
RobMcH authored Feb 17, 2018
1 parent cb37bd6 commit 3f47533
Show file tree
Hide file tree
Showing 15 changed files with 1,524 additions and 0 deletions.
255 changes: 255 additions & 0 deletions Tagger.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package tagger;

import tagger.data.ConfusionMatrix;
import tagger.data.Sentence;
import tagger.data.Token;
import tagger.model.Evaluation;
import tagger.model.FeatureExtractors;
import tagger.model.LabelExtractor;
import tagger.model.Perceptron;
import tagger.utility.Logger;

import java.io.*;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;

/**
* @author Robert McHardy
* @author Alexander Ehmann
* The multi-class perceptron based Part-of-Speech tagger.
*/
public class Tagger {
private static final HashSet<String> classCounter = new HashSet<>();

/**
* Returns the number of part of speech classes observed.
*
* @return number of POS classes.
*/
public static int numClasses() {
return classCounter.size();
}

/**
* Clears the performance critical data structures of the tagger infrastructure. This should be performed if the
* the tagger is run on different data sets in order to maintain the runtime performance.
*/
public static void clear() {
classCounter.clear();
FeatureExtractors.stringMapper.clear();
LabelExtractor.clear();
}

private Tagger() {
}

/**
* Constructs a list of sentences from a file in the CoNLL format including information regarding the gold
* and predicted labels.
*
* @param inputFile The name of the file in the CoNLL format.
* @return The constructed list of sentences.
*/
public static List<Sentence> readData(String inputFile) {
LinkedList<Sentence> sentences = new LinkedList<>();
try (FileReader fr = new FileReader(inputFile); BufferedReader buff = new BufferedReader(fr)) {
String line;
int sentNr = 0;
int previousSentNr = 1;
boolean notAdded = false;
Sentence sentence = new Sentence();
String[] contents;
while ((line = buff.readLine()) != null) {
contents = line.split("\t");
if (contents.length > 5 && contents[0].length() > 0) {
if (contents[0].contains("_")) {
sentNr = Integer.parseInt(contents[0].split("_")[0]);
} else {
if (Integer.parseInt(contents[0]) == 1) {
sentNr++;
}
}
Token token = new Token();
token.word = contents[1];
token.label = contents[4];
token.prediction = contents[5];
classCounter.add(token.label);
classCounter.add(token.prediction);

// Construct a list of determiners, proper nouns and adjectives as features
if (token.label.equals("DT")) {
FeatureExtractors.addDeterminer(token.word);
} else if (token.label.equals("JJ")) {
FeatureExtractors.addAdjective(token.word);
} else if (token.label.equals("NNP")) {
FeatureExtractors.addProperNoun(token.word);
}

if (!sentence.isEmpty()) {
token.previous = sentence.get(sentence.size() - 1);
sentence.get(sentence.size() - 1).next = token;
} else {
token.previous = null;
}
if (previousSentNr == sentNr) {
notAdded = true;
sentence.addToken(token);
} else {
notAdded = false;
sentences.add(sentence);
previousSentNr = sentNr;
sentence = new Sentence();
sentence.add(token);
}
}
}
if (notAdded) {
sentences.add(sentence);
}
} catch (IOException e) {
Logger.printException(e);
}
return sentences;
}

/**
* Extracts the three preceding and subsequent tokens for all tokens in the given list of sentences where the gold
* and predicted label were confused.
*
* @param data The list of sentences.
* @param goldLabel The gold label.
* @param predLabel The predicted label.
*/
public static void extractInstances(List<Sentence> data, String goldLabel, String predLabel) {
for (Sentence s : data) {
for (int i = 0; i < s.size(); i++) {
Token t = s.get(i);
if (t.label.equals(goldLabel) && t.prediction.equals(predLabel)) {
for (int j = Math.max(0, i - 3); j < Math.min(s.size(), i + 4); j++) {
Token temp = s.get(j);
if (temp.equals(t)) {
Logger.printString(String.format("%-13s\t%-10s\t%-10s\n", "*" + temp.word + "*", temp.label,
temp.prediction));
} else {
Logger.printString(String.format("%-15s\t%-10s\t%-10s\n", temp.word, temp.label,
temp.prediction));
}
}
Logger.printString("**********************\n");
}
}
}
}

/**
* Saves the predictions for the given data in the given file. The format is:
* <word> <label> <prediction> [*]
* Where [] marks that the star is optional and only present if the prediction is wrong.
*
* @param data A list of sentences.
* @param filepath A path to a file where the predictions will be stored.
*/
public static void savePredictions(List<Sentence> data, String filepath) {
try (FileWriter fw = new FileWriter(filepath); BufferedWriter buff = new BufferedWriter(fw)) {
for (Sentence sentence : data) {
for (Token token : sentence) {
buff.write(String.format("%s %s %s %s\n", token.word, token.label, token.prediction,
!token.label.equals(token.prediction) ? "*" : ""));
}
}
} catch (IOException e) {
Logger.printException(e);
}
}

/**
* Constructs and trains a perceptron on the given training data. If test data is present, it will be annotated.
*
* @param trainData The training data for the perceptron. This can't be null.
* @param testData The optional test data.
* @return The constructed perceptron.
*/
public static Perceptron pipeline(List<Sentence> trainData, List<Sentence> testData) {
if (trainData == null) {
throw new IllegalArgumentException("The training data can't be null.");
}
FeatureExtractors.extractAllFeatures(trainData);
if (testData != null) {
FeatureExtractors.extractAllFeatures(testData);
}
Perceptron p = new Perceptron(numClasses(), FeatureExtractors.stringMapper.numFeatures());
p.train(trainData, testData, 45);
return p;
}

/**
* Starting point for the tagger. The first argument is expected to always be a file path to a training file for the
* tagger.
*
* @param args Command line arguments.
*/
public static void main(String... args) {
List<Sentence> trainData = null;
List<Sentence> testData = null;
if (args.length > 0 && new File(args[0]).exists()) {
trainData = readData(args[0]);
} else if (args[0].equals("-r")) {
trainData = FeatureExtractors.readFromFile(args[1]);
} else {
Logger.printString("Use the -help command.\n");
System.exit(1);
}
if (args.length > 2 && new File(args[2]).exists()) {
testData = readData(args[2]);
}
Perceptron p = pipeline(trainData, testData);
switch (args.length) {
case 3:
if (args[1].equals("-w") && testData == null) {
// Save weights
p.saveWeights(args[2]);
break;
}
if (args[1].equals("-p") && testData == null) {
// Save predictions
savePredictions(trainData, args[2] + "-train");
break;
}
if (args[1].equals("-t")) {
// Fall through
}
case 4:
if (trainData != null && testData != null) {
// Training and test file exist
Logger.printString("Confusion matrix of test data:\n");
ConfusionMatrix c = new ConfusionMatrix(testData);
c.print(Math.min(5, numClasses()));
Logger.printString("Accuracy on training data: " + Evaluation.accuracy(trainData) + "\n");
Logger.printString("Accuracy on test data: " + Evaluation.accuracy(testData) + "\n");
}
break;
case 5:
if (args[3].equals("-p") && !new File(args[4] + "-train").exists() &&
!new File(args[4] + "-test").exists() && testData != null) {
// Save predictions
savePredictions(trainData, args[4] + "-train");
savePredictions(testData, args[4] + "-test");
}
break;
default:
Logger.printString("Usage: java Tagger <PathToTrainingFile> <Options>\n"); // -r path
Logger.printString("Options:\n-t <PathToTestFile> [OutputPath]: Path to a test file which will be "
+ "annotated. If [OutputPath] is not specified, the accuracy of the tagger and a confusion "
+ "matrix will be printed.");
Logger.printString("\n-w <OutputPath>: Save the weights of the tagger to the file specified by" +
"<OutputPath>.");
Logger.printString("\n-p <OutputPath>: Save the predictions for the training and test file to the file"
+ "specified by <OutputPath> plus an appended suffix.");
Logger.printString("\n-r <InputPath>: Reads sentences and their annotations from a file in the svm-" +
"multiclass format and uses this data to train a model.");
Logger.printString("\n-s <OutputPath>: Saves the sentences enriched with their extracted features in "
+ "the file given by <OutputPath>.");
}
}
}
114 changes: 114 additions & 0 deletions data/ConfusionMatrix.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package tagger.data;

import tagger.utility.BidirectionalMap;
import tagger.utility.Logger;
import tagger.utility.MutableInt;

import java.util.*;

/**
* @author Robert McHardy
* @author Alexander Ehmann
* Calculates a confusion matrix for a given list of sentences.
*/
public class ConfusionMatrix {
private final int[][] matrix;
private final BidirectionalMap<String, Integer> indexMap;

/**
* Used for ordering the labels according to their error frequency.
*/
private static <K, V extends Comparable<? super V>>
SortedSet<Map.Entry<K, V>> entriesSortedByValues(Map<K, V> map) {
SortedSet<Map.Entry<K, V>> sortedEntries = new TreeSet<>(
(e1, e2) -> {
int res = e2.getValue().compareTo(e1.getValue());
return res != 0 ? res : 1;
}
);
sortedEntries.addAll(map.entrySet());
return sortedEntries;
}

/**
* Construct a confusion matrix for a given list of sentences.
*
* @param data The list of sentences.
*/
public ConfusionMatrix(List<Sentence> data) {
TreeMap<String, MutableInt> labelFreq = new TreeMap<>();
// Count how many errors are made per label.
MutableInt count;
for (Sentence s : data) {
for (Token t : s) {
count = labelFreq.get(t.label);
if (count == null) {
labelFreq.put(t.label, new MutableInt());
} else {
count.increment();
}
count = labelFreq.get(t.prediction);
if (count == null) {
labelFreq.put(t.prediction, new MutableInt());
} else {
count.increment();
}
}
}
// Create the matrix.
matrix = new int[labelFreq.size()][];
indexMap = new BidirectionalMap<>(labelFreq.size());
for (int i = 0; i < labelFreq.size(); i++) {
matrix[i] = new int[labelFreq.size()];
}
// Create a list of labels sorted by error frequency.
int i = 0;
for (Map.Entry<String, MutableInt> e : entriesSortedByValues(labelFreq)) {
if (i < labelFreq.size()) {
indexMap.put(e.getKey(), i++);
}
}
// Fill the matrix.
for (Sentence s : data) {
for (Token t : s) {
matrix[indexMap.keyMapGet(t.label)][indexMap.keyMapGet(t.prediction)] += 1;
}
}
}

/**
* Calculate the number of errors for the given gold and predicted label according to the matrix.
*
* @param goldLabel The given gold label.
* @param predLabel The given predicted label.
* @return The number of errors.
*/
public int numberErrors(String goldLabel, String predLabel) {
return matrix[indexMap.keyMapGet(goldLabel)][indexMap.keyMapGet(predLabel)];
}

/**
* Prints the confusion matrix. The printed matrix is sorted such that the most frequent errors are in the upper
* left corner.
*
* @param maxDim The maximum dimensions of the printed matrix.
*/
public void print(int maxDim) {
Logger.printString("", 5);
for (int i = 0; i <= maxDim; i++) {
Logger.printString(indexMap.valueMapGet(i), 6);
}
Logger.printString("\n", 1);
for (int j = 0; j <= maxDim; j++) {
Logger.printString(indexMap.valueMapGet(j), 5);
for (int i = 0; i <= maxDim; i++) {
Logger.printString(Integer.toString(matrix[j][i]), 5);
if (i < maxDim) {
Logger.printString(" ");
} else {
Logger.printString("\n");
}
}
}
}
}
14 changes: 14 additions & 0 deletions data/Sentence.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package tagger.data;

import java.util.ArrayList;

/**
* @author Robert McHardy
* @author Alexander Ehmann
* A wrapper around java.util.ArrayList to store tokens (a sentence).
*/
public class Sentence extends ArrayList<Token> {
public void addToken(Token token) {
this.add(token);
}
}
Loading

0 comments on commit 3f47533

Please sign in to comment.