Skip to content

Commit

Permalink
Convert Marker from namedtuple to py class
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Jul 19, 2024
1 parent f8dd123 commit 9e3b24a
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 119 deletions.
2 changes: 1 addition & 1 deletion abipy/abio/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def escape(text):
for name in names:
value = vars[name]
if mnemonics and value is not None:
print(f"{name=}")
#print(f"{name=}")
app(escape("#### <" + var_database[name].mnemonics + ">"))

# Build variable, convert to string and append it
Expand Down
15 changes: 8 additions & 7 deletions abipy/core/kpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def issamek(k1, k2, atol=None):
"""
k1 = np.asarray(k1)
k2 = np.asarray(k2)
#if k1.shape != k2.shape:

return is_integer(k1 - k2, atol=atol)

Expand Down Expand Up @@ -416,15 +415,12 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
ngkpt:
check_mesh:
"""

# Transforms kpt in its corresponding reduced number in the interval [0,1[
k_indices = [np.round((kpt % 1) * ngkpt) for kpt in frac_coords]
k_indices = np.array(k_indices, dtype=int)

for kpt, inds in zip(frac_coords, k_indices):
if np.any(inds >= ngkpt):
print(f"{kpt=}, {np.round(kpt % 1)=} {inds=})")
#raise ValueError("")

# Debug secction.
if check_mesh:
print(f"kpoints_indices: Testing whether k-points belong to the {ngkpt =} mesh")
ierr = 0
Expand All @@ -435,7 +431,12 @@ def kpoints_indices(frac_coords, ngkpt, check_mesh=0) -> np.ndarray:
ierr += 1; print(kpt, "-->", same_k)
if ierr:
raise ValueError("Wrong mapping")
print("Check succesful!")

#for kpt, inds in zip(frac_coords, k_indices):
# if np.any(inds >= ngkpt):
# raise ValueError(f"inds >= nkgpt for {kpt=}, {np.round(kpt % 1)=} {inds=})")

print("Check succesfull!")

return k_indices

Expand Down
1 change: 0 additions & 1 deletion abipy/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,6 @@ def check_image(structure, site):

# If the centrosymmetry is broken at a given atomic site of the given structure,
# returns False. Else, return True

sites = self.sites

