Skip to content

Commit

Permalink
approximate all double identity comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
kreeben committed May 14, 2024
1 parent c2c104c commit fccc910
Show file tree
Hide file tree
Showing 23 changed files with 115 additions and 93 deletions.
2 changes: 1 addition & 1 deletion index.bat
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sir.bat indexwikipedia --directory C:\projects\resin\src\Sir.HttpServer\AppData\database --collection wikipedia --skip 0 --take 1000000 --pageSize 10000 --sampleSize 1000 %*
sir.bat indexwikipedia --directory C:\projects\resin\src\Sir.HttpServer\AppData\database --collection wikipedia --skip 0 --take 1000000 --pageSize 1000 --sampleSize 1000 %*
3 changes: 1 addition & 2 deletions src/Sir.Cmd/AnalyzeDocumentCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ public void Run(IDictionary<string, string> args, ILogger logger)
var select = new HashSet<string>(args["select"].Split(new char[] { ',', ' ' }, StringSplitOptions.RemoveEmptyEntries));
var collectionId = collection.ToHash();
var model = new BagOfCharsModel();
var embedding = new SortedList<int, float>();

using (var documentReader = new DocumentRegistryReader(dataDirectory, collectionId))
{
var doc = DocumentReader.Read(documentId, documentReader, select);

foreach (var field in doc.Fields)
{
var tokens = model.CreateEmbedding(field.Value.ToString(), true, embedding);
var tokens = model.CreateEmbedding(field.Value.ToString(), true);
var tree = new VectorNode();

foreach (var token in tokens)
Expand Down
3 changes: 1 addition & 2 deletions src/Sir.Cmd/BenchmarkCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ public void RunTokenizeBenchmark(IDictionary<string, string> args, ILogger logge
var model = new BagOfCharsModel();
var documents = new List<Document>(WikipediaHelper.Read(fileName, skip, take, new HashSet<string> { "text" }));
var timer = Stopwatch.StartNew();
var embedding = new SortedList<int, float>();

for (int i = 0; i < numOfRuns; i++)
{
foreach (var document in documents)
{
var embeddings = new List<ISerializableVector>(model.CreateEmbedding((string)document.Fields[0].Value, false, embedding));
var embeddings = new List<ISerializableVector>(model.CreateEmbedding((string)document.Fields[0].Value, false));
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/Sir.Cmd/ValidateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ public void Run(IDictionary<string, string> args, ILogger logger)
var selectFields = new HashSet<string> { "title" };
var time = Stopwatch.StartNew();
var count = 0;
var embedding = new SortedList<int, float>();

using (var kvReader = new KeyValue.KeyValueReader(dir, collectionId))
using (var validateSession = new ValidateSession<string>(
collectionId,
new SearchSession<string>(dir, model, new LogStructuredIndexingStrategy(model), logger),
new QueryParser<string>(kvReader, model, embedding: embedding, logger: logger)))
new QueryParser<string>(kvReader, model, logger: logger)))
using (var documents = new DocumentStreamSession(dir))
{
foreach (var doc in documents.ReadDocuments<string>(collectionId, selectFields, skip, take))
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.ImageTests/ImageModelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public void Can_traverse_streamed()
throw new Exception($"unable to find {word} in tree.");
}

if (hit.Score < model.IdenticalAngle)
if (hit.Score.Approximates(model.IdenticalAngle))
{
throw new Exception($"unable to score {word}.");
}
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.Images/LinearClassifierImageModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class LinearClassifierImageModel : Sir.DistanceCalculator, IModel<IImage>
public double FoldAngle => 0.75d;
public override int NumOfDimensions => 784;

public IEnumerable<ISerializableVector> CreateEmbedding(IImage data, bool label, SortedList<int, float> embedding = null)
public IEnumerable<ISerializableVector> CreateEmbedding(IImage data, bool label)
{
var pixels = data.Pixels.Select(x => Convert.ToSingle(x));

Expand Down
2 changes: 1 addition & 1 deletion src/Sir.InformationRetreival/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Sir
/// <typeparam name="T">The type of data the model should consist of.</typeparam>
public interface IModel<T> : IModel
{
IEnumerable<ISerializableVector> CreateEmbedding(T data, bool label, SortedList<int, float> embedding = null);
IEnumerable<ISerializableVector> CreateEmbedding(T data, bool label);
}

/// <summary>
Expand Down
3 changes: 1 addition & 2 deletions src/Sir.InformationRetreival/IO/GraphBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ public static class GraphBuilder
public static VectorNode CreateTree<T>(this IModel<T> model, IIndexReadWriteStrategy indexingStrategy, params T[] data)
{
var root = new VectorNode();
var embedding = new SortedList<int, float>();

foreach (var item in data)
{
foreach (var vector in model.CreateEmbedding(item, true, embedding))
foreach (var vector in model.CreateEmbedding(item, true))
{
indexingStrategy.Put<T>(root, new VectorNode(vector));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void SerializePage(string directory, ulong collectionId, long keyId, Vect
var time = Stopwatch.StartNew();

using (var vectorStream = StreamFactory.CreateAppendStream(directory, collectionId, keyId, "vec"))
using (var postingsWriter = new PostingsWriter(StreamFactory.CreateSeekableWriteStream(directory, collectionId, keyId, "pos"), indexCache:null))
using (var postingsWriter = new PostingsWriter(StreamFactory.CreateSeekableWriteStream(directory, collectionId, keyId, "pos"), indexCache: indexCache))
using (var columnWriter = new ColumnWriter(StreamFactory.CreateAppendStream(directory, collectionId, keyId, "ix")))
using (var pageIndexWriter = new PageIndexWriter(StreamFactory.CreateAppendStream(directory, collectionId, keyId, "ixtp")))
{
Expand Down
6 changes: 2 additions & 4 deletions src/Sir.InformationRetreival/Parsers/QueryParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ public class QueryParser<T>
private readonly KeyValueReader _kvReader;
private readonly IModel<T> _model;
private readonly ILogger _logger;
private readonly SortedList<int, float> _embedding;

public QueryParser(KeyValueReader kvReader, IModel<T> model, SortedList<int, float> embedding = null, ILogger logger = null)
public QueryParser(KeyValueReader kvReader, IModel<T> model, ILogger logger = null)
{
_kvReader = kvReader;
_model = model;
_logger = logger;
_embedding = embedding ?? new SortedList<int, float>();
}

public Query Parse(
Expand Down Expand Up @@ -243,7 +241,7 @@ private IList<Term> CreateTerms(ulong collectionId, string key, T value, bool an

if (_kvReader.TryGetKeyId(key.ToHash(), out keyId))
{
var tokens = _model.CreateEmbedding(value, label, _embedding);
var tokens = _model.CreateEmbedding(value, label);

foreach (var term in tokens)
{
Expand Down
9 changes: 9 additions & 0 deletions src/Sir.InformationRetreival/SerializableVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ public class SerializableVector : ISerializableVector
public int[] Indices { get { return ((SparseVectorStorage<float>)Value.Storage).Indices; } }
public float[] Values { get { return ((SparseVectorStorage<float>)Value.Storage).Values; } }

public SerializableVector()
{
}

public SerializableVector(int numOfDimensions, object label = null)
{
Value = CreateVector.Sparse<float>(numOfDimensions);
Expand Down Expand Up @@ -48,6 +52,11 @@ public SerializableVector(SortedList<int, float> dictionary, int numOfDimensions
Label = label;
}

public bool IsEmptyVector()
{
return Value == null;
}

public SerializableVector(int[] index, float[] values, int numOfDimensions, object label = null)
{
var tuples = new Tuple<int, float>[Math.Min(index.Length, numOfDimensions)];
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.InformationRetreival/Session/DocumentDatabase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public DocumentDatabase(string directory, ulong collectionId, IModel<T> model =

public QueryParser<T> CreateQueryParser()
{
return new QueryParser<T>(SearchSession.GetKeyValueReader(_collectionId), _model, IndexSession.EmptyEmbedding, _logger);
return new QueryParser<T>(SearchSession.GetKeyValueReader(_collectionId), _model, _logger);
}

public IEnumerable<Document> StreamDocuments(HashSet<string> fieldsOfInterest, int skip, int take)
Expand Down
8 changes: 2 additions & 6 deletions src/Sir.InformationRetreival/Session/DocumentStreamSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,8 @@ public IEnumerable<VectorNode> ReadDocumentValuesAsVectors<T>(
HashSet<string> select,
DocumentRegistryReader documentReader,
IModel<T> model,
bool label,
SortedList<int, float> embedding = null)
bool label)
{
if (embedding == null)
embedding = new SortedList<int, float>();

var docInfo = documentReader.GetDocumentAddress(doc.docId);
var docMap = documentReader.GetDocumentMap(docInfo.offset, docInfo.length);

Expand All @@ -222,7 +218,7 @@ public IEnumerable<VectorNode> ReadDocumentValuesAsVectors<T>(
{
var vInfo = documentReader.GetAddressOfValue(kvp.valId);

foreach (var vector in documentReader.GetValueConvertedToVectors<T>(vInfo.offset, vInfo.len, vInfo.dataType, value => model.CreateEmbedding(value, label, embedding)))
foreach (var vector in documentReader.GetValueConvertedToVectors<T>(vInfo.offset, vInfo.len, vInfo.dataType, value => model.CreateEmbedding(value, label)))
{
tree.AddIfUnique(new VectorNode(vector, docId:doc.docId, keyId:kvp.keyId), model);
}
Expand Down
19 changes: 0 additions & 19 deletions src/Sir.InformationRetreival/Session/IIndexSession.cs

This file was deleted.

4 changes: 2 additions & 2 deletions src/Sir.InformationRetreival/Session/IndexCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void Put(VectorNode node)
{
var hit = PathFinder.ClosestMatch(tree, vector, _model);

if (hit.Score >= _model.IdenticalAngle)
if (hit.Score.Approximates(_model.IdenticalAngle))
{
return hit.Node.PostingsOffset == -1 ? null : hit.Node.PostingsOffset;
}
Expand All @@ -49,7 +49,7 @@ public void UpdatePostingsOffset(long keyId, ISerializableVector vector, long po
{
var hit = PathFinder.ClosestMatch(tree, vector, _model);

if (hit.Score >= _model.IdenticalAngle)
if (hit.Score.Approximates(_model.IdenticalAngle))
{
hit.Node.PostingsOffset = postingsOffset;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Sir.InformationRetreival/Session/IndexDebugger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public IndexDebugger(ILogger logger, int sampleSize = 1000)
_logger = logger;
}

public void Step(IIndexSession indexSession, string message = null)
public void Step<T>(IndexSession<T> indexSession, string message = null)
{
_steps++;

Expand Down
34 changes: 7 additions & 27 deletions src/Sir.InformationRetreival/Session/IndexSession.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.Logging;
using Sir.IO;
using System;
using System.Collections.Generic;
using System.Diagnostics;
Expand All @@ -10,7 +9,7 @@ namespace Sir
/// Write a paged index.
/// </summary>
/// <typeparam name="T"></typeparam>
public class IndexSession<T> : IIndexSession<T>, IDisposable
public class IndexSession<T> : IDisposable
{
private readonly IModel<T> _model;
private readonly IIndexReadWriteStrategy _indexingStrategy;
Expand All @@ -20,8 +19,6 @@ public class IndexSession<T> : IIndexSession<T>, IDisposable
private readonly ILogger _logger;
private readonly IndexCache _indexCache;

public SortedList<int, float> EmptyEmbedding = new SortedList<int, float>();

public IndexSession(
string directory,
ulong collectionId,
Expand All @@ -41,12 +38,12 @@ public IndexSession(

public void Put(long docId, long keyId, T value, bool label)
{
var tokens = _model.CreateEmbedding(value, label, EmptyEmbedding);
var tokens = _model.CreateEmbedding(value, label);

Put(docId, keyId, tokens);
}

public void Put(long docId, long keyId, IEnumerable<ISerializableVector> tokens)
private void Put(long docId, long keyId, IEnumerable<ISerializableVector> tokens)
{
VectorNode column;

Expand All @@ -58,27 +55,10 @@ public void Put(long docId, long keyId, IEnumerable<ISerializableVector> tokens)

foreach (var token in tokens)
{
_indexingStrategy.Put<T>(
column,
new VectorNode(vector:token, docId:docId, keyId:keyId));
}
}

public void Put(VectorNode token)
{
VectorNode column;

if (!_index.TryGetValue(token.KeyId.Value, out column))
{
column = new VectorNode();
_index.Add(token.KeyId.Value, column);
}

foreach (var node in PathFinder.All(token))
{
_indexingStrategy.Put<T>(
column,
new VectorNode(node.Vector, docIds: node.DocIds));
if (!token.IsEmptyVector())
_indexingStrategy.Put<T>(
column,
new VectorNode(vector:token, docId:docId, keyId:keyId));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Sir.InformationRetreival/Session/SearchSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ private void Scan(Query query, bool identicalMatchesOnly)

if (hit != null)
{
if (!identicalMatchesOnly || (hit.Score >= _model.IdenticalAngle))
if (!identicalMatchesOnly || hit.Score.Approximates(_model.IdenticalAngle))
{
term.Score = hit.Score;
term.PostingsOffsets = hit.PostingsOffsets;
Expand Down
1 change: 1 addition & 0 deletions src/Sir.KeyValue/ISerializableVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ public interface ISerializableVector
void AverageInPlace(ISerializableVector vector);
ISerializableVector Append(ISerializableVector vector);
ISerializableVector Shift(int numOfPositionsToShift, int numOfDimensions);
bool IsEmptyVector();
}
}
20 changes: 9 additions & 11 deletions src/Sir.Strings/BagOfCharsModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ public class BagOfCharsModel : DistanceCalculator, IModel<string>
public double IdenticalAngle => 0.998d;
public double FoldAngle => 0.55d;
public override int NumOfDimensions => System.Text.Unicode.UnicodeRanges.All.Length;
private readonly SortedList<int, float> _embedding = new SortedList<int, float>();

public IEnumerable<ISerializableVector> CreateEmbedding(string data, bool label, SortedList<int, float> embedding = null)
public IEnumerable<ISerializableVector> CreateEmbedding(string data, bool label)
{
var source = data.ToCharArray();

if (source.Length > 0)
{
if (embedding == null)
embedding = new SortedList<int, float>();
else
embedding.Clear();
_embedding.Clear();

var offset = 0;
int index = 0;
Expand All @@ -28,33 +26,33 @@ public IEnumerable<ISerializableVector> CreateEmbedding(string data, bool label,

if (char.IsLetterOrDigit(c) || char.GetUnicodeCategory(c) == System.Globalization.UnicodeCategory.MathSymbol)
{
embedding.AddOrAppendToComponent(c, 1);
_embedding.AddOrAppendToComponent(c, 1);
}
else
{
if (embedding.Count > 0)
if (_embedding.Count > 0)
{
var len = index - offset;

var vector = new SerializableVector(
embedding,
_embedding,
NumOfDimensions,
label ? new string(source, offset, len) : null);

embedding.Clear();
_embedding.Clear();
yield return vector;
}

offset = index + 1;
}
}

if (embedding.Count > 0)
if (_embedding.Count > 0)
{
var len = index - offset;

var vector = new SerializableVector(
embedding,
_embedding,
NumOfDimensions,
label ? new string(source, offset, len) : null);

Expand Down
Loading

0 comments on commit fccc910

Please sign in to comment.