Skip to content

Commit

Permalink
Remove OncvPlotter.plotly_ methods, use standard mpl methods with obj…
Browse files Browse the repository at this point in the history
….plot(plotly=True)
  • Loading branch information
gmatteo committed Jul 23, 2024
1 parent f080687 commit 090fca1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 90 deletions.
4 changes: 3 additions & 1 deletion abipy/core/func1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,15 @@ def ifft(self, x0=None) -> Function1D:

# return self.__class__(self.mesh, -(2 / np.pi) * wmesh * kk_values)

def plot_ax(self, ax, exchange_xy=False, xfactor=1, yfactor=1, *args, **kwargs) -> list:
def plot_ax(self, ax, exchange_xy=False, normalize=False, xfactor=1, yfactor=1, *args, **kwargs) -> list:
"""
Helper function to plot self on axis ax.
Args:
ax: |matplotlib-Axes|.
exchange_xy: True to exchange the axis in the plot.
args: Positional arguments passed to ax.plot
normalize: Normalize the ydata to 1.
xfactor, yfactor: xvalues and yvalues are multiplied by this factor before plotting.
kwargs: Keyword arguments passed to ``matplotlib``. Accepts
Expand All @@ -504,6 +505,7 @@ def plot_ax(self, ax, exchange_xy=False, xfactor=1, yfactor=1, *args, **kwargs)
xx, yy = self.mesh, data_from_cplx_mode(c, self.values)
if xfactor != 1: xx = xx * xfactor
if yfactor != 1: yy = yy * yfactor
if normalize: yy = np.max(yy)

if exchange_xy:
xx, yy = yy, xx
Expand Down
74 changes: 40 additions & 34 deletions abipy/eph/varpeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, gsr_kpath, ddb, verbose = 0, **anaddb_kwargs):
#def analyze(self)


ITER_LABELS = [
_LABELS = [
r'$E_{pol}$',
r'$E_{el}$',
r'$E_{ph}$',
Expand Down Expand Up @@ -185,7 +185,7 @@ def get_last_iteration_dict_ev(self, spin: int) -> dict:
nstep2cv = nstep2cv_spin[spin]
last_iteration = iter_rec_spin[spin, nstep2cv-1, :] * abu.Ha_eV

return dict(zip(ITER_LABELS, last_iteration))
return dict(zip(_LABELS, last_iteration))

@add_fig_kwargs
def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:
Expand All @@ -211,7 +211,7 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:
xs = np.arange(1, nstep2cv + 1)

for iax, ax in enumerate(ax_mat[spin]):
for ilab, label in enumerate(ITER_LABELS):
for ilab, label in enumerate(_LABELS):
ys = iterations[:,ilab]
if iax == 0:
# Plot energies in linear scale.
Expand Down Expand Up @@ -389,15 +389,16 @@ def plot_bz_sampling(self, what="kpoints", fold=False,
return self.structure.plot_bz(show=False, **kws)

@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 20,
def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 20, normalize: bool = False,
lpratio: int = 5, step: float = 0.1, width: float = 0.2, method: str = "linear",
ax_list=None, ylims=None, scale=10, fontsize=8, **kwargs) -> Figure:
ax_list=None, ylims=None, scale=10, fontsize=12, **kwargs) -> Figure:
"""
Plot electronic energies with markers whose size is proportional to |A_nk|^2.
Args:
ebands_kpath: ElectronBands or Abipy file providing an electronic band structure along a path.
ebands_kmesh: ElectronBands or Abipy file providing an electronic band structure in the IBZ.
normalize: Rescale the two DOS to plot them on the same scale.
method=Interpolation method.
ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
scale: Scaling factor for |A_nk|^2.
Expand Down Expand Up @@ -427,11 +428,12 @@ def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 2
ymin = min(ymin, e)
ymax = max(ymax, e)

points = Marker(x, y, s, color="y")
points = Marker(x, y, s, color="orange")

nrows, ncols = 1, 2
gridspec_kw = {'width_ratios': [2, 1]}
ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False)
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
ax_list = ax_list.ravel()

ax = ax_list[0]
Expand Down Expand Up @@ -462,22 +464,20 @@ def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 2
enes_n = ebands_kmesh.eigens[self.spin, ik, self.bstart:self.bstop]
a2_n = a2_interp.eval_kpoint(kpoint)
for e, a2 in zip(enes_n, a2_n):
ank_dos += weight * a2 * gaussian(mesh, width, center=e - e0)
ank_dos += weight * a2 * gaussian(mesh, width, center=e-e0)