for s1 in sites:
Expand Down
33 changes: 19 additions & 14 deletions abipy/dfpt/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,8 @@ def plot(self, ax=None, units="eV", qlabels=None, branch_range=None, match_bands
self.plot_ax(ax, branch_range, units=units, match_bands=match_bands, **kwargs)

if points is not None:
ax.scatter(points.x, np.array(points.y), s=np.abs(points.s), marker="o", c="b")
ax.scatter(points.x, np.array(points.y), s=np.abs(points.s),
marker=points.marker, c=points.color, alpha=points.alpha)

if temp is not None:
# Scatter plot with Bose-Einstein occupation factors for T = temp
Expand Down Expand Up @@ -2940,7 +2941,7 @@ def __init__(self, filepath: str):
path: path to the file
"""
super().__init__(filepath)
self.reader = PHBST_Reader(filepath)
self.reader = self.r = PHBST_Reader(filepath)

# Initialize Phonon bands and add metadata from ncfile
self._phbands = PhononBands.from_file(filepath)
Expand Down Expand Up @@ -2982,7 +2983,7 @@ def phbands(self) -> PhononBands:

def close(self) -> None:
"""Close the file."""
self.reader.close()
self.r.close()

@lazy_property
def params(self) -> dict:
Expand Down Expand Up @@ -3623,12 +3624,16 @@ def __init__(self, filepath: str):
# Open the file, read data and create objects.
super().__init__(filepath)

self.reader = r = PhdosReader(filepath)
self.wmesh = r.wmesh
self.reader = self.r = PhdosReader(filepath)
self.wmesh = self.r.wmesh

def close(self) -> None:
"""Close the file."""
self.reader.close()
self.r.close()

@lazy_property
def qptrlatt(self):
return self.r.read_value("qptrlatt")

@lazy_property
def params(self) -> dict:
Expand Down Expand Up @@ -3666,12 +3671,12 @@ def to_string(self, verbose: int = 0) -> str:
@lazy_property
def structure(self) -> Structure:
"""|Structure| object."""
return self.reader.structure
return self.r.structure

@lazy_property
def phdos(self) -> PhononDos:
"""|PhononDos| object."""
return self.reader.read_phdos()
return self.r.read_phdos()

@lazy_property
def pjdos_symbol(self):
Expand All @@ -3680,15 +3685,15 @@ def pjdos_symbol(self):
where PhononDos is the contribution to the total DOS summed over atoms
with chemical symbol `symbol`.
"""
return self.reader.read_pjdos_symbol_dict()
return self.r.read_pjdos_symbol_dict()

@lazy_property
def msqd_dos(self):
"""
|MsqDos| object with Mean square displacement tensor in cartesian coords.
Allows one to calculate Debye Waller factors by integration with 1/omega and the Bose-Einstein factor.
"""
return self.reader.read_msq_dos()
return self.r.read_msq_dos()

@add_fig_kwargs
def plot_pjdos_type(self, units="eV", stacked=True, colormap="jet", alpha=0.7, exchange_xy=False,
Expand Down Expand Up @@ -3851,7 +3856,7 @@ def plot_pjdos_cartdirs_type(self, units="eV", stacked=True, colormap="jet", alp
cmap = plt.get_cmap(colormap)

# symbol --> [three, number_of_frequencies] in cart dirs
pjdos_symbol_rc = self.reader.read_pjdos_symbol_xyz_dict()
pjdos_symbol_rc = self.r.read_pjdos_symbol_xyz_dict()

xx = self.phdos.mesh * factor
for idir, ax in enumerate(ax_list):
Expand All @@ -3865,7 +3870,7 @@ def plot_pjdos_cartdirs_type(self, units="eV", stacked=True, colormap="jet", alp

# Plot Type projected DOSes along cartesian direction idir
cumulative = np.zeros(len(self.wmesh))
for itype, symbol in enumerate(self.reader.chemical_symbols):
for itype, symbol in enumerate(self.r.chemical_symbols):
color = cmap(float(itype) / max(1, ntypat - 1))
yy = pjdos_symbol_rc[symbol][idir] / factor

Expand Down Expand Up @@ -3923,7 +3928,7 @@ def plot_pjdos_cartdirs_site(self, view="inequivalent", units="eV", stacked=True
cmap = plt.get_cmap(colormap)

# [natom, three, nomega] array with PH-DOS projected over atoms and cartesian directions
pjdos_atdir = self.reader.read_pjdos_atdir()
pjdos_atdir = self.r.read_pjdos_atdir()

xx = self.phdos.mesh * factor
for idir, ax in enumerate(ax_list):
Expand Down Expand Up @@ -4019,7 +4024,7 @@ def to_pymatgen(self) -> PmgCompletePhononDos:
total_dos = self.phdos.to_pymatgen()

# [natom, three, nomega] array with PH-DOS projected over atoms and cartesian directions"""
pjdos_atdir = self.reader.read_pjdos_atdir()
pjdos_atdir = self.r.read_pjdos_atdir()

factor = abu.phfactor_ev2units("thz")
summed_pjdos = np.sum(pjdos_atdir, axis=1) / factor
Expand Down
13 changes: 5 additions & 8 deletions abipy/electrons/ebands.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def num_leadblanks(string):
return _TIPS


class ElectronTransition(object):
class ElectronTransition:
"""
This object describes an electronic transition between two single-particle states.
"""
Expand Down Expand Up @@ -516,8 +516,7 @@ def __init__(self, structure, kpoints, eigens, fermie, occfacts, nelect, nspinor
Args:
structure: |Structure| object.
kpoints: |KpointList| instance.
eigens: Array-like object with the eigenvalues (eV) stored as [s, k, b]
where s: spin , k: kpoint, b: band index
eigens: Array-like object with the eigenvalues (eV) stored as [s, k, b] where s: spin , k: kpoint, b: band index
fermie: Fermi level in eV.
occfacts: Occupation factors (same shape as eigens)
nelect: Number of valence electrons in the unit cell.
Expand Down Expand Up @@ -798,8 +797,7 @@ def with_points_along_path(self, frac_bounds=None, knames=None, dist_tol=1e-12):
frac_bounds: [M, 3] array with the vertexes of the k-path in reduced coordinates.
If None, the k-path is automatically selected from the structure.
knames: List of strings with the k-point labels defining the k-path. It has precedence over frac_bounds.
dist_tol: A point is considered to be on the path if its distance from the line
is less than dist_tol.
dist_tol: A point is considered to be on the path if its distance from the line is less than dist_tol.
Return:
namedtuple with the following attributes::
Expand Down Expand Up @@ -2216,7 +2214,8 @@ def plot(self, spin=None, band_range=None, klabels=None, e0="fermie", ax=None, y
self.plot_ax(ax, e0, spin=spin, band=band, **opts)

if points is not None:
ax.scatter(points.x, np.array(points.y) - e0, s=np.abs(points.s), marker="o", c="b")
ax.scatter(points.x, np.array(points.y) - e0, s=np.abs(points.s),
marker=points.marker, c=points.color, alpha=points.alpha)

if with_gaps and (self.mband > self.nspinor * self.nelect // 2):
# Show fundamental and direct gaps for each spin.
Expand Down Expand Up @@ -2775,7 +2774,6 @@ def plot_with_edos(self, edos, klabels=None, ax_list=None, e0="fermie", points=N
* ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
* Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
* None: Don't shift energies, equivalent to ``e0 = 0``
points: Marker object with the position and the size of the marker.
Used for plotting purpose e.g. QP energies, energy derivatives...
with_gaps: True to add markers and arrows showing the fundamental and the direct gap.
Expand Down Expand Up @@ -2854,7 +2852,6 @@ def plotly_with_edos(self, edos, klabels=None, fig=None, band_rcd=None, dos_rcd=
* ``edos_fermie``: Use the Fermi energy computed from the DOS to define the zero of energy in both subplots.
* Number e.g ``e0 = 0.5``: shift all eigenvalues to have zero energy at 0.5 eV
* None: Don't shift energies, equivalent to ``e0 = 0``
points: Marker object with the position and the size of the marker.
Used for plotting purpose e.g. QP energies, energy derivatives...
with_gaps: True to add markers and arrows showing the fundamental and the direct gap.
Expand Down
Loading

0 comments on commit 9e3b24a

Please sign in to comment.