Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove boilerplate in LDPCDecodersExt and PyQDecodersExt #452

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,23 @@ using QuantumClifford
using QuantumClifford.ECC
import QuantumClifford.ECC: AbstractSyndromeDecoder, decode, parity_checks

struct BeliefPropDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type
H
faults_matrix
n
s
k
cx
cz
bpdecoderx
bpdecoderz
end
abstract type AbstractLDPCDecoder <: AbstractSyndromeDecoder end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think this extra abstract type is necessary.


struct BitFlipDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type
# A common structure to hold shared fields for different decoders
mutable struct GenericLDPCDecoder{D} <: AbstractLDPCDecoder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name does not seem very appropriate to me. These are not that general, although they are a common way to do things. The basic feature here is that the decoders are built out of two classical decoders. Maybe a name like CSSDecoderFromClassical or IndependentCSSDecoders or PairOfClassicalDecoders or PairedClassicalDecoder or CSSClassicalPairDecoders. I do not really have a good suggestion right now.

H
faults_matrix
n
s
k
cx
cz
bfdecoderx
bfdecoderz
decoderx::D
decoderz::D
end

function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing)
# Common constructor for any LDPC decoder (both BitFlip and BeliefProp)
function GenericLDPCDecoder(c, DecoderType; errorrate=nothing, maxiter=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function forces a very specific form that is expected from the decoders. I am not sure that form will be true for the next decoder to be added. This is one of my main worries with this PR

Hx = parity_checks_x(c)
Hz = parity_checks_z(c)
H = parity_checks(c)
Expand All @@ -40,52 +32,29 @@ function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing)
cx = size(Hx, 1)
cz = size(Hz, 1)
fm = faults_matrix(H)

isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"BeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].")
isnothing(errorrate) || 0 ≤ errorrate ≤ 1 || error("`errorrate` must be in the range [0, 1].")
errorrate = isnothing(errorrate) ? 0.0 : errorrate
maxiter = isnothing(maxiter) ? n : maxiter
bpx = LDPCDecoders.BeliefPropagationDecoder(Hx, errorrate, maxiter)
bpz = LDPCDecoders.BeliefPropagationDecoder(Hz, errorrate, maxiter)
decoderx = DecoderType(Hx, errorrate, maxiter)
decoderz = DecoderType(Hz, errorrate, maxiter)
return GenericLDPCDecoder{DecoderType}(H, fm, n, s, k, cx, cz, decoderx, decoderz)
end

return BeliefPropDecoder(H, fm, n, s, k, cx, cz, bpx, bpz)
function BeliefPropDecoder(c; errorrate=nothing, maxiter=nothing)
return GenericLDPCDecoder(c, LDPCDecoders.BeliefPropagationDecoder; errorrate, maxiter)
end

function BitFlipDecoder(c; errorrate=nothing, maxiter=nothing)
Hx = parity_checks_x(c)
Hz = parity_checks_z(c)
H = parity_checks(c)
s, n = size(H)
_, _, r = canonicalize!(Base.copy(H), ranks=true)
k = n - r
cx = size(Hx, 1)
cz = size(Hz, 1)
fm = faults_matrix(H)

isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"BitFlipDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].")
errorrate = isnothing(errorrate) ? 0.0 : errorrate
maxiter = isnothing(maxiter) ? n : maxiter
bfx = LDPCDecoders.BitFlipDecoder(Hx, errorrate, maxiter)
bfz = LDPCDecoders.BitFlipDecoder(Hz, errorrate, maxiter)

return BitFlipDecoder(H, fm, n, s, k, cx, cz, bfx, bfz)
return GenericLDPCDecoder(c, LDPCDecoders.BitFlipDecoder; errorrate, maxiter)
end

parity_checks(d::BeliefPropDecoder) = d.H
parity_checks(d::BitFlipDecoder) = d.H

function decode(d::BeliefPropDecoder, syndrome_sample)
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bpdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bpdecoderz, row_z)
return vcat(guess_x, guess_z)
end
parity_checks(d::GenericLDPCDecoder) = d.H

