Skip to content

Commit

Permalink
compartmentlize
Browse files Browse the repository at this point in the history
  • Loading branch information
kreeben committed Nov 3, 2021
1 parent 64cfa80 commit 5055cdb
Show file tree
Hide file tree
Showing 25 changed files with 213 additions and 233 deletions.
6 changes: 3 additions & 3 deletions src/Sir.Core/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Sir
/// <typeparam name="T">The type of data the model should consist of.</typeparam>
public interface IModel<T> : IModel
{
IEnumerable<IVector> Tokenize(T data);
IEnumerable<ISerializableVector> Tokenize(T data);
}

/// <summary>
Expand All @@ -34,7 +34,7 @@ public interface IVectorSpaceConfig
public interface IDistanceCalculator
{
int NumOfDimensions { get; }
double CosAngle(IVector vec1, IVector vec2);
double CosAngle(IVector vector, long vectorOffset, int componentCount, Stream vectorStream);
double CosAngle(ISerializableVector vec1, ISerializableVector vec2);
double CosAngle(ISerializableVector vector, long vectorOffset, int componentCount, Stream vectorStream);
}
}
24 changes: 24 additions & 0 deletions src/Sir.Core/ISerializableVector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using MathNet.Numerics.LinearAlgebra;
using System.IO;

namespace Sir
{
public interface ISerializableVector
{
int[] Indices { get; }
float[] Values { get; }
Vector<float> Value { get; }
void Serialize(Stream stream);
int ComponentCount { get; }
object Label { get; }
void AddInPlace(ISerializableVector vector);
ISerializableVector Add(ISerializableVector vector);
ISerializableVector Subtract(ISerializableVector vector);
void SubtractInPlace(ISerializableVector vector);
ISerializableVector Multiply(float scalar);
ISerializableVector Divide(float scalar);
void AverageInPlace(ISerializableVector vector);
ISerializableVector Append(ISerializableVector vector);
ISerializableVector Shift(int numOfPositionsToShift, int numOfDimensions, string label = null);
}
}
24 changes: 0 additions & 24 deletions src/Sir.Core/IVector.cs

This file was deleted.

6 changes: 3 additions & 3 deletions src/Sir.Core/VectorNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class VectorNode
public long ComponentCount { get; set; }
public long VectorOffset { get; set; }
public long PostingsOffset { get; set; }
public IVector Vector { get; set; }
public ISerializableVector Vector { get; set; }

public object Sync { get; } = new object();

Expand Down Expand Up @@ -74,7 +74,7 @@ public VectorNode(long postingsOffset)
VectorOffset = -1;
}

public VectorNode(IVector vector = null, long docId = -1, long postingsOffset = -1, long? keyId = null, List<long> docIds = null)
public VectorNode(ISerializableVector vector = null, long docId = -1, long postingsOffset = -1, long? keyId = null, List<long> docIds = null)
{
Vector = vector;
ComponentCount = vector == null ? 0 : vector.ComponentCount;
Expand All @@ -101,7 +101,7 @@ public VectorNode(IVector vector = null, long docId = -1, long postingsOffset =
}
}

public VectorNode(long postingsOffset, long vecOffset, long terminator, long weight, IVector vector)
public VectorNode(long postingsOffset, long vecOffset, long terminator, long weight, ISerializableVector vector)
{
PostingsOffset = postingsOffset;
VectorOffset = vecOffset;
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.Document/DocumentReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public object GetValue(long offset, int len, byte dataType)
return _vals.Get(offset, len, dataType);
}

public IEnumerable<IVector> GetVectors<T>(long offset, int len, byte dataType, Func<T, IEnumerable<IVector>> tokenizer)
public IEnumerable<ISerializableVector> GetVectors<T>(long offset, int len, byte dataType, Func<T, IEnumerable<ISerializableVector>> tokenizer)
{
return _vals.GetVectors(offset, len, dataType, tokenizer);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.KeyValue/ValueReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public void Dispose()
_stream.Dispose();
}

