Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Dec 15, 2024
1 parent 861c9b8 commit 5d17288
Show file tree
Hide file tree
Showing 13 changed files with 353 additions and 304 deletions.
10 changes: 5 additions & 5 deletions abipy/eph/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,27 @@ class BaseEphReader(ElectronsReader):
"""

@lazy_property
def ddb_ngqpt(self):
def ddb_ngqpt(self) -> np.ndarray:
"""Q-Mesh for DDB file."""
return self.read_value("ddb_ngqpt")

@lazy_property
def ngqpt(self):
def ngqpt(self) -> np.ndarray:
"""Effective Q-mesh used in to compute integrals (ph_linewidts, e-ph self-energy)."""
return self.read_value("ngqpt")

@lazy_property
def ph_ngqpt(self):
def ph_ngqpt(self) -> np.ndarray:
"""Q-mesh for Phonon DOS, interpolated A2F ..."""
return self.read_value("ph_ngqpt")

@lazy_property
def eph_ngqpt_fine(self):
def eph_ngqpt_fine(self) -> np.ndarray:
"""Q-mesh for interpolated DFPT potentials"""
return self.read_value("eph_ngqpt_fine")

@lazy_property
def common_eph_params(self):
def common_eph_params(self) -> dict:
"""
Read basic parameters (scalars) from the netcdf files produced by the EPH code and cache them
"""
Expand Down
46 changes: 23 additions & 23 deletions abipy/eph/gpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"""
from __future__ import annotations

import dataclasses
#import dataclasses
import numpy as np
import pandas as pd
#import pandas as pd
import abipy.core.abinit_units as abu

from monty.string import marquee
Expand All @@ -27,7 +27,7 @@
from abipy.eph.common import BaseEphReader


def k2s(k_vector, fmt=".3f", threshold = 1e-8) -> str:
def k2s(k_vector, fmt=".3f", threshold=1e-8) -> str:
k_vector = np.asarray(k_vector)
k_vector[np.abs(k_vector) < threshold] = 0

Expand Down Expand Up @@ -99,7 +99,7 @@ def params(self) -> dict:
def __str__(self) -> str:
return self.to_string()

