Skip to content

Commit

Permalink
move iterative
Browse files Browse the repository at this point in the history
  • Loading branch information
nonhermitian committed Oct 8, 2024
1 parent 920907f commit 119a646
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 85 deletions.
90 changes: 90 additions & 0 deletions mthree/iterative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# This code is part of Mthree.
#
# (C) Copyright IBM 2024.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=no-name-in-module, invalid-name
import numpy as np
import scipy.sparse.linalg as spla

from mthree.norms import ainv_onenorm_est_iter
from mthree.matvec import M3MatVec
from mthree.utils import counts_to_vector, vector_to_quasiprobs
from mthree.exceptions import M3Error


def iterative_solver(
mitigator,
counts,
qubits,
distance,
tol=1e-5,
max_iter=25,
details=0,
callback=None,
return_mitigation_overhead=False,
):
"""Compute solution using GMRES and Jacobi preconditioning.
Parameters:
counts (dict): Input counts dict.
qubits (int): Qubits over which to calibrate.
tol (float): Tolerance to use.
max_iter (int): Maximum number of iterations to perform.
distance (int): Distance to correct for. Default=num_bits
details (bool): Return col norms.
callback (callable): Callback function to record iteration count.
return_mitigation_overhead (bool): Returns the mitigation overhead, default=False.
Returns:
QuasiDistribution: dict of Quasiprobabilites
Raises:
M3Error: Solver did not converge.
"""
cals = mitigator._form_cals(qubits)
M = M3MatVec(dict(counts), cals, distance)
L = spla.LinearOperator(
(M.num_elems, M.num_elems),
matvec=M.matvec,
rmatvec=M.rmatvec,
dtype=np.float32,
)
diags = M.get_diagonal()

def precond_matvec(x):
out = x / diags
return out

P = spla.LinearOperator(
(M.num_elems, M.num_elems), precond_matvec, dtype=np.float32
)
vec = counts_to_vector(M.sorted_counts)

out, error = spla.gmres(
L,
vec,
rtol=tol,
atol=tol,
maxiter=max_iter,
M=P,
callback=callback,
callback_type="legacy",
)
if error:
raise M3Error("GMRES did not converge: {}".format(error))

gamma = None
if return_mitigation_overhead:
gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter)

quasi = vector_to_quasiprobs(out, M.sorted_counts)
if details:
return quasi, M.get_col_norms(), gamma
return quasi, gamma
80 changes: 4 additions & 76 deletions mthree/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import psutil
import numpy as np
import scipy.sparse.linalg as spla
import orjson
from qiskit.providers import BackendV2
from qiskit_ibm_runtime import SamplerV2
Expand All @@ -33,11 +32,10 @@
balanced_cal_strings,
balanced_cal_circuits,
)
from mthree.utils import counts_to_vector, vector_to_quasiprobs, gmres
from mthree.direct import direct_solver as direct_solve
from mthree.direct import reduced_cal_matrix as cal_matrix
from mthree.norms import ainv_onenorm_est_iter
from mthree.matvec import M3MatVec
from mthree.iterative import iterative_solver

from mthree.exceptions import M3Error
from mthree.classes import QuasiCollection
from ._helpers import system_info
Expand Down Expand Up @@ -639,7 +637,7 @@ def callback(_):

