Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: UInt256: Divide, Lsh, Rsh, Exp, ExpMod, SubtractMod; Int256: Multiply - when in and out params are referenced to the same struct #33

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion src/Nethermind.Int256.Test/UInt256Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ public virtual void Add((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

a.Add(b, out a);
a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -58,6 +63,19 @@ public virtual void AddOverflow((BigInteger A, BigInteger B) test)
{
overflow.Should().Be(true);
}

overflow = T.AddOverflow(uint256a, uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
if (test.A + test.B <= (BigInteger)UInt256.MaxValue)
{
overflow.Should().Be(false);
}
else
{
overflow.Should().Be(true);
}
}

[TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))]
Expand All @@ -78,6 +96,11 @@ public virtual void AddMod((BigInteger A, BigInteger B, BigInteger M) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.AddMod(uint256b, uint256m, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -98,6 +121,11 @@ public virtual void Subtract((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.Subtract(uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))]
Expand Down Expand Up @@ -130,6 +158,11 @@ protected void SubtractModCore((BigInteger A, BigInteger B, BigInteger M) test,
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.SubtractMod(uint256b, uint256m, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -148,13 +181,18 @@ public virtual void SubtractOverflow((BigInteger A, BigInteger B) test)
{
UInt256 res = a - b;
res.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);
}
else
{
uint256a.Subtract(uint256b, out T res);
res.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);

uint256a.Subtract(uint256b, out uint256a);
uint256a.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);
}
resUInt256.Should().Be(resBigInt);
}
else
{
Expand All @@ -169,6 +207,10 @@ public virtual void SubtractOverflow((BigInteger A, BigInteger B) test)
uint256a.Subtract(uint256b, out T res);
res.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);

uint256a.Subtract(uint256b, out uint256a);
uint256a.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);
}
}
}
Expand All @@ -185,6 +227,11 @@ public virtual void Multiply((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.Multiply(uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))]
Expand All @@ -204,6 +251,11 @@ public virtual void MultiplyMod((BigInteger A, BigInteger B, BigInteger M) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.MultiplyMod(uint256b, uint256m, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -222,6 +274,11 @@ public virtual void Div((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.Divide(in uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -236,6 +293,11 @@ public virtual void And((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

T.And(uint256a, uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -250,6 +312,11 @@ public virtual void Or((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

T.Or(uint256a, uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))]
Expand All @@ -264,6 +331,11 @@ public virtual void Xor((BigInteger A, BigInteger B) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

T.Xor(uint256a, uint256b, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))]
Expand All @@ -279,6 +351,11 @@ public virtual void Exp((BigInteger A, int n) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.Exp(convertFromInt(test.n), out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))]
Expand All @@ -300,6 +377,11 @@ public virtual void ExpMod((BigInteger A, BigInteger B, BigInteger M) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.ExpMod(uint256b, uint256m, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))]
Expand All @@ -318,6 +400,11 @@ public virtual void Lsh((BigInteger A, int n) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.LeftShift(test.n, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))]
Expand All @@ -336,6 +423,11 @@ public virtual void Rsh((BigInteger A, int n) test)
res.Convert(out BigInteger resUInt256);

resUInt256.Should().Be(resBigInt);

uint256a.RightShift(test.n, out uint256a);
uint256a.Convert(out resUInt256);

resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(UnaryOps), nameof(UnaryOps.TestCases))]
Expand All @@ -359,6 +451,10 @@ public virtual void Not(BigInteger test)
T.Not(uint256, out T res);
res.Convert(out BigInteger resUInt256);
resUInt256.Should().Be(resBigInt);

T.Not(in uint256, out uint256);
uint256.Convert(out resUInt256);
resUInt256.Should().Be(resBigInt);
}

[TestCaseSource(typeof(UnaryOps), nameof(UnaryOps.TestCases))]
Expand Down
2 changes: 1 addition & 1 deletion src/Nethermind.Int256/Int256.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ public static void Multiply(in Int256 a, in Int256 b, out Int256 res)
b.Neg(out bv);
}
UInt256.Multiply(av._value, bv._value, out UInt256 ures);
res = new Int256(ures);
int aSign = a.Sign;
int bSign = b.Sign;
res = new Int256(ures);
if ((aSign < 0 && bSign < 0) || (aSign >= 0 && bSign >= 0))
{
return;
Expand Down
27 changes: 12 additions & 15 deletions src/Nethermind.Int256/UInt256.cs
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)

public void Subtract(in UInt256 b, out UInt256 res) => Subtract(this, b, out res);

public static void SubtractMod(in UInt256 a, in UInt256 b, in UInt256 m, out UInt256 res)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to remove in here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I made it intentionaly.
This will create a copy of a.
So when you use it like
UInt256.SubtractMod(a, in b, out a)
and inside the method you reach
if (SubtractUnderflow(a, b, out res))
this line will not override a (as a and res are the same) (and we need original value of a later)

I could just made a copy of a inside the method instead, but why do we need to pass it as reference and then just copy it?