def to_string(self, verbose: int=0) -> str:
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosiy level ``verbose``."""
lines = []; app = lines.append

Expand All @@ -126,15 +126,15 @@ def _get_which_g_list(which_g: str) -> list[str]:
return all_choices

if which_g not in all_choices:
raise ValueError(f"Invalid {which=}, should be in {all_choices=}")
raise ValueError(f"Invalid {which_g=}, should be in {all_choices=}")

return [which_g]

def _get_band_range(self, band_range):
return (self.r.bstart, self.r.bstop) if band_range is None else band_range

@add_fig_kwargs
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1, gmax_mev=250,
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int = 0, scale=1, gmax_mev=250,
ph_modes=None, with_phbands=True, with_ebands=False,
ax_mat=None, fontsize=8, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -481,16 +481,16 @@ def get_gnuq_average_spin(self, spin: int, band_range: list|tuple|None, eps_mev:
# Average over degenerate k+q electrons taking bstart into account.
absg = absg_avg.copy()
for iq in range(self.nq_path):
for n_k in range(nb_in_g):
for m_kq in range(nb_in_g):
w_1 = all_eigens_kq[spin, iq, m_kq + bstart]
g2_nu[:], nn = 0.0, 0
for bsum_kq in range(nb_in_g):
w_2 = all_eigens_kq[spin, iq, bsum_kq + bstart]
if abs(w_2 - w_1) >= eps_ev: continue
nn += 1
g2_nu += absg[iq,:,bsum_kq,n_k] ** 2
absg_avg[iq,:,m_kq,n_k] = np.sqrt(g2_nu / nn)
for n_k in range(nb_in_g):
for m_kq in range(nb_in_g):
w_1 = all_eigens_kq[spin, iq, m_kq + bstart]
g2_nu[:], nn = 0.0, 0
for bsum_kq in range(nb_in_g):
w_2 = all_eigens_kq[spin, iq, bsum_kq + bstart]
if abs(w_2 - w_1) >= eps_ev: continue
nn += 1
g2_nu += absg[iq,:,bsum_kq,n_k] ** 2
absg_avg[iq,:,m_kq,n_k] = np.sqrt(g2_nu / nn)

# Transpose the data: (nq_path, natom3, nb_in_g, nb_in_g) -> (natom3, nq_path, nb_in_g, nb_in_g)
absg_avg, absg_raw = absg_avg.transpose(1, 0, 2, 3).copy(), absg_raw.transpose(1, 0, 2, 3).copy()
Expand Down Expand Up @@ -553,11 +553,11 @@ def get_gnuk_average_spin(self, spin: int, band_range: list|tuple|None, eps_mev:
absg_avg = np.zeros_like(absg)
for ik in range(self.nk_path):
for nu in range(natom3):
# Sum the squared values of absg over the degenerate phonon mu indices.
mask_nu = np.abs(phfreqs_ha[iq, :] - phfreqs_ha[iq, nu]) < eps_ha
g2_mn = np.sum(absg[ik, mask_nu, :, :]**2, axis=0)
# Compute the symmetrized value and divide by the number of degenerate ph-modes for this iq.
absg_avg[ik, nu, :, :] = np.sqrt(g2_mn / np.sum(mask_nu))
# Sum the squared values of absg over the degenerate phonon mu indices.
mask_nu = np.abs(phfreqs_ha[iq, :] - phfreqs_ha[iq, nu]) < eps_ha
g2_mn = np.sum(absg[ik, mask_nu, :, :]**2, axis=0)
# Compute the symmetrized value and divide by the number of degenerate ph-modes for this iq.
absg_avg[ik, nu, :, :] = np.sqrt(g2_mn / np.sum(mask_nu))

# MG FIXME: Note the difference with a similar function in gkq here I use absg and not absgk
# Average over degenerate k electrons taking bstart into account.
Expand Down Expand Up @@ -645,7 +645,7 @@ def plot_g_qpath(self, which_g="avg", gmax_mev=250, ph_modes=None,

# TODO: Compute common band range.
band_range = None
ref_ifile= 0
ref_ifile = 0
#q_label = r"$|q|^{%d}$" % with_qexp if with_qexp else ""
#g_units = "(meV)" if with_qexp == 0 else r"(meV $\AA^-{%s}$)" % with_qexp

Expand Down Expand Up @@ -703,4 +703,4 @@ def write_notebook(self, nbpath=None) -> str:
#nb.cells.extend(self.get_baserobot_code_cells())
#nb.cells.extend(self.get_ebands_code_cells())

return self._write_nb_nbpath(nb, nbpath)
return self._write_nb_nbpath(nb, nbpath)
14 changes: 7 additions & 7 deletions abipy/eph/gstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def find_iq_glob_qpoint(self, qpoint, spin: int):
#print(f"Found {qpoint = } with index {iq_g = }")
return iq_g, qpoint

raise ValueError(f"Cannot find {qpoint = } in GSTORE.nc")
raise ValueError(f"Cannot find {qpoint=} in GSTORE.nc")

def find_ik_glob_kpoint(self, kpoint, spin: int):
"""Find the internal indices of the kpoint needed to access the gvals array."""
Expand All @@ -486,7 +486,7 @@ def find_ik_glob_kpoint(self, kpoint, spin: int):
#print(f"Found {kpoint = } with index {ik_g = }")
return ik_g, kpoint

raise ValueError(f"Cannot find {kpoint = } in GSTORE.nc")
raise ValueError(f"Cannot find {kpoint=} in GSTORE.nc")

# TODO: This fix to read groups should be imported in pymatgen.
@lazy_property
Expand Down Expand Up @@ -539,7 +539,7 @@ def neq(self, ref_basename: str | None = None, verbose: int = 0) -> int:
return ierr

@staticmethod
def _neq_two_gstores(gstore1: GstoreFile, gstore2: GstoreFile, verbose: int) -> int:
def _neq_two_gstores(self: GstoreFile, gstore2: GstoreFile, verbose: int) -> int:
"""
Helper function to compare two GSTORE files.
"""
Expand All @@ -550,12 +550,12 @@ def _neq_two_gstores(gstore1: GstoreFile, gstore2: GstoreFile, verbose: int) ->
]

for aname in aname_list:
self._compare_attr_name(aname, gstore1, gstore2)
self._compare_attr_name(aname, self, gstore2)

# Now compare the gkq objects for each spin.
ierr = 0
for spin in range(gstore1.nsppol):
gqk1, gqk2 = gstore1.gqk_spin[spin], gstore2.gqk_spin[spin]
for spin in range(self.nsppol):
gqk1, gqk2 = self.gqk_spin[spin], gstore2.gqk_spin[spin]
ierr += gqk1.neq(gqk2, verbose)

return ierr
Expand Down Expand Up @@ -585,4 +585,4 @@ def write_notebook(self, nbpath=None) -> str:
#nb.cells.extend(self.get_baserobot_code_cells())
#nb.cells.extend(self.get_ebands_code_cells())

return self._write_nb_nbpath(nb, nbpath)
return self._write_nb_nbpath(nb, nbpath)
2 changes: 1 addition & 1 deletion abipy/eph/gwan.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def neq(self, other: Gqk, verbose: int) -> int:
return ierr


class GstoreReader(BaseEphReader):
class GwanReader(BaseEphReader):
"""
Reads data from file and constructs objects.
Expand Down
38 changes: 18 additions & 20 deletions abipy/eph/varpeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
import pandas as pd
import abipy.core.abinit_units as abu

