Skip to content

Commit

Permalink
sync with abipy/develop (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezhique authored Feb 7, 2025
1 parent 326bef1 commit ea026f0
Showing 1 changed file with 102 additions and 19 deletions.
121 changes: 102 additions & 19 deletions abipy/eph/varpeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:
@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath,
ebands_kmesh=None, lpratio: int = 5, with_info = True, with_legend=True,
with_ibz_a2dos=True, method="gaussian", step: float = 0.05, width: float = 0.1,
with_ibz_a2dos=True, method="gaussian", step="auto", width="auto",
nksmall: int = 20, normalize: bool = False, with_title=True, interp_method="linear",
ax_mat=None, ylims=None, scale=50, marker_color="gold", marker_edgecolor="gray",
marker_alpha=0.5, fontsize=12, lw_bands=1.0, lw_dos=1.0,
Expand Down Expand Up @@ -623,7 +623,7 @@ def plot_ank_with_ebands(self, ebands_kpath,
x, y, s = [], [], []

a2_max = np.max(np.abs(a_data[pstate]))**2
scale *= 1./a2_max
scale *= 1. / a2_max

for ik, kpoint in enumerate(ebands_kpath.kpoints):
enes_n = ebands_kpath.eigens[self.spin, ik, self.bstart:self.bstop]
Expand Down Expand Up @@ -671,6 +671,14 @@ def plot_ank_with_ebands(self, ebands_kpath,
ebands_kmesh = ElectronBands.as_ebands(ebands_kmesh)

# Get electronic DOS from ebands_kmesh.
# Maybe it's better to set N bandwidth divisions & width factor instead of
# the step and width arguments themselves?
bandwidth = ylims[1] - ylims[0] if ylims else 1.2*(ymax - ymin)
if step == "auto":
step = bandwidth / 200
if width == "auto":
width = 2.0 * step

edos_kws = dict(method=method, step=step, width=width)
edos = ebands_kmesh.get_edos(**edos_kws)
edos_mesh = edos.spin_dos[self.spin].mesh
Expand Down Expand Up @@ -792,18 +800,21 @@ def plot_ank_with_ebands(self, ebands_kpath,
elif pkind == "electron":
fill_from, fill_to = shifted_bm, shifted_bm + filter_value

ax.axhline(fill_from, c='k', zorder=0, lw=lw_dos)
ax.axhline(fill_to, c='k', zorder=0, lw=lw_dos)
ax.fill_between(xrange, ylims[0], fill_from,
color='gray', linewidth=lw_dos, alpha=0.5, zorder=0)
color='lightgray', linewidth=0, alpha=0.5, zorder=0)
ax.fill_between(xrange, fill_to, ylims[1],
color='gray', linewidth=lw_dos, alpha=0.5, zorder=0)
color='lightgray', linewidth=0, alpha=0.5, zorder=0)

if with_title:
fig.suptitle(self.get_title(with_gaps=True))

return fig

@add_fig_kwargs
def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs) -> Figure:
def plot_bqnu_with_ddb(self, ddb, smearing_ev=0.001,
with_phdos=True, anaddb_kwargs=None, **kwargs) -> Figure:
"""
High-level interface to plot phonon energies with markers whose size is proportional to |B_qnu|^2.
Similar to plot_bqnu_with_phbands but this function receives in input a DdbFile or a
Expand All @@ -817,6 +828,7 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs)
"""
ddb = DdbFile.as_ddb(ddb)
anaddb_kwargs = {} if anaddb_kwargs is None else anaddb_kwargs
anaddb_kwargs["dos_method"] = f"gaussian:{smearing_ev} eV"

with ddb.anaget_phbst_and_phdos_files(**anaddb_kwargs) as g:
phbst_file, phdos_file = g[0], g[1]
Expand All @@ -826,11 +838,12 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs)
ddb=ddb, **kwargs)