Same situation you can find in
static void Exp(in UInt256 b, in UInt256 e, out UInt256 result)
where I removed a line UInt256 bs = b; and change signature to
public static void Exp(UInt256 b, in UInt256 e, out UInt256 result)

Same in ExpMod

Copy link
Author

@Abalioha Abalioha Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apple M1 Pro, .NET 7.0 : .NET 7.0.5 (7.0.523.17405), Arm64 RyuJIT AdvSIMD

Method Mean Error StdDev Ratio
SubtractMod_UInt256 5.878 ns 0.0035 ns 0.0029 ns 1.00
SubtractMod_UInt256_withoutIn 5.741 ns 0.0039 ns 0.0032 ns 0.98

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if I will change

public static void Multiply(in Int256 a, in Int256 b, out Int256 res)
{
    Int256 av = a, bv = b;
    if (a.Sign < 0)
    {
        a.Neg(out av);
    }
    if (b.Sign < 0)
    {
        b.Neg(out bv);
    }
    UInt256.Multiply(av._value, bv._value, out UInt256 ures);
    int aSign = a.Sign;
    int bSign = b.Sign;
    res = new Int256(ures);
...

to

public static void MultiplyWithoutIn(Int256 a, Int256 b, out Int256 res)
{
    int aSign = a.Sign;
    int bSign = b.Sign;

    if (aSign < 0) a.Neg(out a);
    if (bSign < 0) b.Neg(out b);
    UInt256.Multiply(a._value, b._value, out UInt256 ures);

    res = new Int256(ures);
...

(remove in and Int256 av = a, bv = b;)
this will give improvement too

Apple M1 Pro, .NET 7.0 : .NET 7.0.5 (7.0.523.17405), Arm64 RyuJIT AdvSIMD

Method EnvironmentVariables Mean Ratio
Multiply_Int256 Empty 9.861 ns 1.00
Multiply_Int256_withoutIn Empty 9.463 ns 0.96
Multiply_Int256 DOTNET_EnableHWIntrinsic=0 19.376 ns 1.96
Multiply_Int256_withoutIn DOTNET_EnableHWIntrinsic=0 18.132 ns 1.84

public static void SubtractMod(UInt256 a, in UInt256 b, in UInt256 m, out UInt256 res)
{
if (SubtractUnderflow(a, b, out res))
{
Expand Down Expand Up @@ -1024,40 +1024,38 @@ private void Squared(out UInt256 result)
result = new UInt256(res);
}

public static void Exp(in UInt256 b, in UInt256 e, out UInt256 result)
public static void Exp(UInt256 b, in UInt256 e, out UInt256 result)
{
result = One;
UInt256 bs = b;
int len = e.BitLen;
for (int i = 0; i < len; i++)
{
if (e.Bit(i))
{
Multiply(result, bs, out result);
Multiply(result, b, out result);
}
bs.Squared(out bs);
b.Squared(out b);
}
}

public void Exp(in UInt256 exp, out UInt256 res) => Exp(this, exp, out res);

public static void ExpMod(in UInt256 b, in UInt256 e, in UInt256 m, out UInt256 result)
public static void ExpMod(UInt256 b, in UInt256 e, in UInt256 m, out UInt256 result)
{
if (m.IsOne)
{
result = Zero;
return;
}
result = One;
UInt256 bs = b;
int len = e.BitLen;
for (int i = 0; i < len; i++)
{
if (e.Bit(i))
{
MultiplyMod(result, bs, m, out result);
MultiplyMod(result, b, m, out result);
}
MultiplyMod(bs, bs, m, out bs);
MultiplyMod(b, b, m, out b);
}
}

Expand Down Expand Up @@ -1193,9 +1191,10 @@ public static void Divide(in UInt256 x, in UInt256 y, out UInt256 res)
// At this point, we know
// x/y ; x > y > 0

res = default; // initialize with zeros
const int length = 4;
Udivrem(ref Unsafe.As<UInt256, ulong>(ref res), ref Unsafe.As<UInt256, ulong>(ref Unsafe.AsRef(in x)), length, y, out UInt256 _);
UInt256 quot = default;
Udivrem(ref Unsafe.As<UInt256, ulong>(ref quot), ref Unsafe.As<UInt256, ulong>(ref Unsafe.AsRef(in x)), length, y, out UInt256 _);
res = quot;
}

public void Divide(in UInt256 a, out UInt256 res) => Divide(this, a, out res);
Expand Down Expand Up @@ -1278,8 +1277,7 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res)
}
}

res = Zero;
ulong z0 = res.u0, z1 = res.u1, z2 = res.u2, z3 = res.u3;
ulong z0 = 0, z1 = 0, z2 = 0, z3 = 0;
ulong a = 0, b = 0;
// Big swaps first
if (n > 192)
Expand Down Expand Up @@ -1372,8 +1370,7 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res)
}
}

res = Zero;
ulong z0 = res.u0, z1 = res.u1, z2 = res.u2, z3 = res.u3;
ulong z0 = 0, z1 = 0, z2 = 0, z3 = 0;
ulong a = 0, b = 0;
// Big swaps first
if (n > 192)
Expand Down