from collections import defaultdict
from monty.string import marquee
from monty.functools import lazy_property
from monty.termcolor import cprint
#from monty.termcolor import cprint
from abipy.core.func1d import Function1D
from abipy.core.structure import Structure
from abipy.core.kpoints import kpoints_indices, kmesh_from_mpdivs, map_grid2ibz, IrredZone
from abipy.core.kpoints import kpoints_indices, kmesh_from_mpdivs, map_grid2ibz #, IrredZone
from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter
from abipy.tools.typing import PathLike
from abipy.tools.plotting import (add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, set_visible,
Expand Down Expand Up @@ -107,7 +106,6 @@ class Entry:
# Convert to dictionary: name --> Entry
_ALL_ENTRIES = {e.name: e for e in _ALL_ENTRIES}


class VarpeqFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter):
"""
This file stores the results of a VARPEQ calculations: SCF cycle, A_nk, B_qnu coefficients
Expand Down Expand Up @@ -174,7 +172,7 @@ def params(self) -> dict:
def __str__(self) -> str:
return self.to_string()

def to_string(self, verbose: int=0) -> str:
def to_string(self, verbose: int = 0) -> str:
"""String representation with verbosiy level ``verbose``."""
lines = []; app = lines.append

Expand All @@ -197,7 +195,7 @@ def to_string(self, verbose: int=0) -> str:
for spin in range(self.nsppol):
polaron = self.polaron_spin[spin]
df = polaron.get_final_results_df()
app(f"Last SCF iteration. Energies in eV units")
app("Last SCF iteration. Energies in eV units")
app(str(df))
app("")

Expand Down Expand Up @@ -322,7 +320,7 @@ def ufact_k(k):

return df_list

def get_final_results_df(self, with_params: bool=False) -> pd.DataFrame:
def get_final_results_df(self, with_params: bool = False) -> pd.DataFrame:
"""
Return daframe with the last iteration for all polaronic states.
NB: Energies are in eV.
Expand All @@ -342,7 +340,7 @@ def get_final_results_df(self, with_params: bool=False) -> pd.DataFrame:
def __str__(self) -> str:
return self.to_string()

def to_string(self, verbose: int=0) -> str:
def to_string(self, verbose: int = 0) -> str:
"""
String representation with verbosiy level verbose.
"""
Expand Down Expand Up @@ -380,7 +378,7 @@ def ngkpt_and_shifts(self) -> tuple:

return ngkpt, shifts

def get_title(self, with_gaps: bool=True) -> str:
def get_title(self, with_gaps: bool = True) -> str:
"""
Return string with title for matplotlib plots.
"""
Expand Down Expand Up @@ -462,7 +460,7 @@ def get_b2_interpolator_state(self, interp_method) -> BzRegularGridInterpolator:
return [BzRegularGridInterpolator(self.structure, shifts, np.abs(b_data[pstate])**2, method=interp_method)
for pstate in range(self.nstates)]