if details:
st = perf_counter()
mit_counts, col_norms, gamma = self._matvec_solver(
mit_counts, col_norms, gamma = iterative_solver(self,
counts,
qubits,
distance,
Expand All @@ -658,7 +656,7 @@ def callback(_):
info["col_norms"] = col_norms
return mit_counts, info
# pylint: disable=unbalanced-tuple-unpacking
mit_counts, gamma = self._matvec_solver(
mit_counts, gamma = iterative_solver(self,
counts,
qubits,
distance,
Expand Down Expand Up @@ -696,76 +694,6 @@ def reduced_cal_matrix(self, counts, qubits, distance=None):
return cal_matrix(self, counts, qubits, distance)


def _matvec_solver(
self,
counts,
qubits,
distance,
tol=1e-5,
max_iter=25,
details=0,
callback=None,
return_mitigation_overhead=False,
):
"""Compute solution using GMRES and Jacobi preconditioning.
Parameters:
counts (dict): Input counts dict.
qubits (int): Qubits over which to calibrate.
tol (float): Tolerance to use.
max_iter (int): Maximum number of iterations to perform.
distance (int): Distance to correct for. Default=num_bits
details (bool): Return col norms.
callback (callable): Callback function to record iteration count.
return_mitigation_overhead (bool): Returns the mitigation overhead, default=False.
Returns:
QuasiDistribution: dict of Quasiprobabilites
Raises:
M3Error: Solver did not converge.
"""
cals = self._form_cals(qubits)
M = M3MatVec(dict(counts), cals, distance)
L = spla.LinearOperator(
(M.num_elems, M.num_elems),
matvec=M.matvec,
rmatvec=M.rmatvec,
dtype=np.float32,
)
diags = M.get_diagonal()

def precond_matvec(x):
out = x / diags
return out

P = spla.LinearOperator(
(M.num_elems, M.num_elems), precond_matvec, dtype=np.float32
)
vec = counts_to_vector(M.sorted_counts)

out, error = gmres(
L,
vec,
rtol=tol,
atol=tol,
maxiter=max_iter,
M=P,
callback=callback,
callback_type="legacy",
)
if error:
raise M3Error("GMRES did not converge: {}".format(error))

gamma = None
if return_mitigation_overhead:
gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter)

quasi = vector_to_quasiprobs(out, M.sorted_counts)
if details:
return quasi, M.get_col_norms(), gamma
return quasi, gamma

def readout_fidelity(self, qubits=None):
"""Compute readout fidelity for calibrated qubits.
Expand Down
11 changes: 5 additions & 6 deletions mthree/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import scipy.sparse.linalg as spla

from mthree.exceptions import M3Error
from mthree.utils import gmres


def ainv_onenorm_est_lu(A, LU=None):
Expand Down Expand Up @@ -140,12 +139,12 @@ def precond_matvec(x):
v = (1.0 / dims) * np.ones(dims, dtype=np.float32)

# Initial solve
v, error = gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P)
v, error = spla.gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P)
if error:
raise M3Error("Iterative solver error {}".format(error))
gamma = la.norm(v, 1)
eta = np.sign(v)
x, error = gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P)
x, error = spla.gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P)
if error:
raise M3Error("Iterative solver error {}".format(error))
# loop over reasonable number of trials
Expand All @@ -155,7 +154,7 @@ def precond_matvec(x):
idx = np.where(np.abs(x) == x_nrm)[0][0]
v = np.zeros(dims, dtype=np.float32)
v[idx] = 1
v, error = gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P)
v, error = spla.gmres(L, v, rtol=tol, atol=tol, maxiter=max_iter, M=P)
if error:
raise M3Error("Iterative solver error {}".format(error))
gamma_prime = gamma
Expand All @@ -165,7 +164,7 @@ def precond_matvec(x):
break

eta = np.sign(v)
x, error = gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P)
x, error = spla.gmres(LT, eta, rtol=tol, atol=tol, maxiter=max_iter, M=P)
if error:
raise M3Error("Iterative solver error {}".format(error))
if la.norm(x, np.inf) == x[idx]:
Expand All @@ -176,7 +175,7 @@ def precond_matvec(x):
x = np.arange(1, dims + 1, dtype=np.float32)
x = (-1) ** (x + 1) * (1 + (x - 1) / (dims - 1))

x, error = gmres(L, x, rtol=tol, atol=tol, maxiter=max_iter, M=P)
x, error = spla.gmres(L, x, rtol=tol, atol=tol, maxiter=max_iter, M=P)
if error:
raise M3Error("Iterative solver error {}".format(error))

Expand Down
3 changes: 0 additions & 3 deletions mthree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"""
import numpy as np
import scipy.sparse.linalg as spla

from qiskit.result import marginal_distribution as marg_dist
from mthree.exceptions import M3Error
Expand All @@ -36,8 +35,6 @@
ProbCollection,
)

gmres = spla.gmres


def final_measurement_mapping(circuit):
"""Return the final measurement mapping for the circuit.
Expand Down

0 comments on commit 119a646

Please sign in to comment.