ank_dos = Function1D(mesh, ank_dos)
print("A2(E) integrates to:", ank_dos.integral_value, " Ideally, it should be 1.")

# Rescale the two DOS to plot them on the same scale.
ank_dos = ank_dos / ank_dos.max

ax = ax_list[1]
edos.plot_ax(ax, e0, spin=self.spin, normalize=True, exchange_xy=True, label="eDOS(E)")
edos.plot_ax(ax, e0, spin=self.spin, normalize=normalize, exchange_xy=True, label="eDOS(E)")
ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2$(E)", color=points.color)
ax.set_xlabel("Arbitrary units", fontsize=fontsize)
ank_dos.plot_ax(ax, exchange_xy=True, label=r"$A^2$(E)", color=points.color)
ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)

if ylims is None:
# Automatic ylims.
ymin -= 0.1 * abs(ymin)
ymin -= e0
ymax += 0.1 * abs(ymax)
Expand Down Expand Up @@ -512,26 +512,30 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaget_kwargs=None, **kwargs)
**kwargs)

@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width = 0.001,
def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width = 0.001, normalize: bool=False,
method="linear", verbose=0, anaddb_kwargs=None,
ax=None, scale=10, fontsize=12, **kwargs) -> Figure:
"""
Plot phonon energies with markers whose size is proportional to |B_qnu|^2.
Args:
phbands_qpath: PhononBands or Abipy file providing a phonon band structure.
normalize: Rescale the two DOS to plot them on the same scale.
phdos_file:
method=Interpolation method.
ax: |matplotlib-Axes| or None if a new figure should be created.
scale: Scaling factor for |B_qnu|^2.
"""
with_phdos = phdos_file is not None and ddb is not None
nrows, ncols = 1, 2 if with_phdos else 1
nrows, ncols = 1, 1
gridspec_kw = None
if with_phdos:
ncols, gridspec_kw = 2, {'width_ratios': [2, 1]}

ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False)
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
ax_list = ax_list.ravel()


phbands_qpath = PhononBands.as_phbands(phbands_qpath)
b2_interp = self.get_b2_interpolator(method)

Expand All @@ -549,8 +553,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width
for w, b2 in zip(omegas_nu, b2_nu):
x.append(iq); y.append(w); s.append(scale * b2)

points = Marker(x, y, s, color="yellow")

points = Marker(x, y, s, color="orange")
phbands_qpath.plot(ax=ax_list[0], points=points, show=False)

if not with_phdos:
Expand All @@ -575,13 +578,10 @@ def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width

bqnu_dos = Function1D(mesh, bqnu_dos)

# Rescale the two DOS to plot them on the same scale.
phdos = phdos / phdos.max
bqnu_dos = bqnu_dos / bqnu_dos.max

ax = ax_list[1]
phdos.plot_ax(ax, exchange_xy=True, label="phDOS(E)")
bqnu_dos.plot_ax(ax, exchange_xy=True, label=r"$B^2$(E)", color=points.color)
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)")
bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$B^2$(E)", color=points.color)

ax.set_xlabel("Arbitrary units", fontsize=fontsize)
ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
Expand Down Expand Up @@ -771,7 +771,7 @@ def sort_func(abifile):
data = defaultdict(list)

# Now loop over the sorted files and extract the results of the final iteration.
for i, (label, abifile, nktot) in zip(labels, abifiles, nktot_list):
for i, (label, abifile, nktot) in enumerate(zip(labels, abifiles, nktot_list)):
for k, v in abifile.get_last_iteration_dict_ev(spin).items():
data[k].append(v)

Expand Down Expand Up @@ -807,30 +807,36 @@ def plot_scf_cycle(self, **kwargs) -> Figure:
return fig

@add_fig_kwargs
def plot_kdata(self, fontsie=12, **kwargs) -> Figure:
def plot_kconv(self, fontsize=12, **kwargs) -> Figure:
"""
Plot the convergence of the data wrt to the k-point sampling.
"""
nsppol = self.getattr_alleq("nsppol")