@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
def plot_bqnu_with_phbands(self, phbands_qpath, with_legend=True,
phdos_file=None, ddb=None, width=0.001, normalize: bool = True,
verbose=0, anaddb_kwargs=None, with_title=True, interp_method="linear",
ax_mat=None, scale=10, marker_color="gold", marker_edgecolor='gray',
marker_alpha=0.8, fontsize=12, lw_bands=1.0, lw_dos=1.0, **kwargs) -> Figure:
ax_mat=None, scale=50, marker_color="gold", marker_edgecolor='gray',
marker_alpha=0.5, fontsize=12, lw_bands=1.0, lw_dos=1.0,
fill_dos=True, **kwargs) -> Figure:
"""
Plot phonon energies with markers whose size is proportional to |B_qnu|^2.
Expand Down Expand Up @@ -859,30 +872,52 @@ def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,

phbands_qpath = PhononBands.as_phbands(phbands_qpath)

# maybe this is unnecessary
if anaddb_file:
phbands_qpath.read_non_anal_from_file(anaddb_file)

# Get interpolators for B_qnu
b2_interp_state = self.get_b2_interpolator_state(interp_method)

b_data, *_ = self.insert_b_inbox(fill_value=0)

# TODO: need to fix this hardcoded representation
units = 'meV'
units_scale = 1e3 if units == 'meV' else 1

# Plot phonon bands with markers.
ymin, ymax = +np.inf, -np.inf
for pstate in range(self.nstates):
x, y, s = [], [], []

b2_max = np.max(np.abs(b_data[pstate]))**2
scale *= 1. / b2_max

for iq, qpoint in enumerate(phbands_qpath.qpoints):
omegas_nu = phbands_qpath.phfreqs[iq,:]

for w, b2 in zip(omegas_nu, b2_interp_state[pstate].eval_kpoint(qpoint), strict=True):
w *= units_scale
x.append(iq); y.append(w); s.append(scale * b2)
ymin, ymax = min(ymin, w), max(ymax, w)

ax = ax_mat[pstate, 0]
points = Marker(x, y, s, color=marker_color, edgecolors=marker_edgecolor,
alpha=marker_alpha, label=r'$|B_{\nu\mathbf{q}}|^2$')
phbands_qpath.plot(ax=ax, points=points, show=False, linewidth=lw_bands)
phbands_qpath.plot(ax=ax, points=points, show=False, linewidth=lw_bands, units=units)
ax.legend(loc="best", shadow=True, fontsize=fontsize)

if pstate != self.nstates - 1:
if pstate != self.nstates - 1 or not with_legend:
set_visible(ax, False, *["legend", "xlabel"])

# determine bandwidth and set ylims
# if no negative freqs, set ymin exactly to 0
if ymin > -1e-6:
ymin = 0
bandwidth = ymax - ymin
ymin -= 0.1*bandwidth if ymin != 0 else 0
ymax += 0.1*bandwidth

for ax in ax_mat.ravel():
ax.set_ylim(ymin, ymax)


if not with_phdos:
# Return immediately.
if with_title:
Expand All @@ -894,6 +929,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
####################
# NB: B_qnu do not necessarily have the symmetry of the lattice so we have to loop over the full BZ.
# The frequency mesh is in eV, values are in states/eV.
# (note the units_scale variable before the phbands calculation)
# Use same q-mesh as phdos
phdos = phdos_file.phdos
phdos_ngqpt = np.diagonal(phdos_file.qptrlatt)
Expand All @@ -914,6 +950,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
raise RuntimeError(f"{len(phbands_qmesh.qpoints)=} != {np.product(phdos_ngqpt)=}")

#with_ibz_b2dos = False
xmax = -np.inf
for pstate in range(self.nstates):
# Compute B2(E) by looping over the full BZ.
bqnu_dos = np.zeros(len(phdos_mesh))
Expand All @@ -928,15 +965,21 @@ def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
bqnu_dos = Function1D(phdos_mesh, bqnu_dos)

ax = ax_mat[pstate, 1]
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)", color="black",
linewidth=lw_dos)
bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$B^2$(E)", color=marker_color,
linewidth=lw_dos)
pdos_opts = {"color": "black"}
lines_pdos = phdos.plot_dos_idos(ax, exchange_xy=True, units=units, label="phDOS(E)",
normalize=normalize, linewidth=lw_dos, **pdos_opts)
#phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)", color="black",
# linewidth=lw_dos, units=units)
lines_bdos = bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize,
label=r"$B^2$(E)", color=marker_color, linewidth=lw_dos,
xfactor=units_scale, yfactor=1/units_scale)
set_grid_legend(ax, fontsize, xlabel="Arb. unit")

# Get mapping BZ --> IBZ needed to obtain the KS eigenvalues e_nk from the IBZ for the DOS
# Compute B2(E) using only q-points in the IBZ. This is just for testing.
# B2_IBZ(E) should be equal to B2(E) only if B_qnu fullfill the lattice symmetries. See notes above.
with_ibz_b2dos = False
ibz_dos_opts = {"color": "darkred"}
"""
bqnu_dos = np.zeros(len(phdos_mesh))
for iq_ibz, qpoint in zip(bz2ibz, bz_qpoints):
Expand All @@ -945,8 +988,48 @@ def plot_bqnu_with_phbands(self, phbands_qpath, anaddb_file=None,
bqnu_dos += b2 gaussian(phdos_mesh, width, center=w)
bqnu_dos /= np.product(phdos_ngqpt)
"""
lines_bdos_ibz = None


dos_lines = [lines_pdos, lines_bdos]
colors = [pdos_opts["color"], marker_color]
span = ymax - ymin

if with_ibz_b2dos:
dos_lines.append(lines_bdos_ibz)
colors.append(ibz_dos_opts["color"])
# determine max x value for auto xlims
for dos, c in zip(dos_lines, colors):
for line in dos:
x_data, y_data = line.get_xdata(), line.get_ydata()
mask = (y_data > ymin) & (y_data < ymax+span*0.1)
xmax = max(np.max(x_data[mask]), xmax)

# fill Bqnu dos in order
if fill_dos:
y_common = np.linspace(ymin, ymax+span*0.1, 100)
xleft = np.zeros_like(y_common)
# skip eDOS, fill only BDOS
for dos, c in zip(dos_lines[1:], colors[1:]):
for line in dos:
x_data, y_data = line.get_xdata(), line.get_ydata()
interp_x = interp1d(y_data, x_data, kind='linear', fill_value='extrapolate')
xright = interp_x(y_common)

mask = (xright - xleft) > 0
y, x0, x1 = y_common[mask], xleft[mask], xright[mask]

if pstate != self.nstates - 1:
ax.fill_betweenx(y, x0, x1,
alpha=marker_alpha, color=c, linewidth=0)
xleft = xright

# Auto xlims for DOS
span = xmax
xmax += 0.1 * span
for ax in ax_mat[:,1]:
ax.set_xlim(0, xmax)

if pstate != self.nstates - 1 or not with_legend:
set_visible(ax, False, *["legend", "xlabel"])

if with_title:
Expand Down

0 comments on commit ea026f0

Please sign in to comment.