public IEnumerable<IVector> GetVectors<T>(long offset, int len, byte dataType, Func<T, IEnumerable<IVector>> tokenizer)
public IEnumerable<ISerializableVector> GetVectors<T>(long offset, int len, byte dataType, Func<T, IEnumerable<ISerializableVector>> tokenizer)
{
int read;
Span<byte> buf = new byte[len];
Expand Down
6 changes: 3 additions & 3 deletions src/Sir.Search/Field.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ namespace Sir.Search
[DebuggerDisplay("{Name}")]
public class Field
{
private IEnumerable<IVector> _tokens;
private IEnumerable<ISerializableVector> _tokens;

public VectorNode Tree { get; private set; }
public long KeyId { get; set; }
public long DocumentId { get; set; }
public string Name { get; }
public object Value { get; set; }
public IEnumerable<IVector> Tokens { get { return _tokens; } }
public IEnumerable<ISerializableVector> Tokens { get { return _tokens; } }

public Field(string name, object value, long keyId = -1, long documentId = -1)
{
Expand All @@ -28,7 +28,7 @@ public Field(string name, object value, long keyId = -1, long documentId = -1)
DocumentId = documentId;
}

private IEnumerable<IVector> GetTokens()
private IEnumerable<ISerializableVector> GetTokens()
{
foreach (var node in PathFinder.All(Tree))
yield return node.Vector;
Expand Down
31 changes: 24 additions & 7 deletions src/Sir.Search/Models/BagOfCharsModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public void ExecutePut<T>(VectorNode column, VectorNode node)
column.MergeOrAddConcurrent(node, this);
}

public IEnumerable<IVector> Tokenize(string data)
public IEnumerable<ISerializableVector> Tokenize(string data)
{
ReadOnlyMemory<char> source = data.AsMemory();

Expand All @@ -39,7 +39,7 @@ public IEnumerable<IVector> Tokenize(string data)
{
var len = index - offset;

var vector = new IndexedVector(
var vector = new SerializableVector(
embedding,
NumOfDimensions,
new string(source.Span.Slice(offset, len)));
Expand All @@ -56,7 +56,7 @@ public IEnumerable<IVector> Tokenize(string data)
{
var len = index - offset;

var vector = new IndexedVector(
var vector = new SerializableVector(
embedding,
NumOfDimensions,
new string(source.Span.Slice(offset, len)));
Expand All @@ -67,6 +67,23 @@ public IEnumerable<IVector> Tokenize(string data)
}
}

public static class TokenizeOperations
{
public static void AddOrAppendToComponent(this SortedList<int, float> vec, int key)
{
float v;

if (vec.TryGetValue(key, out v))
{
vec[key] = v + 1;
}
else
{
vec.Add(key, 1);
}
}
}

public class BocEmbeddingsModel : DistanceCalculator, IModel<string>
{
public double IdenticalAngle => 0.95d;
Expand All @@ -86,7 +103,7 @@ public void ExecutePut<T>(VectorNode column, VectorNode node)
column.Build(node, this);
}

public IEnumerable<IVector> Tokenize(string data)
public IEnumerable<ISerializableVector> Tokenize(string data)
{
return _wordTokenizer.Tokenize(data);
}
Expand All @@ -110,16 +127,16 @@ public void ExecutePut<T>(VectorNode column, VectorNode node)
column.MergeOrAdd(node, this);
}

public IEnumerable<IVector> Tokenize(string data)
public IEnumerable<ISerializableVector> Tokenize(string data)
{
var tokens = (IList<IVector>)_wordTokenizer.Tokenize(data);
var tokens = (IList<ISerializableVector>)_wordTokenizer.Tokenize(data);

for (int i = 0; i < tokens.Count; i++)
{
var context0 = i - 1;
var context1 = i + 1;
var token = tokens[i];
var vector = new IndexedVector(NumOfDimensions, token.Label);
var vector = new SerializableVector(NumOfDimensions, token.Label);

if (context0 >= 0)
{
Expand Down
4 changes: 2 additions & 2 deletions src/Sir.Search/Models/DistanceCalculator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public abstract class DistanceCalculator : IDistanceCalculator
{
public abstract int NumOfDimensions { get; }

public double CosAngle(IVector vec1, IVector vec2)
public double CosAngle(ISerializableVector vec1, ISerializableVector vec2)
{
var dotSelf1 = vec1.Value.Norm(2);
var dotSelf2 = vec2.Value.Norm(2);
Expand All @@ -19,7 +19,7 @@ public double CosAngle(IVector vec1, IVector vec2)
return dotProduct / (dotSelf1 * dotSelf2);
}

public double CosAngle(IVector vector, long vectorOffset, int componentCount, Stream vectorStream)
public double CosAngle(ISerializableVector vector, long vectorOffset, int componentCount, Stream vectorStream)
{
Span<byte> buf = new byte[componentCount * 2 * sizeof(int)];

Expand Down
4 changes: 2 additions & 2 deletions src/Sir.Search/Models/LinearClassifierImageModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public void ExecutePut<T>(VectorNode column, VectorNode node)
column.MergeOrAddSupervised(node, this);
}

public IEnumerable<IVector> Tokenize(IImage data)
public IEnumerable<ISerializableVector> Tokenize(IImage data)
{
var pixels = data.Pixels.Select(x => Convert.ToSingle(x));

yield return new IndexedVector(pixels, data.Label);
yield return new SerializableVector(pixels, data.Label);
}
}
}
2 changes: 1 addition & 1 deletion src/Sir.Search/Session/IndexSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void Put(long docId, long keyId, T value)
Put(docId, keyId, tokens);
}

public void Put(long docId, long keyId, IEnumerable<IVector> tokens)
public void Put(long docId, long keyId, IEnumerable<ISerializableVector> tokens)
{
var tree = new VectorNode(keyId: keyId);

Expand Down
55 changes: 14 additions & 41 deletions src/Sir.Store.Tests/ImageModelTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Microsoft.Extensions.Logging;
using NUnit.Framework;
using Sir.Mnist;
using Sir.Search;
Expand All @@ -12,17 +11,17 @@ namespace Sir.Tests
{
public class ImageModelTests
{
private ILoggerFactory _loggerFactory;
private ILogger<ImageModelTests> _logger;
private Database _sessionFactory;
private IImage[] _data;
private string _directory = @"c:\temp\sir_tests";

[Test]
public void Can_train_in_memory()
public void Can_create_in_memory_linear_classifier()
{
// Use the same set of images to both create and validate a linear classifier.

var trainingData = new MnistReader(
@"resources\t10k-images.idx3-ubyte",
@"resources\t10k-labels.idx1-ubyte").Read().Take(100).ToArray();

var model = new LinearClassifierImageModel();
var tree = model.CreateTree(model, _data);
var tree = model.CreateTree(model, trainingData);

Print(tree);

Expand All @@ -31,23 +30,23 @@ public void Can_train_in_memory()
var count = 0;
var errors = 0;

foreach (var word in _data)
foreach (var image in trainingData)
{
foreach (var queryVector in model.Tokenize(word))
foreach (var queryVector in model.Tokenize(image))
{
var hit = PathFinder.ClosestMatch(tree, queryVector, model);

if (hit == null)
{
throw new Exception($"unable to find {word} in tree.");
throw new Exception($"unable to find {image} in tree.");
}

if (!hit.Node.Vector.Label.Equals(word.Label))
if (!hit.Node.Vector.Label.Equals(image.Label))
{
errors++;
}

Debug.WriteLine($"{word} matched with {hit.Node.Vector.Label} with {hit.Score * 100}% certainty.");
Debug.WriteLine($"{image} matched with {hit.Node.Vector.Label} with {hit.Score * 100}% certainty.");

count++;
}
Expand All @@ -64,36 +63,10 @@ public void Can_train_in_memory()
});
}

[SetUp]
public void Setup()
{
_loggerFactory = LoggerFactory.Create(builder =>
{
builder
.AddFilter("Microsoft", LogLevel.Warning)
.AddFilter("System", LogLevel.Warning)
.AddDebug();
});

_logger = _loggerFactory.CreateLogger<ImageModelTests>();

_sessionFactory = new Database(logger: _loggerFactory.CreateLogger<Database>());

_data = new MnistReader(
@"C:\temp\mnist\t10k-images.idx3-ubyte",
@"C:\temp\mnist\t10k-labels.idx1-ubyte").Read().Take(100).ToArray();
}

[TearDown]
public void TearDown()
{
_sessionFactory.Dispose();
}

private static void Print(VectorNode tree)
{
var diagram = PathFinder.Visualize(tree);
File.WriteAllText(@"c:\temp\imagemodeltesttree.txt", diagram);
File.WriteAllText("imagemodeltesttree.txt", diagram);
Debug.WriteLine(diagram);
}
}
Expand Down
Loading

0 comments on commit 5055cdb

Please sign in to comment.