function decode(d::BitFlipDecoder, syndrome_sample)
function decode(d::GenericLDPCDecoder, syndrome_sample)
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bfdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bfdecoderz, row_z)
guess_z, _ = LDPCDecoders.decode!(d.decoderx, row_x)
guess_x, _ = LDPCDecoders.decode!(d.decoderz, row_z)
return vcat(guess_x, guess_z)
end

Expand Down
105 changes: 34 additions & 71 deletions ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,89 +8,53 @@ import QuantumClifford.ECC: AbstractSyndromeDecoder, decode, batchdecode, parity

abstract type PyBP <: AbstractSyndromeDecoder end

struct PyBeliefPropDecoder <: PyBP # TODO all these decoders have the same fields, maybe we can factor out a common type
# A common structure to hold shared fields for different decoders
mutable struct GenericLDPCDecoder{D} <: PyBP
code
H
Hx
Hz
nx
nz
faults_matrix
pyx
pyz
pyx::D
pyz::D
end

struct PyBeliefPropOSDecoder <: PyBP # TODO all these decoders have the same fields, maybe we can factor out a common type
code
H
Hx
Hz
nx
nz
faults_matrix
pyx
pyz
end

function PyBeliefPropDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing)
Hx = parity_checks_x(c) |> collect # TODO should be sparse
Hz = parity_checks_z(c) |> collect # TODO should be sparse
H = parity_checks(c)
fm = faults_matrix(c)
max_iter=isnothing(maxiter) ? 0 : maxiter
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error(lazy"PyBeliefPropDecoder got an unknown belief propagation method argument. `bpmethod` must be one of :productsum, :minsum, :minsumlog.")
bp_method = get(Dict(:productsum => 0, :minsum => 1, :minsumlog => 2), bpmethod, 0)
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"PyBeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].")
error_rate = isnothing(errorrate) ? PythonCall.Py(nothing) : errorrate
pyx = ldpc.bp_decoder(np.array(Hx); max_iter, bp_method, error_rate) # TODO should be sparse
pyz = ldpc.bp_decoder(np.array(Hz); max_iter, bp_method, error_rate) # TODO should be sparse
return PyBeliefPropDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz)
end

function PyBeliefPropOSDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing, osdmethod=nothing, osdorder=0)
Hx = parity_checks_x(c) |> collect # TODO should be sparse
Hz = parity_checks_z(c) |> collect # TODO should be sparse
# Common function to initialize PyBeliefPropDecoder or PyBeliefPropOSDecoder
function initialize_decoder(c, maxiter, bpmethod, errorrate, osdmethod, osdorder, decoder_type)
Hx = parity_checks_x(c) |> collect # TODO keep these sparse
Hz = parity_checks_z(c) |> collect # TODO keep these sparse
H = parity_checks(c)
fm = faults_matrix(c)
max_iter=isnothing(maxiter) ? 0 : maxiter
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error(lazy"PyBeliefPropDecoder got an unknown belief propagation method argument. `bpmethod` must be one of :productsum, :minsum, :minsumlog.")
max_iter = isnothing(maxiter) ? 0 : maxiter
bpmethod ∈ (nothing, :productsum, :minsum, :minsumlog) || error("Unknown bpmethod")
bp_method = get(Dict(:productsum => 0, :minsum => 1, :minsumlog => 2), bpmethod, 0)
isnothing(errorrate) || 0≤errorrate≤1 || error(lazy"PyBeliefPropDecoder got an invalid error rate argument. `errorrate` must be in the range [0, 1].")
error_rate = isnothing(errorrate) ? PythonCall.Py(nothing) : errorrate
isnothing(osdmethod) || osdmethod ∈ (:zeroorder, :exhaustive, :combinationsweep) || error(lazy"PyBeliefPropOSDecoder got an unknown OSD method argument. `osdmethod` must be one of :zeroorder, :exhaustive, :combinationsweep.")
osd_method = get(Dict(:zeroorder => "osd0", :exhaustive => "osde", :combinationsweep => "osdcs"), osdmethod, 0)
0≤osdorder || error(lazy"PyBeliefPropOSDecoder got an invalid OSD order argument. `osdorder` must be ≥0.")
osd_order = osdorder
pyx = ldpc.bposd_decoder(np.array(Hx); max_iter, bp_method, error_rate, osd_method, osd_order) # TODO should be sparse
pyz = ldpc.bposd_decoder(np.array(Hz); max_iter, bp_method, error_rate, osd_method, osd_order) # TODO should be sparse
return PyBeliefPropOSDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz)
osd_method = isnothing(osdmethod) ? "osd0" : osdmethod
if decoder_type == :beliefprop
pyx = ldpc.bp_decoder(np.array(Hx); max_iter, bp_method, error_rate) # TODO keep these sparse
pyz = ldpc.bp_decoder(np.array(Hz); max_iter, bp_method, error_rate) # TODO keep these sparse
elseif decoder_type == :beliefprop_os
pyx = ldpc.bposd_decoder(np.array(Hx); max_iter, bp_method, error_rate, osd_method, osdorder)
pyz = ldpc.bposd_decoder(np.array(Hz); max_iter, bp_method, error_rate, osd_method, osdorder)
else
error("Unknown decoder type.")
end
Comment on lines +35 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a major concern. The previous approach was extendable. This if-else chain is not. Adding a new decoder would require a modification of this function (so it can not be done in an independent package), which is the main problem that Julia's multiple dispatch is meant to solve.

return GenericPyLDPCDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz)
end

parity_checks(d::PyBP) = d.H

function decode(d::PyBP, syndrome_sample)
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x)))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z)))
vcat(guess_x_errors, guess_z_errors)
function PyBeliefPropDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing)
return initialize_decoder(c, maxiter, bpmethod, errorrate, nothing, 0, :beliefprop)
end

struct PyMatchingDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type
code
H
Hx
Hz
nx
nz
faults_matrix
pyx
pyz
function PyBeliefPropOSDecoder(c; maxiter=nothing, bpmethod=nothing, errorrate=nothing, osdmethod=nothing, osdorder=0)
return initialize_decoder(c, maxiter, bpmethod, errorrate, osdmethod, osdorder, :beliefprop_os)
end

function PyMatchingDecoder(c; weights=nothing)
Hx = parity_checks_x(c) |> collect # TODO keep these sparse
Hz = parity_checks_z(c) |> collect
Hz = parity_checks_z(c) |> collect # TODO keep these sparse
H = parity_checks(c)
fm = faults_matrix(c)
if isnothing(weights)
Expand All @@ -100,26 +64,25 @@ function PyMatchingDecoder(c; weights=nothing)
pyx = pm.Matching.from_check_matrix(Hx, weights=weights)
pyz = pm.Matching.from_check_matrix(Hz, weights=weights)
end
return PyMatchingDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz)
return GenericPyLDPCDecoder(c, H, Hx, Hz, size(Hx, 1), size(Hz, 1), fm, pyx, pyz)
end

parity_checks(d::PyMatchingDecoder) = d.H
parity_checks(d::GenericPyLDPCDecoder) = d.H

function decode(d::PyMatchingDecoder, syndrome_sample)
function decode(d::GenericPyLDPCDecoder, syndrome_sample)
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(row_z))
vcat(guess_x_errors, guess_z_errors)
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x)))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z)))
return vcat(guess_x_errors, guess_z_errors)
end

function batchdecode(d::PyMatchingDecoder, syndrome_samples)
function batchdecode(d::GenericPyLDPCDecoder, syndrome_samples)
row_x = @view syndrome_samples[:,1:d.nx]
row_z = @view syndrome_samples[:,d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode_batch(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode_batch(row_z))
n_cols_x = size(guess_x_errors, 2)
hcat(guess_x_errors, guess_z_errors)
return hcat(guess_x_errors, guess_z_errors)
end

end
Loading