-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove boilerplate in LDPCDecodersExt
and PyQDecodersExt
#452
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,31 +6,23 @@ using QuantumClifford | |
using QuantumClifford.ECC | ||
import QuantumClifford.ECC: AbstractSyndromeDecoder, decode, parity_checks | ||
|
||
struct BeliefPropDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type | ||
H | ||
faults_matrix | ||
n | ||
s | ||
k | ||
cx | ||
cz | ||
bpdecoderx | ||
bpdecoderz | ||
end | ||
abstract type AbstractLDPCDecoder <: AbstractSyndromeDecoder end | ||
|
||
struct BitFlipDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type | ||
# A common structure to hold shared fields for different decoders | ||
mutable struct GenericLDPCDecoder{D} <: AbstractLDPCDecoder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name does not seem very appropriate to me. These are not that general, although they are a common way to do things. The basic feature here is that the decoders are built out of two classical decoders. Maybe a name like |
||
H | ||
faults_matrix | ||
n | ||
s | ||
k | ||
cx | ||
cz | ||
bfdecoderx | ||
bfdecoderz | ||
decoderx::D | ||
decoderz::D | ||
end | ||
|
||
function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing) | ||
# Common constructor for any LDPC decoder (both BitFlip and BeliefProp) | ||
function GenericLDPCDecoder(c, DecoderType; errorrate=nothing, maxiter=nothing) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function forces a very specific form that is expected from the decoders. I am not sure that form will be true for the next decoder to be added. This is one of my main worries with this PR |
||
Hx = parity_checks_x(c) | ||
Hz = parity_checks_z(c) | ||
H = parity_checks(c) | ||
|
@@ -40,52 +32,29 @@ function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing) | |
cx = size(Hx, 1) | ||
cz = size(Hz, 1) | ||
fm = faults_matrix(H) | ||
|
||
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"BeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].") | ||
isnothing(errorrate) || 0 ≤ errorrate ≤ 1 || error("`errorrate` must be in the range [0, 1].") | ||
errorrate = isnothing(errorrate) ? 0.0 : errorrate | ||
maxiter = isnothing(maxiter) ? n : maxiter | ||
bpx = LDPCDecoders.BeliefPropagationDecoder(Hx, errorrate, maxiter) | ||
bpz = LDPCDecoders.BeliefPropagationDecoder(Hz, errorrate, maxiter) | ||
decoderx = DecoderType(Hx, errorrate, maxiter) | ||
decoderz = DecoderType(Hz, errorrate, maxiter) | ||
return GenericLDPCDecoder{DecoderType}(H, fm, n, s, k, cx, cz, decoderx, decoderz) | ||
end | ||
|
||
return BeliefPropDecoder(H, fm, n, s, k, cx, cz, bpx, bpz) | ||
function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing) | ||
return GenericLDPCDecoder(c, LDPCDecoders.BeliefPropagationDecoder; errorrate, maxiter) | ||
end | ||
|
||
function BitFlipDecoder(c; errorrate=nothing, maxiter=nothing) | ||
Hx = parity_checks_x(c) | ||
Hz = parity_checks_z(c) | ||
H = parity_checks(c) | ||
s, n = size(H) | ||
_, _, r = canonicalize!(Base.copy(H), ranks=true) | ||
k = n - r | ||
cx = size(Hx, 1) | ||
cz = size(Hz, 1) | ||
fm = faults_matrix(H) | ||
|
||
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"BitFlipDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].") | ||
errorrate = isnothing(errorrate) ? 0.0 : errorrate | ||
maxiter = isnothing(maxiter) ? n : maxiter | ||
bfx = LDPCDecoders.BitFlipDecoder(Hx, errorrate, maxiter) | ||
bfz = LDPCDecoders.BitFlipDecoder(Hz, errorrate, maxiter) | ||
|
||
return BitFlipDecoder(H, fm, n, s, k, cx, cz, bfx, bfz) | ||
return GenericLDPCDecoder(c, LDPCDecoders.BitFlipDecoder; errorrate, maxiter) | ||
end | ||
|
||
parity_checks(d::BeliefPropDecoder) = d.H | ||
parity_checks(d::BitFlipDecoder) = d.H | ||
|
||
function decode(d::BeliefPropDecoder, syndrome_sample) | ||
row_x = @view syndrome_sample[1:d.cx] | ||
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz] | ||
guess_z, success = LDPCDecoders.decode!(d.bpdecoderx, row_x) | ||
guess_x, success = LDPCDecoders.decode!(d.bpdecoderz, row_z) | ||
return vcat(guess_x, guess_z) | ||
end | ||
parity_checks(d::GenericLDPCDecoder) = d.H | ||
|
||
function decode(d::BitFlipDecoder, syndrome_sample) | ||
function decode(d::GenericLDPCDecoder, syndrome_sample) | ||
row_x = @view syndrome_sample[1:d.cx] | ||
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz] | ||
guess_z, success = LDPCDecoders.decode!(d.bfdecoderx, row_x) | ||
guess_x, success = LDPCDecoders.decode!(d.bfdecoderz, row_z) | ||
guess_z, _ = LDPCDecoders.decode!(d.decoderx, row_x) | ||
guess_x, _ = LDPCDecoders.decode!(d.decoderz, row_z) | ||
return vcat(guess_x, guess_z) | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,89 +8,53 @@ import QuantumClifford.ECC: AbstractSyndromeDecoder, decode, batchdecode, parity | |
|
||
abstract type PyBP <: AbstractSyndromeDecoder end | ||
|
||
struct PyBeliefPropDecoder <: PyBP # TODO all these decoders have the same fields, maybe we can factor out a common type | ||
# A common structure to hold shared fields for different decoders | ||
mutable struct GenericLDPCDecoder{D} <: PyBP | ||
code | ||
H | ||
Hx | ||
Hz | ||
nx | ||
nz | ||
faults_matrix | ||
pyx | ||
pyz | ||
pyx::D | ||
pyz::D | ||
end | ||
|
||
struct PyBeliefPropOSDecoder <: PyBP # TODO all these decoders have the same fields, maybe we can factor out a common type | ||
code | ||
H | ||
Hx | ||
Hz | ||
nx | ||
nz | ||
faults_matrix | ||
pyx | ||
pyz | ||
end | ||
|
||
function PyBeliefPropDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing) | ||
Hx = parity_checks_x(c) |> collect # TODO should be sparse | ||
Hz = parity_checks_z(c) |> collect # TODO should be sparse | ||
H = parity_checks(c) | ||
fm = faults_matrix(c) | ||
max_iter=isnothing(maxiter) ? 0 : maxiter | ||
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error(lazy"PyBeliefPropDecoder got an unknown belief propagation method argument. `bpmethod` must be one of :productsum, :minsum, :minsumlog.") | ||
bp_method = get(Dict(:productsum => 0, :minsum => 1, :minsumlog => 2), bpmethod, 0) | ||
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"PyBeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].") | ||
error_rate = isnothing(errorrate) ? PythonCall.Py(nothing) : errorrate | ||
pyx = ldpc.bp_decoder(np.array(Hx); max_iter, bp_method, error_rate) # TODO should be sparse | ||
pyz = ldpc.bp_decoder(np.array(Hz); max_iter, bp_method, error_rate) # TODO should be sparse | ||
return PyBeliefPropDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz) | ||
end | ||
|
||
function PyBeliefPropOSDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing, osdmethod=nothing, osdorder=0) | ||
Hx = parity_checks_x(c) |> collect # TODO should be sparse | ||
Hz = parity_checks_z(c) |> collect # TODO should be sparse | ||
# Common function to initialize PyBeliefPropDecoder or PyBeliefPropOSDecoder | ||
function initialize_decoder(c, maxiter, bpmethod, errorrate, osdmethod, osdorder, decoder_type) | ||
Hx = parity_checks_x(c) |> collect # TODO keep these sparse | ||
Hz = parity_checks_z(c) |> collect # TODO keep these sparse | ||
H = parity_checks(c) | ||
fm = faults_matrix(c) | ||
max_iter=isnothing(maxiter) ? 0 : maxiter | ||
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error(lazy"PyBeliefPropDecoder got an unknown belief propagation method argument. `bpmethod` must be one of :productsum, :minsum, :minsumlog.") | ||
max_iter = isnothing(maxiter) ? 0 : maxiter | ||
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error("Unknown bpmethod") | ||
bp_method = get(Dict(:productsum => 0, :minsum => 1, :minsumlog => 2), bpmethod, 0) | ||
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"PyBeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].") | ||
error_rate = isnothing(errorrate) ? PythonCall.Py(nothing) : errorrate | ||
isnothing(osdmethod) || osdmethod ∈ (:zeroorder, :exhaustive, :combinationsweep) || error(lazy"PyBeliefPropOSDecoder got an unknown OSD method argument. `osdmethod` must be one of :zeroorder, :exhaustive, :combinationsweep.") | ||
osd_method = get(Dict(:zeroorder => "osd0", :exhaustive => "osde", :combinationsweep => "osdcs"), osdmethod, 0) | ||
0≤osdorder || error(lazy"PyBeliefPropOSDecoder got an invalid OSD order argument. `osdorder` must be ≥0.") | ||
osd_order = osdorder | ||
pyx = ldpc.bposd_decoder(np.array(Hx); max_iter, bp_method, error_rate, osd_method, osd_order) # TODO should be sparse | ||
pyz = ldpc.bposd_decoder(np.array(Hz); max_iter, bp_method, error_rate, osd_method, osd_order) # TODO should be sparse | ||
return PyBeliefPropOSDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz) | ||
osd_method = isnothing(osdmethod) ? "osd0" : osdmethod | ||
if decoder_type == :beliefprop | ||
pyx = ldpc.bp_decoder(np.array(Hx); max_iter, bp_method, error_rate) # TODO keep these sparse | ||
pyz = ldpc.bp_decoder(np.array(Hz); max_iter, bp_method, error_rate) # TODO keep these sparse | ||
elseif decoder_type == :beliefprop_os | ||
pyx = ldpc.bposd_decoder(np.array(Hx); max_iter, bp_method, error_rate, osd_method, osdorder) | ||
pyz = ldpc.bposd_decoder(np.array(Hz); max_iter, bp_method, error_rate, osd_method, osdorder) | ||
else | ||
error("Unknown decoder type.") | ||
end | ||
Comment on lines
+35
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also a major concern. The previous approach was extendable. This if-else chain is not. Adding a new decoder would require a modification of this function (so it can not be done in an independent package), which is the main problem that Julia's multiple dispatch is meant to solve. |
||
return GenericPyLDPCDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz) | ||
end | ||
|
||
parity_checks(d::PyBP) = d.H | ||
|
||
function decode(d::PyBP, syndrome_sample) | ||
row_x = @view syndrome_sample[1:d.nx] | ||
row_z = @view syndrome_sample[d.nx+1:end] | ||
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x))) | ||
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z))) | ||
vcat(guess_x_errors, guess_z_errors) | ||
function PyBeliefPropDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing) | ||
return initialize_decoder(c, maxiter, bpmethod, errorrate, nothing, 0, :beliefprop) | ||
end | ||
|
||
struct PyMatchingDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type | ||
code | ||
H | ||
Hx | ||
Hz | ||
nx | ||
nz | ||
faults_matrix | ||
pyx | ||
pyz | ||
function PyBeliefPropOSDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing, osdmethod=nothing, osdorder=0) | ||
return initialize_decoder(c, maxiter, bpmethod, errorrate, osdmethod, osdorder, :beliefprop_os) | ||
end | ||
|
||
function PyMatchingDecoder(c; weights=nothing) | ||
Hx = parity_checks_x(c) |> collect # TODO keep these sparse | ||
Hz = parity_checks_z(c) |> collect | ||
Hz = parity_checks_z(c) |> collect # TODO keep these sparse | ||
H = parity_checks(c) | ||
fm = faults_matrix(c) | ||
if isnothing(weights) | ||
|
@@ -100,26 +64,25 @@ function PyMatchingDecoder(c; weights=nothing) | |
pyx = pm.Matching.from_check_matrix(Hx, weights=weights) | ||
pyz = pm.Matching.from_check_matrix(Hz, weights=weights) | ||
end | ||
return PyMatchingDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz) | ||
return GenericPyLDPCDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz) | ||
end | ||
|
||
parity_checks(d::PyMatchingDecoder) = d.H | ||
parity_checks(d::GenericPyLDPCDecoder) = d.H | ||
|
||
function decode(d::PyMatchingDecoder, syndrome_sample) | ||
function decode(d::GenericPyLDPCDecoder, syndrome_sample) | ||
row_x = @view syndrome_sample[1:d.nx] | ||
row_z = @view syndrome_sample[d.nx+1:end] | ||
guess_z_errors = PythonCall.PyArray(d.pyx.decode(row_x)) | ||
guess_x_errors = PythonCall.PyArray(d.pyz.decode(row_z)) | ||
vcat(guess_x_errors, guess_z_errors) | ||
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x))) | ||
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z))) | ||
return vcat(guess_x_errors, guess_z_errors) | ||
end | ||
|
||
function batchdecode(d::PyMatchingDecoder, syndrome_samples) | ||
function batchdecode(d::GenericPyLDPCDecoder, syndrome_samples) | ||
row_x = @view syndrome_samples[:,1:d.nx] | ||
row_z = @view syndrome_samples[:,d.nx+1:end] | ||
guess_z_errors = PythonCall.PyArray(d.pyx.decode_batch(row_x)) | ||
guess_x_errors = PythonCall.PyArray(d.pyz.decode_batch(row_z)) | ||
n_cols_x = size(guess_x_errors, 2) | ||
hcat(guess_x_errors, guess_z_errors) | ||
return hcat(guess_x_errors, guess_z_errors) | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think this extra abstract type is necessary.