def write_a2_bxsf(self, filepath: PathLike, fill_value: float=0.0) -> None:
def write_a2_bxsf(self, filepath: PathLike, fill_value: float = 0.0) -> None:
r"""
Export \sum_n |A_{pnk}|^2 in BXSF format suitable for visualization with xcrysden (use ``xcrysden --bxsf FILE``).
Requires gamma-centered k-mesh.
Expand All @@ -480,7 +478,7 @@ def write_a2_bxsf(self, filepath: PathLike, fill_value: float=0.0) -> None:

bxsf_write(filepath, self.structure, 1, self.nstates, ngkpt, a2_data, fermie, unit="Ha")

def write_b2_bxsf(self, filepath: PathLike, fill_value: float=0.0) -> None:
def write_b2_bxsf(self, filepath: PathLike, fill_value: float = 0.0) -> None:
r"""
Export \sum_{\nu} |B_{q\nu}|^2 in BXSF format suitable for visualization with xcrysden (use ``xcrysden --bxsf FILE``).
Expand Down Expand Up @@ -554,9 +552,9 @@ def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:

@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath,
ebands_kmesh=None, lpratio: int=5,
with_ibz_a2dos=True, method="gaussian", step: float=0.05, width: float=0.1,
nksmall: int=20, normalize: bool=False, with_title=True, interp_method="linear",
ebands_kmesh=None, lpratio: int = 5,
with_ibz_a2dos=True, method="gaussian", step: float = 0.05, width: float = 0.1,
nksmall: int = 20, normalize: bool = False, with_title=True, interp_method="linear",
ax_mat=None, ylims=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
"""
Plot electron bands with markers whose size is proportional to |A_nk|^2.
Expand Down Expand Up @@ -721,7 +719,7 @@ def plot_bqnu_with_ddb(self, ddb, with_phdos=True, anaddb_kwargs=None, **kwargs)

@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath,
phdos_file=None, ddb=None, width=0.001, normalize: bool=True,
phdos_file=None, ddb=None, width=0.001, normalize: bool = True,
verbose=0, anaddb_kwargs=None, with_title=True, interp_method="linear",
ax_mat=None, scale=10, marker_color="gold", fontsize=12, **kwargs) -> Figure:
"""
Expand Down Expand Up @@ -773,7 +771,8 @@ def plot_bqnu_with_phbands(self, phbands_qpath,

if not with_phdos:
# Return immediately.
if with_title: fig.suptitle(self.get_title(with_gaps=True))
if with_title:
fig.suptitle(self.get_title(with_gaps=True))
return fig

####################
Expand Down Expand Up @@ -803,7 +802,6 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
#with_ibz_b2dos = False

for pstate in range(self.nstates):

# Compute B2(E) by looping over the full BZ.
bqnu_dos = np.zeros(len(phdos_mesh))
for iq_bz, qpoint in enumerate(phbands_qmesh.qpoints):
Expand All @@ -815,7 +813,7 @@ def plot_bqnu_with_phbands(self, phbands_qpath,
for w, b2 in zip(freqs_nu, b2_interp_state[pstate].eval_kpoint(qpoint), strict=True):
bqnu_dos += q_weight * b2 * gaussian(phdos_mesh, width, center=w)

bqnu_dos = Function1D(wmesh, bqnu_dos)
bqnu_dos = Function1D(phdos_mesh, bqnu_dos)

ax = ax_mat[pstate, 1]
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)", color="black")
Expand Down Expand Up @@ -946,7 +944,7 @@ def to_string(self, verbose=0) -> str:

return "\n".join(lines)

def get_final_results_df(self, spin=None, sortby=None, with_params: bool=True) -> pd.DataFrame:
def get_final_results_df(self, spin=None, sortby=None, with_params: bool = True) -> pd.DataFrame:
"""
Return dataframe with the last iteration for all polaronic states.
NB: Energies are in eV.
Expand Down Expand Up @@ -1046,4 +1044,4 @@ def write_notebook(self, nbpath=None) -> str:
#nbv.new_code_cell("ebands_plotter = robot.get_ebands_plotter()"),
])

return self._write_nb_nbpath(nb, nbpath)
return self._write_nb_nbpath(nb, nbpath)
Loading

0 comments on commit 5d17288

Please sign in to comment.