# Build grid of plots.
nrows, ncols = len(ITER_LABELS), nsppol
nrows, ncols = len(_LABELS), nsppol
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
deg = 1
for spin in range(nsppol):
kdata = self.get_kdata_spin(spin)
xs = kdata["minibz_vol"]
xvals = np.linspace(0, 1.1 * xs.max(), 100)
for ix, label in enumerate(ITER_LABELS):
for ix, label in enumerate(_LABELS):
ax = ax_mat[ix, spin]
color = "k"
ys = kdata[label]
# plot ab-initio points.
ax.scatter(xs, ys, color=color, marker="o")
# plot fit.
p = np.poly1d(np.polyfit(xs, ys, deg))
ax = ax_mat[ix,spin]
ax.scatter(xs, ys, marker="o")
ax.plot(xvals, p[xvals], style="k--")
ax.plot(xvals, p(xvals), color=color, ls="--")

ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
#ax.set_xlabel("Iteration", fontsize=fontsize)
#ax.set_ylabel("Energy (eV)" if iax == 0 else r"$|\Delta|$ Energy (eV)", fontsize=fontsize)
ax.set_ylabel(label, fontsize=fontsize)
#ax.legend(loc="right", shadow=True, fontsize=fontsize)
#print([(0, p(0)), (xs[0], ys[0]), (xs[1], ys[1])])

return fig

Expand Down
55 changes: 0 additions & 55 deletions abipy/ppcodes/oncv_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def _add_rc_vlines_ax(self, ax, with_lloc=False) -> None:
ax.axvline(self.parser.rc5, lw=2, color=color, ls="--")
ax._custom_rc_lines.append((self.parser.rc5, color))

def plotly_atan_logders(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_atan_logders(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_atan_logders(self, ax=None, with_xlabel=True,
fontsize: int = 8, **kwargs) -> Figure:
Expand Down Expand Up @@ -156,11 +151,6 @@ def _get_ae_ps_wfs(self, what) -> tuple:
raise ValueError(f"Invalid value for {what=}")
return ae_wfs, ps_wfs

def plotly_radial_wfs(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_radial_wfs(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_radial_wfs(self, ax=None, what="bound_states",
fontsize: int = 8, **kwargs) -> Figure:
Expand Down Expand Up @@ -201,11 +191,6 @@ def plot_radial_wfs(self, ax=None, what="bound_states",

return fig

def plotly_projectors(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_projects(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_projectors(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -235,11 +220,6 @@ def plot_projectors(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:

return fig

def plotly_densities(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_densities(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_densities(self, ax=None, timesr2=False, fontsize: int = 8, **kwargs) -> Figure:
"""
Expand All @@ -262,11 +242,6 @@ def plot_densities(self, ax=None, timesr2=False, fontsize: int = 8, **kwargs) ->
)
return fig

def plotly_der_densities(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_der_densities(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_der_densities(self, ax=None, order=1, acc=4, fontsize=8, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -295,11 +270,6 @@ def plot_der_densities(self, ax=None, order=1, acc=4, fontsize=8, **kwargs) -> F
)
return fig

def plotly_potentials(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_potentials(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_potentials(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
Expand All @@ -322,11 +292,6 @@ def plot_potentials(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:

return fig

def plotly_vtau(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_vtau(*args, show=False, **kwargs))

def plot_vtau(self, xscale="log", ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
Plot v_tau and v_tau(model+pseudo) potentials on axis ax.
Expand All @@ -352,11 +317,6 @@ def plot_vtau(self, xscale="log", ax=None, fontsize: int = 8, **kwargs) -> Figur

return fig

def plotly_tau(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_tau(*args, show=False, **kwargs))

def plot_tau(self, ax=None, yscale="log", fontsize: int = 8, **kwargs) -> Figure:
"""
Plot kinetic energy densities tauPS and tau(M+PS) on axis ax.
Expand All @@ -381,11 +341,6 @@ def plot_tau(self, ax=None, yscale="log", fontsize: int = 8, **kwargs) -> Figure

return fig

def plotly_der_potentials(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_der_potentials(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_der_potentials(self, ax=None, order=1, acc=4, fontsize: int = 8, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -417,11 +372,6 @@ def plot_der_potentials(self, ax=None, order=1, acc=4, fontsize: int = 8, **kwar

return fig

def plotly_kene_vs_ecut(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_kene_vs_ecut(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_kene_vs_ecut(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -449,11 +399,6 @@ def plot_kene_vs_ecut(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:

return fig

def plotly_atanlogder_econv(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_atan_logder_econv(*args, show=False, **kwargs))

@add_fig_kwargs
def plot_atanlogder_econv(self, ax_list=None, fontsize: int = 6, **kwargs) -> Figure:
"""
Expand Down

0 comments on commit 090fca1

Please sign in to comment.