diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a727f6d7910..4c63a00a60c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,3 +2,5 @@ pymatgen/io/ase.py @Andrew-S-Rosen pymatgen/io/abinit/* @gmatteo pymatgen/io/lobster/* @JaGeo +pymatgen/ext/* @ml-evs +tests/ext/* @ml-evs diff --git a/.github/release.yml b/.github/release.yml index 5bc3ec2d50c..1eb9bf3b135 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -14,6 +14,8 @@ changelog: labels: [housekeeping] - title: ๐Ÿš€ Performance labels: [performance] + - title: ๐Ÿšง CI + labels: [ci] - title: ๐Ÿ’ก Refactoring labels: [refactor] - title: ๐Ÿงช Tests diff --git a/.github/workflows/issue-metrics.yml b/.github/workflows/issue-metrics.yml new file mode 100644 index 00000000000..26e498a56ae --- /dev/null +++ b/.github/workflows/issue-metrics.yml @@ -0,0 +1,42 @@ +name: Monthly issue metrics +on: + workflow_dispatch: + schedule: + - cron: '3 2 1 * *' + +permissions: + contents: read + +jobs: + build: + name: issue metrics + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: read + steps: + - name: Get dates for last month + shell: bash + run: | + # Calculate the first day of the previous month + first_day=$(date -d "last month" +%Y-%m-01) + + # Calculate the last day of the previous month + last_day=$(date -d "$first_day +1 month -1 day" +%Y-%m-%d) + + #Set an environment variable with the date range + echo "$first_day..$last_day" + echo "last_month=$first_day..$last_day" >> "$GITHUB_ENV" + + - name: Run issue-metrics tool + uses: github/issue-metrics@v3 + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SEARCH_QUERY: 'repo:materialsproject/pymatgen is:issue created:${{ env.last_month }} -reason:"not planned"' + + - name: Create issue + uses: peter-evans/create-issue-from-file@v5 + with: + title: Monthly issue metrics report + token: ${{ secrets.GITHUB_TOKEN }} + content-filepath: ./issue_metrics.md diff --git a/.github/workflows/jekyll-gh-pages.yml b/.github/workflows/jekyll-gh-pages.yml index 308873fc823..91f17b3ea39 100644 --- a/.github/workflows/jekyll-gh-pages.yml +++ b/.github/workflows/jekyll-gh-pages.yml @@ -1,15 +1,11 @@ -# Sample workflow for building and deploying a Jekyll site to GitHub Pages name: Deploy Jekyll with GitHub Pages dependencies preinstalled on: - # Runs on pushes targeting the default branch push: branches: ["master"] + workflow_dispatch: # enable manual workflow execution - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages permissions: contents: read pages: write @@ -22,23 +18,26 @@ concurrency: cancel-in-progress: false jobs: - # Build job build: + # prevent this action from running on forks + if: github.repository == 'materialsproject/pymatgen' runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 + - name: Setup Pages uses: actions/configure-pages@v3 + - name: Build with Jekyll uses: actions/jekyll-build-pages@v1 with: source: ./docs destination: ./_site + - name: Upload artifact uses: actions/upload-pages-artifact@v2 - # Deployment job deploy: environment: name: github-pages diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dbdf64a779c..cacd81dcf36 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,7 +27,7 @@ jobs: - name: ruff run: | ruff --version - ruff . + ruff check . ruff format --check . - name: mypy diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ad6a6eecea0..24b61d90aac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,6 +19,9 @@ jobs: test: # prevent this action from running on forks if: github.repository == 'materialsproject/pymatgen' + defaults: + run: + shell: bash -l {0} # enables conda/mamba env activation by reading bash profile strategy: fail-fast: false matrix: @@ -48,65 +51,41 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - cache-dependency-path: setup.py + - name: Set up micromamba + uses: mamba-org/setup-micromamba@main + + - name: Create mamba environment + run: | + micromamba create -n pmg python=${{ matrix.python-version }} --yes - name: Install uv - run: pip install uv + run: micromamba run -n pmg pip install uv - name: Copy GULP to bin if: matrix.os == 'ubuntu-latest' run: | sudo cp cmd_line/gulp/Linux_64bit/* /usr/local/bin/ - - name: Install Bader + - name: Install ubuntu-only conda dependencies if: matrix.os == 'ubuntu-latest' run: | - wget https://theory.cm.utexas.edu/henkelman/code/bader/download/bader_lnx_64.tar.gz - tar xvzf bader_lnx_64.tar.gz - sudo mv bader /usr/local/bin/ - continue-on-error: true # This is not critical to succeed. + micromamba install -n pmg -c conda-forge enumlib packmol bader openbabel openff-toolkit --yes - - name: Install Enumlib - if: matrix.os == 'ubuntu-latest' + - name: Install pymatgen and dependencies run: | - git clone --recursive https://github.com/msg-byu/enumlib.git - cd enumlib/symlib/src - export F90=gfortran - make - cd ../../src - make enum.x - sudo mv enum.x /usr/local/bin/ - cd .. - sudo cp aux_src/makeStr.py /usr/local/bin/ - continue-on-error: true # This is not critical to succeed. - - - name: Install Packmol - if: matrix.os == 'ubuntu-latest' - run: | - wget -O packmol.tar.gz https://github.com/m3g/packmol/archive/refs/tags/v20.14.2.tar.gz - tar xvzf packmol.tar.gz - export F90=gfortran - cd packmol-20.14.2 - ./configure - make - sudo mv packmol /usr/local/bin/ - cd .. - continue-on-error: true # This is not critical to succeed. - - - name: Install dependencies - run: | - uv pip install numpy cython --system + micromamba activate pmg + # TODO remove temporary fix. added since uv install torch is flaky. + # track https://github.com/astral-sh/uv/issues/1921 for resolution + pip install torch + + uv pip install numpy cython - uv pip install -e '.[dev,optional]' --system + uv pip install --editable '.[dev,optional]' # TODO remove next line installing ase from main branch when FrechetCellFilter is released - uv pip install --upgrade 'ase@git+https://gitlab.com/ase/ase' --system + uv pip install --upgrade 'git+https://gitlab.com/ase/ase' - name: pytest split ${{ matrix.split }} run: | + micromamba activate pmg pytest --splits 10 --group ${{ matrix.split }} --durations-path tests/files/.pytest-split-durations tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81032d2aeda..a6745467041 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,14 +8,14 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.3.7 hooks: - id: ruff args: [--fix, --unsafe-fixes] - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-yaml - id: end-of-file-fixer diff --git a/dev_scripts/chemenv/equivalent_indices.py b/dev_scripts/chemenv/equivalent_indices.py index 4e1250b58e8..1d30564b05f 100644 --- a/dev_scripts/chemenv/equivalent_indices.py +++ b/dev_scripts/chemenv/equivalent_indices.py @@ -99,9 +99,9 @@ # 0. any point for i0 in range(8): # 1. point opposite to point 0. in the square face - if i0 in [0, 2]: + if i0 in {0, 2}: i1 = i0 + 1 - elif i0 in [1, 3]: + elif i0 in {1, 3}: i1 = i0 - 1 elif i0 == 4: i1 = 7 @@ -111,10 +111,14 @@ i1 = 5 elif i0 == 7: i1 = 4 + else: + raise RuntimeError("Cannot determine point.") + # 2. one of the two last points in the square face sfleft = list(sf1) if i0 in sf1 else list(sf2) sfleft.remove(i0) sfleft.remove(i1) + i2 = 0 for i2 in sfleft: sfleft2 = list(sfleft) sfleft2.remove(i2) diff --git a/dev_scripts/chemenv/get_plane_permutations_optimized.py b/dev_scripts/chemenv/get_plane_permutations_optimized.py index d0d86273957..47ed4f13b93 100644 --- a/dev_scripts/chemenv/get_plane_permutations_optimized.py +++ b/dev_scripts/chemenv/get_plane_permutations_optimized.py @@ -279,7 +279,7 @@ def random_permutations_iterator(initial_permutation, n_permutations): f"Get the explicit optimized permutations for geometry {cg.name!r} (symbol : " f'{cg_symbol!r}) ? ("y" to confirm, "q" to quit)\n' ) - if test not in ["y", "q"]: + if test not in ("y", "q"): print("Wrong key, try again") continue if test == "y": diff --git a/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py b/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py index 7b1f40218bd..06e5841d155 100644 --- a/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py +++ b/dev_scripts/chemenv/strategies/multi_weights_strategy_parameters.py @@ -151,17 +151,18 @@ def get_structure(self, morphing_factor): coords = copy.deepcopy(self.abstract_geometry.points_wcs_ctwcc()) bare_points = self.abstract_geometry.bare_points_with_centre + origin = None for morphing in self.morphing_description: - if morphing["site_type"] == "neighbor": - i_site = morphing["ineighbor"] + 1 - if morphing["expansion_origin"] == "central_site": - origin = bare_points[0] - vector = bare_points[i_site] - origin - coords[i_site] += vector * (morphing_factor - 1.0) - else: + if morphing["site_type"] != "neighbor": raise ValueError(f"Key \"site_type\" is {morphing['site_type']} while it can only be neighbor") + i_site = morphing["ineighbor"] + 1 + if morphing["expansion_origin"] == "central_site": + origin = bare_points[0] + vector = bare_points[i_site] - origin + coords[i_site] += vector * (morphing_factor - 1.0) + return Structure(lattice=lattice, species=species, coords=coords, coords_are_cartesian=True) def estimate_parameters(self, dist_factor_min, dist_factor_max, symmetry_measure_type="csm_wcs_ctwcc"): @@ -269,7 +270,7 @@ def get_weights(self, weights_options): "+-------------------------------------------------------------+\n" ) - with open("ce_pairs.json") as file: + with open("ce_pairs.json", encoding="utf-8") as file: ce_pairs = json.load(file) self_weight_max_csms: dict[str, list[float]] = {} self_weight_max_csms_per_cn: dict[str, list[float]] = {} diff --git a/dev_scripts/chemenv/view_environment.py b/dev_scripts/chemenv/view_environment.py index 17a1df52428..2caa22e9f34 100644 --- a/dev_scripts/chemenv/view_environment.py +++ b/dev_scripts/chemenv/view_environment.py @@ -52,6 +52,7 @@ print() # Visualize the separation plane of a given algorithm sep_plane = False + algo = None if any(algo.algorithm_type == SEPARATION_PLANE for algo in cg.algorithms): test = input("Enter index of the algorithm for which you want to visualize the plane : ") if test != "": diff --git a/dev_scripts/potcar_scrambler.py b/dev_scripts/potcar_scrambler.py index fe6d659cf94..639c366bc7c 100644 --- a/dev_scripts/potcar_scrambler.py +++ b/dev_scripts/potcar_scrambler.py @@ -4,6 +4,7 @@ import shutil import warnings from glob import glob +from typing import TYPE_CHECKING import numpy as np from monty.os.path import zpath @@ -14,6 +15,9 @@ from pymatgen.io.vasp.sets import _load_yaml_config from pymatgen.util.testing import VASP_IN_DIR +if TYPE_CHECKING: + from typing_extensions import Self + class PotcarScrambler: """ @@ -34,18 +38,14 @@ class PotcarScrambler: from existing POTCAR `input_filename` """ - def __init__(self, potcars: Potcar | PotcarSingle): - if isinstance(potcars, PotcarSingle): - self.PSP_list = [potcars] - else: - self.PSP_list = potcars + def __init__(self, potcars: Potcar | PotcarSingle) -> None: + self.PSP_list = [potcars] if isinstance(potcars, PotcarSingle) else potcars self.scrambled_potcars_str = "" for psp in self.PSP_list: scrambled_potcar_str = self.scramble_single_potcar(psp) self.scrambled_potcars_str += scrambled_potcar_str - return - def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5): + def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5) -> float: n_prec = len(input_str.split(".")[1]) bd = max(1, bloat * abs(float(input_str))) return round(bd * np.random.rand(1)[0], n_prec) @@ -53,7 +53,7 @@ def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5): def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5): input_str = input_str.strip() - if input_str.lower() in ("t", "f", "true", "false"): + if input_str.lower() in {"t", "f", "true", "false"}: return bool(np.random.randint(2)) if input_str.upper() == input_str.lower() and input_str[0].isnumeric(): @@ -68,7 +68,7 @@ def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5): except ValueError: return input_str - def scramble_single_potcar(self, potcar: PotcarSingle): + def scramble_single_potcar(self, potcar: PotcarSingle) -> str: """ Scramble the body of a POTCAR, retain the PSCTR header information. @@ -124,12 +124,12 @@ def scramble_single_potcar(self, potcar: PotcarSingle): ) return scrambled_potcar_str - def to_file(self, filename: str): + def to_file(self, filename: str) -> None: with zopen(filename, mode="wt") as file: file.write(self.scrambled_potcars_str) @classmethod - def from_file(cls, input_filename: str, output_filename: str | None = None): + def from_file(cls, input_filename: str, output_filename: str | None = None) -> Self: psp = Potcar.from_file(input_filename) psp_scrambled = cls(psp) if output_filename: @@ -137,7 +137,7 @@ def from_file(cls, input_filename: str, output_filename: str | None = None): return psp_scrambled -def generate_fake_potcar_libraries(): +def generate_fake_potcar_libraries() -> None: """ To test the `_gen_potcar_summary_stats` function in `pymatgen.io.vasp.inputs`, need a library of fake POTCARs which do not violate copyright @@ -173,7 +173,7 @@ def generate_fake_potcar_libraries(): break -def potcar_cleanser(): +def potcar_cleanser() -> None: """ Function to replace copyrighted POTCARs used in io.vasp.sets testing with dummy POTCARs that have scrambled PSP and kinetic energy values diff --git a/dev_scripts/regen_libxcfunc.py b/dev_scripts/regen_libxcfunc.py index 2965acab22e..1c0491d112c 100755 --- a/dev_scripts/regen_libxcfunc.py +++ b/dev_scripts/regen_libxcfunc.py @@ -50,16 +50,16 @@ def write_libxc_docs_json(xc_funcs, json_path): xc_funcs = deepcopy(xc_funcs) # Remove XC_FAMILY from Family and XC_ from Kind to make strings more human-readable. - for d in xc_funcs.values(): - d["Family"] = d["Family"].replace("XC_FAMILY_", "", 1) - d["Kind"] = d["Kind"].replace("XC_", "", 1) + for dct in xc_funcs.values(): + dct["Family"] = dct["Family"].replace("XC_FAMILY_", "", 1) + dct["Kind"] = dct["Kind"].replace("XC_", "", 1) # Build lightweight version with a subset of keys. - for num, d in xc_funcs.items(): - xc_funcs[num] = {key: d[key] for key in ("Family", "Kind", "References")} + for num, dct in xc_funcs.items(): + xc_funcs[num] = {key: dct[key] for key in ("Family", "Kind", "References")} # Descriptions are optional for opt in ("Description 1", "Description 2"): - desc = d.get(opt) + desc = dct.get(opt) if desc is not None: xc_funcs[num][opt] = desc diff --git a/dev_scripts/update_pt_data.py b/dev_scripts/update_pt_data.py index 3c740e2255a..178c2ce394b 100644 --- a/dev_scripts/update_pt_data.py +++ b/dev_scripts/update_pt_data.py @@ -128,16 +128,16 @@ def parse_radii(): def update_ionic_radii(): data = loadfn(ptable_yaml_path) - for d in data.values(): - if "Ionic_radii" in d: - d["Ionic radii"] = {k: v / 100 for k, v in d["Ionic_radii"].items()} - del d["Ionic_radii"] - if "Ionic_radii_hs" in d: - d["Ionic radii hs"] = {k: v / 100 for k, v in d["Ionic_radii_hs"].items()} - del d["Ionic_radii_hs"] - if "Ionic_radii_ls" in d: - d["Ionic radii ls"] = {k: v / 100 for k, v in d["Ionic_radii_ls"].items()} - del d["Ionic_radii_ls"] + for dct in data.values(): + if "Ionic_radii" in dct: + dct["Ionic radii"] = {k: v / 100 for k, v in dct["Ionic_radii"].items()} + del dct["Ionic_radii"] + if "Ionic_radii_hs" in dct: + dct["Ionic radii hs"] = {k: v / 100 for k, v in dct["Ionic_radii_hs"].items()} + del dct["Ionic_radii_hs"] + if "Ionic_radii_ls" in dct: + dct["Ionic radii ls"] = {k: v / 100 for k, v in dct["Ionic_radii_ls"].items()} + del dct["Ionic_radii_ls"] with open("periodic_table2.yaml", mode="w") as file: yaml.dump(data, file) with open("../pymatgen/core/periodic_table.json", mode="w") as file: @@ -150,9 +150,10 @@ def parse_shannon_radii(): from openpyxl import load_workbook wb = load_workbook("Shannon Radii.xlsx") - print(wb.get_sheet_names()) + print(wb.sheetnames()) sheet = wb["Sheet1"] i = 2 + el = charge = cn = None radii = collections.defaultdict(dict) while sheet[f"E{i}"].value: if sheet[f"A{i}"].value: @@ -162,8 +163,7 @@ def parse_shannon_radii(): radii[el][charge] = {} if sheet[f"C{i}"].value: cn = sheet[f"C{i}"].value - if cn not in radii[el][charge]: - radii[el][charge][cn] = {} + radii[el][charge].setdefault(cn, {}) spin = sheet[f"D{i}"].value if sheet[f"D{i}"].value is not None else "" @@ -236,6 +236,7 @@ def add_electron_affinities(): req = requests.get("https://wikipedia.org/wiki/Electron_affinity_(data_page)") soup = BeautifulSoup(req.text, "html.parser") + table = None for table in soup.find_all("table"): if "Hydrogen" in table.text: break @@ -272,6 +273,7 @@ def add_ionization_energies(): with open("NIST Atomic Ionization Energies Output.html") as file: soup = BeautifulSoup(file.read(), "html.parser") + table = None for table in soup.find_all("table"): if "Hydrogen" in table.text: break diff --git a/docs/CHANGES.md b/docs/CHANGES.md index 6ec54785c29..cd1d4cebf39 100644 --- a/docs/CHANGES.md +++ b/docs/CHANGES.md @@ -6,6 +6,107 @@ nav_order: 4 # Changelog +## v2024.4.13 + +Hot fix release for [v2024.4.12](#v2024412) to be yanked on PyPI due to https://github.com/materialsproject/pymatgen/issues/3751. + +### ๐Ÿ› Bug Fixes + +* Revert mistaken `Cohp.has_antibnd_states_below_efermi` rename by @JaGeo in https://github.com/materialsproject/pymatgen/pull/3750 +* Fix `typing_extension` `ImportError` in downstream packages by @janosh in https://github.com/materialsproject/pymatgen/pull/3752 +* Update some of the OPTIMADE aliases by @ml-evs in https://github.com/materialsproject/pymatgen/pull/3754 + +### ๐Ÿงน House-Keeping + +* Remove duplicate ruff rule in `pyproject.toml` by @Andrew-S-Rosen in https://github.com/materialsproject/pymatgen/pull/3755 + +**Full Changelog**: https://github.com/materialsproject/pymatgen/compare/v2024.4.12...v2024.4.13 + +## v2024.4.12 + +### ๐ŸŽ‰ New Features + +* Add `pymatgen.io.openff` module by @orionarcher in https://github.com/materialsproject/pymatgen/pull/3729 + +### ๐Ÿ› Bug Fixes + +* Fix blank line bug in `io.res.ResWriter` by @stefsmeets in https://github.com/materialsproject/pymatgen/pull/3671 +* Reset label for sites changed by `Structure.replace_species()` by @stefsmeets in https://github.com/materialsproject/pymatgen/pull/3672 +* Fix `phonopy.get_pmg_structure` `site_properties` key for magmoms by @JonathanSchmidt1 in https://github.com/materialsproject/pymatgen/pull/3679 +* Improve Bandoverlaps parser by @naik-aakash in https://github.com/materialsproject/pymatgen/pull/3689 +* Convert some `staticmethod` to `classmethod` by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3710 +* Correct units of Element.atomic_orbitals by @esoteric-ephemera in https://github.com/materialsproject/pymatgen/pull/3714 +* Add a fix for if a parameter is None in AimsControlIn by @tpurcell90 in https://github.com/materialsproject/pymatgen/pull/3727 +* Replace general `raise Exception` and add missing `raise` keyword by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3728 +* Fix `ChemicalPotentialDiagram` 2D plot not respecting `formal_chempots` setting by @uliaschauer in https://github.com/materialsproject/pymatgen/pull/3734 +* Update ENCUT type to float in incar_parameters.json by @yuuukuma in https://github.com/materialsproject/pymatgen/pull/3741 +* Clean up `core.surface` comments and docstrings by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3691 +* Fix `io.cp2k.input.DataFile` by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3745 + +### ๐Ÿ›  Enhancements + +* Ensure `MSONAtoms` is indeed `MSONable` when `Atoms.info` is loaded with goodies by @Andrew-S-Rosen in https://github.com/materialsproject/pymatgen/pull/3670 +* Generalize fatband plots from Lobster by @JaGeo in https://github.com/materialsproject/pymatgen/pull/3688 +* Plotting of Multicenter COBIs by @JaGeo in https://github.com/materialsproject/pymatgen/pull/2926 +* Support appending vectors to positions in XSF format by @mturiansky in https://github.com/materialsproject/pymatgen/pull/3704 +* Define `needs_u_correction(comp: CompositionLike) -> set[str]` utility function by @janosh in https://github.com/materialsproject/pymatgen/pull/3703 +* Add more flexibility to `PhononDOSPlotter` and `PhononBSPlotter` by @ab5424 in https://github.com/materialsproject/pymatgen/pull/3700 +* Define `ElementType` enum in `core/periodic_table.py` by @janosh in https://github.com/materialsproject/pymatgen/pull/3726 + +### ๐Ÿšง CI + +* Migrate CI dependency installation from `pip` to `uv` by @janosh in https://github.com/materialsproject/pymatgen/pull/3675 +* Prevent GitHub Actions from running docs-related CI on forks by @lan496 in https://github.com/materialsproject/pymatgen/pull/3697 + +### ๐Ÿ“– Documentation + +* Reformat docstrings to Google style and add type annotations by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3694 +* Breaking: all plot methods return `plt.Axes` by @janosh in https://github.com/materialsproject/pymatgen/pull/3749 + +### ๐Ÿงน House-Keeping + +* Clean up test files: VASP outputs by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3653 +* Clean up test files: VASP inputs by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3674 +* Clean up test files: dedicated VASP directories, `xyz`, `mcif`, `cssr`, `exciting`, `wannier90` by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3681 +* Remove exception printing when importing phonopy by @lan496 in https://github.com/materialsproject/pymatgen/pull/3696 +* Standardize test names: e.g. `LatticeTestCase` -> `TestLattice` by @janosh in https://github.com/materialsproject/pymatgen/pull/3693 +* Clean up tests by @janosh in https://github.com/materialsproject/pymatgen/pull/3713 +* Fix import order for `if TYPE_CHECKING:` block by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3711 +* Use `Self` type in Method Signatures by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3705 +* Remove deprecated `analysis.interface`, rename classes to PascalCase and rename `with_*` to `from_*` by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3725 +* Test `EntrySet.ground_states` and CIF writing in `NEBSet.write_input` by @janosh in https://github.com/materialsproject/pymatgen/pull/3732 + +### ๐Ÿš€ Performance + +* Dynamic `__hash__` for `BalancedReaction` by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3676 + +### ๐Ÿงช Tests + +* Clean up tests 2 by @janosh in https://github.com/materialsproject/pymatgen/pull/3716 +* Remove unnecessary `unittest.TestCase` subclassing by @janosh in https://github.com/materialsproject/pymatgen/pull/3718 + +### ๐Ÿ”’ Security Fixes + +* Avoid using `exec` in code by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3736 +* Avoid using `eval`, replace manual offset in `enumerate` and rename single letter variables by @DanielYang59 in https://github.com/materialsproject/pymatgen/pull/3739 + +### ๐Ÿท๏ธ Type Hints + +* `Self` return type on `from_dict` methods by @janosh in https://github.com/materialsproject/pymatgen/pull/3702 +* Return `self` from `Structure` methods `replace`, `substitute`, `remove_species`, `remove_sites` by @janosh in https://github.com/materialsproject/pymatgen/pull/3706 +* `Self` return type on `Lattice` methods by @janosh in https://github.com/materialsproject/pymatgen/pull/3707 + +### ๐Ÿคทโ€โ™‚๏ธ Other Changes + +* `os.path.(exists->isfile)` by @janosh in https://github.com/materialsproject/pymatgen/pull/3690 + +## New Contributors + +* @JonathanSchmidt1 made their first contribution in https://github.com/materialsproject/pymatgen/pull/3679 +* @uliaschauer made their first contribution in https://github.com/materialsproject/pymatgen/pull/3734 + +**Full Changelog**: https://github.com/materialsproject/pymatgen/compare/v2024.3.1...v2024.4.12 + ## v2024.3.1 ## What's Changed diff --git a/docs/apidoc/pymatgen.io.aims.rst b/docs/apidoc/pymatgen.io.aims.rst index 800001df3c5..7949c7d0e56 100644 --- a/docs/apidoc/pymatgen.io.aims.rst +++ b/docs/apidoc/pymatgen.io.aims.rst @@ -6,6 +6,14 @@ pymatgen.io.aims package :undoc-members: :show-inheritance: +Subpackages +----------- + +.. toctree:: + :maxdepth: 7 + + pymatgen.io.aims.sets + Submodules ---------- diff --git a/docs/apidoc/pymatgen.io.aims.sets.rst b/docs/apidoc/pymatgen.io.aims.sets.rst new file mode 100644 index 00000000000..29a1aafbd4f --- /dev/null +++ b/docs/apidoc/pymatgen.io.aims.sets.rst @@ -0,0 +1,34 @@ +pymatgen.io.aims.sets package +============================= + +.. automodule:: pymatgen.io.aims.sets + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +pymatgen.io.aims.sets.base module +--------------------------------- + +.. automodule:: pymatgen.io.aims.sets.base + :members: + :undoc-members: + :show-inheritance: + +pymatgen.io.aims.sets.bs module +------------------------------- + +.. automodule:: pymatgen.io.aims.sets.bs + :members: + :undoc-members: + :show-inheritance: + +pymatgen.io.aims.sets.core module +--------------------------------- + +.. automodule:: pymatgen.io.aims.sets.core + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/apidoc/pymatgen.io.rst b/docs/apidoc/pymatgen.io.rst index ed205163887..3d276f7d2fe 100644 --- a/docs/apidoc/pymatgen.io.rst +++ b/docs/apidoc/pymatgen.io.rst @@ -104,6 +104,14 @@ pymatgen.io.gaussian module :undoc-members: :show-inheritance: +pymatgen.io.icet module +----------------------- + +.. automodule:: pymatgen.io.icet + :members: + :undoc-members: + :show-inheritance: + pymatgen.io.jarvis module ------------------------- diff --git a/docs/apidoc/pymatgen.util.rst b/docs/apidoc/pymatgen.util.rst index 807645b00d9..8a645ccb065 100644 --- a/docs/apidoc/pymatgen.util.rst +++ b/docs/apidoc/pymatgen.util.rst @@ -6,6 +6,14 @@ pymatgen.util package :undoc-members: :show-inheritance: +Subpackages +----------- + +.. toctree:: + :maxdepth: 7 + + pymatgen.util.testing + Submodules ---------- @@ -89,14 +97,6 @@ pymatgen.util.string module :undoc-members: :show-inheritance: -pymatgen.util.testing module ----------------------------- - -.. automodule:: pymatgen.util.testing - :members: - :undoc-members: - :show-inheritance: - pymatgen.util.typing module --------------------------- diff --git a/docs/apidoc/pymatgen.util.testing.rst b/docs/apidoc/pymatgen.util.testing.rst new file mode 100644 index 00000000000..0ee4fd38001 --- /dev/null +++ b/docs/apidoc/pymatgen.util.testing.rst @@ -0,0 +1,18 @@ +pymatgen.util.testing package +============================= + +.. automodule:: pymatgen.util.testing + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +pymatgen.util.testing.aims module +--------------------------------- + +.. automodule:: pymatgen.util.testing.aims + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/modules.html b/docs/modules.html index 97adb20638f..d6d5696b2a3 100644 --- a/docs/modules.html +++ b/docs/modules.html @@ -4,7 +4,7 @@ - pymatgen — pymatgen 2024.1.27 documentation + pymatgen — pymatgen 2024.3.1 documentation @@ -17,7 +17,7 @@ - + @@ -37,7 +37,7 @@
- 2024.1.27 + 2024.3.1
@@ -2784,6 +2784,7 @@

pymatgen
  • SiteCollection.n_elems
  • SiteCollection.ntypesp
  • SiteCollection.num_sites
  • +
  • SiteCollection.reduced_formula
  • SiteCollection.remove_oxidation_states()
  • SiteCollection.remove_site_property()
  • SiteCollection.remove_spin()
  • @@ -3411,8 +3412,10 @@

    pymatgen
  • Entry.elements
  • Entry.energy
  • Entry.energy_per_atom
  • +
  • Entry.formula
  • Entry.is_element
  • Entry.normalize()
  • +
  • Entry.reduced_formula
  • Submodules
  • @@ -3722,6 +3725,10 @@

    pymatgen
  • pymatgen.io.aims package
  • +
  • pymatgen.io.icet module +
  • pymatgen.io.jarvis module
  • +
  • find_codopant()
  • pymatgen.transformations.site_transformations module
      @@ -5307,6 +5336,28 @@

      pymatgen

  • pymatgen.util package
  • Submodules
  • -
  • pymatgen.dao module
  • +
  • pymatgen.dao module
  • diff --git a/docs/pymatgen.alchemy.html b/docs/pymatgen.alchemy.html index c3d43ab1e1e..31e8b8c3389 100644 --- a/docs/pymatgen.alchemy.html +++ b/docs/pymatgen.alchemy.html @@ -4,7 +4,7 @@ - pymatgen.alchemy package — pymatgen 2024.1.27 documentation + pymatgen.alchemy package — pymatgen 2024.3.1 documentation @@ -17,7 +17,7 @@ - + @@ -37,7 +37,7 @@
    - 2024.1.27 + 2024.3.1
    @@ -175,14 +175,14 @@

    Submodules
    -class AbstractStructureFilter[source]๏ƒ
    -

    Bases: MSONable

    +class AbstractStructureFilter[source]๏ƒ +

    Bases: MSONable, ABC

    AbstractStructureFilter that defines an API to perform testing of Structures. Structures that return True to a test are retained during transmutation while those that return False are removed.

    -abstract test(structure: Structure)[source]๏ƒ
    +abstract test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Parameters:
    @@ -205,7 +205,7 @@

    Submodules
    -class ChargeBalanceFilter[source]๏ƒ
    +class ChargeBalanceFilter[source]๏ƒ

    Bases: AbstractStructureFilter

    This filter removes structures that are not charge balanced from the transmuter. This only works if the structure is oxidation state @@ -214,7 +214,7 @@

    Submodules
    -test(structure: Structure)[source]๏ƒ
    +test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Parameters:
    @@ -233,7 +233,7 @@

    Submodules
    -class ContainsSpecieFilter(species, strict_compare=False, AND=True, exclude=False)[source]๏ƒ
    +class ContainsSpecieFilter(species, strict_compare=False, AND=True, exclude=False)[source]๏ƒ

    Bases: AbstractStructureFilter

    Filter for structures containing certain elements or species. By default compares by atomic number.

    @@ -251,13 +251,13 @@

    Submodules
    -as_dict()[source]๏ƒ
    +as_dict()[source]๏ƒ

    Returns: MSONable dict.

    -classmethod from_dict(dct)[source]๏ƒ
    +classmethod from_dict(dct: dict) Self[source]๏ƒ
    Parameters:

    dct (dict) โ€“ Dict representation.

    @@ -270,7 +270,7 @@

    Submodules
    -test(structure: Structure)[source]๏ƒ
    +test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Returns:
    @@ -286,7 +286,7 @@

    Submodules
    -class RemoveDuplicatesFilter(structure_matcher: dict | StructureMatcher | None = None, symprec: float | None = None)[source]๏ƒ
    +class RemoveDuplicatesFilter(structure_matcher: dict | StructureMatcher | None = None, symprec: float | None = None)[source]๏ƒ

    Bases: AbstractStructureFilter

    This filter removes exact duplicate structures from the transmuter.

    Remove duplicate structures based on the structure matcher @@ -304,7 +304,7 @@

    Submodules
    -test(structure: Structure) bool[source]๏ƒ
    +test(structure: Structure) bool[source]๏ƒ
    Parameters:

    structure (Structure) โ€“ Input structure to test.

    @@ -322,7 +322,7 @@

    Submodules
    -class RemoveExistingFilter(existing_structures, structure_matcher=None, symprec=None)[source]๏ƒ
    +class RemoveExistingFilter(existing_structures, structure_matcher=None, symprec=None)[source]๏ƒ

    Bases: AbstractStructureFilter

    This filter removes structures existing in a given list from the transmuter.

    Remove existing structures based on the structure matcher @@ -341,13 +341,13 @@

    Submodules
    -as_dict()[source]๏ƒ
    +as_dict()[source]๏ƒ

    Returns: MSONable dict.

    -test(structure: Structure)[source]๏ƒ
    +test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Parameters:
    @@ -366,7 +366,7 @@

    Submodules
    -class SpecieProximityFilter(specie_and_min_dist_dict)[source]๏ƒ
    +class SpecieProximityFilter(specie_and_min_dist_dict)[source]๏ƒ

    Bases: AbstractStructureFilter

    This filter removes structures that have certain species that are too close together.

    @@ -382,13 +382,13 @@

    Submodules
    -as_dict()[source]๏ƒ
    +as_dict()[source]๏ƒ

    Returns: MSONable dict.

    -classmethod from_dict(dct)[source]๏ƒ
    +classmethod from_dict(dct: dict) Self[source]๏ƒ
    Parameters:

    dct (dict) โ€“ Dict representation.

    @@ -401,7 +401,7 @@

    Submodules
    -test(structure: Structure)[source]๏ƒ
    +test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Parameters:
    @@ -420,7 +420,7 @@

    Submodules
    -class SpeciesMaxDistFilter(sp1, sp2, max_dist)[source]๏ƒ
    +class SpeciesMaxDistFilter(sp1, sp2, max_dist)[source]๏ƒ

    Bases: AbstractStructureFilter

    This filter removes structures that do have two particular species that are not nearest neighbors by a predefined max_dist. For instance, if you are @@ -440,7 +440,7 @@

    Submodules
    -test(structure: Structure)[source]๏ƒ
    +test(structure: Structure)[source]๏ƒ

    Method to execute the test.

    Parameters:
    @@ -469,7 +469,7 @@

    Submodules
    -class TransformedStructure(structure: Structure, transformations: list[AbstractTransformation] | None = None, history: list[AbstractTransformation | dict[str, Any]] | None = None, other_parameters: dict[str, Any] | None = None)[source]๏ƒ
    +class TransformedStructure(structure: Structure, transformations: AbstractTransformation | Sequence[AbstractTransformation] | None = None, history: list[AbstractTransformation | dict[str, Any]] | None = None, other_parameters: dict[str, Any] | None = None)[source]๏ƒ

    Bases: MSONable

    Container object for new structures that include history of transformations.

    @@ -480,8 +480,7 @@

    SubmodulesParameters:
    • structure (Structure) โ€“ Input structure

    • -
    • transformations (list[Transformation]) โ€“ List of transformations to -apply.

    • +
    • transformations (list[Transformation]) โ€“ List of transformations to apply.

    • history (list[Transformation]) โ€“ Previous history.

    • other_parameters (dict) โ€“ Additional parameters to be added.

    @@ -489,7 +488,7 @@

    Submodules
    -append_filter(structure_filter: AbstractStructureFilter) None[source]๏ƒ
    +append_filter(structure_filter: AbstractStructureFilter) None[source]๏ƒ

    Adds a filter.

    Parameters:
    @@ -501,7 +500,7 @@

    Submodules
    -append_transformation(transformation, return_alternatives: bool = False, clear_redo: bool = True) list[TransformedStructure] | None[source]๏ƒ
    +append_transformation(transformation, return_alternatives: bool = False, clear_redo: bool = True) list[TransformedStructure] | None[source]๏ƒ

    Appends a transformation to the TransformedStructure.

    Parameters:
    @@ -523,13 +522,13 @@

    Submodules
    -as_dict() dict[str, Any][source]๏ƒ
    +as_dict() dict[str, Any][source]๏ƒ

    Dict representation of the TransformedStructure.

    -extend_transformations(transformations: list[AbstractTransformation], return_alternatives: bool = False) None[source]๏ƒ
    +extend_transformations(transformations: list[AbstractTransformation], return_alternatives: bool = False) None[source]๏ƒ

    Extends a sequence of transformations to the TransformedStructure.

    Parameters:
    @@ -546,7 +545,7 @@

    Submodules
    -classmethod from_cif_str(cif_string: str, transformations: list[AbstractTransformation] | None = None, primitive: bool = True, occupancy_tolerance: float = 1.0) TransformedStructure[source]๏ƒ
    +classmethod from_cif_str(cif_string: str, transformations: list[AbstractTransformation] | None = None, primitive: bool = True, occupancy_tolerance: float = 1.0) TransformedStructure[source]๏ƒ

    Generates TransformedStructure from a cif string.

    Parameters:
    @@ -574,13 +573,13 @@

    Submodules
    -classmethod from_dict(dct) TransformedStructure[source]๏ƒ
    +classmethod from_dict(dct: dict) TransformedStructure[source]๏ƒ

    Creates a TransformedStructure from a dict.

    -classmethod from_poscar_str(poscar_string: str, transformations: list[AbstractTransformation] | None = None) TransformedStructure[source]๏ƒ
    +classmethod from_poscar_str(poscar_string: str, transformations: list[AbstractTransformation] | None = None) TransformedStructure[source]๏ƒ

    Generates TransformedStructure from a poscar string.

    Parameters:
    @@ -595,7 +594,7 @@

    Submodules
    -classmethod from_snl(snl: StructureNL) TransformedStructure[source]๏ƒ
    +classmethod from_snl(snl: StructureNL) TransformedStructure[source]๏ƒ

    Create TransformedStructure from SNL.

    Parameters:
    @@ -609,7 +608,7 @@

    Submodules
    -get_vasp_input(vasp_input_set: type[~pymatgen.io.vasp.sets.VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, **kwargs) dict[str, Any][source]๏ƒ
    +get_vasp_input(vasp_input_set: type[~pymatgen.io.vasp.sets.VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, **kwargs) dict[str, Any][source]๏ƒ

    Returns VASP input as a dict of VASP objects.

    Parameters:
    @@ -624,7 +623,7 @@

    Submodules
    -redo_next_change() None[source]๏ƒ
    +redo_next_change() None[source]๏ƒ

    Redo the last undone change in the TransformedStructure.

    Raises:
    @@ -635,7 +634,7 @@

    Submodules
    -set_parameter(key: str, value: Any) None[source]๏ƒ
    +set_parameter(key: str, value: Any) TransformedStructure[source]๏ƒ

    Sets a parameter.

    Parameters:
    @@ -644,37 +643,42 @@

    SubmodulesReturns: +

    TransformedStructure

    +

    -property structures: list[Structure][source]๏ƒ
    +property structures: list[Structure][source]๏ƒ

    Copy of all structures in the TransformedStructure. A structure is stored after every single transformation.

    -to_snl(authors, **kwargs) StructureNL[source]๏ƒ
    -

    Generate SNL from TransformedStructure.

    +to_snl(authors: list[str], **kwargs) StructureNL[source]๏ƒ +

    Generate a StructureNL from TransformedStructure.

    Parameters:
      -
    • authors โ€“ List of authors

    • -
    • **kwargs โ€“

      All kwargs supported by StructureNL.

      -

    • +
    • authors (List[str]) โ€“ List of authors contributing to the generated StructureNL.

    • +
    • **kwargs (Any) โ€“ All kwargs supported by StructureNL.

    Returns:
    -

    StructureNL

    +

    The generated StructureNL object.

    +
    +
    Return type:
    +

    StructureNL

    -undo_last_change() None[source]๏ƒ
    +undo_last_change() None[source]๏ƒ

    Undo the last change in the TransformedStructure.

    Raises:
    @@ -685,7 +689,7 @@

    Submodules
    -property was_modified: bool[source]๏ƒ
    +property was_modified: bool[source]๏ƒ

    Boolean describing whether the last transformation on the structure made any alterations to it one example of when this would return false is in the case of performing a substitution transformation on the @@ -694,7 +698,7 @@

    Submodules
    -write_vasp_input(vasp_input_set: type[~pymatgen.io.vasp.sets.VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, output_dir: str = '.', create_directory: bool = True, **kwargs) None[source]๏ƒ
    +write_vasp_input(vasp_input_set: type[~pymatgen.io.vasp.sets.VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, output_dir: str = '.', create_directory: bool = True, **kwargs) None[source]๏ƒ

    Writes VASP input to an output_dir.

    Parameters:
    @@ -723,7 +727,7 @@

    Submodules
    -class CifTransmuter(cif_string, transformations=None, primitive=True, extend_collection=False)[source]๏ƒ
    +class CifTransmuter(cif_string, transformations=None, primitive=True, extend_collection=False)[source]๏ƒ

    Bases: StandardTransmuter

    Generates a Transmuter from a cif string, possibly containing multiple structures.

    @@ -745,7 +749,7 @@

    Submodules
    -classmethod from_filenames(filenames, transformations=None, primitive=True, extend_collection=False)[source]๏ƒ
    +classmethod from_filenames(filenames, transformations=None, primitive=True, extend_collection=False)[source]๏ƒ

    Generates a TransformedStructureCollection from a cif, possibly containing multiple structures.

    @@ -765,7 +769,7 @@

    Submodules
    -class PoscarTransmuter(poscar_string, transformations=None, extend_collection=False)[source]๏ƒ
    +class PoscarTransmuter(poscar_string, transformations=None, extend_collection=False)[source]๏ƒ

    Bases: StandardTransmuter

    Generates a transmuter from a sequence of POSCARs.

    @@ -781,7 +785,7 @@

    Submodules
    -static from_filenames(poscar_filenames, transformations=None, extend_collection=False)[source]๏ƒ
    +static from_filenames(poscar_filenames, transformations=None, extend_collection=False)[source]๏ƒ

    Convenient constructor to generates a POSCAR transmuter from a list of POSCAR filenames.

    @@ -800,13 +804,13 @@

    Submodules
    -class StandardTransmuter(transformed_structures, transformations=None, extend_collection: int = 0, ncores: int | None = None)[source]๏ƒ
    +class StandardTransmuter(transformed_structures, transformations=None, extend_collection: int = 0, ncores: int | None = None)[source]๏ƒ

    Bases: object

    An example of a Transmuter object, which performs a sequence of transformations on many structures to generate TransformedStructures.

    -transformed_structures[source]๏ƒ
    +transformed_structures[source]๏ƒ

    List of all transformed structures.

    Type:
    @@ -836,7 +840,7 @@

    Submodules
    -add_tags(tags)[source]๏ƒ
    +add_tags(tags)[source]๏ƒ

    Add tags for the structures generated by the transmuter.

    Parameters:
    @@ -848,7 +852,7 @@

    Submodules
    -append_transformation(transformation, extend_collection=False, clear_redo=True)[source]๏ƒ
    +append_transformation(transformation, extend_collection=False, clear_redo=True)[source]๏ƒ

    Appends a transformation to all TransformedStructures.

    Parameters:
    @@ -878,7 +882,7 @@

    Submodules
    -append_transformed_structures(trafo_structs_or_transmuter)[source]๏ƒ
    +append_transformed_structures(trafo_structs_or_transmuter)[source]๏ƒ

    Method is overloaded to accept either a list of transformed structures or transmuter, it which case it appends the second transmuterโ€s structures.

    @@ -892,7 +896,7 @@

    Submodules
    -apply_filter(structure_filter)[source]๏ƒ
    +apply_filter(structure_filter)[source]๏ƒ

    Applies a structure_filter to the list of TransformedStructures in the transmuter.

    @@ -904,7 +908,7 @@

    Submodules
    -extend_transformations(transformations)[source]๏ƒ
    +extend_transformations(transformations)[source]๏ƒ

    Extends a sequence of transformations to the TransformedStructure.

    Parameters:
    @@ -915,7 +919,7 @@

    Submodules
    -classmethod from_structures(structures, transformations=None, extend_collection=0)[source]๏ƒ
    +classmethod from_structures(structures, transformations=None, extend_collection=0)[source]๏ƒ

    Alternative constructor from structures rather than TransformedStructures.

    @@ -938,7 +942,7 @@

    Submodules
    -redo_next_change()[source]๏ƒ
    +redo_next_change()[source]๏ƒ

    Redo the last undone transformation in the TransformedStructure.

    Raises:
    @@ -949,7 +953,7 @@

    Submodules
    -set_parameter(key, value)[source]๏ƒ
    +set_parameter(key, value)[source]๏ƒ

    Add parameters to the transmuter. Additional parameters are stored in the as_dict() output.

    @@ -964,7 +968,7 @@

    Submodules
    -undo_last_change()[source]๏ƒ
    +undo_last_change()[source]๏ƒ

    Undo the last transformation in the TransformedStructure.

    Raises:
    @@ -975,7 +979,7 @@

    Submodules
    -write_vasp_input(**kwargs)[source]๏ƒ
    +write_vasp_input(**kwargs)[source]๏ƒ

    Batch write vasp input for a sequence of transformed structures to output_dir, following the format output_dir/{formula}_{number}.

    @@ -989,7 +993,7 @@

    Submodules
    -batch_write_vasp_input(transformed_structures: Sequence[TransformedStructure], vasp_input_set: type[VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, output_dir: str = '.', create_directory: bool = True, subfolder: Callable[[TransformedStructure], str] | None = None, include_cif: bool = False, **kwargs)[source]๏ƒ
    +batch_write_vasp_input(transformed_structures: Sequence[TransformedStructure], vasp_input_set: type[VaspInputSet] = <class 'pymatgen.io.vasp.sets.MPRelaxSet'>, output_dir: str = '.', create_directory: bool = True, subfolder: Callable[[TransformedStructure], str] | None = None, include_cif: bool = False, **kwargs)[source]๏ƒ

    Batch write vasp input for a sequence of transformed structures to output_dir, following the format output_dir/{group}/{formula}_{number}.

    diff --git a/docs/pymatgen.analysis.chemenv.connectivity.html b/docs/pymatgen.analysis.chemenv.connectivity.html index 923a0b85d0a..45d3e8b8ab2 100644 --- a/docs/pymatgen.analysis.chemenv.connectivity.html +++ b/docs/pymatgen.analysis.chemenv.connectivity.html @@ -4,7 +4,7 @@ - pymatgen.analysis.chemenv.connectivity package — pymatgen 2024.1.27 documentation + pymatgen.analysis.chemenv.connectivity package — pymatgen 2024.3.1 documentation @@ -17,7 +17,7 @@ - + @@ -37,7 +37,7 @@
    - 2024.1.27 + 2024.3.1
    @@ -175,7 +175,7 @@

    Submodules
    -class ConnectedComponent(environments=None, links=None, environments_data=None, links_data=None, graph=None)[source]๏ƒ
    +class ConnectedComponent(environments=None, links=None, environments_data=None, links_data=None, graph=None)[source]๏ƒ

    Bases: MSONable

    Class used to describe the connected components in a structure in terms of coordination environments.

    Constructor for the ConnectedComponent object.

    @@ -198,7 +198,7 @@

    Submodules
    -as_dict()[source]๏ƒ
    +as_dict()[source]๏ƒ

    Bson-serializable dict representation of the ConnectedComponent object.

    Returns:
    @@ -212,7 +212,7 @@

    Submodules
    -compute_periodicity(algorithm='all_simple_paths') None[source]๏ƒ
    +compute_periodicity(algorithm='all_simple_paths') None[source]๏ƒ
    Parameters:

    () (algorithm) โ€“

    @@ -222,19 +222,19 @@

    Submodules
    -compute_periodicity_all_simple_paths_algorithm()[source]๏ƒ
    +compute_periodicity_all_simple_paths_algorithm()[source]๏ƒ

    Get the periodicity vectors of the connected component.

    -compute_periodicity_cycle_basis() None[source]๏ƒ
    +compute_periodicity_cycle_basis() None[source]๏ƒ

    Compute periodicity vectors of the connected component.

    -coordination_sequence(source_node, path_size=5, coordination='number', include_source=False)[source]๏ƒ
    +coordination_sequence(source_node, path_size=5, coordination='number', include_source=False)[source]๏ƒ

    Get the coordination sequence for a given node.

    Parameters:
    @@ -284,7 +284,7 @@

    Submodules
    -description(full=False)[source]๏ƒ
    +description(full=False)[source]๏ƒ
    Parameters:

    full (bool) โ€“ Whether to return a short or full description.

    @@ -300,7 +300,7 @@

    Submodules
    -elastic_centered_graph(start_node=None)[source]๏ƒ
    +elastic_centered_graph(start_node=None)[source]๏ƒ
    Parameters:

    () (start_node) โ€“

    @@ -316,12 +316,12 @@

    Submodules
    -classmethod from_dict(d)[source]๏ƒ
    +classmethod from_dict(dct: dict) Self[source]๏ƒ

    Reconstructs the ConnectedComponent object from a dict representation of the ConnectedComponent object created using the as_dict method.

    Parameters:
    -

    d (dict) โ€“ dict representation of the ConnectedComponent object

    +

    dct (dict) โ€“ dict representation of the ConnectedComponent object

    Returns:

    The connected component representing the links of a given set of environments.

    @@ -334,7 +334,7 @@

    Submodules
    -classmethod from_graph(g)[source]๏ƒ
    +classmethod from_graph(g)[source]๏ƒ

    Constructor for the ConnectedComponent object from a graph of the connected component.

    Parameters:
    @@ -351,7 +351,7 @@

    Submodules
    -property graph[source]๏ƒ
    +property graph[source]๏ƒ

    Return the graph of this connected component.

    Returns:
    @@ -369,37 +369,37 @@

    Submodules
    -property is_0d: bool[source]๏ƒ
    +property is_0d: bool[source]๏ƒ

    Whether this connected component is 0-dimensional.

    -property is_1d: bool[source]๏ƒ
    +property is_1d: bool[source]๏ƒ

    Whether this connected component is 1-dimensional.

    -property is_2d: bool[source]๏ƒ
    +property is_2d: bool[source]๏ƒ

    Whether this connected component is 2-dimensional.

    -property is_3d: bool[source]๏ƒ
    +property is_3d: bool[source]๏ƒ

    Whether this connected component is 3-dimensional.

    -property is_periodic: bool[source]๏ƒ
    +property is_periodic: bool[source]๏ƒ

    Whether this connected component is periodic.

    -make_supergraph(multiplicity)[source]๏ƒ
    +make_supergraph(multiplicity)[source]๏ƒ
    Parameters:

    () (multiplicity) โ€“

    @@ -415,19 +415,19 @@

    Submodules
    -property periodicity[source]๏ƒ
    +property periodicity[source]๏ƒ

    Get periodicity of this connected component.

    -property periodicity_vectors[source]๏ƒ
    +property periodicity_vectors[source]๏ƒ

    Get periodicity vectors of this connected component.

    -show_graph(graph: MultiGraph | None = None, save_file: str | None = None, drawing_type: str = 'internal') None[source]๏ƒ
    +show_graph(graph: MultiGraph | None = None, save_file: str | None = None, drawing_type: str = 'internal') None[source]๏ƒ

    Displays the graph using the specified drawing type.

    Parameters:
    @@ -445,7 +445,7 @@

    Submodules
    -draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None)[source]๏ƒ
    +draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None)[source]๏ƒ

    Draw network of environments in a matplotlib figure axes.

    Parameters:
    @@ -462,7 +462,7 @@

    Submodules
    -make_supergraph(graph, multiplicity, periodicity_vectors)[source]๏ƒ
    +make_supergraph(graph, multiplicity, periodicity_vectors)[source]๏ƒ

    Make super graph from a graph of environments.

    Parameters:
    @@ -487,31 +487,33 @@

    Submodules
    -class ConnectivityFinder(multiple_environments_choice=None)[source]๏ƒ
    +class ConnectivityFinder(multiple_environments_choice=None)[source]๏ƒ

    Bases: object

    Main class used to find the structure connectivity of a structure.

    Constructor for the ConnectivityFinder.

    Parameters:
    -

    multiple_environments_choice โ€“ defines the procedure to apply when

    +
      +
    • multiple_environments_choice โ€“ defines the procedure to apply when

    • +
    • one (the environment of a given site is described as a "mix" of more than) โ€“

    • +
    • environments. (coordination) โ€“

    • +
    -

    the environment of a given site is described as a โ€œmixโ€ of more than one -coordination environments.

    -get_structure_connectivity(light_structure_environments)[source]๏ƒ
    +get_structure_connectivity(light_structure_environments)[source]๏ƒ

    Get the structure connectivity from the coordination environments provided as an input.

    Parameters:
    -

    light_structure_environments โ€“ LightStructureEnvironments with the

    +
      +
    • light_structure_environments โ€“ LightStructureEnvironments with the

    • +
    • structure (relevant coordination environments in the) โ€“

    • +
    -
    -

    relevant coordination environments in the structure

    -
    -
    Returns:
    -

    a StructureConnectivity object describing the connectivity of

    +
    Returns:
    +

    a StructureConnectivity object describing the connectivity of

    the environments in the structure

    @@ -519,7 +521,7 @@

    Submodules
    -setup_parameters(multiple_environments_choice)[source]๏ƒ
    +setup_parameters(multiple_environments_choice)[source]๏ƒ

    Setup of the parameters for the connectivity finder.

    @@ -531,7 +533,7 @@

    Submodules
    -class AbstractEnvironmentNode(central_site, i_central_site)[source]๏ƒ
    +class AbstractEnvironmentNode(central_site, i_central_site)[source]๏ƒ

    Bases: MSONable

    Abstract class used to define an environment as a node in a graph.

    Constructor for the AbstractEnvironmentNode object.

    @@ -546,104 +548,104 @@

    Submodules
    -ATOM = 6[source]๏ƒ
    +ATOM = 6[source]๏ƒ

    -CE_NNBCES_NBCES_LIGANDS = -1[source]๏ƒ
    +CE_NNBCES_NBCES_LIGANDS = -1[source]๏ƒ
    -COORDINATION_ENVIRONMENT = 0[source]๏ƒ
    +COORDINATION_ENVIRONMENT = 0[source]๏ƒ
    -DEFAULT_EXTENSIONS = (6, 0)[source]๏ƒ
    +DEFAULT_EXTENSIONS = (6, 0)[source]๏ƒ
    -LIGANDS_ARRANGEMENT = 4[source]๏ƒ
    +LIGANDS_ARRANGEMENT = 4[source]๏ƒ
    -NEIGHBORING_CES = 2[source]๏ƒ
    +NEIGHBORING_CES = 2[source]๏ƒ
    -NEIGHBORING_COORDINATION_ENVIRONMENTS = 2[source]๏ƒ
    +NEIGHBORING_COORDINATION_ENVIRONMENTS = 2[source]๏ƒ
    -NEIGHBORS_LIGANDS_ARRANGEMENT = 5[source]๏ƒ
    +NEIGHBORS_LIGANDS_ARRANGEMENT = 5[source]๏ƒ
    -NUMBER_OF_LIGANDS_FOR_EACH_NEIGHBORING_CE = 3[source]๏ƒ
    +NUMBER_OF_LIGANDS_FOR_EACH_NEIGHBORING_CE = 3[source]๏ƒ
    -NUMBER_OF_LIGANDS_FOR_EACH_NEIGHBORING_COORDINATION_ENVIRONMENT = 3[source]๏ƒ
    +NUMBER_OF_LIGANDS_FOR_EACH_NEIGHBORING_COORDINATION_ENVIRONMENT = 3[source]๏ƒ
    -NUMBER_OF_NEIGHBORING_CES = 1[source]๏ƒ
    +NUMBER_OF_NEIGHBORING_CES = 1[source]๏ƒ
    -NUMBER_OF_NEIGHBORING_COORDINATION_ENVIRONMENTS = 1[source]๏ƒ
    +NUMBER_OF_NEIGHBORING_COORDINATION_ENVIRONMENTS = 1[source]๏ƒ
    -property atom_symbol[source]๏ƒ
    +property atom_symbol[source]๏ƒ

    Symbol of the atom on the central site.

    -property ce[source]๏ƒ
    +property ce[source]๏ƒ

    Coordination environment of this node.

    -property ce_symbol[source]๏ƒ
    +property ce_symbol[source]๏ƒ

    Coordination environment of this node.

    -abstract property coordination_environment[source]๏ƒ
    +abstract property coordination_environment[source]๏ƒ

    Coordination environment of this node.

    -everything_equal(other)[source]๏ƒ
    +everything_equal(other)[source]๏ƒ

    Checks equality with respect to another AbstractEnvironmentNode using the index of the central site as well as the central site itself.

    -property isite[source]๏ƒ
    +property isite[source]๏ƒ

    Index of the central site.

    -property mp_symbol[source]๏ƒ
    +property mp_symbol[source]๏ƒ

    Coordination environment of this node.

    @@ -651,7 +653,7 @@

    Submodules
    -class EnvironmentNode(central_site, i_central_site, ce_symbol)[source]๏ƒ
    +class EnvironmentNode(central_site, i_central_site, ce_symbol)[source]๏ƒ

    Bases: AbstractEnvironmentNode

    Class used to define an environment as a node in a graph.

    Constructor for the EnvironmentNode object.

    @@ -667,13 +669,13 @@

    Submodules
    -property coordination_environment[source]๏ƒ
    +property coordination_environment[source]๏ƒ

    Coordination environment of this node.

    -everything_equal(other)[source]๏ƒ
    +everything_equal(other)[source]๏ƒ

    Compare with another environment node.

    Returns:
    @@ -689,7 +691,7 @@

    Submodules
    -get_environment_node(central_site, i_central_site, ce_symbol)[source]๏ƒ
    +get_environment_node(central_site, i_central_site, ce_symbol)[source]๏ƒ

    Get the EnvironmentNode class or subclass for the given site and symbol.

    Parameters:
    @@ -711,7 +713,7 @@

    Submodules
    -class StructureConnectivity(light_structure_environment, connectivity_graph=None, environment_subgraphs=None)[source]๏ƒ
    +class StructureConnectivity(light_structure_environment, connectivity_graph=None, environment_subgraphs=None)[source]๏ƒ

    Bases: MSONable

    Main class containing the connectivity of a structure.

    Constructor for the StructureConnectivity object.

    @@ -732,7 +734,7 @@

    Submodules
    -add_bonds(isite, site_neighbors_set)[source]๏ƒ
    +add_bonds(isite, site_neighbors_set)[source]๏ƒ

    Add the bonds for a given site index to the structure connectivity graph.

    Parameters:
    @@ -746,19 +748,19 @@

    Submodules
    -add_sites()[source]๏ƒ
    +add_sites()[source]๏ƒ

    Add the sites in the structure connectivity graph.

    -as_dict()[source]๏ƒ
    +as_dict()[source]๏ƒ

    Convert to MSONable dict.

    -environment_subgraph(environments_symbols=None, only_atoms=None)[source]๏ƒ
    +environment_subgraph(environments_symbols=None, only_atoms=None)[source]๏ƒ
    Parameters:

    -

    - - - - - - \ No newline at end of file + + + diff --git a/pymatgen/alchemy/filters.py b/pymatgen/alchemy/filters.py index 6d060675acc..3f1bf77091e 100644 --- a/pymatgen/alchemy/filters.py +++ b/pymatgen/alchemy/filters.py @@ -13,6 +13,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure @@ -92,7 +94,7 @@ def __repr__(self): ] ) - def as_dict(self): + def as_dict(self) -> dict: """Returns: MSONable dict.""" return { "@module": type(self).__module__, @@ -106,7 +108,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -165,7 +167,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. diff --git a/pymatgen/alchemy/materials.py b/pymatgen/alchemy/materials.py index f3f5e1efeff..5be0288de4a 100644 --- a/pymatgen/alchemy/materials.py +++ b/pymatgen/alchemy/materials.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.alchemy.filters import AbstractStructureFilter @@ -212,7 +214,7 @@ def write_vasp_input( **kwargs: All keyword args supported by the VASP input set. """ vasp_input_set(self.final_structure, **kwargs).write_input(output_dir, make_dir_if_not_present=create_directory) - with open(f"{output_dir}/transformations.json", mode="w") as file: + with open(f"{output_dir}/transformations.json", mode="w", encoding="utf-8") as file: json.dump(self.as_dict(), file) def __str__(self) -> str: @@ -267,7 +269,7 @@ def from_cif_str( transformations: list[AbstractTransformation] | None = None, primitive: bool = True, occupancy_tolerance: float = 1.0, - ) -> TransformedStructure: + ) -> Self: """Generates TransformedStructure from a cif string. Args: @@ -311,7 +313,7 @@ def from_poscar_str( cls, poscar_string: str, transformations: list[AbstractTransformation] | None = None, - ) -> TransformedStructure: + ) -> Self: """Generates TransformedStructure from a poscar string. Args: @@ -339,24 +341,26 @@ def as_dict(self) -> dict[str, Any]: dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ dct["history"] = jsanitize(self.history) - dct["last_modified"] = str(datetime.datetime.utcnow()) + dct["last_modified"] = str(datetime.datetime.now(datetime.timezone.utc)) dct["other_parameters"] = jsanitize(self.other_parameters) return dct @classmethod - def from_dict(cls, dct) -> TransformedStructure: + def from_dict(cls, dct: dict) -> Self: """Creates a TransformedStructure from a dict.""" struct = Structure.from_dict(dct) return cls(struct, history=dct["history"], other_parameters=dct.get("other_parameters")) - def to_snl(self, authors, **kwargs) -> StructureNL: - """Generate SNL from TransformedStructure. + def to_snl(self, authors: list[str], **kwargs) -> StructureNL: + """ + Generate a StructureNL from TransformedStructure. - :param authors: List of authors - :param **kwargs: All kwargs supported by StructureNL. + Args: + authors (List[str]): List of authors contributing to the generated StructureNL. + **kwargs (Any): All kwargs supported by StructureNL. Returns: - StructureNL + StructureNL: The generated StructureNL object. """ if self.other_parameters: warn("Data in TransformedStructure.other_parameters discarded during type conversion to SNL") @@ -374,7 +378,7 @@ def to_snl(self, authors, **kwargs) -> StructureNL: return StructureNL(self.final_structure, authors, history=history, **kwargs) @classmethod - def from_snl(cls, snl: StructureNL) -> TransformedStructure: + def from_snl(cls, snl: StructureNL) -> Self: """Create TransformedStructure from SNL. Args: diff --git a/pymatgen/alchemy/transmuters.py b/pymatgen/alchemy/transmuters.py index 1193c83caa1..e4b4b4e1ffd 100644 --- a/pymatgen/alchemy/transmuters.py +++ b/pymatgen/alchemy/transmuters.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + __author__ = "Shyue Ping Ong, Will Richards" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -42,7 +44,7 @@ def __init__( transformations=None, extend_collection: int = 0, ncores: int | None = None, - ): + ) -> None: """Initializes a transmuter from an initial list of pymatgen.alchemy.materials.TransformedStructure. @@ -71,7 +73,16 @@ def __getitem__(self, index): def __getattr__(self, name): return [getattr(x, name) for x in self.transformed_structures] - def undo_last_change(self): + def __len__(self): + return len(self.transformed_structures) + + def __str__(self): + output = ["Current structures", "------------"] + for x in self.transformed_structures: + output.append(str(x.final_structure)) + return "\n".join(output) + + def undo_last_change(self) -> None: """Undo the last transformation in the TransformedStructure. Raises: @@ -80,7 +91,7 @@ def undo_last_change(self): for x in self.transformed_structures: x.undo_last_change() - def redo_next_change(self): + def redo_next_change(self) -> None: """Redo the last undone transformation in the TransformedStructure. Raises: @@ -89,9 +100,6 @@ def redo_next_change(self): for x in self.transformed_structures: x.redo_next_change() - def __len__(self): - return len(self.transformed_structures) - def append_transformation(self, transformation, extend_collection=False, clear_redo=True): """Appends a transformation to all TransformedStructures. @@ -178,12 +186,6 @@ def add_tags(self, tags): """ self.set_parameter("tags", tags) - def __str__(self): - output = ["Current structures", "------------"] - for x in self.transformed_structures: - output.append(str(x.final_structure)) - return "\n".join(output) - def append_transformed_structures(self, trafo_structs_or_transmuter): """Method is overloaded to accept either a list of transformed structures or transmuter, it which case it appends the second transmuter"s @@ -201,7 +203,7 @@ def append_transformed_structures(self, trafo_structs_or_transmuter): self.transformed_structures.extend(trafo_structs_or_transmuter) @classmethod - def from_structures(cls, structures, transformations=None, extend_collection=0): + def from_structures(cls, structures, transformations=None, extend_collection=0) -> Self: """Alternative constructor from structures rather than TransformedStructures. @@ -256,7 +258,7 @@ def __init__(self, cif_string, transformations=None, primitive=True, extend_coll super().__init__(transformed_structures, transformations, extend_collection) @classmethod - def from_filenames(cls, filenames, transformations=None, primitive=True, extend_collection=False): + def from_filenames(cls, filenames, transformations=None, primitive=True, extend_collection=False) -> Self: """Generates a TransformedStructureCollection from a cif, possibly containing multiple structures. @@ -269,7 +271,7 @@ def from_filenames(cls, filenames, transformations=None, primitive=True, extend_ """ cif_files = [] for filename in filenames: - with open(filename) as file: + with open(filename, encoding="utf-8") as file: cif_files.append(file.read()) return cls( "\n".join(cif_files), @@ -294,8 +296,8 @@ def __init__(self, poscar_string, transformations=None, extend_collection=False) trafo_struct = TransformedStructure.from_poscar_str(poscar_string, []) super().__init__([trafo_struct], transformations, extend_collection=extend_collection) - @staticmethod - def from_filenames(poscar_filenames, transformations=None, extend_collection=False): + @classmethod + def from_filenames(cls, poscar_filenames, transformations=None, extend_collection=False) -> StandardTransmuter: """Convenient constructor to generates a POSCAR transmuter from a list of POSCAR filenames. @@ -308,7 +310,7 @@ def from_filenames(poscar_filenames, transformations=None, extend_collection=Fal """ trafo_structs = [] for filename in poscar_filenames: - with open(filename) as file: + with open(filename, encoding="utf-8") as file: trafo_structs.append(TransformedStructure.from_poscar_str(file.read(), [])) return StandardTransmuter(trafo_structs, transformations, extend_collection=extend_collection) diff --git a/pymatgen/analysis/adsorption.py b/pymatgen/analysis/adsorption.py index 44d0c9330cb..527b0304085 100644 --- a/pymatgen/analysis/adsorption.py +++ b/pymatgen/analysis/adsorption.py @@ -24,7 +24,11 @@ from pymatgen.util.coord import in_coord_list_pbc if TYPE_CHECKING: + import matplotlib.pyplot as plt from numpy.typing import ArrayLike + from typing_extensions import Self + + from pymatgen.core.surface import Slab __author__ = "Joseph Montoya" __copyright__ = "Copyright 2016, The Materials Project" @@ -54,7 +58,7 @@ class AdsorbateSiteFinder: """ def __init__( - self, slab, selective_dynamics: bool = False, height: float = 0.9, mi_vec: ArrayLike | None = None + self, slab: Slab, selective_dynamics: bool = False, height: float = 0.9, mi_vec: ArrayLike | None = None ) -> None: """Create an AdsorbateSiteFinder object. @@ -88,7 +92,7 @@ def from_bulk_and_miller( center_slab=True, selective_dynamics=False, undercoord_threshold=0.09, - ): + ) -> Self: """This method constructs the adsorbate site finder from a bulk structure and a miller index, which allows the surface sites to be determined from the difference in bulk and slab coordination, as @@ -132,7 +136,7 @@ def from_bulk_and_miller( vnn_surface = VoronoiNN(tol=0.05, allow_pathological=True) - surf_props, undercoords = [], [] + surf_props, under_coords = [], [] this_mi_vec = get_mi_vec(this_slab) mi_mags = [np.dot(this_mi_vec, site.coords) for site in this_slab] average_mi_mag = np.average(mi_mags) @@ -140,20 +144,20 @@ def from_bulk_and_miller( bulk_coord = this_slab.site_properties["bulk_coordinations"][n] slab_coord = len(vnn_surface.get_nn(this_slab, n)) mi_mag = np.dot(this_mi_vec, site.coords) - undercoord = (bulk_coord - slab_coord) / bulk_coord - undercoords += [undercoord] - if undercoord > undercoord_threshold and mi_mag > average_mi_mag: + under_coord = (bulk_coord - slab_coord) / bulk_coord + under_coords += [under_coord] + if under_coord > undercoord_threshold and mi_mag > average_mi_mag: surf_props += ["surface"] else: surf_props += ["subsurface"] new_site_properties = { "surface_properties": surf_props, - "undercoords": undercoords, + "undercoords": under_coords, } new_slab = this_slab.copy(site_properties=new_site_properties) return cls(new_slab, selective_dynamics) - def find_surface_sites_by_height(self, slab, height=0.9, xy_tol=0.05): + def find_surface_sites_by_height(self, slab: Slab, height=0.9, xy_tol=0.05): """This method finds surface sites by determining which sites are within a threshold value in height from the topmost site in a list of sites. @@ -179,7 +183,8 @@ def find_surface_sites_by_height(self, slab, height=0.9, xy_tol=0.05): # sort surface sites by height surf_sites = [s for (h, s) in zip(m_projs[mask], surf_sites)] surf_sites.reverse() - unique_sites, unique_perp_fracs = [], [] + unique_sites: list = [] + unique_perp_fracs: list = [] for site in surf_sites: this_perp = site.coords - np.dot(site.coords, self.mvec) this_perp_frac = slab.lattice.get_fractional_coords(this_perp) @@ -190,7 +195,7 @@ def find_surface_sites_by_height(self, slab, height=0.9, xy_tol=0.05): return surf_sites - def assign_site_properties(self, slab, height=0.9): + def assign_site_properties(self, slab: Slab, height=0.9): """Assigns site properties.""" if "surface_properties" in slab.site_properties: return slab @@ -260,8 +265,8 @@ def find_adsorption_sites( ads_sites["subsurface"] = ss_sites if "bridge" in positions or "hollow" in positions: mesh = self.get_extended_surface_mesh() - sop = get_rot(self.slab) - dt = Delaunay([sop.operate(m.coords)[:2] for m in mesh]) + symm_op = get_rot(self.slab) + dt = Delaunay([symm_op.operate(m.coords)[:2] for m in mesh]) # TODO: refactor below to properly account for >3-fold for v in dt.simplices: if -1 not in v: @@ -383,8 +388,8 @@ def add_adsorbate(self, molecule: Molecule, ads_coord, repeat=None, translate=Tr molecule.translate_sites(vector=[-x, -y, -z]) if reorient: # Reorient the molecule along slab m_index - sop = get_rot(self.slab) - molecule.apply_operation(sop.inverse) + symm_op = get_rot(self.slab) + molecule.apply_operation(symm_op.inverse) struct = self.slab.copy() if repeat: struct.make_supercell(repeat) @@ -553,20 +558,20 @@ def generate_substitution_structures( sym_slab = SpacegroupAnalyzer(self.slab).get_symmetrized_structure() # Define a function for substituting a site - def substitute(site, i): + def substitute(site, idx): slab = self.slab.copy() props = self.slab.site_properties if sub_both_sides: # Find an equivalent site on the other surface - eq_indices = next(indices for indices in sym_slab.equivalent_indices if i in indices) + eq_indices = next(indices for indices in sym_slab.equivalent_indices if idx in indices) for ii in eq_indices: if f"{sym_slab[ii].frac_coords[2]:.6f}" != f"{site.frac_coords[2]:.6f}": props["surface_properties"][ii] = "substitute" slab.replace(ii, atom) break - props["surface_properties"][i] = "substitute" - slab.replace(i, atom) + props["surface_properties"][idx] = "substitute" + slab.replace(idx, atom) slab.add_site_property("surface_properties", props["surface_properties"]) return slab @@ -598,7 +603,7 @@ def get_mi_vec(slab): return mvec / np.linalg.norm(mvec) -def get_rot(slab): +def get_rot(slab: Slab) -> SymmOp: """Gets the transformation to rotate the z axis into the miller index.""" new_z = get_mi_vec(slab) a, _b, _c = slab.lattice.matrix @@ -621,8 +626,8 @@ def reorient_z(structure): to the A-B plane. """ struct = structure.copy() - sop = get_rot(struct) - struct.apply_operation(sop) + symm_op = get_rot(struct) + struct.apply_operation(symm_op) return struct @@ -632,8 +637,8 @@ def reorient_z(structure): def plot_slab( - slab, - ax, + slab: Slab, + ax: plt.Axes, scale=0.8, repeat=5, window=1.5, @@ -667,8 +672,8 @@ def plot_slab( alphas = alphas.clip(min=0) corner = [0, 0, slab.lattice.get_fractional_coords(coords[-1])[-1]] corner = slab.lattice.get_cartesian_coords(corner)[:2] - verts = orig_cell[:2, :2] - lattsum = verts[0] + verts[1] + vertices = orig_cell[:2, :2] + lattice_sum = vertices[0] + vertices[1] # inverse coords, sites, alphas, to show other side of slab if inverse: alphas = np.array(reversed(alphas)) @@ -676,13 +681,13 @@ def plot_slab( coords = np.array(reversed(coords)) # Draw circles at sites and stack them accordingly for n, coord in enumerate(coords): - r = sites[n].species.elements[0].atomic_radius * scale - ax.add_patch(patches.Circle(coord[:2] - lattsum * (repeat // 2), r, color="w", zorder=2 * n)) + radius = sites[n].species.elements[0].atomic_radius * scale + ax.add_patch(patches.Circle(coord[:2] - lattice_sum * (repeat // 2), radius, color="w", zorder=2 * n)) color = color_dict[sites[n].species.elements[0].symbol] ax.add_patch( patches.Circle( - coord[:2] - lattsum * (repeat // 2), - r, + coord[:2] - lattice_sum * (repeat // 2), + radius, facecolor=color, alpha=alphas[n], edgecolor="k", @@ -698,22 +703,22 @@ def plot_slab( inverse_slab.make_supercell([1, 1, -1]) asf = AdsorbateSiteFinder(inverse_slab) ads_sites = asf.find_adsorption_sites()["all"] - sop = get_rot(orig_slab) - ads_sites = [sop.operate(ads_site)[:2].tolist() for ads_site in ads_sites] + symm_op = get_rot(orig_slab) + ads_sites = [symm_op.operate(ads_site)[:2].tolist() for ads_site in ads_sites] ax.plot(*zip(*ads_sites), color="k", marker="x", markersize=10, mew=1, linestyle="", zorder=10000) # Draw unit cell if draw_unit_cell: - verts = np.insert(verts, 1, lattsum, axis=0).tolist() - verts += [[0.0, 0.0]] - verts = [[0.0, 0.0], *verts] + vertices = np.insert(vertices, 1, lattice_sum, axis=0).tolist() + vertices += [[0.0, 0.0]] + vertices = [[0.0, 0.0], *vertices] codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY] - verts = [(np.array(vert) + corner).tolist() for vert in verts] - path = Path(verts, codes) + vertices = [(np.array(vert) + corner).tolist() for vert in vertices] + path = Path(vertices, codes) patch = patches.PathPatch(path, facecolor="none", lw=2, alpha=0.5, zorder=2 * n + 2) ax.add_patch(patch) ax.set_aspect("equal") - center = corner + lattsum / 2.0 - extent = np.max(lattsum) + center = corner + lattice_sum / 2.0 + extent = np.max(lattice_sum) lim_array = [center - extent * window, center + extent * window] x_lim = [ele[0] for ele in lim_array] y_lim = [ele[1] for ele in lim_array] diff --git a/pymatgen/analysis/bond_dissociation.py b/pymatgen/analysis/bond_dissociation.py index 96c1d853285..7f6ddbbff44 100644 --- a/pymatgen/analysis/bond_dissociation.py +++ b/pymatgen/analysis/bond_dissociation.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings import networkx as nx from monty.json import MSONable @@ -48,7 +49,7 @@ def __init__( Args: molecule_entry (dict): Entry for the principle molecule. Should have the keys mentioned above. - fragment_entries (list of dicts): List of fragment entries. Each should have the keys mentioned above. + fragment_entries (list[dict]): Fragment entries. Each should have the keys mentioned above. allow_additional_charge_separation (bool): If True, consider larger than normal charge separation among fragments. Defaults to False. See the definition of self.expected_charges below for more specific information. @@ -89,7 +90,7 @@ def __init__( self.expected_charges = [final_charge - 2, final_charge - 1, final_charge, final_charge + 1] # Build principle molecule graph - self.mol_graph = MoleculeGraph.with_local_env_strategy( + self.mol_graph = MoleculeGraph.from_local_env_strategy( Molecule.from_dict(molecule_entry["final_molecule"]), OpenBabelNN() ) # Loop through bonds, aka graph edges, and fragment and process: @@ -98,13 +99,13 @@ def __init__( self.fragment_and_process(bonds) # If multibreak, loop through pairs of ring bonds. if multibreak: - print( + warnings.warn( "Breaking pairs of ring bonds. WARNING: Structure changes much more likely, meaning dissociation values" " are less reliable! This is a bad idea!" ) self.bond_pairs = [] - for ii, bond in enumerate(self.ring_bonds): - for jj in range(ii + 1, len(self.ring_bonds)): + for ii, bond in enumerate(self.ring_bonds, start=1): + for jj in range(ii, len(self.ring_bonds)): bond_pair = [bond, self.ring_bonds[jj]] self.bond_pairs += [bond_pair] for bond_pair in self.bond_pairs: @@ -120,6 +121,7 @@ def fragment_and_process(self, bonds): try: frags = self.mol_graph.split_molecule_subgraphs(bonds, allow_reverse=True) frag_success = True + except MolGraphSplitError: # If split is unsuccessful, then we have encountered a ring bond if len(bonds) == 1: @@ -152,13 +154,13 @@ def fragment_and_process(self, bonds): pb_mol = bb.pybel_mol smiles = pb_mol.write("smi").split()[0] specie = nx.get_node_attributes(self.mol_graph.graph, "specie") - print( + warnings.warn( f"Missing ring opening fragment resulting from the breakage of {specie[bonds[0][0]]} " f"{specie[bonds[0][1]]} bond {bonds[0][0]} {bonds[0][1]} which would yield a " f"molecule with this SMILES string: {smiles}" ) elif len(good_entries) == 1: - # If we have only one good entry, format it and add it to the list that will eventually return: + # If we have only one good entry, format it and add it to the list that will eventually return self.bond_dissociation_energies += [self.build_new_entry(good_entries, bonds)] else: # We shouldn't ever encounter more than one good entry. @@ -166,9 +168,9 @@ def fragment_and_process(self, bonds): elif len(bonds) == 2: raise RuntimeError("Should only be trying to break two bonds if multibreak is true! Exiting...") else: - print("No reason to try and break more than two bonds at once! Exiting...") - raise ValueError + raise ValueError("No reason to try and break more than two bonds at once! Exiting...") frag_success = False + if frag_success: # If the principle did successfully split, then we aren't dealing with a ring bond. # As above, we begin by making sure we haven't already encountered an identical pair of fragments: @@ -203,14 +205,14 @@ def fragment_and_process(self, bonds): smiles = pb_mol.write("smi").split()[0] for charge in self.expected_charges: if charge not in frag1_charges_found: - print(f"Missing {charge=} for fragment {smiles}") + warnings.warn(f"Missing {charge=} for fragment {smiles}") if len(frag2_charges_found) < len(self.expected_charges): bb = BabelMolAdaptor(frags[1].molecule) pb_mol = bb.pybel_mol smiles = pb_mol.write("smi").split()[0] for charge in self.expected_charges: if charge not in frag2_charges_found: - print(f"Missing {charge=} for fragment {smiles}") + warnings.warn(f"Missing {charge=} for fragment {smiles}") # Now we attempt to pair fragments with the right total charge, starting with only fragments with no # structural change: for frag1 in frag1_entries[0]: # 0 -> no structural change @@ -241,12 +243,12 @@ def fragment_and_process(self, bonds): self.bond_dissociation_energies += [self.build_new_entry([frag1, frag2], bonds)] n_entries_for_this_frag_pair += 1 - def search_fragment_entries(self, frag): + def search_fragment_entries(self, frag) -> list: """ Search all fragment entries for those isomorphic to the given fragment. - We distinguish between entries where both initial and final molgraphs are isomorphic to the - given fragment (entries) vs those where only the initial molgraph is isomorphic to the given - fragment (initial_entries) vs those where only the final molgraph is isomorphic (final_entries). + We distinguish between entries where both initial and final MoleculeGraphs are isomorphic to the + given fragment (entries) vs those where only the initial MoleculeGraph is isomorphic to the given + fragment (initial_entries) vs those where only the final MoleculeGraph is isomorphic (final_entries). Args: frag: Fragment @@ -263,13 +265,14 @@ def search_fragment_entries(self, frag): final_entries += [entry] return [entries, initial_entries, final_entries] - def filter_fragment_entries(self, fragment_entries): + def filter_fragment_entries(self, fragment_entries: list) -> None: """ Filter the fragment entries. - :param fragment_entries: + Args: + fragment_entries (List): Fragment entries to be filtered. """ - self.filtered_entries = [] + self.filtered_entries: list = [] for entry in fragment_entries: # Check and make sure that PCM dielectric is consistent with principle: if "pcm_dielectric" in self.molecule_entry: @@ -284,10 +287,10 @@ def filter_fragment_entries(self, fragment_entries): raise RuntimeError(err_msg.replace("[[placeholder]]", "a different")) # Build initial and final molgraphs: - entry["initial_molgraph"] = MoleculeGraph.with_local_env_strategy( + entry["initial_molgraph"] = MoleculeGraph.from_local_env_strategy( Molecule.from_dict(entry["initial_molecule"]), OpenBabelNN() ) - entry["final_molgraph"] = MoleculeGraph.with_local_env_strategy( + entry["final_molgraph"] = MoleculeGraph.from_local_env_strategy( Molecule.from_dict(entry["final_molecule"]), OpenBabelNN() ) # Classify any potential structural change that occurred during optimization: @@ -305,8 +308,9 @@ def filter_fragment_entries(self, fragment_entries): else: entry["structure_change"] = "bond_change" found_similar_entry = False + # Check for uniqueness - for ii, filtered_entry in enumerate(self.filtered_entries): + for idx, filtered_entry in enumerate(self.filtered_entries): if filtered_entry["formula_pretty"] == entry["formula_pretty"] and ( filtered_entry["initial_molgraph"].isomorphic_to(entry["initial_molgraph"]) and filtered_entry["final_molgraph"].isomorphic_to(entry["final_molgraph"]) @@ -316,19 +320,23 @@ def filter_fragment_entries(self, fragment_entries): # If two entries are found that pass the above similarity check, take the one with the lower # energy: if entry["final_energy"] < filtered_entry["final_energy"]: - self.filtered_entries[ii] = entry + self.filtered_entries[idx] = entry # Note that this will essentially choose between singlet and triplet entries assuming both have # the same structural details break if not found_similar_entry: self.filtered_entries += [entry] - def build_new_entry(self, frags, bonds): + def build_new_entry(self, frags: list, bonds: list) -> list: """ - Simple function to format a bond dissociation entry that will eventually be returned to the user. + Build a new entry for bond dissociation that will be returned to the user. - :param frags: - :param bonds: + Args: + frags (list): Fragments involved in the bond dissociation. + bonds (list): Bonds broken in the dissociation process. + + Returns: + list: Formatted bond dissociation entries. """ specie = nx.get_node_attributes(self.mol_graph.graph, "specie") if len(frags) == 2: @@ -348,6 +356,7 @@ def build_new_entry(self, frags, bonds): frags[1]["initial_molecule"]["spin_multiplicity"], frags[1]["final_energy"], ] + else: new_entry = [ self.molecule_entry["final_energy"] - frags[0]["final_energy"], @@ -360,4 +369,5 @@ def build_new_entry(self, frags, bonds): frags[0]["initial_molecule"]["spin_multiplicity"], frags[0]["final_energy"], ] + return new_entry diff --git a/pymatgen/analysis/bond_valence.py b/pymatgen/analysis/bond_valence.py index d6cda8e1c78..36099fb645e 100644 --- a/pymatgen/analysis/bond_valence.py +++ b/pymatgen/analysis/bond_valence.py @@ -36,8 +36,7 @@ def calculate_bv_sum(site, nn_list, scale_factor=1.0): - """ - Calculates the BV sum of a site. + """Calculates the BV sum of a site. Args: site (PeriodicSite): The central site to calculate the bond valence @@ -63,8 +62,7 @@ def calculate_bv_sum(site, nn_list, scale_factor=1.0): def calculate_bv_sum_unordered(site, nn_list, scale_factor=1): - """ - Calculates the BV sum of a site for unordered structures. + """Calculates the BV sum of a site for unordered structures. Args: site (PeriodicSite): The central site to calculate the bond valence @@ -82,7 +80,7 @@ def calculate_bv_sum_unordered(site, nn_list, scale_factor=1): # site "site" is obtained as : # \sum_{nn} \sum_j^N \sum_k^{N_{nn}} f_{site}_j f_{nn_i}_k vij_full # where vij_full is the valence bond of the fully occupied bond - bvsum = 0 + bv_sum = 0 for specie1, occu1 in site.species.items(): el1 = Element(specie1.symbol) for nn in nn_list: @@ -95,8 +93,8 @@ def calculate_bv_sum_unordered(site, nn_list, scale_factor=1): c2 = BV_PARAMS[el2]["c"] R = r1 + r2 - r1 * r2 * (sqrt(c1) - sqrt(c2)) ** 2 / (c1 * r1 + c2 * r2) vij = exp((R - nn.nn_distance * scale_factor) / 0.31) - bvsum += occu1 * occu2 * vij * (1 if el1.X < el2.X else -1) - return bvsum + bv_sum += occu1 * occu2 * vij * (1 if el1.X < el2.X else -1) + return bv_sum class BVAnalyzer: diff --git a/pymatgen/analysis/chemenv/connectivity/connected_components.py b/pymatgen/analysis/chemenv/connectivity/connected_components.py index 134b96b540f..d4e56c8520b 100644 --- a/pymatgen/analysis/chemenv/connectivity/connected_components.py +++ b/pymatgen/analysis/chemenv/connectivity/connected_components.py @@ -4,6 +4,7 @@ import itertools import logging +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import networkx as nx @@ -18,6 +19,9 @@ from pymatgen.analysis.chemenv.utils.graph_utils import get_delta from pymatgen.analysis.chemenv.utils.math_utils import get_linearly_independent_vectors +if TYPE_CHECKING: + from typing_extensions import Self + def draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None): """Draw network of environments in a matplotlib figure axes. @@ -234,19 +238,20 @@ def __init__( "__init__", "Trying to add edge with some unexistent node ...", ) - if links_data is not None: - if (env_node1, env_node2, key) in links_data: - edge_data = links_data[(env_node1, env_node2, key)] - elif (env_node2, env_node1, key) in links_data: - edge_data = links_data[(env_node2, env_node1, key)] - elif (env_node1, env_node2) in links_data: - edge_data = links_data[(env_node1, env_node2)] - elif (env_node2, env_node1) in links_data: - edge_data = links_data[(env_node2, env_node1)] - else: - edge_data = None + if links_data is None: + edge_data = None + + elif (env_node1, env_node2, key) in links_data: + edge_data = links_data[(env_node1, env_node2, key)] + elif (env_node2, env_node1, key) in links_data: + edge_data = links_data[(env_node2, env_node1, key)] + elif (env_node1, env_node2) in links_data: + edge_data = links_data[(env_node1, env_node2)] + elif (env_node2, env_node1) in links_data: + edge_data = links_data[(env_node2, env_node1)] else: edge_data = None + if edge_data: self._connected_subgraph.add_edge(env_node1, env_node2, key, **edge_data) else: @@ -736,7 +741,11 @@ def elastic_centered_graph(self, start_node=None): check_centered_connected_subgraph = nx.MultiGraph() check_centered_connected_subgraph.add_nodes_from(centered_connected_subgraph.nodes()) check_centered_connected_subgraph.add_edges_from( - [e for e in centered_connected_subgraph.edges(data=True) if np.allclose(e[2]["delta"], np.zeros(3))] + [ + edge + for edge in centered_connected_subgraph.edges(data=True) + if np.allclose(edge[2]["delta"], np.zeros(3)) + ] ) if not is_connected(check_centered_connected_subgraph): raise RuntimeError("Could not find a centered graph.") @@ -827,35 +836,35 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the ConnectedComponent object from a dict representation of the ConnectedComponent object created using the as_dict method. Args: - d (dict): dict representation of the ConnectedComponent object + dct (dict): dict representation of the ConnectedComponent object Returns: ConnectedComponent: The connected component representing the links of a given set of environments. """ nodes_map = { - inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in d["nodes"].items() + inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in dct["nodes"].items() } - nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in d["nodes"].items()} - dod = {} - for e1, e1dict in d["graph"].items(): - dod[e1] = {} + nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in dct["nodes"].items()} + nested_dict: dict[str, dict] = {} + for e1, e1dict in dct["graph"].items(): + nested_dict[e1] = {} for e2, e2dict in e1dict.items(): - dod[e1][e2] = { + nested_dict[e1][e2] = { cls._edgedictkey_to_edgekey(ied): cls._retuplify_edgedata(edata) for ied, edata in e2dict.items() } - graph = nx.from_dict_of_dicts(dod, create_using=nx.MultiGraph, multigraph_input=True) + graph = nx.from_dict_of_dicts(nested_dict, create_using=nx.MultiGraph, multigraph_input=True) nx.set_node_attributes(graph, nodes_data) nx.relabel_nodes(graph, nodes_map, copy=False) return cls(graph=graph) @classmethod - def from_graph(cls, g): + def from_graph(cls, g) -> Self: """ Constructor for the ConnectedComponent object from a graph of the connected component. diff --git a/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py b/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py index 0c10a51d6ee..55c72970a95 100644 --- a/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py +++ b/pymatgen/analysis/chemenv/connectivity/connectivity_finder.py @@ -24,9 +24,10 @@ def __init__(self, multiple_environments_choice=None): """ Constructor for the ConnectivityFinder. - :param multiple_environments_choice: defines the procedure to apply when - the environment of a given site is described as a "mix" of more than one - coordination environments. + Args: + multiple_environments_choice: defines the procedure to apply when + the environment of a given site is described as a "mix" of more than one + coordination environments. """ self.setup_parameters(multiple_environments_choice=multiple_environments_choice) @@ -35,8 +36,9 @@ def get_structure_connectivity(self, light_structure_environments): Get the structure connectivity from the coordination environments provided as an input. - :param light_structure_environments: LightStructureEnvironments with the - relevant coordination environments in the structure + Args: + light_structure_environments: LightStructureEnvironments with the + relevant coordination environments in the structure Returns: a StructureConnectivity object describing the connectivity of diff --git a/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py b/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py index 1220fd76244..f940842374e 100644 --- a/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py +++ b/pymatgen/analysis/chemenv/connectivity/structure_connectivity.py @@ -4,6 +4,7 @@ import collections import logging +from typing import TYPE_CHECKING import networkx as nx import numpy as np @@ -13,6 +14,9 @@ from pymatgen.analysis.chemenv.connectivity.environment_nodes import get_environment_node from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -299,26 +303,30 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (): + dct (dict): Returns: StructureConnectivity """ # Reconstructs the graph with integer as nodes (json's as_dict replaces integer keys with str keys) - cgraph = nx.from_dict_of_dicts(d["connectivity_graph"], create_using=nx.MultiGraph, multigraph_input=True) - cgraph = nx.relabel_nodes(cgraph, int) # Just relabel the nodes using integer casting (maps str->int) + connect_graph = nx.from_dict_of_dicts( + dct["connectivity_graph"], create_using=nx.MultiGraph, multigraph_input=True + ) + connect_graph = nx.relabel_nodes( + connect_graph, int + ) # Just relabel the nodes using integer casting (maps str->int) # Relabel multi-edges (removes multi-edges with str keys and adds them back with int keys) - edges = set(cgraph.edges()) + edges = set(connect_graph.edges()) for n1, n2 in edges: - new_edges = {int(iedge): edata for iedge, edata in cgraph[n1][n2].items()} - cgraph.remove_edges_from([(n1, n2, iedge) for iedge, edata in cgraph[n1][n2].items()]) - cgraph.add_edges_from([(n1, n2, iedge, edata) for iedge, edata in new_edges.items()]) + new_edges = {int(iedge): edata for iedge, edata in connect_graph[n1][n2].items()} + connect_graph.remove_edges_from([(n1, n2, iedge) for iedge, edata in connect_graph[n1][n2].items()]) + connect_graph.add_edges_from([(n1, n2, iedge, edata) for iedge, edata in new_edges.items()]) return cls( - LightStructureEnvironments.from_dict(d["light_structure_environments"]), - connectivity_graph=cgraph, + LightStructureEnvironments.from_dict(dct["light_structure_environments"]), + connectivity_graph=connect_graph, environment_subgraphs=None, ) # TODO: also deserialize the environment_subgraphs diff --git a/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py b/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py index 79edf04d466..63a32c88460 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py +++ b/pymatgen/analysis/chemenv/coordination_environments/chemenv_strategies.py @@ -9,7 +9,7 @@ import abc import os -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import numpy as np from monty.json import MSONable @@ -30,6 +30,9 @@ from pymatgen.core.sites import PeriodicSite from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -59,10 +62,11 @@ class DistanceCutoffFloat(float, StrategyOption): allowed_values = "Real number between 1 and +infinity" - def __new__(cls, cutoff): + def __new__(cls, cutoff) -> Self: """Special float that should be between 1 and infinity. - :param cutoff: Distance cutoff. + Args: + cutoff: Distance cutoff. """ flt = float.__new__(cls, cutoff) if flt < 1: @@ -78,12 +82,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Initialize distance cutoff from dict. - :param d: Dict representation of the distance cutoff. + Args: + dct (dict): Dict representation of the distance cutoff. """ - return cls(d["value"]) + return cls(dct["value"]) class AngleCutoffFloat(float, StrategyOption): @@ -91,10 +96,11 @@ class AngleCutoffFloat(float, StrategyOption): allowed_values = "Real number between 0 and 1" - def __new__(cls, cutoff): + def __new__(cls, cutoff) -> Self: """Special float that should be between 0 and 1. - :param cutoff: Angle cutoff. + Args: + cutoff: Angle cutoff. """ flt = float.__new__(cls, cutoff) if not 0 <= flt <= 1: @@ -110,12 +116,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Initialize angle cutoff from dict. - :param d: Dict representation of the angle cutoff. + Args: + dct (dict): Dict representation of the angle cutoff. """ - return cls(d["value"]) + return cls(dct["value"]) class CSMFloat(float, StrategyOption): @@ -123,10 +130,11 @@ class CSMFloat(float, StrategyOption): allowed_values = "Real number between 0 and 100" - def __new__(cls, cutoff): + def __new__(cls, cutoff) -> Self: """Special float that should be between 0 and 100. - :param cutoff: CSM. + Args: + cutoff: CSM. """ flt = float.__new__(cls, cutoff) if not 0 <= flt <= 100: @@ -142,10 +150,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize CSM from dict. - :param d: Dict representation of the CSM. + Args: + dct (dict): Dict representation of the CSM. """ return cls(dct["value"]) @@ -157,7 +166,7 @@ class AdditionalConditionInt(int, StrategyOption): for integer, description in AdditionalConditions.CONDITION_DESCRIPTION.items(): allowed_values += f" - {integer} for {description!r}\n" - def __new__(cls, integer): + def __new__(cls, integer) -> Self: """Special int representing additional conditions.""" if str(int(integer)) != str(integer): raise ValueError(f"Additional condition {integer} is not an integer") @@ -175,10 +184,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize additional condition from dict. - :param d: Dict representation of the additional condition. + Args: + dct (dict): Dict representation of the additional condition. """ return cls(dct["value"]) @@ -202,8 +212,10 @@ def __init__( ): """ Abstract constructor for the all chemenv strategies. - :param structure_environments: StructureEnvironments object containing all the information on the - coordination of the sites in a structure. + + Args: + structure_environments: StructureEnvironments object containing all the information on the + coordination of the sites in a structure. """ self.structure_environments = None if structure_environments is not None: @@ -218,7 +230,8 @@ def symmetry_measure_type(self): def set_structure_environments(self, structure_environments): """Set the structure environments to this strategy. - :param structure_environments: StructureEnvironments object. + Args: + structure_environments: StructureEnvironments object. """ self.structure_environments = structure_environments if not isinstance(self.structure_environments.voronoi, DetailedVoronoiContainer): @@ -236,7 +249,8 @@ def prepare_symmetries(self): def equivalent_site_index_and_transform(self, psite): """Get the equivalent site and corresponding symmetry+translation transformations. - :param psite: Periodic site. + Args: + psite: Periodic site. Returns: Equivalent site in the unit cell, translations and symmetry transformation. @@ -305,9 +319,11 @@ def equivalent_site_index_and_transform(self, psite): def get_site_neighbors(self, site): """ Applies the strategy to the structure_environments object in order to get the neighbors of a given site. - :param site: Site for which the neighbors are looked for - :param structure_environments: StructureEnvironments object containing all the information needed to get the - neighbors of the site + + Args: + site: Site for which the neighbors are looked for + structure_environments: StructureEnvironments object containing all the information needed to get the + neighbors of the site Returns: The list of neighbors of the site. For complex strategies, where one allows multiple solutions, this @@ -325,7 +341,9 @@ def get_site_coordination_environment(self, site): """ Applies the strategy to the structure_environments object in order to define the coordination environment of a given site. - :param site: Site for which the coordination environment is looked for + + Args: + site: Site for which the coordination environment is looked for Returns: The coordination environment of the site. For complex strategies, where one allows multiple @@ -338,7 +356,9 @@ def get_site_coordination_environments(self, site): """ Applies the strategy to the structure_environments object in order to define the coordination environment of a given site. - :param site: Site for which the coordination environment is looked for + + Args: + site: Site for which the coordination environment is looked for Returns: The coordination environment of the site. For complex strategies, where one allows multiple @@ -362,7 +382,9 @@ def get_site_coordination_environments_fractions( """ Applies the strategy to the structure_environments object in order to define the coordination environment of a given site. - :param site: Site for which the coordination environment is looked for + + Args: + site: Site for which the coordination environment is looked for Returns: The coordination environment of the site. For complex strategies, where one allows multiple @@ -374,7 +396,9 @@ def get_site_ce_fractions_and_neighbors(self, site, full_ce_info=False, strategy """ Applies the strategy to the structure_environments object in order to get coordination environments, their fraction, csm, geometry_info, and neighbors - :param site: Site for which the above information is sought + + Args: + site: Site for which the above information is sought Returns: The list of neighbors of the site. For complex strategies, where one allows multiple solutions, this @@ -408,15 +432,17 @@ def get_site_ce_fractions_and_neighbors(self, site, full_ce_info=False, strategy def set_option(self, option_name, option_value): """Set up a given option for this strategy. - :param option_name: Name of the option. - :param option_value: Value for this option. + Args: + option_name: Name of the option. + option_value: Value for this option. """ setattr(self, option_name, option_value) def setup_options(self, all_options_dict): """Set up options for this strategy based on a dict. - :param all_options_dict: Dict of option_name->option_value. + Args: + all_options_dict: Dict of option_name->option_value. """ for option_name, option_value in all_options_dict.items(): self.set_option(option_name, option_value) @@ -425,7 +451,9 @@ def setup_options(self, all_options_dict): def __eq__(self, other: object) -> bool: """ Equality method that should be implemented for any strategy - :param other: strategy to be compared with the current one + + Args: + other: strategy to be compared with the current one """ raise NotImplementedError @@ -451,11 +479,13 @@ def as_dict(self): raise NotImplementedError @classmethod - def from_dict(cls, dct) -> AbstractChemenvStrategy: + def from_dict(cls, dct) -> Self: """ Reconstructs the SimpleAbundanceChemenvStrategy object from a dict representation of the SimpleAbundanceChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the SimpleAbundanceChemenvStrategy object + + Args: + dct: dict representation of the SimpleAbundanceChemenvStrategy object Returns: StructureEnvironments object. @@ -498,10 +528,10 @@ class SimplestChemenvStrategy(AbstractChemenvStrategy): ) STRATEGY_DESCRIPTION = ( - " Simplest ChemenvStrategy using fixed angle and distance parameters \n" - " for the definition of neighbors in the Voronoi approach. \n" - " The coordination environment is then given as the one with the \n" - " lowest continuous symmetry measure." + "Simplest ChemenvStrategy using fixed angle and distance parameters \n" + "for the definition of neighbors in the Voronoi approach. \n" + "The coordination environment is then given as the one with the \n" + "lowest continuous symmetry measure." ) def __init__( @@ -515,8 +545,10 @@ def __init__( ): """ Constructor for this SimplestChemenvStrategy. - :param distance_cutoff: Distance cutoff used - :param angle_cutoff: Angle cutoff used. + + Args: + distance_cutoff: Distance cutoff used + angle_cutoff: Angle cutoff used. """ AbstractChemenvStrategy.__init__(self, structure_environments, symmetry_measure_type=symmetry_measure_type) self.distance_cutoff = distance_cutoff @@ -538,7 +570,8 @@ def distance_cutoff(self): def distance_cutoff(self, distance_cutoff): """Set the distance cutoff for this strategy. - :param distance_cutoff: Distance cutoff. + Args: + distance_cutoff: Distance cutoff. """ self._distance_cutoff = DistanceCutoffFloat(distance_cutoff) @@ -551,7 +584,8 @@ def angle_cutoff(self): def angle_cutoff(self, angle_cutoff): """Set the angle cutoff for this strategy. - :param angle_cutoff: Angle cutoff. + Args: + angle_cutoff: Angle cutoff. """ self._angle_cutoff = AngleCutoffFloat(angle_cutoff) @@ -564,7 +598,8 @@ def additional_condition(self): def additional_condition(self, additional_condition): """Set the additional condition for this strategy. - :param additional_condition: Additional condition. + Args: + additional_condition: Additional condition. """ self._additional_condition = AdditionalConditionInt(additional_condition) @@ -577,18 +612,20 @@ def continuous_symmetry_measure_cutoff(self): def continuous_symmetry_measure_cutoff(self, continuous_symmetry_measure_cutoff): """Set the CSM cutoff for this strategy. - :param continuous_symmetry_measure_cutoff: CSM cutoff + Args: + continuous_symmetry_measure_cutoff: CSM cutoff """ self._continuous_symmetry_measure_cutoff = CSMFloat(continuous_symmetry_measure_cutoff) def get_site_neighbors(self, site, isite=None, dequivsite=None, dthissite=None, mysym=None): """Get the neighbors of a given site. - :param site: Site for which neighbors are needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. + Args: + site: Site for which neighbors are needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. Returns: List of coordinated neighbors of site. @@ -626,12 +663,13 @@ def get_site_coordination_environment( ): """Get the coordination environment of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_map: Whether to return cn_map (identifies the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_map: Whether to return cn_map (identifies the NeighborsSet used). Returns: Coordination environment of site. @@ -708,15 +746,16 @@ def get_site_coordination_environments_fractions( ): """Get the coordination environments of a given site and additional information. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param ordered: Whether to order the list by fractions. - :param min_fraction: Minimum fraction to include in the list - :param return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). - :param return_strategy_dict_info: Whether to add the info about the strategy used. + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + ordered: Whether to order the list by fractions. + min_fraction: Minimum fraction to include in the list + return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). + return_strategy_dict_info: Whether to add the info about the strategy used. Returns: List of Dict with coordination environment, fraction and additional info. @@ -762,12 +801,13 @@ def get_site_coordination_environments( ): """Get the coordination environments of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). Returns: List of coordination environment. @@ -780,9 +820,10 @@ def get_site_coordination_environments( def add_strategy_visualization_to_subplot(self, subplot, visualization_options=None, plot_type=None): """Add a visual of the strategy on a distance-angle plot. - :param subplot: Axes object onto the visual should be added. - :param visualization_options: Options for the visual. - :param plot_type: Type of distance-angle plot. + Args: + subplot: Axes object onto the visual should be added. + visualization_options: Options for the visual. + plot_type: Type of distance-angle plot. """ subplot.plot( self._distance_cutoff, self._angle_cutoff, "o", markeredgecolor=None, markerfacecolor="w", markersize=12 @@ -819,11 +860,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict) -> SimplestChemenvStrategy: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the SimplestChemenvStrategy object from a dict representation of the SimplestChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the SimplestChemenvStrategy object + + Args: + dct: dict representation of the SimplestChemenvStrategy object Returns: StructureEnvironments object. @@ -855,10 +898,10 @@ class SimpleAbundanceChemenvStrategy(AbstractChemenvStrategy): ), ) STRATEGY_DESCRIPTION = ( - ' Simple Abundance ChemenvStrategy using the most "abundant" neighbors map \n' - " for the definition of neighbors in the Voronoi approach. \n" - " The coordination environment is then given as the one with the \n" - " lowest continuous symmetry measure." + 'Simple Abundance ChemenvStrategy using the most "abundant" neighbors map \n' + "for the definition of neighbors in the Voronoi approach. \n" + "The coordination environment is then given as the one with the \n" + "lowest continuous symmetry measure." ) def __init__( @@ -869,8 +912,10 @@ def __init__( ): """ Constructor for the SimpleAbundanceChemenvStrategy. - :param structure_environments: StructureEnvironments object containing all the information on the - coordination of the sites in a structure. + + Args: + structure_environments: StructureEnvironments object containing all the information on the + coordination of the sites in a structure. """ raise NotImplementedError("SimpleAbundanceChemenvStrategy not yet implemented") AbstractChemenvStrategy.__init__(self, structure_environments, symmetry_measure_type=symmetry_measure_type) @@ -884,7 +929,8 @@ def uniquely_determines_coordination_environments(self): def get_site_neighbors(self, site): """Get the neighbors of a given site with this strategy. - :param site: Periodic site. + Args: + site: Periodic site. Returns: List of neighbors of site. @@ -910,12 +956,13 @@ def get_site_coordination_environment( ): """Get the coordination environment of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_map: Whether to return cn_map (identifies the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_map: Whether to return cn_map (identifies the NeighborsSet used). Returns: Coordination environment of site. @@ -951,12 +998,13 @@ def get_site_coordination_environments( ): """Get the coordination environments of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). Returns: List of coordination environment. @@ -1017,11 +1065,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict) -> SimpleAbundanceChemenvStrategy: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the SimpleAbundanceChemenvStrategy object from a dict representation of the SimpleAbundanceChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the SimpleAbundanceChemenvStrategy object + + Args: + dct: dict representation of the SimpleAbundanceChemenvStrategy object Returns: StructureEnvironments object. @@ -1053,15 +1103,20 @@ def __init__( """Initialize strategy. Not yet implemented. - :param structure_environments: - :param truncate_dist_ang: - :param additional_condition: - :param max_nabundant: - :param target_environments: - :param target_penalty_type: - :param max_csm: - :param symmetry_measure_type: + + Args: + structure_environments: + truncate_dist_ang: + additional_condition: + max_nabundant: + target_environments: + target_penalty_type: + max_csm: + symmetry_measure_type: """ + + raise NotImplementedError("TargetedPenaltiedAbundanceChemenvStrategy not yet implemented") + super().__init__( self, structure_environments, @@ -1072,7 +1127,6 @@ def __init__( self.target_environments = target_environments self.target_penalty_type = target_penalty_type self.max_csm = max_csm - raise NotImplementedError("TargetedPenaltiedAbundanceChemenvStrategy not yet implemented") def get_site_coordination_environment( self, @@ -1085,12 +1139,13 @@ def get_site_coordination_environment( ): """Get the coordination environment of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_map: Whether to return cn_map (identifies the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_map: Whether to return cn_map (identifies the NeighborsSet used). Returns: Coordination environment of site. @@ -1172,7 +1227,7 @@ def as_dict(self): "max_csm": self.max_csm, } - def __eq__(self, other: object) -> bool: + def __eq__(self, other): if not isinstance(other, type(self)): return NotImplemented @@ -1185,11 +1240,13 @@ def __eq__(self, other: object) -> bool: ) @classmethod - def from_dict(cls, dct) -> TargetedPenaltiedAbundanceChemenvStrategy: + def from_dict(cls, dct) -> Self: """ Reconstructs the TargetedPenaltiedAbundanceChemenvStrategy object from a dict representation of the TargetedPenaltiedAbundanceChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the TargetedPenaltiedAbundanceChemenvStrategy object + + Args: + dct: dict representation of the TargetedPenaltiedAbundanceChemenvStrategy object Returns: TargetedPenaltiedAbundanceChemenvStrategy object. @@ -1214,10 +1271,11 @@ def as_dict(self): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1232,7 +1290,8 @@ class AngleNbSetWeight(NbSetWeight): def __init__(self, aa=1): """Initialize AngleNbSetWeight estimator. - :param aa: Exponent of the angle for the estimator. + Args: + aa: Exponent of the angle for the estimator. """ self.aa = aa if self.aa == 1: @@ -1243,10 +1302,11 @@ def __init__(self, aa=1): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1257,7 +1317,8 @@ def weight(self, nb_set, structure_environments, cn_map=None, additional_info=No def angle_sum(nb_set): """Sum of all angles in a neighbors set. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: Sum of solid angles for the neighbors set. @@ -1267,7 +1328,8 @@ def angle_sum(nb_set): def angle_sumn(self, nb_set): """Sum of all angles to a given power in a neighbors set. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: Sum of solid angles to the power aa for the neighbors set. @@ -1288,7 +1350,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Construct AngleNbSetWeight from dict representation.""" return cls(aa=dct["aa"]) @@ -1301,9 +1363,10 @@ class NormalizedAngleDistanceNbSetWeight(NbSetWeight): def __init__(self, average_type, aa, bb): """Initialize NormalizedAngleDistanceNbSetWeight. - :param average_type: Average function. - :param aa: Exponent for the angle values. - :param bb: Exponent for the distance values. + Args: + average_type: Average function. + aa: Exponent for the angle values. + bb: Exponent for the distance values. """ self.average_type = average_type if self.average_type == "geometric": @@ -1351,10 +1414,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of NormalizedAngleDistanceNbSetWeight. + Args: + dct (dict): Dict representation of NormalizedAngleDistanceNbSetWeight. Returns: NormalizedAngleDistanceNbSetWeight. @@ -1365,7 +1429,8 @@ def from_dict(cls, dct): def invdist(nb_set): """Inverse distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of inverse distances. @@ -1375,7 +1440,8 @@ def invdist(nb_set): def invndist(self, nb_set): """Inverse power distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of inverse power distances. @@ -1386,7 +1452,8 @@ def invndist(self, nb_set): def ang(nb_set): """Angle weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of angle weights. @@ -1396,7 +1463,8 @@ def ang(nb_set): def angn(self, nb_set): """Power angle weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of power angle weights. @@ -1407,7 +1475,8 @@ def angn(self, nb_set): def anginvdist(nb_set): """Angle/distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of angle/distance weights. @@ -1418,7 +1487,8 @@ def anginvdist(nb_set): def anginvndist(self, nb_set): """Angle/power distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of angle/power distance weights. @@ -1429,7 +1499,8 @@ def anginvndist(self, nb_set): def angninvdist(self, nb_set): """Power angle/distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of power angle/distance weights. @@ -1440,7 +1511,8 @@ def angninvdist(self, nb_set): def angninvndist(self, nb_set): """Power angle/power distance weight. - :param nb_set: Neighbors set. + Args: + nb_set: Neighbors set. Returns: List of power angle/power distance weights. @@ -1451,10 +1523,11 @@ def angninvndist(self, nb_set): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1466,7 +1539,8 @@ def weight(self, nb_set, structure_environments, cn_map=None, additional_info=No def gweight(fda_list): """Geometric mean of the weights. - :param fda_list: List of estimator weights for each neighbor. + Args: + fda_list: List of estimator weights for each neighbor. Returns: Geometric mean of the weights. @@ -1477,7 +1551,8 @@ def gweight(fda_list): def aweight(fda_list): """Standard mean of the weights. - :param fda_list: List of estimator weights for each neighbor. + Args: + fda_list: List of estimator weights for each neighbor. Returns: Standard mean of the weights. @@ -1496,13 +1571,14 @@ def get_effective_csm( ): """Get the effective continuous symmetry measure of a given neighbors set. - :param nb_set: Neighbors set. - :param cn_map: Mapping index of this neighbors set. - :param structure_environments: Structure environments. - :param additional_info: Additional information for the neighbors set. - :param symmetry_measure_type: Type of symmetry measure to be used in the effective CSM. - :param max_effective_csm: Max CSM to use for the effective CSM calculation. - :param effective_csm_estimator_ratio_function: Ratio function to use to compute effective CSM. + Args: + nb_set: Neighbors set. + cn_map: Mapping index of this neighbors set. + structure_environments: Structure environments. + additional_info: Additional information for the neighbors set. + symmetry_measure_type: Type of symmetry measure to be used in the effective CSM. + max_effective_csm: Max CSM to use for the effective CSM calculation. + effective_csm_estimator_ratio_function: Ratio function to use to compute effective CSM. Returns: Effective CSM of a given Neighbors set. """ @@ -1536,16 +1612,15 @@ def get_effective_csm( return effective_csm -def set_info(additional_info, field, isite, cn_map, value): +def set_info(additional_info, field, isite, cn_map, value) -> None: """Set additional information for the weights. - :param additional_info: Additional information. - :param field: Type of additional information. - :param isite: Index of site to add info. - :param cn_map: Mapping index of the neighbors set. - :param value: Value of this additional information. - Returns: - None + Args: + additional_info: Additional information. + field: Type of additional information. + isite: Index of site to add info. + cn_map: Mapping index of the neighbors set. + value: Value of this additional information. """ try: additional_info[field][isite][cn_map] = value @@ -1579,9 +1654,10 @@ def __init__( ): """Initialize SelfCSMNbSetWeight. - :param effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). - :param weight_estimator: Weight estimator within a given neighbors set. - :param symmetry_measure_type: Type of symmetry measure to be used. + Args: + effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). + weight_estimator: Weight estimator within a given neighbors set. + symmetry_measure_type: Type of symmetry measure to be used. """ self.effective_csm_estimator = effective_csm_estimator self.effective_csm_estimator_rf = CSMInfiniteRatioFunction.from_dict(effective_csm_estimator) @@ -1593,10 +1669,11 @@ def __init__( def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1641,10 +1718,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of SelfCSMNbSetWeight. + Args: + dct (dict): Dict representation of SelfCSMNbSetWeight. Returns: SelfCSMNbSetWeight. @@ -1680,10 +1758,11 @@ def __init__( ): """Initialize DeltaCSMNbSetWeight. - :param effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). - :param weight_estimator: Weight estimator within a given neighbors set. - :param delta_cn_weight_estimators: Specific weight estimators for specific cn - :param symmetry_measure_type: Type of symmetry measure to be used. + Args: + effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). + weight_estimator: Weight estimator within a given neighbors set. + delta_cn_weight_estimators: Specific weight estimators for specific cn + symmetry_measure_type: Type of symmetry measure to be used. """ self.effective_csm_estimator = effective_csm_estimator self.effective_csm_estimator_rf = CSMInfiniteRatioFunction.from_dict(effective_csm_estimator) @@ -1701,10 +1780,11 @@ def __init__( def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1817,11 +1897,12 @@ def delta_cn_specifics( ): """Initialize DeltaCSMNbSetWeight from specific coordination number differences. - :param delta_csm_mins: Minimums for each coordination number. - :param delta_csm_maxs: Maximums for each coordination number. - :param function: Ratio function used. - :param symmetry_measure_type: Type of symmetry measure to be used. - :param effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). + Args: + delta_csm_mins: Minimums for each coordination number. + delta_csm_maxs: Maximums for each coordination number. + function: Ratio function used. + symmetry_measure_type: Type of symmetry measure to be used. + effective_csm_estimator: Ratio function used for the effective CSM (comparison between neighbors sets). Returns: DeltaCSMNbSetWeight. @@ -1875,10 +1956,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of DeltaCSMNbSetWeight. + Args: + dct (dict): Dict representation of DeltaCSMNbSetWeight. Returns: DeltaCSMNbSetWeight. @@ -1903,8 +1985,9 @@ class CNBiasNbSetWeight(NbSetWeight): def __init__(self, cn_weights, initialization_options): """Initialize CNBiasNbSetWeight. - :param cn_weights: Weights for each coordination. - :param initialization_options: Options for initialization. + Args: + cn_weights: Weights for each coordination. + initialization_options: Options for initialization. """ self.cn_weights = cn_weights self.initialization_options = initialization_options @@ -1912,10 +1995,11 @@ def __init__(self, cn_weights, initialization_options): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -1938,10 +2022,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of CNBiasNbSetWeight. + Args: + dct (dict): Dict representation of CNBiasNbSetWeight. Returns: CNBiasNbSetWeight. @@ -1955,8 +2040,9 @@ def from_dict(cls, dct): def linearly_equidistant(cls, weight_cn1, weight_cn13): """Initialize linearly equidistant weights for each coordination. - :param weight_cn1: Weight of coordination 1. - :param weight_cn13: Weight of coordination 13. + Args: + weight_cn1: Weight of coordination 1. + weight_cn13: Weight of coordination 13. Returns: CNBiasNbSetWeight. @@ -1974,8 +2060,9 @@ def linearly_equidistant(cls, weight_cn1, weight_cn13): def geometrically_equidistant(cls, weight_cn1, weight_cn13): """Initialize geometrically equidistant weights for each coordination. - :param weight_cn1: Weight of coordination 1. - :param weight_cn13: Weight of coordination 13. + Arge: + weight_cn1: Weight of coordination 1. + weight_cn13: Weight of coordination 13. Returns: CNBiasNbSetWeight. @@ -1993,7 +2080,8 @@ def geometrically_equidistant(cls, weight_cn1, weight_cn13): def explicit(cls, cn_weights): """Initialize weights explicitly for each coordination. - :param cn_weights: Weights for each coordination. + Args: + cn_weights: Weights for each coordination. Returns: CNBiasNbSetWeight. @@ -2004,10 +2092,11 @@ def explicit(cls, cn_weights): return cls(cn_weights=cn_weights, initialization_options=initialization_options) @classmethod - def from_description(cls, dct): + def from_description(cls, dct: dict) -> Self: """Initialize weights from description. - :param dct: Dictionary description. + Args: + dct (dict): Dictionary description. Returns: CNBiasNbSetWeight. @@ -2018,7 +2107,8 @@ def from_description(cls, dct): return cls.geometrically_equidistant(weight_cn1=dct["weight_cn1"], weight_cn13=dct["weight_cn13"]) if dct["type"] == "explicit": return cls.explicit(cn_weights=dct["cn_weights"]) - return None + + raise RuntimeError("Cannot initialize Weights.") class DistanceAngleAreaNbSetWeight(NbSetWeight): @@ -2045,13 +2135,14 @@ def __init__( ): """Initialize CNBiasNbSetWeight. - :param weight_type: Type of weight. - :param surface_definition: Definition of the surface. - :param nb_sets_from_hints: How to deal with neighbors sets obtained from "hints". - :param other_nb_sets: What to do with other neighbors sets. - :param additional_condition: Additional condition to be used. - :param smoothstep_distance: Smoothstep distance. - :param smoothstep_angle: Smoothstep angle. + Args: + weight_type: Type of weight. + surface_definition: Definition of the surface. + nb_sets_from_hints: How to deal with neighbors sets obtained from "hints". + other_nb_sets: What to do with other neighbors sets. + additional_condition: Additional condition to be used. + smoothstep_distance: Smoothstep distance. + smoothstep_angle: Smoothstep angle. """ self.weight_type = weight_type if weight_type == "has_intersection": @@ -2084,10 +2175,11 @@ def __init__( def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -2102,10 +2194,11 @@ def weight(self, nb_set, structure_environments, cn_map=None, additional_info=No def w_area_has_intersection(self, nb_set, structure_environments, cn_map, additional_info): """Get intersection of the neighbors set area with the surface. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments. - :param cn_map: Mapping index of the neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments. + cn_map: Mapping index of the neighbors set. + additional_info: Additional information. Returns: Area intersection between neighbors set and surface. @@ -2120,10 +2213,11 @@ def w_area_has_intersection(self, nb_set, structure_environments, cn_map, additi def w_area_intersection_nbsfh_fbs_onb0(self, nb_set, structure_environments, cn_map, additional_info): """Get intersection of the neighbors set area with the surface. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments. - :param cn_map: Mapping index of the neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments. + cn_map: Mapping index of the neighbors set. + additional_info: Additional information. Returns: Area intersection between neighbors set and surface. @@ -2169,10 +2263,11 @@ def w_area_intersection_nbsfh_fbs_onb0(self, nb_set, structure_environments, cn_ def rectangle_crosses_area(self, d1, d2, a1, a2): """Whether a given rectangle crosses the area defined by the upper and lower curves. - :param d1: lower d. - :param d2: upper d. - :param a1: lower a. - :param a2: upper a. + Args: + d1: lower d. + d2: upper d. + a1: lower a. + a2: upper a. """ # Case 1 if d1 <= self.dmin and d2 <= self.dmin: @@ -2235,10 +2330,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of DistanceAngleAreaNbSetWeight. + Args: + dct (dict): Dict representation of DistanceAngleAreaNbSetWeight. Returns: DistanceAngleAreaNbSetWeight. @@ -2260,8 +2356,9 @@ class DistancePlateauNbSetWeight(NbSetWeight): def __init__(self, distance_function=None, weight_function=None): """Initialize DistancePlateauNbSetWeight. - :param distance_function: Distance function to use. - :param weight_function: Ratio function to use. + Args: + distance_function: Distance function to use. + weight_function: Ratio function to use. """ if distance_function is None: self.distance_function = {"type": "normalized_distance"} @@ -2279,10 +2376,11 @@ def __init__(self, distance_function=None, weight_function=None): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -2302,10 +2400,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of DistancePlateauNbSetWeight. + Args: + dct (dict): Dict representation of DistancePlateauNbSetWeight. Returns: DistancePlateauNbSetWeight. @@ -2324,8 +2423,9 @@ class AnglePlateauNbSetWeight(NbSetWeight): def __init__(self, angle_function=None, weight_function=None): """Initialize AnglePlateauNbSetWeight. - :param angle_function: Angle function to use. - :param weight_function: Ratio function to use. + Args: + angle_function: Angle function to use. + weight_function: Ratio function to use. """ if angle_function is None: self.angle_function = {"type": "normalized_angle"} @@ -2343,10 +2443,11 @@ def __init__(self, angle_function=None, weight_function=None): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -2366,10 +2467,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of AnglePlateauNbSetWeight. + Args: + dct (dict): Dict representation of AnglePlateauNbSetWeight. Returns: AnglePlateauNbSetWeight. @@ -2385,8 +2487,9 @@ class DistanceNbSetWeight(NbSetWeight): def __init__(self, weight_function=None, nbs_source="voronoi"): """Initialize DistanceNbSetWeight. - :param weight_function: Ratio function to use. - :param nbs_source: Source of the neighbors. + Args: + weight_function: Ratio function to use. + nbs_source: Source of the neighbors. """ if weight_function is None: self.weight_function = { @@ -2403,10 +2506,11 @@ def __init__(self, weight_function=None, nbs_source="voronoi"): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -2444,10 +2548,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of DistanceNbSetWeight. + Args: + dct (dict): Dict representation of DistanceNbSetWeight. Returns: DistanceNbSetWeight. @@ -2463,8 +2568,9 @@ class DeltaDistanceNbSetWeight(NbSetWeight): def __init__(self, weight_function=None, nbs_source="voronoi"): """Initialize DeltaDistanceNbSetWeight. - :param weight_function: Ratio function to use. - :param nbs_source: Source of the neighbors. + Args: + weight_function: Ratio function to use. + nbs_source: Source of the neighbors. """ if weight_function is None: self.weight_function = { @@ -2481,10 +2587,11 @@ def __init__(self, weight_function=None, nbs_source="voronoi"): def weight(self, nb_set, structure_environments, cn_map=None, additional_info=None): """Get the weight of a given neighbors set. - :param nb_set: Neighbors set. - :param structure_environments: Structure environments used to estimate weight. - :param cn_map: Mapping index for this neighbors set. - :param additional_info: Additional information. + Args: + nb_set: Neighbors set. + structure_environments: Structure environments used to estimate weight. + cn_map: Mapping index for this neighbors set. + additional_info: Additional information. Returns: Weight of the neighbors set. @@ -2525,10 +2632,11 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Initialize from dict. - :param dct: Dict representation of DeltaDistanceNbSetWeight. + Args: + dct (dict): Dict representation of DeltaDistanceNbSetWeight. Returns: DeltaDistanceNbSetWeight. @@ -2555,8 +2663,10 @@ def __init__( ): """ Constructor for the WeightedNbSetChemenvStrategy. - :param structure_environments: StructureEnvironments object containing all the information on the - coordination of the sites in a structure. + + Args: + structure_environments: StructureEnvironments object containing all the information on the + coordination of the sites in a structure. """ if nb_set_weights is None: raise ValueError(f"{nb_set_weights=} must be provided") @@ -2590,15 +2700,16 @@ def get_site_coordination_environments_fractions( ): """Get the coordination environments of a given site and additional information. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param ordered: Whether to order the list by fractions. - :param min_fraction: Minimum fraction to include in the list - :param return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). - :param return_strategy_dict_info: Whether to add the info about the strategy used. + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + ordered: Whether to order the list by fractions. + min_fraction: Minimum fraction to include in the list + return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). + return_strategy_dict_info: Whether to add the info about the strategy used. Returns: List of Dict with coordination environment, fraction and additional info. @@ -2761,12 +2872,13 @@ def get_site_coordination_environments( ): """Get the coordination environments of a given site. - :param site: Site for which coordination environment is needed. - :param isite: Index of the site. - :param dequivsite: Translation of the equivalent site. - :param dthissite: Translation of this site. - :param mysym: Symmetry to be applied. - :param return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). + Args: + site: Site for which coordination environment is needed. + isite: Index of the site. + dequivsite: Translation of the equivalent site. + dthissite: Translation of this site. + mysym: Symmetry to be applied. + return_maps: Whether to return cn_maps (identifies all the NeighborsSet used). Returns: List of coordination environment. @@ -2812,11 +2924,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct) -> WeightedNbSetChemenvStrategy: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the WeightedNbSetChemenvStrategy object from a dict representation of the WeightedNbSetChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the WeightedNbSetChemenvStrategy object + + Args: + dct: dict representation of the WeightedNbSetChemenvStrategy object Returns: WeightedNbSetChemenvStrategy object. @@ -2832,7 +2946,7 @@ def from_dict(cls, dct) -> WeightedNbSetChemenvStrategy: class MultiWeightsChemenvStrategy(WeightedNbSetChemenvStrategy): """MultiWeightsChemenvStrategy.""" - STRATEGY_DESCRIPTION = " Multi Weights ChemenvStrategy" + STRATEGY_DESCRIPTION = "Multi Weights ChemenvStrategy" # STRATEGY_INFO_FIELDS = ['cn_map_surface_fraction', 'cn_map_surface_weight', # 'cn_map_mean_csm', 'cn_map_csm_weight', # 'cn_map_delta_csm', 'cn_map_delta_csms_cn_map2', 'cn_map_delta_csm_weight', @@ -2858,8 +2972,10 @@ def __init__( ): """ Constructor for the MultiWeightsChemenvStrategy. - :param structure_environments: StructureEnvironments object containing all the information on the - coordination of the sites in a structure. + + Args: + structure_environments: StructureEnvironments object containing all the information on the + coordination of the sites in a structure. """ self._additional_condition = additional_condition self.dist_ang_area_weight = dist_ang_area_weight @@ -2985,11 +3101,13 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct) -> MultiWeightsChemenvStrategy: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the MultiWeightsChemenvStrategy object from a dict representation of the MultipleAbundanceChemenvStrategy object created using the as_dict method. - :param dct: dict representation of the MultiWeightsChemenvStrategy object + + Args: + dct: dict representation of the MultiWeightsChemenvStrategy object Returns: MultiWeightsChemenvStrategy object. diff --git a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometries.py b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometries.py index 6afc40db794..9605fc7ff4c 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometries.py +++ b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometries.py @@ -13,11 +13,15 @@ import itertools import json import os +from typing import TYPE_CHECKING import numpy as np from monty.json import MontyDecoder, MSONable from scipy.special import factorial +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -109,7 +113,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Reconstruct ExplicitPermutationsAlgorithm from its JSON-serializable dict representation. """ @@ -241,7 +245,7 @@ def safe_separation_permutations(self, ordered_plane=False, ordered_point_groups number of permutations. add_opposite: Whether to add the permutations from the second group before the first group as well. - Returns + Returns: list[int]: safe permutations. """ s0 = list(range(len(self.point_groups[0]))) @@ -324,7 +328,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the SeparationPlane algorithm from its JSON-serializable dict representation. @@ -497,7 +501,7 @@ def as_dict(self): return {"hints_type": self.hints_type, "options": self.options} @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Reconstructs the NeighborsSetsHints from its JSON-serializable dict representation.""" return cls(hints_type=dct["hints_type"], options=dct["options"]) @@ -592,7 +596,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the CoordinationGeometry from its JSON-serializable dict representation. @@ -602,7 +606,6 @@ def from_dict(cls, dct): Returns: CoordinationGeometry """ - dec = MontyDecoder() return cls( mp_symbol=dct["mp_symbol"], name=dct["name"], @@ -620,7 +623,7 @@ def from_dict(cls, dct): deactivate=dct["deactivate"], faces=dct["_faces"], edges=dct["_edges"], - algorithms=[dec.process_decoded(algo_d) for algo_d in dct["_algorithms"]] + algorithms=[MontyDecoder().process_decoded(algo_d) for algo_d in dct["_algorithms"]] if dct["_algorithms"] is not None else None, equivalent_indices=dct.get("equivalent_indices"), @@ -794,7 +797,7 @@ def faces(self, sites, permutation=None): list of its vertices coordinates. """ coords = [site.coords for site in sites] if permutation is None else [sites[ii].coords for ii in permutation] - return [[coords[ii] for ii in f] for f in self._faces] + return [[coords[ii] for ii in face] for face in self._faces] def edges(self, sites, permutation=None, input="sites"): """ @@ -811,7 +814,7 @@ def edges(self, sites, permutation=None, input="sites"): # coords = [sites[ii].coords for ii in permutation] if permutation is not None: coords = [coords[ii] for ii in permutation] - return [[coords[ii] for ii in e] for e in self._edges] + return [[coords[ii] for ii in edge] for edge in self._edges] def solid_angles(self, permutation=None): """ @@ -851,11 +854,11 @@ def get_pmeshes(self, sites, permutation=None): elif len(face) == 4: out += "5\n" else: - for ii, f in enumerate(face): + for ii, f in enumerate(face, start=1): out += "4\n" out += f"{len(_vertices) + iface}\n" out += f"{f}\n" - out += f"{face[np.mod(ii + 1, len(face))]}\n" + out += f"{face[np.mod(ii, len(face))]}\n" out += f"{len(_vertices) + iface}\n" if len(face) in [3, 4]: for face_vertex in face: @@ -1186,7 +1189,8 @@ def is_a_valid_coordination_geometry( return True except LookupError: return True - raise Exception("Should not be here !") + # TODO give a more helpful error message that suggests possible reasons and solutions + raise RuntimeError("Should not be here!") def pretty_print(self, type="implemented_geometries", maxcn=8, additional_info=None): """ diff --git a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py index 165978b3f97..42bc3c2f305 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py +++ b/pymatgen/analysis/chemenv/coordination_environments/coordination_geometry_finder.py @@ -18,6 +18,7 @@ import logging import time from random import shuffle +from typing import TYPE_CHECKING import numpy as np from numpy.linalg import norm, svd @@ -46,6 +47,9 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -80,12 +84,16 @@ def __init__( ): """ Constructor for the abstract geometry - :param central_site: Coordinates of the central site - :param bare_coords: Coordinates of the neighbors of the central site - :param centering_type: How to center the abstract geometry - :param include_central_site_in_centroid: When the centering is on the centroid, the central site is included - if this parameter is set to True. - :raise: ValueError if the parameters are not consistent. + + Args: + central_site: Coordinates of the central site + bare_coords: Coordinates of the neighbors of the central site + centering_type: How to center the abstract geometry + include_central_site_in_centroid: When the centering is on the centroid, + the central site is included if this parameter is set to True. + + Raises: + ValueError if the parameters are not consistent. """ bcoords = np.array(bare_coords) self.bare_centre = np.array(central_site) @@ -185,11 +193,12 @@ def __str__(self): return "\n".join(outs) @classmethod - def from_cg(cls, cg, centering_type="standard", include_central_site_in_centroid=False): + def from_cg(cls, cg, centering_type="standard", include_central_site_in_centroid=False) -> Self: """ - :param cg: - :param centering_type: - :param include_central_site_in_centroid: + Args: + cg: + centering_type: + include_central_site_in_centroid: """ central_site = cg.get_central_site() bare_coords = [np.array(pt, float) for pt in cg.points] @@ -202,7 +211,8 @@ def from_cg(cls, cg, centering_type="standard", include_central_site_in_centroid def points_wcs_csc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wcs_csc @@ -210,7 +220,8 @@ def points_wcs_csc(self, permutation=None): def points_wocs_csc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wocs_csc @@ -218,7 +229,8 @@ def points_wocs_csc(self, permutation=None): def points_wcs_ctwcc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wcs_ctwcc @@ -231,7 +243,8 @@ def points_wcs_ctwcc(self, permutation=None): def points_wocs_ctwcc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wocs_ctwcc @@ -239,7 +252,8 @@ def points_wocs_ctwcc(self, permutation=None): def points_wcs_ctwocc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wcs_ctwocc @@ -252,7 +266,8 @@ def points_wcs_ctwocc(self, permutation=None): def points_wocs_ctwocc(self, permutation=None): """ - :param permutation: + Args: + permutation: """ if permutation is None: return self._points_wocs_ctwocc @@ -273,10 +288,13 @@ def symmetry_measure(points_distorted, points_perfect): """ Computes the continuous symmetry measure of the (distorted) set of points "points_distorted" with respect to the (perfect) set of points "points_perfect". - :param points_distorted: List of points describing a given (distorted) polyhedron for which the symmetry measure - has to be computed with respect to the model polyhedron described by the list of points - "points_perfect". - :param points_perfect: List of "perfect" points describing a given model polyhedron. + + Args: + points_distorted: List of points describing a given (distorted) polyhedron for which the symmetry measure + has to be computed with respect to the model polyhedron described by the list of points + "points_perfect". + points_perfect: List of "perfect" points describing a given model polyhedron. + Returns: The continuous symmetry measure of the distorted polyhedron with respect to the perfect polyhedron. """ @@ -306,9 +324,12 @@ def find_rotation(points_distorted, points_perfect): """ This finds the rotation matrix that aligns the (distorted) set of points "points_distorted" with respect to the (perfect) set of points "points_perfect" in a least-square sense. - :param points_distorted: List of points describing a given (distorted) polyhedron for which the rotation that - aligns these points in a least-square sense to the set of perfect points "points_perfect" - :param points_perfect: List of "perfect" points describing a given model polyhedron. + + Args: + points_distorted: List of points describing a given (distorted) polyhedron for which the rotation that + aligns these points in a least-square sense to the set of perfect points "points_perfect" + points_perfect: List of "perfect" points describing a given model polyhedron. + Returns: The rotation matrix. """ @@ -321,10 +342,12 @@ def find_scaling_factor(points_distorted, points_perfect, rot): """ This finds the scaling factor between the (distorted) set of points "points_distorted" and the (perfect) set of points "points_perfect" in a least-square sense. - :param points_distorted: List of points describing a given (distorted) polyhedron for which the scaling factor has - to be obtained. - :param points_perfect: List of "perfect" points describing a given model polyhedron. - :param rot: The rotation matrix + + Args: + points_distorted: List of points describing a given (distorted) polyhedron for + which the scaling factor has to be obtained. + points_perfect: List of "perfect" points describing a given model polyhedron. + rot: The rotation matrix Returns: The scaling factor between the two structures and the rotated set of (distorted) points. @@ -403,15 +426,17 @@ def setup_parameters( chosen. This can be the centroid of the structure (including or excluding the atom for which the coordination geometry is looked for) or the atom itself. In the 'standard' centering_type, the reference point is the central atom for coordination numbers 1, 2, 3 and 4 and the centroid for coordination numbers > 4. - :param centering_type: Type of the reference point (centering) 'standard', 'centroid' or 'central_site' - :param include_central_site_in_centroid: In case centering_type is 'centroid', the central site is included if - this value is set to True. - :param bva_distance_scale_factor: Scaling factor for the bond valence analyzer (this might be different whether - the structure is an experimental one, an LDA or a GGA relaxed one, or any other relaxation scheme (where - under- or over-estimation of bond lengths is known). - :param structure_refinement: Refinement of the structure. Can be "none", "refined" or "symmetrized". - :param spg_analyzer_options: Options for the SpaceGroupAnalyzer (dictionary specifying "symprec" - and "angle_tolerance". See pymatgen's SpaceGroupAnalyzer for more information. + + Args: + centering_type: Type of the reference point (centering) 'standard', 'centroid' or 'central_site' + include_central_site_in_centroid: In case centering_type is 'centroid', the central site is included if + this value is set to True. + bva_distance_scale_factor: Scaling factor for the bond valence analyzer (this might be different whether + the structure is an experimental one, an LDA or a GGA relaxed one, or any other relaxation scheme (where + under- or over-estimation of bond lengths is known). + structure_refinement: Refinement of the structure. Can be "none", "refined" or "symmetrized". + spg_analyzer_options: Options for the SpaceGroupAnalyzer (dictionary specifying "symprec" + and "angle_tolerance". See pymatgen's SpaceGroupAnalyzer for more information. """ self.centering_type = centering_type self.include_central_site_in_centroid = include_central_site_in_centroid @@ -429,8 +454,10 @@ def setup_parameter(self, parameter, value): """ Setup of one specific parameter to the given value. The other parameters are unchanged. See setup_parameters method for the list of possible parameters - :param parameter: Parameter to setup/update - :param value: Value of the parameter. + + Args: + parameter: Parameter to setup/update + value: Value of the parameter. """ self.__dict__[parameter] = value @@ -438,7 +465,9 @@ def setup_structure(self, structure: Structure): """ Sets up the structure for which the coordination geometries have to be identified. The structure is analyzed with the space group analyzer and a refined structure is used - :param structure: A pymatgen Structure. + + Args: + structure: A pymatgen Structure. """ self.initial_structure = structure.copy() if self.structure_refinement == self.STRUCTURE_REFINEMENT_NONE: @@ -477,10 +506,12 @@ def set_structure(self, lattice: Lattice, species, coords, coords_are_cartesian) """ Sets up the pymatgen structure for which the coordination geometries have to be identified starting from the lattice, the species and the coordinates - :param lattice: The lattice of the structure - :param species: The species on the sites - :param coords: The coordinates of the sites - :param coords_are_cartesian: If set to True, the coordinates are given in Cartesian coordinates. + + Args: + lattice: The lattice of the structure + species: The species on the sites + coords: The coordinates of the sites + coords_are_cartesian: If set to True, the coordinates are given in Cartesian coordinates. """ self.setup_structure(Structure(lattice, species, coords, coords_are_cartesian)) @@ -494,12 +525,13 @@ def compute_coordination_environments( initial_structure_environments=None, ): """ - :param structure: - :param indices: - :param only_cations: - :param strategy: - :param valences: - :param initial_structure_environments: + Args: + structure: + indices: + only_cations: + strategy: + valences: + initial_structure_environments: """ self.setup_structure(structure=structure) if valences == "bond-valence-analysis": @@ -553,34 +585,36 @@ def compute_structure_environments( """ Computes and returns the StructureEnvironments object containing all the information about the coordination environments in the structure - :param excluded_atoms: Atoms for which the coordination geometries does not have to be identified - :param only_atoms: If not set to None, atoms for which the coordination geometries have to be identified - :param only_cations: If set to True, will only compute environments for cations - :param only_indices: If not set to None, will only compute environments the atoms of the given indices - :param maximum_distance_factor: If not set to None, neighbors beyond - maximum_distance_factor*closest_neighbor_distance are not considered - :param minimum_angle_factor: If not set to None, neighbors for which the angle is lower than - minimum_angle_factor*largest_angle_neighbor are not considered - :param max_cn: maximum coordination number to be considered - :param min_cn: minimum coordination number to be considered - :param only_symbols: if not set to None, consider only coordination environments with the given symbols - :param valences: valences of the atoms - :param additional_conditions: additional conditions to be considered in the bonds (example : only bonds - between cation and anion - :param info: additional info about the calculation - :param timelimit: time limit (in secs) after which the calculation of the StructureEnvironments object stops - :param initial_structure_environments: initial StructureEnvironments object (most probably incomplete) - :param get_from_hints: whether to add neighbors sets from "hints" (e.g. capped environment => test the - neighbors without the cap) - :param voronoi_normalized_distance_tolerance: tolerance for the normalized distance used to distinguish - neighbors sets - :param voronoi_normalized_angle_tolerance: tolerance for the normalized angle used to distinguish - neighbors sets - :param voronoi_distance_cutoff: determines distance of considered neighbors. Especially important to increase it - for molecules in a box. - :param recompute: whether to recompute the sites already computed (when initial_structure_environments - is not None) - :param optimization: optimization algorithm + + Args: + excluded_atoms: Atoms for which the coordination geometries does not have to be identified + only_atoms: If not set to None, atoms for which the coordination geometries have to be identified + only_cations: If set to True, will only compute environments for cations + only_indices: If not set to None, will only compute environments the atoms of the given indices + maximum_distance_factor: If not set to None, neighbors beyond + maximum_distance_factor*closest_neighbor_distance are not considered + minimum_angle_factor: If not set to None, neighbors for which the angle is lower than + minimum_angle_factor*largest_angle_neighbor are not considered + max_cn: maximum coordination number to be considered + min_cn: minimum coordination number to be considered + only_symbols: if not set to None, consider only coordination environments with the given symbols + valences: valences of the atoms + additional_conditions: additional conditions to be considered in the bonds (example : only bonds + between cation and anion + info: additional info about the calculation + timelimit: time limit (in secs) after which the calculation of the StructureEnvironments object stops + initial_structure_environments: initial StructureEnvironments object (most probably incomplete) + get_from_hints: whether to add neighbors sets from "hints" (e.g. capped environment => test the + neighbors without the cap) + voronoi_normalized_distance_tolerance: tolerance for the normalized distance used to distinguish + neighbors sets + voronoi_normalized_angle_tolerance: tolerance for the normalized angle used to distinguish + neighbors sets + voronoi_distance_cutoff: determines distance of considered neighbors. Especially important to increase it + for molecules in a box. + recompute: whether to recompute the sites already computed (when initial_structure_environments + is not None) + optimization: optimization algorithm Returns: The StructureEnvironments object containing all the information about the coordination @@ -848,13 +882,14 @@ def compute_structure_environments( def update_nb_set_environments(self, se, isite, cn, inb_set, nb_set, recompute=False, optimization=None): """ - :param se: - :param isite: - :param cn: - :param inb_set: - :param nb_set: - :param recompute: - :param optimization: + Args: + se: + isite: + cn: + inb_set: + nb_set: + recompute: + optimization: """ ce = se.get_coordination_environments(isite=isite, cn=cn, nb_set=nb_set) if ce is not None and not recompute: @@ -915,8 +950,10 @@ def update_nb_set_environments(self, se, isite, cn, inb_set, nb_set, recompute=F def setup_local_geometry(self, isite, coords, optimization=None): """ Sets up the AbstractGeometry for the local geometry of site with index isite. - :param isite: Index of the site for which the local geometry has to be set up - :param coords: The coordinates of the (local) neighbors. + + Args: + isite: Index of the site for which the local geometry has to be set up + coords: The coordinates of the (local) neighbors. """ self.local_geometry = AbstractGeometry( central_site=self.structure.cart_coords[isite], @@ -939,15 +976,16 @@ def setup_test_perfect_environment( points=None, ): """ - :param symbol: - :param randomness: - :param max_random_dist: - :param symbol_type: - :param indices: - :param random_translation: - :param random_rotation: - :param random_scale: - :param points: + Args: + symbol: + randomness: + max_random_dist: + symbol_type: + indices: + random_translation: + random_rotation: + random_scale: + points: """ if symbol_type == "IUPAC": cg = self.allcg.get_geometry_from_IUPAC_symbol(symbol) @@ -1072,7 +1110,9 @@ def setup_test_perfect_environment( def setup_random_structure(self, coordination): """ Sets up a purely random structure with a given coordination. - :param coordination: coordination number for the random structure. + + Args: + coordination: coordination number for the random structure. """ aa = 0.4 bb = -0.2 @@ -1090,7 +1130,9 @@ def setup_random_structure(self, coordination): def setup_random_indices_local_geometry(self, coordination): """ Sets up random indices for the local geometry, for testing purposes - :param coordination: coordination of the local geometry. + + Args: + coordination: coordination of the local geometry. """ self.icentral_site = 0 self.indices = list(range(1, coordination + 1)) @@ -1099,7 +1141,9 @@ def setup_random_indices_local_geometry(self, coordination): def setup_ordered_indices_local_geometry(self, coordination): """ Sets up ordered indices for the local geometry, for testing purposes - :param coordination: coordination of the local geometry. + + Args: + coordination: coordination of the local geometry. """ self.icentral_site = 0 self.indices = list(range(1, coordination + 1)) @@ -1107,7 +1151,9 @@ def setup_ordered_indices_local_geometry(self, coordination): def setup_explicit_indices_local_geometry(self, explicit_indices): """ Sets up explicit indices for the local geometry, for testing purposes - :param explicit_indices: explicit indices for the neighbors (set of numbers + + Args: + explicit_indices: explicit indices for the neighbors (set of numbers from 0 to CN-1 in a given order). """ self.icentral_site = 0 @@ -1306,7 +1352,8 @@ def coordination_geometry_symmetry_measures( permutations depending on the permutation setup. Depending on the parameters of the LocalGeometryFinder and on the coordination geometry, different methods are called. - :param coordination_geometry: Coordination geometry for which the symmetry measures are looked for + Args: + coordination_geometry: Coordination geometry for which the symmetry measures are looked for Raises: NotImplementedError: if the permutation_setup does not exist @@ -1354,7 +1401,8 @@ def coordination_geometry_symmetry_measures_sepplane_optim( permutations depending on the permutation setup. Depending on the parameters of the LocalGeometryFinder and on the coordination geometry, different methods are called. - :param coordination_geometry: Coordination geometry for which the symmetry measures are looked for + Args: + coordination_geometry: Coordination geometry for which the symmetry measures are looked for Raises: NotImplementedError: if the permutation_setup does not exist @@ -1392,7 +1440,9 @@ def coordination_geometry_symmetry_measures_standard( Returns the symmetry measures for a set of permutations (whose setup depends on the coordination geometry) for the coordination geometry "coordination_geometry". Standard implementation looking for the symmetry measures of each permutation - :param coordination_geometry: The coordination geometry to be investigated + + Args: + coordination_geometry: The coordination geometry to be investigated Returns: The symmetry measures for the given coordination geometry for each permutation investigated. @@ -1471,7 +1521,9 @@ def coordination_geometry_symmetry_measures_separation_plane( """ Returns the symmetry measures of the given coordination geometry "coordination_geometry" using separation facets to reduce the complexity of the system. Caller to the refined 2POINTS, 3POINTS and other ... - :param coordination_geometry: The coordination geometry to be investigated + + Args: + coordination_geometry: The coordination geometry to be investigated Returns: The symmetry measures for the given coordination geometry for each plane and permutation investigated. @@ -1995,8 +2047,10 @@ def coordination_geometry_symmetry_measures_fallback_random( Returns the symmetry measures for a random set of permutations for the coordination geometry "coordination_geometry". Fallback implementation for the plane separation algorithms measures of each permutation - :param coordination_geometry: The coordination geometry to be investigated - :param NRANDOM: Number of random permutations to be tested + + Args: + coordination_geometry: The coordination geometry to be investigated + NRANDOM: Number of random permutations to be tested Returns: The symmetry measures for the given coordination geometry for each permutation investigated. diff --git a/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py b/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py index f4a89a809a8..d4d2217a12b 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py +++ b/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py @@ -8,6 +8,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt import numpy as np from matplotlib import cm @@ -22,6 +24,9 @@ from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions from pymatgen.core import Element, PeriodicNeighbor, PeriodicSite, Species, Structure +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -366,7 +371,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct, structure: Structure, detailed_voronoi): + def from_dict(cls, dct, structure: Structure, detailed_voronoi) -> Self: """ Reconstructs the NeighborsSet algorithm from its JSON-serializable dict representation, together with the structure and the DetailedVoronoiContainer. @@ -497,7 +502,7 @@ def init_neighbors_sets(self, isite, additional_conditions=None, valences=None): } site_voronoi_indices = [ inb - for inb, voro_nb_dict in enumerate(site_voronoi) + for inb, _voro_nb_dict in enumerate(site_voronoi) if ( distance_conditions[idp][inb] and angle_conditions[iap][inb] @@ -1247,7 +1252,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict) -> StructureEnvironments: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the StructureEnvironments object from a dict representation of the StructureEnvironments created using the as_dict method. @@ -1419,7 +1424,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct, structure: Structure, all_nbs_sites): + def from_dict(cls, dct, structure: Structure, all_nbs_sites) -> Self: """ Reconstructs the NeighborsSet algorithm from its JSON-serializable dict representation, together with the structure and all the possible neighbors sites. @@ -1476,7 +1481,7 @@ def __init__( self.valences_origin = valences_origin @classmethod - def from_structure_environments(cls, strategy, structure_environments, valences=None, valences_origin=None): + def from_structure_environments(cls, strategy, structure_environments, valences=None, valences_origin=None) -> Self: """ Construct a LightStructureEnvironments object from a strategy and a StructureEnvironments object. @@ -1492,10 +1497,10 @@ def from_structure_environments(cls, strategy, structure_environments, valences= """ structure = structure_environments.structure strategy.set_structure_environments(structure_environments=structure_environments) - coordination_environments = [None] * len(structure) - neighbors_sets = [None] * len(structure) - _all_nbs_sites = [] - all_nbs_sites = [] + coordination_environments: list = [None] * len(structure) + neighbors_sets: list = [None] * len(structure) + _all_nbs_sites: list = [] + all_nbs_sites: list = [] if valences is None: valences = structure_environments.valences if valences_origin is None: @@ -1510,7 +1515,7 @@ def from_structure_environments(cls, strategy, structure_environments, valences= coordination_environments[idx] = [] neighbors_sets[idx] = [] site_ces = [] - site_nbs_sets = [] + site_nbs_sets: list = [] for ce_and_neighbors in site_ces_and_nbs_list: _all_nbs_sites_indices = [] # Coordination environment @@ -1556,6 +1561,7 @@ def from_structure_environments(cls, strategy, structure_environments, valences= site_nbs_sets.append(nb_set) coordination_environments[idx] = site_ces neighbors_sets[idx] = site_nbs_sets + return cls( strategy=strategy, coordination_environments=coordination_environments, @@ -2015,7 +2021,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct) -> LightStructureEnvironments: + def from_dict(cls, dct) -> Self: """ Reconstructs the LightStructureEnvironments object from a dict representation of the LightStructureEnvironments created using the as_dict method. @@ -2026,11 +2032,10 @@ def from_dict(cls, dct) -> LightStructureEnvironments: Returns: LightStructureEnvironments object. """ - dec = MontyDecoder() - structure = dec.process_decoded(dct["structure"]) + structure = MontyDecoder().process_decoded(dct["structure"]) all_nbs_sites = [] for nb_site in dct["all_nbs_sites"]: - periodic_site = dec.process_decoded(nb_site["site"]) + periodic_site = MontyDecoder().process_decoded(nb_site["site"]) site = PeriodicNeighbor( species=periodic_site.species, coords=periodic_site.frac_coords, @@ -2056,7 +2061,7 @@ def from_dict(cls, dct) -> LightStructureEnvironments: for site_nb_sets in dct["neighbors_sets"] ] return cls( - strategy=dec.process_decoded(dct["strategy"]), + strategy=MontyDecoder().process_decoded(dct["strategy"]), coordination_environments=dct["coordination_environments"], all_nbs_sites=all_nbs_sites, neighbors_sets=neighbors_sets, @@ -2335,7 +2340,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict) -> ChemicalEnvironments: + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the ChemicalEnvironments object from a dict representation of the ChemicalEnvironments created using the as_dict method. diff --git a/pymatgen/analysis/chemenv/coordination_environments/voronoi.py b/pymatgen/analysis/chemenv/coordination_environments/voronoi.py index 85fce37f62c..b5bdeed8764 100644 --- a/pymatgen/analysis/chemenv/coordination_environments/voronoi.py +++ b/pymatgen/analysis/chemenv/coordination_environments/voronoi.py @@ -4,6 +4,7 @@ import logging import time +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -20,6 +21,9 @@ from pymatgen.core.sites import PeriodicSite from pymatgen.core.structure import Structure +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -29,7 +33,7 @@ __date__ = "Feb 20, 2016" -def from_bson_voronoi_list2(bson_nb_voro_list2, structure): +def from_bson_voronoi_list2(bson_nb_voro_list2: list[PeriodicSite], structure: Structure): """ Returns the voronoi_list needed for the VoronoiContainer object from a bson-encoded voronoi_list. @@ -41,21 +45,22 @@ def from_bson_voronoi_list2(bson_nb_voro_list2, structure): The voronoi_list needed for the VoronoiContainer (with PeriodicSites as keys of the dictionary - not allowed in the BSON format). """ - voronoi_list = [None] * len(bson_nb_voro_list2) - for isite, voro in enumerate(bson_nb_voro_list2): - if voro is None or voro == "None": + voronoi_list: list[list[dict] | None] = [None] * len(bson_nb_voro_list2) + + for idx, voro in enumerate(bson_nb_voro_list2): + if voro in (None, "None"): continue - voronoi_list[isite] = [] + + voronoi_list[idx] = [] for psd, dct in voro: struct_site = structure[dct["index"]] - periodic_site = PeriodicSite( + dct["site"] = PeriodicSite( struct_site._species, struct_site.frac_coords + psd[1], struct_site._lattice, properties=struct_site.properties, ) - dct["site"] = periodic_site - voronoi_list[isite].append(dct) + voronoi_list[idx].append(dct) # type: ignore[union-attr] return voronoi_list @@ -141,8 +146,8 @@ def setup_voronoi_list(self, indices, voronoi_cutoff): logging.debug("Please consider increasing voronoi_distance_cutoff") t1 = time.process_time() logging.debug("Setting up Voronoi list :") - for jj, isite in enumerate(indices): - logging.debug(f" - Voronoi analysis for site #{isite} ({jj + 1}/{len(indices)})") + for jj, isite in enumerate(indices, start=1): + logging.debug(f" - Voronoi analysis for site #{isite} ({jj}/{len(indices)})") site = self.structure[isite] neighbors1 = [(site, 0.0, isite)] neighbors1.extend(struct_neighbors[isite]) @@ -938,7 +943,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Reconstructs the VoronoiContainer object from a dict representation of the VoronoiContainer created using the as_dict method. @@ -953,6 +958,7 @@ def from_dict(cls, dct): voronoi_list2 = from_bson_voronoi_list2(dct["bson_nb_voro_list2"], structure) maximum_distance_factor = dct.get("maximum_distance_factor") minimum_angle_factor = dct.get("minimum_angle_factor") + return cls( structure=structure, voronoi_list2=voronoi_list2, diff --git a/pymatgen/analysis/chemenv/utils/chemenv_config.py b/pymatgen/analysis/chemenv/utils/chemenv_config.py index bb641207c25..99eeb8d1a62 100644 --- a/pymatgen/analysis/chemenv/utils/chemenv_config.py +++ b/pymatgen/analysis/chemenv/utils/chemenv_config.py @@ -42,7 +42,10 @@ class ChemEnvConfig: ) def __init__(self, package_options=None): - """:param package_options:""" + """ + Args: + package_options: + """ if SETTINGS.get("PMG_MAPI_KEY"): self.materials_project_configuration = SETTINGS.get("PMG_MAPI_KEY") else: @@ -94,7 +97,7 @@ def setup_package_options(self): self.package_options = self.DEFAULT_PACKAGE_OPTIONS print("Choose between the following strategies : ") strategies = list(strategies_class_lookup) - for idx, strategy in enumerate(strategies, 1): + for idx, strategy in enumerate(strategies, start=1): print(f" <{idx}> : {strategy}") test = input(" ... ") self.package_options["default_strategy"] = { @@ -135,7 +138,9 @@ def package_options_description(self): def save(self, root_dir=None): """ Save the options. - :param root_dir: + + Args: + root_dir: """ if root_dir is None: home = expanduser("~") @@ -158,7 +163,9 @@ def save(self, root_dir=None): def auto_load(cls, root_dir=None): """ Autoload options. - :param root_dir: + + Args: + root_dir: """ if root_dir is None: home = expanduser("~") diff --git a/pymatgen/analysis/chemenv/utils/chemenv_errors.py b/pymatgen/analysis/chemenv/utils/chemenv_errors.py index 8e1cae20e2b..479d61454d0 100644 --- a/pymatgen/analysis/chemenv/utils/chemenv_errors.py +++ b/pymatgen/analysis/chemenv/utils/chemenv_errors.py @@ -16,9 +16,10 @@ class AbstractChemenvError(Exception): def __init__(self, cls, method, msg): """ - :param cls: - :param method: - :param msg: + Args: + cls: + method: + msg: """ self.cls = cls self.method = method @@ -32,7 +33,10 @@ class NeighborsNotComputedChemenvError(AbstractChemenvError): """Neighbors not computed error.""" def __init__(self, site): - """:param site:""" + """ + Args: + site: + """ self.site = site def __str__(self): @@ -43,7 +47,10 @@ class EquivalentSiteSearchError(AbstractChemenvError): """Equivalent site search error.""" def __init__(self, site): - """:param site:""" + """ + Args: + site: + """ self.site = site def __str__(self): @@ -54,7 +61,10 @@ class SolidAngleError(AbstractChemenvError): """Solid angle error.""" def __init__(self, cosinus): - """:param cosinus:""" + """ + Args: + cosinus: + """ self.cosinus = cosinus def __str__(self): @@ -66,9 +76,10 @@ class ChemenvError(Exception): def __init__(self, cls: str, method: str, msg: str): """ - :param cls: - :param method: - :param msg: + Args: + cls: + method: + msg: """ self.cls = cls self.method = method diff --git a/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py b/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py index 447b7bd29a1..04e55a3ebb4 100644 --- a/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py +++ b/pymatgen/analysis/chemenv/utils/coordination_geometry_utils.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" @@ -28,7 +29,9 @@ def get_lower_and_upper_f(surface_calculation_options): """Get the lower and upper functions defining a surface in the distance-angle space of neighbors. - :param surface_calculation_options: Options for the surface. + Args: + surface_calculation_options: Options for the surface. + Returns: Dictionary containing the "lower" and "upper" functions for the surface. """ @@ -339,13 +342,13 @@ def rectangle_surface_intersection( xmax = min(x2, bounds_lower[1]) def diff(x): - flwx = f_lower(x) - fupx = f_upper(x) - minup = np.min([fupx, y2 * np.ones_like(fupx)], axis=0) - maxlw = np.max([flwx, y1 * np.ones_like(flwx)], axis=0) - zeros = np.zeros_like(fupx) - upper = np.where(y2 >= flwx, np.where(y1 <= fupx, minup, zeros), zeros) - lower = np.where(y1 <= fupx, np.where(y2 >= flwx, maxlw, zeros), zeros) + f_low_x = f_lower(x) + f_up_x = f_upper(x) + min_up = np.min([f_up_x, y2 * np.ones_like(f_up_x)], axis=0) + max_lw = np.max([f_low_x, y1 * np.ones_like(f_low_x)], axis=0) + zeros = np.zeros_like(f_up_x) + upper = np.where(y2 >= f_low_x, np.where(y1 <= f_up_x, min_up, zeros), zeros) + lower = np.where(y1 <= f_up_x, np.where(y2 >= f_low_x, max_lw, zeros), zeros) return upper - lower return quad(diff, xmin, xmax) @@ -356,16 +359,14 @@ def solid_angle(center, coords): Helper method to calculate the solid angle of a set of coords from the center. Args: - center: - Center to measure solid angle from. - coords: - List of coords to determine solid angle. + center: Center to measure solid angle from. + coords: List of coords to determine solid angle. Returns: The solid angle. """ - o = np.array(center) - r = [np.array(c) - o for c in coords] + origin = np.array(center) + r = [np.array(c) - origin for c in coords] r.append(r[0]) n = [np.cross(r[i + 1], r[i]) for i in range(len(r) - 1)] n.append(np.cross(r[1], r[0])) @@ -388,8 +389,10 @@ def solid_angle(center, coords): def vectorsToMatrix(aa, bb): """ Performs the vector multiplication of the elements of two vectors, constructing the 3x3 matrix. - :param aa: One vector of size 3 - :param bb: Another vector of size 3 + + Args: + aa: One vector of size 3 + bb: Another vector of size 3 Returns: A 3x3 matrix M composed of the products of the elements of aa and bb : M_ij = aa_i * bb_j. @@ -404,8 +407,9 @@ def vectorsToMatrix(aa, bb): def matrixTimesVector(MM, aa): """ - :param MM: A matrix of size 3x3 - :param aa: A vector of size 3 + Args: + MM: A matrix of size 3x3 + aa: A vector of size 3 Returns: A vector of size 3 which is the product of the matrix by the vector @@ -419,8 +423,10 @@ def matrixTimesVector(MM, aa): def rotateCoords(coords, R): """ Rotate the list of points using rotation matrix R - :param coords: List of points to be rotated - :param R: Rotation matrix + + Args: + coords: List of points to be rotated + R: Rotation matrix Returns: List of rotated points. @@ -435,8 +441,10 @@ def rotateCoords(coords, R): def rotateCoordsOpt(coords, R): """ Rotate the list of points using rotation matrix R - :param coords: List of points to be rotated - :param R: Rotation matrix + + Args: + coords: List of points to be rotated + R: Rotation matrix Returns: List of rotated points. @@ -448,10 +456,12 @@ def changebasis(uu, vv, nn, pps): """ For a list of points given in standard coordinates (in terms of e1, e2 and e3), returns the same list expressed in the basis (uu, vv, nn), which is supposed to be orthonormal. - :param uu: First vector of the basis - :param vv: Second vector of the basis - :param nn: Third vector of the basis - :param pps: List of points in basis (e1, e2, e3) + + Args: + uu: First vector of the basis + vv: Second vector of the basis + nn: Third vector of the basis + pps: List of points in basis (e1, e2, e3) Returns: List of points in basis (uu, vv, nn). """ @@ -474,10 +484,13 @@ def collinear(p1, p2, p3=None, tolerance=0.25): triangle is less than (tolerance x largest_triangle), then the three points are considered collinear. The largest_triangle is defined as the right triangle whose legs are the two smallest distances between the three points ie, its area is : 0.5 x (min(|p2-p1|,|p3-p1|,|p3-p2|) x second_min(|p2-p1|,|p3-p1|,|p3-p2|)) - :param p1: First point - :param p2: Second point - :param p3: Third point (origin [0.0, 0.0, 0.0 if not given]) - :param tolerance: Area tolerance for the collinearity test (0.25 gives about 0.125 deviation from the line) + + Args: + p1: First point + p2: Second point + p3: Third point (origin [0.0, 0.0, 0.0 if not given]) + tolerance: Area tolerance for the collinearity test (0.25 gives about 0.125 deviation from the line) + Returns: bool: True if the three points are considered as collinear within the given tolerance. """ @@ -494,7 +507,9 @@ def collinear(p1, p2, p3=None, tolerance=0.25): def anticlockwise_sort(pps): """ Sort a list of 2D points in anticlockwise order - :param pps: List of points to be sorted + + Args: + pps: List of points to be sorted Returns: Sorted list of points. @@ -512,7 +527,9 @@ def anticlockwise_sort(pps): def anticlockwise_sort_indices(pps): """ Returns the indices that would sort a list of 2D points in anticlockwise order - :param pps: List of points to be sorted + + Args: + pps: List of points to be sorted Returns: Indices of the sorted list of points. @@ -526,7 +543,9 @@ def anticlockwise_sort_indices(pps): def sort_separation(separation): """Sort a separation. - :param separation: Initial separation. + Args: + separation: Initial separation. + Returns: Sorted list of separation. """ @@ -538,7 +557,8 @@ def sort_separation(separation): def sort_separation_tuple(separation): """Sort a separation. - :param separation: Initial separation + Args: + separation: Initial separation Returns: Sorted tuple of separation @@ -559,8 +579,10 @@ def sort_separation_tuple(separation): def separation_in_list(separation_indices, separation_indices_list): """ Checks if the separation indices of a plane are already in the list - :param separation_indices: list of separation indices (three arrays of integers) - :param separation_indices_list: list of the list of separation indices to be compared to + + Args: + separation_indices: list of separation indices (three arrays of integers) + separation_indices_list: list of the list of separation indices to be compared to Returns: bool: True if the separation indices are already in the list. @@ -575,9 +597,11 @@ def separation_in_list(separation_indices, separation_indices_list): def is_anion_cation_bond(valences, ii, jj) -> bool: """ Checks if two given sites are an anion and a cation. - :param valences: list of site valences - :param ii: index of a site - :param jj: index of another site + + Args: + valences: list of site valences + ii: index of a site + jj: index of another site Returns: bool: True if one site is an anion and the other is a cation (based on valences). @@ -619,7 +643,9 @@ class Plane: def __init__(self, coefficients, p1=None, p2=None, p3=None): """ Initializes a plane from the 4 coefficients a, b, c and d of ax + by + cz + d = 0 - :param coefficients: abcd coefficients of the plane. + + Args: + coefficients: abcd coefficients of the plane. """ # Initializes the normal vector self.normal_vector = np.array([coefficients[0], coefficients[1], coefficients[2]], float) @@ -652,8 +678,9 @@ def __init__(self, coefficients, p1=None, p2=None, p3=None): def init_3points(self, non_zeros, zeros): """Initialize three random points on this plane. - :param non_zeros: Indices of plane coefficients ([a, b, c]) that are not zero. - :param zeros: Indices of plane coefficients ([a, b, c]) that are equal to zero. + Args: + non_zeros: Indices of plane coefficients ([a, b, c]) that are not zero. + zeros: Indices of plane coefficients ([a, b, c]) that are equal to zero. """ if len(non_zeros) == 3: self.p1 = np.array([-self.d / self.a, 0.0, 0.0], float) @@ -690,8 +717,10 @@ def __str__(self): def is_in_plane(self, pp, dist_tolerance) -> bool: """ Determines if point pp is in the plane within the tolerance dist_tolerance - :param pp: point to be tested - :param dist_tolerance: tolerance on the distance to the plane within which point pp is considered in the plane + + Args: + pp: point to be tested + dist_tolerance: tolerance on the distance to the plane within which point pp is considered in the plane Returns: bool: True if pp is in the plane. @@ -701,7 +730,9 @@ def is_in_plane(self, pp, dist_tolerance) -> bool: def is_same_plane_as(self, plane) -> bool: """ Checks whether the plane is identical to another Plane "plane" - :param plane: Plane to be compared to + + Args: + plane: Plane to be compared to Returns: bool: True if the two facets are identical. @@ -711,7 +742,9 @@ def is_same_plane_as(self, plane) -> bool: def is_in_list(self, plane_list) -> bool: """ Checks whether the plane is identical to one of the Planes in the plane_list list of Planes - :param plane_list: List of Planes to be compared to + + Args: + plane_list: List of Planes to be compared to Returns: bool: True if the plane is in the list. @@ -723,9 +756,11 @@ def indices_separate(self, points, dist_tolerance): Returns three lists containing the indices of the points lying on one side of the plane, on the plane and on the other side of the plane. The dist_tolerance parameter controls the tolerance to which a point is considered to lie on the plane or not (distance to the plane) - :param points: list of points - :param dist_tolerance: tolerance to which a point is considered to lie on the plane - or not (distance to the plane) + + Args: + points: list of points + dist_tolerance: tolerance to which a point is considered to lie on the plane + or not (distance to the plane) Returns: The lists of indices of the points on one side of the plane, on the plane and @@ -746,7 +781,9 @@ def indices_separate(self, points, dist_tolerance): def distance_to_point(self, point): """ Computes the absolute distance from the plane to the point - :param point: Point for which distance is computed + + Args: + point: Point for which distance is computed Returns: Distance between the plane and the point. @@ -757,7 +794,9 @@ def distances(self, points): """ Computes the distances from the plane to each of the points. Positive distances are on the side of the normal of the plane while negative distances are on the other side - :param points: Points for which distances are computed + + Args: + points: Points for which distances are computed Returns: Distances from the plane to the points (positive values on the side of the normal to the plane, @@ -770,8 +809,10 @@ def distances_indices_sorted(self, points, sign=False): Computes the distances from the plane to each of the points. Positive distances are on the side of the normal of the plane while negative distances are on the other side. Indices sorting the points from closest to furthest is also computed. - :param points: Points for which distances are computed - :param sign: Whether to add sign information in the indices sorting the points distances + + Args: + points: Points for which distances are computed + sign: Whether to add sign information in the indices sorting the points distances Returns: Distances from the plane to the points (positive values on the side of the normal to the plane, @@ -791,11 +832,13 @@ def distances_indices_groups(self, points, delta=None, delta_factor=0.05, sign=F to furthest is also computed. Grouped indices are also given, for which indices of the distances that are separated by less than delta are grouped together. The delta parameter is either set explicitly or taken as a fraction (using the delta_factor parameter) of the maximal point distance. - :param points: Points for which distances are computed - :param delta: Distance interval for which two points are considered in the same group. - :param delta_factor: If delta is None, the distance interval is taken as delta_factor times the maximal + + Args: + points: Points for which distances are computed + delta: Distance interval for which two points are considered in the same group. + delta_factor: If delta is None, the distance interval is taken as delta_factor times the maximal point distance. - :param sign: Whether to add sign information in the indices sorting the points distances + sign: Whether to add sign information in the indices sorting the points distances Returns: Distances from the plane to the points (positive values on the side of the normal to the plane, @@ -819,7 +862,9 @@ def distances_indices_groups(self, points, delta=None, delta_factor=0.05, sign=F def projectionpoints(self, pps): """ Projects each points in the point list pps on plane and returns the list of projected points - :param pps: List of points to project on plane + + Args: + pps: List of points to project on plane Returns: List of projected point on plane. @@ -845,7 +890,9 @@ def project_and_to2dim_ordered_indices(self, pps, plane_center="mean"): """ Projects each points in the point list pps on plane and returns the indices that would sort the list of projected points in anticlockwise order - :param pps: List of points to project on plane + + Args: + pps: List of points to project on plane Returns: List of indices that would sort the list of projected points. @@ -856,7 +903,9 @@ def project_and_to2dim_ordered_indices(self, pps, plane_center="mean"): def project_and_to2dim(self, pps, plane_center): """ Projects the list of points pps to the plane and changes the basis from 3D to the 2D basis of the plane - :param pps: List of points to be projected + + Args: + pps: List of points to be projected Returns: :raise: @@ -883,8 +932,9 @@ def project_and_to2dim(self, pps, plane_center): def fit_error(self, points, fit="least_square_distance"): """Evaluate the error for a list of points with respect to this plane. - :param points: List of points. - :param fit: Type of fit error. + Args: + points: List of points. + fit: Type of fit error. Returns: Error for a list of points with respect to this plane. @@ -898,7 +948,8 @@ def fit_error(self, points, fit="least_square_distance"): def fit_least_square_distance_error(self, points): """Evaluate the sum of squared distances error for a list of points with respect to this plane. - :param points: List of points. + Args: + points: List of points. Returns: Sum of squared distances error for a list of points with respect to this plane. @@ -908,7 +959,8 @@ def fit_least_square_distance_error(self, points): def fit_maximum_distance_error(self, points): """Evaluate the max distance error for a list of points with respect to this plane. - :param points: List of points. + Args: + points: List of points. Returns: Max distance error for a list of points with respect to this plane. @@ -969,11 +1021,12 @@ def crosses_origin(self): return self._crosses_origin @classmethod - def from_2points_and_origin(cls, p1, p2): + def from_2points_and_origin(cls, p1, p2) -> Self: """Initializes plane from two points and the origin. - :param p1: First point. - :param p2: Second point. + Args: + p1: First point. + p2: Second point. Returns: Plane. @@ -981,12 +1034,13 @@ def from_2points_and_origin(cls, p1, p2): return cls.from_3points(p1, p2, np.zeros(3)) @classmethod - def from_3points(cls, p1, p2, p3): + def from_3points(cls, p1, p2, p3) -> Self: """Initializes plane from three points. - :param p1: First point. - :param p2: Second point. - :param p3: Third point. + Args: + p1: First point. + p2: Second point. + p3: Third point. Returns: Plane. @@ -1001,13 +1055,14 @@ def from_3points(cls, p1, p2, p3): return cls(coefficients, p1=p1, p2=p2, p3=p3) @classmethod - def from_npoints(cls, points, best_fit="least_square_distance"): + def from_npoints(cls, points, best_fit="least_square_distance") -> Self: """Initializes plane from a list of points. If the number of points is larger than 3, will use a least square fitting or max distance fitting. - :param points: List of points. - :param best_fit: Type of fitting procedure for more than 3 points. + Args: + points: List of points. + best_fit: Type of fitting procedure for more than 3 points. Returns: Plane @@ -1020,13 +1075,15 @@ def from_npoints(cls, points, best_fit="least_square_distance"): return cls.from_npoints_least_square_distance(points) if best_fit == "maximum_distance": return cls.from_npoints_maximum_distance(points) - return None + + raise ValueError("Cannot initialize Plane.") @classmethod - def from_npoints_least_square_distance(cls, points): + def from_npoints_least_square_distance(cls, points) -> Self: """Initializes plane from a list of points using a least square fitting procedure. - :param points: List of points. + Args: + points: List of points. Returns: Plane. @@ -1048,14 +1105,15 @@ def from_npoints_least_square_distance(cls, points): return cls(coefficients) @classmethod - def perpendicular_bisector(cls, p1, p2): + def perpendicular_bisector(cls, p1, p2) -> Self: """Initialize a plane from the perpendicular bisector of two points. The perpendicular bisector of two points is the plane perpendicular to the vector joining these two points and passing through the middle of the segment joining the two points. - :param p1: First point. - :param p2: Second point. + Args: + p1: First point. + p2: Second point. Returns: Plane. @@ -1066,10 +1124,11 @@ def perpendicular_bisector(cls, p1, p2): return cls(np.array([normal_vector[0], normal_vector[1], normal_vector[2], dd], float)) @classmethod - def from_npoints_maximum_distance(cls, points): + def from_npoints_maximum_distance(cls, points) -> Self: """Initializes plane from a list of points using a max distance fitting procedure. - :param points: List of points. + Args: + points: List of points. Returns: Plane. @@ -1077,8 +1136,8 @@ def from_npoints_maximum_distance(cls, points): convex_hull = ConvexHull(points) heights = [] ipoints_heights = [] - for isimplex, _simplex in enumerate(convex_hull.simplices): - cc = convex_hull.equations[isimplex] + for idx, _simplex in enumerate(convex_hull.simplices): + cc = convex_hull.equations[idx] plane = Plane.from_coefficients(cc[0], cc[1], cc[2], cc[3]) distances = [plane.distance_to_point(pp) for pp in points] ipoint_height = np.argmax(distances) @@ -1095,13 +1154,14 @@ def from_npoints_maximum_distance(cls, points): return cls(np.array([normal_vector[0], normal_vector[1], normal_vector[2], dd], float)) @classmethod - def from_coefficients(cls, a, b, c, d): + def from_coefficients(cls, a, b, c, d) -> Self: """Initialize plane from its coefficients. - :param a: a coefficient of the plane. - :param b: b coefficient of the plane. - :param c: c coefficient of the plane. - :param d: d coefficient of the plane. + Args: + a: a coefficient of the plane. + b: b coefficient of the plane. + c: c coefficient of the plane. + d: d coefficient of the plane. Returns: Plane. diff --git a/pymatgen/analysis/chemenv/utils/defs_utils.py b/pymatgen/analysis/chemenv/utils/defs_utils.py index 390ad9170b8..01bbc9e8450 100644 --- a/pymatgen/analysis/chemenv/utils/defs_utils.py +++ b/pymatgen/analysis/chemenv/utils/defs_utils.py @@ -70,9 +70,10 @@ class AdditionalConditions: def check_condition(self, condition, structure: Structure, parameters): """ - :param condition: - :param structure: - :param parameters: + Args: + condition: + structure: + parameters: """ if condition == self.NONE: return True diff --git a/pymatgen/analysis/chemenv/utils/func_utils.py b/pymatgen/analysis/chemenv/utils/func_utils.py index 92db68fa7e7..d6b166b6b77 100644 --- a/pymatgen/analysis/chemenv/utils/func_utils.py +++ b/pymatgen/analysis/chemenv/utils/func_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar import numpy as np @@ -14,6 +14,9 @@ smoothstep, ) +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "David Waroquiers" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Geoffroy Hautier" @@ -31,8 +34,9 @@ class AbstractRatioFunction: def __init__(self, function, options_dict=None): """Constructor for AbstractRatioFunction. - :param function: Ration function name. - :param options_dict: Dictionary containing the parameters for the ratio function. + Args: + function: Ration function name. + options_dict: Dictionary containing the parameters for the ratio function. """ if function not in self.ALLOWED_FUNCTIONS: raise ValueError(f"{function=!r} is not allowed in RatioFunction of type {type(self).__name__}") @@ -43,7 +47,8 @@ def __init__(self, function, options_dict=None): def setup_parameters(self, options_dict): """Set up the parameters for this ratio function. - :param options_dict: Dictionary containing the parameters for the ratio function. + Args: + options_dict: Dictionary containing the parameters for the ratio function. """ function_options = self.ALLOWED_FUNCTIONS[self.function] if len(function_options) > 0: @@ -88,7 +93,8 @@ def setup_parameters(self, options_dict): def evaluate(self, value): """Evaluate the ratio function for the given value. - :param value: Value for which ratio function has to be evaluated. + Args: + value: Value for which ratio function has to be evaluated. Returns: Ratio function corresponding to the value. @@ -96,13 +102,11 @@ def evaluate(self, value): return self.eval(value) @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Construct ratio function from dict. - :param dct: Dict representation of the ratio function - - Returns: - Ratio function object. + Args: + dct (dict): Dict representation of the ratio function """ return cls(function=dct["function"], options_dict=dct["options"]) @@ -126,7 +130,8 @@ def power2_decreasing_exp(self, vals): The values (i.e. "x"), are scaled to the "max" parameter. The "a" constant correspond to the "alpha" parameter. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -138,7 +143,8 @@ def smootherstep(self, vals): The values (i.e. "x"), are scaled between the "lower" and "upper" parameters. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -150,7 +156,8 @@ def smoothstep(self, vals): The values (i.e. "x"), are scaled between the "lower" and "upper" parameters. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -162,7 +169,8 @@ def inverse_smootherstep(self, vals): The values (i.e. "x"), are scaled between the "lower" and "upper" parameters. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -174,7 +182,8 @@ def inverse_smoothstep(self, vals): The values (i.e. "x"), are scaled between the "lower" and "upper" parameters. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -186,7 +195,8 @@ def power2_inverse_decreasing(self, vals): The values (i.e. "x"), are scaled to the "max" parameter. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -198,7 +208,8 @@ def power2_inverse_power2_decreasing(self, vals): The values (i.e. "x"), are scaled to the "max" parameter. - :param vals: Values for which the ratio function has to be evaluated. + Args: + vals: Values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the values. @@ -228,7 +239,8 @@ def power2_decreasing_exp(self, vals): The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant correspond to the "alpha" parameter. - :param vals: CSM values for which the ratio function has to be evaluated. + Args: + vals: CSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the CSM values. @@ -240,7 +252,8 @@ def smootherstep(self, vals): The CSM values (i.e. "x"), are scaled between the "lower_csm" and "upper_csm" parameters. - :param vals: CSM values for which the ratio function has to be evaluated. + Args: + vals: CSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the CSM values. @@ -256,7 +269,8 @@ def smoothstep(self, vals): The CSM values (i.e. "x"), are scaled between the "lower_csm" and "upper_csm" parameters. - :param vals: CSM values for which the ratio function has to be evaluated. + Args: + vals: CSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the CSM values. @@ -270,7 +284,8 @@ def smoothstep(self, vals): def fractions(self, data): """Get the fractions from the CSM ratio function applied to the data. - :param data: List of CSM values to estimate fractions. + Args: + data: List of CSM values to estimate fractions. Returns: Corresponding fractions for each CSM. @@ -285,7 +300,8 @@ def fractions(self, data): def mean_estimator(self, data): """Get the weighted CSM using this CSM ratio function applied to the data. - :param data: List of CSM values to estimate the weighted CSM. + Args: + data: List of CSM values to estimate the weighted CSM. Returns: Weighted CSM from this ratio function. @@ -323,7 +339,8 @@ def power2_inverse_decreasing(self, vals): The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant correspond to the "alpha" parameter. - :param vals: CSM values for which the ratio function has to be evaluated. + Args: + vals: CSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the CSM values. @@ -336,7 +353,8 @@ def power2_inverse_power2_decreasing(self, vals): The CSM values (i.e. "x"), are scaled to the "max_csm" parameter. The "a" constant correspond to the "alpha" parameter. - :param vals: CSM values for which the ratio function has to be evaluated. + Args: + vals: CSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the CSM values. @@ -346,7 +364,8 @@ def power2_inverse_power2_decreasing(self, vals): def fractions(self, data): """Get the fractions from the CSM ratio function applied to the data. - :param data: List of CSM values to estimate fractions. + Args: + data: List of CSM values to estimate fractions. Returns: Corresponding fractions for each CSM. @@ -370,7 +389,8 @@ def fractions(self, data): def mean_estimator(self, data): """Get the weighted CSM using this CSM ratio function applied to the data. - :param data: List of CSM values to estimate the weighted CSM. + Args: + data: List of CSM values to estimate the weighted CSM. Returns: Weighted CSM from this ratio function. @@ -406,7 +426,8 @@ def smootherstep(self, vals): The DeltaCSM values (i.e. "x"), are scaled between the "delta_csm_min" and "delta_csm_max" parameters. - :param vals: DeltaCSM values for which the ratio function has to be evaluated. + Args: + vals: DeltaCSM values for which the ratio function has to be evaluated. Returns: Result of the ratio function applied to the DeltaCSM values. diff --git a/pymatgen/analysis/chemenv/utils/graph_utils.py b/pymatgen/analysis/chemenv/utils/graph_utils.py index bb271b25015..42ccb0d348d 100644 --- a/pymatgen/analysis/chemenv/utils/graph_utils.py +++ b/pymatgen/analysis/chemenv/utils/graph_utils.py @@ -4,20 +4,26 @@ import itertools import operator +from typing import TYPE_CHECKING import networkx as nx import numpy as np from monty.json import MSONable +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "waroquiers" def get_delta(node1, node2, edge_data): """ Get the delta. - :param node1: - :param node2: - :param edge_data: + + Args: + node1: + node2: + edge_data: """ if node1.isite == edge_data["start"] and node2.isite == edge_data["end"]: return np.array(edge_data["delta"], dtype=int) @@ -29,11 +35,13 @@ def get_delta(node1, node2, edge_data): def get_all_simple_paths_edges(graph, source, target, cutoff=None, data=True): """ Get all the simple path and edges. - :param graph: - :param source: - :param target: - :param cutoff: - :param data: + + Args: + graph: + source: + target: + cutoff: + data: """ edge_paths = [] if not graph.is_multigraph(): @@ -136,9 +144,10 @@ class SimpleGraphCycle(MSONable): def __init__(self, nodes, validate=True, ordered=None): """ - :param nodes: - :param validate: - :param ordered: + Args: + nodes: + validate: + ordered: """ self.nodes = tuple(nodes) if validate: @@ -181,7 +190,8 @@ def _is_valid(self, check_strict_ordering=False): def validate(self, check_strict_ordering=False): """ - :param check_strict_ordering: + Args: + check_strict_ordering: """ is_valid, msg = self._is_valid(check_strict_ordering=check_strict_ordering) if not is_valid: @@ -249,7 +259,7 @@ def __eq__(self, other: object) -> bool: return self.nodes == other.nodes @classmethod - def from_edges(cls, edges, edges_are_ordered=True): + def from_edges(cls, edges, edges_are_ordered: bool = True) -> Self: """Constructs SimpleGraphCycle from a list edges. By default, the edges list is supposed to be ordered as it will be @@ -258,7 +268,7 @@ def from_edges(cls, edges, edges_are_ordered=True): order in the list. """ if edges_are_ordered: - nodes = [e[0] for e in edges] + nodes = [edge[0] for edge in edges] if not all(e1e2[0][1] == e1e2[1][0] for e1e2 in zip(edges, edges[1:])) or edges[-1][1] != edges[0][0]: raise ValueError("Could not construct a cycle from edges.") else: @@ -282,7 +292,7 @@ def from_edges(cls, edges, edges_are_ordered=True): nodes.pop() return cls(nodes) - def as_dict(self): + def as_dict(self) -> dict: """MSONable dict""" dct = MSONable.as_dict(self) # Transforming tuple object to a list to allow BSON and MongoDB @@ -290,13 +300,15 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d, validate=False): + def from_dict(cls, dct: dict, validate: bool = False) -> Self: """ Serialize from dict. - :param d: - :param validate: + + Args: + dct (dict): Dict representation. + validate: If True, will validate the cycle. """ - return cls(nodes=d["nodes"], validate=validate, ordered=d["ordered"]) + return cls(nodes=dct["nodes"], validate=validate, ordered=dct["ordered"]) class MultiGraphCycle(MSONable): @@ -314,10 +326,11 @@ class MultiGraphCycle(MSONable): def __init__(self, nodes, edge_indices, validate=True, ordered=None): """ - :param nodes: - :param edge_indices: - :param validate: - :param ordered: + Args: + nodes: + edge_indices: + validate: + ordered: """ self.nodes = tuple(nodes) self.edge_indices = tuple(edge_indices) @@ -366,13 +379,14 @@ def _is_valid(self, check_strict_ordering=False): def validate(self, check_strict_ordering=False): """ - :param check_strict_ordering: + Args: + check_strict_ordering: """ is_valid, msg = self._is_valid(check_strict_ordering=check_strict_ordering) if not is_valid: raise ValueError(f"MultiGraphCycle is not valid : {msg}") - def order(self, raise_on_fail=True): + def order(self, raise_on_fail: bool = True): """Orders the SimpleGraphCycle. The ordering is performed such that the first node is the "lowest" one @@ -380,7 +394,8 @@ def order(self, raise_on_fail=True): first node. If raise_on_fail is set to True a RuntimeError will be raised if the ordering fails. - :param raise_on_fail: If set to True, will raise a RuntimeError if the ordering fails. + Args: + raise_on_fail: If set to True, will raise a RuntimeError if the ordering fails. """ # always validate the cycle if it needs to be ordered # also validates that the nodes can be strictly ordered @@ -453,7 +468,9 @@ def __eq__(self, other: object) -> bool: def get_all_elementary_cycles(graph): """ - :param graph: + + Args: + graph: """ if not isinstance(graph, nx.Graph): raise TypeError("graph should be a networkx Graph object.") @@ -473,8 +490,8 @@ def get_all_elementary_cycles(graph): edge_idx += 1 cycles_matrix = np.zeros(shape=(len(cycle_basis), edge_idx), dtype=bool) for icycle, cycle in enumerate(cycle_basis): - for in1, n1 in enumerate(cycle): - n2 = cycle[(in1 + 1) % len(cycle)] + for in1, n1 in enumerate(cycle, start=1): + n2 = cycle[(in1) % len(cycle)] iedge = all_edges_dict[(n1, n2)] cycles_matrix[icycle, iedge] = True diff --git a/pymatgen/analysis/chemenv/utils/math_utils.py b/pymatgen/analysis/chemenv/utils/math_utils.py index ebe6112f281..1f0175a244a 100644 --- a/pymatgen/analysis/chemenv/utils/math_utils.py +++ b/pymatgen/analysis/chemenv/utils/math_utils.py @@ -44,8 +44,12 @@ def _cartesian_product(lists): def prime_factors(n: int) -> list[int]: """Lists prime factors of a given natural integer, from greatest to smallest - :param n: Natural integer - :rtype : list of all prime factors of the given natural n. + + Args: + n: Natural integer + + Returns: + list of all prime factors of the given natural n. """ idx = 2 while idx <= sqrt(n): @@ -57,13 +61,15 @@ def prime_factors(n: int) -> list[int]: return [n] # n is prime -def _factor_generator(n): +def _factor_generator(n: int) -> dict[int, int]: """ From a given natural integer, returns the prime factors and their multiplicity - :param n: Natural integer + + Args: + n: Natural integer """ p = prime_factors(n) - factors = {} + factors: dict[int, int] = {} for p1 in p: try: factors[p1] += 1 @@ -75,7 +81,9 @@ def _factor_generator(n): def divisors(n): """ From a given natural integer, returns the list of divisors in ascending order - :param n: Natural integer + + Args: + n: Natural integer Returns: List of divisors of n in ascending order. @@ -92,9 +100,11 @@ def divisors(n): def get_center_of_arc(p1, p2, radius): """ - :param p1: - :param p2: - :param radius: + + Args: + p1: + p2: + radius: """ dx = p2[0] - p1[0] dy = p2[1] - p1[1] @@ -110,7 +120,9 @@ def get_center_of_arc(p1, p2, radius): def get_linearly_independent_vectors(vectors_list): """ - :param vectors_list: + + Args: + vectors_list: """ independent_vectors_list = [] for vector in vectors_list: @@ -132,11 +144,13 @@ def get_linearly_independent_vectors(vectors_list): def scale_and_clamp(xx, edge0, edge1, clamp0, clamp1): """ - :param xx: - :param edge0: - :param edge1: - :param clamp0: - :param clamp1: + + Args: + xx: + edge0: + edge1: + clamp0: + clamp1: """ return np.clip((xx - edge0) / (edge1 - edge0), clamp0, clamp1) @@ -144,9 +158,11 @@ def scale_and_clamp(xx, edge0, edge1, clamp0, clamp1): # Step function based on the cumulative distribution function of the normal law def normal_cdf_step(xx, mean, scale): """ - :param xx: - :param mean: - :param scale: + + Args: + xx: + mean: + scale: """ return 0.5 * (1.0 + erf((xx - mean) / (np.sqrt(2.0) * scale))) @@ -160,9 +176,11 @@ def normal_cdf_step(xx, mean, scale): def smoothstep(xx, edges=None, inverse=False): """ - :param xx: - :param edges: - :param inverse: + + Args: + xx: + edges: + inverse: """ if edges is None: xx_clipped = np.clip(xx, 0.0, 1.0) @@ -175,9 +193,11 @@ def smoothstep(xx, edges=None, inverse=False): def smootherstep(xx, edges=None, inverse=False): """ - :param xx: - :param edges: - :param inverse: + + Args: + xx: + edges: + inverse: """ if edges is None: xx_clipped = np.clip(xx, 0.0, 1.0) @@ -190,9 +210,11 @@ def smootherstep(xx, edges=None, inverse=False): def cosinus_step(xx, edges=None, inverse=False): """ - :param xx: - :param edges: - :param inverse: + + Args: + xx: + edges: + inverse: """ if edges is None: xx_clipped = np.clip(xx, 0.0, 1.0) @@ -205,19 +227,23 @@ def cosinus_step(xx, edges=None, inverse=False): def power3_step(xx, edges=None, inverse=False): """ - :param xx: - :param edges: - :param inverse: + + Args: + xx: + edges: + inverse: """ return smoothstep(xx, edges=edges, inverse=inverse) def powern_parts_step(xx, edges=None, inverse=False, nn=2): """ - :param xx: - :param edges: - :param inverse: - :param nn: + + Args: + xx: + edges: + inverse: + nn: """ if edges is None: aa = np.power(0.5, 1.0 - nn) @@ -256,9 +282,11 @@ def powern_parts_step(xx, edges=None, inverse=False, nn=2): def powern_decreasing(xx, edges=None, nn=2): """ - :param xx: - :param edges: - :param nn: + + Args: + xx: + edges: + nn: """ if edges is None: aa = 1.0 / np.power(-1.0, nn) @@ -269,9 +297,11 @@ def powern_decreasing(xx, edges=None, nn=2): def power2_decreasing_exp(xx, edges=None, alpha=1.0): """ - :param xx: - :param edges: - :param alpha: + + Args: + xx: + edges: + alpha: """ if edges is None: aa = 1.0 / np.power(-1.0, 2) @@ -287,9 +317,11 @@ def power2_decreasing_exp(xx, edges=None, alpha=1.0): def power2_tangent_decreasing(xx, edges=None, prefactor=None): """ - :param xx: - :param edges: - :param prefactor: + + Args: + xx: + edges: + prefactor: """ if edges is None: aa = 1.0 / np.power(-1.0, 2) if prefactor is None else prefactor @@ -301,9 +333,11 @@ def power2_tangent_decreasing(xx, edges=None, prefactor=None): def power2_inverse_decreasing(xx, edges=None, prefactor=None): """ - :param xx: - :param edges: - :param prefactor: + + Args: + xx: + edges: + prefactor: """ if edges is None: aa = 1.0 / np.power(-1.0, 2) if prefactor is None else prefactor @@ -315,9 +349,11 @@ def power2_inverse_decreasing(xx, edges=None, prefactor=None): def power2_inverse_power2_decreasing(xx, edges=None, prefactor=None): """ - :param xx: - :param edges: - :param prefactor: + + Args: + xx: + edges: + prefactor: """ if edges is None: aa = 1.0 / np.power(-1.0, 2) if prefactor is None else prefactor @@ -332,10 +368,12 @@ def power2_inverse_power2_decreasing(xx, edges=None, prefactor=None): def power2_inverse_powern_decreasing(xx, edges=None, prefactor=None, powern=2.0): """ - :param xx: - :param edges: - :param prefactor: - :param powern: + + Args: + xx: + edges: + prefactor: + powern: """ if edges is None: aa = 1.0 / np.power(-1.0, 2) if prefactor is None else prefactor diff --git a/pymatgen/analysis/chemenv/utils/scripts_utils.py b/pymatgen/analysis/chemenv/utils/scripts_utils.py index 66394140ce9..5b17eb0df15 100644 --- a/pymatgen/analysis/chemenv/utils/scripts_utils.py +++ b/pymatgen/analysis/chemenv/utils/scripts_utils.py @@ -64,18 +64,19 @@ def draw_cg( """ Draw cg. - :param vis: - :param site: - :param neighbors: - :param cg: - :param perm: - :param perfect2local_map: - :param show_perfect: - :param csm_info: - :param symmetry_measure_type: - :param perfect_radius: - :param show_distorted: - :param faces_color_override: + Args: + site: + vis: + neighbors: + cg: + perm: + perfect2local_map: + show_perfect: + csm_info: + symmetry_measure_type: + perfect_radius: + show_distorted: + faces_color_override: """ if show_perfect: if csm_info is None: @@ -154,12 +155,13 @@ def draw_cg( def visualize(cg, zoom=None, vis=None, factor=1.0, view_index=True, faces_color_override=None): """ Visualizing a coordination geometry - :param cg: - :param zoom: - :param vis: - :param factor: - :param view_index: - :param faces_color_override: + Args: + cg: + zoom: + vis: + factor: + view_index: + faces_color_override: """ if vis is None and StructureVis is not None: vis = StructureVis(show_polyhedron=False, show_unit_cell=False) @@ -192,7 +194,8 @@ def compute_environments(chemenv_configuration): """ Compute the environments. - :param chemenv_configuration: + Args: + chemenv_configuration: """ string_sources = { "cif": {"string": "a Cif file", "regexp": r".*\.cif$"}, diff --git a/pymatgen/analysis/chempot_diagram.py b/pymatgen/analysis/chempot_diagram.py index 2b9613ebd8f..7a1469147aa 100644 --- a/pymatgen/analysis/chempot_diagram.py +++ b/pymatgen/analysis/chempot_diagram.py @@ -103,9 +103,7 @@ def __init__( renormalized_entries = [] for entry in entries: comp_dict = entry.composition.as_dict() - renormalization_energy = sum( - [comp_dict[el] * _el_refs[Element(el)].energy_per_atom for el in comp_dict] - ) + renormalization_energy = sum(comp_dict[el] * _el_refs[Element(el)].energy_per_atom for el in comp_dict) renormalized_entries.append(_renormalize_entry(entry, renormalization_energy / sum(comp_dict.values()))) entries = renormalized_entries @@ -115,6 +113,7 @@ def __init__( self.default_min_limit = default_min_limit self.elements = sorted({els for ent in self.entries for els in ent.elements}) self.dim = len(self.elements) + self.formal_chempots = formal_chempots self._min_entries, self._el_refs = self._get_min_entries_and_el_refs(self.entries) self._entry_dict = {ent.reduced_formula: ent for ent in self._min_entries} self._border_hyperplanes = self._get_border_hyperplanes() @@ -186,6 +185,7 @@ def get_plot( entries=entries, limits=self.limits, default_min_limit=self.default_min_limit, + formal_chempots=self.formal_chempots, ) fig = cpd.get_plot(elements=elems, label_stable=label_stable) # type: ignore else: @@ -666,7 +666,7 @@ def get_centroid_2d(vertices: np.ndarray) -> np.ndarray: polygon. Useful for calculating the location of an annotation on a chemical potential domain within a 3D chemical potential diagram. - **NOTE**: vertices must be ordered circumferentially! + NOTE vertices must be ordered circumferentially! Args: vertices: array of 2-d coordinates corresponding to a polygon, ordered diff --git a/pymatgen/analysis/diffraction/core.py b/pymatgen/analysis/diffraction/core.py index 24a5912cf8e..24f07587fe3 100644 --- a/pymatgen/analysis/diffraction/core.py +++ b/pymatgen/analysis/diffraction/core.py @@ -68,8 +68,9 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) sphere of radius 2 / wavelength. Returns: - (DiffractionPattern) + DiffractionPattern """ + raise NotImplementedError def get_plot( self, diff --git a/pymatgen/analysis/diffraction/neutron.py b/pymatgen/analysis/diffraction/neutron.py index cdd457b09b8..0057669f630 100644 --- a/pymatgen/analysis/diffraction/neutron.py +++ b/pymatgen/analysis/diffraction/neutron.py @@ -79,15 +79,15 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) sphere of radius 2 / wavelength. Returns: - (NDPattern) + DiffractionPattern: ND pattern """ if self.symprec: finder = SpacegroupAnalyzer(structure, symprec=self.symprec) structure = finder.get_refined_structure() wavelength = self.wavelength - latt = structure.lattice - is_hex = latt.is_hexagonal() + lattice = structure.lattice + is_hex = lattice.is_hexagonal() # Obtained from Bragg condition. Note that reciprocal lattice # vector length is 1 / d_hkl. @@ -98,7 +98,7 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) ) # Obtain crystallographic reciprocal lattice points within range - recip_latt = latt.reciprocal_lattice_crystallographic + recip_latt = lattice.reciprocal_lattice_crystallographic recip_pts = recip_latt.get_points_in_sphere([[0, 0, 0]], [0, 0, 0], max_r) if min_r: recip_pts = [pt for pt in recip_pts if pt[1] >= min_r] diff --git a/pymatgen/analysis/diffraction/tem.py b/pymatgen/analysis/diffraction/tem.py index bb00159fc4d..bab041edf66 100644 --- a/pymatgen/analysis/diffraction/tem.py +++ b/pymatgen/analysis/diffraction/tem.py @@ -588,7 +588,7 @@ def get_plot_2d(self, structure: Structure) -> go.Figure: ), ] layout = dict( - title="2D Diffraction Pattern
    Beam Direction: " + "".join(str(e) for e in self.beam_direction), + title="2D Diffraction Pattern
    Beam Direction: " + "".join(map(str, self.beam_direction)), font={"size": 14, "color": "#7f7f7f"}, hovermode="closest", xaxis={ diff --git a/pymatgen/analysis/diffraction/xrd.py b/pymatgen/analysis/diffraction/xrd.py index be5e288ebf3..f2b2c0ff4fd 100644 --- a/pymatgen/analysis/diffraction/xrd.py +++ b/pymatgen/analysis/diffraction/xrd.py @@ -140,15 +140,15 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) sphere of radius 2 / wavelength. Returns: - (XRDPattern) + DiffractionPattern: XRD pattern """ if self.symprec: finder = SpacegroupAnalyzer(structure, symprec=self.symprec) structure = finder.get_refined_structure() wavelength = self.wavelength - latt = structure.lattice - is_hex = latt.is_hexagonal() + lattice = structure.lattice + is_hex = lattice.is_hexagonal() # Obtained from Bragg condition. Note that reciprocal lattice # vector length is 1 / d_hkl. @@ -159,7 +159,7 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) ) # Obtain crystallographic reciprocal lattice points within range - recip_latt = latt.reciprocal_lattice_crystallographic + recip_latt = lattice.reciprocal_lattice_crystallographic recip_pts = recip_latt.get_points_in_sphere([[0, 0, 0]], [0, 0, 0], max_r) if min_r: recip_pts = [pt for pt in recip_pts if pt[1] >= min_r] @@ -215,7 +215,7 @@ def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90) g_dot_r = np.dot(frac_coords, np.transpose([hkl])).T[0] # Highly vectorized computation of atomic scattering factors. - # Equivalent non-vectorized code is:: + # Equivalent non-vectorized code is: # # for site in structure: # el = site.specie diff --git a/pymatgen/analysis/dimensionality.py b/pymatgen/analysis/dimensionality.py index e4c8d7b931c..b252f51aa9e 100644 --- a/pymatgen/analysis/dimensionality.py +++ b/pymatgen/analysis/dimensionality.py @@ -51,7 +51,7 @@ def get_dimensionality_larsen(bonded_structure): due to periodic boundary conditions. Requires a StructureGraph object as input. This can be generated using one - of the NearNeighbor classes. For example, using the CrystalNN class:: + of the NearNeighbor classes. For example, using the CrystalNN class: bonded_structure = CrystalNN().get_bonded_structure(structure) @@ -67,7 +67,7 @@ def get_dimensionality_larsen(bonded_structure): CrystalNN.get_bonded_structure() method. Returns: - (int): The dimensionality of the structure. + int: The dimensionality of the structure. """ return max(c["dimensionality"] for c in get_structure_components(bonded_structure)) @@ -85,7 +85,7 @@ def get_structure_components( structure type or improper connections due to periodic boundary conditions. Requires a StructureGraph object as input. This can be generated using one - of the NearNeighbor classes. For example, using the CrystalNN class:: + of the NearNeighbor classes. For example, using the CrystalNN class: bonded_structure = CrystalNN().get_bonded_structure(structure) @@ -108,19 +108,19 @@ def get_structure_components( objects for zero-dimensional components. Returns: - (list of dict): Information on the components in a structure as a list - of dictionaries with the keys: - - - "structure_graph": A pymatgen StructureGraph object for the - component. - - "dimensionality": The dimensionality of the structure component as an - int. - - "orientation": If inc_orientation is `True`, the orientation of the - component as a tuple. E.g. (1, 1, 1) - - "site_ids": If inc_site_ids is `True`, the site indices of the - sites in the component as a tuple. - - "molecule_graph": If inc_molecule_graph is `True`, the site a - MoleculeGraph object for zero-dimensional components. + list[dict]: Information on the components in a structure as a list + of dictionaries with the keys: + + - "structure_graph": A pymatgen StructureGraph object for the + component. + - "dimensionality": The dimensionality of the structure component as an + int. + - "orientation": If inc_orientation is `True`, the orientation of the + component as a tuple. E.g. (1, 1, 1) + - "site_ids": If inc_site_ids is `True`, the site indices of the + sites in the component as a tuple. + - "molecule_graph": If inc_molecule_graph is `True`, the site a + MoleculeGraph object for zero-dimensional components. """ comp_graphs = (bonded_structure.graph.subgraph(c) for c in nx.weakly_connected_components(bonded_structure.graph)) @@ -167,8 +167,7 @@ def get_structure_components( def calculate_dimensionality_of_site(bonded_structure, site_index, inc_vertices=False): - """ - Calculates the dimensionality of the component containing the given site. + """Calculates the dimensionality of the component containing the given site. Implements directly the modified breadth-first-search algorithm described in Algorithm 1 of: @@ -186,10 +185,10 @@ def calculate_dimensionality_of_site(bonded_structure, site_index, inc_vertices= images) of the component. Returns: - (int or tuple): If inc_vertices is False, the dimensionality of the - component will be returned as an int. If inc_vertices is true, the - function will return a tuple of (dimensionality, vertices), where - vertices is a list of tuples. E.g. [(0, 0, 0), (1, 1, 1)]. + int | tuple: If inc_vertices is False, the dimensionality of the + component will be returned as an int. If inc_vertices is true, the + function will return a tuple of (dimensionality, vertices), where + vertices is a list of tuples. E.g. [(0, 0, 0), (1, 1, 1)]. """ def neighbors(comp_index): @@ -258,7 +257,7 @@ def zero_d_graph_to_molecule_graph(bonded_structure, graph): interest. Returns: - (MoleculeGraph): A MoleculeGraph object of the component. + MoleculeGraph: A MoleculeGraph object of the component. """ seen_indices = [] sites = [] @@ -287,7 +286,7 @@ def zero_d_graph_to_molecule_graph(bonded_structure, graph): sorted_sites = np.array(sites, dtype=object)[indices_ordering] sorted_graph = nx.convert_node_labels_to_integers(graph, ordering="sorted") mol = Molecule([s.specie for s in sorted_sites], [s.coords for s in sorted_sites]) - return MoleculeGraph.with_edges(mol, nx.Graph(sorted_graph).edges()) + return MoleculeGraph.from_edges(mol, nx.Graph(sorted_graph).edges()) def get_dimensionality_cheon( @@ -327,8 +326,8 @@ def get_dimensionality_cheon( structures. Testing with a larger cell circumvents this problem Returns: - (str): dimension of the largest cluster as a string. If there are ions - or molecules it returns 'intercalated ion/molecule' + str: dimension of the largest cluster as a string. If there are ions + or molecules it returns 'intercalated ion/molecule' """ if ldict is None: ldict = JmolNN().el_radius @@ -396,9 +395,9 @@ def find_connected_atoms(struct, tolerance=0.45, ldict=None): from JMol are used as default Returns: - (np.ndarray): A numpy array of shape (number of atoms, number of atoms); - If any image of atom j is bonded to atom i with periodic boundary - conditions, the matrix element [atom i, atom j] is 1. + np.ndarray: A numpy array of shape (number of atoms, number of atoms); + If any image of atom j is bonded to atom i with periodic boundary + conditions, the matrix element [atom i, atom j] is 1. """ if ldict is None: ldict = JmolNN().el_radius diff --git a/pymatgen/analysis/elasticity/elastic.py b/pymatgen/analysis/elasticity/elastic.py index 416a5636c61..8db995325e7 100644 --- a/pymatgen/analysis/elasticity/elastic.py +++ b/pymatgen/analysis/elasticity/elastic.py @@ -27,6 +27,7 @@ from collections.abc import Sequence from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core import Structure @@ -50,7 +51,7 @@ class NthOrderElasticTensor(Tensor): GPa_to_eV_A3 = Unit("GPa").get_conversion_factor(Unit("eV ang^-3")) symbol = "C" - def __new__(cls, input_array, check_rank=None, tol: float = 1e-4): + def __new__(cls, input_array, check_rank=None, tol: float = 1e-4) -> Self: """ Args: input_array (): @@ -92,7 +93,7 @@ def energy_density(self, strain, convert_GPa_to_eV=True): return e_density @classmethod - def from_diff_fit(cls, strains, stresses, eq_stress=None, order=2, tol: float = 1e-10): + def from_diff_fit(cls, strains, stresses, eq_stress=None, order=2, tol: float = 1e-10) -> Self: """ Takes a list of strains and stresses, and returns a list of coefficients for a polynomial fit of the given order. @@ -132,7 +133,7 @@ class ElasticTensor(NthOrderElasticTensor): in units of eV/A^3. """ - def __new__(cls, input_array, tol: float = 1e-4): + def __new__(cls, input_array, tol: float = 1e-4) -> Self: """ Create an ElasticTensor object. The constructor throws an error if the shape of the input_matrix argument is not 3x3x3x3, i. e. in true tensor notation. Issues a @@ -275,7 +276,7 @@ def snyder_ac(self, structure: Structure) -> float: n_sites = len(structure) n_atoms = structure.composition.num_atoms site_density = 1e30 * n_sites / structure.volume - tot_mass = sum(e.atomic_mass for e in structure.species) + tot_mass = sum(spec.atomic_mass for spec in structure.species) avg_mass = 1.6605e-27 * tot_mass / n_atoms return ( 0.38483 @@ -329,7 +330,7 @@ def clarke_thermalcond(self, structure: Structure) -> float: float: Clarke's thermal conductivity (in SI units) """ n_sites = len(structure) - tot_mass = sum(e.atomic_mass for e in structure.species) + tot_mass = sum(spec.atomic_mass for spec in structure.species) n_atoms = structure.composition.num_atoms weight = float(structure.composition.weight) avg_mass = 1.6605e-27 * tot_mass / n_atoms @@ -459,7 +460,7 @@ def get_structure_property_dict( return sp_dict @classmethod - def from_pseudoinverse(cls, strains, stresses): + def from_pseudoinverse(cls, strains, stresses) -> Self: """ Class method to fit an elastic tensor from stress/strain data. Method uses Moore-Penrose pseudo-inverse to invert @@ -483,7 +484,7 @@ def from_pseudoinverse(cls, strains, stresses): return cls.from_voigt(voigt_fit) @classmethod - def from_independent_strains(cls, strains, stresses, eq_stress=None, vasp=False, tol: float = 1e-10): + def from_independent_strains(cls, strains, stresses, eq_stress=None, vasp=False, tol: float = 1e-10) -> Self: """ Constructs the elastic tensor least-squares fit of independent strains @@ -522,7 +523,7 @@ class ComplianceTensor(Tensor): since the compliance tensor has a unique vscale. """ - def __new__(cls, s_array): + def __new__(cls, s_array) -> Self: """ Args: s_array (): @@ -555,7 +556,7 @@ def __init__(self, c_list: Sequence) -> None: super().__init__(c_list) @classmethod - def from_diff_fit(cls, strains, stresses, eq_stress=None, tol: float = 1e-10, order=3): + def from_diff_fit(cls, strains, stresses, eq_stress=None, tol: float = 1e-10, order=3) -> Self: """ Generates an elastic tensor expansion via the fitting function defined below in diff_fit. @@ -612,7 +613,7 @@ def get_tgt(self, temperature: float | None = None, structure: Structure = None, structure (Structure): Structure to be used in directional heat capacity determination, only necessary if temperature is specified - quad (dict): quadrature for integration, should be + quadct (dict): quadrature for integration, should be dictionary with "points" and "weights" keys defaults to quadpy.sphere.Lebedev(19) as read from file """ @@ -645,7 +646,7 @@ def get_gruneisen_parameter(self, temperature=None, structure=None, quad=None): structure (float): Structure to be used in directional heat capacity determination, only necessary if temperature is specified - quad (dict): quadrature for integration, should be + quadct (dict): quadrature for integration, should be dictionary with "points" and "weights" keys defaults to quadpy.sphere.Lebedev(19) as read from file """ @@ -760,8 +761,8 @@ def get_strain_from_stress(self, stress): """ compl_exp = self.get_compliance_expansion() strain = 0 - for n, compl in enumerate(compl_exp): - strain += compl.einsum_sequence([stress] * (n + 1)) / factorial(n + 1) + for n, compl in enumerate(compl_exp, start=1): + strain += compl.einsum_sequence([stress] * (n)) / factorial(n) return strain def get_effective_ecs(self, strain, order=2): diff --git a/pymatgen/analysis/elasticity/strain.py b/pymatgen/analysis/elasticity/strain.py index 00e428a1be0..9e04c052836 100644 --- a/pymatgen/analysis/elasticity/strain.py +++ b/pymatgen/analysis/elasticity/strain.py @@ -20,6 +20,7 @@ from collections.abc import Sequence from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core.structure import Structure @@ -38,7 +39,7 @@ class Deformation(SquareTensor): symbol = "d" - def __new__(cls, deformation_gradient): + def __new__(cls, deformation_gradient) -> Self: """ Create a Deformation object. Note that the constructor uses __new__ rather than __init__ according to the standard method of subclassing numpy ndarrays. @@ -81,7 +82,7 @@ def apply_to_structure(self, structure: Structure): return def_struct @classmethod - def from_index_amount(cls, matrix_pos, amt): + def from_index_amount(cls, matrix_pos, amt) -> Self: """ Factory method for constructing a Deformation object from a matrix position and amount. @@ -158,7 +159,7 @@ class Strain(SquareTensor): symbol = "e" - def __new__(cls, strain_matrix): + def __new__(cls, strain_matrix) -> Self: """ Create a Strain object. Note that the constructor uses __new__ rather than __init__ according to the standard method of @@ -185,7 +186,7 @@ def __array_finalize__(self, obj): self._vscale = getattr(obj, "_vscale", None) @classmethod - def from_deformation(cls, deformation: ArrayLike) -> Strain: + def from_deformation(cls, deformation: ArrayLike) -> Self: """ Factory method that returns a Strain object from a deformation gradient. @@ -197,7 +198,7 @@ def from_deformation(cls, deformation: ArrayLike) -> Strain: return cls(0.5 * (np.dot(dfm.trans, dfm) - np.eye(3))) @classmethod - def from_index_amount(cls, idx, amount): + def from_index_amount(cls, idx: tuple | int, amount: float) -> Self: """ Like Deformation.from_index_amount, except generates a strain from the zero 3x3 tensor or Voigt vector with @@ -208,7 +209,7 @@ def from_index_amount(cls, idx, amount): idx (tuple or integer): index to be perturbed, can be Voigt or full-tensor notation amount (float): amount to perturb selected index """ - if np.array(idx).ndim == 0: + if isinstance(idx, int): v = np.zeros(6) v[idx] = amount return cls.from_voigt(v) diff --git a/pymatgen/analysis/elasticity/stress.py b/pymatgen/analysis/elasticity/stress.py index f4eaa9234fd..5fce24d4f5e 100644 --- a/pymatgen/analysis/elasticity/stress.py +++ b/pymatgen/analysis/elasticity/stress.py @@ -6,11 +6,15 @@ from __future__ import annotations import math +from typing import TYPE_CHECKING import numpy as np from pymatgen.core.tensors import SquareTensor +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Joseph Montoya" __copyright__ = "Copyright 2012, The Materials Project" __credits__ = "Maarten de Jong, Mark Asta, Anubhav Jain" @@ -29,7 +33,7 @@ class Stress(SquareTensor): symbol = "s" - def __new__(cls, stress_matrix): + def __new__(cls, stress_matrix) -> Self: """ Create a Stress object. Note that the constructor uses __new__ rather than __init__ according to the standard method of diff --git a/pymatgen/analysis/energy_models.py b/pymatgen/analysis/energy_models.py index eb4ff7591ed..0dafed86a02 100644 --- a/pymatgen/analysis/energy_models.py +++ b/pymatgen/analysis/energy_models.py @@ -15,6 +15,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure __version__ = "0.1" @@ -26,7 +28,8 @@ class EnergyModel(MSONable, abc.ABC): @abc.abstractmethod def get_energy(self, structure) -> float: """ - :param structure: Structure + Args: + structure: Structure Returns: Energy value @@ -34,7 +37,7 @@ def get_energy(self, structure) -> float: return 0.0 @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -73,7 +76,8 @@ def __init__(self, real_space_cut=None, recip_space_cut=None, eta=None, acc_fact def get_energy(self, structure: Structure): """ - :param structure: Structure + Args: + structure: Structure Returns: Energy value @@ -121,7 +125,8 @@ def __init__(self, symprec: float = 0.1, angle_tolerance=5): def get_energy(self, structure: Structure): """ - :param structure: Structure + Args: + structure: Structure Returns: Energy value @@ -156,7 +161,8 @@ def __init__(self, j, max_radius): def get_energy(self, structure: Structure): """ - :param structure: Structure + Args: + structure: Structure Returns: Energy value @@ -188,7 +194,8 @@ class NsitesModel(EnergyModel): def get_energy(self, structure: Structure): """ - :param structure: Structure + Args: + structure: Structure Returns: Energy value diff --git a/pymatgen/analysis/eos.py b/pymatgen/analysis/eos.py index 5d751618dce..a60d7b765e5 100644 --- a/pymatgen/analysis/eos.py +++ b/pymatgen/analysis/eos.py @@ -467,16 +467,16 @@ def get_rms(x, y): # loop over the data points. while (n_data_fit >= n_data_min) and (e_min in e_v_work): max_poly_order = n_data_fit - max_poly_order_factor - e = [ei[0] for ei in e_v_work] - v = [ei[1] for ei in e_v_work] + energies = [ei[0] for ei in e_v_work] + volumes = [ei[1] for ei in e_v_work] # loop over polynomial order for idx in range(min_poly_order, max_poly_order + 1): - coeffs = np.polyfit(v, e, idx) - pder = np.polyder(coeffs) - a = np.poly1d(pder)(v_before) - b = np.poly1d(pder)(v_after) + coeffs = np.polyfit(volumes, energies, idx) + polyder = np.polyder(coeffs) + a = np.poly1d(polyder)(v_before) + b = np.poly1d(polyder)(v_after) if a * b < 0: - rms = get_rms(e, np.poly1d(coeffs)(v)) + rms = get_rms(energies, np.poly1d(coeffs)(volumes)) rms_min = min(rms_min, rms * idx / n_data_fit) all_coeffs[(idx, n_data_fit)] = [coeffs.tolist(), rms] # store the fit coefficients small to large, @@ -495,12 +495,12 @@ def get_rms(x, y): weighted_avg_coeffs = np.zeros((fit_poly_order,)) # combine all the filtered polynomial candidates to get the final fit. - for k, v in all_coeffs.items(): + for key, val in all_coeffs.items(): # weighted rms = rms * polynomial order / rms_min / ndata_fit - weighted_rms = v[1] * k[0] / rms_min / k[1] + weighted_rms = val[1] * key[0] / rms_min / key[1] weight = np.exp(-(weighted_rms**2)) norm += weight - coeffs = np.array(v[0]) + coeffs = np.array(val[0]) # pad the coefficient array with zeros coeffs = np.lib.pad(coeffs, (0, max(fit_poly_order - len(coeffs), 0)), "constant") weighted_avg_coeffs += weight * coeffs @@ -522,7 +522,7 @@ class EOS: Fit equation of state for bulk systems. - The following equations are supported:: + The following equations are supported: murnaghan: PRB 28, 5480 (1983) @@ -539,7 +539,7 @@ class EOS: numerical_eos: 10.1103/PhysRevB.90.174107. - Usage:: + Usage: eos = EOS(eos_name='murnaghan') eos_fit = eos.fit(volumes, energies) diff --git a/pymatgen/analysis/ewald.py b/pymatgen/analysis/ewald.py index 2041efb8168..092f45266a5 100644 --- a/pymatgen/analysis/ewald.py +++ b/pymatgen/analysis/ewald.py @@ -6,7 +6,7 @@ from copy import copy, deepcopy from datetime import datetime from math import log, pi, sqrt -from typing import Any +from typing import TYPE_CHECKING, Any from warnings import warn import numpy as np @@ -17,6 +17,9 @@ from pymatgen.core.structure import Structure from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Shyue Ping Ong, William Davidson Richard" __copyright__ = "Copyright 2011, The Materials Project" __credits__ = "Christopher Fischer" @@ -288,7 +291,7 @@ def get_site_energy(self, site_index): site_index (int): Index of site Returns: - (float) - Energy of that site + float: Energy of that site """ if not self._initialized: self._calc_ewald_terms() @@ -341,8 +344,7 @@ def _calc_recip(self): for g, g2, gr, exp_val, s_real, s_imag in zip(gs, g2s, grs, exp_vals, s_reals, s_imags): # Uses the identity sin(x)+cos(x) = 2**0.5 sin(x + pi/4) - m = (gr[None, :] + pi / 4) - gr[:, None] - np.sin(m, m) + m = np.sin((gr[None, :] + pi / 4) - gr[:, None]) m *= exp_val / g2 e_recip += m @@ -444,11 +446,11 @@ def as_dict(self, verbosity: int = 0) -> dict: } @classmethod - def from_dict(cls, dct: dict[str, Any], fmt: str | None = None, **kwargs) -> EwaldSummation: + def from_dict(cls, dct: dict[str, Any], fmt: str | None = None, **kwargs) -> Self: """Create an EwaldSummation instance from JSON-serialized dictionary. Args: - d (dict): Dictionary representation + dct (dict): Dictionary representation fmt (str, optional): Unused. Defaults to None. Returns: diff --git a/pymatgen/analysis/ferroelectricity/polarization.py b/pymatgen/analysis/ferroelectricity/polarization.py index 349fc133827..1b0f8806914 100644 --- a/pymatgen/analysis/ferroelectricity/polarization.py +++ b/pymatgen/analysis/ferroelectricity/polarization.py @@ -56,6 +56,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.core.sites import PeriodicSite @@ -123,8 +125,8 @@ def get_nearest_site(struct: Structure, coords: Sequence[float], site: PeriodicS Closest site and distance. """ index = struct.index(site) - r = r or np.linalg.norm(np.sum(struct.lattice.matrix, axis=0)) - ns = struct.get_sites_in_sphere(coords, r, include_index=True) + radius = r or np.linalg.norm(np.sum(struct.lattice.matrix, axis=0)) + ns = struct.get_sites_in_sphere(coords, radius, include_index=True) # Get sites with identical index to site ns = [n for n in ns if n[2] == index] # Sort by distance to coords @@ -172,7 +174,7 @@ def __init__(self, p_elecs, p_ions, structures, p_elecs_in_cartesian=True, p_ion self.structures = structures @classmethod - def from_outcars_and_structures(cls, outcars, structures, calc_ionic_from_zval=False): + def from_outcars_and_structures(cls, outcars, structures, calc_ionic_from_zval=False) -> Self: """ Create Polarization object from list of Outcars and Structures in order of nonpolar to polar. @@ -424,7 +426,10 @@ class EnergyTrend: """Class for fitting trends to energies.""" def __init__(self, energies): - """:param energies: Energies""" + """ + Args: + energies: Energies + """ self.energies = energies def spline(self): diff --git a/pymatgen/analysis/fragmenter.py b/pymatgen/analysis/fragmenter.py index 8ef04f8852b..4c83154b848 100644 --- a/pymatgen/analysis/fragmenter.py +++ b/pymatgen/analysis/fragmenter.py @@ -4,6 +4,7 @@ import copy import logging +from typing import TYPE_CHECKING from monty.json import MSONable @@ -11,6 +12,9 @@ from pymatgen.analysis.local_env import OpenBabelNN, metal_edge_extender from pymatgen.io.babel import BabelMolAdaptor +if TYPE_CHECKING: + from pymatgen.core.structure import Molecule + __author__ = "Samuel Blau" __copyright__ = "Copyright 2018, The Materials Project" __version__ = "2.0" @@ -27,14 +31,14 @@ class Fragmenter(MSONable): def __init__( self, - molecule, - edges=None, - depth=1, - open_rings=False, - use_metal_edge_extender=False, - opt_steps=10000, - prev_unique_frag_dict=None, - assume_previous_thoroughness=True, + molecule: Molecule, + edges: list | None = None, + depth: int = 1, + open_rings: bool = False, + use_metal_edge_extender: bool = False, + opt_steps: int = 10000, + prev_unique_frag_dict: dict | None = None, + assume_previous_thoroughness: bool = True, ): """ Standard constructor for molecule fragmentation. @@ -76,10 +80,10 @@ def __init__( self.opt_steps = opt_steps if edges is None: - self.mol_graph = MoleculeGraph.with_local_env_strategy(molecule, OpenBabelNN()) + self.mol_graph = MoleculeGraph.from_local_env_strategy(molecule, OpenBabelNN()) else: - edges = {(e[0], e[1]): None for e in edges} - self.mol_graph = MoleculeGraph.with_edges(molecule, edges) + _edges: dict[tuple[int, int], dict | None] = {(edge[0], edge[1]): None for edge in edges} + self.mol_graph = MoleculeGraph.from_edges(molecule, _edges) if ("Li" in molecule.composition or "Mg" in molecule.composition) and use_metal_edge_extender: self.mol_graph = metal_edge_extender(self.mol_graph) @@ -159,7 +163,7 @@ def __init__( for frag_key in self.unique_frag_dict: self.total_unique_fragments += len(self.unique_frag_dict[frag_key]) - def _fragment_one_level(self, old_frag_dict): + def _fragment_one_level(self, old_frag_dict: dict) -> dict: """ Perform one step of iterative fragmentation on a list of molecule graphs. Loop through the graphs, then loop through each graph's edges and attempt to remove that edge in order to obtain two @@ -210,7 +214,7 @@ def _fragment_one_level(self, old_frag_dict): new_frag_dict[new_frag_key] = [fragment] return new_frag_dict - def _open_all_rings(self): + def _open_all_rings(self) -> None: """ Having already generated all unique fragments that did not require ring opening, now we want to also obtain fragments that do require opening. We achieve this by @@ -221,7 +225,7 @@ def _open_all_rings(self): alph_formula = self.mol_graph.molecule.composition.alphabetical_formula mol_key = f"{alph_formula} E{len(self.mol_graph.graph.edges())}" self.all_unique_frag_dict[mol_key] = [self.mol_graph] - new_frag_keys = {"0": []} + new_frag_keys: dict[str, list] = {"0": []} new_frag_key_dict = {} for key in self.all_unique_frag_dict: for fragment in self.all_unique_frag_dict[key]: @@ -291,7 +295,7 @@ def _open_all_rings(self): self.all_unique_frag_dict.pop(mol_key) -def open_ring(mol_graph, bond, opt_steps): +def open_ring(mol_graph: MoleculeGraph, bond: list, opt_steps: int) -> MoleculeGraph: """ Function to actually open a ring using OpenBabel's local opt. Given a molecule graph and a bond, convert the molecule graph into an OpenBabel molecule, remove @@ -302,4 +306,5 @@ def open_ring(mol_graph, bond, opt_steps): ob_mol = BabelMolAdaptor.from_molecule_graph(mol_graph) ob_mol.remove_bond(bond[0][0] + 1, bond[0][1] + 1) ob_mol.localopt(steps=opt_steps, forcefield="uff") - return MoleculeGraph.with_local_env_strategy(ob_mol.pymatgen_mol, OpenBabelNN()) + + return MoleculeGraph.from_local_env_strategy(ob_mol.pymatgen_mol, OpenBabelNN()) diff --git a/pymatgen/analysis/functional_groups.py b/pymatgen/analysis/functional_groups.py index 7242d13aa78..b54a9c8fb06 100644 --- a/pymatgen/analysis/functional_groups.py +++ b/pymatgen/analysis/functional_groups.py @@ -35,11 +35,11 @@ def __init__(self, molecule, optimize=False): """ Instantiation method for FunctionalGroupExtractor. - :param molecule: Either a filename, a pymatgen.core.structure.Molecule - object, or a pymatgen.analysis.graphs.MoleculeGraph object. - :param optimize: Default False. If True, then the input molecule will be - modified, adding Hydrogens, performing a simple conformer search, - etc. + Args: + molecule: Either a filename, a pymatgen.core.structure.Molecule + object, or a pymatgen.analysis.graphs.MoleculeGraph object. + optimize: Default False. If True, then the input molecule will be + modified, adding Hydrogens, performing a simple conformer search, etc. """ self.molgraph = None @@ -86,7 +86,7 @@ def __init__(self, molecule, optimize=False): raise ValueError("Input to FunctionalGroupExtractor must be str, Molecule, or MoleculeGraph.") if self.molgraph is None: - self.molgraph = MoleculeGraph.with_local_env_strategy(self.molecule, OpenBabelNN()) + self.molgraph = MoleculeGraph.from_local_env_strategy(self.molecule, OpenBabelNN()) # Assign a specie and coordinates to each node in the graph, # corresponding to the Site in the Molecule object @@ -99,22 +99,23 @@ def get_heteroatoms(self, elements=None): Identify non-H, non-C atoms in the MoleculeGraph, returning a list of their node indices. - :param elements: List of elements to identify (if only certain + Args: + elements: List of elements to identify (if only certain functional groups are of interest). Returns: set of ints representing node indices """ - heteroatoms = set() + hetero_atoms = set() for node in self.molgraph.graph.nodes(): if elements is not None: if str(self.species[node]) in elements: - heteroatoms.add(node) + hetero_atoms.add(node) elif str(self.species[node]) not in ["C", "H"]: - heteroatoms.add(node) + hetero_atoms.add(node) - return heteroatoms + return hetero_atoms def get_special_carbon(self, elements=None): """ @@ -129,9 +130,10 @@ def get_special_carbon(self, elements=None): nitrogens or sulfurs; these O, N or S atoms must have only single bonds - all atoms in oxirane, aziridine and thiirane rings" - :param elements: List of elements that will qualify a carbon as special - (if only certain functional groups are of interest). - Default None. + Args: + elements: List of elements that will qualify a carbon as special + (if only certain functional groups are of interest). + Default None. Returns: set of ints representing node indices @@ -196,8 +198,9 @@ def link_marked_atoms(self, atoms): and attempt to connect them, returning a list of disjoint groups of special atoms (and their connected hydrogens). - :param atoms: set of marked "interesting" atoms, presumably identified - using other functions in this class. + Args: + atoms: set of marked "interesting" atoms, presumably identified + using other functions in this class. Returns: list of sets of ints, representing groups of connected atoms @@ -232,9 +235,10 @@ def get_basic_functional_groups(self, func_groups=None): TODO: Think of other functional groups that are important enough to be added (ex: do we need ethyl, butyl, propyl?) - :param func_groups: List of strs representing the functional groups of - interest. Default to None, meaning that all of the functional groups - defined in this function will be sought. + Args: + func_groups: List of strs representing the functional groups of + interest. Default to None, meaning that all of the functional groups + defined in this function will be sought. Returns: list of sets of ints, representing groups of connected atoms @@ -293,14 +297,15 @@ def get_all_functional_groups(self, elements=None, func_groups=None, catch_basic Identify all functional groups (or all within a certain subset) in the molecule, combining the methods described above. - :param elements: List of elements that will qualify a carbon as special - (if only certain functional groups are of interest). - Default None. - :param func_groups: List of strs representing the functional groups of - interest. Default to None, meaning that all of the functional groups - defined in this function will be sought. - :param catch_basic: bool. If True, use get_basic_functional_groups and - other methods + Args: + elements: List of elements that will qualify a carbon as special + (if only certain functional groups are of interest). + Default None. + func_groups: List of strs representing the functional groups of + interest. Default to None, meaning that all of the functional groups + defined in this function will be sought. + catch_basic: bool. If True, use get_basic_functional_groups and + other methods Returns: list of sets of ints, representing groups of connected atoms @@ -318,7 +323,8 @@ def categorize_functional_groups(self, groups): """ Determine classes of functional groups present in a set. - :param groups: Set of functional groups. + Args: + groups: Set of functional groups. Returns: dict containing representations of the groups, the indices of diff --git a/pymatgen/analysis/gb/grain.py b/pymatgen/analysis/gb/grain.py index 7d4e662e3bf..9a5b8c4166a 100644 --- a/pymatgen/analysis/gb/grain.py +++ b/pymatgen/analysis/gb/grain.py @@ -1,2325 +1,11 @@ -"""Module containing classes to generate grain boundaries.""" - from __future__ import annotations -import itertools -import logging import warnings -from fractions import Fraction -from functools import reduce -from math import cos, floor, gcd -from typing import TYPE_CHECKING, Any - -import numpy as np -from monty.fractions import lcm - -from pymatgen.core.lattice import Lattice -from pymatgen.core.sites import PeriodicSite, Site -from pymatgen.core.structure import Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - -if TYPE_CHECKING: - from collections.abc import Sequence - - from numpy.typing import ArrayLike - - from pymatgen.core.trajectory import Vector3D - from pymatgen.util.typing import CompositionLike - -# This module implements representations of grain boundaries, as well as -# algorithms for generating them. - -__author__ = "Xiang-Guo Li" -__copyright__ = "Copyright 2018, The Materials Virtual Lab" -__version__ = "0.1" -__maintainer__ = "Xiang-Guo Li" -__email__ = "xil110@ucsd.edu" -__date__ = "7/30/18" - -logger = logging.getLogger(__name__) - - -class GrainBoundary(Structure): - """ - Subclass of Structure representing a GrainBoundary (GB) object. Implements additional - attributes pertaining to gbs, but the init method does not actually implement any - algorithm that creates a GB. This is a DUMMY class who's init method only holds - information about the GB. Also has additional methods that returns other information - about a GB such as sigma value. - - Note that all gbs have the GB surface normal oriented in the c-direction. This means - the lattice vectors a and b are in the GB surface plane (at least for one grain) and - the c vector is out of the surface plane (though not necessarily perpendicular to the - surface). - """ - - def __init__( - self, - lattice: np.ndarray | Lattice, - species: Sequence[CompositionLike], - coords: Sequence[ArrayLike], - rotation_axis: Vector3D, - rotation_angle: float, - gb_plane: Vector3D, - join_plane: Vector3D, - init_cell: Structure, - vacuum_thickness: float, - ab_shift: tuple[float, float], - site_properties: dict[str, Any], - oriented_unit_cell: Structure, - validate_proximity: bool = False, - coords_are_cartesian: bool = False, - properties: dict | None = None, - ) -> None: - """ - Makes a GB structure, a structure object with additional information - and methods pertaining to gbs. - - Args: - lattice (Lattice/3x3 array): The lattice, either as an instance or - any 2D array. Each row should correspond to a lattice vector. - species ([Species]): Sequence of species on each site. Can take in - flexible input, including: - - i. A sequence of element / species specified either as string - symbols, e.g. ["Li", "Fe2+", "P", ...] or atomic numbers, - e.g., (3, 56, ...) or actual Element or Species objects. - - ii. List of dict of elements/species and occupancies, e.g., - [{"Fe" : 0.5, "Mn":0.5}, ...]. This allows the setup of - disordered structures. - coords (Nx3 array): list of fractional/cartesian coordinates for each species. - rotation_axis (list[int]): Rotation axis of GB in the form of a list of integers, e.g. [1, 1, 0]. - rotation_angle (float, in unit of degree): rotation angle of GB. - gb_plane (list): Grain boundary plane in the form of a list of integers - e.g.: [1, 2, 3]. - join_plane (list): Joining plane of the second grain in the form of a list of - integers. e.g.: [1, 2, 3]. - init_cell (Structure): initial bulk structure to form the GB. - site_properties (dict): Properties associated with the sites as a - dict of sequences, The sequences have to be the same length as - the atomic species and fractional_coords. For GB, you should - have the 'grain_label' properties to classify the sites as 'top', - 'bottom', 'top_incident', or 'bottom_incident'. - vacuum_thickness (float in angstrom): The thickness of vacuum inserted - between two grains of the GB. - ab_shift (list of float, in unit of crystal vector a, b): The relative - shift along a, b vectors. - oriented_unit_cell (Structure): oriented unit cell of the bulk init_cell. - Helps to accurately calculate the bulk properties that are consistent - with GB calculations. - validate_proximity (bool): Whether to check if there are sites - that are less than 0.01 Ang apart. Defaults to False. - coords_are_cartesian (bool): Set to True if you are providing - coordinates in Cartesian coordinates. Defaults to False. - properties (dict): dictionary containing properties associated - with the whole GrainBoundary. - """ - self.oriented_unit_cell = oriented_unit_cell - self.rotation_axis = rotation_axis - self.rotation_angle = rotation_angle - self.gb_plane = gb_plane - self.join_plane = join_plane - self.init_cell = init_cell - self.vacuum_thickness = vacuum_thickness - self.ab_shift = ab_shift - super().__init__( - lattice, - species, - coords, - validate_proximity=validate_proximity, - coords_are_cartesian=coords_are_cartesian, - site_properties=site_properties, - properties=properties, - ) - - def copy(self): - """ - Convenience method to get a copy of the structure, with options to add - site properties. - - Returns: - A copy of the Structure, with optionally new site_properties and - optionally sanitized. - """ - return GrainBoundary( - self.lattice, - self.species_and_occu, - self.frac_coords, - self.rotation_axis, - self.rotation_angle, - self.gb_plane, - self.join_plane, - self.init_cell, - self.vacuum_thickness, - self.ab_shift, - self.site_properties, - self.oriented_unit_cell, - ) - - def get_sorted_structure(self, key=None, reverse=False): - """ - Get a sorted copy of the structure. The parameters have the same - meaning as in list.sort. By default, sites are sorted by the - electronegativity of the species. Note that Slab has to override this - because of the different __init__ args. - - Args: - key: Specifies a function of one argument that is used to extract - a comparison key from each list element: key=str.lower. The - default value is None (compare the elements directly). - reverse (bool): If set to True, then the list elements are sorted - as if each comparison were reversed. - """ - sites = sorted(self, key=key, reverse=reverse) - struct = Structure.from_sites(sites) - return GrainBoundary( - struct.lattice, - struct.species_and_occu, - struct.frac_coords, - self.rotation_axis, - self.rotation_angle, - self.gb_plane, - self.join_plane, - self.init_cell, - self.vacuum_thickness, - self.ab_shift, - self.site_properties, - self.oriented_unit_cell, - ) - - @property - def sigma(self) -> int: - """ - This method returns the sigma value of the GB. - If using 'quick_gen' to generate GB, this value is not valid. - """ - return int(round(self.oriented_unit_cell.volume / self.init_cell.volume)) - - @property - def sigma_from_site_prop(self) -> int: - """ - This method returns the sigma value of the GB from site properties. - If the GB structure merge some atoms due to the atoms too closer with - each other, this property will not work. - """ - n_coi = 0 - if None in self.site_properties["grain_label"]: - raise RuntimeError("Site were merged, this property do not work") - for tag in self.site_properties["grain_label"]: - if "incident" in tag: - n_coi += 1 - return int(round(len(self) / n_coi)) - - @property - def top_grain(self) -> Structure: - """Return the top grain (Structure) of the GB.""" - top_sites = [] - for i, tag in enumerate(self.site_properties["grain_label"]): - if "top" in tag: - top_sites.append(self.sites[i]) - return Structure.from_sites(top_sites) - - @property - def bottom_grain(self) -> Structure: - """Return the bottom grain (Structure) of the GB.""" - bottom_sites = [] - for i, tag in enumerate(self.site_properties["grain_label"]): - if "bottom" in tag: - bottom_sites.append(self.sites[i]) - return Structure.from_sites(bottom_sites) - - @property - def coincidents(self) -> list[Site]: - """Return the a list of coincident sites.""" - coincident_sites = [] - for idx, tag in enumerate(self.site_properties["grain_label"]): - if "incident" in tag: - coincident_sites.append(self.sites[idx]) - return coincident_sites - - def __str__(self): - comp = self.composition - outs = [ - f"Gb Summary ({comp.formula})", - f"Reduced Formula: {comp.reduced_formula}", - f"Rotation axis: {self.rotation_axis}", - f"Rotation angle: {self.rotation_angle}", - f"GB plane: {self.gb_plane}", - f"Join plane: {self.join_plane}", - f"vacuum thickness: {self.vacuum_thickness}", - f"ab_shift: {self.ab_shift}", - ] - - def to_str(x, rjust=10): - return (f"{x:0.6f}").rjust(rjust) - - outs += ( - f"abc : {' '.join(to_str(i) for i in self.lattice.abc)}", - f"angles: {' '.join(to_str(i) for i in self.lattice.angles)}", - f"Sites ({len(self)})", - ) - for idx, site in enumerate(self): - outs.append(f"{idx + 1} {site.species_string} {' '.join(to_str(coord, 12) for coord in site.frac_coords)}") - return "\n".join(outs) - - def as_dict(self): - """ - Returns: - Dictionary representation of GrainBoundary object. - """ - dct = super().as_dict() - dct["@module"] = type(self).__module__ - dct["@class"] = type(self).__name__ - dct["init_cell"] = self.init_cell.as_dict() - dct["rotation_axis"] = self.rotation_axis - dct["rotation_angle"] = self.rotation_angle - dct["gb_plane"] = self.gb_plane - dct["join_plane"] = self.join_plane - dct["vacuum_thickness"] = self.vacuum_thickness - dct["ab_shift"] = self.ab_shift - dct["oriented_unit_cell"] = self.oriented_unit_cell.as_dict() - return dct - - @classmethod - def from_dict(cls, dct: dict) -> GrainBoundary: # type: ignore[override] - """ - Generates a GrainBoundary object from a dictionary created by as_dict(). - - Args: - dct: dict - - Returns: - GrainBoundary object - """ - lattice = Lattice.from_dict(dct["lattice"]) - sites = [PeriodicSite.from_dict(site_dict, lattice) for site_dict in dct["sites"]] - struct = Structure.from_sites(sites) - - return GrainBoundary( - lattice=lattice, - species=struct.species_and_occu, - coords=struct.frac_coords, - rotation_axis=dct["rotation_axis"], - rotation_angle=dct["rotation_angle"], - gb_plane=dct["gb_plane"], - join_plane=dct["join_plane"], - init_cell=Structure.from_dict(dct["init_cell"]), - vacuum_thickness=dct["vacuum_thickness"], - ab_shift=dct["ab_shift"], - oriented_unit_cell=Structure.from_dict(dct["oriented_unit_cell"]), - site_properties=struct.site_properties, - ) - - -class GrainBoundaryGenerator: - """ - This class is to generate grain boundaries (GBs) from bulk - conventional cell (fcc, bcc can from the primitive cell), and works for Cubic, - Tetragonal, Orthorhombic, Rhombohedral, and Hexagonal systems. - It generate GBs from given parameters, which includes - GB plane, rotation axis, rotation angle. - - This class works for any general GB, including twist, tilt and mixed GBs. - The three parameters, rotation axis, GB plane and rotation angle, are - sufficient to identify one unique GB. While sometimes, users may not be able - to tell what exactly rotation angle is but prefer to use sigma as an parameter, - this class also provides the function that is able to return all possible - rotation angles for a specific sigma value. - The same sigma value (with rotation axis fixed) can correspond to - multiple rotation angles. - Users can use structure matcher in pymatgen to get rid of the redundant structures. - """ - - def __init__(self, initial_structure: Structure, symprec: float = 0.1, angle_tolerance: float = 1) -> None: - """ - Args: - initial_structure (Structure): Initial input structure. It can - be conventional or primitive cell (primitive cell works for bcc and fcc). - For fcc and bcc, using conventional cell can lead to a non-primitive - grain boundary structure. - This code supplies Cubic, Tetragonal, Orthorhombic, Rhombohedral, and - Hexagonal systems. - symprec (float): Tolerance for symmetry finding. Defaults to 0.1 (the value used - in Materials Project), which is for structures with slight deviations - from their proper atomic positions (e.g., structures relaxed with - electronic structure codes). - A smaller value of 0.01 is often used for properly refined - structures with atoms in the proper symmetry coordinates. - User should make sure the symmetry is what you want. - angle_tolerance (float): Angle tolerance for symmetry finding. - """ - analyzer = SpacegroupAnalyzer(initial_structure, symprec, angle_tolerance) - self.lat_type = analyzer.get_lattice_type()[0] - if self.lat_type == "t": - # need to use the conventional cell for tetragonal - initial_structure = analyzer.get_conventional_standard_structure() - a, b, c = initial_structure.lattice.abc - # c axis of tetragonal structure not in the third direction - if abs(a - b) > symprec: - # a == c, rotate b to the third direction - if abs(a - c) < symprec: - initial_structure.make_supercell([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - # b == c, rotate a to the third direction - else: - initial_structure.make_supercell([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) - elif self.lat_type == "h": - alpha, beta, gamma = initial_structure.lattice.angles - # c axis is not in the third direction - if abs(gamma - 90) < angle_tolerance: - # alpha = 120 or 60, rotate b, c to a, b vectors - if abs(alpha - 90) > angle_tolerance: - initial_structure.make_supercell([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) - # beta = 120 or 60, rotate c, a to a, b vectors - elif abs(beta - 90) > angle_tolerance: - initial_structure.make_supercell([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - elif self.lat_type == "r": - # need to use primitive cell for rhombohedra - initial_structure = analyzer.get_primitive_standard_structure() - elif self.lat_type == "o": - # need to use the conventional cell for orthorhombic - initial_structure = analyzer.get_conventional_standard_structure() - self.initial_structure = initial_structure - - def gb_from_parameters( - self, - rotation_axis, - rotation_angle, - expand_times=4, - vacuum_thickness=0.0, - ab_shift: tuple[float, float] = (0, 0), - normal=False, - ratio=None, - plane=None, - max_search=20, - tol_coi=1.0e-8, - rm_ratio=0.7, - quick_gen=False, - ): - """ - Args: - rotation_axis (list): Rotation axis of GB in the form of a list of integer - e.g.: [1, 1, 0] - rotation_angle (float, in unit of degree): rotation angle used to generate GB. - Make sure the angle is accurate enough. You can use the enum* functions - in this class to extract the accurate angle. - e.g.: The rotation angle of sigma 3 twist GB with the rotation axis - [1, 1, 1] and GB plane (1, 1, 1) can be 60 degree. - If you do not know the rotation angle, but know the sigma value, we have - provide the function get_rotation_angle_from_sigma which is able to return - all the rotation angles of sigma value you provided. - expand_times (int): The multiple times used to expand one unit grain to larger grain. - This is used to tune the grain length of GB to warrant that the two GBs in one - cell do not interact with each other. Default set to 4. - vacuum_thickness (float, in angstrom): The thickness of vacuum that you want to insert - between two grains of the GB. Default to 0. - ab_shift (list of float, in unit of a, b vectors of Gb): in plane shift of two grains - normal (logic): - determine if need to require the c axis of top grain (first transformation matrix) - perpendicular to the surface or not. - default to false. - ratio (list of integers): - lattice axial ratio. - For cubic system, ratio is not needed. - For tetragonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, - that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None. - e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. - For rhombohedral system, ratio = [mu, mv], list of two integers, - that is, mu/mv is the ratio of (1+2*cos(alpha))/cos(alpha). - If irrational, set it to None. - For hexagonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - This code also supplies a class method to generate the ratio from the - structure (get_ratio). User can also make their own approximation and - input the ratio directly. - plane (list): Grain boundary plane in the form of a list of integers - e.g.: [1, 2, 3]. If none, we set it as twist GB. The plane will be perpendicular - to the rotation axis. - max_search (int): max search for the GB lattice vectors that give the smallest GB - lattice. If normal is true, also max search the GB c vector that perpendicular - to the plane. For complex GB, if you want to speed up, you can reduce this value. - But too small of this value may lead to error. - tol_coi (float): tolerance to find the coincidence sites. When making approximations to - the ratio needed to generate the GB, you probably need to increase this tolerance to - obtain the correct number of coincidence sites. To check the number of coincidence - sites are correct or not, you can compare the generated Gb object's sigma_from_site_prop - with enum* sigma values (what user expected by input). - rm_ratio (float): the criteria to remove the atoms which are too close with each other. - rm_ratio*bond_length of bulk system is the criteria of bond length, below which the atom - will be removed. Default to 0.7. - quick_gen (bool): whether to quickly generate a supercell, if set to true, no need to - find the smallest cell. - - Returns: - Grain boundary structure (GB object). - """ - lat_type = self.lat_type - # if the initial structure is primitive cell in cubic system, - # calculate the transformation matrix from its conventional cell - # to primitive cell, basically for bcc and fcc systems. - trans_cry = np.eye(3) - if lat_type == "c": - analyzer = SpacegroupAnalyzer(self.initial_structure) - convention_cell = analyzer.get_conventional_standard_structure() - vol_ratio = self.initial_structure.volume / convention_cell.volume - # bcc primitive cell, belong to cubic system - if abs(vol_ratio - 0.5) < 1.0e-3: - trans_cry = np.array([[0.5, 0.5, -0.5], [-0.5, 0.5, 0.5], [0.5, -0.5, 0.5]]) - logger.info("Make sure this is for cubic with bcc primitive cell") - # fcc primitive cell, belong to cubic system - elif abs(vol_ratio - 0.25) < 1.0e-3: - trans_cry = np.array([[0.5, 0.5, 0], [0, 0.5, 0.5], [0.5, 0, 0.5]]) - logger.info("Make sure this is for cubic with fcc primitive cell") - else: - logger.info("Make sure this is for cubic with conventional cell") - elif lat_type == "t": - logger.info("Make sure this is for tetragonal system") - if ratio is None: - logger.info("Make sure this is for irrational c2/a2") - elif len(ratio) != 2: - raise RuntimeError("Tetragonal system needs correct c2/a2 ratio") - elif lat_type == "o": - logger.info("Make sure this is for orthorhombic system") - if ratio is None: - raise RuntimeError("CSL does not exist if all axial ratios are irrational for an orthorhombic system") - if len(ratio) != 3: - raise RuntimeError("Orthorhombic system needs correct c2:b2:a2 ratio") - elif lat_type == "h": - logger.info("Make sure this is for hexagonal system") - if ratio is None: - logger.info("Make sure this is for irrational c2/a2") - elif len(ratio) != 2: - raise RuntimeError("Hexagonal system needs correct c2/a2 ratio") - elif lat_type == "r": - logger.info("Make sure this is for rhombohedral system") - if ratio is None: - logger.info("Make sure this is for irrational (1+2*cos(alpha)/cos(alpha) ratio") - elif len(ratio) != 2: - raise RuntimeError("Rhombohedral system needs correct (1+2*cos(alpha)/cos(alpha) ratio") - else: - raise RuntimeError( - "Lattice type not implemented. This code works for cubic, " - "tetragonal, orthorhombic, rhombohedral, hexagonal systems" - ) - - # transform four index notation to three index notation for hexagonal and rhombohedral - if len(rotation_axis) == 4: - u1 = rotation_axis[0] - v1 = rotation_axis[1] - w1 = rotation_axis[3] - if lat_type.lower() == "h": - u = 2 * u1 + v1 - v = 2 * v1 + u1 - w = w1 - rotation_axis = [u, v, w] - elif lat_type.lower() == "r": - u = 2 * u1 + v1 + w1 - v = v1 + w1 - u1 - w = w1 - 2 * v1 - u1 - rotation_axis = [u, v, w] - - # make sure gcd(rotation_axis)==1 - if reduce(gcd, rotation_axis) != 1: - rotation_axis = [int(round(x / reduce(gcd, rotation_axis))) for x in rotation_axis] - # transform four index notation to three index notation for plane - if plane is not None and len(plane) == 4: - u1, v1, w1 = plane[0], plane[1], plane[3] - plane = [u1, v1, w1] - # set the plane for grain boundary when plane is None. - if plane is None: - if lat_type.lower() == "c": - plane = rotation_axis - else: - if lat_type.lower() == "h": - c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] - metric = np.array([[1, -0.5, 0], [-0.5, 1, 0], [0, 0, c2_a2_ratio]]) - elif lat_type.lower() == "r": - cos_alpha = 0.5 if ratio is None else 1.0 / (ratio[0] / ratio[1] - 2) - metric = np.array( - [ - [1, cos_alpha, cos_alpha], - [cos_alpha, 1, cos_alpha], - [cos_alpha, cos_alpha, 1], - ] - ) - elif lat_type.lower() == "t": - c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] - metric = np.array([[1, 0, 0], [0, 1, 0], [0, 0, c2_a2_ratio]]) - elif lat_type.lower() == "o": - for idx in range(3): - if ratio[idx] is None: - ratio[idx] = 1 - metric = np.array([[1, 0, 0], [0, ratio[1] / ratio[2], 0], [0, 0, ratio[0] / ratio[2]]]) - else: - raise RuntimeError("Lattice type has not implemented.") - - plane = np.matmul(rotation_axis, metric) - fractions = [Fraction(x).limit_denominator() for x in plane] - least_mul = reduce(lcm, [f.denominator for f in fractions]) - plane = [int(round(x * least_mul)) for x in plane] - - if reduce(gcd, plane) != 1: - index = reduce(gcd, plane) - plane = [int(round(x / index)) for x in plane] - - t1, t2 = self.get_trans_mat( - r_axis=rotation_axis, - angle=rotation_angle, - normal=normal, - trans_cry=trans_cry, - lat_type=lat_type, - ratio=ratio, - surface=plane, - max_search=max_search, - quick_gen=quick_gen, - ) - - # find the join_plane - if lat_type.lower() != "c": - if lat_type.lower() == "h": - if ratio is None: - mu, mv = [1, 1] - else: - mu, mv = ratio - trans_cry1 = np.array([[1, 0, 0], [-0.5, np.sqrt(3.0) / 2.0, 0], [0, 0, np.sqrt(mu / mv)]]) - elif lat_type.lower() == "r": - if ratio is None: - c2_a2_ratio = 1.0 - else: - mu, mv = ratio - c2_a2_ratio = 3 / (2 - 6 * mv / mu) - trans_cry1 = np.array( - [ - [0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - [-0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - [0, -1 * np.sqrt(3.0) / 3.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - ] - ) - else: - if lat_type.lower() == "t": - if ratio is None: - mu, mv = [1, 1] - else: - mu, mv = ratio - lam = mv - elif lat_type.lower() == "o": - new_ratio = [1 if v is None else v for v in ratio] - mu, lam, mv = new_ratio - trans_cry1 = np.array([[1, 0, 0], [0, np.sqrt(lam / mv), 0], [0, 0, np.sqrt(mu / mv)]]) - else: - trans_cry1 = trans_cry - grain_matrix = np.dot(t2, trans_cry1) - plane_init = np.cross(grain_matrix[0], grain_matrix[1]) - if lat_type.lower() != "c": - plane_init = np.dot(plane_init, trans_cry1.T) - join_plane = self.vec_to_surface(plane_init) - - parent_structure = self.initial_structure.copy() - # calculate the bond_length in bulk system. - if len(parent_structure) == 1: - temp_str = parent_structure.copy() - temp_str.make_supercell([1, 1, 2]) - distance = temp_str.distance_matrix - else: - distance = parent_structure.distance_matrix - bond_length = np.min(distance[np.nonzero(distance)]) - - # top grain - top_grain = fix_pbc(parent_structure * t1) - - # obtain the smallest oriented cell - if normal and not quick_gen: - t_temp = self.get_trans_mat( - r_axis=rotation_axis, - angle=rotation_angle, - normal=False, - trans_cry=trans_cry, - lat_type=lat_type, - ratio=ratio, - surface=plane, - max_search=max_search, - ) - oriented_unit_cell = fix_pbc(parent_structure * t_temp[0]) - t_matrix = oriented_unit_cell.lattice.matrix - normal_v_plane = np.cross(t_matrix[0], t_matrix[1]) - unit_normal_v = normal_v_plane / np.linalg.norm(normal_v_plane) - unit_ab_adjust = (t_matrix[2] - np.dot(unit_normal_v, t_matrix[2]) * unit_normal_v) / np.dot( - unit_normal_v, t_matrix[2] - ) - else: - oriented_unit_cell = top_grain.copy() - unit_ab_adjust = 0.0 - - # bottom grain, using top grain's lattice matrix - bottom_grain = fix_pbc(parent_structure * t2, top_grain.lattice.matrix) - - # label both grains with 'top','bottom','top_incident','bottom_incident' - n_sites = len(top_grain) - t_and_b = Structure( - top_grain.lattice, - top_grain.species + bottom_grain.species, - list(top_grain.frac_coords) + list(bottom_grain.frac_coords), - ) - t_and_b_dis = t_and_b.lattice.get_all_distances( - t_and_b.frac_coords[0:n_sites], t_and_b.frac_coords[n_sites : n_sites * 2] - ) - index_incident = np.nonzero(t_and_b_dis < np.min(t_and_b_dis) + tol_coi) - - top_labels = [] - for idx in range(n_sites): - if idx in index_incident[0]: - top_labels.append("top_incident") - else: - top_labels.append("top") - bottom_labels = [] - for idx in range(n_sites): - if idx in index_incident[1]: - bottom_labels.append("bottom_incident") - else: - bottom_labels.append("bottom") - top_grain = Structure( - Lattice(top_grain.lattice.matrix), - top_grain.species, - top_grain.frac_coords, - site_properties={"grain_label": top_labels}, - ) - bottom_grain = Structure( - Lattice(bottom_grain.lattice.matrix), - bottom_grain.species, - bottom_grain.frac_coords, - site_properties={"grain_label": bottom_labels}, - ) - - # expand both grains - top_grain.make_supercell([1, 1, expand_times]) - bottom_grain.make_supercell([1, 1, expand_times]) - top_grain = fix_pbc(top_grain) - bottom_grain = fix_pbc(bottom_grain) - - # determine the top-grain location. - edge_b = 1.0 - max(bottom_grain.frac_coords[:, 2]) - edge_t = 1.0 - max(top_grain.frac_coords[:, 2]) - c_adjust = (edge_t - edge_b) / 2.0 - - # construct all species - all_species = [] - all_species.extend([site.specie for site in bottom_grain]) - all_species.extend([site.specie for site in top_grain]) - - half_lattice = top_grain.lattice - # calculate translation vector, perpendicular to the plane - normal_v_plane = np.cross(half_lattice.matrix[0], half_lattice.matrix[1]) - unit_normal_v = normal_v_plane / np.linalg.norm(normal_v_plane) - translation_v = unit_normal_v * vacuum_thickness - - # construct the final lattice - whole_matrix_no_vac = np.array(half_lattice.matrix) - whole_matrix_no_vac[2] = half_lattice.matrix[2] * 2 - whole_matrix_with_vac = whole_matrix_no_vac.copy() - whole_matrix_with_vac[2] = whole_matrix_no_vac[2] + translation_v * 2 - whole_lat = Lattice(whole_matrix_with_vac) - - # construct the coords, move top grain with translation_v - all_coords = [] - grain_labels = bottom_grain.site_properties["grain_label"] + top_grain.site_properties["grain_label"] - for site in bottom_grain: - all_coords.append(site.coords) - for site in top_grain: - all_coords.append( - site.coords - + half_lattice.matrix[2] * (1 + c_adjust) - + unit_ab_adjust * np.linalg.norm(half_lattice.matrix[2] * (1 + c_adjust)) - + translation_v - + ab_shift[0] * whole_matrix_with_vac[0] - + ab_shift[1] * whole_matrix_with_vac[1] - ) - - gb_with_vac = Structure( - whole_lat, - all_species, - all_coords, - coords_are_cartesian=True, - site_properties={"grain_label": grain_labels}, - ) - # merge closer atoms. extract near GB atoms. - cos_c_norm_plane = np.dot(unit_normal_v, whole_matrix_with_vac[2]) / whole_lat.c - range_c_len = abs(bond_length / cos_c_norm_plane / whole_lat.c) - sites_near_gb = [] - sites_away_gb: list[PeriodicSite] = [] - for site in gb_with_vac: - if ( - site.frac_coords[2] < range_c_len - or site.frac_coords[2] > 1 - range_c_len - or (site.frac_coords[2] > 0.5 - range_c_len and site.frac_coords[2] < 0.5 + range_c_len) - ): - sites_near_gb.append(site) - else: - sites_away_gb.append(site) - if len(sites_near_gb) >= 1: - s_near_gb = Structure.from_sites(sites_near_gb) - s_near_gb.merge_sites(tol=bond_length * rm_ratio, mode="d") - all_sites = sites_away_gb + s_near_gb.sites # type: ignore - gb_with_vac = Structure.from_sites(all_sites) - - # move coordinates into the periodic cell. - gb_with_vac = fix_pbc(gb_with_vac, whole_lat.matrix) - return GrainBoundary( - whole_lat, - gb_with_vac.species, - gb_with_vac.cart_coords, # type: ignore[arg-type] - rotation_axis, - rotation_angle, - plane, - join_plane, - self.initial_structure, - vacuum_thickness, - ab_shift, - site_properties=gb_with_vac.site_properties, - oriented_unit_cell=oriented_unit_cell, - coords_are_cartesian=True, - ) - - def get_ratio(self, max_denominator=5, index_none=None): - """ - find the axial ratio needed for GB generator input. - - Args: - max_denominator (int): the maximum denominator for - the computed ratio, default to be 5. - index_none (int): specify the irrational axis. - 0-a, 1-b, 2-c. Only may be needed for orthorhombic system. - - Returns: - axial ratio needed for GB generator (list of integers). - """ - structure = self.initial_structure - lat_type = self.lat_type - if lat_type in ("t", "h"): - # For tetragonal and hexagonal system, ratio = c2 / a2. - a, _, c = structure.lattice.lengths - if c > a: - frac = Fraction(c**2 / a**2).limit_denominator(max_denominator) - ratio = [frac.numerator, frac.denominator] - else: - frac = Fraction(a**2 / c**2).limit_denominator(max_denominator) - ratio = [frac.denominator, frac.numerator] - elif lat_type == "r": - # For rhombohedral system, ratio = (1 + 2 * cos(alpha)) / cos(alpha). - cos_alpha = cos(structure.lattice.alpha / 180 * np.pi) - frac = Fraction((1 + 2 * cos_alpha) / cos_alpha).limit_denominator(max_denominator) - ratio = [frac.numerator, frac.denominator] - elif lat_type == "o": - # For orthorhombic system, ratio = c2:b2:a2.If irrational for one axis, set it to None. - ratio = [None] * 3 - lat = (structure.lattice.c, structure.lattice.b, structure.lattice.a) - index = [0, 1, 2] - if index_none is None: - min_index = np.argmin(lat) - index.pop(min_index) - frac1 = Fraction(lat[index[0]] ** 2 / lat[min_index] ** 2).limit_denominator(max_denominator) - frac2 = Fraction(lat[index[1]] ** 2 / lat[min_index] ** 2).limit_denominator(max_denominator) - com_lcm = lcm(frac1.denominator, frac2.denominator) - ratio[min_index] = com_lcm - ratio[index[0]] = frac1.numerator * int(round(com_lcm / frac1.denominator)) - ratio[index[1]] = frac2.numerator * int(round(com_lcm / frac2.denominator)) - else: - index.pop(index_none) - if lat[index[0]] > lat[index[1]]: - frac = Fraction(lat[index[0]] ** 2 / lat[index[1]] ** 2).limit_denominator(max_denominator) - ratio[index[0]] = frac.numerator - ratio[index[1]] = frac.denominator - else: - frac = Fraction(lat[index[1]] ** 2 / lat[index[0]] ** 2).limit_denominator(max_denominator) - ratio[index[1]] = frac.numerator - ratio[index[0]] = frac.denominator - elif lat_type == "c": - # Cubic system does not need axial ratio. - return None - else: - raise RuntimeError("Lattice type not implemented.") - return ratio - - @staticmethod - def get_trans_mat( - r_axis, - angle, - normal=False, - trans_cry=None, - lat_type="c", - ratio=None, - surface=None, - max_search=20, - quick_gen=False, - ): - """ - Find the two transformation matrix for each grain from given rotation axis, - GB plane, rotation angle and corresponding ratio (see explanation for ratio - below). - The structure of each grain can be obtained by applying the corresponding - transformation matrix to the conventional cell. - The algorithm for this code is from reference, Acta Cryst, A32,783(1976). - - Args: - r_axis (list of 3 integers, e.g. u, v, w - or 4 integers, e.g. u, v, t, w for hex/rho system only): - the rotation axis of the grain boundary. - angle (float, in unit of degree): the rotation angle of the grain boundary - normal (logic): determine if need to require the c axis of one grain associated with - the first transformation matrix perpendicular to the surface or not. - default to false. - trans_cry (np.array): shape 3x3. If the structure given are primitive cell in cubic system, e.g. - bcc or fcc system, trans_cry is the transformation matrix from its - conventional cell to the primitive cell. - lat_type (str): one character to specify the lattice type. Defaults to 'c' for cubic. - 'c' or 'C': cubic system - 't' or 'T': tetragonal system - 'o' or 'O': orthorhombic system - 'h' or 'H': hexagonal system - 'r' or 'R': rhombohedral system - ratio (list of integers): - lattice axial ratio. - For cubic system, ratio is not needed. - For tetragonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, - that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None. - e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. - For rhombohedral system, ratio = [mu, mv], list of two integers, - that is, mu/mv is the ratio of (1+2*cos(alpha)/cos(alpha). - If irrational, set it to None. - For hexagonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - surface (list of 3 integers, e.g. h, k, l - or 4 integers, e.g. h, k, i, l for hex/rho system only): - the miller index of grain boundary plane, with the format of [h,k,l] - if surface is not given, the default is perpendicular to r_axis, which is - a twist grain boundary. - max_search (int): max search for the GB lattice vectors that give the smallest GB - lattice. If normal is true, also max search the GB c vector that perpendicular - to the plane. - quick_gen (bool): whether to quickly generate a supercell, if set to true, no need to - find the smallest cell. - - Returns: - t1 (3 by 3 integer array): The transformation array for one grain. - t2 (3 by 3 integer array): The transformation array for the other grain - """ - if trans_cry is None: - trans_cry = np.eye(3) - # transform four index notation to three index notation - if len(r_axis) == 4: - u1 = r_axis[0] - v1 = r_axis[1] - w1 = r_axis[3] - if lat_type.lower() == "h": - u = 2 * u1 + v1 - v = 2 * v1 + u1 - w = w1 - r_axis = [u, v, w] - elif lat_type.lower() == "r": - u = 2 * u1 + v1 + w1 - v = v1 + w1 - u1 - w = w1 - 2 * v1 - u1 - r_axis = [u, v, w] - - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - - if surface is not None and len(surface) == 4: - u1 = surface[0] - v1 = surface[1] - w1 = surface[3] - surface = [u1, v1, w1] - # set the surface for grain boundary. - if surface is None: - if lat_type.lower() == "c": - surface = r_axis - else: - if lat_type.lower() == "h": - c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] - metric = np.array([[1, -0.5, 0], [-0.5, 1, 0], [0, 0, c2_a2_ratio]]) - elif lat_type.lower() == "r": - cos_alpha = 0.5 if ratio is None else 1.0 / (ratio[0] / ratio[1] - 2) - metric = np.array( - [ - [1, cos_alpha, cos_alpha], - [cos_alpha, 1, cos_alpha], - [cos_alpha, cos_alpha, 1], - ] - ) - elif lat_type.lower() == "t": - c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] - metric = np.array([[1, 0, 0], [0, 1, 0], [0, 0, c2_a2_ratio]]) - elif lat_type.lower() == "o": - for idx in range(3): - if ratio[idx] is None: - ratio[idx] = 1 - metric = np.array( - [ - [1, 0, 0], - [0, ratio[1] / ratio[2], 0], - [0, 0, ratio[0] / ratio[2]], - ] - ) - else: - raise RuntimeError("Lattice type has not implemented.") - - surface = np.matmul(r_axis, metric) - fractions = [Fraction(x).limit_denominator() for x in surface] - least_mul = reduce(lcm, [f.denominator for f in fractions]) - surface = [int(round(x * least_mul)) for x in surface] - - if reduce(gcd, surface) != 1: - index = reduce(gcd, surface) - surface = [int(round(x / index)) for x in surface] - - if lat_type.lower() == "h": - # set the value for u,v,w,mu,mv,m,n,d,x - # check the reference for the meaning of these parameters - u, v, w = r_axis - # make sure mu, mv are coprime integers. - if ratio is None: - mu, mv = [1, 1] - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - else: - mu, mv = ratio - if gcd(mu, mv) != 1: - temp = gcd(mu, mv) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - d = (u**2 + v**2 - u * v) * mv + w**2 * mu - if abs(angle - 180.0) < 1.0e0: - m = 0 - n = 1 - else: - fraction = Fraction( - np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(float(d) / 3.0 / mu) - ).limit_denominator() - m = fraction.denominator - n = fraction.numerator - - # construct the rotation matrix, check reference for details - r_list = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, - (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, - (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, - 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, - (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, - (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, - ] - m = -1 * m - r_list_inv = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, - (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, - (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, - 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, - (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, - (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, - ] - m = -1 * m - F = 3 * mu * m**2 + d * n**2 - all_list = r_list + r_list_inv + [F] - com_fac = reduce(gcd, all_list) - sigma = F / com_fac - r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) - elif lat_type.lower() == "r": - # set the value for u,v,w,mu,mv,m,n,d - # check the reference for the meaning of these parameters - u, v, w = r_axis - # make sure mu, mv are coprime integers. - if ratio is None: - mu, mv = [1, 1] - if u + v + w != 0 and (u != v or u != w): - raise RuntimeError( - "For irrational ratio_alpha, CSL only exist for [1,1,1] or [u, v, -(u+v)] and m =0" - ) - else: - mu, mv = ratio - if gcd(mu, mv) != 1: - temp = gcd(mu, mv) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - d = (u**2 + v**2 + w**2) * (mu - 2 * mv) + 2 * mv * (v * w + w * u + u * v) - if abs(angle - 180.0) < 1.0e0: - m = 0 - n = 1 - else: - fraction = Fraction(np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(float(d) / mu)).limit_denominator() - m = fraction.denominator - n = fraction.numerator - - # construct the rotation matrix, check reference for details - r_list = [ - (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 - + 2 * mv * (v - w) * m * n - - 2 * mv * v * w * n**2 - + mu * m**2, - 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 - + 2 * mv * (w - u) * m * n - - 2 * mv * u * w * n**2 - + mu * m**2, - 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 - + 2 * mv * (u - v) * m * n - - 2 * mv * u * v * n**2 - + mu * m**2, - ] - m = -1 * m - r_list_inv = [ - (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 - + 2 * mv * (v - w) * m * n - - 2 * mv * v * w * n**2 - + mu * m**2, - 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 - + 2 * mv * (w - u) * m * n - - 2 * mv * u * w * n**2 - + mu * m**2, - 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 - + 2 * mv * (u - v) * m * n - - 2 * mv * u * v * n**2 - + mu * m**2, - ] - m = -1 * m - F = mu * m**2 + d * n**2 - all_list = r_list_inv + r_list + [F] - com_fac = reduce(gcd, all_list) - sigma = F / com_fac - r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) - else: - u, v, w = r_axis - if lat_type.lower() == "c": - mu = 1 - lam = 1 - mv = 1 - elif lat_type.lower() == "t": - if ratio is None: - mu, mv = [1, 1] - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - else: - mu, mv = ratio - lam = mv - elif lat_type.lower() == "o": - if None in ratio: - mu, lam, mv = ratio - non_none = [i for i in ratio if i is not None] - if len(non_none) < 2: - raise RuntimeError("No CSL exist for two irrational numbers") - non1, non2 = non_none - if mu is None: - lam = non1 - mv = non2 - mu = 1 - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - elif lam is None: - mu = non1 - mv = non2 - lam = 1 - if v != 0 and (u != 0 or (w != 0)): - raise RuntimeError("For irrational b2, CSL only exist for [0,1,0] or [u,0,w] and m = 0") - elif mv is None: - mu = non1 - lam = non2 - mv = 1 - if u != 0 and (w != 0 or (v != 0)): - raise RuntimeError("For irrational a2, CSL only exist for [1,0,0] or [0,v,w] and m = 0") - else: - mu, lam, mv = ratio - if u == 0 and v == 0: - mu = 1 - if u == 0 and w == 0: - lam = 1 - if v == 0 and w == 0: - mv = 1 - - # make sure mu, lambda, mv are coprime integers. - if reduce(gcd, [mu, lam, mv]) != 1: - temp = reduce(gcd, [mu, lam, mv]) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - lam = int(round(lam / temp)) - d = (mv * u**2 + lam * v**2) * mv + w**2 * mu * mv - if abs(angle - 180.0) < 1.0e0: - m = 0 - n = 1 - else: - fraction = Fraction(np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(d / mu / lam)).limit_denominator() - m = fraction.denominator - n = fraction.numerator - r_list = [ - (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * lam * (v * u * mv * n**2 - w * mu * m * n), - 2 * mu * (u * w * mv * n**2 + v * lam * m * n), - 2 * mv * (u * v * mv * n**2 + w * mu * m * n), - (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * mv * mu * (v * w * n**2 - u * m * n), - 2 * mv * (u * w * mv * n**2 - v * lam * m * n), - 2 * lam * mv * (v * w * n**2 + u * m * n), - (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, - ] - m = -1 * m - r_list_inv = [ - (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * lam * (v * u * mv * n**2 - w * mu * m * n), - 2 * mu * (u * w * mv * n**2 + v * lam * m * n), - 2 * mv * (u * v * mv * n**2 + w * mu * m * n), - (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * mv * mu * (v * w * n**2 - u * m * n), - 2 * mv * (u * w * mv * n**2 - v * lam * m * n), - 2 * lam * mv * (v * w * n**2 + u * m * n), - (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, - ] - m = -1 * m - F = mu * lam * m**2 + d * n**2 - all_list = r_list + r_list_inv + [F] - com_fac = reduce(gcd, all_list) - sigma = F / com_fac - r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) - - if sigma > 1000: - raise RuntimeError("Sigma >1000 too large. Are you sure what you are doing, Please check the GB if exist") - # transform surface, r_axis, r_matrix in terms of primitive lattice - surface = np.matmul(surface, np.transpose(trans_cry)) - fractions = [Fraction(x).limit_denominator() for x in surface] - least_mul = reduce(lcm, [f.denominator for f in fractions]) - surface = [int(round(x * least_mul)) for x in surface] - if reduce(gcd, surface) != 1: - index = reduce(gcd, surface) - surface = [int(round(x / index)) for x in surface] - r_axis = np.rint(np.matmul(r_axis, np.linalg.inv(trans_cry))).astype(int) - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - r_matrix = np.dot(np.dot(np.linalg.inv(trans_cry.T), r_matrix), trans_cry.T) - # set one vector of the basis to the rotation axis direction, and - # obtain the corresponding transform matrix - eye = np.eye(3, dtype=int) - for hh in range(3): - if abs(r_axis[hh]) != 0: - eye[hh] = np.array(r_axis) - kk = hh + 1 if hh + 1 < 3 else abs(2 - hh) - ll = hh + 2 if hh + 2 < 3 else abs(1 - hh) - break - trans = eye.T - new_rot = np.array(r_matrix) - - # with the rotation matrix to construct the CSL lattice, check reference for details - fractions = [Fraction(x).limit_denominator() for x in new_rot[:, kk]] - least_mul = reduce(lcm, [f.denominator for f in fractions]) - scale = np.zeros((3, 3)) - scale[hh, hh] = 1 - scale[kk, kk] = least_mul - scale[ll, ll] = sigma / least_mul - for idx in range(least_mul): - check_int = idx * new_rot[:, kk] + (sigma / least_mul) * new_rot[:, ll] - if all(np.round(x, 5).is_integer() for x in list(check_int)): - n_final = idx - break - if "n_final" not in locals(): - raise RuntimeError("Something is wrong. Check if this GB exists or not") - scale[kk, ll] = n_final - # each row of mat_csl is the CSL lattice vector - csl_init = np.rint(np.dot(np.dot(r_matrix, trans), scale)).astype(int).T - if abs(r_axis[hh]) > 1: - csl_init = GrainBoundaryGenerator.reduce_mat(np.array(csl_init), r_axis[hh], r_matrix) - csl = np.rint(Lattice(csl_init).get_niggli_reduced_lattice().matrix).astype(int) - - # find the best slab supercell in terms of the conventional cell from the csl lattice, - # which is the transformation matrix - - # now trans_cry is the transformation matrix from crystal to Cartesian coordinates. - # for cubic, do not need to change. - if lat_type.lower() != "c": - if lat_type.lower() == "h": - trans_cry = np.array([[1, 0, 0], [-0.5, np.sqrt(3.0) / 2.0, 0], [0, 0, np.sqrt(mu / mv)]]) - elif lat_type.lower() == "r": - c2_a2_ratio = 1.0 if ratio is None else 3.0 / (2 - 6 * mv / mu) - trans_cry = np.array( - [ - [0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - [-0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - [0, -1 * np.sqrt(3.0) / 3.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], - ] - ) - else: - trans_cry = np.array([[1, 0, 0], [0, np.sqrt(lam / mv), 0], [0, 0, np.sqrt(mu / mv)]]) - t1_final = GrainBoundaryGenerator.slab_from_csl( - csl, surface, normal, trans_cry, max_search=max_search, quick_gen=quick_gen - ) - t2_final = np.array(np.rint(np.dot(t1_final, np.linalg.inv(r_matrix.T)))).astype(int) - return t1_final, t2_final - - @staticmethod - def enum_sigma_cubic(cutoff, r_axis): - """ - Find all possible sigma values and corresponding rotation angles - within a sigma value cutoff with known rotation axis in cubic system. - The algorithm for this code is from reference, Acta Cryst, A40,108(1984). - - Args: - cutoff (int): the cutoff of sigma values. - r_axis (list of 3 integers, e.g. u, v, w): - the rotation axis of the grain boundary, with the format of [u,v,w]. - - Returns: - dict: sigmas dictionary with keys as the possible integer sigma values - and values as list of the possible rotation angles to the - corresponding sigma values. e.g. the format as - {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} - Note: the angles are the rotation angles of one grain respect to - the other grain. - When generating the microstructures of the grain boundary using these angles, - you need to analyze the symmetry of the structure. Different angles may - result in equivalent microstructures. - """ - sigmas = {} - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - - # count the number of odds in r_axis - odd_r = len(list(filter(lambda x: x % 2 == 1, r_axis))) - # Compute the max n we need to enumerate. - if odd_r == 3: - a_max = 4 - elif odd_r == 0: - a_max = 1 - else: - a_max = 2 - n_max = int(np.sqrt(cutoff * a_max / sum(np.array(r_axis) ** 2))) - # enumerate all possible n, m to give possible sigmas within the cutoff. - for n_loop in range(1, n_max + 1): - n = n_loop - m_max = int(np.sqrt(cutoff * a_max - n**2 * sum(np.array(r_axis) ** 2))) - for m in range(m_max + 1): - if gcd(m, n) == 1 or m == 0: - n = 1 if m == 0 else n_loop - # construct the quadruple [m, U,V,W], count the number of odds in - # quadruple to determine the parameter a, refer to the reference - quadruple = [m] + [x * n for x in r_axis] - odd_qua = len(list(filter(lambda x: x % 2 == 1, quadruple))) - if odd_qua == 4: - a = 4 - elif odd_qua == 2: - a = 2 - else: - a = 1 - sigma = int(round((m**2 + n**2 * sum(np.array(r_axis) ** 2)) / a)) - if 1 < sigma <= cutoff: - if sigma not in list(sigmas): - if m == 0: - angle = 180.0 - else: - angle = 2 * np.arctan(n * np.sqrt(sum(np.array(r_axis) ** 2)) / m) / np.pi * 180 - sigmas[sigma] = [angle] - else: - if m == 0: - angle = 180.0 - else: - angle = 2 * np.arctan(n * np.sqrt(sum(np.array(r_axis) ** 2)) / m) / np.pi * 180 - if angle not in sigmas[sigma]: - sigmas[sigma].append(angle) - return sigmas - - @staticmethod - def enum_sigma_hex(cutoff, r_axis, c2_a2_ratio): - """ - Find all possible sigma values and corresponding rotation angles - within a sigma value cutoff with known rotation axis in hexagonal system. - The algorithm for this code is from reference, Acta Cryst, A38,550(1982). - - Args: - cutoff (int): the cutoff of sigma values. - r_axis (list of 3 integers, e.g. u, v, w - or 4 integers, e.g. u, v, t, w): - the rotation axis of the grain boundary. - c2_a2_ratio (list of two integers, e.g. mu, mv): - mu/mv is the square of the hexagonal axial ratio, which is rational - number. If irrational, set c2_a2_ratio = None - - Returns: - sigmas (dict): - dictionary with keys as the possible integer sigma values - and values as list of the possible rotation angles to the - corresponding sigma values. - e.g. the format as - {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} - Note: the angles are the rotation angle of one grain respect to the - other grain. - When generating the microstructure of the grain boundary using these - angles, you need to analyze the symmetry of the structure. Different - angles may result in equivalent microstructures. - """ - sigmas = {} - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - # transform four index notation to three index notation - if len(r_axis) == 4: - u1 = r_axis[0] - v1 = r_axis[1] - w1 = r_axis[3] - u = 2 * u1 + v1 - v = 2 * v1 + u1 - w = w1 - else: - u, v, w = r_axis - - # make sure mu, mv are coprime integers. - if c2_a2_ratio is None: - mu, mv = [1, 1] - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - else: - mu, mv = c2_a2_ratio - if gcd(mu, mv) != 1: - temp = gcd(mu, mv) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - - # refer to the meaning of d in reference - d = (u**2 + v**2 - u * v) * mv + w**2 * mu - - # Compute the max n we need to enumerate. - n_max = int(np.sqrt((cutoff * 12 * mu * mv) / abs(d))) - - # Enumerate all possible n, m to give possible sigmas within the cutoff. - for n in range(1, n_max + 1): - if (c2_a2_ratio is None) and w == 0: - m_max = 0 - else: - m_max = int(np.sqrt((cutoff * 12 * mu * mv - n**2 * d) / (3 * mu))) - for m in range(m_max + 1): - if gcd(m, n) == 1 or m == 0: - # construct the rotation matrix, refer to the reference - R_list = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, - (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, - (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, - 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, - (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, - (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, - ] - m = -1 * m - # inverse of the rotation matrix - R_list_inv = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, - (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, - (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, - 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, - (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, - (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, - ] - m = -1 * m - F = 3 * mu * m**2 + d * n**2 - all_list = R_list_inv + R_list + [F] - # Compute the max common factors for the elements of the rotation matrix - # and its inverse. - com_fac = reduce(gcd, all_list) - sigma = int(round((3 * mu * m**2 + d * n**2) / com_fac)) - if 1 < sigma <= cutoff: - if sigma not in list(sigmas): - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / 3.0 / mu)) / np.pi * 180 - sigmas[sigma] = [angle] - else: - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / 3.0 / mu)) / np.pi * 180 - if angle not in sigmas[sigma]: - sigmas[sigma].append(angle) - if m_max == 0: - break - return sigmas - - @staticmethod - def enum_sigma_rho(cutoff, r_axis, ratio_alpha): - """ - Find all possible sigma values and corresponding rotation angles - within a sigma value cutoff with known rotation axis in rhombohedral system. - The algorithm for this code is from reference, Acta Cryst, A45,505(1989). - - Args: - cutoff (int): the cutoff of sigma values. - r_axis (list[int]): of 3 integers, e.g. u, v, w - or 4 integers, e.g. u, v, t, w): - the rotation axis of the grain boundary, with the format of [u,v,w] - or Weber indices [u, v, t, w]. - ratio_alpha (list of two integers, e.g. mu, mv): - mu/mv is the ratio of (1+2*cos(alpha))/cos(alpha) with rational number. - If irrational, set ratio_alpha = None. - - Returns: - sigmas (dict): - dictionary with keys as the possible integer sigma values - and values as list of the possible rotation angles to the - corresponding sigma values. - e.g. the format as - {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} - Note: the angles are the rotation angle of one grain respect to the - other grain. - When generating the microstructure of the grain boundary using these - angles, you need to analyze the symmetry of the structure. Different - angles may result in equivalent microstructures. - """ - sigmas = {} - # transform four index notation to three index notation - if len(r_axis) == 4: - u1 = r_axis[0] - v1 = r_axis[1] - w1 = r_axis[3] - u = 2 * u1 + v1 + w1 - v = v1 + w1 - u1 - w = w1 - 2 * v1 - u1 - r_axis = [u, v, w] - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - u, v, w = r_axis - # make sure mu, mv are coprime integers. - if ratio_alpha is None: - mu, mv = [1, 1] - if u + v + w != 0 and (u != v or u != w): - raise RuntimeError("For irrational ratio_alpha, CSL only exist for [1,1,1] or [u, v, -(u+v)] and m =0") - else: - mu, mv = ratio_alpha - if gcd(mu, mv) != 1: - temp = gcd(mu, mv) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - - # refer to the meaning of d in reference - d = (u**2 + v**2 + w**2) * (mu - 2 * mv) + 2 * mv * (v * w + w * u + u * v) - # Compute the max n we need to enumerate. - n_max = int(np.sqrt((cutoff * abs(4 * mu * (mu - 3 * mv))) / abs(d))) - - # Enumerate all possible n, m to give possible sigmas within the cutoff. - for n in range(1, n_max + 1): - if ratio_alpha is None and u + v + w == 0: - m_max = 0 - else: - m_max = int(np.sqrt((cutoff * abs(4 * mu * (mu - 3 * mv)) - n**2 * d) / (mu))) - for m in range(m_max + 1): - if gcd(m, n) == 1 or m == 0: - # construct the rotation matrix, refer to the reference - R_list = [ - (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 - + 2 * mv * (v - w) * m * n - - 2 * mv * v * w * n**2 - + mu * m**2, - 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 - + 2 * mv * (w - u) * m * n - - 2 * mv * u * w * n**2 - + mu * m**2, - 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 - + 2 * mv * (u - v) * m * n - - 2 * mv * u * v * n**2 - + mu * m**2, - ] - m = -1 * m - # inverse of the rotation matrix - R_list_inv = [ - (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 - + 2 * mv * (v - w) * m * n - - 2 * mv * v * w * n**2 - + mu * m**2, - 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), - (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 - + 2 * mv * (w - u) * m * n - - 2 * mv * u * w * n**2 - + mu * m**2, - 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), - 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), - (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 - + 2 * mv * (u - v) * m * n - - 2 * mv * u * v * n**2 - + mu * m**2, - ] - m = -1 * m - F = mu * m**2 + d * n**2 - all_list = R_list_inv + R_list + [F] - # Compute the max common factors for the elements of the rotation matrix and its inverse. - com_fac = reduce(gcd, all_list) - sigma = int(round(abs(F / com_fac))) - if 1 < sigma <= cutoff: - if sigma not in list(sigmas): - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 - sigmas[sigma] = [angle] - else: - angle = 180 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180.0 - if angle not in sigmas[sigma]: - sigmas[sigma].append(angle) - if m_max == 0: - break - return sigmas - - @staticmethod - def enum_sigma_tet(cutoff, r_axis, c2_a2_ratio): - """ - Find all possible sigma values and corresponding rotation angles - within a sigma value cutoff with known rotation axis in tetragonal system. - The algorithm for this code is from reference, Acta Cryst, B46,117(1990). - - Args: - cutoff (int): the cutoff of sigma values. - r_axis (list of 3 integers, e.g. u, v, w): - the rotation axis of the grain boundary, with the format of [u,v,w]. - c2_a2_ratio (list of two integers, e.g. mu, mv): - mu/mv is the square of the tetragonal axial ratio with rational number. - if irrational, set c2_a2_ratio = None - - Returns: - dict: sigmas dictionary with keys as the possible integer sigma values - and values as list of the possible rotation angles to the - corresponding sigma values. e.g. the format as - {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} - Note: the angles are the rotation angle of one grain respect to the - other grain. - When generating the microstructure of the grain boundary using these - angles, you need to analyze the symmetry of the structure. Different - angles may result in equivalent microstructures. - """ - sigmas = {} - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - - u, v, w = r_axis - - # make sure mu, mv are coprime integers. - if c2_a2_ratio is None: - mu, mv = [1, 1] - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - else: - mu, mv = c2_a2_ratio - if gcd(mu, mv) != 1: - temp = gcd(mu, mv) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - - # refer to the meaning of d in reference - d = (u**2 + v**2) * mv + w**2 * mu - - # Compute the max n we need to enumerate. - n_max = int(np.sqrt((cutoff * 4 * mu * mv) / d)) - - # Enumerate all possible n, m to give possible sigmas within the cutoff. - for n in range(1, n_max + 1): - m_max = 0 if c2_a2_ratio is None and w == 0 else int(np.sqrt((cutoff * 4 * mu * mv - n**2 * d) / mu)) - for m in range(m_max + 1): - if gcd(m, n) == 1 or m == 0: - # construct the rotation matrix, refer to the reference - R_list = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + mu * m**2, - 2 * v * u * mv * n**2 - 2 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * v * mu * m * n, - 2 * u * v * mv * n**2 + 2 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 + mu * m**2, - 2 * v * w * mu * n**2 - 2 * u * mu * m * n, - 2 * u * w * mv * n**2 - 2 * v * mv * m * n, - 2 * v * w * mv * n**2 + 2 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv) * n**2 + mu * m**2, - ] - m = -1 * m - # inverse of rotation matrix - R_list_inv = [ - (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + mu * m**2, - 2 * v * u * mv * n**2 - 2 * w * mu * m * n, - 2 * u * w * mu * n**2 + 2 * v * mu * m * n, - 2 * u * v * mv * n**2 + 2 * w * mu * m * n, - (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 + mu * m**2, - 2 * v * w * mu * n**2 - 2 * u * mu * m * n, - 2 * u * w * mv * n**2 - 2 * v * mv * m * n, - 2 * v * w * mv * n**2 + 2 * u * mv * m * n, - (w**2 * mu - u**2 * mv - v**2 * mv) * n**2 + mu * m**2, - ] - m = -1 * m - F = mu * m**2 + d * n**2 - all_list = R_list + R_list_inv + [F] - # Compute the max common factors for the elements of the rotation matrix - # and its inverse. - com_fac = reduce(gcd, all_list) - sigma = int(round((mu * m**2 + d * n**2) / com_fac)) - if 1 < sigma <= cutoff: - if sigma not in list(sigmas): - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 - sigmas[sigma] = [angle] - else: - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 - if angle not in sigmas[sigma]: - sigmas[sigma].append(angle) - if m_max == 0: - break - - return sigmas - - @staticmethod - def enum_sigma_ort(cutoff, r_axis, c2_b2_a2_ratio): - """ - Find all possible sigma values and corresponding rotation angles - within a sigma value cutoff with known rotation axis in orthorhombic system. - The algorithm for this code is from reference, Scipta Metallurgica 27, 291(1992). - - Args: - cutoff (int): the cutoff of sigma values. - r_axis (list of 3 integers, e.g. u, v, w): - the rotation axis of the grain boundary, with the format of [u,v,w]. - c2_b2_a2_ratio (list of 3 integers, e.g. mu,lambda, mv): - mu:lam:mv is the square of the orthorhombic axial ratio with rational - numbers. If irrational for one axis, set it to None. - e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. - - Returns: - dict: sigmas dictionary with keys as the possible integer sigma values - and values as list of the possible rotation angles to the - corresponding sigma values. e.g. the format as - {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} - Note: the angles are the rotation angle of one grain respect to the - other grain. - When generating the microstructure of the grain boundary using these - angles, you need to analyze the symmetry of the structure. Different - angles may result in equivalent microstructures. - """ - sigmas = {} - # make sure gcd(r_axis)==1 - if reduce(gcd, r_axis) != 1: - r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] - - u, v, w = r_axis - # make sure mu, lambda, mv are coprime integers. - if None in c2_b2_a2_ratio: - mu, lam, mv = c2_b2_a2_ratio - non_none = [i for i in c2_b2_a2_ratio if i is not None] - if len(non_none) < 2: - raise RuntimeError("No CSL exist for two irrational numbers") - non1, non2 = non_none - if reduce(gcd, non_none) != 1: - temp = reduce(gcd, non_none) - non1 = int(round(non1 / temp)) - non2 = int(round(non2 / temp)) - if mu is None: - lam = non1 - mv = non2 - mu = 1 - if w != 0 and (u != 0 or (v != 0)): - raise RuntimeError("For irrational c2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") - elif lam is None: - mu = non1 - mv = non2 - lam = 1 - if v != 0 and (u != 0 or (w != 0)): - raise RuntimeError("For irrational b2, CSL only exist for [0,1,0] or [u,0,w] and m = 0") - elif mv is None: - mu = non1 - lam = non2 - mv = 1 - if u != 0 and (w != 0 or (v != 0)): - raise RuntimeError("For irrational a2, CSL only exist for [1,0,0] or [0,v,w] and m = 0") - else: - mu, lam, mv = c2_b2_a2_ratio - if reduce(gcd, c2_b2_a2_ratio) != 1: - temp = reduce(gcd, c2_b2_a2_ratio) - mu = int(round(mu / temp)) - mv = int(round(mv / temp)) - lam = int(round(lam / temp)) - if u == 0 and v == 0: - mu = 1 - if u == 0 and w == 0: - lam = 1 - if v == 0 and w == 0: - mv = 1 - # refer to the meaning of d in reference - d = (mv * u**2 + lam * v**2) * mv + w**2 * mu * mv - - # Compute the max n we need to enumerate. - n_max = int(np.sqrt((cutoff * 4 * mu * mv * mv * lam) / d)) - # Enumerate all possible n, m to give possible sigmas within the cutoff. - for n in range(1, n_max + 1): - mu_temp, lam_temp, mv_temp = c2_b2_a2_ratio - if (mu_temp is None and w == 0) or (lam_temp is None and v == 0) or (mv_temp is None and u == 0): - m_max = 0 - else: - m_max = int(np.sqrt((cutoff * 4 * mu * mv * lam * mv - n**2 * d) / mu / lam)) - for m in range(m_max + 1): - if gcd(m, n) == 1 or m == 0: - # construct the rotation matrix, refer to the reference - R_list = [ - (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * lam * (v * u * mv * n**2 - w * mu * m * n), - 2 * mu * (u * w * mv * n**2 + v * lam * m * n), - 2 * mv * (u * v * mv * n**2 + w * mu * m * n), - (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * mv * mu * (v * w * n**2 - u * m * n), - 2 * mv * (u * w * mv * n**2 - v * lam * m * n), - 2 * lam * mv * (v * w * n**2 + u * m * n), - (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, - ] - m = -1 * m - # inverse of rotation matrix - R_list_inv = [ - (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * lam * (v * u * mv * n**2 - w * mu * m * n), - 2 * mu * (u * w * mv * n**2 + v * lam * m * n), - 2 * mv * (u * v * mv * n**2 + w * mu * m * n), - (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, - 2 * mv * mu * (v * w * n**2 - u * m * n), - 2 * mv * (u * w * mv * n**2 - v * lam * m * n), - 2 * lam * mv * (v * w * n**2 + u * m * n), - (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, - ] - m = -1 * m - F = mu * lam * m**2 + d * n**2 - all_list = R_list + R_list_inv + [F] - # Compute the max common factors for the elements of the rotation matrix - # and its inverse. - com_fac = reduce(gcd, all_list) - sigma = int(round((mu * lam * m**2 + d * n**2) / com_fac)) - if 1 < sigma <= cutoff: - if sigma not in list(sigmas): - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu / lam)) / np.pi * 180 - sigmas[sigma] = [angle] - else: - angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu / lam)) / np.pi * 180 - if angle not in sigmas[sigma]: - sigmas[sigma].append(angle) - if m_max == 0: - break - - return sigmas - - @staticmethod - def enum_possible_plane_cubic(plane_cutoff, r_axis, r_angle): - """ - Find all possible plane combinations for GBs given a rotation axis and angle for - cubic system, and classify them to different categories, including 'Twist', - 'Symmetric tilt', 'Normal tilt', 'Mixed' GBs. - - Args: - plane_cutoff (int): the cutoff of plane miller index. - r_axis (list of 3 integers, e.g. u, v, w): - the rotation axis of the grain boundary, with the format of [u,v,w]. - r_angle (float): rotation angle of the GBs. - - Returns: - dict: all combinations with keys as GB type, e.g. 'Twist','Symmetric tilt',etc. - and values as the combination of the two plane miller index (GB plane and joining plane). - """ - all_combinations = {} - all_combinations["Symmetric tilt"] = [] - all_combinations["Twist"] = [] - all_combinations["Normal tilt"] = [] - all_combinations["Mixed"] = [] - sym_plane = symm_group_cubic([[1, 0, 0], [1, 1, 0]]) - j = np.arange(0, plane_cutoff + 1) - combination = [] - for idx in itertools.product(j, repeat=3): - if sum(abs(np.array(idx))) != 0: - combination.append(list(idx)) - if len(np.nonzero(idx)[0]) == 3: - for i1 in range(3): - new_i = list(idx).copy() - new_i[i1] = -1 * new_i[i1] - combination.append(new_i) - elif len(np.nonzero(idx)[0]) == 2: - new_i = list(idx).copy() - new_i[np.nonzero(idx)[0][0]] = -1 * new_i[np.nonzero(idx)[0][0]] - combination.append(new_i) - miller = np.array(combination) - miller = miller[np.argsort(np.linalg.norm(miller, axis=1))] - for val in miller: - if reduce(gcd, val) == 1: - matrix = GrainBoundaryGenerator.get_trans_mat(r_axis, r_angle, surface=val, quick_gen=True) - vec = np.cross(matrix[1][0], matrix[1][1]) - miller2 = GrainBoundaryGenerator.vec_to_surface(vec) - if np.all(np.abs(np.array(miller2)) <= plane_cutoff): - cos_1 = abs(np.dot(val, r_axis) / np.linalg.norm(val) / np.linalg.norm(r_axis)) - if 1 - cos_1 < 1.0e-5: - all_combinations["Twist"].append([list(val), miller2]) - elif cos_1 < 1.0e-8: - sym_tilt = False - if np.sum(np.abs(val)) == np.sum(np.abs(miller2)): - ave = (np.array(val) + np.array(miller2)) / 2 - ave1 = (np.array(val) - np.array(miller2)) / 2 - for plane in sym_plane: - cos_2 = abs(np.dot(ave, plane) / np.linalg.norm(ave) / np.linalg.norm(plane)) - cos_3 = abs(np.dot(ave1, plane) / np.linalg.norm(ave1) / np.linalg.norm(plane)) - if 1 - cos_2 < 1.0e-5 or 1 - cos_3 < 1.0e-5: - all_combinations["Symmetric tilt"].append([list(val), miller2]) - sym_tilt = True - break - if not sym_tilt: - all_combinations["Normal tilt"].append([list(val), miller2]) - else: - all_combinations["Mixed"].append([list(val), miller2]) - return all_combinations - - @staticmethod - def get_rotation_angle_from_sigma(sigma, r_axis, lat_type="C", ratio=None): - """ - Find all possible rotation angle for the given sigma value. - - Args: - sigma (int): sigma value provided - r_axis (list of 3 integers, e.g. u, v, w - or 4 integers, e.g. u, v, t, w for hex/rho system only): - the rotation axis of the grain boundary. - lat_type (str): one character to specify the lattice type. Defaults to 'c' for cubic. - 'c' or 'C': cubic system - 't' or 'T': tetragonal system - 'o' or 'O': orthorhombic system - 'h' or 'H': hexagonal system - 'r' or 'R': rhombohedral system - ratio (list of integers): lattice axial ratio. - For cubic system, ratio is not needed. - For tetragonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, - that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None. - e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. - For rhombohedral system, ratio = [mu, mv], list of two integers, - that is, mu/mv is the ratio of (1+2*cos(alpha)/cos(alpha). - If irrational, set it to None. - For hexagonal system, ratio = [mu, mv], list of two integers, - that is, mu/mv = c2/a2. If it is irrational, set it to none. - - Returns: - rotation_angles corresponding to the provided sigma value. - If the sigma value is not correct, return the rotation angle corresponding - to the correct possible sigma value right smaller than the wrong sigma value provided. - """ - if lat_type.lower() == "c": - logger.info("Make sure this is for cubic system") - sigma_dict = GrainBoundaryGenerator.enum_sigma_cubic(cutoff=sigma, r_axis=r_axis) - elif lat_type.lower() == "t": - logger.info("Make sure this is for tetragonal system") - if ratio is None: - logger.info("Make sure this is for irrational c2/a2 ratio") - elif len(ratio) != 2: - raise RuntimeError("Tetragonal system needs correct c2/a2 ratio") - sigma_dict = GrainBoundaryGenerator.enum_sigma_tet(cutoff=sigma, r_axis=r_axis, c2_a2_ratio=ratio) - elif lat_type.lower() == "o": - logger.info("Make sure this is for orthorhombic system") - if len(ratio) != 3: - raise RuntimeError("Orthorhombic system needs correct c2:b2:a2 ratio") - sigma_dict = GrainBoundaryGenerator.enum_sigma_ort(cutoff=sigma, r_axis=r_axis, c2_b2_a2_ratio=ratio) - elif lat_type.lower() == "h": - logger.info("Make sure this is for hexagonal system") - if ratio is None: - logger.info("Make sure this is for irrational c2/a2 ratio") - elif len(ratio) != 2: - raise RuntimeError("Hexagonal system needs correct c2/a2 ratio") - sigma_dict = GrainBoundaryGenerator.enum_sigma_hex(cutoff=sigma, r_axis=r_axis, c2_a2_ratio=ratio) - elif lat_type.lower() == "r": - logger.info("Make sure this is for rhombohedral system") - if ratio is None: - logger.info("Make sure this is for irrational (1+2*cos(alpha)/cos(alpha) ratio") - elif len(ratio) != 2: - raise RuntimeError("Rhombohedral system needs correct (1+2*cos(alpha)/cos(alpha) ratio") - sigma_dict = GrainBoundaryGenerator.enum_sigma_rho(cutoff=sigma, r_axis=r_axis, ratio_alpha=ratio) - else: - raise RuntimeError("Lattice type not implemented") - - sigmas = list(sigma_dict) - if not sigmas: - raise RuntimeError("This is a wrong sigma value, and no sigma exists smaller than this value.") - if sigma in sigmas: - rotation_angles = sigma_dict[sigma] - else: - sigmas.sort() - warnings.warn( - "This is not the possible sigma value according to the rotation axis!" - "The nearest neighbor sigma and its corresponding angle are returned" - ) - rotation_angles = sigma_dict[sigmas[-1]] - rotation_angles.sort() - return rotation_angles - - @staticmethod - def slab_from_csl(csl, surface, normal, trans_cry, max_search=20, quick_gen=False): - """ - By linear operation of csl lattice vectors to get the best corresponding - slab lattice. That is the area of a,b vectors (within the surface plane) - is the smallest, the c vector first, has shortest length perpendicular - to surface [h,k,l], second, has shortest length itself. - - Args: - csl (3 by 3 integer array): - input csl lattice. - surface (list of 3 integers, e.g. h, k, l): - the miller index of the surface, with the format of [h,k,l] - normal (logic): - determine if the c vector needs to perpendicular to surface - trans_cry (3 by 3 array): - transform matrix from crystal system to orthogonal system - max_search (int): max search for the GB lattice vectors that give the smallest GB - lattice. If normal is true, also max search the GB c vector that perpendicular - to the plane. - quick_gen (bool): whether to quickly generate a supercell, no need to find the smallest - cell if set to true. - - Returns: - t_matrix: a slab lattice ( 3 by 3 integer array): - """ - # set the transform matrix in real space - trans = trans_cry - # transform matrix in reciprocal space - ctrans = np.linalg.inv(trans.T) - - t_matrix = csl.copy() - # vectors constructed from csl that perpendicular to surface - ab_vector = [] - # obtain the miller index of surface in terms of csl. - miller = np.matmul(surface, csl.T) - if reduce(gcd, miller) != 1: - miller = [int(round(x / reduce(gcd, miller))) for x in miller] - miller_nonzero = [] - # quickly generate a supercell, normal is not work in this way - if quick_gen: - scale_factor = [] - eye = np.eye(3, dtype=int) - for ii, jj in enumerate(miller): - if jj == 0: - scale_factor.append(eye[ii]) - else: - miller_nonzero.append(ii) - if len(scale_factor) < 2: - index_len = len(miller_nonzero) - for ii in range(index_len): - for jj in range(ii + 1, index_len): - lcm_miller = lcm(miller[miller_nonzero[ii]], miller[miller_nonzero[jj]]) - scl_factor = [0, 0, 0] - scl_factor[miller_nonzero[ii]] = -int(round(lcm_miller / miller[miller_nonzero[ii]])) - scl_factor[miller_nonzero[jj]] = int(round(lcm_miller / miller[miller_nonzero[jj]])) - scale_factor.append(scl_factor) - if len(scale_factor) == 2: - break - t_matrix[0] = np.array(np.dot(scale_factor[0], csl)) - t_matrix[1] = np.array(np.dot(scale_factor[1], csl)) - t_matrix[2] = csl[miller_nonzero[0]] - if abs(np.linalg.det(t_matrix)) > 1000: - warnings.warn("Too large matrix. Suggest to use quick_gen=False") - return t_matrix - - for ii, jj in enumerate(miller): - if jj == 0: - ab_vector.append(csl[ii]) - else: - c_index = ii - miller_nonzero.append(jj) - - if len(miller_nonzero) > 1: - t_matrix[2] = csl[c_index] - index_len = len(miller_nonzero) - lcm_miller = [] - for ii in range(index_len): - for jj in range(ii + 1, index_len): - com_gcd = gcd(miller_nonzero[ii], miller_nonzero[jj]) - mil1 = int(round(miller_nonzero[ii] / com_gcd)) - mil2 = int(round(miller_nonzero[jj] / com_gcd)) - lcm_miller.append(max(abs(mil1), abs(mil2))) - lcm_sorted = sorted(lcm_miller) - max_j = lcm_sorted[0] if index_len == 2 else lcm_sorted[1] - else: - if not normal: - t_matrix[0] = ab_vector[0] - t_matrix[1] = ab_vector[1] - t_matrix[2] = csl[c_index] - return t_matrix - max_j = abs(miller_nonzero[0]) - max_j = min(max_j, max_search) - # area of a, b vectors - area = None - # length of c vector - c_norm = np.linalg.norm(np.matmul(t_matrix[2], trans)) - # c vector length along the direction perpendicular to surface - c_length = np.abs(np.dot(t_matrix[2], surface)) - # check if the init c vector perpendicular to the surface - if normal: - c_cross = np.cross(np.matmul(t_matrix[2], trans), np.matmul(surface, ctrans)) - normal_init = np.linalg.norm(c_cross) < 1e-8 - - jj = np.arange(0, max_j + 1) - combination = [] - for ii in itertools.product(jj, repeat=3): - if sum(abs(np.array(ii))) != 0: - combination.append(list(ii)) - if len(np.nonzero(ii)[0]) == 3: - for i1 in range(3): - new_i = list(ii).copy() - new_i[i1] = -1 * new_i[i1] - combination.append(new_i) - elif len(np.nonzero(ii)[0]) == 2: - new_i = list(ii).copy() - new_i[np.nonzero(ii)[0][0]] = -1 * new_i[np.nonzero(ii)[0][0]] - combination.append(new_i) - for ii in combination: - if reduce(gcd, ii) == 1: - temp = np.dot(np.array(ii), csl) - if abs(np.dot(temp, surface) - 0) < 1.0e-8: - ab_vector.append(temp) - else: - # c vector length along the direction perpendicular to surface - c_len_temp = np.abs(np.dot(temp, surface)) - # c vector length itself - c_norm_temp = np.linalg.norm(np.matmul(temp, trans)) - if normal: - c_cross = np.cross(np.matmul(temp, trans), np.matmul(surface, ctrans)) - if np.linalg.norm(c_cross) < 1.0e-8: - if normal_init: - if c_norm_temp < c_norm: - t_matrix[2] = temp - c_norm = c_norm_temp - else: - c_norm = c_norm_temp - normal_init = True - t_matrix[2] = temp - elif c_len_temp < c_length or (abs(c_len_temp - c_length) < 1.0e-8 and c_norm_temp < c_norm): - t_matrix[2] = temp - c_norm = c_norm_temp - c_length = c_len_temp - - if normal and (not normal_init): - logger.info("Did not find the perpendicular c vector, increase max_j") - while not normal_init: - if max_j == max_search: - warnings.warn("Cannot find the perpendicular c vector, please increase max_search") - break - max_j = 3 * max_j - max_j = min(max_j, max_search) - jj = np.arange(0, max_j + 1) - combination = [] - for ii in itertools.product(jj, repeat=3): - if sum(abs(np.array(ii))) != 0: - combination.append(list(ii)) - if len(np.nonzero(ii)[0]) == 3: - for i1 in range(3): - new_i = list(ii).copy() - new_i[i1] = -1 * new_i[i1] - combination.append(new_i) - elif len(np.nonzero(ii)[0]) == 2: - new_i = list(ii).copy() - new_i[np.nonzero(ii)[0][0]] = -1 * new_i[np.nonzero(ii)[0][0]] - combination.append(new_i) - for ii in combination: - if reduce(gcd, ii) == 1: - temp = np.dot(np.array(ii), csl) - if abs(np.dot(temp, surface) - 0) > 1.0e-8: - c_cross = np.cross(np.matmul(temp, trans), np.matmul(surface, ctrans)) - if np.linalg.norm(c_cross) < 1.0e-8: - # c vector length itself - c_norm_temp = np.linalg.norm(np.matmul(temp, trans)) - if normal_init: - if c_norm_temp < c_norm: - t_matrix[2] = temp - c_norm = c_norm_temp - else: - c_norm = c_norm_temp - normal_init = True - t_matrix[2] = temp - if normal_init: - logger.info("Found perpendicular c vector") - - # find the best a, b vectors with their formed area smallest and average norm of a,b smallest. - for ii in itertools.combinations(ab_vector, 2): - area_temp = np.linalg.norm(np.cross(np.matmul(ii[0], trans), np.matmul(ii[1], trans))) - if abs(area_temp - 0) > 1.0e-8: - ab_norm_temp = np.linalg.norm(np.matmul(ii[0], trans)) + np.linalg.norm(np.matmul(ii[1], trans)) - if area is None: - area = area_temp - ab_norm = ab_norm_temp - t_matrix[0] = ii[0] - t_matrix[1] = ii[1] - elif area_temp < area or (abs(area - area_temp) < 1.0e-8 and ab_norm_temp < ab_norm): - t_matrix[0] = ii[0] - t_matrix[1] = ii[1] - area = area_temp - ab_norm = ab_norm_temp - - # make sure we have a left-handed crystallographic system - if np.linalg.det(np.matmul(t_matrix, trans)) < 0: - t_matrix *= -1 - - if normal and abs(np.linalg.det(t_matrix)) > 1000: - warnings.warn("Too large matrix. Suggest to use Normal=False") - return t_matrix - - @staticmethod - def reduce_mat(mat, mag, r_matrix): - """ - Reduce integer array mat's determinant mag times by linear combination - of its row vectors, so that the new array after rotation (r_matrix) is - still an integer array. - - Args: - mat (3 by 3 array): input matrix - mag (int): reduce times for the determinant - r_matrix (3 by 3 array): rotation matrix - Return: - the reduced integer array - """ - max_j = abs(int(round(np.linalg.det(mat) / mag))) - reduced = False - for h in range(3): - kk = h + 1 if h + 1 < 3 else abs(2 - h) - ll = h + 2 if h + 2 < 3 else abs(1 - h) - jj = np.arange(-max_j, max_j + 1) - for j1, j2 in itertools.product(jj, repeat=2): - temp = mat[h] + j1 * mat[kk] + j2 * mat[ll] - if all(np.round(x, 5).is_integer() for x in list(temp / mag)): - mat_copy = mat.copy() - mat_copy[h] = np.array([int(round(ele / mag)) for ele in temp]) - new_mat = np.dot(mat_copy, np.linalg.inv(r_matrix.T)) - if all(np.round(x, 5).is_integer() for x in list(np.ravel(new_mat))): - reduced = True - mat[h] = np.array([int(round(ele / mag)) for ele in temp]) - break - if reduced: - break - - if not reduced: - warnings.warn("Matrix reduction not performed, may lead to non-primitive GB cell.") - return mat - - @staticmethod - def vec_to_surface(vec): - """ - Transform a float vector to a surface miller index with integers. - - Args: - vec (1 by 3 array float vector): input float vector - Return: - the surface miller index of the input vector. - """ - miller = [None] * 3 - index = [] - for idx, value in enumerate(vec): - if abs(value) < 1.0e-8: - miller[idx] = 0 - else: - index.append(idx) - if len(index) == 1: - miller[index[0]] = 1 - else: - min_index = np.argmin([i for i in vec if i != 0]) - true_index = index[min_index] - index.pop(min_index) - frac = [] - for value in index: - frac.append(Fraction(vec[value] / vec[true_index]).limit_denominator(100)) - if len(index) == 1: - miller[true_index] = frac[0].denominator - miller[index[0]] = frac[0].numerator - else: - com_lcm = lcm(frac[0].denominator, frac[1].denominator) - miller[true_index] = com_lcm - miller[index[0]] = frac[0].numerator * int(round(com_lcm / frac[0].denominator)) - miller[index[1]] = frac[1].numerator * int(round(com_lcm / frac[1].denominator)) - return miller - - -def fix_pbc(structure, matrix=None): - """ - Set all frac_coords of the input structure within [0,1]. - - Args: - structure (pymatgen structure object): input structure - matrix (lattice matrix, 3 by 3 array/matrix): new structure's lattice matrix, - If None, use input structure's matrix. - - Return: - new structure with fixed frac_coords and lattice matrix - """ - spec = [] - coords = [] - latte = Lattice(structure.lattice.matrix) if matrix is None else Lattice(matrix) - - for site in structure: - spec.append(site.specie) - coord = np.array(site.frac_coords) - for i in range(3): - coord[i] -= floor(coord[i]) - if np.allclose(coord[i], 1) or np.allclose(coord[i], 0): - coord[i] = 0 - else: - coord[i] = round(coord[i], 7) - coords.append(coord) - - return Structure(latte, spec, coords, site_properties=structure.site_properties) - - -def symm_group_cubic(mat): - """ - obtain cubic symmetric equivalents of the list of vectors. - - Args: - matrix (lattice matrix, n by 3 array/matrix) - Return: - cubic symmetric equivalents of the list of vectors. - """ - sym_group = np.zeros([24, 3, 3]) - sym_group[0, :] = np.eye(3) - sym_group[1, :] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] - sym_group[2, :] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] - sym_group[3, :] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] - sym_group[4, :] = [[0, -1, 0], [-1, 0, 0], [0, 0, -1]] - sym_group[5, :] = [[0, -1, 0], [1, 0, 0], [0, 0, 1]] - sym_group[6, :] = [[0, 1, 0], [-1, 0, 0], [0, 0, 1]] - sym_group[7, :] = [[0, 1, 0], [1, 0, 0], [0, 0, -1]] - sym_group[8, :] = [[-1, 0, 0], [0, 0, -1], [0, -1, 0]] - sym_group[9, :] = [[-1, 0, 0], [0, 0, 1], [0, 1, 0]] - sym_group[10, :] = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] - sym_group[11, :] = [[1, 0, 0], [0, 0, 1], [0, -1, 0]] - sym_group[12, :] = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] - sym_group[13, :] = [[0, 1, 0], [0, 0, -1], [-1, 0, 0]] - sym_group[14, :] = [[0, -1, 0], [0, 0, 1], [-1, 0, 0]] - sym_group[15, :] = [[0, -1, 0], [0, 0, -1], [1, 0, 0]] - sym_group[16, :] = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - sym_group[17, :] = [[0, 0, 1], [-1, 0, 0], [0, -1, 0]] - sym_group[18, :] = [[0, 0, -1], [1, 0, 0], [0, -1, 0]] - sym_group[19, :] = [[0, 0, -1], [-1, 0, 0], [0, 1, 0]] - sym_group[20, :] = [[0, 0, -1], [0, -1, 0], [-1, 0, 0]] - sym_group[21, :] = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] - sym_group[22, :] = [[0, 0, 1], [0, -1, 0], [1, 0, 0]] - sym_group[23, :] = [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] +from pymatgen.core.interface import * # noqa: F403 - mat = np.atleast_2d(mat) - all_vectors = [] - for sym in sym_group: - for vec in mat: - all_vectors.append(np.dot(sym, vec)) - return np.unique(np.array(all_vectors), axis=0) +warnings.warn( + "Grain boundary analysis has been moved to pymatgen.core.interface." + "This stub is retained for backwards compatibility and will be removed Dec 31 2024.", + DeprecationWarning, +) diff --git a/pymatgen/analysis/graphs.py b/pymatgen/analysis/graphs.py index b8d4ec7f88c..974a02572b6 100644 --- a/pymatgen/analysis/graphs.py +++ b/pymatgen/analysis/graphs.py @@ -11,11 +11,12 @@ from itertools import combinations from operator import itemgetter from shutil import which -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast import networkx as nx import networkx.algorithms.isomorphism as iso import numpy as np +from monty.dev import deprecated from monty.json import MSONable from networkx.drawing.nx_agraph import write_dot from networkx.readwrite import json_graph @@ -32,6 +33,16 @@ except ImportError: igraph = None +if TYPE_CHECKING: + from collections.abc import Sequence + + from igraph import Graph + from numpy.typing import ArrayLike + from typing_extensions import Self + + from pymatgen.analysis.local_env import NearNeighbors + from pymatgen.core import Species + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -46,12 +57,12 @@ ConnectedSite = namedtuple("ConnectedSite", "site, jimage, index, weight, dist") -def _compare(g1, g2, i1, i2): +def _compare(g1, g2, i1, i2) -> bool: """Helper function called by isomorphic to ensure comparison of node identities.""" return g1.vs[i1]["species"] == g2.vs[i2]["species"] -def _igraph_from_nxgraph(graph): +def _igraph_from_nxgraph(graph) -> Graph: """Helper function that converts a networkx graph object into an igraph graph object.""" nodes = graph.nodes(data=True) new_igraph = igraph.Graph() @@ -63,7 +74,7 @@ def _igraph_from_nxgraph(graph): def _isomorphic(frag1: nx.Graph, frag2: nx.Graph) -> bool: """ - Internal function to check if two graph objects are isomorphic, using igraph if + Helper function to check if two graph objects are isomorphic, using igraph if if is available and networkx if it is not. """ f1_nodes = frag1.nodes(data=True) @@ -103,10 +114,10 @@ class StructureGraph(MSONable): any kind of information that connects two Sites. """ - def __init__(self, structure: Structure, graph_data=None): + def __init__(self, structure: Structure, graph_data: dict | None = None) -> None: """ - If constructing this class manually, use the with_empty_graph method or - with_local_env_strategy method (using an algorithm provided by the local_env + If constructing this class manually, use the from_empty_graph method or + from_local_env_strategy method (using an algorithm provided by the local_env module, such as O'Keeffe). This class that contains connection information: relationships between sites represented by a Graph structure, and an associated structure object. @@ -132,24 +143,24 @@ def __init__(self, structure: Structure, graph_data=None): self.graph = nx.readwrite.json_graph.adjacency_graph(graph_data) # tidy up edge attr dicts, reading to/from json duplicates information - for _, _, _, dct in self.graph.edges(keys=True, data=True): + for _, _, _, data in self.graph.edges(keys=True, data=True): for key in ("id", "key"): - dct.pop(key, None) + data.pop(key, None) # ensure images are tuples (conversion to lists happens when serializing back # from json), it's important images are hashable/immutable - if to_img := dct.get("to_jimage"): - dct["to_jimage"] = tuple(to_img) - if from_img := dct.get("from_jimage"): - dct["from_jimage"] = tuple(from_img) + if to_img := data.get("to_jimage"): + data["to_jimage"] = tuple(to_img) + if from_img := data.get("from_jimage"): + data["from_jimage"] = tuple(from_img) @classmethod - def with_empty_graph( + def from_empty_graph( cls, structure: Structure, name: str = "bonds", edge_weight_name: str | None = None, edge_weight_units: str | None = None, - ) -> StructureGraph: + ) -> Self: """ Constructor for an empty StructureGraph, i.e. no edges, containing only nodes corresponding to sites in Structure. @@ -186,23 +197,32 @@ def with_empty_graph( return cls(structure, graph_data=graph_data) - @staticmethod - def with_edges(structure, edges): + @classmethod + @deprecated( + from_empty_graph, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_empty_graph(cls, *args, **kwargs): + return cls.from_empty_graph(*args, **kwargs) + + @classmethod + def from_edges(cls, structure: Structure, edges: dict) -> Self: """ Constructor for MoleculeGraph, using pre-existing or pre-defined edges with optional edge parameters. - :param molecule: Molecule object - :param edges: dict representing the bonds of the functional - group (format: {(from_index, to_index, from_image, to_image): props}, - where props is a dictionary of properties, including weight. - Props should be None if no additional properties are to be - specified. + Args: + structure: Structure object + edges: dict representing the bonds of the functional + group (format: {(from_index, to_index, from_image, to_image): props}, + where props is a dictionary of properties, including weight. + Props should be None if no additional properties are to be + specified. Returns: sg, a StructureGraph """ - sg = StructureGraph.with_empty_graph(structure, name="bonds", edge_weight_name="weight", edge_weight_units="") + struct_graph = cls.from_empty_graph(structure, name="bonds", edge_weight_name="weight", edge_weight_units="") for edge, props in edges.items(): try: @@ -221,13 +241,13 @@ def with_edges(structure, edges): else: weight = None - nodes = sg.graph.nodes + nodes = struct_graph.graph.nodes if not (from_index in nodes and to_index in nodes): raise ValueError( "Edges cannot be added if nodes are not present in the graph. Please check your indices." ) - sg.add_edge( + struct_graph.add_edge( from_index, to_index, from_jimage=from_image, @@ -236,26 +256,35 @@ def with_edges(structure, edges): edge_properties=props, ) - sg.set_node_attributes() - return sg + struct_graph.set_node_attributes() + return struct_graph + + @classmethod + @deprecated( + from_edges, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_edges(cls, *args, **kwargs): + return cls.from_edges(*args, **kwargs) - @staticmethod - def with_local_env_strategy(structure, strategy, weights=False, edge_properties=False): + @classmethod + def from_local_env_strategy( + cls, structure: Structure, strategy: NearNeighbors, weights: bool = False, edge_properties: bool = False + ) -> Self: """ Constructor for StructureGraph, using a strategy from pymatgen.analysis.local_env. - :param structure: Structure object - :param strategy: an instance of a - pymatgen.analysis.local_env.NearNeighbors object - :param weights: if True, use weights from local_env class - (consult relevant class for their meaning) - :param edge_properties: if True, edge_properties from neighbors will be used + Args: + structure: Structure object + strategy: an instance of a pymatgen.analysis.local_env.NearNeighbors object + weights(bool): if True, use weights from local_env class (consult relevant class for their meaning) + edge_properties(bool): if True, edge_properties from neighbors will be used """ if not strategy.structures_allowed: raise ValueError("Chosen strategy is not designed for use with structures! Please choose another strategy.") - sg = StructureGraph.with_empty_graph(structure, name="bonds") + struct_graph = cls.from_empty_graph(structure, name="bonds") for idx, neighbors in enumerate(strategy.get_all_nn_info(structure)): for neighbor in neighbors: @@ -263,7 +292,7 @@ def with_local_env_strategy(structure, strategy, weights=False, edge_properties= # for any one bond, one from site u to site v # and another form site v to site u: this is # harmless, so warn_duplicates=False - sg.add_edge( + struct_graph.add_edge( from_index=idx, from_jimage=(0, 0, 0), to_index=neighbor["site_index"], @@ -273,15 +302,23 @@ def with_local_env_strategy(structure, strategy, weights=False, edge_properties= warn_duplicates=False, ) - return sg + return struct_graph + + @classmethod + @deprecated( + from_local_env_strategy, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_local_env_strategy(cls, *args, **kwargs): + return cls.from_local_env_strategy(*args, **kwargs) @property - def name(self): + def name(self) -> str: """Name of graph""" return self.graph.graph["name"] @property - def edge_weight_name(self): + def edge_weight_name(self) -> str: """Name of the edge weight property of graph""" return self.graph.graph["edge_weight_name"] @@ -292,14 +329,14 @@ def edge_weight_unit(self): def add_edge( self, - from_index, - to_index, - from_jimage=(0, 0, 0), - to_jimage=None, - weight=None, - warn_duplicates=True, - edge_properties=None, - ): + from_index: int, + to_index: int, + from_jimage: tuple[int, int, int] = (0, 0, 0), + to_jimage: tuple[int, int, int] | None = None, + weight: float | None = None, + warn_duplicates: bool = True, + edge_properties: dict | None = None, + ) -> None: """ Add edge to graph. @@ -310,17 +347,18 @@ def add_edge( However, images will always be shifted so that from_index < to_index and from_jimage becomes (0, 0, 0). - :param from_index: index of site connecting from - :param to_index: index of site connecting to - :param from_jimage (tuple of ints): lattice vector of periodic - image, e.g. (1, 0, 0) for periodic image in +x direction - :param to_jimage (tuple of ints): lattice vector of image - :param weight (float): e.g. bond length - :param warn_duplicates (bool): if True, will warn if - trying to add duplicate edges (duplicate edges will not - be added in either case) - :param edge_properties (dict): any other information to - store on graph edges, similar to Structure's site_properties + Args: + from_index: index of site connecting from + to_index: index of site connecting to + from_jimage (tuple of ints): lattice vector of periodic + image, e.g. (1, 0, 0) for periodic image in +x direction + to_jimage (tuple of ints): lattice vector of image + weight (float): e.g. bond length + warn_duplicates (bool): if True, will warn if + trying to add duplicate edges (duplicate edges will not + be added in either case) + edge_properties (dict): any other information to + store on graph edges, similar to Structure's site_properties """ # this is not necessary for the class to work, but # just makes it neater @@ -369,7 +407,7 @@ def add_edge( return # sanitize types - from_jimage, to_jimage = tuple(map(int, from_jimage)), tuple(map(int, to_jimage)) + from_jimage, to_jimage = tuple(map(int, from_jimage)), tuple(map(int, to_jimage)) # type: ignore[assignment] from_index, to_index = int(from_index), int(to_index) # if edge is from site i to site i, constrain direction of edge @@ -386,7 +424,7 @@ def add_edge( if not is_positive: # let's flip the jimage, # e.g. (0, 1, 0) is equivalent to (0, -1, 0) in this case - to_jimage = tuple(-idx for idx in to_jimage) + to_jimage = tuple(-idx for idx in to_jimage) # type: ignore[assignment] # check we're not trying to add a duplicate edge # there should only ever be at most one edge @@ -420,28 +458,29 @@ def add_edge( def insert_node( self, - idx, - species, - coords, - coords_are_cartesian=False, - validate_proximity=False, - site_properties=None, - edges=None, - ): + idx: int, + species: Species, + coords: ArrayLike, + coords_are_cartesian: bool = False, + validate_proximity: bool = False, + site_properties: dict | None = None, + edges: list | dict | None = None, + ) -> None: """ A wrapper around Molecule.insert(), which also incorporates the new site into the MoleculeGraph. - :param idx: Index at which to insert the new site - :param species: Species for the new site - :param coords: 3x1 array representing coordinates of the new site - :param coords_are_cartesian: Whether coordinates are cartesian. - Defaults to False. - :param validate_proximity: For Molecule.insert(); if True (default - False), distance will be checked to ensure that site can be safely - added. - :param site_properties: Site properties for Molecule - :param edges: List of dicts representing edges to be added to the + Args: + idx: Index at which to insert the new site + species: Species for the new site + coords: 3x1 array representing coordinates of the new site + coords_are_cartesian: Whether coordinates are cartesian. + Defaults to False. + validate_proximity: For Molecule.insert(); if True (default + False), distance will be checked to ensure that + site can be safely added. + site_properties: Site properties for Molecule + edges: List of dicts representing edges to be added to the MoleculeGraph. These edges must include the index of the new site i, and all indices used for these edges should reflect the MoleculeGraph AFTER the insertion, NOT before. Each dict should at @@ -482,7 +521,7 @@ def insert_node( except KeyError: raise RuntimeError("Some edges are invalid.") - def set_node_attributes(self): + def set_node_attributes(self) -> None: """ Gives each node a "specie" and a "coords" attribute, updated with the current species and coordinates. @@ -501,27 +540,28 @@ def set_node_attributes(self): def alter_edge( self, - from_index, - to_index, - to_jimage=None, - new_weight=None, - new_edge_properties=None, + from_index: int, + to_index: int, + to_jimage: tuple | None = None, + new_weight: float | None = None, + new_edge_properties: dict | None = None, ): """ Alters either the weight or the edge_properties of an edge in the StructureGraph. - :param from_index: int - :param to_index: int - :param to_jimage: tuple - :param new_weight: alter_edge does not require - that weight be altered. As such, by default, this - is None. If weight is to be changed, it should be a - float. - :param new_edge_properties: alter_edge does not require - that edge_properties be altered. As such, by default, - this is None. If any edge properties are to be changed, - it should be a dictionary of edge properties to be changed. + Args: + from_index: int + to_index: int + to_jimage: tuple + new_weight: alter_edge does not require + that weight be altered. As such, by default, this + is None. If weight is to be changed, it should be a + float. + new_edge_properties: alter_edge does not require + that edge_properties be altered. As such, by default, + this is None. If any edge properties are to be changed, + it should be a dictionary of edge properties to be changed. """ existing_edges = self.graph.get_edge_data(from_index, to_index) @@ -545,16 +585,19 @@ def alter_edge( for prop in list(new_edge_properties): self.graph[from_index][to_index][edge_index][prop] = new_edge_properties[prop] - def break_edge(self, from_index, to_index, to_jimage=None, allow_reverse=False): + def break_edge( + self, from_index: int, to_index: int, to_jimage: tuple | None = None, allow_reverse: bool = False + ) -> None: """ Remove an edge from the StructureGraph. If no image is given, this method will fail. - :param from_index: int - :param to_index: int - :param to_jimage: tuple - :param allow_reverse: If allow_reverse is True, then break_edge will - attempt to break both (from_index, to_index) and, failing that, - will attempt to break (to_index, from_index). + Args: + from_index: int + to_index: int + to_jimage: tuple + allow_reverse: If allow_reverse is True, then break_edge will + attempt to break both (from_index, to_index) and, failing that, + will attempt to break (to_index, from_index). """ # ensure that edge exists before attempting to remove it existing_edges = self.graph.get_edge_data(from_index, to_index) @@ -586,12 +629,13 @@ def break_edge(self, from_index, to_index, to_jimage=None, allow_reverse=False): f"no edge exists between those sites." ) - def remove_nodes(self, indices): + def remove_nodes(self, indices: Sequence[int | None]) -> None: """ A wrapper for Molecule.remove_sites(). - :param indices: list of indices in the current Molecule (and graph) to - be removed. + Args: + indices: list of indices in the current Molecule (and graph) to + be removed. """ self.structure.remove_sites(indices) self.graph.remove_nodes_from(indices) @@ -603,12 +647,12 @@ def remove_nodes(self, indices): def substitute_group( self, - index, - func_grp, - strategy, - bond_order=1, - graph_dict=None, - strategy_params=None, + index: int, + func_grp: Molecule | str, + strategy: Any, + bond_order: int = 1, + graph_dict: dict | None = None, + strategy_params: dict | None = None, ): """ Builds off of Structure.substitute to replace an atom in self.structure @@ -619,34 +663,34 @@ def substitute_group( substituted will not place atoms to close to each other, or violate the dimensions of the Lattice. - :param index: Index of atom to substitute. - :param func_grp: Substituent molecule. There are two options: - - 1. Providing an actual Molecule as the input. The first atom - must be a DummySpecies X, indicating the position of - nearest neighbor. The second atom must be the next - nearest atom. For example, for a methyl group - substitution, func_grp should be X-CH3, where X is the - first site and C is the second site. What the code will - do is to remove the index site, and connect the nearest - neighbor to the C atom in CH3. The X-C bond indicates the - directionality to connect the atoms. - 2. A string name. The molecule will be obtained from the - relevant template in func_groups.json. - :param strategy: Class from pymatgen.analysis.local_env. - :param bond_order: A specified bond order to calculate the bond - length between the attached functional group and the nearest - neighbor site. Defaults to 1. - :param graph_dict: Dictionary representing the bonds of the functional - group (format: {(u, v): props}, where props is a dictionary of - properties, including weight. If None, then the algorithm - will attempt to automatically determine bonds using one of - a list of strategies defined in pymatgen.analysis.local_env. - :param strategy_params: dictionary of keyword arguments for strategy. - If None, default parameters will be used. + Args: + index: Index of atom to substitute. + func_grp: Substituent molecule. There are two options: + 1. Providing an actual Molecule as the input. The first atom + must be a DummySpecies X, indicating the position of + nearest neighbor. The second atom must be the next + nearest atom. For example, for a methyl group + substitution, func_grp should be X-CH3, where X is the + first site and C is the second site. What the code will + do is to remove the index site, and connect the nearest + neighbor to the C atom in CH3. The X-C bond indicates the + directionality to connect the atoms. + 2. A string name. The molecule will be obtained from the + relevant template in func_groups.json. + strategy: Class from pymatgen.analysis.local_env. + bond_order: A specified bond order to calculate the bond + length between the attached functional group and the nearest + neighbor site. Defaults to 1. + graph_dict: Dictionary representing the bonds of the functional + group (format: {(u, v): props}, where props is a dictionary of + properties, including weight. If None, then the algorithm + will attempt to automatically determine bonds using one of + a list of strategies defined in pymatgen.analysis.local_env. + strategy_params: dictionary of keyword arguments for strategy. + If None, default parameters will be used. """ - def map_indices(grp): + def map_indices(grp: Molecule) -> dict[int, int]: grp_map = {} # Get indices now occupied by functional group @@ -706,15 +750,17 @@ def map_indices(grp): warn_duplicates=False, ) - def get_connected_sites(self, n, jimage=(0, 0, 0)): + def get_connected_sites(self, n: int, jimage: tuple[int, int, int] = (0, 0, 0)) -> list[ConnectedSite]: """ Returns a named tuple of neighbors of site n: periodic_site, jimage, index, weight. Index is the index of the corresponding site in the original structure, weight can be None if not defined. - :param n: index of Site in Structure - :param jimage: lattice vector of site + + Args: + n: index of Site in Structure + jimage: lattice vector of site Returns: list of ConnectedSite tuples, @@ -726,8 +772,8 @@ def get_connected_sites(self, n, jimage=(0, 0, 0)): out_edges = [(u, v, d, "out") for u, v, d in self.graph.out_edges(n, data=True)] in_edges = [(u, v, d, "in") for u, v, d in self.graph.in_edges(n, data=True)] - for u, v, d, dir in out_edges + in_edges: - to_jimage = d["to_jimage"] + for u, v, data, dir in out_edges + in_edges: + to_jimage = data["to_jimage"] if dir == "in": u, v = v, u @@ -740,9 +786,10 @@ def get_connected_sites(self, n, jimage=(0, 0, 0)): # from_site if jimage arg != (0, 0, 0) relative_jimage = np.subtract(to_jimage, jimage) - dist = self.structure[u].distance(self.structure[v], jimage=relative_jimage) + u_site = cast(PeriodicSite, self.structure[u]) # tell mypy that u_site is a PeriodicSite + dist = u_site.distance(self.structure[v], jimage=relative_jimage) - weight = d.get("weight") + weight = data.get("weight") if (v, to_jimage) not in connected_site_images: connected_site = ConnectedSite(site=site, jimage=to_jimage, index=v, weight=weight, dist=dist) @@ -751,35 +798,38 @@ def get_connected_sites(self, n, jimage=(0, 0, 0)): connected_site_images.add((v, to_jimage)) # return list sorted by closest sites first - connected_sites = list(connected_sites) - connected_sites.sort(key=lambda x: x.dist) + _connected_sites = list(connected_sites) + _connected_sites.sort(key=lambda x: x.dist) - return connected_sites + return _connected_sites - def get_coordination_of_site(self, n): + def get_coordination_of_site(self, n: int) -> int: """ - Returns the number of neighbors of site n. - In graph terms, simply returns degree - of node corresponding to site n. - :param n: index of site - :return (int): + Returns the number of neighbors of site n. In graph terms, + simply returns degree of node corresponding to site n. + + Args: + n: index of site + + Returns: + int: number of neighbors of site n. """ n_self_loops = sum(1 for n, v in self.graph.edges(n) if n == v) return self.graph.degree(n) - n_self_loops def draw_graph_to_file( self, - filename="graph", - diff=None, - hide_unconnected_nodes=False, - hide_image_edges=True, - edge_colors=False, - node_labels=False, - weight_labels=False, - image_labels=False, - color_scheme="VESTA", - keep_dot=False, - algo="fdp", + filename: str = "graph", + diff: StructureGraph = None, + hide_unconnected_nodes: bool = False, + hide_image_edges: bool = True, + edge_colors: bool = False, + node_labels: bool = False, + weight_labels: bool = False, + image_labels: bool = False, + color_scheme: str = "VESTA", + keep_dot: bool = False, + algo: str = "fdp", ): """ Draws graph using GraphViz. @@ -793,31 +843,27 @@ def draw_graph_to_file( `hide_image_edges` can help, especially in larger graphs. - :param filename: filename to output, will detect filetype - from extension (any graphviz filetype supported, such as - pdf or png) - :param diff (StructureGraph): an additional graph to - compare with, will color edges red that do not exist in diff - and edges green that are in diff graph but not in the - reference graph - :param hide_unconnected_nodes: if True, hide unconnected - nodes - :param hide_image_edges: if True, do not draw edges that - go through periodic boundaries - :param edge_colors (bool): if True, use node colors to - color edges - :param node_labels (bool): if True, label nodes with - species and site index - :param weight_labels (bool): if True, label edges with - weights - :param image_labels (bool): if True, label edges with - their periodic images (usually only used for debugging, - edges to periodic images always appear as dashed lines) - :param color_scheme (str): "VESTA" or "JMOL" - :param keep_dot (bool): keep GraphViz .dot file for later - visualization - :param algo: any graphviz algo, "neato" (for simple graphs) - or "fdp" (for more crowded graphs) usually give good outputs + Args: + filename: filename to output, will detect filetype + from extension (any graphviz filetype supported, such as + pdf or png) + diff (StructureGraph): an additional graph to + compare with, will color edges red that do not exist in diff + and edges green that are in diff graph but not in the + reference graph + hide_unconnected_nodes: if True, hide unconnected nodes + hide_image_edges: if True, do not draw edges that + go through periodic boundaries + edge_colors (bool): if True, use node colors to color edges + node_labels (bool): if True, label nodes with species and site index + weight_labels (bool): if True, label edges with weights + image_labels (bool): if True, label edges with + their periodic images (usually only used for debugging, + edges to periodic images always appear as dashed lines) + color_scheme (str): "VESTA" or "JMOL" + keep_dot (bool): keep GraphViz .dot file for later visualization + algo: any graphviz algo, "neato" (for simple graphs) + or "fdp" (for more crowded graphs) usually give good outputs """ if not which(algo): raise RuntimeError("StructureGraph graph drawing requires GraphViz binaries to be in the path.") @@ -907,14 +953,14 @@ def draw_graph_to_file( # optionally highlight differences with another graph if diff: - diff = self.diff(diff, strict=True) + _diff = self.diff(diff, strict=True) green_edges = [] red_edges = [] for u, v, k, d in g.edges(keys=True, data=True): - if (u, v, d["to_jimage"]) in diff["self"]: + if (u, v, d["to_jimage"]) in _diff["self"]: # edge has been deleted red_edges.append((u, v, k)) - elif (u, v, d["to_jimage"]) in diff["other"]: + elif (u, v, d["to_jimage"]) in _diff["other"]: # edge has been added green_edges.append((u, v, k)) for u, v, k in green_edges: @@ -938,7 +984,7 @@ def draw_graph_to_file( os.remove(f"{basename}.dot") @property - def types_and_weights_of_connections(self): + def types_and_weights_of_connections(self) -> dict: """ Extract a dictionary summarizing the types and weights of edges in the graph. @@ -963,7 +1009,7 @@ def get_label(u, v): return dict(types) @property - def weight_statistics(self): + def weight_statistics(self) -> dict: """ Extract a statistical summary of edge weights present in the graph. @@ -983,17 +1029,16 @@ def weight_statistics(self): "variance": stats.variance, } - def types_of_coordination_environments(self, anonymous=False): + def types_of_coordination_environments(self, anonymous: bool = False) -> list[str]: """ Extract information on the different co-ordination environments present in the graph. - :param anonymous: if anonymous, will replace specie names - with A, B, C, etc. + Args: + anonymous: if anonymous, will replace specie names with A, B, C, etc. Returns: - a list of co-ordination environments, - e.g. ['Mo-S(6)', 'S-Mo(3)'] + List of coordination environments, e.g. {'Mo-S(6)', 'S-Mo(3)'} """ motifs = set() for idx, site in enumerate(self.structure): @@ -1002,30 +1047,30 @@ def types_of_coordination_environments(self, anonymous=False): connected_sites = self.get_connected_sites(idx) connected_species = [connected_site.site.species_string for connected_site in connected_sites] - labels = [] + sp_counts = [] for sp in set(connected_species): count = connected_species.count(sp) - labels.append((count, sp)) + sp_counts.append((count, sp)) - labels = sorted(labels, reverse=True) + sp_counts = sorted(sp_counts, reverse=True) if anonymous: mapping = {centre_sp: "A"} available_letters = [chr(66 + idx) for idx in range(25)] - for label in labels: + for label in sp_counts: sp = label[1] if sp not in mapping: mapping[sp] = available_letters.pop(0) centre_sp = "A" - labels = [(label[0], mapping[label[1]]) for label in labels] + sp_counts = [(label[0], mapping[label[1]]) for label in sp_counts] - labels = [f"{label[1]}({label[0]})" for label in labels] + labels = [f"{label[1]}({label[0]})" for label in sp_counts] motif = f"{centre_sp}-{','.join(labels)}" motifs.add(motif) - return sorted(motifs) + return sorted(set(motifs)) - def as_dict(self): + def as_dict(self) -> dict: """ As in pymatgen.core.Structure except with using `to_dict_of_dicts` from NetworkX @@ -1039,12 +1084,12 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct) -> Self: """As in pymatgen.core.Structure except restoring graphs using from_dict_of_dicts from NetworkX to restore graph information. """ - struct = Structure.from_dict(d["structure"]) - return cls(struct, d["graphs"]) + struct = Structure.from_dict(dct["structure"]) + return cls(struct, dct["graphs"]) def __mul__(self, scaling_matrix): """ @@ -1055,7 +1100,9 @@ def __mul__(self, scaling_matrix): graph could also be done on the original graph, but a larger graph can be easier to visualize and reason about. - :param scaling_matrix: same as Structure.__mul__ + + Args: + scaling_matrix: same as Structure.__mul__ """ # Developer note: a different approach was also trialed, using # a simple Graph (instead of MultiDiGraph), with node indices @@ -1134,8 +1181,8 @@ def __mul__(self, scaling_matrix): # this could probably be a lot smaller tol = 0.05 - for u, v, k, dct in new_g.edges(keys=True, data=True): - to_jimage = dct["to_jimage"] # for node v + for u, v, k, data in new_g.edges(keys=True, data=True): + to_jimage = data["to_jimage"] # for node v # reduce unnecessary checking if to_jimage != (0, 0, 0): @@ -1143,29 +1190,24 @@ def __mul__(self, scaling_matrix): n_u = u % len(self.structure) n_v = v % len(self.structure) - # get fractional coordinates of where atoms defined - # by edge are expected to be, relative to original - # lattice (keeping original lattice has + # get fractional coordinates of where atoms defined by edge are expected + # to be, relative to original lattice (keeping original lattice has # significant benefits) v_image_frac = np.add(self.structure[n_v].frac_coords, to_jimage) u_frac = self.structure[n_u].frac_coords - # using the position of node u as a reference, - # get relative Cartesian coordinates of where - # atoms defined by edge are expected to be + # using the position of node u as a reference, get relative Cartesian + # coordinates of where atoms defined by edge are expected to be v_image_cart = orig_lattice.get_cartesian_coords(v_image_frac) u_cart = orig_lattice.get_cartesian_coords(u_frac) v_rel = np.subtract(v_image_cart, u_cart) - # now retrieve position of node v in - # new supercell, and get asgolute Cartesian - # coordinates of where atoms defined by edge - # are expected to be - v_expec = new_structure[u].coords + v_rel + # now retrieve position of node v in new supercell, and get absolute + # Cartesian coordinates of where atoms defined by edge are expected to be + v_expect = new_structure[u].coords + v_rel - # now search in new structure for these atoms - # query returns (distance, index) - v_present = kd_tree.query(v_expec) + # now search in new structure for these atoms query returns (distance, index) + v_present = kd_tree.query(v_expect) v_present = v_present[1] if v_present[0] <= tol else None # check if image sites now present in supercell @@ -1174,10 +1216,10 @@ def __mul__(self, scaling_matrix): if v_present is not None: new_u = u new_v = v_present - new_d = dct.copy() + new_data = data.copy() # node now inside supercell - new_d["to_jimage"] = (0, 0, 0) + new_data["to_jimage"] = (0, 0, 0) edges_to_remove.append((u, v, k)) @@ -1189,7 +1231,7 @@ def __mul__(self, scaling_matrix): new_u, new_v = new_v, new_u edges_inside_supercell.append({new_u, new_v}) - edges_to_add.append((new_u, new_v, new_d)) + edges_to_add.append((new_u, new_v, new_data)) else: # want to find new_v such that we have @@ -1197,7 +1239,7 @@ def __mul__(self, scaling_matrix): # so that nodes on one side of supercell # are connected to nodes on opposite side - v_expec_frac = new_structure.lattice.get_fractional_coords(v_expec) + v_expec_frac = new_structure.lattice.get_fractional_coords(v_expect) # find new to_jimage # use np.around to fix issues with finite precision leading to incorrect image @@ -1205,27 +1247,27 @@ def __mul__(self, scaling_matrix): v_expec_image = v_expec_image - v_expec_image % 1 v_expec_frac = np.subtract(v_expec_frac, v_expec_image) - v_expec = new_structure.lattice.get_cartesian_coords(v_expec_frac) - v_present = kd_tree.query(v_expec) + v_expect = new_structure.lattice.get_cartesian_coords(v_expec_frac) + v_present = kd_tree.query(v_expect) v_present = v_present[1] if v_present[0] <= tol else None if v_present is not None: new_u = u new_v = v_present - new_d = dct.copy() + new_data = data.copy() new_to_jimage = tuple(map(int, v_expec_image)) # normalize direction if new_v < new_u: new_u, new_v = new_v, new_u - new_to_jimage = tuple(np.multiply(-1, dct["to_jimage"]).astype(int)) + new_to_jimage = tuple(np.multiply(-1, data["to_jimage"]).astype(int)) - new_d["to_jimage"] = new_to_jimage + new_data["to_jimage"] = new_to_jimage edges_to_remove.append((u, v, k)) if (new_u, new_v, new_to_jimage) not in new_periodic_images: - edges_to_add.append((new_u, new_v, new_d)) + edges_to_add.append((new_u, new_v, new_data)) new_periodic_images.append((new_u, new_v, new_to_jimage)) logger.debug(f"Removing {len(edges_to_remove)} edges, adding {len(edges_to_add)} new edges.") @@ -1233,18 +1275,18 @@ def __mul__(self, scaling_matrix): # add/delete marked edges for edge in edges_to_remove: new_g.remove_edge(*edge) - for u, v, dct in edges_to_add: - new_g.add_edge(u, v, **dct) + for u, v, data in edges_to_add: + new_g.add_edge(u, v, **data) # return new instance of StructureGraph with supercell - dct = { + data = { "@module": type(self).__module__, "@class": type(self).__name__, "structure": new_structure.as_dict(), "graphs": json_graph.adjacency_data(new_g), } - return StructureGraph.from_dict(dct) + return type(self).from_dict(data) def __rmul__(self, other): return self.__mul__(other) @@ -1299,7 +1341,7 @@ def __len__(self): """length of Structure / number of nodes in graph""" return len(self.structure) - def sort(self, key=None, reverse=False): + def sort(self, key=None, reverse: bool = False) -> None: """Same as Structure.sort(). Also remaps nodes in graph. Args: @@ -1332,7 +1374,7 @@ def sort(self, key=None, reverse=False): self.graph.add_edge(u, v, **d) def __copy__(self): - return StructureGraph.from_dict(self.as_dict()) + return type(self).from_dict(self.as_dict()) def __eq__(self, other: object) -> bool: """ @@ -1340,8 +1382,8 @@ def __eq__(self, other: object) -> bool: and have the same edges between Sites. Edge weights can be different and StructureGraphs can still be considered equal. - :param other: StructureGraph - :return (bool): + Args: + other: StructureGraph """ if not isinstance(other, StructureGraph): return NotImplemented @@ -1351,13 +1393,13 @@ def __eq__(self, other: object) -> bool: other_sorted = other.__copy__() other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)]) - edges = {(u, v, d["to_jimage"]) for u, v, d in self.graph.edges(keys=False, data=True)} + edges = {(u, v, data["to_jimage"]) for u, v, data in self.graph.edges(keys=False, data=True)} - edges_other = {(u, v, d["to_jimage"]) for u, v, d in other_sorted.graph.edges(keys=False, data=True)} + edges_other = {(u, v, data["to_jimage"]) for u, v, data in other_sorted.graph.edges(keys=False, data=True)} return (edges == edges_other) and (self.structure == other_sorted.structure) - def diff(self, other, strict=True): + def diff(self, other: StructureGraph, strict: bool = True) -> dict: """ Compares two StructureGraphs. Returns dict with keys 'self', 'other', 'both' with edges that are @@ -1377,14 +1419,15 @@ def diff(self, other, strict=True): same if the underlying Structures are ordered differently. - :param other: StructureGraph - :param strict: if False, will compare bonds - from different Structures, with node indices - replaced by Species strings, will not count - number of occurrences of bonds + Args: + other: StructureGraph + strict: if False, will compare bonds + from different Structures, with node indices + replaced by Species strings, will not count + number of occurrences of bonds """ if self.structure != other.structure and strict: - return ValueError("Meaningless to compare StructureGraphs if corresponding Structures are different.") + raise ValueError("Meaningless to compare StructureGraphs if corresponding Structures are different.") if strict: # sort for consistent node indices @@ -1393,25 +1436,26 @@ def diff(self, other, strict=True): other_sorted = copy.copy(other) other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)]) - edges = {(u, v, d["to_jimage"]) for u, v, d in self.graph.edges(keys=False, data=True)} + edges: set[tuple] = {(u, v, data["to_jimage"]) for u, v, data in self.graph.edges(keys=False, data=True)} - edges_other = {(u, v, d["to_jimage"]) for u, v, d in other_sorted.graph.edges(keys=False, data=True)} + edges_other: set[tuple] = { + (u, v, data["to_jimage"]) for u, v, data in other_sorted.graph.edges(keys=False, data=True) + } else: edges = { - (str(self.structure[u].specie), str(self.structure[v].specie)) - for u, v, d in self.graph.edges(keys=False, data=True) + (str(self.structure[u].specie), str(self.structure[v].specie)) for u, v in self.graph.edges(keys=False) } edges_other = { (str(other.structure[u].specie), str(other.structure[v].specie)) - for u, v, d in other.graph.edges(keys=False, data=True) + for u, v in other.graph.edges(keys=False) } if len(edges) == 0 and len(edges_other) == 0: - jaccard_dist = 0 # by definition + jaccard_dist = 0.0 # by definition else: - jaccard_dist = 1 - len(edges & edges_other) / len(edges | edges_other) + jaccard_dist = 1.0 - len(edges & edges_other) / len(edges | edges_other) return { "self": edges - edges_other, @@ -1420,7 +1464,7 @@ def diff(self, other, strict=True): "dist": jaccard_dist, } - def get_subgraphs_as_molecules(self, use_weights=False): + def get_subgraphs_as_molecules(self, use_weights: bool = False) -> list[Molecule]: """ Retrieve subgraphs as molecules, useful for extracting molecules from periodic crystals. @@ -1429,12 +1473,13 @@ def get_subgraphs_as_molecules(self, use_weights=False): present in the crystal (a duplicate defined as an isomorphic subgraph). - :param use_weights (bool): If True, only treat subgraphs - as isomorphic if edges have the same weights. Typically, - this means molecules will need to have the same bond - lengths to be defined as duplicates, otherwise bond - lengths can differ. This is a fairly robust approach, - but will treat e.g. enantiomers as being duplicates. + Args: + use_weights (bool): If True, only treat subgraphs + as isomorphic if edges have the same weights. Typically, + this means molecules will need to have the same bond + lengths to be defined as duplicates, otherwise bond + lengths can differ. This is a fairly robust approach, + but will treat e.g. enantiomers as being duplicates. Returns: list of unique Molecules in Structure @@ -1474,7 +1519,7 @@ def edge_match(e1, e2): return True # prune duplicate subgraphs - unique_subgraphs = [] + unique_subgraphs: list = [] for subgraph in molecule_subgraphs: already_present = [ nx.is_isomorphic(subgraph, g, node_match=node_match, edge_match=edge_match) for g in unique_subgraphs @@ -1516,8 +1561,8 @@ class MoleculeGraph(MSONable): def __init__(self, molecule, graph_data=None): """ - If constructing this class manually, use the `with_empty_graph` - method or `with_local_env_strategy` method (using an algorithm + If constructing this class manually, use the `from_empty_graph` + method or `from_local_env_strategy` method (using an algorithm provided by the `local_env` module, such as O'Keeffe). This class that contains connection information: @@ -1531,11 +1576,12 @@ def __init__(self, molecule, graph_data=None): Use cases for this include storing bonding information, NMR J-couplings, Heisenberg exchange parameters, etc. - :param molecule: Molecule object + Args: + molecule: Molecule object - :param graph_data: dict containing graph information in - dict format (not intended to be constructed manually, - see as_dict method for format) + graph_data: dict containing graph information in + dict format (not intended to be constructed manually, + see as_dict method for format) """ if isinstance(molecule, MoleculeGraph): # just make a copy from input @@ -1546,33 +1592,36 @@ def __init__(self, molecule, graph_data=None): # tidy up edge attr dicts, reading to/from json duplicates # information - for _, _, _, d in self.graph.edges(keys=True, data=True): + for _, _, _, data in self.graph.edges(keys=True, data=True): for key in ("id", "key"): - d.pop(key, None) + data.pop(key, None) # ensure images are tuples (conversion to lists happens # when serializing back from json), it's important images # are hashable/immutable - if "to_jimage" in d: - d["to_jimage"] = tuple(d["to_jimage"]) - if "from_jimage" in d: - d["from_jimage"] = tuple(d["from_jimage"]) + if "to_jimage" in data: + data["to_jimage"] = tuple(data["to_jimage"]) + if "from_jimage" in data: + data["from_jimage"] = tuple(data["from_jimage"]) self.set_node_attributes() @classmethod - def with_empty_graph(cls, molecule, name="bonds", edge_weight_name=None, edge_weight_units=None): + def from_empty_graph(cls, molecule, name="bonds", edge_weight_name=None, edge_weight_units=None) -> Self: """ Constructor for MoleculeGraph, returns a MoleculeGraph object with an empty graph (no edges, only nodes defined that correspond to Sites in Molecule). - :param molecule (Molecule): - :param name (str): name of graph, e.g. "bonds" - :param edge_weight_name (str): name of edge weights, - e.g. "bond_length" or "exchange_constant" - :param edge_weight_units (str): name of edge weight units + Args: + molecule (Molecule): + name (str): name of graph, e.g. "bonds" + edge_weight_name (str): name of edge weights, + e.g. "bond_length" or "exchange_constant" + edge_weight_units (str): name of edge weight units e.g. "ร…" or "eV" - :return (MoleculeGraph): + + Returns: + MoleculeGraph """ if edge_weight_name and (edge_weight_units is None): raise ValueError( @@ -1596,22 +1645,31 @@ def with_empty_graph(cls, molecule, name="bonds", edge_weight_name=None, edge_we return cls(molecule, graph_data=graph_data) - @staticmethod - def with_edges(molecule: Molecule, edges: dict[tuple[int, int], dict]): + @classmethod + @deprecated( + from_empty_graph, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_empty_graph(cls, *args, **kwargs): + return cls.from_empty_graph(*args, **kwargs) + + @classmethod + def from_edges(cls, molecule: Molecule, edges: dict[tuple[int, int], None | dict]) -> Self: """ Constructor for MoleculeGraph, using pre-existing or pre-defined edges with optional edge parameters. - :param molecule: Molecule object - :param edges: dict representing the bonds of the functional - group (format: {(u, v): props}, where props is a dictionary of - properties, including weight. Props should be None if no - additional properties are to be specified. + Args: + molecule: Molecule object + edges: dict representing the bonds of the functional + group (format: {(u, v): props}, where props is a dictionary of + properties, including weight. Props should be None if no + additional properties are to be specified. Returns: - mg, a MoleculeGraph + A MoleculeGraph """ - mg = MoleculeGraph.with_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="") + mg = cls.from_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="") for edge, props in edges.items(): try: @@ -1620,12 +1678,13 @@ def with_edges(molecule: Molecule, edges: dict[tuple[int, int], dict]): except TypeError: raise ValueError("Edges must be given as (from_index, to_index) tuples") - if props is not None: + if props is None: + weight = None + + else: weight = props.pop("weight", None) if len(props.items()) == 0: - props = None # type: ignore[assignment] - else: - weight = None + props = None nodes = mg.graph.nodes if not (from_index in nodes and to_index in nodes): @@ -1638,15 +1697,23 @@ def with_edges(molecule: Molecule, edges: dict[tuple[int, int], dict]): mg.set_node_attributes() return mg - @staticmethod - def with_local_env_strategy(molecule, strategy): + @classmethod + @deprecated( + from_edges, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_edges(cls, *args, **kwargs): + return cls.from_edges(*args, **kwargs) + + @classmethod + def from_local_env_strategy(cls, molecule, strategy) -> Self: """ Constructor for MoleculeGraph, using a strategy from pymatgen.analysis.local_env. - :param molecule: Molecule object - :param strategy: an instance of a - pymatgen.analysis.local_env.NearNeighbors object + molecule: Molecule object + strategy: an instance of a + pymatgen.analysis.local_env.NearNeighbors object Returns: mg, a MoleculeGraph @@ -1655,7 +1722,7 @@ def with_local_env_strategy(molecule, strategy): raise ValueError(f"{strategy=} is not designed for use with molecules! Choose another strategy.") extend_structure = strategy.extend_structure_molecules - mg = MoleculeGraph.with_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="") + mg = cls.from_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="") # NearNeighbor classes only (generally) work with structures # molecules have to be boxed first @@ -1670,19 +1737,21 @@ def with_local_env_strategy(molecule, strategy): else: structure = None - for n in range(len(molecule)): - neighbors = strategy.get_nn_info(molecule, n) if structure is None else strategy.get_nn_info(structure, n) + for idx in range(len(molecule)): + neighbors = ( + strategy.get_nn_info(molecule, idx) if structure is None else strategy.get_nn_info(structure, idx) + ) for neighbor in neighbors: # all bonds in molecules should not cross # (artificial) periodic boundaries if not np.array_equal(neighbor["image"], [0, 0, 0]): continue - if n > neighbor["site_index"]: + if idx > neighbor["site_index"]: from_index = neighbor["site_index"] - to_index = n + to_index = idx else: - from_index = n + from_index = idx to_index = neighbor["site_index"] mg.add_edge( @@ -1703,6 +1772,14 @@ def with_local_env_strategy(molecule, strategy): mg.set_node_attributes() return mg + @classmethod + @deprecated( + from_local_env_strategy, + "Deprecated on 2024-03-29, to be removed on 2025-03-20.", + ) + def with_local_env_strategy(cls, *args, **kwargs): + return cls.from_local_env_strategy(*args, **kwargs) + @property def name(self): """Name of graph""" @@ -1736,14 +1813,15 @@ def add_edge( However, images will always be shifted so that from_index < to_index and from_jimage becomes (0, 0, 0). - :param from_index: index of site connecting from - :param to_index: index of site connecting to - :param weight (float): e.g. bond length - :param warn_duplicates (bool): if True, will warn if - trying to add duplicate edges (duplicate edges will not - be added in either case) - :param edge_properties (dict): any other information to - store on graph edges, similar to Structure's site_properties + Args: + from_index: index of site connecting from + to_index: index of site connecting to + weight (float): e.g. bond length + warn_duplicates (bool): if True, will warn if + trying to add duplicate edges (duplicate edges will not + be added in either case) + edge_properties (dict): any other information to + store on graph edges, similar to Structure's site_properties """ # this is not necessary for the class to work, but # just makes it neater @@ -1783,19 +1861,20 @@ def insert_node( A wrapper around Molecule.insert(), which also incorporates the new site into the MoleculeGraph. - :param idx: Index at which to insert the new site - :param species: Species for the new site - :param coords: 3x1 array representing coordinates of the new site - :param validate_proximity: For Molecule.insert(); if True (default - False), distance will be checked to ensure that site can be safely - added. - :param site_properties: Site properties for Molecule - :param edges: List of dicts representing edges to be added to the - MoleculeGraph. These edges must include the index of the new site i, - and all indices used for these edges should reflect the - MoleculeGraph AFTER the insertion, NOT before. Each dict should at - least have a "to_index" and "from_index" key, and can also have a - "weight" and a "properties" key. + Args: + idx: Index at which to insert the new site + species: Species for the new site + coords: 3x1 array representing coordinates of the new site + validate_proximity: For Molecule.insert(); if True (default + False), distance will be checked to ensure that site can be safely + added. + site_properties: Site properties for Molecule + edges: List of dicts representing edges to be added to the + MoleculeGraph. These edges must include the index of the new site i, + and all indices used for these edges should reflect the + MoleculeGraph AFTER the insertion, NOT before. Each dict should at + least have a "to_index" and "from_index" key, and can also have a + "weight" and a "properties" key. """ self.molecule.insert( idx, @@ -1850,16 +1929,17 @@ def alter_edge(self, from_index, to_index, new_weight=None, new_edge_properties= Alters either the weight or the edge_properties of an edge in the MoleculeGraph. - :param from_index: int - :param to_index: int - :param new_weight: alter_edge does not require - that weight be altered. As such, by default, this - is None. If weight is to be changed, it should be a - float. - :param new_edge_properties: alter_edge does not require - that edge_properties be altered. As such, by default, - this is None. If any edge properties are to be changed, - it should be a dictionary of edge properties to be changed. + Args: + from_index: int + to_index: int + new_weight: alter_edge does not require + that weight be altered. As such, by default, this + is None. If weight is to be changed, it should be a + float. + new_edge_properties: alter_edge does not require + that edge_properties be altered. As such, by default, + this is None. If any edge properties are to be changed, + it should be a dictionary of edge properties to be changed. """ existing_edge = self.graph.get_edge_data(from_index, to_index) @@ -1881,11 +1961,12 @@ def break_edge(self, from_index, to_index, allow_reverse=False): """ Remove an edge from the MoleculeGraph. - :param from_index: int - :param to_index: int - :param allow_reverse: If allow_reverse is True, then break_edge will - attempt to break both (from_index, to_index) and, failing that, - will attempt to break (to_index, from_index). + Args: + from_index: int + to_index: int + allow_reverse: If allow_reverse is True, then break_edge will + attempt to break both (from_index, to_index) and, failing that, + will attempt to break (to_index, from_index). """ # ensure that edge exists before attempting to remove it existing_edge = self.graph.get_edge_data(from_index, to_index) @@ -1906,12 +1987,12 @@ def break_edge(self, from_index, to_index, allow_reverse=False): f"no edge exists between those sites." ) - def remove_nodes(self, indices): + def remove_nodes(self, indices: list[int]) -> None: """ A wrapper for Molecule.remove_sites(). - :param indices: list of indices in the current Molecule (and graph) to - be removed. + Args: + indices: indices in the current Molecule (and graph) to be removed. """ self.molecule.remove_sites(indices) self.graph.remove_nodes_from(indices) @@ -2009,14 +2090,16 @@ def split_molecule_subgraphs(self, bonds, allow_reverse=False, alterations=None) NOTE: This function does not modify the original MoleculeGraph. It creates a copy, modifies that, and returns two or more new MoleculeGraph objects. - :param bonds: list of tuples (from_index, to_index) - representing bonds to be broken to split the MoleculeGraph. - :param alterations: a dict {(from_index, to_index): alt}, - where alt is a dictionary including weight and/or edge - properties to be changed following the split. - :param allow_reverse: If allow_reverse is True, then break_edge will - attempt to break both (from_index, to_index) and, failing that, - will attempt to break (to_index, from_index). + + Args: + bonds: list of tuples (from_index, to_index) + representing bonds to be broken to split the MoleculeGraph. + alterations: a dict {(from_index, to_index): alt}, + where alt is a dictionary including weight and/or edge + properties to be changed following the split. + allow_reverse: If allow_reverse is True, then break_edge will + attempt to break both (from_index, to_index) and, failing that, + will attempt to break (to_index, from_index). Returns: list of MoleculeGraphs. @@ -2073,8 +2156,8 @@ def build_unique_fragments(self): unique_frags = [] for frag in fragments: found = False - for f in unique_frags: - if _isomorphic(frag, f): + for fragment in unique_frags: + if _isomorphic(frag, fragment): found = True break if not found: @@ -2100,7 +2183,7 @@ def build_unique_fragments(self): edges[(from_index, to_index)] = edge_props unique_mol_graph_list.append( - self.with_edges( + self.from_edges( Molecule(species=species, coords=coords, charge=self.molecule.charge), edges, ) @@ -2128,31 +2211,31 @@ def substitute_group( NOTE: using a MoleculeGraph will generally produce a different graph compared with using a Molecule or str (when not using graph_dict). - :param index: Index of atom to substitute. - :param func_grp: Substituent molecule. There are three options: - - 1. Providing an actual molecule as the input. The first atom - must be a DummySpecies X, indicating the position of - nearest neighbor. The second atom must be the next - nearest atom. For example, for a methyl group - substitution, func_grp should be X-CH3, where X is the - first site and C is the second site. What the code will - do is to remove the index site, and connect the nearest - neighbor to the C atom in CH3. The X-C bond indicates the - directionality to connect the atoms. - 2. A string name. The molecule will be obtained from the - relevant template in func_groups.json. - 3. A MoleculeGraph object. - :param strategy: Class from pymatgen.analysis.local_env. - :param bond_order: A specified bond order to calculate the bond + Args: + index: Index of atom to substitute. + func_grp: Substituent molecule. There are three options: + 1. Providing an actual molecule as the input. The first atom + must be a DummySpecies X, indicating the position of + nearest neighbor. The second atom must be the next + nearest atom. For example, for a methyl group + substitution, func_grp should be X-CH3, where X is the + first site and C is the second site. What the code will + do is to remove the index site, and connect the nearest + neighbor to the C atom in CH3. The X-C bond indicates the + directionality to connect the atoms. + 2. A string name. The molecule will be obtained from the + relevant template in func_groups.json. + 3. A MoleculeGraph object. + strategy: Class from pymatgen.analysis.local_env. + bond_order: A specified bond order to calculate the bond length between the attached functional group and the nearest neighbor site. Defaults to 1. - :param graph_dict: Dictionary representing the bonds of the functional + graph_dict: Dictionary representing the bonds of the functional group (format: {(u, v): props}, where props is a dictionary of properties, including weight. If None, then the algorithm will attempt to automatically determine bonds using one of a list of strategies defined in pymatgen.analysis.local_env. - :param strategy_params: dictionary of keyword arguments for strategy. + strategy_params: dictionary of keyword arguments for strategy. If None, default parameters will be used. """ @@ -2208,7 +2291,7 @@ def map_indices(grp): ) else: - graph = self.with_local_env_strategy(func_grp, strategy(**(strategy_params or {}))) + graph = self.from_local_env_strategy(func_grp, strategy(**(strategy_params or {}))) for u, v in list(graph.graph.edges()): edge_props = graph.graph.get_edge_data(u, v)[0] @@ -2242,32 +2325,32 @@ def replace_group( TODO: Figure out how to replace into a ring structure. - :param index: Index of atom to substitute. - :param func_grp: Substituent molecule. There are three options: - - 1. Providing an actual molecule as the input. The first atom - must be a DummySpecies X, indicating the position of - nearest neighbor. The second atom must be the next - nearest atom. For example, for a methyl group - substitution, func_grp should be X-CH3, where X is the - first site and C is the second site. What the code will - do is to remove the index site, and connect the nearest - neighbor to the C atom in CH3. The X-C bond indicates the - directionality to connect the atoms. - 2. A string name. The molecule will be obtained from the - relevant template in func_groups.json. - 3. A MoleculeGraph object. - :param strategy: Class from pymatgen.analysis.local_env. - :param bond_order: A specified bond order to calculate the bond - length between the attached functional group and the nearest - neighbor site. Defaults to 1. - :param graph_dict: Dictionary representing the bonds of the functional - group (format: {(u, v): props}, where props is a dictionary of - properties, including weight. If None, then the algorithm - will attempt to automatically determine bonds using one of - a list of strategies defined in pymatgen.analysis.local_env. - :param strategy_params: dictionary of keyword arguments for strategy. - If None, default parameters will be used. + Args: + index: Index of atom to substitute. + func_grp: Substituent molecule. There are three options: + 1. Providing an actual molecule as the input. The first atom + must be a DummySpecies X, indicating the position of + nearest neighbor. The second atom must be the next + nearest atom. For example, for a methyl group + substitution, func_grp should be X-CH3, where X is the + first site and C is the second site. What the code will + do is to remove the index site, and connect the nearest + neighbor to the C atom in CH3. The X-C bond indicates the + directionality to connect the atoms. + 2. A string name. The molecule will be obtained from the + relevant template in func_groups.json. + 3. A MoleculeGraph object. + strategy: Class from pymatgen.analysis.local_env. + bond_order: A specified bond order to calculate the bond + length between the attached functional group and the nearest + neighbor site. Defaults to 1. + graph_dict: Dictionary representing the bonds of the functional + group (format: {(u, v): props}, where props is a dictionary of + properties, including weight. If None, then the algorithm + will attempt to automatically determine bonds using one of + a list of strategies defined in pymatgen.analysis.local_env. + strategy_params: dictionary of keyword arguments for strategy. + If None, default parameters will be used. """ self.set_node_attributes() neighbors = self.get_connected_sites(index) @@ -2355,8 +2438,8 @@ def find_rings(self, including=None) -> list[list[tuple[int, int]]]: for cycle in cycles_nodes: edges = [] - for idx, itm in enumerate(cycle): - edges.append((cycle[idx - 1], itm)) + for idx, itm in enumerate(cycle, start=-1): + edges.append((cycle[idx], itm)) cycles_edges.append(edges) return cycles_edges @@ -2368,8 +2451,9 @@ def get_connected_sites(self, n): Index is the index of the corresponding site in the original structure, weight can be None if not defined. - :param n: index of Site in Molecule - :param jimage: lattice vector of site + Args: + n: index of Site in Molecule + jimage: lattice vector of site Returns: list of ConnectedSite tuples, @@ -2402,13 +2486,17 @@ def get_connected_sites(self, n): return connected_sites - def get_coordination_of_site(self, n): + def get_coordination_of_site(self, n) -> int: """ Returns the number of neighbors of site n. In graph terms, simply returns degree of node corresponding to site n. - :param n: index of site - :return (int): + + Args: + n: index of site + + Returns: + int: the number of neighbors of site n. """ n_self_loops = sum(1 for n, v in self.graph.edges(n) if n == v) return self.graph.degree(n) - n_self_loops @@ -2439,31 +2527,28 @@ def draw_graph_to_file( `hide_image_edges` can help, especially in larger graphs. - :param filename: filename to output, will detect filetype - from extension (any graphviz filetype supported, such as - pdf or png) - :param diff (StructureGraph): an additional graph to - compare with, will color edges red that do not exist in diff - and edges green that are in diff graph but not in the - reference graph - :param hide_unconnected_nodes: if True, hide unconnected - nodes - :param hide_image_edges: if True, do not draw edges that - go through periodic boundaries - :param edge_colors (bool): if True, use node colors to - color edges - :param node_labels (bool): if True, label nodes with - species and site index - :param weight_labels (bool): if True, label edges with - weights - :param image_labels (bool): if True, label edges with - their periodic images (usually only used for debugging, - edges to periodic images always appear as dashed lines) - :param color_scheme (str): "VESTA" or "JMOL" - :param keep_dot (bool): keep GraphViz .dot file for later - visualization - :param algo: any graphviz algo, "neato" (for simple graphs) - or "fdp" (for more crowded graphs) usually give good outputs + Args: + filename: filename to output, will detect filetype + from extension (any graphviz filetype supported, such as + pdf or png) + diff (StructureGraph): an additional graph to + compare with, will color edges red that do not exist in diff + and edges green that are in diff graph but not in the + reference graph + hide_unconnected_nodes: if True, hide unconnected nodes + hide_image_edges: if True, do not draw edges that + go through periodic boundaries + edge_colors (bool): if True, use node colors to color edges + node_labels (bool): if True, label nodes with + species and site index + weight_labels (bool): if True, label edges with weights + image_labels (bool): if True, label edges with + their periodic images (usually only used for debugging, + edges to periodic images always appear as dashed lines) + color_scheme (str): "VESTA" or "JMOL" + keep_dot (bool): keep GraphViz .dot file for later visualization + algo: any graphviz algo, "neato" (for simple graphs) + or "fdp" (for more crowded graphs) usually give good outputs """ if not which(algo): raise RuntimeError("StructureGraph graph drawing requires GraphViz binaries to be in the path.") @@ -2595,7 +2680,7 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ As in pymatgen.core.Molecule except restoring graphs using `from_dict_of_dicts` @@ -2687,16 +2772,13 @@ def sort(self, key: Callable[[Molecule], float] | None = None, reverse: bool = F self.graph.add_edge(u, v, **data) def __copy__(self): - return MoleculeGraph.from_dict(self.as_dict()) + return type(self).from_dict(self.as_dict()) def __eq__(self, other: object) -> bool: """ Two MoleculeGraphs are equal if they have equal Molecules, and have the same edges between Sites. Edge weights can be different and MoleculeGraphs can still be considered equal. - - :param other: MoleculeGraph - :return (bool): """ if not isinstance(other, type(self)): return NotImplemented @@ -2710,19 +2792,20 @@ def __eq__(self, other: object) -> bool: other_sorted = other.__copy__() other_sorted.sort(key=lambda site: mapping[tuple(site.coords)]) - edges = {(u, v) for u, v, d in self.graph.edges(keys=False, data=True)} + edges = set(self.graph.edges(keys=False)) - edges_other = {(u, v) for u, v, d in other_sorted.graph.edges(keys=False, data=True)} + edges_other = set(other_sorted.graph.edges(keys=False)) return (edges == edges_other) and (self.molecule == other_sorted.molecule) - def isomorphic_to(self, other): + def isomorphic_to(self, other: MoleculeGraph) -> bool: """ Checks if the graphs of two MoleculeGraphs are isomorphic to one another. In order to prevent problems with misdirected edges, both graphs are converted into undirected nx.Graph objects. - :param other: MoleculeGraph object to be compared. + Args: + other: MoleculeGraph object to be compared. Returns: bool @@ -2755,11 +2838,11 @@ def diff(self, other, strict=True): same if the underlying Molecules are ordered differently. - :param other: MoleculeGraph - :param strict: if False, will compare bonds - from different Molecules, with node indices - replaced by Species strings, will not count - number of occurrences of bonds + Args: + other: MoleculeGraph + strict: if False, will compare bonds + from different Molecules, with node indices replaced by Species + strings, will not count number of occurrences of bonds """ if self.molecule != other.molecule and strict: return ValueError("Meaningless to compare MoleculeGraphs if corresponding Molecules are different.") @@ -2771,21 +2854,21 @@ def diff(self, other, strict=True): other_sorted = copy.copy(other) other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)]) - edges = {(u, v, d.get("to_jimage", (0, 0, 0))) for u, v, d in self.graph.edges(keys=False, data=True)} + edges = {(u, v, data.get("to_jimage", (0, 0, 0))) for u, v, data in self.graph.edges(keys=False, data=True)} edges_other = { - (u, v, d.get("to_jimage", (0, 0, 0))) for u, v, d in other_sorted.graph.edges(keys=False, data=True) + (u, v, data.get("to_jimage", (0, 0, 0))) + for u, v, data in other_sorted.graph.edges(keys=False, data=True) } else: edges = { - (str(self.molecule[u].specie), str(self.molecule[v].specie)) - for u, v, d in self.graph.edges(keys=False, data=True) + (str(self.molecule[u].specie), str(self.molecule[v].specie)) for u, v in self.graph.edges(keys=False) } edges_other = { (str(other.structure[u].specie), str(other.structure[v].specie)) - for u, v, d in other.graph.edges(keys=False, data=True) + for u, v in other.graph.edges(keys=False) } if len(edges) == 0 and len(edges_other) == 0: diff --git a/pymatgen/analysis/interface.py b/pymatgen/analysis/interface.py deleted file mode 100644 index ef2b855bc9a..00000000000 --- a/pymatgen/analysis/interface.py +++ /dev/null @@ -1,23 +0,0 @@ -"""This module provides classes to store, generate, and manipulate material interfaces.""" - -from __future__ import annotations - -import warnings - -from pymatgen.analysis.interfaces import CoherentInterfaceBuilder # noqa: F401 -from pymatgen.core.interface import Interface # noqa: F401 - -__author__ = "Eric Sivonxay, Shyam Dwaraknath, and Kyle Bystrom" -__copyright__ = "Copyright 2019, The Materials Project" -__version__ = "0.1" -__maintainer__ = "Kyle Bystrom" -__email__ = "kylebystrom@gmail.com" -__date__ = "5/29/2019" -__status__ = "Prototype" - -warnings.warn( - "The substrate_analyzer module is being moved to the interfaces submodule in analysis." - " These imports will break in Pymatgen 2023", - category=FutureWarning, - stacklevel=2, -) diff --git a/pymatgen/analysis/interface_reactions.py b/pymatgen/analysis/interface_reactions.py index c8a19222d5c..fc69ad75df7 100644 --- a/pymatgen/analysis/interface_reactions.py +++ b/pymatgen/analysis/interface_reactions.py @@ -316,7 +316,7 @@ def _get_reaction(self, x: float) -> Reaction: return reaction - def _get_elem_amt_in_rxn(self, rxn: Reaction) -> int: + def _get_elem_amt_in_rxn(self, rxn: Reaction) -> float: """ Computes total number of atoms in a reaction formula for elements not in external reservoir. This method is used in the calculation diff --git a/pymatgen/analysis/interfaces/substrate_analyzer.py b/pymatgen/analysis/interfaces/substrate_analyzer.py index 7e069aaea09..5a528f74b20 100644 --- a/pymatgen/analysis/interfaces/substrate_analyzer.py +++ b/pymatgen/analysis/interfaces/substrate_analyzer.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core import Structure @@ -39,7 +40,7 @@ def from_zsl( substrate_miller, elasticity_tensor=None, ground_state_energy=0, - ): + ) -> Self: """Generate a substrate match from a ZSL match plus metadata.""" # Get the appropriate surface structure struct = SlabGenerator(film, film_miller, 20, 15, primitive=False).get_slab().oriented_unit_cell diff --git a/pymatgen/analysis/local_env.py b/pymatgen/analysis/local_env.py index 6ca1a1a1f92..6ec084015ca 100644 --- a/pymatgen/analysis/local_env.py +++ b/pymatgen/analysis/local_env.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, get_args import numpy as np -from monty.dev import requires +from monty.dev import deprecated, requires from monty.serialization import loadfn from ruamel.yaml import YAML from scipy.spatial import Voronoi @@ -34,6 +34,8 @@ openbabel = None if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core.composition import SpeciesLike @@ -407,7 +409,7 @@ def get_all_nn_info(self, structure: Structure): Args: structure (Structure): Input structure - Return: + Returns: List of NN site information for each site in the structure. Each entry has the same format as `get_nn_info` """ @@ -623,11 +625,13 @@ def get_bonded_structure( order_parameters = [self.get_local_order_parameters(structure, n) for n in range(len(structure))] structure.add_site_property("order_parameters", order_parameters) - sg = StructureGraph.with_local_env_strategy(structure, self, weights=weights, edge_properties=edge_properties) + struct_graph = StructureGraph.from_local_env_strategy( + structure, self, weights=weights, edge_properties=edge_properties + ) # sets the attributes - sg.set_node_attributes() - return sg + struct_graph.set_node_attributes() + return struct_graph def get_local_order_parameters(self, structure: Structure, n: int): """ @@ -641,9 +645,9 @@ def get_local_order_parameters(self, structure: Structure, n: int): structure: Structure object n (int): site index. - Returns (dict[str, float]): - A dict of order parameters (values) and the - underlying motif type (keys; for example, tetrahedral). + Returns: + dict[str, float]: A dict of order parameters (values) and the + underlying motif type (keys; for example, tetrahedral). """ # code from @nisse3000, moved here from graphs to avoid circular # import, also makes sense to have this as a general NN method @@ -1164,7 +1168,7 @@ def _is_in_targets(site, targets): targets ([Element]) List of elements Returns: - (boolean) Whether this site contains a certain list of elements + boolean: Whether this site contains a certain list of elements """ elems = _get_elements(site) return all(elem in targets for elem in elems) @@ -1520,7 +1524,7 @@ def get_bonded_structure(self, structure: Structure, decorate: bool = False) -> order_parameters = [self.get_local_order_parameters(structure, n) for n in range(len(structure))] structure.add_site_property("order_parameters", order_parameters) - return MoleculeGraph.with_local_env_strategy(structure, self) + return MoleculeGraph.from_local_env_strategy(structure, self) def get_nn_shell_info(self, structure: Structure, site_idx, shell): """Get a certain nearest neighbor shell for a certain site. @@ -1612,8 +1616,9 @@ def get_nn_info(self, structure: Structure, n: int): Get all near-neighbor sites and weights (orders) of bonds for a given atom. - :param structure: input Molecule. - :param n: index of site for which to determine near neighbors. + Args: + structure: input Molecule. + n: index of site for which to determine near neighbors. Returns: [dict] representing a neighboring site and the type of @@ -1666,7 +1671,7 @@ def get_bonded_structure(self, structure: Structure, decorate: bool = False) -> order_parameters = [self.get_local_order_parameters(structure, n) for n in range(len(structure))] structure.add_site_property("order_parameters", order_parameters) - return MoleculeGraph.with_local_env_strategy(structure, self) + return MoleculeGraph.from_local_env_strategy(structure, self) def get_nn_shell_info(self, structure: Structure, site_idx, shell): """Get a certain nearest neighbor shell for a certain site. @@ -1960,7 +1965,7 @@ def vol_tetra(vt1, vt2, vt3, vt4): vt4 (array-like): coordinates of vertex 4. Returns: - (float): volume of the tetrahedron. + float: volume of the tetrahedron. """ return np.abs(np.dot((vt1 - vt4), np.cross((vt2 - vt4), (vt3 - vt4)))) / 6 @@ -1977,7 +1982,7 @@ def get_okeeffe_params(el_symbol): el_symbol (str): element symbol. Returns: - (dict): atom-size ('r') and electronegativity-related ('c') parameter. + dict: atom-size ('r') and electronegativity-related ('c') parameter. """ el = Element(el_symbol) if el not in list(BV_PARAMS): @@ -3332,8 +3337,8 @@ def get_order_parameters( if self._geomops2: # Compute all (unique) angles and sort the resulting list. aij = [] - for ir, r in enumerate(rij_norm): - for j in range(ir + 1, len(rij_norm)): + for ir, r in enumerate(rij_norm, start=1): + for j in range(ir, len(rij_norm)): aij.append(acos(max(-1.0, min(np.inner(r, rij_norm[j]), 1.0)))) aijs = sorted(aij) @@ -3365,7 +3370,7 @@ def get_order_parameters( return ops -class BrunnerNN_reciprocal(NearNeighbors): +class BrunnerNNReciprocal(NearNeighbors): """ Determine coordination number using Brunner's algorithm which counts the atoms that are within the largest gap in differences in real space @@ -3410,9 +3415,8 @@ def get_nn_info(self, structure: Structure, n: int): n (int): index of site for which to determine near-neighbor sites. Returns: - siw (list of tuples (Site, array, float)): tuples, each one - of which represents a coordinated site, its image location, - and its weight. + list[tuples[Site, array, float]]: tuples, each one of which represents a + coordinated site, its image location, and its weight. """ site = structure[n] neighs_dists = structure.get_neighbors(site, self.cutoff) @@ -3423,21 +3427,29 @@ def get_nn_info(self, structure: Structure, n: int): d_max = ds[ns.index(max(ns))] siw = [] for nn in neighs_dists: - s, dist = nn, nn.nn_distance + site, dist = nn, nn.nn_distance if dist < d_max + self.tol: w = ds[0] / dist - siw.append( + siw += [ { - "site": s, - "image": self._get_image(structure, s), + "site": site, + "image": self._get_image(structure, site), "weight": w, - "site_index": self._get_original_site(structure, s), + "site_index": self._get_original_site(structure, site), } - ) + ] return siw -class BrunnerNN_relative(NearNeighbors): +@deprecated( + BrunnerNNReciprocal, + "Deprecated on 2024-03-29, to be removed on 2025-03-29.", +) +class BrunnerNN_reciprocal(BrunnerNNReciprocal): + pass + + +class BrunnerNNRelative(NearNeighbors): """ Determine coordination number using Brunner's algorithm which counts the atoms that are within the largest gap in differences in real space @@ -3509,7 +3521,15 @@ def get_nn_info(self, structure: Structure, n: int): return siw -class BrunnerNN_real(NearNeighbors): +@deprecated( + BrunnerNNRelative, + "Deprecated on 2024-03-29, to be removed on 2025-03-29.", +) +class BrunnerNN_relative(BrunnerNNRelative): + pass + + +class BrunnerNNReal(NearNeighbors): """ Determine coordination number using Brunner's algorithm which counts the atoms that are within the largest gap in differences in real space @@ -3581,6 +3601,14 @@ def get_nn_info(self, structure: Structure, n: int): return siw +@deprecated( + BrunnerNNReal, + "Deprecated on 2024-03-29, to be removed on 2025-03-29.", +) +class BrunnerNN_real(BrunnerNNReal): + pass + + class EconNN(NearNeighbors): """ Determines the average effective coordination number for each cation in a @@ -4060,19 +4088,21 @@ def _semicircle_integral(dist_bins, idx): Returns: float: integral of portion of unit semicircle """ - r = 1 + radius = 1 x1 = dist_bins[idx] x2 = dist_bins[idx + 1] if dist_bins[idx] == 1: - area1 = 0.25 * math.pi * r**2 + area1 = 0.25 * math.pi * radius**2 else: - area1 = 0.5 * ((x1 * math.sqrt(r**2 - x1**2)) + (r**2 * math.atan(x1 / math.sqrt(r**2 - x1**2)))) + area1 = 0.5 * ( + (x1 * math.sqrt(radius**2 - x1**2)) + (radius**2 * math.atan(x1 / math.sqrt(radius**2 - x1**2))) + ) - area2 = 0.5 * ((x2 * math.sqrt(r**2 - x2**2)) + (r**2 * math.atan(x2 / math.sqrt(r**2 - x2**2)))) + area2 = 0.5 * ((x2 * math.sqrt(radius**2 - x2**2)) + (radius**2 * math.atan(x2 / math.sqrt(radius**2 - x2**2)))) - return (area1 - area2) / (0.25 * math.pi * r**2) + return (area1 - area2) / (0.25 * math.pi * radius**2) @staticmethod def transform_to_length(nn_data, length): @@ -4134,10 +4164,10 @@ def _get_radius(site): return el.ionic_radii[oxi] # e.g., oxi = 2.667, average together 2+ and 3+ radii - if int(math.floor(oxi)) in el.ionic_radii and int(math.ceil(oxi)) in el.ionic_radii: - oxi_low = el.ionic_radii[int(math.floor(oxi))] - oxi_high = el.ionic_radii[int(math.ceil(oxi))] - x = oxi - int(math.floor(oxi)) + if math.floor(oxi) in el.ionic_radii and math.ceil(oxi) in el.ionic_radii: + oxi_low = el.ionic_radii[math.floor(oxi)] + oxi_high = el.ionic_radii[math.ceil(oxi)] + x = oxi - math.floor(oxi) return (1 - x) * oxi_low + x * oxi_high if oxi > 0 and el.average_cationic_radius > 0: @@ -4213,7 +4243,7 @@ def extend_structure_molecules(self) -> bool: return True @classmethod - def from_preset(cls, preset) -> CutOffDictNN: + def from_preset(cls, preset) -> Self: """ Initialize a CutOffDictNN according to a preset set of cutoffs. diff --git a/pymatgen/analysis/magnetism/analyzer.py b/pymatgen/analysis/magnetism/analyzer.py index a183a1d41ab..4f0002e8ae1 100644 --- a/pymatgen/analysis/magnetism/analyzer.py +++ b/pymatgen/analysis/magnetism/analyzer.py @@ -14,6 +14,7 @@ import numpy as np from monty.serialization import loadfn +from ruamel.yaml.error import MarkedYAMLError from scipy.signal import argrelextrema from scipy.stats import gaussian_kde @@ -40,10 +41,9 @@ try: DEFAULT_MAGMOMS = loadfn(f"{MODULE_DIR}/default_magmoms.yaml") -except Exception: +except (FileNotFoundError, MarkedYAMLError): warnings.warn("Could not load default_magmoms.yaml, falling back to VASPIncarBase.yaml") - DEFAULT_MAGMOMS = loadfn(f"{MODULE_DIR}/../../io/vasp/VASPIncarBase.yaml") - DEFAULT_MAGMOMS = DEFAULT_MAGMOMS["MAGMOM"] + DEFAULT_MAGMOMS = loadfn(f"{MODULE_DIR}/../../io/vasp/VASPIncarBase.yaml")["INCAR"]["MAGMOM"] @unique @@ -63,8 +63,9 @@ class OverwriteMagmomMode(Enum): none = "none" respect_sign = "respect_sign" - respect_zero = "respect_zeros" + respect_zeros = "respect_zeros" replace_all = "replace_all" + replace_all_if_undefined = "replace_all_if_undefined" normalize = "normalize" @@ -77,7 +78,7 @@ class CollinearMagneticStructureAnalyzer: def __init__( self, structure: Structure, - overwrite_magmom_mode: OverwriteMagmomMode | str = "none", + overwrite_magmom_mode: str | OverwriteMagmomMode = OverwriteMagmomMode.none, round_magmoms: bool = False, detect_valences: bool = False, make_primitive: bool = True, @@ -137,6 +138,8 @@ def __init__( in Bohr magneton) below which total magnetization is treated as zero when defining magnetic ordering. Defaults to 1e-8. """ + OverwriteMagmomMode(overwrite_magmom_mode) # raises ValueError on invalid mode + if default_magmoms: self.default_magmoms = default_magmoms else: @@ -177,15 +180,13 @@ def __init__( raise ValueError( "Structure contains magnetic moments on both " "magmom site properties and spin species " - "properties. This is ambiguous. Remove one or " - "the other." + "properties. This is ambiguous. Remove one or the other." ) if has_magmoms: if None in structure.site_properties["magmom"]: warnings.warn( - "Be careful with mixing types in your magmom " - "site properties. Any 'None' magmoms have been " - "replaced with zero." + "Be careful with mixing types in your magmom site properties. " + "Any 'None' magmoms have been replaced with zero." ) magmoms = [m or 0 for m in structure.site_properties["magmom"]] elif has_spin: @@ -209,8 +210,7 @@ def __init__( "give useful results, but use with caution." ) - # this is for collinear structures only, make sure magmoms - # are all floats + # this is for collinear structures only, make sure magmoms are all floats magmoms = list(map(float, magmoms)) # set properties that should be done /before/ we process input magmoms @@ -229,16 +229,6 @@ def __init__( ] # overwrite existing magmoms with default_magmoms - if overwrite_magmom_mode not in ( - "none", - "respect_sign", - "respect_zeros", - "replace_all", - "replace_all_if_undefined", - "normalize", - ): - raise ValueError("Unsupported mode.") - for idx, site in enumerate(structure): if site.species_string in self.default_magmoms: # look for species first, e.g. Fe2+ @@ -252,8 +242,7 @@ def __init__( # overwrite_magmom_mode = "respect_sign" will change magnitude of # existing moments only, and keep zero magmoms as # zero: it will keep the magnetic ordering intact - - if overwrite_magmom_mode == "respect_sign": + if overwrite_magmom_mode == OverwriteMagmomMode.respect_sign.value: set_net_positive = False if magmoms[idx] > 0: magmoms[idx] = default_magmom @@ -262,21 +251,18 @@ def __init__( # overwrite_magmom_mode = "respect_zeros" will give a ferromagnetic # structure but will keep zero magmoms as zero - - elif overwrite_magmom_mode == "respect_zeros": + elif overwrite_magmom_mode == OverwriteMagmomMode.respect_zeros.value: if magmoms[idx] != 0: magmoms[idx] = default_magmom # overwrite_magmom_mode = "replace_all" will ignore input magmoms # and give a ferromagnetic structure with magnetic # moments on *all* atoms it thinks could be magnetic - - elif overwrite_magmom_mode == "replace_all": + elif overwrite_magmom_mode == OverwriteMagmomMode.replace_all.value: magmoms[idx] = default_magmom # overwrite_magmom_mode = "normalize" set magmoms magnitude to 1 - - elif overwrite_magmom_mode == "normalize" and magmoms[idx] != 0: + elif overwrite_magmom_mode == OverwriteMagmomMode.normalize.value and magmoms[idx] != 0: magmoms[idx] = int(magmoms[idx] / abs(magmoms[idx])) # round magmoms, used to smooth out computational data @@ -315,11 +301,11 @@ def _round_magmoms(magmoms: ArrayLike, round_magmoms_mode: float) -> np.ndarray: kernel = gaussian_kde(magmoms, bw_method=round_magmoms_mode) # with a linearly spaced grid 1000x finer than width - xgrid = np.linspace(-range_m, range_m, int(1000 * range_m / round_magmoms_mode)) + x_grid = np.linspace(-range_m, range_m, int(1000 * range_m / round_magmoms_mode)) # and evaluate the kde on this grid, extracting the maxima of the kde peaks - kernel_m = kernel.evaluate(xgrid) - extrema = xgrid[argrelextrema(kernel_m, comparator=np.greater)] + kernel_m = kernel.evaluate(x_grid) + extrema = x_grid[argrelextrema(kernel_m, comparator=np.greater)] # round magmoms to these extrema magmoms = [extrema[(np.abs(extrema - m)).argmin()] for m in magmoms] @@ -414,8 +400,7 @@ def magmoms(self) -> np.ndarray: @property def types_of_magnetic_species(self) -> tuple[Element | Species | DummySpecies, ...]: - """Equivalent to Structure.types_of_specie but only returns - magnetic species. + """Equivalent to Structure.types_of_specie but only returns magnetic species. Returns: tuple: types of Species @@ -552,7 +537,7 @@ def matches_ordering(self, other: Structure) -> bool: Returns: bool: True if magnetic orderings match, False otherwise """ - a = CollinearMagneticStructureAnalyzer( + cmag_analyzer = CollinearMagneticStructureAnalyzer( self.structure, overwrite_magmom_mode="normalize" ).get_structure_with_spin() @@ -572,7 +557,7 @@ def matches_ordering(self, other: Structure) -> bool: b_positive = b_positive.get_structure_with_spin() analyzer = analyzer.get_structure_with_spin() - return a.matches(b_positive) or a.matches(analyzer) + return cmag_analyzer.matches(b_positive) or cmag_analyzer.matches(analyzer) def __str__(self): """ @@ -1039,8 +1024,8 @@ def _add_structures(ordered_structures, ordered_structures_origins, structures_t f"Removing {len(ordered_structures) - len(structs_to_keep)} low symmetry ordered structures" ) - ordered_structures = [ordered_structures[i] for i, _ in structs_to_keep] - ordered_structures_origins = [ordered_structures_origins[i] for i, _ in structs_to_keep] + ordered_structures = [ordered_structures[idx] for idx, _struct in structs_to_keep] + ordered_structures_origins = [ordered_structures_origins[idx] for idx, _struct in structs_to_keep] # and ensure fm is always at index 0 fm_index = ordered_structures_origins.index("fm") diff --git a/pymatgen/analysis/magnetism/heisenberg.py b/pymatgen/analysis/magnetism/heisenberg.py index 83cba241d1f..8d5922aa4db 100644 --- a/pymatgen/analysis/magnetism/heisenberg.py +++ b/pymatgen/analysis/magnetism/heisenberg.py @@ -9,6 +9,7 @@ import copy import logging from ast import literal_eval +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -21,6 +22,9 @@ from pymatgen.core.structure import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "ncfrey" __version__ = "0.1" __maintainer__ = "Nathan C. Frey" @@ -30,7 +34,18 @@ class HeisenbergMapper: - """Class to compute exchange parameters from low energy magnetic orderings.""" + """Class to compute exchange parameters from low energy magnetic orderings. + + Attributes: + strategy (object): Class from pymatgen.analysis.local_env for constructing graphs. + sgraphs (list): StructureGraph objects. + unique_site_ids (dict): Maps each site to its unique numerical identifier. + wyckoff_ids (dict): Maps unique numerical identifier to wyckoff position. + nn_interactions (dict): {i: j} pairs of NN interactions between unique sites. + dists (dict): NN, NNN, and NNNN interaction distances + ex_mat (DataFrame): Invertible Heisenberg Hamiltonian for each graph. + ex_params (dict): Exchange parameter values (meV/atom) + """ def __init__(self, ordered_structures, energies, cutoff=0, tol: float = 0.02): """ @@ -53,16 +68,6 @@ def __init__(self, ordered_structures, energies, cutoff=0, tol: float = 0.02): Defaults to 0 (only NN, no NNN, etc.) tol (float): Tolerance (in Angstrom) on nearest neighbor distances being equal. - - Parameters: - strategy (object): Class from pymatgen.analysis.local_env for constructing graphs. - sgraphs (list): StructureGraph objects. - unique_site_ids (dict): Maps each site to its unique numerical identifier. - wyckoff_ids (dict): Maps unique numerical identifier to wyckoff position. - nn_interactions (dict): {i: j} pairs of NN interactions between unique sites. - dists (dict): NN, NNN, and NNNN interaction distances - ex_mat (DataFrame): Invertible Heisenberg Hamiltonian for each graph. - ex_params (dict): Exchange parameter values (meV/atom) """ # Save original copies of inputs self.ordered_structures_ = ordered_structures @@ -111,7 +116,7 @@ def _get_graphs(cutoff, ordered_structures): strategy = MinimumDistanceNN(cutoff=cutoff, get_all_sites=True) if cutoff else MinimumDistanceNN() # only NN # Generate structure graphs - return [StructureGraph.with_local_env_strategy(s, strategy=strategy) for s in ordered_structures] + return [StructureGraph.from_local_env_strategy(s, strategy=strategy) for s in ordered_structures] @staticmethod def _get_unique_sites(structure): @@ -182,9 +187,9 @@ def _get_nn_dict(self): # Keep only up to NNNN and call dists equal if they are within tol all_dists = sorted(set(all_dists)) rm_list = [] - for idx, d in enumerate(all_dists[:-1]): - if abs(d - all_dists[idx + 1]) < tol: - rm_list.append(idx + 1) + for idx, d in enumerate(all_dists[:-1], start=1): + if abs(d - all_dists[idx]) < tol: + rm_list.append(idx) all_dists = [d for idx, d in enumerate(all_dists) if idx not in rm_list] @@ -537,7 +542,7 @@ def get_interaction_graph(self, filename=None): structure = self.ordered_structures[0] sgraph = self.sgraphs[0] - igraph = StructureGraph.with_empty_graph( + igraph = StructureGraph.from_empty_graph( structure, edge_weight_name="exchange_constant", edge_weight_units="meV" ) @@ -562,11 +567,10 @@ def get_interaction_graph(self, filename=None): # Save to a json file if desired if filename: - if filename.endswith(".json"): - dumpfn(igraph, filename) - else: + if not filename.endswith(".json"): filename += ".json" - dumpfn(igraph, filename) + + dumpfn(igraph, filename) return igraph @@ -724,8 +728,9 @@ def _do_cleanup(structures, energies): # Check for duplicate / degenerate states (sometimes different initial # configs relax to the same state) remove_list = [] + e_tol = 6 # 10^-6 eV/atom tol on energies + for idx, energy in enumerate(energies): - e_tol = 6 # 10^-6 eV/atom tol on energies energy = round(energy, e_tol) if idx not in remove_list: for i_check, e_check in enumerate(energies): @@ -735,16 +740,15 @@ def _do_cleanup(structures, energies): # Also discard structures with small |magmoms| < 0.1 uB # xx - get rid of these or just bury them in the list? - # for i, s in enumerate(ordered_structures): - # magmoms = s.site_properties['magmom'] - # if i not in remove_list: - # if any(abs(m) < 0.1 for m in magmoms): - # remove_list.append(i) + # for idx, struct in enumerate(ordered_structures): + # magmoms = struct.site_properties["magmom"] + # if idx not in remove_list and any(abs(m) < 0.1 for m in magmoms): + # remove_list.append(idx) # Remove duplicates - if len(remove_list) > 0: - ordered_structures = [s for i, s in enumerate(ordered_structures) if i not in remove_list] - energies = [e for i, e in enumerate(energies) if i not in remove_list] + if remove_list: + ordered_structures = [struct for idx, struct in enumerate(ordered_structures) if idx not in remove_list] + energies = [energy for idx, energy in enumerate(energies) if idx not in remove_list] # Sort by energy if not already sorted ordered_structures = [s for _, s in sorted(zip(energies, ordered_structures), reverse=False)] @@ -881,63 +885,63 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Create a HeisenbergModel from a dict.""" # Reconstitute the site ids usids = {} wids = {} nnis = {} - for k, v in d["nn_interactions"].items(): + for k, v in dct["nn_interactions"].items(): nn_dict = {} for k1, v1 in v.items(): key = literal_eval(k1) nn_dict[key] = v1 nnis[k] = nn_dict - for k, v in d["unique_site_ids"].items(): + for k, v in dct["unique_site_ids"].items(): key = literal_eval(k) if isinstance(key, int): usids[(key,)] = v elif isinstance(key, tuple): usids[key] = v - for k, v in d["wyckoff_ids"].items(): + for k, v in dct["wyckoff_ids"].items(): key = literal_eval(k) wids[key] = v # Reconstitute the structure and graph objects structures = [] sgraphs = [] - for v in d["structures"]: + for v in dct["structures"]: structures.append(Structure.from_dict(v)) - for v in d["sgraphs"]: + for v in dct["sgraphs"]: sgraphs.append(StructureGraph.from_dict(v)) # Interaction graph - igraph = StructureGraph.from_dict(d["igraph"]) + igraph = StructureGraph.from_dict(dct["igraph"]) # Reconstitute the exchange matrix DataFrame try: - ex_mat = eval(d["ex_mat"]) + ex_mat = literal_eval(dct["ex_mat"]) ex_mat = pd.DataFrame.from_dict(ex_mat) except SyntaxError: # if ex_mat is empty ex_mat = pd.DataFrame(columns=["E", "E0"]) return HeisenbergModel( - formula=d["formula"], + formula=dct["formula"], structures=structures, - energies=d["energies"], - cutoff=d["cutoff"], - tol=d["tol"], + energies=dct["energies"], + cutoff=dct["cutoff"], + tol=dct["tol"], sgraphs=sgraphs, unique_site_ids=usids, wyckoff_ids=wids, nn_interactions=nnis, - dists=d["dists"], + dists=dct["dists"], ex_mat=ex_mat, - ex_params=d["ex_params"], - javg=d["javg"], + ex_params=dct["ex_params"], + javg=dct["javg"], igraph=igraph, ) diff --git a/pymatgen/analysis/molecule_matcher.py b/pymatgen/analysis/molecule_matcher.py index e4056d31536..88b3115b8c2 100644 --- a/pymatgen/analysis/molecule_matcher.py +++ b/pymatgen/analysis/molecule_matcher.py @@ -17,6 +17,7 @@ import logging import math import re +from typing import TYPE_CHECKING import numpy as np from monty.dev import requires @@ -33,6 +34,9 @@ except ImportError: openbabel = None +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Xiaohui Qu, Adam Fekete" __version__ = "1.0" @@ -80,7 +84,7 @@ def get_molecule_hash(self, mol): """ @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -166,17 +170,17 @@ def uniform_labels(self, mol1, mol2): def get_molecule_hash(self, mol): """Return inchi as molecular hash.""" - obconv = openbabel.OBConversion() - obconv.SetOutFormat("inchi") - obconv.AddOption("X", openbabel.OBConversion.OUTOPTIONS, "DoNotAddH") - inchi_text = obconv.WriteString(mol) + ob_conv = openbabel.OBConversion() + ob_conv.SetOutFormat("inchi") + ob_conv.AddOption("X", openbabel.OBConversion.OUTOPTIONS, "DoNotAddH") + inchi_text = ob_conv.WriteString(mol) match = re.search(r"InChI=(?P.+)\n", inchi_text) return match.group("inchi") def as_dict(self): """ Returns: - Jsonable dict. + JSON-able dict. """ return { "version": __version__, @@ -185,10 +189,10 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: IsomorphismMolAtomMapper @@ -220,15 +224,15 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict Representation. + dct (dict): Dict Representation. Returns: InchiMolAtomMapper """ - return cls(angle_tolerance=d["angle_tolerance"]) + return cls(angle_tolerance=dct["angle_tolerance"]) @staticmethod def _inchi_labels(mol): @@ -299,7 +303,7 @@ def _virtual_molecule(self, mol, ilabels, eq_atoms): farthest_group_idx: The equivalent atom group index in which there is the farthest atom to the centroid - Return: + Returns: The virtual molecule """ vmol = openbabel.OBMol() @@ -350,7 +354,7 @@ def _align_heavy_atoms(mol1, mol2, vmol1, vmol2, ilabel1, ilabel2, eq_atoms): ilabel2: inchi label map of the second molecule eq_atoms: equivalent atom labels - Return: + Returns: corrected inchi labels of heavy atoms of the second molecule """ n_virtual = vmol1.NumAtoms() @@ -423,7 +427,7 @@ def _align_hydrogen_atoms(mol1, mol2, heavy_indices1, heavy_indices2): heavy_indices1: inchi label map of the first molecule heavy_indices2: label map of the second molecule - Return: + Returns: corrected label map of all atoms of the second molecule """ num_atoms = mol2.NumAtoms() @@ -534,7 +538,7 @@ def uniform_labels(self, mol1, mol2): return None, None # Topologically different if iequal_atom1 != iequal_atom2: - raise Exception("Design Error! Equivalent atoms are inconsistent") + raise RuntimeError("Design Error! Equivalent atoms are inconsistent") vmol1 = self._virtual_molecule(ob_mol1, ilabel1, iequal_atom1) vmol2 = self._virtual_molecule(ob_mol2, ilabel2, iequal_atom2) @@ -707,17 +711,17 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: MoleculeMatcher """ return cls( - tolerance=d["tolerance"], - mapper=AbstractMolAtomMapper.from_dict(d["mapper"]), + tolerance=dct["tolerance"], + mapper=AbstractMolAtomMapper.from_dict(dct["mapper"]), ) diff --git a/pymatgen/analysis/molecule_structure_comparator.py b/pymatgen/analysis/molecule_structure_comparator.py index aea3d71c936..9e89308c370 100644 --- a/pymatgen/analysis/molecule_structure_comparator.py +++ b/pymatgen/analysis/molecule_structure_comparator.py @@ -11,11 +11,15 @@ from __future__ import annotations import itertools +from typing import TYPE_CHECKING from monty.json import MSONable from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Xiaohui Qu" __copyright__ = "Copyright 2011, The Materials Project" __version__ = "1.0" @@ -263,10 +267,10 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: MoleculeStructureComparator diff --git a/pymatgen/analysis/nmr.py b/pymatgen/analysis/nmr.py index b5a0c8cb31a..4f597ad2fba 100644 --- a/pymatgen/analysis/nmr.py +++ b/pymatgen/analysis/nmr.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import namedtuple +from typing import TYPE_CHECKING import numpy as np @@ -11,6 +12,9 @@ from pymatgen.core.units import FloatWithUnit from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Shyam Dwaraknath" __copyright__ = "Copyright 2016, The Materials Project" __version__ = "0.2" @@ -36,7 +40,7 @@ class ChemicalShielding(SquareTensor): MehringNotation = namedtuple("MehringNotation", "sigma_iso, sigma_11, sigma_22, sigma_33") MarylandNotation = namedtuple("MarylandNotation", "sigma_iso, omega, kappa") - def __new__(cls, cs_matrix, vscale=None): + def __new__(cls, cs_matrix, vscale=None) -> Self | None: # type: ignore[misc] """ Create a Chemical Shielding tensor. Note that the constructor uses __new__ @@ -100,7 +104,7 @@ def maryland_values(self): return self.MarylandNotation(sigma_iso, omega, kappa) @classmethod - def from_maryland_notation(cls, sigma_iso, omega, kappa): + def from_maryland_notation(cls, sigma_iso, omega, kappa) -> Self: """ Initialize from Maryland notation. @@ -126,7 +130,7 @@ class ElectricFieldGradient(SquareTensor): Authors: Shyam Dwaraknath, Xiaohui Qu """ - def __new__(cls, efg_matrix, vscale=None): + def __new__(cls, efg_matrix, vscale=None) -> Self | None: # type: ignore[misc] """ Create a Chemical Shielding tensor. Note that the constructor uses __new__ @@ -202,7 +206,7 @@ def coupling_constant(self, specie): Can take a isotope or element string, Species object, or Site object - Return: + Returns: the coupling constant as a FloatWithUnit in MHz """ planks_constant = FloatWithUnit(6.62607004e-34, "m^2 kg s^-1") @@ -215,16 +219,16 @@ def coupling_constant(self, specie): if len(specie.split("-")) > 1: isotope = str(specie) specie = Species(specie.split("-")[0]) - Q = specie.get_nmr_quadrupole_moment(isotope) + quad_pol_mom = specie.get_nmr_quadrupole_moment(isotope) else: specie = Species(specie) - Q = specie.get_nmr_quadrupole_moment() + quad_pol_mom = specie.get_nmr_quadrupole_moment() elif isinstance(specie, Site): specie = specie.specie - Q = specie.get_nmr_quadrupole_moment() + quad_pol_mom = specie.get_nmr_quadrupole_moment() elif isinstance(specie, Species): - Q = specie.get_nmr_quadrupole_moment() + quad_pol_mom = specie.get_nmr_quadrupole_moment() else: raise ValueError("Invalid species provided for quadrupolar coupling constant calculations") - return (e * Q * Vzz / planks_constant).to("MHz") + return (e * quad_pol_mom * Vzz / planks_constant).to("MHz") diff --git a/pymatgen/analysis/phase_diagram.py b/pymatgen/analysis/phase_diagram.py index f96695bb7be..397cd6cc90d 100644 --- a/pymatgen/analysis/phase_diagram.py +++ b/pymatgen/analysis/phase_diagram.py @@ -40,10 +40,11 @@ from io import StringIO from numpy.typing import ArrayLike + from typing_extensions import Self logger = logging.getLogger(__name__) -with open(os.path.join(os.path.dirname(__file__), "..", "util", "plotly_pd_layouts.json")) as file: +with open(os.path.join(os.path.dirname(__file__), "..", "util", "plotly_pd_layouts.json"), encoding="utf-8") as file: plotly_layouts = json.load(file) @@ -105,7 +106,7 @@ def as_dict(self): return return_dict @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): dictionary representation of PDEntry. @@ -199,17 +200,17 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): dictionary representation of GrandPotPDEntry. + dct (dict): dictionary representation of GrandPotPDEntry. Returns: GrandPotPDEntry """ - chempots = {Element(symbol): u for symbol, u in d["chempots"].items()} - entry = MontyDecoder().process_decoded(d["entry"]) - return cls(entry, chempots, d["name"]) + chempots = {Element(symbol): u for symbol, u in dct["chempots"].items()} + entry = MontyDecoder().process_decoded(dct["entry"]) + return cls(entry, chempots, dct["name"]) class TransformedPDEntry(PDEntry): @@ -284,17 +285,17 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): dictionary representation of TransformedPDEntry. + dct (dict): dictionary representation of TransformedPDEntry. Returns: TransformedPDEntry """ - sp_mapping = d["sp_mapping"] - del d["sp_mapping"] - entry = MontyDecoder().process_decoded(d) + sp_mapping = dct["sp_mapping"] + del dct["sp_mapping"] + entry = MontyDecoder().process_decoded(dct) return cls(entry, sp_mapping) @@ -408,10 +409,10 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict[str, Any]) -> PhaseDiagram: + def from_dict(cls, dct: dict[str, Any]) -> Self: """ Args: - d (dict): dictionary representation of PhaseDiagram. + dct (dict): dictionary representation of PhaseDiagram. Returns: PhaseDiagram @@ -482,7 +483,7 @@ def _compute(self) -> dict[str, Any]: final_facets.append(facet) facets = final_facets - simplexes = [Simplex(qhull_data[f, :-1]) for f in facets] + simplexes = [Simplex(qhull_data[facet, :-1]) for facet in facets] self.elements = elements return { "facets": facets, @@ -617,10 +618,10 @@ def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]: Args: comp (Composition): A composition """ - c = self.pd_coords(comp) - for f, s in zip(self.facets, self.simplexes): - if s.in_simplex(c, PhaseDiagram.numerical_tol / 10): - return f, s + coord = self.pd_coords(comp) + for facet, simplex in zip(self.facets, self.simplexes): + if simplex.in_simplex(coord, PhaseDiagram.numerical_tol / 10): + return facet, simplex raise RuntimeError(f"No facet found for {comp = }") @@ -631,10 +632,12 @@ def _get_all_facets_and_simplexes(self, comp): Args: comp (Composition): A composition """ - c = self.pd_coords(comp) + coords = self.pd_coords(comp) all_facets = [ - f for f, s in zip(self.facets, self.simplexes) if s.in_simplex(c, PhaseDiagram.numerical_tol / 10) + facet + for facet, simplex in zip(self.facets, self.simplexes) + if simplex.in_simplex(coords, PhaseDiagram.numerical_tol / 10) ] if not all_facets: @@ -1222,14 +1225,14 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): which each element has a chemical potential set to a given value. "absolute" values (i.e., not referenced to element energies) """ - mu_ref = np.array([self.el_refs[e].energy_per_atom for e in self.elements if e != dep_elt]) - chempot_ranges = self.get_chempot_range_map([e for e in self.elements if e != dep_elt]) + mu_ref = np.array([self.el_refs[elem].energy_per_atom for elem in self.elements if elem != dep_elt]) + chempot_ranges = self.get_chempot_range_map([elem for elem in self.elements if elem != dep_elt]) for elem in self.elements: if elem not in target_comp.elements: target_comp = target_comp + Composition({elem: 0.0}) - coeff = [-target_comp[e] for e in self.elements if e != dep_elt] + coeff = [-target_comp[elem] for elem in self.elements if elem != dep_elt] for elem, chempots in chempot_ranges.items(): if elem.composition.reduced_composition == target_comp.reduced_composition: @@ -1238,7 +1241,7 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): all_coords = [] for simplex in chempots: for v in simplex._coords: - elements = [e for e in self.elements if e != dep_elt] + elements = [elem for elem in self.elements if elem != dep_elt] res = {} for idx, el in enumerate(elements): res[el] = v[idx] + mu_ref[idx] @@ -1274,26 +1277,26 @@ def get_chempot_range_stability_phase(self, target_comp, open_elt): {Element: (mu_min, mu_max)}: Chemical potentials are given in "absolute" values (i.e., not referenced to 0) """ - muref = np.array([self.el_refs[e].energy_per_atom for e in self.elements if e != open_elt]) - chempot_ranges = self.get_chempot_range_map([e for e in self.elements if e != open_elt]) - for e in self.elements: - if e not in target_comp.elements: - target_comp = target_comp + Composition({e: 0.0}) + mu_ref = np.array([self.el_refs[elem].energy_per_atom for elem in self.elements if elem != open_elt]) + chempot_ranges = self.get_chempot_range_map([elem for elem in self.elements if elem != open_elt]) + for elem in self.elements: + if elem not in target_comp.elements: + target_comp = target_comp + Composition({elem: 0.0}) - coeff = [-target_comp[e] for e in self.elements if e != open_elt] + coeff = [-target_comp[elem] for elem in self.elements if elem != open_elt] max_open = -float("inf") min_open = float("inf") max_mus = min_mus = None - for e, chempots in chempot_ranges.items(): - if e.composition.reduced_composition == target_comp.reduced_composition: - multiplicator = e.composition[open_elt] / target_comp[open_elt] - ef = e.energy / multiplicator + for elem, chempots in chempot_ranges.items(): + if elem.composition.reduced_composition == target_comp.reduced_composition: + multiplier = elem.composition[open_elt] / target_comp[open_elt] + ef = elem.energy / multiplier all_coords = [] for s in chempots: for v in s._coords: all_coords.append(v) - test_open = (np.dot(v + muref, coeff) + ef) / target_comp[open_elt] + test_open = (np.dot(v + mu_ref, coeff) + ef) / target_comp[open_elt] if test_open > max_open: max_open = test_open max_mus = v @@ -1301,11 +1304,11 @@ def get_chempot_range_stability_phase(self, target_comp, open_elt): min_open = test_open min_mus = v - elems = [e for e in self.elements if e != open_elt] + elems = [elem for elem in self.elements if elem != open_elt] res = {} for idx, el in enumerate(elems): - res[el] = (min_mus[idx] + muref[idx], max_mus[idx] + muref[idx]) + res[el] = (min_mus[idx] + mu_ref[idx], max_mus[idx] + mu_ref[idx]) res[open_elt] = (min_open, max_open) return res @@ -1419,12 +1422,14 @@ def __init__(self, entries, chempots, elements=None, *, computed_data=None): when generated for the first time. """ if elements is None: - elements = {els for e in entries for els in e.elements} + elements = {els for entry in entries for els in entry.elements} self.chempots = {get_el_sp(el): u for el, u in chempots.items()} elements = set(elements) - set(self.chempots) - all_entries = [GrandPotPDEntry(e, self.chempots) for e in entries if len(elements.intersection(e.elements)) > 0] + all_entries = [ + GrandPotPDEntry(entry, self.chempots) for entry in entries if len(elements.intersection(entry.elements)) > 0 + ] super().__init__(all_entries, elements, computed_data=None) @@ -1447,23 +1452,23 @@ def as_dict(self): return { "@module": type(self).__module__, "@class": type(self).__name__, - "all_entries": [e.as_dict() for e in self.all_entries], + "all_entries": [entry.as_dict() for entry in self.all_entries], "chempots": self.chempots, - "elements": [e.as_dict() for e in self.elements], + "elements": [entry.as_dict() for entry in self.elements], } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): dictionary representation of GrandPotentialPhaseDiagram. + dct (dict): dictionary representation of GrandPotentialPhaseDiagram. Returns: GrandPotentialPhaseDiagram """ - entries = MontyDecoder().process_decoded(d["all_entries"]) - elements = MontyDecoder().process_decoded(d["elements"]) - return cls(entries, d["chempots"], elements) + entries = MontyDecoder().process_decoded(dct["all_entries"]) + elements = MontyDecoder().process_decoded(dct["elements"]) + return cls(entries, dct["chempots"], elements) class CompoundPhaseDiagram(PhaseDiagram): @@ -1549,24 +1554,23 @@ def as_dict(self): return { "@module": type(self).__module__, "@class": type(self).__name__, - "original_entries": [e.as_dict() for e in self.original_entries], + "original_entries": [entry.as_dict() for entry in self.original_entries], "terminal_compositions": [c.as_dict() for c in self.terminal_compositions], "normalize_terminal_compositions": self.normalize_terminals, } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): dictionary representation of CompoundPhaseDiagram. + dct (dict): dictionary representation of CompoundPhaseDiagram. Returns: CompoundPhaseDiagram """ - dec = MontyDecoder() - entries = dec.process_decoded(d["original_entries"]) - terminal_compositions = dec.process_decoded(d["terminal_compositions"]) - return cls(entries, terminal_compositions, d["normalize_terminal_compositions"]) + entries = MontyDecoder().process_decoded(dct["original_entries"]) + terminal_compositions = MontyDecoder().process_decoded(dct["terminal_compositions"]) + return cls(entries, terminal_compositions, dct["normalize_terminal_compositions"]) class PatchedPhaseDiagram(PhaseDiagram): @@ -1614,7 +1618,7 @@ def __init__( verbose (bool): Whether to show progress bar during convex hull construction. """ if elements is None: - elements = sorted({els for e in entries for els in e.elements}) + elements = sorted({els for entry in entries for els in entry.elements}) self.dim = len(elements) @@ -1639,7 +1643,10 @@ def __init__( raise ValueError(f"There are more terminal elements than dimensions: {extra}") data = np.array( - [[e.composition.get_atomic_fraction(el) for el in elements] + [e.energy_per_atom] for e in min_entries] + [ + [*(entry.composition.get_atomic_fraction(el) for el in elements), entry.energy_per_atom] + for entry in min_entries + ] ) # Use only entries with negative formation energy @@ -1653,7 +1660,7 @@ def __init__( self.qhull_entries = tuple(min_entries[idx] for idx in inds) # make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable # prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically) - self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries) + self._qhull_spaces = tuple(frozenset(entry.elements) for entry in self.qhull_entries) # Get all unique chemical spaces spaces = {s for s in self._qhull_spaces if len(s) > 1} @@ -1682,7 +1689,7 @@ def __init__( # NOTE add el_refs in case no multielement entries are present for el _stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries} self._stable_entries = tuple(_stable_entries | {*self.el_refs.values()}) - self._stable_spaces = tuple(frozenset(e.elements) for e in self._stable_entries) + self._stable_spaces = tuple(frozenset(entry.elements) for entry in self._stable_entries) def __repr__(self): return f"{type(self).__name__} covering {len(self.spaces)} sub-spaces" @@ -1713,15 +1720,15 @@ def as_dict(self) -> dict[str, Any]: return { "@module": type(self).__module__, "@class": type(self).__name__, - "all_entries": [e.as_dict() for e in self.all_entries], - "elements": [e.as_dict() for e in self.elements], + "all_entries": [entry.as_dict() for entry in self.all_entries], + "elements": [entry.as_dict() for entry in self.elements], } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): dictionary representation of PatchedPhaseDiagram. + dct (dict): dictionary representation of PatchedPhaseDiagram. Returns: PatchedPhaseDiagram @@ -1938,7 +1945,7 @@ def fmt(fl): for face in itertools.combinations(facet, len(facet) - 1): face_entries = [pd.qhull_entries[idx] for idx in face] - if any(e.reduced_formula in terminal_formulas for e in face_entries): + if any(entry.reduced_formula in terminal_formulas for entry in face_entries): continue try: @@ -1973,8 +1980,8 @@ def fmt(fl): for c, entry in zip(coeffs[:-1], face_entries): if c > tol: - r = entry.composition.reduced_composition - products.append(f"{fmt(c / r.num_atoms * factor)} {r.reduced_formula}") + redu_comp = entry.composition.reduced_composition + products.append(f"{fmt(c / redu_comp.num_atoms * factor)} {redu_comp.reduced_formula}") product_entries.append((c, entry)) energy += c * entry.energy_per_atom @@ -1991,7 +1998,7 @@ def fmt(fl): form_1 = entry1.reduced_formula form_2 = entry2.reduced_formula logger.debug(f"Reactants = {form_1}, {form_2}") - logger.debug(f"Products = {', '.join([e.reduced_formula for e in face_entries])}") + logger.debug(f"Products = {', '.join([entry.reduced_formula for entry in face_entries])}") rxn_entries = sorted(rxn_entries, key=lambda e: e.name, reverse=True) @@ -1999,9 +2006,9 @@ def fmt(fl): self.entry2 = entry2 self.rxn_entries = rxn_entries self.labels = {} - for idx, entry in enumerate(rxn_entries): - self.labels[str(idx + 1)] = entry.attribute - entry.name = str(idx + 1) + for idx, entry in enumerate(rxn_entries, start=1): + self.labels[str(idx)] = entry.attribute + entry.name = str(idx) self.all_entries = all_entries self.pd = pd @@ -2179,7 +2186,7 @@ def __init__( self.ternary_style = ternary_style.lower() self.lines = uniquelines(self._pd.facets) if dim > 1 else [[self._pd.facets[0][0], self._pd.facets[0][0]]] - self._min_energy = min(self._pd.get_form_energy_per_atom(e) for e in self._pd.stable_entries) + self._min_energy = min(self._pd.get_form_energy_per_atom(entry) for entry in self._pd.stable_entries) self._dim = dim self.plotkwargs = plotkwargs or { @@ -2420,7 +2427,7 @@ class (pymatgen.analysis.chempot_diagram). contain_zero = any(comp.get_atomic_fraction(el) == 0 for el in elements) is_boundary = (not contain_zero) and sum(comp.get_atomic_fraction(el) for el in elements) == 1 for line in lines: - (x, y) = line.coords.transpose() + x, y = line.coords.transpose() plt.plot(x, y, "k-") for coord in line.coords: @@ -2724,10 +2731,10 @@ def _create_plotly_fill(self): c = [e0.composition[el_c], e1.composition[el_c], e2.composition[el_c]] e_strs = [] - for e in (e0, e1, e2): - if hasattr(e, "original_entry"): - e = e.original_entry - e_strs.append(htmlify(e.reduced_formula)) + for entry in (e0, e1, e2): + if hasattr(entry, "original_entry"): + entry = entry.original_entry + e_strs.append(htmlify(entry.reduced_formula)) name = f"{e_strs[0]}โ€”{e_strs[1]}โ€”{e_strs[2]}" @@ -2751,7 +2758,7 @@ def _create_plotly_fill(self): coords = np.array( [triangular_coord(c) for c in zip(self._pd.qhull_data[:-1, 0], self._pd.qhull_data[:-1, 1])] ) - energies = np.array([self._pd.get_form_energy_per_atom(e) for e in self._pd.qhull_entries]) + energies = np.array([self._pd.get_form_energy_per_atom(entry) for entry in self._pd.qhull_entries]) traces.append( go.Mesh3d( @@ -3479,7 +3486,7 @@ def _create_plotly_ternary_support_lines(self): """ stable_entry_coords = dict(map(reversed, self.pd_plot_data[1].items())) - elem_coords = [stable_entry_coords[e] for e in self._pd.el_refs.values()] + elem_coords = [stable_entry_coords[entry] for entry in self._pd.el_refs.values()] # add top and bottom triangle guidelines x, y, z = [], [], [] @@ -3756,11 +3763,7 @@ def uniquelines(q): setoflines: A set of tuple of lines. E.g., ((1,2), (1,3), (2,3), ....) """ - setoflines = set() - for facets in q: - for line in itertools.combinations(facets, 2): - setoflines.add(tuple(sorted(line))) - return setoflines + return {tuple(sorted(line)) for facets in q for line in itertools.combinations(facets, 2)} def triangular_coord(coord): @@ -3820,13 +3823,13 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering): 'Left','Right'] Returns: - (newlines, newstable_entries, newunstable_entries): - - newlines is a list of list of coordinates for lines in the PD. - - newstable_entries is a {coordinate : entry} for each stable node - in the phase diagram. (Each coordinate can only have one - stable phase) - - newunstable_entries is a {entry: coordinates} for all unstable - nodes in the phase diagram. + tuple[list, dict, dict]: + - new_lines is a list of list of coordinates for lines in the PD. + - new_stable_entries is a {coordinate: entry} for each stable node + in the phase diagram. (Each coordinate can only have one + stable phase) + - new_unstable_entries is a {entry: coordinates} for all unstable + nodes in the phase diagram. """ yup = -1000.0 xleft = 1000.0 @@ -3856,104 +3859,104 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering): # The coordinates were already in the user ordering return lines, stable_entries, unstable_entries - newlines = [[np.array(1 - x), y] for x, y in lines] - newstable_entries = {(1 - c[0], c[1]): entry for c, entry in stable_entries.items()} - newunstable_entries = {entry: (1 - c[0], c[1]) for entry, c in unstable_entries.items()} - return newlines, newstable_entries, newunstable_entries + new_lines = [[np.array(1 - x), y] for x, y in lines] + new_stable_entries = {(1 - c[0], c[1]): entry for c, entry in stable_entries.items()} + new_unstable_entries = {entry: (1 - c[0], c[1]) for entry, c in unstable_entries.items()} + return new_lines, new_stable_entries, new_unstable_entries if nameup == ordering[1]: if nameleft == ordering[2]: c120 = np.cos(2 * np.pi / 3.0) s120 = np.sin(2 * np.pi / 3.0) - newlines = [] + new_lines = [] for x, y in lines: newx = np.zeros_like(x) newy = np.zeros_like(y) for ii, xx in enumerate(x): newx[ii] = c120 * (xx - cc[0]) - s120 * (y[ii] - cc[1]) + cc[0] newy[ii] = s120 * (xx - cc[0]) + c120 * (y[ii] - cc[1]) + cc[1] - newlines.append([newx, newy]) - newstable_entries = { + new_lines.append([newx, newy]) + new_stable_entries = { ( c120 * (c[0] - cc[0]) - s120 * (c[1] - cc[1]) + cc[0], s120 * (c[0] - cc[0]) + c120 * (c[1] - cc[1]) + cc[1], ): entry for c, entry in stable_entries.items() } - newunstable_entries = { + new_unstable_entries = { entry: ( c120 * (c[0] - cc[0]) - s120 * (c[1] - cc[1]) + cc[0], s120 * (c[0] - cc[0]) + c120 * (c[1] - cc[1]) + cc[1], ) for entry, c in unstable_entries.items() } - return newlines, newstable_entries, newunstable_entries + return new_lines, new_stable_entries, new_unstable_entries c120 = np.cos(2 * np.pi / 3.0) s120 = np.sin(2 * np.pi / 3.0) - newlines = [] + new_lines = [] for x, y in lines: newx = np.zeros_like(x) newy = np.zeros_like(y) for ii, xx in enumerate(x): newx[ii] = -c120 * (xx - 1.0) - s120 * y[ii] + 1.0 newy[ii] = -s120 * (xx - 1.0) + c120 * y[ii] - newlines.append([newx, newy]) - newstable_entries = { + new_lines.append([newx, newy]) + new_stable_entries = { ( -c120 * (c[0] - 1.0) - s120 * c[1] + 1.0, -s120 * (c[0] - 1.0) + c120 * c[1], ): entry for c, entry in stable_entries.items() } - newunstable_entries = { + new_unstable_entries = { entry: ( -c120 * (c[0] - 1.0) - s120 * c[1] + 1.0, -s120 * (c[0] - 1.0) + c120 * c[1], ) for entry, c in unstable_entries.items() } - return newlines, newstable_entries, newunstable_entries + return new_lines, new_stable_entries, new_unstable_entries if nameup == ordering[2]: if nameleft == ordering[0]: c240 = np.cos(4 * np.pi / 3.0) s240 = np.sin(4 * np.pi / 3.0) - newlines = [] + new_lines = [] for x, y in lines: newx = np.zeros_like(x) newy = np.zeros_like(y) for ii, xx in enumerate(x): newx[ii] = c240 * (xx - cc[0]) - s240 * (y[ii] - cc[1]) + cc[0] newy[ii] = s240 * (xx - cc[0]) + c240 * (y[ii] - cc[1]) + cc[1] - newlines.append([newx, newy]) - newstable_entries = { + new_lines.append([newx, newy]) + new_stable_entries = { ( c240 * (c[0] - cc[0]) - s240 * (c[1] - cc[1]) + cc[0], s240 * (c[0] - cc[0]) + c240 * (c[1] - cc[1]) + cc[1], ): entry for c, entry in stable_entries.items() } - newunstable_entries = { + new_unstable_entries = { entry: ( c240 * (c[0] - cc[0]) - s240 * (c[1] - cc[1]) + cc[0], s240 * (c[0] - cc[0]) + c240 * (c[1] - cc[1]) + cc[1], ) for entry, c in unstable_entries.items() } - return newlines, newstable_entries, newunstable_entries + return new_lines, new_stable_entries, new_unstable_entries c240 = np.cos(4 * np.pi / 3.0) s240 = np.sin(4 * np.pi / 3.0) - newlines = [] + new_lines = [] for x, y in lines: newx = np.zeros_like(x) newy = np.zeros_like(y) for ii, xx in enumerate(x): newx[ii] = -c240 * xx - s240 * y[ii] newy[ii] = -s240 * xx + c240 * y[ii] - newlines.append([newx, newy]) - newstable_entries = { + new_lines.append([newx, newy]) + new_stable_entries = { (-c240 * c[0] - s240 * c[1], -s240 * c[0] + c240 * c[1]): entry for c, entry in stable_entries.items() } - newunstable_entries = { + new_unstable_entries = { entry: (-c240 * c[0] - s240 * c[1], -s240 * c[0] + c240 * c[1]) for entry, c in unstable_entries.items() } - return newlines, newstable_entries, newunstable_entries + return new_lines, new_stable_entries, new_unstable_entries raise ValueError("Invalid ordering.") diff --git a/pymatgen/analysis/piezo.py b/pymatgen/analysis/piezo.py index a29d393abf4..2f168e1eef4 100644 --- a/pymatgen/analysis/piezo.py +++ b/pymatgen/analysis/piezo.py @@ -3,11 +3,16 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING import numpy as np from pymatgen.core.tensors import Tensor +if TYPE_CHECKING: + from numpy.typing import ArrayLike + from typing_extensions import Self + __author__ = "Shyam Dwaraknath" __copyright__ = "Copyright 2016, The Materials Project" __version__ = "1.0" @@ -20,7 +25,7 @@ class PiezoTensor(Tensor): """This class describes the 3x6 piezo tensor in Voigt notation.""" - def __new__(cls, input_array, tol: float = 1e-3): + def __new__(cls, input_array: ArrayLike, tol: float = 1e-3) -> Self: """ Create an PiezoTensor object. The constructor throws an error if the shape of the input_matrix argument is not 3x3x3, i. e. in true @@ -38,7 +43,7 @@ def __new__(cls, input_array, tol: float = 1e-3): return obj.view(cls) @classmethod - def from_vasp_voigt(cls, input_vasp_array): + def from_vasp_voigt(cls, input_vasp_array: ArrayLike) -> Self: """ Args: input_vasp_array (nd.array): Voigt form of tensor. diff --git a/pymatgen/analysis/piezo_sensitivity.py b/pymatgen/analysis/piezo_sensitivity.py index ad2fce63e09..5cb94c29863 100644 --- a/pymatgen/analysis/piezo_sensitivity.py +++ b/pymatgen/analysis/piezo_sensitivity.py @@ -64,7 +64,7 @@ def get_BEC_operations(self, eigtol=1e-5, opstol=1e-3): opstol (float): tolerance for determining if a symmetry operation relates two sites - Return: + Returns: list of symmetry operations mapping equivalent sites and the indexes of those sites. """ @@ -117,7 +117,7 @@ def get_rand_BEC(self, max_charge=1): Args: max_charge (float): maximum born effective charge value - Return: + Returns: np.array Born effective charge tensor """ n_atoms = len(self.structure) @@ -197,7 +197,7 @@ def get_IST_operations(self, opstol=1e-3): opstol (float): tolerance for determining if a symmetry operation relates two sites - Return: + Returns: list of symmetry operations mapping equivalent sites and the indexes of those sites. """ @@ -231,7 +231,7 @@ def get_rand_IST(self, max_force=1): Args: max_force (float): maximum born effective charge value - Return: + Returns: InternalStrainTensor """ n_atoms = len(self.structure) @@ -289,7 +289,7 @@ def get_FCM_operations(self, eigtol=1e-5, opstol=1e-5): opstol (float): tolerance for determining if a symmetry operation relates two sites - Return: + Returns: list of symmetry operations mapping equivalent sites and the indexes of those sites. """ @@ -357,7 +357,7 @@ def get_unstable_FCM(self, max_force=1): Args: max_charge (float): maximum born effective charge value - Return: + Returns: numpy array representing the force constant matrix """ struct = self.structure @@ -417,11 +417,10 @@ def get_symmetrized_FCM(self, unsymmetrized_fcm, max_force=1): unsymmetrized_fcm (numpy array): unsymmetrized force constant matrix max_charge (float): maximum born effective charge value - Return: + Returns: 3Nx3N numpy array representing the force constant matrix """ operations = self.FCM_operations - D = unsymmetrized_fcm for op in operations: same = 0 transpose = 0 @@ -430,24 +429,24 @@ def get_symmetrized_FCM(self, unsymmetrized_fcm, max_force=1): if op[0] == op[3] and op[1] == op[2]: transpose = 1 if transpose == 0 and same == 0: - D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = np.zeros([3, 3]) + unsymmetrized_fcm[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = np.zeros([3, 3]) for symop in op[4]: - tempfcm = D[3 * op[2] : 3 * op[2] + 3, 3 * op[3] : 3 * op[3] + 3] + tempfcm = unsymmetrized_fcm[3 * op[2] : 3 * op[2] + 3, 3 * op[3] : 3 * op[3] + 3] tempfcm = symop.transform_tensor(tempfcm) - D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] += tempfcm + unsymmetrized_fcm[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] += tempfcm if len(op[4]) != 0: - D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = D[ + unsymmetrized_fcm[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = unsymmetrized_fcm[ 3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3 ] / len(op[4]) - D[3 * op[1] : 3 * op[1] + 3, 3 * op[0] : 3 * op[0] + 3] = D[ + unsymmetrized_fcm[3 * op[1] : 3 * op[1] + 3, 3 * op[0] : 3 * op[0] + 3] = unsymmetrized_fcm[ 3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3 ].T continue - temp_tensor = Tensor(D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3]) + temp_tensor = Tensor(unsymmetrized_fcm[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3]) temp_tensor_sum = sum(temp_tensor.transform(symm_op) for symm_op in self.sharedops[op[0]][op[1]]) if len(self.sharedops[op[0]][op[1]]) != 0: temp_tensor_sum = temp_tensor_sum / (len(self.sharedops[op[0]][op[1]])) @@ -462,10 +461,10 @@ def get_symmetrized_FCM(self, unsymmetrized_fcm, max_force=1): else: temp_tensor_sum = (temp_tensor_sum + temp_tensor_sum.T) / 2 - D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = temp_tensor_sum - D[3 * op[1] : 3 * op[1] + 3, 3 * op[0] : 3 * op[0] + 3] = temp_tensor_sum.T + unsymmetrized_fcm[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = temp_tensor_sum + unsymmetrized_fcm[3 * op[1] : 3 * op[1] + 3, 3 * op[0] : 3 * op[0] + 3] = temp_tensor_sum.T - return D + return unsymmetrized_fcm def get_stable_FCM(self, fcm, fcmasum=10): """ @@ -478,7 +477,7 @@ def get_stable_FCM(self, fcm, fcmasum=10): fcmasum (int): number of iterations to attempt to obey the acoustic sum rule - Return: + Returns: 3Nx3N numpy array representing the force constant matrix """ check = 0 @@ -526,7 +525,7 @@ def get_asum_FCM(self, fcm: np.ndarray, numiter: int = 15): numiter (int): number of iterations to attempt to obey the acoustic sum rule - Return: + Returns: numpy array representing the force constant matrix """ # set max force in reciprocal space @@ -617,7 +616,7 @@ def get_rand_FCM(self, asum=15, force=10): asum (int): number of iterations to attempt to obey the acoustic sum rule - Return: + Returns: NxNx3x3 np.array representing the force constant matrix """ from pymatgen.io.phonopy import get_phonopy_structure @@ -664,7 +663,7 @@ def get_piezo(BEC, IST, FCM, rcond=0.0001): FCM (numpy array): NxNx3x3 array representing the born effective charge tensor rcondy (float): condition for excluding eigenvalues in the pseudoinverse - Return: + Returns: 3x3x3 calculated Piezo tensor """ n_sites = len(BEC) @@ -694,7 +693,7 @@ def rand_piezo(struct, pointops, sharedops, BEC, IST, FCM, anumiter=10): IST (numpy array): Nx3x3x3 array representing the internal strain tensor FCM (numpy array): NxNx3x3 array representing the born effective charge tensor anumiter (int): number of iterations for acoustic sum rule convergence - Return: + Returns: list in the form of [Nx3x3 random born effective charge tenosr, Nx3x3x3 random internal strain tensor, NxNx3x3 random force constant matrix, 3x3x3 piezo tensor] """ diff --git a/pymatgen/analysis/pourbaix_diagram.py b/pymatgen/analysis/pourbaix_diagram.py index 8bd534b312b..0cb742a3779 100644 --- a/pymatgen/analysis/pourbaix_diagram.py +++ b/pymatgen/analysis/pourbaix_diagram.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: import matplotlib.pyplot as plt + from typing_extensions import Self __author__ = "Sai Jayaraman" __copyright__ = "Copyright 2012, The Materials Project" @@ -129,20 +130,13 @@ def name(self): @property def energy(self): - """ - Returns (float): total energy of the Pourbaix - entry (at pH, V = 0 vs. SHE). - """ + """Total energy of the Pourbaix entry (at pH, V = 0 vs. SHE).""" # Note: this implicitly depends on formation energies as input return self.uncorrected_energy + self.conc_term - (MU_H2O * self.nH2O) @property def energy_per_atom(self): - """ - energy per atom of the Pourbaix entry. - - Returns (float): energy per atom - """ + """Energy per atom of the Pourbaix entry.""" return self.energy / self.composition.num_atoms @property @@ -225,12 +219,14 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Invokes a PourbaixEntry from a dictionary.""" - entry_type = d["entry_type"] - entry = IonEntry.from_dict(d["entry"]) if entry_type == "Ion" else MontyDecoder().process_decoded(d["entry"]) - entry_id = d["entry_id"] - concentration = d["concentration"] + entry_type = dct["entry_type"] + entry = ( + IonEntry.from_dict(dct["entry"]) if entry_type == "Ion" else MontyDecoder().process_decoded(dct["entry"]) + ) + entry_id = dct["entry_id"] + concentration = dct["concentration"] return cls(entry, entry_id, concentration) @property @@ -293,7 +289,7 @@ def __getattr__(self, attr): # Attributes that are just lists of entry attributes if attr in ["entry_id", "phase_type"]: - return [getattr(e, attr) for e in self.entry_list] + return [getattr(entry, attr) for entry in self.entry_list] # normalization_factor, num_atoms should work from superclass return self.__getattribute__(attr) @@ -301,7 +297,7 @@ def __getattr__(self, attr): @property def name(self): """MultiEntry name, i. e. the name of each entry joined by ' + '.""" - return " + ".join(e.name for e in self.entry_list) + return " + ".join(entry.name for entry in self.entry_list) def __repr__(self): energy, npH, nPhi, nH2O, entry_id = self.energy, self.npH, self.nPhi, self.nH2O, self.entry_id @@ -313,12 +309,12 @@ def as_dict(self): return { "@module": type(self).__module__, "@class": type(self).__name__, - "entry_list": [e.as_dict() for e in self.entry_list], + "entry_list": [entry.as_dict() for entry in self.entry_list], "weights": self.weights, } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -326,7 +322,7 @@ def from_dict(cls, dct): Returns: MultiEntry """ - entry_list = [PourbaixEntry.from_dict(entry) for entry in dct.get("entry_list")] + entry_list = [PourbaixEntry.from_dict(entry) for entry in dct.get("entry_list", ())] return cls(entry_list, dct.get("weights")) @@ -358,9 +354,9 @@ def __init__(self, ion: Ion, energy: float, name: str | None = None, attribute=N super().__init__(composition=ion.composition, energy=energy, name=name, attribute=attribute) @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Returns an IonEntry object from a dict.""" - return cls(Ion.from_dict(d["ion"]), d["energy"], d.get("name"), d.get("attribute")) + return cls(Ion.from_dict(dct["ion"]), dct["energy"], dct.get("name"), dct.get("attribute")) def as_dict(self): """Creates a dict of composition, energy, and ion name.""" @@ -447,7 +443,7 @@ def __init__( if isinstance(entries[0], MultiEntry): self._processed_entries = entries # Extract individual entries - single_entries = list(set(itertools.chain.from_iterable([e.entry_list for e in entries]))) + single_entries = list(set(itertools.chain.from_iterable([entry.entry_list for entry in entries]))) self._unprocessed_entries = single_entries self._filtered_entries = single_entries self._conc_dict = None @@ -697,7 +693,7 @@ def process_multientry(entry_list, prod_comp, coeff_threshold=1e-4): # Note that we get reduced compositions for solids and non-reduced # compositions for ions because ions aren't normalized due to # their charge state. - entry_comps = [e.composition for e in entry_list] + entry_comps = [entry.composition for entry in entry_list] rxn = Reaction(entry_comps + dummy_oh, [prod_comp]) react_coeffs = [-coeff for coeff in rxn.coeffs[: len(entry_list)]] all_coeffs = [*react_coeffs, rxn.get_coeff(prod_comp)] @@ -805,7 +801,7 @@ def find_stable_entry(self, pH, V): Returns: PourbaixEntry: stable entry at pH, V """ - energies_at_conditions = [e.normalized_energy_at_conditions(pH, V) for e in self.stable_entries] + energies_at_conditions = [entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries] return self.stable_entries[np.argmin(energies_at_conditions)] def get_decomposition_energy(self, entry, pH, V): @@ -851,7 +847,7 @@ def get_hull_energy(self, pH, V): Returns: np.array: minimum Pourbaix energy at conditions """ - all_gs = np.array([e.normalized_energy_at_conditions(pH, V) for e in self.stable_entries]) + all_gs = np.array([entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries]) return np.min(all_gs, axis=0) def get_stable_entry(self, pH, V): @@ -866,7 +862,7 @@ def get_stable_entry(self, pH, V): PourbaixEntry | MultiEntry: Pourbaix or multi-entry corresponding to the minimum energy entry at a given pH, V condition """ - all_gs = np.array([e.normalized_energy_at_conditions(pH, V) for e in self.stable_entries]) + all_gs = np.array([entry.normalized_energy_at_conditions(pH, V) for entry in self.stable_entries]) return self.stable_entries[np.argmin(all_gs)] @property @@ -877,7 +873,7 @@ def stable_entries(self): @property def unstable_entries(self): """Returns all unstable entries in the Pourbaix diagram.""" - return [e for e in self.all_entries if e not in self.stable_entries] + return [entry for entry in self.all_entries if entry not in self.stable_entries] @property def all_entries(self): @@ -897,23 +893,28 @@ def as_dict(self): return { "@module": type(self).__module__, "@class": type(self).__name__, - "entries": [e.as_dict() for e in self._unprocessed_entries], + "entries": [entry.as_dict() for entry in self._unprocessed_entries], "comp_dict": self._elt_comp, "conc_dict": self._conc_dict, "filter_solids": self.filter_solids, } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (): Dict representation. + dct (dict): Dict representation. Returns: PourbaixDiagram """ - decoded_entries = MontyDecoder().process_decoded(d["entries"]) - return cls(decoded_entries, d.get("comp_dict"), d.get("conc_dict"), d.get("filter_solids")) + decoded_entries = MontyDecoder().process_decoded(dct["entries"]) + return cls( + decoded_entries, + comp_dict=dct.get("comp_dict"), + conc_dict=dct.get("conc_dict"), + filter_solids=bool(dct.get("filter_solids")), + ) class PourbaixPlotter: @@ -1077,7 +1078,7 @@ def generate_entry_label(entry): entry (PourbaixEntry or MultiEntry): entry to get a label for """ if isinstance(entry, MultiEntry): - return " + ".join(e.name for e in entry.entry_list) + return " + ".join(entry.name for entry in entry.entry_list) # TODO - a more elegant solution could be added later to Stringify # for example, the pattern re.sub(r"([-+][\d\.]*)", r"$^{\1}$", ) diff --git a/pymatgen/analysis/prototypes.py b/pymatgen/analysis/prototypes.py index 584330a0ef1..01ad61fcfd5 100644 --- a/pymatgen/analysis/prototypes.py +++ b/pymatgen/analysis/prototypes.py @@ -100,10 +100,11 @@ def get_prototypes(self, structure: Structure) -> list | None: Args: structure: structure to match - Returns (list): A list of dicts with keys 'snl' for the matched prototype and - 'tags', a dict of tags ('mineral', 'strukturbericht' and 'aflow') of that - prototype. This should be a list containing just a single entry, but it is - possible a material can match multiple prototypes. + Returns: + list | None: A list of dicts with keys 'snl' for the matched prototype and + 'tags', a dict of tags ('mineral', 'strukturbericht' and 'aflow') of that + prototype. This should be a list containing just a single entry, but it is + possible a material can match multiple prototypes. """ tags = self._match_single_prototype(structure) diff --git a/pymatgen/analysis/quasiharmonic.py b/pymatgen/analysis/quasiharmonic.py index 7aa82d77b70..63fea3bcc04 100644 --- a/pymatgen/analysis/quasiharmonic.py +++ b/pymatgen/analysis/quasiharmonic.py @@ -14,6 +14,7 @@ from collections import defaultdict import numpy as np +from monty.dev import deprecated from scipy.constants import physical_constants from scipy.integrate import quadrature from scipy.misc import derivative @@ -45,7 +46,7 @@ "temperature, and Grรผneisen parameter using a quasiharmonic Debye model", path="pymatgen.analysis.quasiharmonic", ) -class QuasiharmonicDebyeApprox: +class QuasiHarmonicDebyeApprox: """Quasi-harmonic approximation.""" def __init__( @@ -98,7 +99,7 @@ def __init__( "The Mie-Gruneisen formulation and anharmonic contribution are circular referenced and " "cannot be used together." ) - self.mass = sum(e.atomic_mass for e in self.structure.species) + self.mass = sum(spec.atomic_mass for spec in self.structure.species) self.natoms = self.structure.composition.num_atoms self.avg_mass = physical_constants["atomic mass constant"][0] * self.mass / self.natoms # kg self.kb = physical_constants["Boltzmann constant in eV/K"][0] @@ -225,7 +226,7 @@ def debye_temperature(self, volume: float) -> float: parameter at 0K (Gruneisen constant). The anharmonic contribution is toggled by setting the anharmonic_contribution - to True or False in the QuasiharmonicDebyeApprox constructor. + to True or False in the QuasiHarmonicDebyeApprox constructor. Args: volume (float): in Ang^3 @@ -356,3 +357,11 @@ def get_summary_dict(self): dct["gruneisen_parameter"].append(self.gruneisen_parameter(t, v)) dct["thermal_conductivity"].append(self.thermal_conductivity(t, v)) return dct + + +@deprecated( + replacement=QuasiHarmonicDebyeApprox, + message="Deprecated on 2024-03-27, to be removed on 2025-03-27.", +) +class QuasiharmonicDebyeApprox(QuasiHarmonicDebyeApprox): + pass diff --git a/pymatgen/analysis/quasirrho.py b/pymatgen/analysis/quasirrho.py index 8f62d8f8104..5498e5fda5b 100644 --- a/pymatgen/analysis/quasirrho.py +++ b/pymatgen/analysis/quasirrho.py @@ -10,14 +10,6 @@ from __future__ import annotations -__author__ = "Alex Epstein" -__copyright__ = "Copyright 2020, The Materials Project" -__version__ = "0.1" -__maintainer__ = "Alex Epstein" -__email__ = "aepstein@lbl.gov" -__date__ = "August 1, 2023" -__credits__ = "Ryan Kingsbury, Steven Wheeler, Trevor Seguin, Evan Spotte-Smith" - from math import isclose from typing import TYPE_CHECKING @@ -28,10 +20,20 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Molecule from pymatgen.io.gaussian import GaussianOutput from pymatgen.io.qchem.outputs import QCOutput +__author__ = "Alex Epstein" +__copyright__ = "Copyright 2020, The Materials Project" +__version__ = "0.1" +__maintainer__ = "Alex Epstein" +__email__ = "aepstein@lbl.gov" +__date__ = "August 1, 2023" +__credits__ = "Ryan Kingsbury, Steven Wheeler, Trevor Seguin, Evan Spotte-Smith" + # Define useful constants kb = kb_ev * const.eV # Pymatgen kb [J/K] light_speed = const.speed_of_light * 100 # [cm/s] @@ -133,9 +135,8 @@ def __init__( self._get_quasirrho_thermo(mol=mol, mult=mult, frequencies=frequencies, elec_energy=energy, sigma_r=sigma_r) @classmethod - def from_gaussian_output(cls, output: GaussianOutput, **kwargs) -> QuasiRRHO: + def from_gaussian_output(cls, output: GaussianOutput, **kwargs) -> Self: """ - Args: output (GaussianOutput): Pymatgen GaussianOutput object @@ -145,13 +146,12 @@ def from_gaussian_output(cls, output: GaussianOutput, **kwargs) -> QuasiRRHO: mult = output.spin_multiplicity elec_e = output.final_energy mol = output.final_structure - vib_freqs = [f["frequency"] for f in output.frequencies[-1]] + vib_freqs = [freq["frequency"] for freq in output.frequencies[-1]] return cls(mol=mol, frequencies=vib_freqs, energy=elec_e, mult=mult, **kwargs) @classmethod - def from_qc_output(cls, output: QCOutput, **kwargs) -> QuasiRRHO: + def from_qc_output(cls, output: QCOutput, **kwargs) -> Self: """ - Args: output (QCOutput): Pymatgen QCOutput object diff --git a/pymatgen/analysis/reaction_calculator.py b/pymatgen/analysis/reaction_calculator.py index f63263a8615..cfcf0fe6a5c 100644 --- a/pymatgen/analysis/reaction_calculator.py +++ b/pymatgen/analysis/reaction_calculator.py @@ -5,7 +5,7 @@ import logging import re from itertools import chain, combinations -from typing import TYPE_CHECKING, no_type_check +from typing import TYPE_CHECKING, no_type_check, overload import numpy as np from monty.fractions import gcd_float @@ -18,6 +18,9 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + + from pymatgen.core import Element, Species from pymatgen.util.typing import CompositionLike __author__ = "Shyue Ping Ong, Anubhav Jain" @@ -68,7 +71,7 @@ def __init__( # calculate net reaction coefficients self._coeffs: list[float] = [] - self._els: list[str] = [] + self._els: list[Element | Species] = [] self._all_comp: list[Composition] = [] for key in {*reactants_coeffs, *products_coeffs}: coeff = products_coeffs.get(key, 0) - reactants_coeffs.get(key, 0) @@ -77,6 +80,32 @@ def __init__( self._all_comp += [key] self._coeffs += [coeff] + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + return NotImplemented + for comp in self._all_comp: + coeff2 = other.get_coeff(comp) if comp in other._all_comp else 0 + if abs(self.get_coeff(comp) - coeff2) > self.TOLERANCE: + return False + return True + + def __hash__(self) -> int: + # Necessity for hash method is unclear (see gh-3673) + return hash((frozenset(self.reactants_coeffs.items()), frozenset(self.products_coeffs.items()))) + + def __str__(self): + return self._str_from_comp(self._coeffs, self._all_comp)[0] + + __repr__ = __str__ + + @overload + def calculate_energy(self, energies: dict[Composition, ufloat]) -> ufloat: + pass + + @overload + def calculate_energy(self, energies: dict[Composition, float]) -> float: + pass + def calculate_energy(self, energies): """ Calculates the energy of the reaction. @@ -90,7 +119,7 @@ def calculate_energy(self, energies): """ return sum(amt * energies[c] for amt, c in zip(self._coeffs, self._all_comp)) - def normalize_to(self, comp, factor=1): + def normalize_to(self, comp: Composition, factor: float = 1) -> None: """ Normalizes the reaction to one of the compositions. By default, normalizes such that the composition given has a @@ -103,7 +132,7 @@ def normalize_to(self, comp, factor=1): scale_factor = abs(1 / self._coeffs[self._all_comp.index(comp)] * factor) self._coeffs = [c * scale_factor for c in self._coeffs] - def normalize_to_element(self, element, factor=1): + def normalize_to_element(self, element: Species | Element, factor: float = 1) -> None: """ Normalizes the reaction to one of the elements. By default, normalizes such that the amount of the element is 1. @@ -119,7 +148,7 @@ def normalize_to_element(self, element, factor=1): scale_factor = factor / current_el_amount self._coeffs = [c * scale_factor for c in coeffs] - def get_el_amount(self, element): + def get_el_amount(self, element: Element | Species) -> float: """ Returns the amount of the element in the reaction. @@ -132,35 +161,35 @@ def get_el_amount(self, element): return sum(self._all_comp[i][element] * abs(self._coeffs[i]) for i in range(len(self._all_comp))) / 2 @property - def elements(self): + def elements(self) -> list[Element | Species]: """List of elements in the reaction.""" - return self._els[:] + return self._els @property - def coeffs(self): + def coeffs(self) -> list[float]: """Final coefficients of the calculated reaction.""" return self._coeffs[:] @property - def all_comp(self): + def all_comp(self) -> list[Composition]: """List of all compositions in the reaction.""" return self._all_comp @property - def reactants(self): + def reactants(self) -> list[Composition]: """List of reactants.""" return [self._all_comp[i] for i in range(len(self._all_comp)) if self._coeffs[i] < 0] @property - def products(self): + def products(self) -> list[Composition]: """List of products.""" return [self._all_comp[i] for i in range(len(self._all_comp)) if self._coeffs[i] > 0] - def get_coeff(self, comp): + def get_coeff(self, comp: Composition) -> float: """Returns coefficient for a particular composition.""" return self._coeffs[self._all_comp.index(comp)] - def normalized_repr_and_factor(self): + def normalized_repr_and_factor(self) -> tuple[str, float]: """ Normalized representation for a reaction For example, ``4 Li + 2 O -> 2Li2O`` becomes ``2 Li + O -> Li2O``. @@ -168,28 +197,15 @@ def normalized_repr_and_factor(self): return self._str_from_comp(self._coeffs, self._all_comp, reduce=True) @property - def normalized_repr(self): + def normalized_repr(self) -> str: """ A normalized representation of the reaction. All factors are converted to lowest common factors. """ return self.normalized_repr_and_factor()[0] - def __eq__(self, other: object) -> bool: - if not isinstance(other, type(self)): - return NotImplemented - for comp in self._all_comp: - coeff2 = other.get_coeff(comp) if comp in other._all_comp else 0 - if abs(self.get_coeff(comp) - coeff2) > self.TOLERANCE: - return False - return True - - def __hash__(self) -> int: - # Necessity for hash method is unclear (see gh-3673) - return hash((frozenset(self.reactants_coeffs.items()), frozenset(self.products_coeffs.items()))) - @classmethod - def _str_from_formulas(cls, coeffs, formulas): + def _str_from_formulas(cls, coeffs, formulas) -> str: reactant_str = [] product_str = [] for amt, formula in zip(coeffs, formulas): @@ -205,7 +221,7 @@ def _str_from_formulas(cls, coeffs, formulas): return f"{' + '.join(reactant_str)} -> {' + '.join(product_str)}" @classmethod - def _str_from_comp(cls, coeffs, compositions, reduce=False): + def _str_from_comp(cls, coeffs, compositions, reduce=False) -> tuple[str, float]: r_coeffs = np.zeros(len(coeffs)) r_formulas = [] for idx, (amt, comp) in enumerate(zip(coeffs, compositions)): @@ -219,22 +235,18 @@ def _str_from_comp(cls, coeffs, compositions, reduce=False): factor = 1 return cls._str_from_formulas(r_coeffs, r_formulas), factor - def __str__(self): - return self._str_from_comp(self._coeffs, self._all_comp)[0] - - __repr__ = __str__ - - def as_entry(self, energies): + def as_entry(self, energies) -> ComputedEntry: """ Returns a ComputedEntry representation of the reaction. """ relevant_comp = [comp * abs(coeff) for coeff, comp in zip(self._coeffs, self._all_comp)] - comp = sum(relevant_comp, Composition()) + comp: Composition = sum(relevant_comp, Composition()) # type: ignore[assignment] + entry = ComputedEntry(0.5 * comp, self.calculate_energy(energies)) entry.name = str(self) return entry - def as_dict(self): + def as_dict(self) -> dict: """ Returns: A dictionary representation of BalancedReaction. @@ -247,20 +259,20 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): from as_dict(). + dct (dict): from as_dict(). Returns: A BalancedReaction object. """ - reactants = {Composition(comp): coeff for comp, coeff in d["reactants"].items()} - products = {Composition(comp): coeff for comp, coeff in d["products"].items()} + reactants = {Composition(comp): coeff for comp, coeff in dct["reactants"].items()} + products = {Composition(comp): coeff for comp, coeff in dct["products"].items()} return cls(reactants, products) @classmethod - def from_str(cls, rxn_str): + def from_str(cls, rxn_str: str) -> Self: """ Generates a balanced reaction from a string. The reaction must already be balanced. @@ -291,7 +303,7 @@ class Reaction(BalancedReaction): the *FIRST* product (or products, if underdetermined) has a coefficient of one. """ - def __init__(self, reactants, products): + def __init__(self, reactants: list[Composition], products: list[Composition]) -> None: """ Reactants and products to be specified as list of pymatgen.core.structure.Composition. e.g., [comp1, comp2]. @@ -365,11 +377,11 @@ def _balance_coeffs(self, comp_matrix, max_num_constraints): return np.squeeze(best_soln) - def copy(self): + def copy(self) -> Self: """Returns a copy of the Reaction object.""" return Reaction(self.reactants, self.products) - def as_dict(self): + def as_dict(self) -> dict: """ Returns: A dictionary representation of Reaction. @@ -382,7 +394,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): from as_dict(). @@ -401,7 +413,7 @@ class ReactionError(Exception): messages to cover situations not covered by standard exception classes. """ - def __init__(self, msg): + def __init__(self, msg: str) -> None: """ Create a ReactionError. @@ -410,7 +422,7 @@ def __init__(self, msg): """ self.msg = msg - def __str__(self): + def __str__(self) -> str: return self.msg @@ -421,7 +433,7 @@ class ComputedReaction(Reaction): energies. """ - def __init__(self, reactant_entries, product_entries): + def __init__(self, reactant_entries: list[ComputedEntry], product_entries: list[ComputedEntry]) -> None: """ Args: reactant_entries ([ComputedEntry]): List of reactant_entries. @@ -430,9 +442,9 @@ def __init__(self, reactant_entries, product_entries): self._reactant_entries = reactant_entries self._product_entries = product_entries self._all_entries = reactant_entries + product_entries - reactant_comp = [e.composition.reduced_composition for e in reactant_entries] + reactant_comp = [entry.composition.reduced_composition for entry in reactant_entries] - product_comp = [e.composition.reduced_composition for e in product_entries] + product_comp = [entry.composition.reduced_composition for entry in product_entries] super().__init__(list(reactant_comp), list(product_comp)) @@ -443,20 +455,20 @@ def all_entries(self): coefficients. """ entries = [] - for c in self._all_comp: - for e in self._all_entries: - if e.reduced_formula == c.reduced_formula: - entries.append(e) + for comp in self._all_comp: + for entry in self._all_entries: + if entry.reduced_formula == comp.reduced_formula: + entries.append(entry) break return entries @property - def calculated_reaction_energy(self): + def calculated_reaction_energy(self) -> float: """ - Returns (float): - The calculated reaction energy. + Returns: + float: The calculated reaction energy. """ - calc_energies = {} + calc_energies: dict[Composition, float] = {} for entry in self._reactant_entries + self._product_entries: comp, factor = entry.composition.get_reduced_composition_and_factor() @@ -465,12 +477,12 @@ def calculated_reaction_energy(self): return self.calculate_energy(calc_energies) @property - def calculated_reaction_energy_uncertainty(self): + def calculated_reaction_energy_uncertainty(self) -> float: """ Calculates the uncertainty in the reaction energy based on the uncertainty in the energies of the products and reactants. """ - calc_energies = {} + calc_energies: dict[Composition, float] = {} for entry in self._reactant_entries + self._product_entries: comp, factor = entry.composition.get_reduced_composition_and_factor() @@ -479,7 +491,7 @@ def calculated_reaction_energy_uncertainty(self): return self.calculate_energy(calc_energies).std_dev - def as_dict(self): + def as_dict(self) -> dict: """ Returns: A dictionary representation of ComputedReaction. @@ -487,20 +499,19 @@ def as_dict(self): return { "@module": type(self).__module__, "@class": type(self).__name__, - "reactants": [e.as_dict() for e in self._reactant_entries], - "products": [e.as_dict() for e in self._product_entries], + "reactants": [entry.as_dict() for entry in self._reactant_entries], + "products": [entry.as_dict() for entry in self._product_entries], } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): from as_dict(). + dct (dict): from as_dict(). Returns: A ComputedReaction object. """ - dec = MontyDecoder() - reactants = [dec.process_decoded(e) for e in d["reactants"]] - products = [dec.process_decoded(e) for e in d["products"]] + reactants = [MontyDecoder().process_decoded(entry) for entry in dct["reactants"]] + products = [MontyDecoder().process_decoded(entry) for entry in dct["products"]] return cls(reactants, products) diff --git a/pymatgen/analysis/solar/slme.py b/pymatgen/analysis/solar/slme.py index e33a17f61ac..dcc7065c421 100644 --- a/pymatgen/analysis/solar/slme.py +++ b/pymatgen/analysis/solar/slme.py @@ -60,7 +60,7 @@ def to_matrix(xx, yy, zz, xy, yz, xz): xz (float): xz component of the matrix. Returns: - (np.array): The matrix, as a 3x3 numpy array. + np.array: The matrix, as a 3x3 numpy array. """ return np.array([[xx, xy, xz], [xy, yy, yz], [xz, yz, zz]]) @@ -78,7 +78,7 @@ def parse_dielectric_data(data): np.array: a Nx3 numpy array. Each row contains the eigenvalues for the corresponding row in `data`. """ - return np.array([np.linalg.eig(to_matrix(*e))[0] for e in data]) + return np.array([np.linalg.eig(to_matrix(*eps))[0] for eps in data]) def absorption_coefficient(dielectric): @@ -94,7 +94,7 @@ def absorption_coefficient(dielectric): - element 2: imaginary dielectric tensors, in ``[xx, yy, zz, xy, xz, yz]`` format. Returns: - (np.array): absorption coefficient using eV as frequency units (cm^-1). + np.array: absorption coefficient using eV as frequency units (cm^-1). """ energies_in_eV = np.array(dielectric[0]) real_dielectric = parse_dielectric_data(dielectric[1]) @@ -116,15 +116,15 @@ def absorption_coefficient(dielectric): def optics(path=""): """Helper function to calculate optical absorption coefficient.""" - dirgap, indirgap = get_dir_indir_gap(path) + dir_gap, indir_gap = get_dir_indir_gap(path) run = Vasprun(path, occu_tol=1e-2) new_en, new_abs = absorption_coefficient(run.dielectric) return ( np.array(new_en, dtype=np.float64), np.array(new_abs, dtype=np.float64), - dirgap, - indirgap, + dir_gap, + indir_gap, ) diff --git a/pymatgen/analysis/structure_analyzer.py b/pymatgen/analysis/structure_analyzer.py index 08baab80817..7984561366b 100644 --- a/pymatgen/analysis/structure_analyzer.py +++ b/pymatgen/analysis/structure_analyzer.py @@ -228,12 +228,12 @@ def get_percentage_bond_dist_changes(self, max_radius: float = 3.0) -> dict[int, reason to duplicate the information or computation. """ data: dict[int, dict[int, float]] = collections.defaultdict(dict) - for inds in itertools.combinations(list(range(len(self.initial))), 2): - (i, j) = sorted(inds) - initial_dist = self.initial[i].distance(self.initial[j]) + for indices in itertools.combinations(list(range(len(self.initial))), 2): + ii, jj = sorted(indices) + initial_dist = self.initial[ii].distance(self.initial[jj]) if initial_dist < max_radius: - final_dist = self.final[i].distance(self.final[j]) - data[i][j] = final_dist / initial_dist - 1 + final_dist = self.final[ii].distance(self.final[jj]) + data[ii][jj] = final_dist / initial_dist - 1 return data @@ -518,7 +518,7 @@ def sulfide_type(structure): structure (Structure): Input structure. Returns: - (str) sulfide/polysulfide or None if structure is a sulfate. + str: sulfide/polysulfide or None if structure is a sulfate. """ structure = structure.copy().remove_oxidation_states() sulphur = Element("S") @@ -549,7 +549,7 @@ def process_site(site): neighbors = sorted(neighbors, key=lambda n: n.nn_distance) dist = neighbors[0].nn_distance coord_elements = [nn.specie for nn in neighbors if nn.nn_distance < dist + 0.4][:4] - avg_electroneg = np.mean([e.X for e in coord_elements]) + avg_electroneg = np.mean([elem.X for elem in coord_elements]) if avg_electroneg > sulphur.X: return "sulfate" if avg_electroneg == sulphur.X and sulphur in coord_elements: diff --git a/pymatgen/analysis/structure_matcher.py b/pymatgen/analysis/structure_matcher.py index 772e41603bd..59b3a5a206b 100644 --- a/pymatgen/analysis/structure_matcher.py +++ b/pymatgen/analysis/structure_matcher.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence + from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike __author__ = "William Davidson Richards, Stephen Dacek, Shyue Ping Ong" @@ -75,9 +77,10 @@ def get_hash(self, composition): """ @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation + Args: + dct (dict): Dict representation Returns: Comparator. @@ -87,11 +90,11 @@ def from_dict(cls, d): f"pymatgen.analysis.{trans_modules}", globals(), locals(), - [d["@class"]], + [dct["@class"]], 0, ) - if hasattr(mod, d["@class"]): - trans = getattr(mod, d["@class"]) + if hasattr(mod, dct["@class"]): + trans = getattr(mod, dct["@class"]) return trans() raise ValueError("Invalid Comparator dict") @@ -267,7 +270,10 @@ def are_equal(self, sp1, sp2) -> bool: def get_hash(self, composition): """ - :param composition: Composition. + Args: + composition: Composition. + + TODO: might need a proper hash method Returns: 1. Difficult to define sensible hash @@ -854,24 +860,25 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation + Args: + dct (dict): Dict representation Returns: StructureMatcher """ return cls( - ltol=d["ltol"], - stol=d["stol"], - angle_tol=d["angle_tol"], - primitive_cell=d["primitive_cell"], - scale=d["scale"], - attempt_supercell=d["attempt_supercell"], - allow_subset=d["allow_subset"], - comparator=AbstractComparator.from_dict(d["comparator"]), - supercell_size=d["supercell_size"], - ignored_species=d["ignored_species"], + ltol=dct["ltol"], + stol=dct["stol"], + angle_tol=dct["angle_tol"], + primitive_cell=dct["primitive_cell"], + scale=dct["scale"], + attempt_supercell=dct["attempt_supercell"], + allow_subset=dct["allow_subset"], + comparator=AbstractComparator.from_dict(dct["comparator"]), + supercell_size=dct["supercell_size"], + ignored_species=dct["ignored_species"], ) def _anonymous_match( @@ -962,9 +969,10 @@ def get_rms_anonymous(self, struct1, struct2): struct2 (Structure): 2nd structure Returns: - tuple: [min_rms, min_mapping]: min_rms is the minimum rms distance, and min_mapping is the - corresponding minimal species mapping that would map - struct1 to struct2. (None, None) is returned if the minimax_rms exceeds the threshold. + tuple[float, float] | tuple[None, None]: 1st element is min_rms, 2nd is min_mapping. + min_rms is the minimum RMS distance, and min_mapping is the corresponding + minimal species mapping that would map struct1 to struct2. (None, None) is + returned if the minimax_rms exceeds the threshold. """ struct1, struct2 = self._process_species([struct1, struct2]) struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2) diff --git a/pymatgen/analysis/structure_prediction/dopant_predictor.py b/pymatgen/analysis/structure_prediction/dopant_predictor.py index 62adad5fd18..bb7b306d3e4 100644 --- a/pymatgen/analysis/structure_prediction/dopant_predictor.py +++ b/pymatgen/analysis/structure_prediction/dopant_predictor.py @@ -26,13 +26,13 @@ def get_dopants_from_substitution_probabilities(structure, num_dopants=5, thresh returned. Returns: - (dict): Dopant suggestions, given as a dictionary with keys "n_type" and - "p_type". The suggestions for each doping type are given as a list of - dictionaries, each with they keys: + dict: Dopant suggestions, given as a dictionary with keys "n_type" and + "p_type". The suggestions for each doping type are given as a list of + dictionaries, each with they keys: - - "probability": The probability of substitution. - - "dopant_species": The dopant species. - - "original_species": The substituted species. + - "probability": The probability of substitution. + - "dopant_species": The dopant species. + - "original_species": The substituted species. """ els_have_oxi_states = [hasattr(s, "oxi_state") for s in structure.species] @@ -72,15 +72,15 @@ def get_dopants_from_shannon_radii(bonded_structure, num_dopants=5, match_oxi_si returned. Returns: - (dict): Dopant suggestions, given as a dictionary with keys "n_type" and - "p_type". The suggestions for each doping type are given as a list of - dictionaries, each with they keys: + dict: Dopant suggestions, given as a dictionary with keys "n_type" and + "p_type". The suggestions for each doping type are given as a list of + dictionaries, each with they keys: - - "radii_diff": The difference between the Shannon radii of the species. - - "dopant_spcies": The dopant species. - - "original_species": The substituted species. + - "radii_diff": The difference between the Shannon radii of the species. + - "dopant_species": The dopant species. + - "original_species": The substituted species. """ - # get a list of all Species for all elements in all their common oxid states + # get a list of all Species for all elements in all their common oxidation states all_species = [Species(el, oxi) for el in Element for oxi in el.common_oxidation_states] # get a series of tuples with (coordination number, specie) @@ -164,13 +164,13 @@ def _shannon_radii_from_cn(species_list, cn_roman, radius_to_compare=0): shannon radii and this radius. Returns: - (list of dict): The Shannon radii for all Species in species. Formatted - as a list of dictionaries, with the keys: + list[dict]: The Shannon radii for all Species in species. Formatted + as a list of dictionaries, with the keys: - - "species": The species with charge state. - - "radius": The Shannon radius for the species. - - "radius_diff": The difference between the Shannon radius and the - radius_to_compare optional argument. + - "species": The species with charge state. + - "radius": The Shannon radius for the species. + - "radius_diff": The difference between the Shannon radius and the + radius_to_compare optional argument. """ shannon_radii = [] diff --git a/pymatgen/analysis/structure_prediction/substitution_probability.py b/pymatgen/analysis/structure_prediction/substitution_probability.py index b3c366dcbc1..2270def52d4 100644 --- a/pymatgen/analysis/structure_prediction/substitution_probability.py +++ b/pymatgen/analysis/structure_prediction/substitution_probability.py @@ -13,12 +13,16 @@ import os from collections import defaultdict from operator import mul +from typing import TYPE_CHECKING from monty.design_patterns import cached_class from pymatgen.core import Species, get_el_sp from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Will Richards, Geoffroy Hautier" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.2" @@ -168,7 +172,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. diff --git a/pymatgen/analysis/structure_prediction/substitutor.py b/pymatgen/analysis/structure_prediction/substitutor.py index e4ac2565262..1294eeb4bcd 100644 --- a/pymatgen/analysis/structure_prediction/substitutor.py +++ b/pymatgen/analysis/structure_prediction/substitutor.py @@ -6,6 +6,7 @@ import itertools import logging from operator import mul +from typing import TYPE_CHECKING from monty.json import MSONable @@ -17,6 +18,9 @@ from pymatgen.transformations.standard_transformations import SubstitutionTransformation from pymatgen.util.due import Doi, due +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Will Richards, Geoffroy Hautier" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.2" @@ -174,7 +178,7 @@ def pred_from_list(self, species_list): There are an exceptionally large number of substitutions to look at (260^n), where n is the number of species in the list. We need a more efficient than brute force way of going - through these possibilities. The brute force method would be:: + through these possibilities. The brute force method would be: output = [] for p in itertools.product(self._sp.species_list, repeat=len(species_list)): @@ -252,14 +256,14 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: Class """ - t = d["threshold"] - kwargs = d["kwargs"] + t = dct["threshold"] + kwargs = dct["kwargs"] return cls(threshold=t, **kwargs) diff --git a/pymatgen/analysis/surface_analysis.py b/pymatgen/analysis/surface_analysis.py index 7dc94233b65..9535116bbc0 100644 --- a/pymatgen/analysis/surface_analysis.py +++ b/pymatgen/analysis/surface_analysis.py @@ -1,7 +1,7 @@ """ This module defines tools to analyze surface and adsorption related quantities as well as related plots. If you use this module, please -consider citing the following works:: +consider citing the following works: R. Tran, Z. Xu, B. Radhakrishnan, D. Winston, W. Sun, K. A. Persson, S. P. Ong, "Surface Energies of Elemental Crystals", Scientific @@ -20,7 +20,7 @@ Computational Materials, 3(1), 14. https://doi.org/10.1038/s41524-017-0017-z -Todo: +TODO: - Still assumes individual elements have their own chempots in a molecular adsorbate instead of considering a single chempot for a single molecular adsorbate. E.g. for an OH @@ -38,6 +38,7 @@ import itertools import random import warnings +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -54,6 +55,9 @@ from pymatgen.util.due import Doi, due from pymatgen.util.plotting import pretty_plot +if TYPE_CHECKING: + from typing_extensions import Self + EV_PER_ANG2_TO_JOULES_PER_M2 = 16.0217656 __author__ = "Richard Tran" @@ -117,7 +121,7 @@ def __init__( """ self.miller_index = miller_index self.label = label - self.adsorbates = adsorbates if adsorbates else [] + self.adsorbates = adsorbates or [] self.clean_entry = clean_entry self.ads_entries_dict = {str(next(iter(ads.composition.as_dict()))): ads for ads in self.adsorbates} self.mark = marker @@ -172,10 +176,11 @@ def surface_energy(self, ucell_entry, ref_entries=None): of the element ref_entry that is not in the list will be treated as a variable. - Returns (Add (Sympy class)): Surface energy + Returns: + float: The surface energy of the slab. """ # Set up - ref_entries = ref_entries if ref_entries else [] + ref_entries = ref_entries or [] # Check if appropriate ref_entries are present if the slab is non-stoichiometric # TODO: There should be a way to identify which specific species are @@ -272,7 +277,7 @@ def Nsurfs_ads_in_slab(self): return n_surfs @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Returns a SlabEntry by reading in an dictionary.""" structure = SlabEntry.from_dict(dct["structure"]) energy = SlabEntry.from_dict(dct["energy"]) @@ -325,7 +330,7 @@ def create_slab_label(self): @classmethod def from_computed_structure_entry( cls, entry, miller_index, label=None, adsorbates=None, clean_entry=None, **kwargs - ): + ) -> Self: """Returns SlabEntry from a ComputedStructureEntry.""" return cls( entry.structure, @@ -512,7 +517,7 @@ def wulff_from_chempot( Returns: WulffShape: The WulffShape at u_ref and u_ads. """ - latt = SpacegroupAnalyzer(self.ucell_entry.structure).get_conventional_standard_structure().lattice + lattice = SpacegroupAnalyzer(self.ucell_entry.structure).get_conventional_standard_structure().lattice miller_list = list(self.all_slab_entries) e_surf_list = [] @@ -529,7 +534,7 @@ def wulff_from_chempot( )[1] e_surf_list.append(gamma) - return WulffShape(latt, miller_list, e_surf_list, symprec=symprec) + return WulffShape(lattice, miller_list, e_surf_list, symprec=symprec) def area_frac_vs_chempot_plot( self, @@ -856,6 +861,7 @@ def chempot_vs_gamma_plot_one( """ delu_dict = delu_dict or {} chempot_range = sorted(chempot_range) + ax = ax or plt.gca() # use dashed lines for slabs that are not stoichiometric # w.r.t. bulk. Label with formula if non-stoichiometric @@ -879,9 +885,10 @@ def chempot_vs_gamma_plot_one( se_range = np.array(gamma_range) * EV_PER_ANG2_TO_JOULES_PER_M2 if JPERM2 else gamma_range - mark = entry.mark if entry.mark else mark - c = entry.color if entry.color else self.color_dict[entry] - return plt.plot(chempot_range, se_range, mark, color=c, label=label) + mark = entry.mark or mark + color = entry.color or self.color_dict[entry] + ax.plot(chempot_range, se_range, mark, color=color, label=label) + return ax def chempot_vs_gamma( self, @@ -936,7 +943,7 @@ def chempot_vs_gamma( delu_dict = {} chempot_range = sorted(chempot_range) - plt = plt if plt else pretty_plot(width=8, height=7) + plt = plt or pretty_plot(width=8, height=7) axes = plt.gca() for hkl in self.all_slab_entries: @@ -1170,7 +1177,7 @@ def surface_chempot_range_map( """ # Set up delu_dict = delu_dict or {} - ax = ax if ax else pretty_plot(12, 8) + ax = ax or pretty_plot(12, 8) el1, el2 = str(elements[0]), str(elements[1]) delu1 = Symbol(f"delu_{elements[0]}") delu2 = Symbol(f"delu_{elements[1]}") @@ -1250,7 +1257,7 @@ def surface_chempot_range_map( # Label the phases x = np.mean([max(xvals), min(xvals)]) y = np.mean([max(yvals), min(yvals)]) - label = entry.label if entry.label else entry.reduced_formula + label = entry.label or entry.reduced_formula ax.annotate(label, xy=[x, y], xytext=[x, y], fontsize=fontsize) # Label plot @@ -1299,7 +1306,7 @@ def entry_dict_from_list(all_slab_entries): all_slab_entries (list): List of SlabEntry objects Returns: - (dict): Dictionary of SlabEntry with the Miller index as the main + dict: Dictionary of SlabEntry with the Miller index as the main key to a dictionary with a clean SlabEntry as the key to a list of adsorbed SlabEntry. """ @@ -1309,7 +1316,7 @@ def entry_dict_from_list(all_slab_entries): hkl = tuple(entry.miller_index) if hkl not in entry_dict: entry_dict[hkl] = {} - clean = entry.clean_entry if entry.clean_entry else entry + clean = entry.clean_entry or entry if clean not in entry_dict[hkl]: entry_dict[hkl][clean] = [] if entry.adsorbates: @@ -1420,7 +1427,7 @@ def get_locpot_along_slab_plot(self, label_energies=True, plt=None, label_fontsi Returns plt of the locpot vs c axis """ - plt = plt if plt else pretty_plot(width=6, height=4) + plt = plt or pretty_plot(width=6, height=4) # plot the raw locpot signal along c plt.plot(self.along_c, self.locpot_along_c, "b--") @@ -1564,7 +1571,7 @@ def is_converged(self, min_points_frac=0.015, tol: float = 0.0025): return all(all_flat) @classmethod - def from_files(cls, poscar_filename, locpot_filename, outcar_filename, shift=0, blength=3.5): + def from_files(cls, poscar_filename, locpot_filename, outcar_filename, shift=0, blength=3.5) -> Self: """ Initializes a WorkFunctionAnalyzer from POSCAR, LOCPOT, and OUTCAR files. @@ -1650,13 +1657,13 @@ def solve_equilibrium_point(self, analyzer1, analyzer2, delu_dict=None, delu_def # Now calculate r delta_gamma = wulff1.weighted_surface_energy - wulff2.weighted_surface_energy delta_E = self.bulk_gform(analyzer1.ucell_entry) - self.bulk_gform(analyzer2.ucell_entry) - r = (-3 * delta_gamma) / (delta_E) + radius = (-3 * delta_gamma) / (delta_E) - return r / 10 if units == "nanometers" else r + return radius / 10 if units == "nanometers" else radius def wulff_gform_and_r( self, - wulffshape, + wulff_shape, bulk_entry, r, from_sphere_area=False, @@ -1669,7 +1676,7 @@ def wulff_gform_and_r( Calculates the formation energy of the particle with arbitrary radius r. Args: - wulffshape (WulffShape): Initial, unscaled WulffShape + wulff_shape (WulffShape): Initial unscaled WulffShape bulk_entry (ComputedStructureEntry): Entry of the corresponding bulk. r (float (Ang)): Arbitrary effective radius of the WulffShape from_sphere_area (bool): There are two ways to calculate the bulk @@ -1685,8 +1692,8 @@ def wulff_gform_and_r( particle formation energy (float in keV), effective radius """ # Set up - miller_se_dict = wulffshape.miller_energy_dict - new_wulff = self.scaled_wulff(wulffshape, r) + miller_se_dict = wulff_shape.miller_energy_dict + new_wulff = self.scaled_wulff(wulff_shape, r) new_wulff_area = new_wulff.miller_area_dict # calculate surface energy of the particle @@ -1703,7 +1710,7 @@ def wulff_gform_and_r( # By approximating the particle as a perfect sphere w_vol = (4 / 3) * np.pi * r**3 sphere_sa = 4 * np.pi * r**2 - tot_wulff_se = wulffshape.weighted_surface_energy * sphere_sa + tot_wulff_se = wulff_shape.weighted_surface_energy * sphere_sa Ebulk = self.bulk_gform(bulk_entry) * w_vol new_r = r @@ -1730,7 +1737,7 @@ def bulk_gform(bulk_entry): """ return bulk_entry.energy / bulk_entry.structure.volume - def scaled_wulff(self, wulffshape, r): + def scaled_wulff(self, wulff_shape, r): """ Scales the Wulff shape with an effective radius r. Note that the resulting Wulff does not necessarily have the same effective radius as the one @@ -1739,22 +1746,22 @@ def scaled_wulff(self, wulffshape, r): multiplied by the given effective radius. Args: - wulffshape (WulffShape): Initial, unscaled WulffShape + wulff_shape (WulffShape): Initial, unscaled WulffShape r (float): Arbitrary effective radius of the WulffShape Returns: WulffShape (scaled by r) """ # get the scaling ratio for the energies - r_ratio = r / wulffshape.effective_radius - miller_list = list(wulffshape.miller_energy_dict) + r_ratio = r / wulff_shape.effective_radius + miller_list = list(wulff_shape.miller_energy_dict) # Normalize the magnitude of the facet normal vectors # of the Wulff shape by the minimum surface energy. - se_list = np.array(list(wulffshape.miller_energy_dict.values())) + se_list = np.array(list(wulff_shape.miller_energy_dict.values())) # Scale the magnitudes by r_ratio scaled_se = se_list * r_ratio - return WulffShape(wulffshape.lattice, miller_list, scaled_se, symprec=self.symprec) + return WulffShape(wulff_shape.lattice, miller_list, scaled_se, symprec=self.symprec) def plot_one_stability_map( self, @@ -1764,7 +1771,7 @@ def plot_one_stability_map( label="", increments=50, delu_default=0, - plt=None, + ax=None, from_sphere_area=False, e_units="keV", r_units="nanometers", @@ -1792,17 +1799,20 @@ def plot_one_stability_map( r_units (str): Can be nanometers or Angstrom e_units (str): Can be keV or eV normalize (str): Whether or not to normalize energy by volume + + Returns: + plt.Axes: matplotlib Axes object """ - plt = plt or pretty_plot(width=8, height=7) + ax = ax or pretty_plot(width=8, height=7) - wulffshape = analyzer.wulff_from_chempot(delu_dict=delu_dict, delu_default=delu_default, symprec=self.symprec) + wulff_shape = analyzer.wulff_from_chempot(delu_dict=delu_dict, delu_default=delu_default, symprec=self.symprec) gform_list, r_list = [], [] - for r in np.linspace(1e-6, max_r, increments): - gform, r = self.wulff_gform_and_r( - wulffshape, + for radius in np.linspace(1e-6, max_r, increments): + gform, radius = self.wulff_gform_and_r( + wulff_shape, analyzer.ucell_entry, - r, + radius, from_sphere_area=from_sphere_area, r_units=r_units, e_units=e_units, @@ -1810,16 +1820,16 @@ def plot_one_stability_map( scale_per_atom=scale_per_atom, ) gform_list.append(gform) - r_list.append(r) + r_list.append(radius) ru = "nm" if r_units == "nanometers" else r"\AA" - plt.xlabel(rf"Particle radius (${ru}$)") + ax.xlabel(rf"Particle radius (${ru}$)") eu = f"${e_units}/{ru}^3$" - plt.ylabel(rf"$G_{{form}}$ ({eu})") + ax.ylabel(rf"$G_{{form}}$ ({eu})") - plt.plot(r_list, gform_list, label=label) + ax.plot(r_list, gform_list, label=label) - return plt + return ax def plot_all_stability_map( self, @@ -1827,7 +1837,7 @@ def plot_all_stability_map( increments=50, delu_dict=None, delu_default=0, - plt=None, + ax=None, labels=None, from_sphere_area=False, e_units="keV", @@ -1852,17 +1862,20 @@ def plot_all_stability_map( from_sphere_area (bool): There are two ways to calculate the bulk formation energy. Either by treating the volume and thus surface area of the particle as a perfect sphere, or as a Wulff shape. + + Returns: + plt.Axes: matplotlib Axes object """ - plt = plt or pretty_plot(width=8, height=7) + ax = ax or pretty_plot(width=8, height=7) - for i, analyzer in enumerate(self.se_analyzers): - label = labels[i] if labels else "" - plt = self.plot_one_stability_map( + for idx, analyzer in enumerate(self.se_analyzers): + label = labels[idx] if labels else "" + ax = self.plot_one_stability_map( analyzer, max_r, delu_dict, label=label, - plt=plt, + ax=ax, increments=increments, delu_default=delu_default, from_sphere_area=from_sphere_area, @@ -1872,7 +1885,7 @@ def plot_all_stability_map( scale_per_atom=scale_per_atom, ) - return plt + return ax def sub_chempots(gamma_dict, chempots): diff --git a/pymatgen/analysis/thermochemistry.py b/pymatgen/analysis/thermochemistry.py index 0dd3eb8a6a9..5c4ec46b721 100644 --- a/pymatgen/analysis/thermochemistry.py +++ b/pymatgen/analysis/thermochemistry.py @@ -2,8 +2,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from pymatgen.core.composition import Composition +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -63,24 +68,24 @@ def __init__( self.uncertainty = uncertainty @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: ThermoData """ - return ThermoData( - d["type"], - d["compound_name"], - d["phaseinfo"], - d["formula"], - d["value"], - d["ref"], - d["method"], - d["temp_range"], - d.get("uncertainty"), + return cls( + dct["type"], + dct["compound_name"], + dct["phaseinfo"], + dct["formula"], + dct["value"], + dct["ref"], + dct["method"], + dct["temp_range"], + dct.get("uncertainty"), ) def as_dict(self): diff --git a/pymatgen/analysis/topological/spillage.py b/pymatgen/analysis/topological/spillage.py index a1bd2f33416..ed1b9300bb3 100644 --- a/pymatgen/analysis/topological/spillage.py +++ b/pymatgen/analysis/topological/spillage.py @@ -14,7 +14,7 @@ class SOCSpillage: """ - Spin-orbit spillage criteria to predict whether a material is topologically non-trival. + Spin-orbit spillage criteria to predict whether a material is topologically non-trivial. The spillage criteria physically signifies number of band-inverted electrons. A non-zero, high value (generally >0.5) suggests non-trivial behavior. """ @@ -39,28 +39,28 @@ def isclose(n1, n2, rel_tol=1e-7): def orth(A): """Helper function to create orthonormal basis.""" u, s, _vh = np.linalg.svd(A, full_matrices=False) - M, N = A.shape + n_rows, n_cols = A.shape eps = np.finfo(float).eps - tol = max(M, N) * np.amax(s) * eps + tol = max(n_rows, n_cols) * np.amax(s) * eps num = np.sum(s > tol, dtype=int) - Q = u[:, :num] - return Q, num + orthonormal_basis = u[:, :num] + return orthonormal_basis, num def overlap_so_spinpol(self): """Main function to calculate SOC spillage.""" - noso = Wavecar(self.wf_noso) + no_so = Wavecar(self.wf_noso) so = Wavecar(self.wf_so) - bcell = np.linalg.inv(noso.a).T - tmp = np.linalg.norm(np.dot(np.diff(noso.kpoints, axis=0), bcell), axis=1) + b_cell = np.linalg.inv(no_so.a).T + tmp = np.linalg.norm(np.dot(np.diff(no_so.kpoints, axis=0), b_cell), axis=1) noso_k = np.concatenate(([0], np.cumsum(tmp))) - noso_bands = np.array(noso.band_energy)[:, :, :, 0] - noso_kvecs = np.array(noso.kpoints) - noso_occs = np.array(noso.band_energy)[:, :, :, 2] - noso_nkpts = len(noso_k) + noso_bands = np.array(no_so.band_energy)[:, :, :, 0] + noso_kvecs = np.array(no_so.kpoints) + noso_occs = np.array(no_so.band_energy)[:, :, :, 2] + n_kpts_noso = len(noso_k) - bcell = np.linalg.inv(so.a).T - tmp = np.linalg.norm(np.dot(np.diff(so.kpoints, axis=0), bcell), axis=1) + b_cell = np.linalg.inv(so.a).T + tmp = np.linalg.norm(np.dot(np.diff(so.kpoints, axis=0), b_cell), axis=1) so_k = np.concatenate(([0], np.cumsum(tmp))) so_bands = np.array([np.array(so.band_energy)[:, :, 0]]) so_kvecs = np.array(so.kpoints) @@ -68,7 +68,7 @@ def overlap_so_spinpol(self): so_nkpts = len(so_k) nelec_list = [] - for nk1 in range(1, noso_nkpts + 1): # no spin orbit kpoints loop + for nk1 in range(1, n_kpts_noso + 1): # no spin orbit kpoints loop knoso = noso_kvecs[nk1 - 1, :] for nk2 in range(1, so_nkpts + 1): # spin orbit kso = so_kvecs[nk2 - 1, :] @@ -126,7 +126,7 @@ def overlap_so_spinpol(self): y = [] nelec_tot = 0.0 - for nk1 in range(1, noso_nkpts + 1): # no spin orbit kpoints loop + for nk1 in range(1, n_kpts_noso + 1): # no spin orbit kpoints loop knoso = noso_kvecs[nk1 - 1, :] for nk2 in range(1, so_nkpts + 1): # spin orbit kso = so_kvecs[nk2 - 1, :] @@ -141,11 +141,11 @@ def overlap_so_spinpol(self): kpoints.append(kso) Mmn = 0.0 vnoso = np.array( - noso.coeffs[0][nk1 - 1][0] + no_so.coeffs[0][nk1 - 1][0] ) # noso.readBandCoeff(ispin=1, ikpt=nk1, iband=1, norm=False) n_noso1 = vnoso.shape[0] vnoso = np.array( - noso.coeffs[1][nk1 - 1][0] + no_so.coeffs[1][nk1 - 1][0] ) # noso.readBandCoeff(ispin=2, ikpt=nk1, iband=1, norm=False) # n_noso2 = vnoso.shape[0] vso = so.coeffs[nk1 - 1][0].flatten() # so.readBandCoeff(ispin=1, ikpt=nk2, iband=1, norm=False) @@ -155,13 +155,13 @@ def overlap_so_spinpol(self): Vnoso = np.zeros((vs, nelec_tot), dtype=complex) Vso = np.zeros((vs, nelec_tot), dtype=complex) - if np.array(noso.coeffs[1][nk1 - 1]).shape[1] == vs // 2: + if np.array(no_so.coeffs[1][nk1 - 1]).shape[1] == vs // 2: # if nk1==10 and nk2==10: # prepare matrices for n1 in range(1, nelec_up + 1): - Vnoso[0 : vs // 2, n1 - 1] = np.array(noso.coeffs[0][nk1 - 1][n1 - 1])[0 : vs // 2] + Vnoso[0 : vs // 2, n1 - 1] = np.array(no_so.coeffs[0][nk1 - 1][n1 - 1])[0 : vs // 2] for n1 in range(1, nelec_dn + 1): - Vnoso[vs // 2 : vs, n1 - 1 + nelec_up] = np.array(noso.coeffs[1][nk1 - 1][n1 - 1])[ + Vnoso[vs // 2 : vs, n1 - 1 + nelec_up] = np.array(no_so.coeffs[1][nk1 - 1][n1 - 1])[ 0 : vs // 2 ] for n1 in range(1, nelec_tot + 1): diff --git a/pymatgen/analysis/transition_state.py b/pymatgen/analysis/transition_state.py index 542aadf0d49..56be52650a9 100644 --- a/pymatgen/analysis/transition_state.py +++ b/pymatgen/analysis/transition_state.py @@ -10,6 +10,7 @@ import os from glob import glob +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -21,6 +22,9 @@ from pymatgen.io.vasp import Outcar from pymatgen.util.plotting import pretty_plot +if TYPE_CHECKING: + from typing_extensions import Self + class NEBAnalysis(MSONable): """An NEBAnalysis class.""" @@ -82,7 +86,7 @@ def setup_spline(self, spline_options=None): self.spline = CubicSpline(x=self.r, y=relative_energies, bc_type=((1, 0.0), (1, 0.0))) @classmethod - def from_outcars(cls, outcars, structures, **kwargs): + def from_outcars(cls, outcars, structures, **kwargs) -> Self: """ Initializes an NEBAnalysis from Outcar and Structure objects. Use the static constructors, e.g., from_dir instead if you @@ -188,7 +192,7 @@ def get_plot(self, normalize_rxn_coordinate: bool = True, label_barrier: bool = return ax @classmethod - def from_dir(cls, root_dir, relaxation_dirs=None, **kwargs): + def from_dir(cls, root_dir, relaxation_dirs=None, **kwargs) -> Self: """ Initializes a NEBAnalysis object from a directory of a NEB run. Note that OUTCARs must be present in all image directories. For the @@ -296,7 +300,9 @@ def combine_neb_plots(neb_analyses, arranged_neb_analyses=False, reverse_plot=Fa Note that the barrier labeled in y-axis in the combined plot might be different from that in the individual plot due to the reference energy used. reverse_plot: reverse the plot or percolation direction. - return: a NEBAnalysis object + + Returns: + a NEBAnalysis object """ x = StructureMatcher() for neb_index, neb in enumerate(neb_analyses): diff --git a/pymatgen/analysis/wulff.py b/pymatgen/analysis/wulff.py index c7d30805cab..8f23920069c 100644 --- a/pymatgen/analysis/wulff.py +++ b/pymatgen/analysis/wulff.py @@ -74,13 +74,14 @@ class WulffFacet: def __init__(self, normal, e_surf, normal_pt, dual_pt, index, m_ind_orig, miller): """ - :param normal: - :param e_surf: - :param normal_pt: - :param dual_pt: - :param index: - :param m_ind_orig: - :param miller: + Args: + normal: + e_surf: + normal_pt: + dual_pt: + index: + m_ind_orig: + miller: """ self.normal = normal self.e_surf = e_surf @@ -376,8 +377,8 @@ def get_plot( units_in_JPERM2 (bool): Units of surface energy, defaults to Joules per square meter (True) - Return: - (matplotlib.pyplot) + Returns: + mpl_toolkits.mplot3d.Axes3D: 3D plot of the Wulff shape. """ from mpl_toolkits.mplot3d import art3d @@ -449,7 +450,7 @@ def get_plot( cmap = plt.get_cmap(color_set) cmap.set_over("0.25") cmap.set_under("0.75") - bounds = [round(e, 2) for e in e_surf_on_wulff] + bounds = [round(ene, 2) for ene in e_surf_on_wulff] bounds.append(1.2 * bounds[-1]) norm = mpl.colors.BoundaryNorm(bounds, cmap.N) # display surface energies @@ -471,7 +472,7 @@ def get_plot( ax_3d.grid("off") if axis_off: ax_3d.axis("off") - return plt + return ax_3d def get_plotly( self, @@ -496,7 +497,7 @@ def get_plotly( units_in_JPERM2 (bool): Units of surface energy, defaults to Joules per square meter (True) - Return: + Returns: (plotly.graph_objects.Figure) """ units = "Jmโปยฒ" if units_in_JPERM2 else "eVร…โปยฒ" diff --git a/pymatgen/analysis/xps.py b/pymatgen/analysis/xps.py index 64606c1ca61..e1ec8399b2d 100644 --- a/pymatgen/analysis/xps.py +++ b/pymatgen/analysis/xps.py @@ -1,14 +1,14 @@ """ This is a module for XPS analysis. It is modelled after the Galore package (https://github.com/SMTG-UCL/galore), but with some modifications for easier analysis from pymatgen itself. Please cite the following original work if you use -this:: +this: Adam J. Jackson, Alex M. Ganose, Anna Regoutz, Russell G. Egdell, David O. Scanlon (2018). Galore: Broadening and weighting for simulation of photoelectron spectroscopy. Journal of Open Source Software, 3(26), 773, doi: 10.21105/joss.007733 You may wish to look at the optional dependency galore for more functionality such as plotting and other cross-sections. -Note that the atomic_subshell_photoionization_cross_sections.csv has been reparsed from the original compilation:: +Note that the atomic_subshell_photoionization_cross_sections.csv has been reparsed from the original compilation: Yeh, J. J.; Lindau, I. Atomic Subshell Photoionization Cross Sections and Asymmetry Parameters: 1 โฉฝ Z โฉฝ 103. Atomic Data and Nuclear Data Tables 1985, 32 (1), 1-155. https://doi.org/10.1016/0092-640X(85)90016-6. @@ -31,6 +31,8 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.electronic_structure.dos import CompleteDos @@ -76,10 +78,12 @@ class XPS(Spectrum): YLABEL = "Intensity" @classmethod - def from_dos(cls, dos: CompleteDos): + def from_dos(cls, dos: CompleteDos) -> Self: """ - :param dos: CompleteDos object with project element-orbital DOS. Can be obtained from Vasprun.get_complete_dos. - :param sigma: Smearing for Gaussian. + Args: + dos: CompleteDos object with project element-orbital DOS. + Can be obtained from Vasprun.get_complete_dos. + sigma: Smearing for Gaussian. Returns: XPS diff --git a/pymatgen/apps/battery/battery_abc.py b/pymatgen/apps/battery/battery_abc.py index 8a8b3f04f79..56ca216f278 100644 --- a/pymatgen/apps/battery/battery_abc.py +++ b/pymatgen/apps/battery/battery_abc.py @@ -227,7 +227,7 @@ def get_sub_electrodes(self, adjacent_only=True): Returns: A list of Electrode objects """ - NotImplementedError( + raise NotImplementedError( "The get_sub_electrodes function must be implemented for each concrete electrode " f"class {type(self).__name__}" ) diff --git a/pymatgen/apps/battery/conversion_battery.py b/pymatgen/apps/battery/conversion_battery.py index 8fbc99b1a52..a3ccc15f3eb 100644 --- a/pymatgen/apps/battery/conversion_battery.py +++ b/pymatgen/apps/battery/conversion_battery.py @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing_extensions import Self + from pymatgen.entries.computed_entries import ComputedEntry @@ -43,7 +45,13 @@ def initial_comp(self) -> Composition: return Composition(self.initial_comp_formula) @classmethod - def from_composition_and_pd(cls, comp, pd, working_ion_symbol="Li", allow_unstable=False): + def from_composition_and_pd( + cls, + comp, + pd: PhaseDiagram, + working_ion_symbol: str = "Li", + allow_unstable: bool = False, + ) -> Self | None: """Convenience constructor to make a ConversionElectrode from a composition and a phase diagram. @@ -56,11 +64,11 @@ def from_composition_and_pd(cls, comp, pd, working_ion_symbol="Li", allow_unstab """ working_ion = Element(working_ion_symbol) entry = working_ion_entry = None - for e in pd.stable_entries: - if e.reduced_formula == comp.reduced_formula: - entry = e - elif e.is_element and e.reduced_formula == working_ion_symbol: - working_ion_entry = e + for ent in pd.stable_entries: + if ent.reduced_formula == comp.reduced_formula: + entry = ent + elif ent.is_element and ent.reduced_formula == working_ion_symbol: + working_ion_entry = ent if not allow_unstable and not entry: raise ValueError(f"Not stable compound found at composition {comp}.") @@ -71,14 +79,16 @@ def from_composition_and_pd(cls, comp, pd, working_ion_symbol="Li", allow_unstab profile.reverse() if len(profile) < 2: return None - working_ion = working_ion_entry.elements[0].symbol - normalization_els = {el: amt for el, amt in comp.items() if el != Element(working_ion)} + + assert working_ion_entry is not None + working_ion_symbol = working_ion_entry.elements[0].symbol + normalization_els = {el: amt for el, amt in comp.items() if el != Element(working_ion_symbol)} framework = comp.as_dict() - if working_ion in framework: - framework.pop(working_ion) + if working_ion_symbol in framework: + framework.pop(working_ion_symbol) framework = Composition(framework) - v_pairs = [ + v_pairs: list[ConversionVoltagePair] = [ ConversionVoltagePair.from_steps( profile[i], profile[i + 1], @@ -88,15 +98,17 @@ def from_composition_and_pd(cls, comp, pd, working_ion_symbol="Li", allow_unstab for i in range(len(profile) - 1) ] - return ConversionElectrode( - voltage_pairs=v_pairs, + return cls( + voltage_pairs=v_pairs, # type: ignore[arg-type] working_ion_entry=working_ion_entry, initial_comp_formula=comp.reduced_formula, framework_formula=framework.reduced_formula, ) @classmethod - def from_composition_and_entries(cls, comp, entries_in_chemsys, working_ion_symbol="Li", allow_unstable=False): + def from_composition_and_entries( + cls, comp, entries_in_chemsys, working_ion_symbol="Li", allow_unstable=False + ) -> Self | None: """Convenience constructor to make a ConversionElectrode from a composition and all entries in a chemical system. @@ -111,7 +123,7 @@ def from_composition_and_entries(cls, comp, entries_in_chemsys, working_ion_symb for comparing with insertion electrodes """ pd = PhaseDiagram(entries_in_chemsys) - return ConversionElectrode.from_composition_and_pd(comp, pd, working_ion_symbol, allow_unstable) + return cls.from_composition_and_pd(comp, pd, working_ion_symbol, allow_unstable) def get_sub_electrodes(self, adjacent_only=True): """If this electrode contains multiple voltage steps, then it is possible @@ -278,7 +290,7 @@ class ConversionVoltagePair(AbstractVoltagePair): entries_discharge: Iterable[ComputedEntry] @classmethod - def from_steps(cls, step1, step2, normalization_els, framework_formula): + def from_steps(cls, step1, step2, normalization_els, framework_formula) -> Self: """Creates a ConversionVoltagePair from two steps in the element profile from a PD analysis. @@ -350,7 +362,7 @@ def from_steps(cls, step1, step2, normalization_els, framework_formula): entries_charge = step1["entries"] entries_discharge = step2["entries"] - return ConversionVoltagePair( + return cls( rxn=rxn, voltage=voltage, mAh=mAh, diff --git a/pymatgen/apps/battery/insertion_battery.py b/pymatgen/apps/battery/insertion_battery.py index 5b34947c6b4..74e97031726 100644 --- a/pymatgen/apps/battery/insertion_battery.py +++ b/pymatgen/apps/battery/insertion_battery.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from monty.json import MontyDecoder from scipy.constants import N_A from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram @@ -19,6 +20,8 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing_extensions import Self + __author__ = "Anubhav Jain, Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" @@ -39,7 +42,7 @@ def from_entries( entries: Iterable[ComputedEntry | ComputedStructureEntry], working_ion_entry: ComputedEntry | ComputedStructureEntry | PDEntry, strip_structures: bool = False, - ): + ) -> Self: """Create a new InsertionElectrode. Args: @@ -355,20 +358,17 @@ def __repr__(self): ) @classmethod - def from_dict_legacy(cls, d): + def from_dict_legacy(cls, dct) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: InsertionElectrode """ - from monty.json import MontyDecoder - - dec = MontyDecoder() return InsertionElectrode( - dec.process_decoded(d["entries"]), - dec.process_decoded(d["working_ion_entry"]), + MontyDecoder().process_decoded(dct["entries"]), + MontyDecoder().process_decoded(dct["working_ion_entry"]), ) def as_dict_legacy(self): @@ -389,7 +389,7 @@ class InsertionVoltagePair(AbstractVoltagePair): entry_discharge: ComputedEntry @classmethod - def from_entries(cls, entry1, entry2, working_ion_entry): + def from_entries(cls, entry1, entry2, working_ion_entry) -> Self: """ Args: entry1: Entry corresponding to one of the entries in the voltage step. @@ -405,7 +405,7 @@ def from_entries(cls, entry1, entry2, working_ion_entry): if entry_charge.composition.get_atomic_fraction(working_element) > entry2.composition.get_atomic_fraction( working_element ): - (entry_charge, entry_discharge) = (entry_discharge, entry_charge) + entry_charge, entry_discharge = entry_discharge, entry_charge comp_charge = entry_charge.composition comp_discharge = entry_discharge.composition diff --git a/pymatgen/apps/borg/hive.py b/pymatgen/apps/borg/hive.py index eec7ff2a2d4..6e44ac409e9 100644 --- a/pymatgen/apps/borg/hive.py +++ b/pymatgen/apps/borg/hive.py @@ -8,6 +8,7 @@ import os import warnings from glob import glob +from typing import TYPE_CHECKING from monty.io import zopen from monty.json import MSONable @@ -17,6 +18,9 @@ from pymatgen.io.vasp.inputs import Incar, Poscar, Potcar from pymatgen.io.vasp.outputs import Dynmat, Oszicar, Vasprun +if TYPE_CHECKING: + from typing_extensions import Self + logger = logging.getLogger(__name__) @@ -174,7 +178,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict Representation. @@ -282,7 +286,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict Representation. @@ -350,22 +354,22 @@ def assimilate(self, path): ComputedEntry """ try: - gaurun = GaussianOutput(path) + gau_run = GaussianOutput(path) except Exception as exc: logger.debug(f"error in {path}: {exc}") return None param = {} for p in self._parameters: - param[p] = getattr(gaurun, p) + param[p] = getattr(gau_run, p) data = {} for d in self._data: - data[d] = getattr(gaurun, d) + data[d] = getattr(gau_run, d) if self._inc_structure: - entry = ComputedStructureEntry(gaurun.final_structure, gaurun.final_energy, parameters=param, data=data) + entry = ComputedStructureEntry(gau_run.final_structure, gau_run.final_energy, parameters=param, data=data) else: entry = ComputedEntry( - gaurun.final_structure.composition, - gaurun.final_energy, + gau_run.final_structure.composition, + gau_run.final_energy, parameters=param, data=data, ) @@ -382,7 +386,7 @@ def get_valid_paths(self, path): List of valid dir/file paths for assimilation """ parent, _subdirs, files = path - return [os.path.join(parent, f) for f in files if os.path.splitext(f)[1] in self._file_extensions] + return [os.path.join(parent, file) for file in files if os.path.splitext(file)[1] in self._file_extensions] def __str__(self): return " GaussianToComputedEntryDrone" @@ -401,7 +405,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict Representation. diff --git a/pymatgen/apps/borg/queen.py b/pymatgen/apps/borg/queen.py index d5a257d5007..9e08d412e33 100644 --- a/pymatgen/apps/borg/queen.py +++ b/pymatgen/apps/borg/queen.py @@ -109,7 +109,7 @@ def load_data(self, filename): def order_assimilation(args): """Internal helper method for BorgQueen to process assimilation.""" - (path, drone, data, status) = args + path, drone, data, status = args new_data = drone.assimilate(path) if new_data: data.append(json.dumps(new_data, cls=MontyEncoder)) diff --git a/pymatgen/cli/pmg.py b/pymatgen/cli/pmg.py index 2524a9cead4..aedb8d5676e 100755 --- a/pymatgen/cli/pmg.py +++ b/pymatgen/cli/pmg.py @@ -22,7 +22,8 @@ def parse_view(args): """Handle view commands. - :param args: Args from command. + Args: + args: Args from command. """ from pymatgen.vis.structure_vtk import StructureVis @@ -37,7 +38,8 @@ def parse_view(args): def diff_incar(args): """Handle diff commands. - :param args: Args from command. + Args: + args: Args from command. """ filepath1 = args.incars[0] filepath2 = args.incars[1] @@ -49,7 +51,7 @@ def format_lists(v): return " ".join(f"{len(tuple(group))}*{i:.2f}" for (i, group) in itertools.groupby(v)) return v - d = incar1.diff(incar2) + diff = incar1.diff(incar2) output = [ ["SAME PARAMS", "", ""], ["---------------", "", ""], @@ -57,20 +59,18 @@ def format_lists(v): ["DIFFERENT PARAMS", "", ""], ["----------------", "", ""], ] - output.extend( - [(k, format_lists(d["Same"][k]), format_lists(d["Same"][k])) for k in sorted(d["Same"]) if k != "SYSTEM"] - ) - output.extend( - [ - ( - k, - format_lists(d["Different"][k]["INCAR1"]), - format_lists(d["Different"][k]["INCAR2"]), - ) - for k in sorted(d["Different"]) - if k != "SYSTEM" - ] - ) + output += [ + (k, format_lists(diff["Same"][k]), format_lists(diff["Same"][k])) for k in sorted(diff["Same"]) if k != "SYSTEM" + ] + output += [ + ( + k, + format_lists(diff["Different"][k]["INCAR1"]), + format_lists(diff["Different"][k]["INCAR2"]), + ) + for k in sorted(diff["Different"]) + if k != "SYSTEM" + ] print(tabulate(output, headers=["", filepath1, filepath2])) return 0 diff --git a/pymatgen/cli/pmg_analyze.py b/pymatgen/cli/pmg_analyze.py index 885a3633ea3..089479576ed 100644 --- a/pymatgen/cli/pmg_analyze.py +++ b/pymatgen/cli/pmg_analyze.py @@ -104,11 +104,11 @@ def get_magnetizations(dir: str, ion_list: list[int]): data = [] max_row = 0 for parent, _subdirs, files in os.walk(dir): - for f in files: - if re.match(r"OUTCAR*", f): + for file in files: + if re.match(r"OUTCAR*", file): try: row = [] - fullpath = os.path.join(parent, f) + fullpath = os.path.join(parent, file) outcar = Outcar(fullpath) mags = outcar.magnetization mags = [m["tot"] for m in mags] @@ -143,15 +143,15 @@ def analyze(args): default_energies = not (args.get_energies or args.ion_list) if args.get_energies or default_energies: - for d in args.directories: - return get_energies(d, args.reanalyze, args.verbose, args.quick, args.sort, args.format) + for folder in args.directories: + return get_energies(folder, args.reanalyze, args.verbose, args.quick, args.sort, args.format) if args.ion_list: if args.ion_list[0] == "All": ion_list = None else: - (start, end) = (int(i) for i in re.split(r"-", args.ion_list[0])) + start, end = (int(i) for i in re.split(r"-", args.ion_list[0])) ion_list = list(range(start, end + 1)) - for d in args.directories: - return get_magnetizations(d, ion_list) + for folder in args.directories: + return get_magnetizations(folder, ion_list) return -1 diff --git a/pymatgen/cli/pmg_config.py b/pymatgen/cli/pmg_config.py index bb9cd1aa9d8..49132e37384 100755 --- a/pymatgen/cli/pmg_config.py +++ b/pymatgen/cli/pmg_config.py @@ -31,8 +31,7 @@ def setup_cp2k_data(cp2k_data_dirs: list[str]) -> None: except OSError: reply = input("Destination directory exists. Continue (y/n)?") if reply != "y": - print("Exiting ...") - raise SystemExit(0) + raise SystemExit("Exiting ...") print("Generating pymatgen resource directory for CP2K...") basis_files = glob(f"{data_dir}/*BASIS*") @@ -42,7 +41,7 @@ def setup_cp2k_data(cp2k_data_dirs: list[str]) -> None: for potential_file in potential_files: print(f"Processing... {potential_file}") - with open(potential_file) as file: + with open(potential_file, encoding="utf-8") as file: try: chunks = chunk(file.read()) except IndexError: @@ -52,9 +51,10 @@ def setup_cp2k_data(cp2k_data_dirs: list[str]) -> None: potential = GthPotential.from_str(chk) potential.filename = os.path.basename(potential_file) potential.version = None - settings[potential.element.symbol]["potentials"][potential.get_hash()] = jsanitize( - potential, strict=True - ) + if potential.element is not None: + settings[potential.element.symbol]["potentials"][potential.get_hash()] = jsanitize( + potential, strict=True + ) except ValueError: # Chunk was readable, but the element is not pmg recognized continue @@ -64,7 +64,7 @@ def setup_cp2k_data(cp2k_data_dirs: list[str]) -> None: for basis_file in basis_files: print(f"Processing... {basis_file}") - with open(basis_file) as file: + with open(basis_file, encoding="utf-8") as file: try: chunks = chunk(file.read()) except IndexError: @@ -87,7 +87,7 @@ def setup_cp2k_data(cp2k_data_dirs: list[str]) -> None: for el in settings: print(f"Writing {el} settings file") - with open(os.path.join(target_dir, el), mode="w") as file: + with open(os.path.join(target_dir, el), mode="w", encoding="utf-8") as file: yaml.dump(settings.get(el), file, default_flow_style=False) print( @@ -111,8 +111,7 @@ def setup_potcars(potcar_dirs: list[str]): except OSError: reply = input("Destination directory exists. Continue (y/n)? ") if reply != "y": - print("Exiting ...") - raise SystemExit(0) + raise SystemExit("Exiting ...") print("Generating pymatgen resources directory...") @@ -170,17 +169,18 @@ def setup_potcars(potcar_dirs: list[str]): def build_enum(fortran_command: str = "gfortran") -> bool: """Build enum. - :param fortran_command: + Args: + fortran_command: The Fortran compiler command. """ cwd = os.getcwd() state = True try: - subprocess.call(["git", "clone", "--recursive", "https://github.com/msg-byu/enumlib.git"]) + subprocess.call(["git", "clone", "--recursive", "https://github.com/msg-byu/enumlib"]) os.chdir(f"{cwd}/enumlib/symlib/src") os.environ["F90"] = fortran_command subprocess.call(["make"]) - enumpath = f"{cwd}/enumlib/src" - os.chdir(enumpath) + enum_path = f"{cwd}/enumlib/src" + os.chdir(enum_path) subprocess.call(["make"]) subprocess.call(["make", "enum.x"]) shutil.copy("enum.x", os.path.join("..", "..")) @@ -196,7 +196,8 @@ def build_enum(fortran_command: str = "gfortran") -> bool: def build_bader(fortran_command="gfortran"): """Build bader package. - :param fortran_command: + Args: + fortran_command: The Fortran compiler command. """ bader_url = "http://theory.cm.utexas.edu/henkelman/code/bader/download/bader.tar.gz" cwd = os.getcwd() diff --git a/pymatgen/cli/pmg_potcar.py b/pymatgen/cli/pmg_potcar.py index 66398ed62b1..aaf064ff374 100755 --- a/pymatgen/cli/pmg_potcar.py +++ b/pymatgen/cli/pmg_potcar.py @@ -16,11 +16,11 @@ def proc_dir(dirname, proc_file_function): dirname (str): Directory name. proc_file_function (callable): Callable to execute on directory. """ - for f in os.listdir(dirname): - if os.path.isdir(os.path.join(dirname, f)): - proc_dir(os.path.join(dirname, f), proc_file_function) + for file in os.listdir(dirname): + if os.path.isdir(os.path.join(dirname, file)): + proc_dir(os.path.join(dirname, file), proc_file_function) else: - proc_file_function(dirname, f) + proc_file_function(dirname, file) def gen_potcar(dirname, filename): diff --git a/pymatgen/command_line/bader_caller.py b/pymatgen/command_line/bader_caller.py index 14590e6c79b..18c5413a437 100644 --- a/pymatgen/command_line/bader_caller.py +++ b/pymatgen/command_line/bader_caller.py @@ -34,6 +34,8 @@ from pymatgen.io.vasp.outputs import Chgcar if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure __author__ = "shyuepingong" @@ -88,7 +90,7 @@ def temp_decompress(file: str | Path, target_dir: str = ".") -> str: """Utility function to copy a compressed file to a target directory (ScratchDir) and decompress it, to avoid modifying files in place. - Parameters: + Args: file (str | Path): The path to the compressed file to be decompressed. target_dir (str, optional): The target directory where the decompressed file will be stored. Defaults to "." (current directory). @@ -441,7 +443,7 @@ def summary(self) -> dict[str, Any]: return summary @classmethod - def from_path(cls, path: str, suffix: str = "") -> BaderAnalysis: + def from_path(cls, path: str, suffix: str = "") -> Self: """Convenient constructor that takes in the path name of VASP run to perform Bader analysis. @@ -508,8 +510,10 @@ def bader_analysis_from_path(path: str, suffix: str = ""): 3. Runs Bader analysis twice: once for charge, and a second time for the charge difference (magnetization density). - :param path: path to folder to search in - :param suffix: specific suffix to look for (e.g. '.relax1' for 'CHGCAR.relax1.gz' + Args: + path: path to folder to search in + suffix: specific suffix to look for (e.g. '.relax1' for 'CHGCAR.relax1.gz' + Returns: summary dict """ @@ -566,10 +570,11 @@ def bader_analysis_from_objects( 2. Runs Bader analysis twice: once for charge, and a second time for the charge difference (magnetization density). - :param chgcar: Chgcar object - :param potcar: (optional) Potcar object - :param aeccar0: (optional) Chgcar object from aeccar0 file - :param aeccar2: (optional) Chgcar object from aeccar2 file + Args: + chgcar: Chgcar object + potcar: (optional) Potcar object + aeccar0: (optional) Chgcar object from aeccar0 file + aeccar2: (optional) Chgcar object from aeccar2 file Returns: summary dict diff --git a/pymatgen/command_line/chargemol_caller.py b/pymatgen/command_line/chargemol_caller.py index 75c8eba084a..7f2fc1bcd91 100644 --- a/pymatgen/command_line/chargemol_caller.py +++ b/pymatgen/command_line/chargemol_caller.py @@ -59,6 +59,8 @@ if TYPE_CHECKING: from pathlib import Path + from pymatgen.core import Structure + __author__ = "Martin Siron, Andrew S. Rosen" __version__ = "0.1" __maintainer__ = "Shyue Ping Ong" @@ -109,21 +111,26 @@ def __init__( self._potcar_path = self._get_filepath(path, "POTCAR") self._aeccar0_path = self._get_filepath(path, "AECCAR0") self._aeccar2_path = self._get_filepath(path, "AECCAR2") + if run_chargemol and not ( self._chgcar_path and self._potcar_path and self._aeccar0_path and self._aeccar2_path ): raise FileNotFoundError("CHGCAR, AECCAR0, AECCAR2, and POTCAR are all needed for Chargemol.") + if self._chgcar_path: - self.chgcar = Chgcar.from_file(self._chgcar_path) - self.structure = self.chgcar.structure - self.natoms = self.chgcar.poscar.natoms + self.chgcar: Chgcar | None = Chgcar.from_file(self._chgcar_path) + self.structure: Structure | None = self.chgcar.structure + self.natoms: list[int] | None = self.chgcar.poscar.natoms + else: self.chgcar = self.structure = self.natoms = None warnings.warn("No CHGCAR found. Some properties may be unavailable.", UserWarning) + if self._potcar_path: self.potcar = Potcar.from_file(self._potcar_path) else: warnings.warn("No POTCAR found. Some properties may be unavailable.", UserWarning) + self.aeccar0 = Chgcar.from_file(self._aeccar0_path) if self._aeccar0_path else None self.aeccar2 = Chgcar.from_file(self._aeccar2_path) if self._aeccar2_path else None diff --git a/pymatgen/command_line/critic2_caller.py b/pymatgen/command_line/critic2_caller.py index 5a8a6215c58..569abd29eea 100644 --- a/pymatgen/command_line/critic2_caller.py +++ b/pymatgen/command_line/critic2_caller.py @@ -42,7 +42,7 @@ import os import subprocess import warnings -from enum import Enum +from enum import Enum, unique from glob import glob from shutil import which from typing import TYPE_CHECKING @@ -61,6 +61,8 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure logging.basicConfig(level=logging.INFO) @@ -84,25 +86,26 @@ class Critic2Caller: "Critic2Caller requires the executable critic to be in the path. " "Please follow the instructions at https://github.com/aoterodelaroza/critic2.", ) - def __init__(self, input_script): + def __init__(self, input_script: str): """Run Critic2 on a given input script. - :param input_script: string defining the critic2 input + Args: + input_script: string defining the critic2 input """ # store if examining the input script is useful, # not otherwise used self._input_script = input_script - with open("input_script.cri", mode="w") as file: + with open("input_script.cri", mode="w", encoding="utf-8") as file: file.write(input_script) args = ["critic2", "input_script.cri"] with subprocess.Popen(args, stdout=subprocess.PIPE, stdin=subprocess.PIPE, close_fds=True) as rs: - stdout, stderr = rs.communicate() - stdout = stdout.decode() + _stdout, _stderr = rs.communicate() + stdout = _stdout.decode() - if stderr: - stderr = stderr.decode() + if _stderr: + stderr = _stderr.decode() warnings.warn(stderr) if rs.returncode != 0: @@ -127,7 +130,7 @@ def from_chgcar( write_cml=False, write_json=True, zpsp=None, - ): + ) -> Self: """Run Critic2 in automatic mode on a supplied structure, charge density (chgcar) and reference charge density (chgcar_ref). @@ -165,20 +168,21 @@ def from_chgcar( sub-dividing the Wigner-Seitz cell and between every atom pair closer than 10 Bohr, see critic2 manual for more options - :param structure: Structure to analyze - :param chgcar: Charge density to use for analysis. If None, will - use promolecular density. Should be a Chgcar object or path (string). - :param chgcar_ref: Reference charge density. If None, will use - chgcar as reference. Should be a Chgcar object or path (string). - :param user_input_settings (dict): as explained above - :param write_cml (bool): Useful for debug, if True will write all - critical points to a file 'table.cml' in the working directory - useful for visualization - :param write_json (bool): Whether to write out critical points - and YT json. YT integration will be performed with this setting. - :param zpsp (dict): Dict of element/symbol name to number of electrons - (ZVAL in VASP pseudopotential), with which to properly augment core regions - and calculate charge transfer. Optional. + Args: + structure: Structure to analyze + chgcar: Charge density to use for analysis. If None, will + use promolecular density. Should be a Chgcar object or path (string). + chgcar_ref: Reference charge density. If None, will use + chgcar as reference. Should be a Chgcar object or path (string). + user_input_settings (dict): as explained above + write_cml (bool): Useful for debug, if True will write all + critical points to a file 'table.cml' in the working directory + useful for visualization + write_json (bool): Whether to write out critical points + and YT json. YT integration will be performed with this setting. + zpsp (dict): Dict of element/symbol name to number of electrons + (ZVAL in VASP pseudopotential), with which to properly augment core regions + and calculate charge transfer. Optional. """ settings = {"CPEPS": 0.1, "SEED": ["WS", "PAIR DIST 10"]} if user_input_settings: @@ -219,7 +223,7 @@ def from_chgcar( input_script += ["yt"] input_script += ["yt JSON yt.json"] - input_script = "\n".join(input_script) + input_script_str = "\n".join(input_script) with ScratchDir("."): structure.to(filename="POSCAR") @@ -234,7 +238,7 @@ def from_chgcar( elif chgcar_ref: os.symlink(chgcar_ref, "ref.CHGCAR") - caller = cls(input_script) + caller = cls(input_script_str) caller.output = Critic2Analysis( structure, @@ -248,7 +252,7 @@ def from_chgcar( return caller @classmethod - def from_path(cls, path, suffix="", zpsp=None): + def from_path(cls, path, suffix="", zpsp=None) -> Self: """Convenience method to run critic2 analysis on a folder with typical VASP output files. This method will: @@ -262,10 +266,11 @@ def from_path(cls, path, suffix="", zpsp=None): 3. Runs critic2 analysis twice: once for charge, and a second time for the charge difference (magnetization density). - :param path: path to folder to search in - :param suffix: specific suffix to look for (e.g. '.relax1' for - 'CHGCAR.relax1.gz') - :param zpsp: manually specify ZPSP if POTCAR not present + Args: + path: path to folder to search in + suffix: specific suffix to look for (e.g. '.relax1' for + 'CHGCAR.relax1.gz') + zpsp: manually specify ZPSP if POTCAR not present """ chgcar_path = get_filepath("CHGCAR", "Could not find CHGCAR!", path, suffix) chgcar = Chgcar.from_file(chgcar_path) @@ -306,6 +311,7 @@ def from_path(cls, path, suffix="", zpsp=None): return cls.from_chgcar(chgcar.structure, chgcar, chgcar_ref, zpsp=zpsp) +@unique class CriticalPointType(Enum): """Enum type for the different varieties of critical point.""" @@ -359,15 +365,16 @@ def __init__( Note this class is usually associated with a Structure, so has information on multiplicity/point group symmetry. - :param index: index of point - :param type: type of point, given as a string - :param coords: Cartesian coordinates in Angstroms - :param frac_coords: fractional coordinates - :param point_group: point group associated with critical point - :param multiplicity: number of equivalent critical points - :param field: value of field at point (f) - :param field_gradient: gradient of field at point (grad f) - :param field_hessian: hessian of field at point (del^2 f) + Args: + index: index of point + type: type of point, given as a string + coords: Cartesian coordinates in Angstroms + frac_coords: fractional coordinates + point_group: point group associated with critical point + multiplicity: number of equivalent critical points + field: value of field at point (f) + field_gradient: gradient of field at point (grad f) + field_hessian: hessian of field at point (del^2 f) """ self.index = index self._type = type @@ -441,16 +448,15 @@ class with bonding information. By default, this returns Only one of (stdout, cpreport) required, with cpreport preferred since this is a new, native JSON output from critic2. - :param structure: associated Structure - :param stdout: stdout from running critic2 in automatic - mode - :param stderr: stderr from running critic2 in automatic - mode - :param cpreport: json output from CPREPORT command - :param yt: json output from YT command - :param zpsp (dict): Dict of element/symbol name to number of electrons - (ZVAL in VASP pseudopotential), with which to calculate charge transfer. - Optional. + Args: + structure: associated Structure + stdout: stdout from running critic2 in automatic mode + stderr: stderr from running critic2 in automatic mode + cpreport: json output from CPREPORT command + yt: json output from YT command + zpsp (dict): Dict of element/symbol name to number of electrons + (ZVAL in VASP pseudopotential), with which to calculate charge transfer. + Optional. Args: structure (Structure): Associated Structure. @@ -523,7 +529,7 @@ def structure_graph(self, include_critical_points=("bond", "ring", "cage")): edge_weight = "bond_length" edge_weight_units = "ร…" - sg = StructureGraph.with_empty_graph( + struct_graph = StructureGraph.from_empty_graph( structure, name="bonds", edge_weight_name=edge_weight, @@ -591,7 +597,7 @@ def structure_graph(self, include_critical_points=("bond", "ring", "cage")): "frac_coords": self.nodes[idx]["frac_coords"], } - sg.add_edge( + struct_graph.add_edge( struct_from_idx, struct_to_idx, from_jimage=from_lvec, @@ -600,7 +606,7 @@ def structure_graph(self, include_critical_points=("bond", "ring", "cage")): edge_properties=edge_properties, ) - return sg + return struct_graph def get_critical_point_for_site(self, n: int): """ @@ -867,10 +873,11 @@ def _parse_stdout(self, stdout): def _add_node(self, idx, unique_idx, frac_coords): """Add information about a node describing a critical point. - :param idx: index - :param unique_idx: index of unique CriticalPoint, - used to look up more information of point (field etc.) - :param frac_coord: fractional coordinates of point + Args: + idx: index + unique_idx: index of unique CriticalPoint, + used to look up more information of point (field etc.) + frac_coords: fractional coordinates of point """ self.nodes[idx] = {"unique_idx": unique_idx, "frac_coords": frac_coords} @@ -887,13 +894,14 @@ def _add_edge(self, idx, from_idx, from_lvec, to_idx, to_lvec): this as a single edge linking nuclei with the properties of the bond critical point stored as an edge attribute. - :param idx: index of node - :param from_idx: from index of node - :param from_lvec: vector of lattice image the from node is in - as tuple of ints - :param to_idx: to index of node - :param to_lvec: vector of lattice image the to node is in as - tuple of ints + Args: + idx: index of node + from_idx: from index of node + from_lvec: vector of lattice image the from node is in + as tuple of ints + to_idx: to index of node + to_lvec: vector of lattice image the to node is in as + tuple of ints """ self.edges[idx] = { "from_idx": from_idx, diff --git a/pymatgen/command_line/enumlib_caller.py b/pymatgen/command_line/enumlib_caller.py index 8c7abd6b7dd..d8df42362c7 100644 --- a/pymatgen/command_line/enumlib_caller.py +++ b/pymatgen/command_line/enumlib_caller.py @@ -243,8 +243,8 @@ def get_sg_info(ss): n_disordered * lcm( *( - f.limit_denominator(n_disordered * self.max_cell_size).denominator - for f in map(fractions.Fraction, index_amounts) + fraction.limit_denominator(n_disordered * self.max_cell_size).denominator + for fraction in map(fractions.Fraction, index_amounts) ) ) ) @@ -256,7 +256,7 @@ def get_sg_info(ss): # enumeration. See Cu7Te5.cif test file. base *= 10 - # base = n_disordered # 10 ** int(math.ceil(math.log10(n_disordered))) + # base = n_disordered # 10 ** math.ceil(math.log10(n_disordered)) # To get a reasonable number of structures, we fix concentrations to the # range expected in the original structure. total_amounts = sum(index_amounts) @@ -266,7 +266,7 @@ def get_sg_info(ss): if abs(conc * base - round(conc * base)) < 1e-5: output.append(f"{int(round(conc * base))} {int(round(conc * base))} {base}") else: - min_conc = int(math.floor(conc * base)) + min_conc = math.floor(conc * base) output.append(f"{min_conc - 1} {min_conc + 1} {base}") output.append("") logger.debug("Generated input file:\n" + "\n".join(output)) diff --git a/pymatgen/command_line/gulp_caller.py b/pymatgen/command_line/gulp_caller.py index cbc0cef08e5..09f3ce9e73a 100644 --- a/pymatgen/command_line/gulp_caller.py +++ b/pymatgen/command_line/gulp_caller.py @@ -351,21 +351,21 @@ def library_line(file_name): """ gulp_lib_set = "GULP_LIB" in os.environ - def readable(f): - return os.path.isfile(f) and os.access(f, os.R_OK) + def readable(file): + return os.path.isfile(file) and os.access(file, os.R_OK) gin = "" dirpath, _fname = os.path.split(file_name) if dirpath and readable(file_name): # Full path specified - gin = "library " + file_name + gin = f"library {file_name}" else: fpath = os.path.join(os.getcwd(), file_name) # Check current dir if readable(fpath): - gin = "library " + fpath + gin = f"library {fpath}" elif gulp_lib_set: # Check the GULP_LIB path fpath = os.path.join(os.environ["GULP_LIB"], file_name) if readable(fpath): - gin = "library " + file_name + gin = f"library {file_name}" if gin: return gin + "\n" raise GulpError("GULP library not found") @@ -541,7 +541,7 @@ def get_relaxed_structure(gout: str): gout (str): GULP output string. Returns: - (Structure) relaxed structure. + Structure: relaxed structure. """ # Find the structure lines structure_lines = [] @@ -549,6 +549,7 @@ def get_relaxed_structure(gout: str): output_lines = gout.split("\n") n_lines = len(output_lines) idx = 0 + a = b = c = alpha = beta = gamma = 0.0 # Compute the input lattice parameters while idx < n_lines: line = output_lines[idx] @@ -621,9 +622,13 @@ def get_relaxed_structure(gout: str): alpha = float(cell_param_lines[3].split()[1]) beta = float(cell_param_lines[4].split()[1]) gamma = float(cell_param_lines[5].split()[1]) - latt = Lattice.from_parameters(a, b, c, alpha, beta, gamma) + if not all([a, b, c, alpha, beta, gamma]): + raise ValueError( + f"Missing lattice parameters in Gulp output: {a=}, {b=}, {c=}, {alpha=}, {beta=}, {gamma=}" + ) + lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma) - return Structure(latt, sp, coords) + return Structure(lattice, sp, coords) class GulpCaller: diff --git a/pymatgen/command_line/vampire_caller.py b/pymatgen/command_line/vampire_caller.py index 7eac6877853..cfd4153f678 100644 --- a/pymatgen/command_line/vampire_caller.py +++ b/pymatgen/command_line/vampire_caller.py @@ -74,7 +74,7 @@ def __init__( If False, attempt to use NN, NNN, etc. interactions. user_input_settings (dict): optional commands for VAMPIRE Monte Carlo - Parameters: + Attributes: sgraph (StructureGraph): Ground state graph. unique_site_ids (dict): Maps each site to its unique identifier nn_interactions (dict): {i: j} pairs of NN interactions @@ -135,9 +135,9 @@ def __init__( stdout = stdout.decode() if stderr: - vanhelsing = stderr.decode() - if len(vanhelsing) > 27: # Suppress blank warning msg - logging.warning(vanhelsing) + van_helsing = stderr.decode() + if len(van_helsing) > 27: # Suppress blank warning msg + logging.warning(van_helsing) if process.returncode != 0: raise RuntimeError(f"Vampire exited with return code {process.returncode}.") @@ -146,9 +146,9 @@ def __init__( self._stderr = stderr # Process output - nmats = max(self.mat_id_dict.values()) - parsed_out, critical_temp = VampireCaller.parse_stdout("output", nmats) - self.output = VampireOutput(parsed_out, nmats, critical_temp) + n_mats = max(self.mat_id_dict.values()) + parsed_out, critical_temp = VampireCaller.parse_stdout("output", n_mats) + self.output = VampireOutput(parsed_out, n_mats, critical_temp) def _create_mat(self): structure = self.structure @@ -158,43 +158,41 @@ def _create_mat(self): # Maps sites to material id for vampire inputs mat_id_dict = {} - nmats = 0 + n_mats = 0 for key in self.unique_site_ids: spin_up, spin_down = False, False - nmats += 1 # at least 1 mat for each unique site + n_mats += 1 # at least 1 mat for each unique site # Check which spin sublattices exist for this site id for site in key: - m = magmoms[site] - if m > 0: + if magmoms[site] > 0: spin_up = True - if m < 0: + if magmoms[site] < 0: spin_down = True # Assign material id for each site for site in key: - m = magmoms[site] if spin_up and not spin_down: - mat_id_dict[site] = nmats + mat_id_dict[site] = n_mats if spin_down and not spin_up: - mat_id_dict[site] = nmats + mat_id_dict[site] = n_mats if spin_up and spin_down: # Check if spin up or down shows up first m0 = magmoms[key[0]] - if m > 0 and m0 > 0: - mat_id_dict[site] = nmats - if m < 0 and m0 < 0: - mat_id_dict[site] = nmats - if m > 0 > m0: - mat_id_dict[site] = nmats + 1 - if m < 0 < m0: - mat_id_dict[site] = nmats + 1 + if magmoms[site] > 0 and m0 > 0: + mat_id_dict[site] = n_mats + if magmoms[site] < 0 and m0 < 0: + mat_id_dict[site] = n_mats + if magmoms[site] > 0 > m0: + mat_id_dict[site] = n_mats + 1 + if magmoms[site] < 0 < m0: + mat_id_dict[site] = n_mats + 1 # Increment index if two sublattices if spin_up and spin_down: - nmats += 1 + n_mats += 1 - mat_file = [f"material:num-materials={nmats}"] + mat_file = [f"material:num-materials={n_mats}"] for key in self.unique_site_ids: i = self.unique_site_ids[key] # unique site id @@ -230,7 +228,7 @@ def _create_mat(self): def _create_input(self): structure = self.structure - mcbs = self.mc_box_size + mc_box_size = self.mc_box_size equil_timesteps = self.equil_timesteps mc_timesteps = self.mc_timesteps mat_name = self.mat_name @@ -255,9 +253,9 @@ def _create_input(self): # System size in nm input_script += [ - f"dimensions:system-size-x = {mcbs:.1f} !nm", - f"dimensions:system-size-y = {mcbs:.1f} !nm", - f"dimensions:system-size-z = {mcbs:.1f} !nm", + f"dimensions:system-size-x = {mc_box_size:.1f} !nm", + f"dimensions:system-size-y = {mc_box_size:.1f} !nm", + f"dimensions:system-size-z = {mc_box_size:.1f} !nm", ] # Critical temperature Monte Carlo calculation diff --git a/pymatgen/core/__init__.py b/pymatgen/core/__init__.py index cd317670ce6..194c6ff4012 100644 --- a/pymatgen/core/__init__.py +++ b/pymatgen/core/__init__.py @@ -1,4 +1,3 @@ -# ruff: noqa: PLC0414 """This package contains core modules and classes for representing structures and operations on them.""" from __future__ import annotations @@ -10,25 +9,13 @@ from ruamel.yaml import YAML -from pymatgen.core.composition import Composition as Composition -from pymatgen.core.lattice import Lattice as Lattice -from pymatgen.core.operations import SymmOp as SymmOp -from pymatgen.core.periodic_table import DummySpecie as DummySpecie -from pymatgen.core.periodic_table import DummySpecies as DummySpecies -from pymatgen.core.periodic_table import Element as Element -from pymatgen.core.periodic_table import Species as Species -from pymatgen.core.periodic_table import get_el_sp as get_el_sp -from pymatgen.core.sites import PeriodicSite as PeriodicSite -from pymatgen.core.sites import Site as Site -from pymatgen.core.structure import IMolecule as IMolecule -from pymatgen.core.structure import IStructure as IStructure -from pymatgen.core.structure import Molecule as Molecule -from pymatgen.core.structure import PeriodicNeighbor as PeriodicNeighbor -from pymatgen.core.structure import SiteCollection as SiteCollection -from pymatgen.core.structure import Structure as Structure -from pymatgen.core.units import ArrayWithUnit as ArrayWithUnit -from pymatgen.core.units import FloatWithUnit as FloatWithUnit -from pymatgen.core.units import Unit as Unit +from pymatgen.core.composition import Composition +from pymatgen.core.lattice import Lattice +from pymatgen.core.operations import SymmOp +from pymatgen.core.periodic_table import DummySpecie, DummySpecies, Element, Species, get_el_sp +from pymatgen.core.sites import PeriodicSite, Site +from pymatgen.core.structure import IMolecule, IStructure, Molecule, PeriodicNeighbor, SiteCollection, Structure +from pymatgen.core.units import ArrayWithUnit, FloatWithUnit, Unit __author__ = "Pymatgen Development Team" __email__ = "pymatgen@googlegroups.com" diff --git a/pymatgen/core/bonds.py b/pymatgen/core/bonds.py index 9b4fdfce8ef..1c45b74f931 100644 --- a/pymatgen/core/bonds.py +++ b/pymatgen/core/bonds.py @@ -117,7 +117,7 @@ def obtain_all_bond_lengths(sp1, sp2, default_bl: float | None = None): bond length as a default value (bond order = 1). If None, a ValueError will be thrown. - Return: + Returns: A dict mapping bond order to bond length in angstrom """ if isinstance(sp1, Element): diff --git a/pymatgen/core/composition.py b/pymatgen/core/composition.py index 6f4875208c0..4ef08da472e 100644 --- a/pymatgen/core/composition.py +++ b/pymatgen/core/composition.py @@ -12,20 +12,23 @@ from functools import total_ordering from itertools import combinations_with_replacement, product from math import isnan -from typing import TYPE_CHECKING, Union, cast +from typing import TYPE_CHECKING, cast from monty.fractions import gcd, gcd_float from monty.json import MSONable from monty.serialization import loadfn -from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp +from pymatgen.core.periodic_table import DummySpecies, Element, ElementType, Species, get_el_sp from pymatgen.core.units import Mass from pymatgen.util.string import Stringify, formula_double_format if TYPE_CHECKING: from collections.abc import Generator, Iterator -SpeciesLike = Union[str, Element, Species, DummySpecies] + from typing_extensions import Self + + from pymatgen.util.typing import SpeciesLike + module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -125,7 +128,7 @@ def __init__(self, *args, strict: bool = False, **kwargs) -> None: elif len(args) == 1 and isinstance(args[0], str): elem_map = self._parse_formula(args[0]) # type: ignore[assignment] elif len(args) == 1 and isinstance(args[0], float) and isnan(args[0]): - raise ValueError("float('NaN') is not a valid Composition, did you mean str('NaN')?") + raise ValueError("float('NaN') is not a valid Composition, did you mean 'NaN'?") else: elem_map = dict(*args, **kwargs) # type: ignore elem_amt = {} @@ -498,28 +501,10 @@ def contains_element_type(self, category: str) -> bool: Returns: bool: True if any elements in Composition match category, otherwise False """ - allowed_categories = ( - "noble_gas", - "transition_metal", - "post_transition_metal", - "rare_earth_metal", - "metal", - "metalloid", - "alkali", - "alkaline", - "halogen", - "chalcogen", - "lanthanoid", - "actinoid", - "quadrupolar", - "s-block", - "p-block", - "d-block", - "f-block", - ) + allowed_categories = [category.value for category in ElementType] if category not in allowed_categories: - raise ValueError(f"Please pick a category from: {allowed_categories}") + raise ValueError(f"Invalid {category=}, pick from {allowed_categories}") if "block" in category: return any(category[0] in el.block for el in self.elements) @@ -617,7 +602,7 @@ def __repr__(self) -> str: return f"{cls_name}({formula!r})" @classmethod - def from_dict(cls, d) -> Composition: + def from_dict(cls, dct: dict) -> Self: """Creates a composition from a dict generated by as_dict(). Strictly not necessary given that the standard constructor already takes in such an input, but this method preserves the standard pymatgen API of having @@ -625,12 +610,12 @@ def from_dict(cls, d) -> Composition: for easier introspection. Args: - d (dict): {symbol: amount} dict. + dct (dict): {symbol: amount} dict. """ - return cls(d) + return cls(dct) @classmethod - def from_weight_dict(cls, weight_dict) -> Composition: + def from_weight_dict(cls, weight_dict: dict[SpeciesLike, float]) -> Self: """Creates a Composition based on a dict of atomic fractions calculated from a dict of weight fractions. Allows for quick creation of the class from weight-based notations commonly used in the industry, such as @@ -1180,7 +1165,7 @@ def _parse_chomp_and_rank(m, f, m_dict, m_points): m_points1 = m_points m_form1 = fuzzy_formula m_dict1 = dict(m_dict) - (m_form1, m_dict1, m_points1) = _parse_chomp_and_rank(m1, m_form1, m_dict1, m_points1) + m_form1, m_dict1, m_points1 = _parse_chomp_and_rank(m1, m_form1, m_dict1, m_points1) if m_dict1: # there was a real match for match in Composition._comps_from_fuzzy_formula(m_form1, m_dict1, m_points1, factor): @@ -1192,7 +1177,7 @@ def _parse_chomp_and_rank(m, f, m_dict, m_points): m_points2 = m_points m_form2 = fuzzy_formula m_dict2 = dict(m_dict) - (m_form2, m_dict2, m_points2) = _parse_chomp_and_rank(m2, m_form2, m_dict2, m_points2) + m_form2, m_dict2, m_points2 = _parse_chomp_and_rank(m2, m_form2, m_dict2, m_points2) if m_dict2: # there was a real match for match in Composition._comps_from_fuzzy_formula(m_form2, m_dict2, m_points2, factor): @@ -1214,7 +1199,7 @@ def reduce_formula(sym_amt, iupac_ordering: bool = False) -> tuple[str, float]: the elements. Returns: - (reduced_formula, factor). + tuple[str, float]: reduced formula and factor. """ syms = sorted(sym_amt, key=lambda x: [get_el_sp(x).X, x]) @@ -1229,7 +1214,7 @@ def reduce_formula(sym_amt, iupac_ordering: bool = False) -> tuple[str, float]: # if the composition contains a poly anion if len(syms) >= 3 and get_el_sp(syms[-1]).X - get_el_sp(syms[-2]).X < 1.65: poly_sym_amt = {syms[i]: sym_amt[syms[i]] / factor for i in [-2, -1]} - (poly_form, poly_factor) = reduce_formula(poly_sym_amt, iupac_ordering=iupac_ordering) + poly_form, poly_factor = reduce_formula(poly_sym_amt, iupac_ordering=iupac_ordering) if poly_factor != 1: poly_anions.append(f"({poly_form}){poly_factor}") diff --git a/pymatgen/core/interface.py b/pymatgen/core/interface.py index cab443c1070..f47ece317db 100644 --- a/pymatgen/core/interface.py +++ b/pymatgen/core/interface.py @@ -1,21 +1,2330 @@ -"""This module provides classes to store, generate, and manipulate material interfaces.""" +"""This module provides classes to store, generate, and manipulate material interfaces, including grain boundaries.""" from __future__ import annotations +import logging +import warnings +from fractions import Fraction +from functools import reduce from itertools import chain, combinations, product +from math import cos, floor, gcd +from typing import TYPE_CHECKING, Any import numpy as np +from monty.fractions import lcm from numpy.testing import assert_allclose from scipy.cluster.hierarchy import fcluster, linkage from scipy.spatial.distance import squareform from pymatgen.analysis.adsorption import AdsorbateSiteFinder from pymatgen.core.lattice import Lattice -from pymatgen.core.sites import PeriodicSite -from pymatgen.core.structure import Site, Structure +from pymatgen.core.sites import PeriodicSite, Site +from pymatgen.core.structure import Structure from pymatgen.core.surface import Slab from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import ArrayLike + from typing_extensions import Self + + from pymatgen.core.trajectory import Vector3D + from pymatgen.util.typing import CompositionLike + +# This module implements representations of grain boundaries, as well as +# algorithms for generating them. + +__author__ = "Xiang-Guo Li" +__copyright__ = "Copyright 2018, The Materials Virtual Lab" +__version__ = "0.1" +__maintainer__ = "Xiang-Guo Li" +__email__ = "xil110@ucsd.edu" +__date__ = "7/30/18" + +logger = logging.getLogger(__name__) + + +class GrainBoundary(Structure): + """ + Subclass of Structure representing a GrainBoundary (GB) object. Implements additional + attributes pertaining to gbs, but the init method does not actually implement any + algorithm that creates a GB. This is a DUMMY class who's init method only holds + information about the GB. Also has additional methods that returns other information + about a GB such as sigma value. + + Note that all gbs have the GB surface normal oriented in the c-direction. This means + the lattice vectors a and b are in the GB surface plane (at least for one grain) and + the c vector is out of the surface plane (though not necessarily perpendicular to the + surface). + """ + + def __init__( + self, + lattice: np.ndarray | Lattice, + species: Sequence[CompositionLike], + coords: Sequence[ArrayLike], + rotation_axis: Vector3D, + rotation_angle: float, + gb_plane: Vector3D, + join_plane: Vector3D, + init_cell: Structure, + vacuum_thickness: float, + ab_shift: tuple[float, float], + site_properties: dict[str, Any], + oriented_unit_cell: Structure, + validate_proximity: bool = False, + coords_are_cartesian: bool = False, + properties: dict | None = None, + ) -> None: + """ + Makes a GB structure, a structure object with additional information + and methods pertaining to gbs. + + Args: + lattice (Lattice/3x3 array): The lattice, either as an instance or + any 2D array. Each row should correspond to a lattice vector. + species ([Species]): Sequence of species on each site. Can take in + flexible input, including: + + i. A sequence of element / species specified either as string + symbols, e.g. ["Li", "Fe2+", "P", ...] or atomic numbers, + e.g., (3, 56, ...) or actual Element or Species objects. + + ii. List of dict of elements/species and occupancies, e.g., + [{"Fe" : 0.5, "Mn":0.5}, ...]. This allows the setup of + disordered structures. + coords (Nx3 array): list of fractional/cartesian coordinates for each species. + rotation_axis (list[int]): Rotation axis of GB in the form of a list of integers, e.g. [1, 1, 0]. + rotation_angle (float, in unit of degree): rotation angle of GB. + gb_plane (list): Grain boundary plane in the form of a list of integers + e.g.: [1, 2, 3]. + join_plane (list): Joining plane of the second grain in the form of a list of + integers. e.g.: [1, 2, 3]. + init_cell (Structure): initial bulk structure to form the GB. + site_properties (dict): Properties associated with the sites as a + dict of sequences, The sequences have to be the same length as + the atomic species and fractional_coords. For GB, you should + have the 'grain_label' properties to classify the sites as 'top', + 'bottom', 'top_incident', or 'bottom_incident'. + vacuum_thickness (float in angstrom): The thickness of vacuum inserted + between two grains of the GB. + ab_shift (list of float, in unit of crystal vector a, b): The relative + shift along a, b vectors. + oriented_unit_cell (Structure): oriented unit cell of the bulk init_cell. + Helps to accurately calculate the bulk properties that are consistent + with GB calculations. + validate_proximity (bool): Whether to check if there are sites + that are less than 0.01 Ang apart. Defaults to False. + coords_are_cartesian (bool): Set to True if you are providing + coordinates in Cartesian coordinates. Defaults to False. + properties (dict): dictionary containing properties associated + with the whole GrainBoundary. + """ + self.oriented_unit_cell = oriented_unit_cell + self.rotation_axis = rotation_axis + self.rotation_angle = rotation_angle + self.gb_plane = gb_plane + self.join_plane = join_plane + self.init_cell = init_cell + self.vacuum_thickness = vacuum_thickness + self.ab_shift = ab_shift + super().__init__( + lattice, + species, + coords, + validate_proximity=validate_proximity, + coords_are_cartesian=coords_are_cartesian, + site_properties=site_properties, + properties=properties, + ) + + def copy(self): + """ + Convenience method to get a copy of the structure, with options to add + site properties. + + Returns: + A copy of the Structure, with optionally new site_properties and + optionally sanitized. + """ + return GrainBoundary( + self.lattice, + self.species_and_occu, + self.frac_coords, + self.rotation_axis, + self.rotation_angle, + self.gb_plane, + self.join_plane, + self.init_cell, + self.vacuum_thickness, + self.ab_shift, + self.site_properties, + self.oriented_unit_cell, + ) + + def get_sorted_structure(self, key=None, reverse=False): + """ + Get a sorted copy of the structure. The parameters have the same + meaning as in list.sort. By default, sites are sorted by the + electronegativity of the species. Note that Slab has to override this + because of the different __init__ args. + + Args: + key: Specifies a function of one argument that is used to extract + a comparison key from each list element: key=str.lower. The + default value is None (compare the elements directly). + reverse (bool): If set to True, then the list elements are sorted + as if each comparison were reversed. + """ + sites = sorted(self, key=key, reverse=reverse) + struct = Structure.from_sites(sites) + return GrainBoundary( + struct.lattice, + struct.species_and_occu, + struct.frac_coords, + self.rotation_axis, + self.rotation_angle, + self.gb_plane, + self.join_plane, + self.init_cell, + self.vacuum_thickness, + self.ab_shift, + self.site_properties, + self.oriented_unit_cell, + ) + + @property + def sigma(self) -> int: + """ + This method returns the sigma value of the GB. + If using 'quick_gen' to generate GB, this value is not valid. + """ + return int(round(self.oriented_unit_cell.volume / self.init_cell.volume)) + + @property + def sigma_from_site_prop(self) -> int: + """ + This method returns the sigma value of the GB from site properties. + If the GB structure merge some atoms due to the atoms too closer with + each other, this property will not work. + """ + n_coi = 0 + if None in self.site_properties["grain_label"]: + raise RuntimeError("Site were merged, this property do not work") + for tag in self.site_properties["grain_label"]: + if "incident" in tag: + n_coi += 1 + return int(round(len(self) / n_coi)) + + @property + def top_grain(self) -> Structure: + """Return the top grain (Structure) of the GB.""" + top_sites = [] + for i, tag in enumerate(self.site_properties["grain_label"]): + if "top" in tag: + top_sites.append(self.sites[i]) + return Structure.from_sites(top_sites) + + @property + def bottom_grain(self) -> Structure: + """Return the bottom grain (Structure) of the GB.""" + bottom_sites = [] + for i, tag in enumerate(self.site_properties["grain_label"]): + if "bottom" in tag: + bottom_sites.append(self.sites[i]) + return Structure.from_sites(bottom_sites) + + @property + def coincidents(self) -> list[Site]: + """Return the a list of coincident sites.""" + coincident_sites = [] + for idx, tag in enumerate(self.site_properties["grain_label"]): + if "incident" in tag: + coincident_sites.append(self.sites[idx]) + return coincident_sites + + def __str__(self): + comp = self.composition + outs = [ + f"Gb Summary ({comp.formula})", + f"Reduced Formula: {comp.reduced_formula}", + f"Rotation axis: {self.rotation_axis}", + f"Rotation angle: {self.rotation_angle}", + f"GB plane: {self.gb_plane}", + f"Join plane: {self.join_plane}", + f"vacuum thickness: {self.vacuum_thickness}", + f"ab_shift: {self.ab_shift}", + ] + + def to_str(x, rjust=10): + return (f"{x:0.6f}").rjust(rjust) + + outs += ( + f"abc : {' '.join(to_str(i) for i in self.lattice.abc)}", + f"angles: {' '.join(to_str(i) for i in self.lattice.angles)}", + f"Sites ({len(self)})", + ) + for idx, site in enumerate(self, start=1): + outs.append(f"{idx} {site.species_string} {' '.join(to_str(coord, 12) for coord in site.frac_coords)}") + return "\n".join(outs) + + def as_dict(self): + """ + Returns: + Dictionary representation of GrainBoundary object. + """ + dct = super().as_dict() + dct["@module"] = type(self).__module__ + dct["@class"] = type(self).__name__ + dct["init_cell"] = self.init_cell.as_dict() + dct["rotation_axis"] = self.rotation_axis + dct["rotation_angle"] = self.rotation_angle + dct["gb_plane"] = self.gb_plane + dct["join_plane"] = self.join_plane + dct["vacuum_thickness"] = self.vacuum_thickness + dct["ab_shift"] = self.ab_shift + dct["oriented_unit_cell"] = self.oriented_unit_cell.as_dict() + return dct + + @classmethod + def from_dict(cls, dct: dict) -> GrainBoundary: # type: ignore[override] + """ + Generates a GrainBoundary object from a dictionary created by as_dict(). + + Args: + dct: dict + + Returns: + GrainBoundary object + """ + lattice = Lattice.from_dict(dct["lattice"]) + sites = [PeriodicSite.from_dict(site_dict, lattice) for site_dict in dct["sites"]] + struct = Structure.from_sites(sites) + + return GrainBoundary( + lattice=lattice, + species=struct.species_and_occu, + coords=struct.frac_coords, + rotation_axis=dct["rotation_axis"], + rotation_angle=dct["rotation_angle"], + gb_plane=dct["gb_plane"], + join_plane=dct["join_plane"], + init_cell=Structure.from_dict(dct["init_cell"]), + vacuum_thickness=dct["vacuum_thickness"], + ab_shift=dct["ab_shift"], + oriented_unit_cell=Structure.from_dict(dct["oriented_unit_cell"]), + site_properties=struct.site_properties, + ) + + +class GrainBoundaryGenerator: + """ + This class is to generate grain boundaries (GBs) from bulk + conventional cell (fcc, bcc can from the primitive cell), and works for Cubic, + Tetragonal, Orthorhombic, Rhombohedral, and Hexagonal systems. + It generate GBs from given parameters, which includes + GB plane, rotation axis, rotation angle. + + This class works for any general GB, including twist, tilt and mixed GBs. + The three parameters, rotation axis, GB plane and rotation angle, are + sufficient to identify one unique GB. While sometimes, users may not be able + to tell what exactly rotation angle is but prefer to use sigma as an parameter, + this class also provides the function that is able to return all possible + rotation angles for a specific sigma value. + The same sigma value (with rotation axis fixed) can correspond to + multiple rotation angles. + Users can use structure matcher in pymatgen to get rid of the redundant structures. + """ + + def __init__(self, initial_structure: Structure, symprec: float = 0.1, angle_tolerance: float = 1) -> None: + """ + Args: + initial_structure (Structure): Initial input structure. It can + be conventional or primitive cell (primitive cell works for bcc and fcc). + For fcc and bcc, using conventional cell can lead to a non-primitive + grain boundary structure. + This code supplies Cubic, Tetragonal, Orthorhombic, Rhombohedral, and + Hexagonal systems. + symprec (float): Tolerance for symmetry finding. Defaults to 0.1 (the value used + in Materials Project), which is for structures with slight deviations + from their proper atomic positions (e.g., structures relaxed with + electronic structure codes). + A smaller value of 0.01 is often used for properly refined + structures with atoms in the proper symmetry coordinates. + User should make sure the symmetry is what you want. + angle_tolerance (float): Angle tolerance for symmetry finding. + """ + analyzer = SpacegroupAnalyzer(initial_structure, symprec, angle_tolerance) + self.lat_type = analyzer.get_lattice_type()[0] + if self.lat_type == "t": + # need to use the conventional cell for tetragonal + initial_structure = analyzer.get_conventional_standard_structure() + a, b, c = initial_structure.lattice.abc + # c axis of tetragonal structure not in the third direction + if abs(a - b) > symprec: + # a == c, rotate b to the third direction + if abs(a - c) < symprec: + initial_structure.make_supercell([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + # b == c, rotate a to the third direction + else: + initial_structure.make_supercell([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + elif self.lat_type == "h": + alpha, beta, gamma = initial_structure.lattice.angles + # c axis is not in the third direction + if abs(gamma - 90) < angle_tolerance: + # alpha = 120 or 60, rotate b, c to a, b vectors + if abs(alpha - 90) > angle_tolerance: + initial_structure.make_supercell([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) + # beta = 120 or 60, rotate c, a to a, b vectors + elif abs(beta - 90) > angle_tolerance: + initial_structure.make_supercell([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + elif self.lat_type == "r": + # need to use primitive cell for rhombohedra + initial_structure = analyzer.get_primitive_standard_structure() + elif self.lat_type == "o": + # need to use the conventional cell for orthorhombic + initial_structure = analyzer.get_conventional_standard_structure() + self.initial_structure = initial_structure + + def gb_from_parameters( + self, + rotation_axis, + rotation_angle, + expand_times=4, + vacuum_thickness=0.0, + ab_shift: tuple[float, float] = (0, 0), + normal=False, + ratio=None, + plane=None, + max_search=20, + tol_coi=1.0e-8, + rm_ratio=0.7, + quick_gen=False, + ): + """ + Args: + rotation_axis (list): Rotation axis of GB in the form of a list of integer + e.g.: [1, 1, 0] + rotation_angle (float, in unit of degree): rotation angle used to generate GB. + Make sure the angle is accurate enough. You can use the enum* functions + in this class to extract the accurate angle. + e.g.: The rotation angle of sigma 3 twist GB with the rotation axis + [1, 1, 1] and GB plane (1, 1, 1) can be 60 degree. + If you do not know the rotation angle, but know the sigma value, we have + provide the function get_rotation_angle_from_sigma which is able to return + all the rotation angles of sigma value you provided. + expand_times (int): The multiple times used to expand one unit grain to larger grain. + This is used to tune the grain length of GB to warrant that the two GBs in one + cell do not interact with each other. Default set to 4. + vacuum_thickness (float, in angstrom): The thickness of vacuum that you want to insert + between two grains of the GB. Default to 0. + ab_shift (list of float, in unit of a, b vectors of Gb): in plane shift of two grains + normal (logic): + determine if need to require the c axis of top grain (first transformation matrix) + perpendicular to the surface or not. + default to false. + ratio (list of integers): + lattice axial ratio. + For cubic system, ratio is not needed. + For tetragonal system, ratio = [mu, mv], list of two integers, + that is, mu/mv = c2/a2. If it is irrational, set it to none. + For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, + that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None. + e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. + For rhombohedral system, ratio = [mu, mv], list of two integers, + that is, mu/mv is the ratio of (1+2*cos(alpha))/cos(alpha). + If irrational, set it to None. + For hexagonal system, ratio = [mu, mv], list of two integers, + that is, mu/mv = c2/a2. If it is irrational, set it to none. + This code also supplies a class method to generate the ratio from the + structure (get_ratio). User can also make their own approximation and + input the ratio directly. + plane (list): Grain boundary plane in the form of a list of integers + e.g.: [1, 2, 3]. If none, we set it as twist GB. The plane will be perpendicular + to the rotation axis. + max_search (int): max search for the GB lattice vectors that give the smallest GB + lattice. If normal is true, also max search the GB c vector that perpendicular + to the plane. For complex GB, if you want to speed up, you can reduce this value. + But too small of this value may lead to error. + tol_coi (float): tolerance to find the coincidence sites. When making approximations to + the ratio needed to generate the GB, you probably need to increase this tolerance to + obtain the correct number of coincidence sites. To check the number of coincidence + sites are correct or not, you can compare the generated Gb object's sigma_from_site_prop + with enum* sigma values (what user expected by input). + rm_ratio (float): the criteria to remove the atoms which are too close with each other. + rm_ratio*bond_length of bulk system is the criteria of bond length, below which the atom + will be removed. Default to 0.7. + quick_gen (bool): whether to quickly generate a supercell, if set to true, no need to + find the smallest cell. + + Returns: + Grain boundary structure (GB object). + """ + lat_type = self.lat_type + # if the initial structure is primitive cell in cubic system, + # calculate the transformation matrix from its conventional cell + # to primitive cell, basically for bcc and fcc systems. + trans_cry = np.eye(3) + if lat_type == "c": + analyzer = SpacegroupAnalyzer(self.initial_structure) + convention_cell = analyzer.get_conventional_standard_structure() + vol_ratio = self.initial_structure.volume / convention_cell.volume + # bcc primitive cell, belong to cubic system + if abs(vol_ratio - 0.5) < 1.0e-3: + trans_cry = np.array([[0.5, 0.5, -0.5], [-0.5, 0.5, 0.5], [0.5, -0.5, 0.5]]) + logger.info("Make sure this is for cubic with bcc primitive cell") + # fcc primitive cell, belong to cubic system + elif abs(vol_ratio - 0.25) < 1.0e-3: + trans_cry = np.array([[0.5, 0.5, 0], [0, 0.5, 0.5], [0.5, 0, 0.5]]) + logger.info("Make sure this is for cubic with fcc primitive cell") + else: + logger.info("Make sure this is for cubic with conventional cell") + elif lat_type == "t": + logger.info("Make sure this is for tetragonal system") + if ratio is None: + logger.info("Make sure this is for irrational c2/a2") + elif len(ratio) != 2: + raise RuntimeError("Tetragonal system needs correct c2/a2 ratio") + elif lat_type == "o": + logger.info("Make sure this is for orthorhombic system") + if ratio is None: + raise RuntimeError("CSL does not exist if all axial ratios are irrational for an orthorhombic system") + if len(ratio) != 3: + raise RuntimeError("Orthorhombic system needs correct c2:b2:a2 ratio") + elif lat_type == "h": + logger.info("Make sure this is for hexagonal system") + if ratio is None: + logger.info("Make sure this is for irrational c2/a2") + elif len(ratio) != 2: + raise RuntimeError("Hexagonal system needs correct c2/a2 ratio") + elif lat_type == "r": + logger.info("Make sure this is for rhombohedral system") + if ratio is None: + logger.info("Make sure this is for irrational (1+2*cos(alpha)/cos(alpha) ratio") + elif len(ratio) != 2: + raise RuntimeError("Rhombohedral system needs correct (1+2*cos(alpha)/cos(alpha) ratio") + else: + raise RuntimeError( + "Lattice type not implemented. This code works for cubic, " + "tetragonal, orthorhombic, rhombohedral, hexagonal systems" + ) + + # transform four index notation to three index notation for hexagonal and rhombohedral + if len(rotation_axis) == 4: + u1 = rotation_axis[0] + v1 = rotation_axis[1] + w1 = rotation_axis[3] + if lat_type.lower() == "h": + u = 2 * u1 + v1 + v = 2 * v1 + u1 + w = w1 + rotation_axis = [u, v, w] + elif lat_type.lower() == "r": + u = 2 * u1 + v1 + w1 + v = v1 + w1 - u1 + w = w1 - 2 * v1 - u1 + rotation_axis = [u, v, w] + + # make sure gcd(rotation_axis)==1 + if reduce(gcd, rotation_axis) != 1: + rotation_axis = [int(round(x / reduce(gcd, rotation_axis))) for x in rotation_axis] + # transform four index notation to three index notation for plane + if plane is not None and len(plane) == 4: + u1, v1, w1 = plane[0], plane[1], plane[3] + plane = [u1, v1, w1] + # set the plane for grain boundary when plane is None. + if plane is None: + if lat_type.lower() == "c": + plane = rotation_axis + else: + if lat_type.lower() == "h": + c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] + metric = np.array([[1, -0.5, 0], [-0.5, 1, 0], [0, 0, c2_a2_ratio]]) + elif lat_type.lower() == "r": + cos_alpha = 0.5 if ratio is None else 1.0 / (ratio[0] / ratio[1] - 2) + metric = np.array( + [ + [1, cos_alpha, cos_alpha], + [cos_alpha, 1, cos_alpha], + [cos_alpha, cos_alpha, 1], + ] + ) + elif lat_type.lower() == "t": + c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] + metric = np.array([[1, 0, 0], [0, 1, 0], [0, 0, c2_a2_ratio]]) + elif lat_type.lower() == "o": + for idx in range(3): + if ratio[idx] is None: + ratio[idx] = 1 + metric = np.array([[1, 0, 0], [0, ratio[1] / ratio[2], 0], [0, 0, ratio[0] / ratio[2]]]) + else: + raise RuntimeError("Lattice type has not implemented.") + + plane = np.matmul(rotation_axis, metric) + fractions = [Fraction(x).limit_denominator() for x in plane] + least_mul = reduce(lcm, [fraction.denominator for fraction in fractions]) + plane = [int(round(x * least_mul)) for x in plane] + + if reduce(gcd, plane) != 1: + index = reduce(gcd, plane) + plane = [int(round(x / index)) for x in plane] + + t1, t2 = self.get_trans_mat( + r_axis=rotation_axis, + angle=rotation_angle, + normal=normal, + trans_cry=trans_cry, + lat_type=lat_type, + ratio=ratio, + surface=plane, + max_search=max_search, + quick_gen=quick_gen, + ) + + # find the join_plane + if lat_type.lower() != "c": + if lat_type.lower() == "h": + if ratio is None: + mu, mv = [1, 1] + else: + mu, mv = ratio + trans_cry1 = np.array([[1, 0, 0], [-0.5, np.sqrt(3.0) / 2.0, 0], [0, 0, np.sqrt(mu / mv)]]) + elif lat_type.lower() == "r": + if ratio is None: + c2_a2_ratio = 1.0 + else: + mu, mv = ratio + c2_a2_ratio = 3 / (2 - 6 * mv / mu) + trans_cry1 = np.array( + [ + [0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + [-0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + [0, -1 * np.sqrt(3.0) / 3.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + ] + ) + else: + if lat_type.lower() == "t": + if ratio is None: + mu, mv = [1, 1] + else: + mu, mv = ratio + lam = mv + elif lat_type.lower() == "o": + new_ratio = [1 if v is None else v for v in ratio] + mu, lam, mv = new_ratio + trans_cry1 = np.array([[1, 0, 0], [0, np.sqrt(lam / mv), 0], [0, 0, np.sqrt(mu / mv)]]) + else: + trans_cry1 = trans_cry + grain_matrix = np.dot(t2, trans_cry1) + plane_init = np.cross(grain_matrix[0], grain_matrix[1]) + if lat_type.lower() != "c": + plane_init = np.dot(plane_init, trans_cry1.T) + join_plane = self.vec_to_surface(plane_init) + + parent_structure = self.initial_structure.copy() + # calculate the bond_length in bulk system. + if len(parent_structure) == 1: + temp_str = parent_structure.copy() + temp_str.make_supercell([1, 1, 2]) + distance = temp_str.distance_matrix + else: + distance = parent_structure.distance_matrix + bond_length = np.min(distance[np.nonzero(distance)]) + + # top grain + top_grain = fix_pbc(parent_structure * t1) + + # obtain the smallest oriented cell + if normal and not quick_gen: + t_temp = self.get_trans_mat( + r_axis=rotation_axis, + angle=rotation_angle, + normal=False, + trans_cry=trans_cry, + lat_type=lat_type, + ratio=ratio, + surface=plane, + max_search=max_search, + ) + oriented_unit_cell = fix_pbc(parent_structure * t_temp[0]) + t_matrix = oriented_unit_cell.lattice.matrix + normal_v_plane = np.cross(t_matrix[0], t_matrix[1]) + unit_normal_v = normal_v_plane / np.linalg.norm(normal_v_plane) + unit_ab_adjust = (t_matrix[2] - np.dot(unit_normal_v, t_matrix[2]) * unit_normal_v) / np.dot( + unit_normal_v, t_matrix[2] + ) + else: + oriented_unit_cell = top_grain.copy() + unit_ab_adjust = 0.0 + + # bottom grain, using top grain's lattice matrix + bottom_grain = fix_pbc(parent_structure * t2, top_grain.lattice.matrix) + + # label both grains with 'top','bottom','top_incident','bottom_incident' + n_sites = len(top_grain) + t_and_b = Structure( + top_grain.lattice, + top_grain.species + bottom_grain.species, + list(top_grain.frac_coords) + list(bottom_grain.frac_coords), + ) + t_and_b_dis = t_and_b.lattice.get_all_distances( + t_and_b.frac_coords[0:n_sites], t_and_b.frac_coords[n_sites : n_sites * 2] + ) + index_incident = np.nonzero(t_and_b_dis < np.min(t_and_b_dis) + tol_coi) + + top_labels = [] + for idx in range(n_sites): + if idx in index_incident[0]: + top_labels.append("top_incident") + else: + top_labels.append("top") + bottom_labels = [] + for idx in range(n_sites): + if idx in index_incident[1]: + bottom_labels.append("bottom_incident") + else: + bottom_labels.append("bottom") + top_grain = Structure( + Lattice(top_grain.lattice.matrix), + top_grain.species, + top_grain.frac_coords, + site_properties={"grain_label": top_labels}, + ) + bottom_grain = Structure( + Lattice(bottom_grain.lattice.matrix), + bottom_grain.species, + bottom_grain.frac_coords, + site_properties={"grain_label": bottom_labels}, + ) + + # expand both grains + top_grain.make_supercell([1, 1, expand_times]) + bottom_grain.make_supercell([1, 1, expand_times]) + top_grain = fix_pbc(top_grain) + bottom_grain = fix_pbc(bottom_grain) + + # determine the top-grain location. + edge_b = 1.0 - max(bottom_grain.frac_coords[:, 2]) + edge_t = 1.0 - max(top_grain.frac_coords[:, 2]) + c_adjust = (edge_t - edge_b) / 2.0 + + # construct all species + all_species = [] + all_species.extend([site.specie for site in bottom_grain]) + all_species.extend([site.specie for site in top_grain]) + + half_lattice = top_grain.lattice + # calculate translation vector, perpendicular to the plane + normal_v_plane = np.cross(half_lattice.matrix[0], half_lattice.matrix[1]) + unit_normal_v = normal_v_plane / np.linalg.norm(normal_v_plane) + translation_v = unit_normal_v * vacuum_thickness + + # construct the final lattice + whole_matrix_no_vac = np.array(half_lattice.matrix) + whole_matrix_no_vac[2] = half_lattice.matrix[2] * 2 + whole_matrix_with_vac = whole_matrix_no_vac.copy() + whole_matrix_with_vac[2] = whole_matrix_no_vac[2] + translation_v * 2 + whole_lat = Lattice(whole_matrix_with_vac) + + # construct the coords, move top grain with translation_v + all_coords = [] + grain_labels = bottom_grain.site_properties["grain_label"] + top_grain.site_properties["grain_label"] + for site in bottom_grain: + all_coords.append(site.coords) + for site in top_grain: + all_coords.append( + site.coords + + half_lattice.matrix[2] * (1 + c_adjust) + + unit_ab_adjust * np.linalg.norm(half_lattice.matrix[2] * (1 + c_adjust)) + + translation_v + + ab_shift[0] * whole_matrix_with_vac[0] + + ab_shift[1] * whole_matrix_with_vac[1] + ) + + gb_with_vac = Structure( + whole_lat, + all_species, + all_coords, + coords_are_cartesian=True, + site_properties={"grain_label": grain_labels}, + ) + # merge closer atoms. extract near GB atoms. + cos_c_norm_plane = np.dot(unit_normal_v, whole_matrix_with_vac[2]) / whole_lat.c + range_c_len = abs(bond_length / cos_c_norm_plane / whole_lat.c) + sites_near_gb = [] + sites_away_gb: list[PeriodicSite] = [] + for site in gb_with_vac: + if ( + site.frac_coords[2] < range_c_len + or site.frac_coords[2] > 1 - range_c_len + or (site.frac_coords[2] > 0.5 - range_c_len and site.frac_coords[2] < 0.5 + range_c_len) + ): + sites_near_gb.append(site) + else: + sites_away_gb.append(site) + if len(sites_near_gb) >= 1: + s_near_gb = Structure.from_sites(sites_near_gb) + s_near_gb.merge_sites(tol=bond_length * rm_ratio, mode="d") + all_sites = sites_away_gb + s_near_gb.sites # type: ignore + gb_with_vac = Structure.from_sites(all_sites) + + # move coordinates into the periodic cell. + gb_with_vac = fix_pbc(gb_with_vac, whole_lat.matrix) + return GrainBoundary( + whole_lat, + gb_with_vac.species, + gb_with_vac.cart_coords, # type: ignore[arg-type] + rotation_axis, + rotation_angle, + plane, + join_plane, + self.initial_structure, + vacuum_thickness, + ab_shift, + site_properties=gb_with_vac.site_properties, + oriented_unit_cell=oriented_unit_cell, + coords_are_cartesian=True, + ) + + def get_ratio(self, max_denominator=5, index_none=None): + """ + find the axial ratio needed for GB generator input. + + Args: + max_denominator (int): the maximum denominator for + the computed ratio, default to be 5. + index_none (int): specify the irrational axis. + 0-a, 1-b, 2-c. Only may be needed for orthorhombic system. + + Returns: + axial ratio needed for GB generator (list of integers). + """ + structure = self.initial_structure + lat_type = self.lat_type + if lat_type in ("t", "h"): + # For tetragonal and hexagonal system, ratio = c2 / a2. + a, _, c = structure.lattice.lengths + if c > a: + frac = Fraction(c**2 / a**2).limit_denominator(max_denominator) + ratio = [frac.numerator, frac.denominator] + else: + frac = Fraction(a**2 / c**2).limit_denominator(max_denominator) + ratio = [frac.denominator, frac.numerator] + elif lat_type == "r": + # For rhombohedral system, ratio = (1 + 2 * cos(alpha)) / cos(alpha). + cos_alpha = cos(structure.lattice.alpha / 180 * np.pi) + frac = Fraction((1 + 2 * cos_alpha) / cos_alpha).limit_denominator(max_denominator) + ratio = [frac.numerator, frac.denominator] + elif lat_type == "o": + # For orthorhombic system, ratio = c2:b2:a2.If irrational for one axis, set it to None. + ratio = [None] * 3 + lat = (structure.lattice.c, structure.lattice.b, structure.lattice.a) + index = [0, 1, 2] + if index_none is None: + min_index = np.argmin(lat) + index.pop(min_index) + frac1 = Fraction(lat[index[0]] ** 2 / lat[min_index] ** 2).limit_denominator(max_denominator) + frac2 = Fraction(lat[index[1]] ** 2 / lat[min_index] ** 2).limit_denominator(max_denominator) + com_lcm = lcm(frac1.denominator, frac2.denominator) + ratio[min_index] = com_lcm + ratio[index[0]] = frac1.numerator * int(round(com_lcm / frac1.denominator)) + ratio[index[1]] = frac2.numerator * int(round(com_lcm / frac2.denominator)) + else: + index.pop(index_none) + if lat[index[0]] > lat[index[1]]: + frac = Fraction(lat[index[0]] ** 2 / lat[index[1]] ** 2).limit_denominator(max_denominator) + ratio[index[0]] = frac.numerator + ratio[index[1]] = frac.denominator + else: + frac = Fraction(lat[index[1]] ** 2 / lat[index[0]] ** 2).limit_denominator(max_denominator) + ratio[index[1]] = frac.numerator + ratio[index[0]] = frac.denominator + elif lat_type == "c": + # Cubic system does not need axial ratio. + return None + else: + raise RuntimeError("Lattice type not implemented.") + return ratio + + @staticmethod + def get_trans_mat( + r_axis, + angle, + normal=False, + trans_cry=None, + lat_type="c", + ratio=None, + surface=None, + max_search=20, + quick_gen=False, + ): + """ + Find the two transformation matrix for each grain from given rotation axis, + GB plane, rotation angle and corresponding ratio (see explanation for ratio + below). + The structure of each grain can be obtained by applying the corresponding + transformation matrix to the conventional cell. + The algorithm for this code is from reference, Acta Cryst, A32,783(1976). + + Args: + r_axis (list of 3 integers, e.g. u, v, w or 4 integers, e.g. u, v, t, w for hex/rho system only): the + rotation axis of the grain boundary. + angle (float, in unit of degree): the rotation angle of the grain boundary + normal (logic): determine if need to require the c axis of one grain associated with + the first transformation matrix perpendicular to the surface or not. + default to false. + trans_cry (np.array): shape 3x3. If the structure given are primitive cell in cubic system, e.g. + bcc or fcc system, trans_cry is the transformation matrix from its + conventional cell to the primitive cell. + lat_type (str): one character to specify the lattice type. Defaults to 'c' for cubic. + 'c' or 'C': cubic system + 't' or 'T': tetragonal system + 'o' or 'O': orthorhombic system + 'h' or 'H': hexagonal system + 'r' or 'R': rhombohedral system + ratio (list of integers): lattice axial ratio. + For cubic system, ratio is not needed. + For tetragonal system, ratio = [mu, mv], list of two integers, that is, mu/mv = c2/a2. If it is + irrational, set it to none. + For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, that is, mu:lam:mv = c2:b2:a2. + If irrational for one axis, set it to None. e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. + For rhombohedral system, ratio = [mu, mv], list of two integers, + that is, mu/mv is the ratio of (1+2*cos(alpha)/cos(alpha). + If irrational, set it to None. + For hexagonal system, ratio = [mu, mv], list of two integers, + that is, mu/mv = c2/a2. If it is irrational, set it to none. + surface (list of 3 integers, e.g. h, k, l or 4 integers, e.g. h, k, i, l for hex/rho system only): The + miller index of grain boundary plane, with the format of [h,k,l] if surface is not given, the default + is perpendicular to r_axis, which is a twist grain boundary. + max_search (int): max search for the GB lattice vectors that give the smallest GB + lattice. If normal is true, also max search the GB c vector that perpendicular + to the plane. + quick_gen (bool): whether to quickly generate a supercell, if set to true, no need to + find the smallest cell. + + Returns: + t1 (3 by 3 integer array): The transformation array for one grain. + t2 (3 by 3 integer array): The transformation array for the other grain + """ + if trans_cry is None: + trans_cry = np.eye(3) + # transform four index notation to three index notation + if len(r_axis) == 4: + u1 = r_axis[0] + v1 = r_axis[1] + w1 = r_axis[3] + if lat_type.lower() == "h": + u = 2 * u1 + v1 + v = 2 * v1 + u1 + w = w1 + r_axis = [u, v, w] + elif lat_type.lower() == "r": + u = 2 * u1 + v1 + w1 + v = v1 + w1 - u1 + w = w1 - 2 * v1 - u1 + r_axis = [u, v, w] + + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + + if surface is not None and len(surface) == 4: + u1 = surface[0] + v1 = surface[1] + w1 = surface[3] + surface = [u1, v1, w1] + # set the surface for grain boundary. + if surface is None: + if lat_type.lower() == "c": + surface = r_axis + else: + if lat_type.lower() == "h": + c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] + metric = np.array([[1, -0.5, 0], [-0.5, 1, 0], [0, 0, c2_a2_ratio]]) + elif lat_type.lower() == "r": + cos_alpha = 0.5 if ratio is None else 1.0 / (ratio[0] / ratio[1] - 2) + metric = np.array( + [ + [1, cos_alpha, cos_alpha], + [cos_alpha, 1, cos_alpha], + [cos_alpha, cos_alpha, 1], + ] + ) + elif lat_type.lower() == "t": + c2_a2_ratio = 1.0 if ratio is None else ratio[0] / ratio[1] + metric = np.array([[1, 0, 0], [0, 1, 0], [0, 0, c2_a2_ratio]]) + elif lat_type.lower() == "o": + for idx in range(3): + if ratio[idx] is None: + ratio[idx] = 1 + metric = np.array( + [ + [1, 0, 0], + [0, ratio[1] / ratio[2], 0], + [0, 0, ratio[0] / ratio[2]], + ] + ) + else: + raise RuntimeError("Lattice type has not implemented.") + + surface = np.matmul(r_axis, metric) + fractions = [Fraction(x).limit_denominator() for x in surface] + least_mul = reduce(lcm, [fraction.denominator for fraction in fractions]) + surface = [int(round(x * least_mul)) for x in surface] + + if reduce(gcd, surface) != 1: + index = reduce(gcd, surface) + surface = [int(round(x / index)) for x in surface] + + if lat_type.lower() == "h": + # set the value for u,v,w,mu,mv,m,n,d,x + # check the reference for the meaning of these parameters + u, v, w = r_axis + # make sure mu, mv are coprime integers. + if ratio is None: + mu, mv = [1, 1] + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + else: + mu, mv = ratio + if gcd(mu, mv) != 1: + temp = gcd(mu, mv) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + d = (u**2 + v**2 - u * v) * mv + w**2 * mu + if abs(angle - 180.0) < 1.0e0: + m = 0 + n = 1 + else: + fraction = Fraction( + np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(float(d) / 3.0 / mu) + ).limit_denominator() + m = fraction.denominator + n = fraction.numerator + + # construct the rotation matrix, check reference for details + r_list = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, + (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, + (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, + 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, + (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, + (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, + ] + m = -1 * m + r_list_inv = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, + (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, + (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, + 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, + (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, + (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, + ] + m = -1 * m + F = 3 * mu * m**2 + d * n**2 + all_list = r_list + r_list_inv + [F] + com_fac = reduce(gcd, all_list) + sigma = F / com_fac + r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) + elif lat_type.lower() == "r": + # set the value for u,v,w,mu,mv,m,n,d + # check the reference for the meaning of these parameters + u, v, w = r_axis + # make sure mu, mv are coprime integers. + if ratio is None: + mu, mv = [1, 1] + if u + v + w != 0 and (u != v or u != w): + raise RuntimeError( + "For irrational ratio_alpha, CSL only exist for [1,1,1] or [u, v, -(u+v)] and m =0" + ) + else: + mu, mv = ratio + if gcd(mu, mv) != 1: + temp = gcd(mu, mv) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + d = (u**2 + v**2 + w**2) * (mu - 2 * mv) + 2 * mv * (v * w + w * u + u * v) + if abs(angle - 180.0) < 1.0e0: + m = 0 + n = 1 + else: + fraction = Fraction(np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(float(d) / mu)).limit_denominator() + m = fraction.denominator + n = fraction.numerator + + # construct the rotation matrix, check reference for details + r_list = [ + (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 + + 2 * mv * (v - w) * m * n + - 2 * mv * v * w * n**2 + + mu * m**2, + 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 + + 2 * mv * (w - u) * m * n + - 2 * mv * u * w * n**2 + + mu * m**2, + 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 + + 2 * mv * (u - v) * m * n + - 2 * mv * u * v * n**2 + + mu * m**2, + ] + m = -1 * m + r_list_inv = [ + (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 + + 2 * mv * (v - w) * m * n + - 2 * mv * v * w * n**2 + + mu * m**2, + 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 + + 2 * mv * (w - u) * m * n + - 2 * mv * u * w * n**2 + + mu * m**2, + 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 + + 2 * mv * (u - v) * m * n + - 2 * mv * u * v * n**2 + + mu * m**2, + ] + m = -1 * m + F = mu * m**2 + d * n**2 + all_list = r_list_inv + r_list + [F] + com_fac = reduce(gcd, all_list) + sigma = F / com_fac + r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) + else: + u, v, w = r_axis + if lat_type.lower() == "c": + mu = 1 + lam = 1 + mv = 1 + elif lat_type.lower() == "t": + if ratio is None: + mu, mv = [1, 1] + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + else: + mu, mv = ratio + lam = mv + elif lat_type.lower() == "o": + if None in ratio: + mu, lam, mv = ratio + non_none = [i for i in ratio if i is not None] + if len(non_none) < 2: + raise RuntimeError("No CSL exist for two irrational numbers") + non1, non2 = non_none + if mu is None: + lam = non1 + mv = non2 + mu = 1 + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + elif lam is None: + mu = non1 + mv = non2 + lam = 1 + if v != 0 and (u != 0 or (w != 0)): + raise RuntimeError("For irrational b2, CSL only exist for [0,1,0] or [u,0,w] and m = 0") + elif mv is None: + mu = non1 + lam = non2 + mv = 1 + if u != 0 and (w != 0 or (v != 0)): + raise RuntimeError("For irrational a2, CSL only exist for [1,0,0] or [0,v,w] and m = 0") + else: + mu, lam, mv = ratio + if u == 0 and v == 0: + mu = 1 + if u == 0 and w == 0: + lam = 1 + if v == 0 and w == 0: + mv = 1 + + # make sure mu, lambda, mv are coprime integers. + if reduce(gcd, [mu, lam, mv]) != 1: + temp = reduce(gcd, [mu, lam, mv]) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + lam = int(round(lam / temp)) + d = (mv * u**2 + lam * v**2) * mv + w**2 * mu * mv + if abs(angle - 180.0) < 1.0e0: + m = 0 + n = 1 + else: + fraction = Fraction(np.tan(angle / 2 / 180.0 * np.pi) / np.sqrt(d / mu / lam)).limit_denominator() + m = fraction.denominator + n = fraction.numerator + r_list = [ + (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * lam * (v * u * mv * n**2 - w * mu * m * n), + 2 * mu * (u * w * mv * n**2 + v * lam * m * n), + 2 * mv * (u * v * mv * n**2 + w * mu * m * n), + (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * mv * mu * (v * w * n**2 - u * m * n), + 2 * mv * (u * w * mv * n**2 - v * lam * m * n), + 2 * lam * mv * (v * w * n**2 + u * m * n), + (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, + ] + m = -1 * m + r_list_inv = [ + (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * lam * (v * u * mv * n**2 - w * mu * m * n), + 2 * mu * (u * w * mv * n**2 + v * lam * m * n), + 2 * mv * (u * v * mv * n**2 + w * mu * m * n), + (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * mv * mu * (v * w * n**2 - u * m * n), + 2 * mv * (u * w * mv * n**2 - v * lam * m * n), + 2 * lam * mv * (v * w * n**2 + u * m * n), + (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, + ] + m = -1 * m + F = mu * lam * m**2 + d * n**2 + all_list = r_list + r_list_inv + [F] + com_fac = reduce(gcd, all_list) + sigma = F / com_fac + r_matrix = (np.array(r_list) / com_fac / sigma).reshape(3, 3) + + if sigma > 1000: + raise RuntimeError("Sigma >1000 too large. Are you sure what you are doing, Please check the GB if exist") + # transform surface, r_axis, r_matrix in terms of primitive lattice + surface = np.matmul(surface, np.transpose(trans_cry)) + fractions = [Fraction(x).limit_denominator() for x in surface] + least_mul = reduce(lcm, [fraction.denominator for fraction in fractions]) + surface = [int(round(x * least_mul)) for x in surface] + if reduce(gcd, surface) != 1: + index = reduce(gcd, surface) + surface = [int(round(x / index)) for x in surface] + r_axis = np.rint(np.matmul(r_axis, np.linalg.inv(trans_cry))).astype(int) + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + r_matrix = np.dot(np.dot(np.linalg.inv(trans_cry.T), r_matrix), trans_cry.T) + # set one vector of the basis to the rotation axis direction, and + # obtain the corresponding transform matrix + eye = np.eye(3, dtype=int) + for hh in range(3): + if abs(r_axis[hh]) != 0: + eye[hh] = np.array(r_axis) + kk = hh + 1 if hh + 1 < 3 else abs(2 - hh) + ll = hh + 2 if hh + 2 < 3 else abs(1 - hh) + break + trans = eye.T + new_rot = np.array(r_matrix) + + # with the rotation matrix to construct the CSL lattice, check reference for details + fractions = [Fraction(x).limit_denominator() for x in new_rot[:, kk]] + least_mul = reduce(lcm, [fraction.denominator for fraction in fractions]) + scale = np.zeros((3, 3)) + scale[hh, hh] = 1 + scale[kk, kk] = least_mul + scale[ll, ll] = sigma / least_mul + for idx in range(least_mul): + check_int = idx * new_rot[:, kk] + (sigma / least_mul) * new_rot[:, ll] + if all(np.round(x, 5).is_integer() for x in list(check_int)): + n_final = idx + break + try: + n_final # noqa: B018 + except NameError: + raise RuntimeError("Something is wrong. Check if this GB exists or not") + scale[kk, ll] = n_final + # each row of mat_csl is the CSL lattice vector + csl_init = np.rint(np.dot(np.dot(r_matrix, trans), scale)).astype(int).T + if abs(r_axis[hh]) > 1: + csl_init = GrainBoundaryGenerator.reduce_mat(np.array(csl_init), r_axis[hh], r_matrix) + csl = np.rint(Lattice(csl_init).get_niggli_reduced_lattice().matrix).astype(int) + + # find the best slab supercell in terms of the conventional cell from the csl lattice, + # which is the transformation matrix + + # now trans_cry is the transformation matrix from crystal to Cartesian coordinates. + # for cubic, do not need to change. + if lat_type.lower() != "c": + if lat_type.lower() == "h": + trans_cry = np.array([[1, 0, 0], [-0.5, np.sqrt(3.0) / 2.0, 0], [0, 0, np.sqrt(mu / mv)]]) + elif lat_type.lower() == "r": + c2_a2_ratio = 1.0 if ratio is None else 3.0 / (2 - 6 * mv / mu) + trans_cry = np.array( + [ + [0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + [-0.5, np.sqrt(3.0) / 6.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + [0, -1 * np.sqrt(3.0) / 3.0, 1.0 / 3 * np.sqrt(c2_a2_ratio)], + ] + ) + else: + trans_cry = np.array([[1, 0, 0], [0, np.sqrt(lam / mv), 0], [0, 0, np.sqrt(mu / mv)]]) + t1_final = GrainBoundaryGenerator.slab_from_csl( + csl, surface, normal, trans_cry, max_search=max_search, quick_gen=quick_gen + ) + t2_final = np.array(np.rint(np.dot(t1_final, np.linalg.inv(r_matrix.T)))).astype(int) + return t1_final, t2_final + + @staticmethod + def enum_sigma_cubic(cutoff, r_axis): + """ + Find all possible sigma values and corresponding rotation angles + within a sigma value cutoff with known rotation axis in cubic system. + The algorithm for this code is from reference, Acta Cryst, A40,108(1984). + + Args: + cutoff (int): the cutoff of sigma values. + r_axis (list of 3 integers, e.g. u, v, w): + the rotation axis of the grain boundary, with the format of [u,v,w]. + + Returns: + dict: sigmas dictionary with keys as the possible integer sigma values + and values as list of the possible rotation angles to the + corresponding sigma values. e.g. the format as + {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} + Note: the angles are the rotation angles of one grain respect to + the other grain. + When generate the microstructures of the grain boundary using these angles, + you need to analyze the symmetry of the structure. Different angles may + result in equivalent microstructures. + """ + sigmas = {} + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + + # count the number of odds in r_axis + odd_r = len(list(filter(lambda x: x % 2 == 1, r_axis))) + # Compute the max n we need to enumerate. + if odd_r == 3: + a_max = 4 + elif odd_r == 0: + a_max = 1 + else: + a_max = 2 + n_max = int(np.sqrt(cutoff * a_max / sum(np.array(r_axis) ** 2))) + # enumerate all possible n, m to give possible sigmas within the cutoff. + for n_loop in range(1, n_max + 1): + n = n_loop + m_max = int(np.sqrt(cutoff * a_max - n**2 * sum(np.array(r_axis) ** 2))) + for m in range(m_max + 1): + if gcd(m, n) == 1 or m == 0: + n = 1 if m == 0 else n_loop + # construct the quadruple [m, U,V,W], count the number of odds in + # quadruple to determine the parameter a, refer to the reference + quadruple = [m] + [x * n for x in r_axis] + odd_qua = len(list(filter(lambda x: x % 2 == 1, quadruple))) + if odd_qua == 4: + a = 4 + elif odd_qua == 2: + a = 2 + else: + a = 1 + sigma = int(round((m**2 + n**2 * sum(np.array(r_axis) ** 2)) / a)) + if 1 < sigma <= cutoff: + if sigma not in list(sigmas): + if m == 0: + angle = 180.0 + else: + angle = 2 * np.arctan(n * np.sqrt(sum(np.array(r_axis) ** 2)) / m) / np.pi * 180 + sigmas[sigma] = [angle] + else: + if m == 0: + angle = 180.0 + else: + angle = 2 * np.arctan(n * np.sqrt(sum(np.array(r_axis) ** 2)) / m) / np.pi * 180 + if angle not in sigmas[sigma]: + sigmas[sigma].append(angle) + return sigmas + + @staticmethod + def enum_sigma_hex(cutoff, r_axis, c2_a2_ratio): + """ + Find all possible sigma values and corresponding rotation angles + within a sigma value cutoff with known rotation axis in hexagonal system. + The algorithm for this code is from reference, Acta Cryst, A38,550(1982). + + Args: + cutoff (int): the cutoff of sigma values. + r_axis (list of 3 integers, e.g. u, v, w or 4 integers, e.g. u, v, t, w): the rotation axis of the grain + boundary. + c2_a2_ratio (list of two integers, e.g. mu, mv): mu/mv is the square of the hexagonal axial ratio, + which is rational number. If irrational, set c2_a2_ratio = None + + Returns: + sigmas (dict): dictionary with keys as the possible integer sigma values and values as list of the + possible rotation angles to the corresponding sigma values. e.g. the format as + {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} + Note: the angles are the rotation angle of one grain respect to the + other grain. + When generate the microstructure of the grain boundary using these + angles, you need to analyze the symmetry of the structure. Different + angles may result in equivalent microstructures. + """ + sigmas = {} + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + # transform four index notation to three index notation + if len(r_axis) == 4: + u1 = r_axis[0] + v1 = r_axis[1] + w1 = r_axis[3] + u = 2 * u1 + v1 + v = 2 * v1 + u1 + w = w1 + else: + u, v, w = r_axis + + # make sure mu, mv are coprime integers. + if c2_a2_ratio is None: + mu, mv = [1, 1] + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + else: + mu, mv = c2_a2_ratio + if gcd(mu, mv) != 1: + temp = gcd(mu, mv) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + + # refer to the meaning of d in reference + d = (u**2 + v**2 - u * v) * mv + w**2 * mu + + # Compute the max n we need to enumerate. + n_max = int(np.sqrt((cutoff * 12 * mu * mv) / abs(d))) + + # Enumerate all possible n, m to give possible sigmas within the cutoff. + for n in range(1, n_max + 1): + if (c2_a2_ratio is None) and w == 0: + m_max = 0 + else: + m_max = int(np.sqrt((cutoff * 12 * mu * mv - n**2 * d) / (3 * mu))) + for m in range(m_max + 1): + if gcd(m, n) == 1 or m == 0: + # construct the rotation matrix, refer to the reference + R_list = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, + (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, + (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, + 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, + (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, + (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, + ] + m = -1 * m + # inverse of the rotation matrix + R_list_inv = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + 2 * w * mu * m * n + 3 * mu * m**2, + (2 * v - u) * u * mv * n**2 - 4 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * (2 * v - u) * mu * m * n, + (2 * u - v) * v * mv * n**2 + 4 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 - 2 * w * mu * m * n + 3 * mu * m**2, + 2 * v * w * mu * n**2 - 2 * (2 * u - v) * mu * m * n, + (2 * u - v) * w * mv * n**2 - 3 * v * mv * m * n, + (2 * v - u) * w * mv * n**2 + 3 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv + u * v * mv) * n**2 + 3 * mu * m**2, + ] + m = -1 * m + F = 3 * mu * m**2 + d * n**2 + all_list = R_list_inv + R_list + [F] + # Compute the max common factors for the elements of the rotation matrix + # and its inverse. + com_fac = reduce(gcd, all_list) + sigma = int(round((3 * mu * m**2 + d * n**2) / com_fac)) + if 1 < sigma <= cutoff: + if sigma not in list(sigmas): + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / 3.0 / mu)) / np.pi * 180 + sigmas[sigma] = [angle] + else: + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / 3.0 / mu)) / np.pi * 180 + if angle not in sigmas[sigma]: + sigmas[sigma].append(angle) + if m_max == 0: + break + return sigmas + + @staticmethod + def enum_sigma_rho(cutoff, r_axis, ratio_alpha): + """ + Find all possible sigma values and corresponding rotation angles + within a sigma value cutoff with known rotation axis in rhombohedral system. + The algorithm for this code is from reference, Acta Cryst, A45,505(1989). + + Args: + cutoff (int): the cutoff of sigma values. + r_axis (list[int]): of 3 integers, e.g. u, v, w + or 4 integers, e.g. u, v, t, w): + the rotation axis of the grain boundary, with the format of [u,v,w] + or Weber indices [u, v, t, w]. + ratio_alpha (list of two integers, e.g. mu, mv): + mu/mv is the ratio of (1+2*cos(alpha))/cos(alpha) with rational number. + If irrational, set ratio_alpha = None. + + Returns: + sigmas (dict): + dictionary with keys as the possible integer sigma values + and values as list of the possible rotation angles to the + corresponding sigma values. + e.g. the format as + {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} + Note: the angles are the rotation angle of one grain respect to the + other grain. + When generate the microstructure of the grain boundary using these + angles, you need to analyze the symmetry of the structure. Different + angles may result in equivalent microstructures. + """ + sigmas = {} + # transform four index notation to three index notation + if len(r_axis) == 4: + u1 = r_axis[0] + v1 = r_axis[1] + w1 = r_axis[3] + u = 2 * u1 + v1 + w1 + v = v1 + w1 - u1 + w = w1 - 2 * v1 - u1 + r_axis = [u, v, w] + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + u, v, w = r_axis + # make sure mu, mv are coprime integers. + if ratio_alpha is None: + mu, mv = [1, 1] + if u + v + w != 0 and (u != v or u != w): + raise RuntimeError("For irrational ratio_alpha, CSL only exist for [1,1,1] or [u, v, -(u+v)] and m =0") + else: + mu, mv = ratio_alpha + if gcd(mu, mv) != 1: + temp = gcd(mu, mv) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + + # refer to the meaning of d in reference + d = (u**2 + v**2 + w**2) * (mu - 2 * mv) + 2 * mv * (v * w + w * u + u * v) + # Compute the max n we need to enumerate. + n_max = int(np.sqrt((cutoff * abs(4 * mu * (mu - 3 * mv))) / abs(d))) + + # Enumerate all possible n, m to give possible sigmas within the cutoff. + for n in range(1, n_max + 1): + if ratio_alpha is None and u + v + w == 0: + m_max = 0 + else: + m_max = int(np.sqrt((cutoff * abs(4 * mu * (mu - 3 * mv)) - n**2 * d) / (mu))) + for m in range(m_max + 1): + if gcd(m, n) == 1 or m == 0: + # construct the rotation matrix, refer to the reference + R_list = [ + (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 + + 2 * mv * (v - w) * m * n + - 2 * mv * v * w * n**2 + + mu * m**2, + 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 + + 2 * mv * (w - u) * m * n + - 2 * mv * u * w * n**2 + + mu * m**2, + 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 + + 2 * mv * (u - v) * m * n + - 2 * mv * u * v * n**2 + + mu * m**2, + ] + m = -1 * m + # inverse of the rotation matrix + R_list_inv = [ + (mu - 2 * mv) * (u**2 - v**2 - w**2) * n**2 + + 2 * mv * (v - w) * m * n + - 2 * mv * v * w * n**2 + + mu * m**2, + 2 * (mv * u * n * (w * n + u * n - m) - (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + 2 * (mv * u * n * (v * n + u * n + m) + (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * v * n * (w * n + v * n + m) + (mu - mv) * m * w * n + (mu - 2 * mv) * u * v * n**2), + (mu - 2 * mv) * (v**2 - w**2 - u**2) * n**2 + + 2 * mv * (w - u) * m * n + - 2 * mv * u * w * n**2 + + mu * m**2, + 2 * (mv * v * n * (v * n + u * n - m) - (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + 2 * (mv * w * n * (w * n + v * n - m) - (mu - mv) * m * v * n + (mu - 2 * mv) * w * u * n**2), + 2 * (mv * w * n * (w * n + u * n + m) + (mu - mv) * m * u * n + (mu - 2 * mv) * w * v * n**2), + (mu - 2 * mv) * (w**2 - u**2 - v**2) * n**2 + + 2 * mv * (u - v) * m * n + - 2 * mv * u * v * n**2 + + mu * m**2, + ] + m = -1 * m + F = mu * m**2 + d * n**2 + all_list = R_list_inv + R_list + [F] + # Compute the max common factors for the elements of the rotation matrix and its inverse. + com_fac = reduce(gcd, all_list) + sigma = int(round(abs(F / com_fac))) + if 1 < sigma <= cutoff: + if sigma not in list(sigmas): + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 + sigmas[sigma] = [angle] + else: + angle = 180 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180.0 + if angle not in sigmas[sigma]: + sigmas[sigma].append(angle) + if m_max == 0: + break + return sigmas + + @staticmethod + def enum_sigma_tet(cutoff, r_axis, c2_a2_ratio): + """ + Find all possible sigma values and corresponding rotation angles + within a sigma value cutoff with known rotation axis in tetragonal system. + The algorithm for this code is from reference, Acta Cryst, B46,117(1990). + + Args: + cutoff (int): the cutoff of sigma values. + r_axis (list of 3 integers, e.g. u, v, w): + the rotation axis of the grain boundary, with the format of [u,v,w]. + c2_a2_ratio (list of two integers, e.g. mu, mv): + mu/mv is the square of the tetragonal axial ratio with rational number. + if irrational, set c2_a2_ratio = None + + Returns: + dict: sigmas dictionary with keys as the possible integer sigma values + and values as list of the possible rotation angles to the + corresponding sigma values. e.g. the format as + {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} + Note: the angles are the rotation angle of one grain respect to the + other grain. + When generate the microstructure of the grain boundary using these + angles, you need to analyze the symmetry of the structure. Different + angles may result in equivalent microstructures. + """ + sigmas = {} + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + + u, v, w = r_axis + + # make sure mu, mv are coprime integers. + if c2_a2_ratio is None: + mu, mv = [1, 1] + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2/a2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + else: + mu, mv = c2_a2_ratio + if gcd(mu, mv) != 1: + temp = gcd(mu, mv) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + + # refer to the meaning of d in reference + d = (u**2 + v**2) * mv + w**2 * mu + + # Compute the max n we need to enumerate. + n_max = int(np.sqrt((cutoff * 4 * mu * mv) / d)) + + # Enumerate all possible n, m to give possible sigmas within the cutoff. + for n in range(1, n_max + 1): + m_max = 0 if c2_a2_ratio is None and w == 0 else int(np.sqrt((cutoff * 4 * mu * mv - n**2 * d) / mu)) + for m in range(m_max + 1): + if gcd(m, n) == 1 or m == 0: + # construct the rotation matrix, refer to the reference + R_list = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + mu * m**2, + 2 * v * u * mv * n**2 - 2 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * v * mu * m * n, + 2 * u * v * mv * n**2 + 2 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 + mu * m**2, + 2 * v * w * mu * n**2 - 2 * u * mu * m * n, + 2 * u * w * mv * n**2 - 2 * v * mv * m * n, + 2 * v * w * mv * n**2 + 2 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv) * n**2 + mu * m**2, + ] + m = -1 * m + # inverse of rotation matrix + R_list_inv = [ + (u**2 * mv - v**2 * mv - w**2 * mu) * n**2 + mu * m**2, + 2 * v * u * mv * n**2 - 2 * w * mu * m * n, + 2 * u * w * mu * n**2 + 2 * v * mu * m * n, + 2 * u * v * mv * n**2 + 2 * w * mu * m * n, + (v**2 * mv - u**2 * mv - w**2 * mu) * n**2 + mu * m**2, + 2 * v * w * mu * n**2 - 2 * u * mu * m * n, + 2 * u * w * mv * n**2 - 2 * v * mv * m * n, + 2 * v * w * mv * n**2 + 2 * u * mv * m * n, + (w**2 * mu - u**2 * mv - v**2 * mv) * n**2 + mu * m**2, + ] + m = -1 * m + F = mu * m**2 + d * n**2 + all_list = R_list + R_list_inv + [F] + # Compute the max common factors for the elements of the rotation matrix + # and its inverse. + com_fac = reduce(gcd, all_list) + sigma = int(round((mu * m**2 + d * n**2) / com_fac)) + if 1 < sigma <= cutoff: + if sigma not in list(sigmas): + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 + sigmas[sigma] = [angle] + else: + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu)) / np.pi * 180 + if angle not in sigmas[sigma]: + sigmas[sigma].append(angle) + if m_max == 0: + break + + return sigmas + + @staticmethod + def enum_sigma_ort(cutoff, r_axis, c2_b2_a2_ratio): + """ + Find all possible sigma values and corresponding rotation angles + within a sigma value cutoff with known rotation axis in orthorhombic system. + The algorithm for this code is from reference, Scipta Metallurgica 27, 291(1992). + + Args: + cutoff (int): the cutoff of sigma values. + r_axis (list of 3 integers, e.g. u, v, w): + the rotation axis of the grain boundary, with the format of [u,v,w]. + c2_b2_a2_ratio (list of 3 integers, e.g. mu,lambda, mv): + mu:lam:mv is the square of the orthorhombic axial ratio with rational + numbers. If irrational for one axis, set it to None. + e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. + + Returns: + dict: sigmas dictionary with keys as the possible integer sigma values + and values as list of the possible rotation angles to the + corresponding sigma values. e.g. the format as + {sigma1: [angle11,angle12,...], sigma2: [angle21, angle22,...],...} + Note: the angles are the rotation angle of one grain respect to the + other grain. + When generate the microstructure of the grain boundary using these + angles, you need to analyze the symmetry of the structure. Different + angles may result in equivalent microstructures. + """ + sigmas = {} + # make sure gcd(r_axis)==1 + if reduce(gcd, r_axis) != 1: + r_axis = [int(round(x / reduce(gcd, r_axis))) for x in r_axis] + + u, v, w = r_axis + # make sure mu, lambda, mv are coprime integers. + if None in c2_b2_a2_ratio: + mu, lam, mv = c2_b2_a2_ratio + non_none = [i for i in c2_b2_a2_ratio if i is not None] + if len(non_none) < 2: + raise RuntimeError("No CSL exist for two irrational numbers") + non1, non2 = non_none + if reduce(gcd, non_none) != 1: + temp = reduce(gcd, non_none) + non1 = int(round(non1 / temp)) + non2 = int(round(non2 / temp)) + if mu is None: + lam = non1 + mv = non2 + mu = 1 + if w != 0 and (u != 0 or (v != 0)): + raise RuntimeError("For irrational c2, CSL only exist for [0,0,1] or [u,v,0] and m = 0") + elif lam is None: + mu = non1 + mv = non2 + lam = 1 + if v != 0 and (u != 0 or (w != 0)): + raise RuntimeError("For irrational b2, CSL only exist for [0,1,0] or [u,0,w] and m = 0") + elif mv is None: + mu = non1 + lam = non2 + mv = 1 + if u != 0 and (w != 0 or (v != 0)): + raise RuntimeError("For irrational a2, CSL only exist for [1,0,0] or [0,v,w] and m = 0") + else: + mu, lam, mv = c2_b2_a2_ratio + if reduce(gcd, c2_b2_a2_ratio) != 1: + temp = reduce(gcd, c2_b2_a2_ratio) + mu = int(round(mu / temp)) + mv = int(round(mv / temp)) + lam = int(round(lam / temp)) + if u == 0 and v == 0: + mu = 1 + if u == 0 and w == 0: + lam = 1 + if v == 0 and w == 0: + mv = 1 + # refer to the meaning of d in reference + d = (mv * u**2 + lam * v**2) * mv + w**2 * mu * mv + + # Compute the max n we need to enumerate. + n_max = int(np.sqrt((cutoff * 4 * mu * mv * mv * lam) / d)) + # Enumerate all possible n, m to give possible sigmas within the cutoff. + for n in range(1, n_max + 1): + mu_temp, lam_temp, mv_temp = c2_b2_a2_ratio + if (mu_temp is None and w == 0) or (lam_temp is None and v == 0) or (mv_temp is None and u == 0): + m_max = 0 + else: + m_max = int(np.sqrt((cutoff * 4 * mu * mv * lam * mv - n**2 * d) / mu / lam)) + for m in range(m_max + 1): + if gcd(m, n) == 1 or m == 0: + # construct the rotation matrix, refer to the reference + R_list = [ + (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * lam * (v * u * mv * n**2 - w * mu * m * n), + 2 * mu * (u * w * mv * n**2 + v * lam * m * n), + 2 * mv * (u * v * mv * n**2 + w * mu * m * n), + (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * mv * mu * (v * w * n**2 - u * m * n), + 2 * mv * (u * w * mv * n**2 - v * lam * m * n), + 2 * lam * mv * (v * w * n**2 + u * m * n), + (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, + ] + m = -1 * m + # inverse of rotation matrix + R_list_inv = [ + (u**2 * mv * mv - lam * v**2 * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * lam * (v * u * mv * n**2 - w * mu * m * n), + 2 * mu * (u * w * mv * n**2 + v * lam * m * n), + 2 * mv * (u * v * mv * n**2 + w * mu * m * n), + (v**2 * mv * lam - u**2 * mv * mv - w**2 * mu * mv) * n**2 + lam * mu * m**2, + 2 * mv * mu * (v * w * n**2 - u * m * n), + 2 * mv * (u * w * mv * n**2 - v * lam * m * n), + 2 * lam * mv * (v * w * n**2 + u * m * n), + (w**2 * mu * mv - u**2 * mv * mv - v**2 * mv * lam) * n**2 + lam * mu * m**2, + ] + m = -1 * m + F = mu * lam * m**2 + d * n**2 + all_list = R_list + R_list_inv + [F] + # Compute the max common factors for the elements of the rotation matrix + # and its inverse. + com_fac = reduce(gcd, all_list) + sigma = int(round((mu * lam * m**2 + d * n**2) / com_fac)) + if 1 < sigma <= cutoff: + if sigma not in list(sigmas): + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu / lam)) / np.pi * 180 + sigmas[sigma] = [angle] + else: + angle = 180.0 if m == 0 else 2 * np.arctan(n / m * np.sqrt(d / mu / lam)) / np.pi * 180 + if angle not in sigmas[sigma]: + sigmas[sigma].append(angle) + if m_max == 0: + break + + return sigmas + + @staticmethod + def enum_possible_plane_cubic(plane_cutoff, r_axis, r_angle): + """ + Find all possible plane combinations for GBs given a rotation axis and angle for + cubic system, and classify them to different categories, including 'Twist', + 'Symmetric tilt', 'Normal tilt', 'Mixed' GBs. + + Args: + plane_cutoff (int): the cutoff of plane miller index. + r_axis (list of 3 integers, e.g. u, v, w): + the rotation axis of the grain boundary, with the format of [u,v,w]. + r_angle (float): rotation angle of the GBs. + + Returns: + dict: all combinations with keys as GB type, e.g. 'Twist','Symmetric tilt',etc. + and values as the combination of the two plane miller index (GB plane and joining plane). + """ + all_combinations = {} + all_combinations["Symmetric tilt"] = [] + all_combinations["Twist"] = [] + all_combinations["Normal tilt"] = [] + all_combinations["Mixed"] = [] + sym_plane = symm_group_cubic([[1, 0, 0], [1, 1, 0]]) + j = np.arange(0, plane_cutoff + 1) + combination = [] + for idx in product(j, repeat=3): + if sum(abs(np.array(idx))) != 0: + combination.append(list(idx)) + if len(np.nonzero(idx)[0]) == 3: + for i1 in range(3): + new_i = list(idx).copy() + new_i[i1] = -1 * new_i[i1] + combination.append(new_i) + elif len(np.nonzero(idx)[0]) == 2: + new_i = list(idx).copy() + new_i[np.nonzero(idx)[0][0]] = -1 * new_i[np.nonzero(idx)[0][0]] + combination.append(new_i) + miller = np.array(combination) + miller = miller[np.argsort(np.linalg.norm(miller, axis=1))] + for val in miller: + if reduce(gcd, val) == 1: + matrix = GrainBoundaryGenerator.get_trans_mat(r_axis, r_angle, surface=val, quick_gen=True) + vec = np.cross(matrix[1][0], matrix[1][1]) + miller2 = GrainBoundaryGenerator.vec_to_surface(vec) + if np.all(np.abs(np.array(miller2)) <= plane_cutoff): + cos_1 = abs(np.dot(val, r_axis) / np.linalg.norm(val) / np.linalg.norm(r_axis)) + if 1 - cos_1 < 1.0e-5: + all_combinations["Twist"].append([list(val), miller2]) + elif cos_1 < 1.0e-8: + sym_tilt = False + if np.sum(np.abs(val)) == np.sum(np.abs(miller2)): + ave = (np.array(val) + np.array(miller2)) / 2 + ave1 = (np.array(val) - np.array(miller2)) / 2 + for plane in sym_plane: + cos_2 = abs(np.dot(ave, plane) / np.linalg.norm(ave) / np.linalg.norm(plane)) + cos_3 = abs(np.dot(ave1, plane) / np.linalg.norm(ave1) / np.linalg.norm(plane)) + if 1 - cos_2 < 1.0e-5 or 1 - cos_3 < 1.0e-5: + all_combinations["Symmetric tilt"].append([list(val), miller2]) + sym_tilt = True + break + if not sym_tilt: + all_combinations["Normal tilt"].append([list(val), miller2]) + else: + all_combinations["Mixed"].append([list(val), miller2]) + return all_combinations + + @staticmethod + def get_rotation_angle_from_sigma(sigma, r_axis, lat_type="C", ratio=None): + """ + Find all possible rotation angle for the given sigma value. + + Args: + sigma (int): sigma value provided + r_axis (list of 3 integers, e.g. u, v, w or 4 integers, e.g. u, v, t, w for hex/rho system only): the + rotation axis of the grain boundary. + lat_type (str): one character to specify the lattice type. Defaults to 'c' for cubic. + 'c' or 'C': cubic system + 't' or 'T': tetragonal system + 'o' or 'O': orthorhombic system + 'h' or 'H': hexagonal system + 'r' or 'R': rhombohedral system + ratio (list of integers): lattice axial ratio. + For cubic system, ratio is not needed. + For tetragonal system, ratio = [mu, mv], list of two integers, + that is, mu/mv = c2/a2. If it is irrational, set it to none. + For orthorhombic system, ratio = [mu, lam, mv], list of 3 integers, + that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None. + e.g. mu:lam:mv = c2,None,a2, means b2 is irrational. + For rhombohedral system, ratio = [mu, mv], list of two integers, + that is, mu/mv is the ratio of (1+2*cos(alpha)/cos(alpha). + If irrational, set it to None. + For hexagonal system, ratio = [mu, mv], list of two integers, + that is, mu/mv = c2/a2. If it is irrational, set it to none. + + Returns: + rotation_angles corresponding to the provided sigma value. + If the sigma value is not correct, return the rotation angle corresponding + to the correct possible sigma value right smaller than the wrong sigma value provided. + """ + if lat_type.lower() == "c": + logger.info("Make sure this is for cubic system") + sigma_dict = GrainBoundaryGenerator.enum_sigma_cubic(cutoff=sigma, r_axis=r_axis) + elif lat_type.lower() == "t": + logger.info("Make sure this is for tetragonal system") + if ratio is None: + logger.info("Make sure this is for irrational c2/a2 ratio") + elif len(ratio) != 2: + raise RuntimeError("Tetragonal system needs correct c2/a2 ratio") + sigma_dict = GrainBoundaryGenerator.enum_sigma_tet(cutoff=sigma, r_axis=r_axis, c2_a2_ratio=ratio) + elif lat_type.lower() == "o": + logger.info("Make sure this is for orthorhombic system") + if len(ratio) != 3: + raise RuntimeError("Orthorhombic system needs correct c2:b2:a2 ratio") + sigma_dict = GrainBoundaryGenerator.enum_sigma_ort(cutoff=sigma, r_axis=r_axis, c2_b2_a2_ratio=ratio) + elif lat_type.lower() == "h": + logger.info("Make sure this is for hexagonal system") + if ratio is None: + logger.info("Make sure this is for irrational c2/a2 ratio") + elif len(ratio) != 2: + raise RuntimeError("Hexagonal system needs correct c2/a2 ratio") + sigma_dict = GrainBoundaryGenerator.enum_sigma_hex(cutoff=sigma, r_axis=r_axis, c2_a2_ratio=ratio) + elif lat_type.lower() == "r": + logger.info("Make sure this is for rhombohedral system") + if ratio is None: + logger.info("Make sure this is for irrational (1+2*cos(alpha)/cos(alpha) ratio") + elif len(ratio) != 2: + raise RuntimeError("Rhombohedral system needs correct (1+2*cos(alpha)/cos(alpha) ratio") + sigma_dict = GrainBoundaryGenerator.enum_sigma_rho(cutoff=sigma, r_axis=r_axis, ratio_alpha=ratio) + else: + raise RuntimeError("Lattice type not implemented") + + sigmas = list(sigma_dict) + if not sigmas: + raise RuntimeError("This is a wrong sigma value, and no sigma exists smaller than this value.") + if sigma in sigmas: + rotation_angles = sigma_dict[sigma] + else: + sigmas.sort() + warnings.warn( + "This is not the possible sigma value according to the rotation axis!" + "The nearest neighbor sigma and its corresponding angle are returned" + ) + rotation_angles = sigma_dict[sigmas[-1]] + rotation_angles.sort() + return rotation_angles + + @staticmethod + def slab_from_csl(csl, surface, normal, trans_cry, max_search=20, quick_gen=False): + """ + By linear operation of csl lattice vectors to get the best corresponding + slab lattice. That is the area of a,b vectors (within the surface plane) + is the smallest, the c vector first, has shortest length perpendicular + to surface [h,k,l], second, has shortest length itself. + + Args: + csl (3 by 3 integer array): + input csl lattice. + surface (list of 3 integers, e.g. h, k, l): + the miller index of the surface, with the format of [h,k,l] + normal (logic): + determine if the c vector needs to perpendicular to surface + trans_cry (3 by 3 array): + transform matrix from crystal system to orthogonal system + max_search (int): max search for the GB lattice vectors that give the smallest GB + lattice. If normal is true, also max search the GB c vector that perpendicular + to the plane. + quick_gen (bool): whether to quickly generate a supercell, no need to find the smallest + cell if set to true. + + Returns: + t_matrix: a slab lattice ( 3 by 3 integer array): + """ + # set the transform matrix in real space + trans = trans_cry + # transform matrix in reciprocal space + ctrans = np.linalg.inv(trans.T) + + t_matrix = csl.copy() + # vectors constructed from csl that perpendicular to surface + ab_vector = [] + # obtain the miller index of surface in terms of csl. + miller = np.matmul(surface, csl.T) + if reduce(gcd, miller) != 1: + miller = [int(round(x / reduce(gcd, miller))) for x in miller] + miller_nonzero = [] + # quickly generate a supercell, normal is not work in this way + if quick_gen: + scale_factor = [] + eye = np.eye(3, dtype=int) + for ii, jj in enumerate(miller): + if jj == 0: + scale_factor.append(eye[ii]) + else: + miller_nonzero.append(ii) + if len(scale_factor) < 2: + index_len = len(miller_nonzero) + for ii in range(index_len): + for jj in range(ii + 1, index_len): + lcm_miller = lcm(miller[miller_nonzero[ii]], miller[miller_nonzero[jj]]) + scl_factor = [0, 0, 0] + scl_factor[miller_nonzero[ii]] = -int(round(lcm_miller / miller[miller_nonzero[ii]])) + scl_factor[miller_nonzero[jj]] = int(round(lcm_miller / miller[miller_nonzero[jj]])) + scale_factor.append(scl_factor) + if len(scale_factor) == 2: + break + t_matrix[0] = np.array(np.dot(scale_factor[0], csl)) + t_matrix[1] = np.array(np.dot(scale_factor[1], csl)) + t_matrix[2] = csl[miller_nonzero[0]] + if abs(np.linalg.det(t_matrix)) > 1000: + warnings.warn("Too large matrix. Suggest to use quick_gen=False") + return t_matrix + + for ii, jj in enumerate(miller): + if jj == 0: + ab_vector.append(csl[ii]) + else: + c_index = ii + miller_nonzero.append(jj) + + if len(miller_nonzero) > 1: + t_matrix[2] = csl[c_index] + index_len = len(miller_nonzero) + lcm_miller = [] + for ii in range(index_len): + for jj in range(ii + 1, index_len): + com_gcd = gcd(miller_nonzero[ii], miller_nonzero[jj]) + mil1 = int(round(miller_nonzero[ii] / com_gcd)) + mil2 = int(round(miller_nonzero[jj] / com_gcd)) + lcm_miller.append(max(abs(mil1), abs(mil2))) + lcm_sorted = sorted(lcm_miller) + max_j = lcm_sorted[0] if index_len == 2 else lcm_sorted[1] + else: + if not normal: + t_matrix[0] = ab_vector[0] + t_matrix[1] = ab_vector[1] + t_matrix[2] = csl[c_index] + return t_matrix + max_j = abs(miller_nonzero[0]) + max_j = min(max_j, max_search) + # area of a, b vectors + area = None + # length of c vector + c_norm = np.linalg.norm(np.matmul(t_matrix[2], trans)) + # c vector length along the direction perpendicular to surface + c_length = np.abs(np.dot(t_matrix[2], surface)) + # check if the init c vector perpendicular to the surface + if normal: + c_cross = np.cross(np.matmul(t_matrix[2], trans), np.matmul(surface, ctrans)) + normal_init = np.linalg.norm(c_cross) < 1e-8 + + jj = np.arange(0, max_j + 1) + combination = [] + for ii in product(jj, repeat=3): + if sum(abs(np.array(ii))) != 0: + combination.append(list(ii)) + if len(np.nonzero(ii)[0]) == 3: + for i1 in range(3): + new_i = list(ii).copy() + new_i[i1] = -1 * new_i[i1] + combination.append(new_i) + elif len(np.nonzero(ii)[0]) == 2: + new_i = list(ii).copy() + new_i[np.nonzero(ii)[0][0]] = -1 * new_i[np.nonzero(ii)[0][0]] + combination.append(new_i) + for ii in combination: + if reduce(gcd, ii) == 1: + temp = np.dot(np.array(ii), csl) + if abs(np.dot(temp, surface) - 0) < 1.0e-8: + ab_vector.append(temp) + else: + # c vector length along the direction perpendicular to surface + c_len_temp = np.abs(np.dot(temp, surface)) + # c vector length itself + c_norm_temp = np.linalg.norm(np.matmul(temp, trans)) + if normal: + c_cross = np.cross(np.matmul(temp, trans), np.matmul(surface, ctrans)) + if np.linalg.norm(c_cross) < 1.0e-8: + if normal_init: + if c_norm_temp < c_norm: + t_matrix[2] = temp + c_norm = c_norm_temp + else: + c_norm = c_norm_temp + normal_init = True + t_matrix[2] = temp + elif c_len_temp < c_length or (abs(c_len_temp - c_length) < 1.0e-8 and c_norm_temp < c_norm): + t_matrix[2] = temp + c_norm = c_norm_temp + c_length = c_len_temp + + if normal and (not normal_init): + logger.info("Did not find the perpendicular c vector, increase max_j") + while not normal_init: + if max_j == max_search: + warnings.warn("Cannot find the perpendicular c vector, please increase max_search") + break + max_j = 3 * max_j + max_j = min(max_j, max_search) + jj = np.arange(0, max_j + 1) + combination = [] + for ii in product(jj, repeat=3): + if sum(abs(np.array(ii))) != 0: + combination.append(list(ii)) + if len(np.nonzero(ii)[0]) == 3: + for i1 in range(3): + new_i = list(ii).copy() + new_i[i1] = -1 * new_i[i1] + combination.append(new_i) + elif len(np.nonzero(ii)[0]) == 2: + new_i = list(ii).copy() + new_i[np.nonzero(ii)[0][0]] = -1 * new_i[np.nonzero(ii)[0][0]] + combination.append(new_i) + for ii in combination: + if reduce(gcd, ii) == 1: + temp = np.dot(np.array(ii), csl) + if abs(np.dot(temp, surface) - 0) > 1.0e-8: + c_cross = np.cross(np.matmul(temp, trans), np.matmul(surface, ctrans)) + if np.linalg.norm(c_cross) < 1.0e-8: + # c vector length itself + c_norm_temp = np.linalg.norm(np.matmul(temp, trans)) + if normal_init: + if c_norm_temp < c_norm: + t_matrix[2] = temp + c_norm = c_norm_temp + else: + c_norm = c_norm_temp + normal_init = True + t_matrix[2] = temp + if normal_init: + logger.info("Found perpendicular c vector") + + # find the best a, b vectors with their formed area smallest and average norm of a,b smallest. + for ii in combinations(ab_vector, 2): + area_temp = np.linalg.norm(np.cross(np.matmul(ii[0], trans), np.matmul(ii[1], trans))) + if abs(area_temp - 0) > 1.0e-8: + ab_norm_temp = np.linalg.norm(np.matmul(ii[0], trans)) + np.linalg.norm(np.matmul(ii[1], trans)) + if area is None: + area = area_temp + ab_norm = ab_norm_temp + t_matrix[0] = ii[0] + t_matrix[1] = ii[1] + elif area_temp < area or (abs(area - area_temp) < 1.0e-8 and ab_norm_temp < ab_norm): + t_matrix[0] = ii[0] + t_matrix[1] = ii[1] + area = area_temp + ab_norm = ab_norm_temp + + # make sure we have a left-handed crystallographic system + if np.linalg.det(np.matmul(t_matrix, trans)) < 0: + t_matrix *= -1 + + if normal and abs(np.linalg.det(t_matrix)) > 1000: + warnings.warn("Too large matrix. Suggest to use Normal=False") + return t_matrix + + @staticmethod + def reduce_mat(mat, mag, r_matrix): + """ + Reduce integer array mat's determinant mag times by linear combination + of its row vectors, so that the new array after rotation (r_matrix) is + still an integer array. + + Args: + mat (3 by 3 array): input matrix + mag (int): reduce times for the determinant + r_matrix (3 by 3 array): rotation matrix + + Returns: + the reduced integer array + """ + max_j = abs(int(round(np.linalg.det(mat) / mag))) + reduced = False + for h in range(3): + kk = h + 1 if h + 1 < 3 else abs(2 - h) + ll = h + 2 if h + 2 < 3 else abs(1 - h) + jj = np.arange(-max_j, max_j + 1) + for j1, j2 in product(jj, repeat=2): + temp = mat[h] + j1 * mat[kk] + j2 * mat[ll] + if all(np.round(x, 5).is_integer() for x in list(temp / mag)): + mat_copy = mat.copy() + mat_copy[h] = np.array([int(round(ele / mag)) for ele in temp]) + new_mat = np.dot(mat_copy, np.linalg.inv(r_matrix.T)) + if all(np.round(x, 5).is_integer() for x in list(np.ravel(new_mat))): + reduced = True + mat[h] = np.array([int(round(ele / mag)) for ele in temp]) + break + if reduced: + break + + if not reduced: + warnings.warn("Matrix reduction not performed, may lead to non-primitive GB cell.") + return mat + + @staticmethod + def vec_to_surface(vec): + """ + Transform a float vector to a surface miller index with integers. + + Args: + vec (1 by 3 array float vector): input float vector + + Returns: + the surface miller index of the input vector. + """ + miller = [None] * 3 + index = [] + for idx, value in enumerate(vec): + if abs(value) < 1.0e-8: + miller[idx] = 0 + else: + index.append(idx) + if len(index) == 1: + miller[index[0]] = 1 + else: + min_index = np.argmin([i for i in vec if i != 0]) + true_index = index[min_index] + index.pop(min_index) + frac = [] + for value in index: + frac.append(Fraction(vec[value] / vec[true_index]).limit_denominator(100)) + if len(index) == 1: + miller[true_index] = frac[0].denominator + miller[index[0]] = frac[0].numerator + else: + com_lcm = lcm(frac[0].denominator, frac[1].denominator) + miller[true_index] = com_lcm + miller[index[0]] = frac[0].numerator * int(round(com_lcm / frac[0].denominator)) + miller[index[1]] = frac[1].numerator * int(round(com_lcm / frac[1].denominator)) + return miller + + +def fix_pbc(structure, matrix=None): + """ + Set all frac_coords of the input structure within [0,1]. + + Args: + structure (pymatgen structure object): input structure + matrix (lattice matrix, 3 by 3 array/matrix): new structure's lattice matrix, + If None, use input structure's matrix. + + + Returns: + new structure with fixed frac_coords and lattice matrix + """ + spec = [] + coords = [] + latte = Lattice(structure.lattice.matrix) if matrix is None else Lattice(matrix) + + for site in structure: + spec.append(site.specie) + coord = np.array(site.frac_coords) + for i in range(3): + coord[i] -= floor(coord[i]) + if np.allclose(coord[i], 1) or np.allclose(coord[i], 0): + coord[i] = 0 + else: + coord[i] = round(coord[i], 7) + coords.append(coord) + + return Structure(latte, spec, coords, site_properties=structure.site_properties) + + +def symm_group_cubic(mat): + """ + Obtain cubic symmetric equivalents of the list of vectors. + + Args: + mat (n by 3 array/matrix): lattice matrix + + + Returns: + cubic symmetric equivalents of the list of vectors. + """ + sym_group = np.zeros([24, 3, 3]) + sym_group[0, :] = np.eye(3) + sym_group[1, :] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + sym_group[2, :] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] + sym_group[3, :] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] + sym_group[4, :] = [[0, -1, 0], [-1, 0, 0], [0, 0, -1]] + sym_group[5, :] = [[0, -1, 0], [1, 0, 0], [0, 0, 1]] + sym_group[6, :] = [[0, 1, 0], [-1, 0, 0], [0, 0, 1]] + sym_group[7, :] = [[0, 1, 0], [1, 0, 0], [0, 0, -1]] + sym_group[8, :] = [[-1, 0, 0], [0, 0, -1], [0, -1, 0]] + sym_group[9, :] = [[-1, 0, 0], [0, 0, 1], [0, 1, 0]] + sym_group[10, :] = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + sym_group[11, :] = [[1, 0, 0], [0, 0, 1], [0, -1, 0]] + sym_group[12, :] = [[0, 1, 0], [0, 0, 1], [1, 0, 0]] + sym_group[13, :] = [[0, 1, 0], [0, 0, -1], [-1, 0, 0]] + sym_group[14, :] = [[0, -1, 0], [0, 0, 1], [-1, 0, 0]] + sym_group[15, :] = [[0, -1, 0], [0, 0, -1], [1, 0, 0]] + sym_group[16, :] = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + sym_group[17, :] = [[0, 0, 1], [-1, 0, 0], [0, -1, 0]] + sym_group[18, :] = [[0, 0, -1], [1, 0, 0], [0, -1, 0]] + sym_group[19, :] = [[0, 0, -1], [-1, 0, 0], [0, 1, 0]] + sym_group[20, :] = [[0, 0, -1], [0, -1, 0], [-1, 0, 0]] + sym_group[21, :] = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + sym_group[22, :] = [[0, 0, 1], [0, -1, 0], [1, 0, 0]] + sym_group[23, :] = [[0, 0, 1], [0, 1, 0], [-1, 0, 0]] + + mat = np.atleast_2d(mat) + all_vectors = [] + for sym in sym_group: + for vec in mat: + all_vectors.append(np.dot(sym, vec)) + return np.unique(np.array(all_vectors), axis=0) + class Interface(Structure): """This class stores data for defining an interface between two structures. @@ -282,8 +2591,10 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct: dict) -> Interface: # type: ignore[override] - """:param dct: dict + def from_dict(cls, dct: dict) -> Self: # type: ignore[override] + """ + Args: + dct: dict. Returns: Creates slab from dict. @@ -316,7 +2627,7 @@ def from_slabs( vacuum_over_film: float = 0, interface_properties: dict | None = None, center_slab: bool = True, - ) -> Interface: + ) -> Self: """Makes an interface structure by merging a substrate and film slabs The film a- and b-vectors will be forced to be the substrate slab's a- and b-vectors. @@ -473,20 +2784,20 @@ def count_layers(struct: Structure, el=None) -> int: h = struct.lattice.c # Projection of c lattice vector in # direction of surface normal. - for i, j in combinations(list(range(n)), 2): - if i != j: - cdist = frac_coords[i][2] - frac_coords[j][2] + for ii, jj in combinations(list(range(n)), 2): + if ii != jj: + cdist = frac_coords[ii][2] - frac_coords[jj][2] cdist = abs(cdist - round(cdist)) * h - dist_matrix[i, j] = cdist - dist_matrix[j, i] = cdist + dist_matrix[ii, jj] = cdist + dist_matrix[jj, ii] = cdist condensed_m = squareform(dist_matrix) z = linkage(condensed_m) clusters = fcluster(z, 0.25, criterion="distance") clustered_sites: dict[int, list[Site]] = {c: [] for c in clusters} - for i, c in enumerate(clusters): - clustered_sites[c].append(struct[i]) + for idx, cluster in enumerate(clusters): + clustered_sites[cluster].append(struct[idx]) plane_heights = { np.average(np.mod([s.frac_coords[2] for s in sites], 1)): c for c, sites in clustered_sites.items() diff --git a/pymatgen/core/ion.py b/pymatgen/core/ion.py index 3a1d29e6c57..0af620e4176 100644 --- a/pymatgen/core/ion.py +++ b/pymatgen/core/ion.py @@ -4,12 +4,16 @@ import re from copy import deepcopy +from typing import TYPE_CHECKING from monty.json import MSONable from pymatgen.core.composition import Composition, reduce_formula from pymatgen.util.string import Stringify, charge_string, formula_double_format +if TYPE_CHECKING: + from typing_extensions import Self + class Ion(Composition, MSONable, Stringify): """Ion object. Just a Composition object with an additional variable to store @@ -19,7 +23,7 @@ class Ion(Composition, MSONable, Stringify): Mn[+2]. Note the order of the sign and magnitude in each representation. """ - def __init__(self, composition, charge=0.0, _properties=None) -> None: + def __init__(self, composition: Composition, charge: float = 0.0) -> None: """Flexible Ion construction, similar to Composition. For more information, please see pymatgen.core.Composition. """ @@ -27,28 +31,26 @@ def __init__(self, composition, charge=0.0, _properties=None) -> None: self._charge = charge @classmethod - def from_formula(cls, formula: str) -> Ion: + def from_formula(cls, formula: str) -> Self: """Creates Ion from formula. The net charge can either be represented as Mn++, Mn+2, Mn[2+], Mn[++], or Mn[+2]. Note the order of the sign and magnitude in each representation. Also note that (aq) can be included in the formula, e.g. "NaOH (aq)". - :param formula: + Args: + formula (str): The formula to create ion from. Returns: Ion """ charge = 0.0 - f = formula # strip (aq), if present - m = re.search(r"\(aq\)", f) - if m: - f = f.replace(m.group(), "", 1) + if match := re.search(r"\(aq\)", formula): + formula = formula.replace(match.group(), "", 1) # check for charge in brackets - m = re.search(r"\[([^\[\]]+)\]", f) - if m: - m_chg = re.search(r"([\.\d]*)([+-]*)([\.\d]*)", m.group(1)) + if match := re.search(r"\[([^\[\]]+)\]", formula): + m_chg = re.search(r"([\.\d]*)([+-]*)([\.\d]*)", match.group(1)) if m_chg: if m_chg.group(1) != "": if m_chg.group(3) != "": @@ -60,18 +62,18 @@ def from_formula(cls, formula: str) -> Ion: for i in re.findall("[+-]", m_chg.group(2)): charge += float(i + "1") - f = f.replace(m.group(), "", 1) + formula = formula.replace(match.group(), "", 1) # if no brackets, parse trailing +/- - for m_chg in re.finditer(r"([+-])([\.\d]*)", f): + for m_chg in re.finditer(r"([+-])([\.\d]*)", formula): sign = m_chg.group(1) sgn = float(str(sign + "1")) if m_chg.group(2).strip() != "": charge += float(m_chg.group(2)) * sgn else: charge += sgn - f = f.replace(m_chg.group(), "", 1) - composition = Composition(f) + formula = formula.replace(m_chg.group(), "", 1) + composition = Composition(formula) return cls(composition, charge) @property @@ -119,8 +121,8 @@ def get_reduced_formula_and_factor(self, iupac_ordering: bool = False, hydrates: Ions containing metals. Returns: - A pretty normalized formula and a multiplicative factor, i.e., - H4O4 returns ('H2O2', 2.0). + tuple[str, float]: A pretty normalized formula and a multiplicative factor, i.e., + H4O4 returns ('H2O2', 2.0). """ all_int = all(abs(x - round(x)) < Composition.amount_tolerance for x in self.values()) if not all_int: @@ -137,7 +139,7 @@ def get_reduced_formula_and_factor(self, iupac_ordering: bool = False, hydrates: comp = self.composition - nH2O * Composition("H2O") el_amt_dict = {k: int(round(v)) for k, v in comp.get_el_amt_dict().items()} - (formula, factor) = reduce_formula(el_amt_dict, iupac_ordering=iupac_ordering) + formula, factor = reduce_formula(el_amt_dict, iupac_ordering=iupac_ordering) if self.composition.get("H") == self.composition.get("O") is not None: formula = formula.replace("HO", "OH") @@ -211,7 +213,7 @@ def as_dict(self) -> dict[str, float]: return dct @classmethod - def from_dict(cls, dct) -> Ion: + def from_dict(cls, dct: dict) -> Self: """Generates an ion object from a dict created by as_dict(). Args: diff --git a/pymatgen/core/lattice.py b/pymatgen/core/lattice.py index 70c327efe4b..c4b0c98cbc1 100644 --- a/pymatgen/core/lattice.py +++ b/pymatgen/core/lattice.py @@ -9,7 +9,7 @@ import warnings from fractions import Fraction from functools import reduce -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from monty.dev import deprecated @@ -23,8 +23,10 @@ from collections.abc import Iterator, Sequence from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core.trajectory import Vector3D + from pymatgen.util.typing import PbcLike __author__ = "Shyue Ping Ong, Michael Kocher" __copyright__ = "Copyright 2011, The Materials Project" @@ -39,8 +41,7 @@ class Lattice(MSONable): """ # Properties lazily generated for efficiency. - - def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True, True)) -> None: + def __init__(self, matrix: ArrayLike, pbc: PbcLike = (True, True, True)) -> None: """Create a lattice from any sequence of 9 numbers. Note that the sequence is assumed to be read one row at a time. Each row represents one lattice vector. @@ -56,7 +57,7 @@ def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True E.g., [[10, 0, 0], [20, 10, 0], [0, 0, 30]] specifies a lattice with lattice vectors [10, 0, 0], [20, 10, 0] and [0, 0, 30]. pbc: a tuple defining the periodic boundary conditions along the three - axis of the lattice. If None periodic in all directions. + axis of the lattice. """ mat = np.array(matrix, dtype=np.float64).reshape((3, 3)) mat.setflags(write=False) @@ -65,7 +66,13 @@ def __init__(self, matrix: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True self._diags = None self._lll_matrix_mappings: dict[float, tuple[np.ndarray, np.ndarray]] = {} self._lll_inverse = None - self._pbc = tuple(pbc) + if len(pbc) != 3 or {*pbc} - {True, False}: + raise ValueError(f"pbc must be a tuple of three True/False values, got {pbc}") + + # don't import module-level, causes circular import with util/typing.py + from pymatgen.util.typing import PbcLike + + self._pbc = cast(PbcLike, tuple(pbc)) @property def lengths(self) -> Vector3D: @@ -83,13 +90,12 @@ def angles(self) -> Vector3D: Returns: The angles (alpha, beta, gamma) of the lattice. """ - mat = self._matrix - lengths = self.lengths + matrix, lengths = self._matrix, self.lengths angles = np.zeros(3) for dim in range(3): - j = (dim + 1) % 3 - k = (dim + 2) % 3 - angles[dim] = np.clip(np.dot(mat[j], mat[k]) / (lengths[j] * lengths[k]), -1, 1) + jj = (dim + 1) % 3 + kk = (dim + 2) % 3 + angles[dim] = np.clip(np.dot(matrix[jj], matrix[kk]) / (lengths[jj] * lengths[kk]), -1, 1) angles = np.arccos(angles) * 180.0 / np.pi return tuple(angles.tolist()) # type: ignore @@ -134,9 +140,9 @@ def matrix(self) -> np.ndarray: return self._matrix @property - def pbc(self) -> tuple[bool, bool, bool]: + def pbc(self) -> PbcLike: """Tuple defining the periodicity of the Lattice.""" - return self._pbc # type: ignore + return self._pbc @property def is_3d_periodic(self) -> bool: @@ -206,12 +212,12 @@ def d_hkl(self, miller_index: ArrayLike) -> float: Returns: d_hkl (float) """ - gstar = self.reciprocal_lattice_crystallographic.metric_tensor + g_star = self.reciprocal_lattice_crystallographic.metric_tensor hkl = np.array(miller_index) - return 1 / ((np.dot(np.dot(hkl, gstar), hkl.T)) ** (1 / 2)) + return 1 / ((np.dot(np.dot(hkl, g_star), hkl.T)) ** (1 / 2)) - @staticmethod - def cubic(a: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def cubic(cls, a: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a cubic lattice. Args: @@ -222,10 +228,10 @@ def cubic(a: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattic Returns: Cubic lattice of dimensions a x a x a. """ - return Lattice([[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]], pbc) + return cls([[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]], pbc) - @staticmethod - def tetragonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def tetragonal(cls, a: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a tetragonal lattice. Args: @@ -237,10 +243,10 @@ def tetragonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, T Returns: Tetragonal lattice of dimensions a x a x c. """ - return Lattice.from_parameters(a, a, c, 90, 90, 90, pbc=pbc) + return cls.from_parameters(a, a, c, 90, 90, 90, pbc=pbc) - @staticmethod - def orthorhombic(a: float, b: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def orthorhombic(cls, a: float, b: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for an orthorhombic lattice. Args: @@ -253,18 +259,16 @@ def orthorhombic(a: float, b: float, c: float, pbc: tuple[bool, bool, bool] = (T Returns: Orthorhombic lattice of dimensions a x b x c. """ - return Lattice.from_parameters(a, b, c, 90, 90, 90, pbc=pbc) + return cls.from_parameters(a, b, c, 90, 90, 90, pbc=pbc) - @staticmethod - def monoclinic( - a: float, b: float, c: float, beta: float, pbc: tuple[bool, bool, bool] = (True, True, True) - ) -> Lattice: + @classmethod + def monoclinic(cls, a: float, b: float, c: float, beta: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a monoclinic lattice. Args: - a (float): *a* lattice parameter of the monoclinc cell. - b (float): *b* lattice parameter of the monoclinc cell. - c (float): *c* lattice parameter of the monoclinc cell. + a (float): *a* lattice parameter of the monoclinic cell. + b (float): *b* lattice parameter of the monoclinic cell. + c (float): *c* lattice parameter of the monoclinic cell. beta (float): *beta* angle between lattice vectors b and c in degrees. pbc (tuple): a tuple defining the periodic boundary conditions along the three @@ -274,10 +278,10 @@ def monoclinic( Monoclinic lattice of dimensions a x b x c with non right-angle beta between lattice vectors a and c. """ - return Lattice.from_parameters(a, b, c, 90, beta, 90, pbc=pbc) + return cls.from_parameters(a, b, c, 90, beta, 90, pbc=pbc) - @staticmethod - def hexagonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def hexagonal(cls, a: float, c: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a hexagonal lattice. Args: @@ -289,10 +293,10 @@ def hexagonal(a: float, c: float, pbc: tuple[bool, bool, bool] = (True, True, Tr Returns: Hexagonal lattice of dimensions a x a x c. """ - return Lattice.from_parameters(a, a, c, 90, 90, 120, pbc=pbc) + return cls.from_parameters(a, a, c, 90, 90, 120, pbc=pbc) - @staticmethod - def rhombohedral(a: float, alpha: float, pbc: tuple[bool, bool, bool] = (True, True, True)) -> Lattice: + @classmethod + def rhombohedral(cls, a: float, alpha: float, pbc: PbcLike = (True, True, True)) -> Self: """Convenience constructor for a rhombohedral lattice. Args: @@ -304,7 +308,7 @@ def rhombohedral(a: float, alpha: float, pbc: tuple[bool, bool, bool] = (True, T Returns: Rhombohedral lattice of dimensions a x a x a. """ - return Lattice.from_parameters(a, a, a, alpha, alpha, alpha, pbc=pbc) + return cls.from_parameters(a, a, a, alpha, alpha, alpha, pbc=pbc) @classmethod def from_parameters( @@ -315,9 +319,10 @@ def from_parameters( alpha: float, beta: float, gamma: float, + *, # help mypy separate positional and keyword-only arguments vesta: bool = False, - pbc: tuple[bool, bool, bool] = (True, True, True), - ): + pbc: PbcLike = (True, True, True), + ) -> Self: """Create a Lattice using unit cell lengths (in Angstrom) and angles (in degrees). Args: @@ -359,10 +364,10 @@ def from_parameters( ] vector_c = [0.0, 0.0, float(c)] - return Lattice([vector_a, vector_b, vector_c], pbc) + return cls([vector_a, vector_b, vector_c], pbc) @classmethod - def from_dict(cls, dct: dict, fmt: str | None = None, **kwargs): + def from_dict(cls, dct: dict, fmt: str | None = None, **kwargs) -> Self: # type: ignore[override] """Create a Lattice from a dictionary containing the a, b, c, alpha, beta, and gamma parameters if fmt is None. @@ -427,7 +432,7 @@ def volume(self) -> float: @property def parameters(self) -> tuple[float, float, float, float, float, float]: - """Returns (a, b, c, alpha, beta, gamma).""" + """Returns 6-tuple of floats (a, b, c, alpha, beta, gamma).""" return (*self.lengths, *self.angles) @property @@ -436,20 +441,22 @@ def params_dict(self) -> dict[str, float]: return dict(zip("a b c alpha beta gamma".split(), self.parameters)) @property - def reciprocal_lattice(self) -> Lattice: + def reciprocal_lattice(self) -> Self: """Return the reciprocal lattice. Note that this is the standard reciprocal lattice used for solid state physics with a factor of 2 * pi. If you are looking for the crystallographic reciprocal lattice, use the reciprocal_lattice_crystallographic property. The property is lazily generated for efficiency. """ - v = np.linalg.inv(self._matrix).T - return Lattice(v * 2 * np.pi) + inv_mat = np.linalg.inv(self._matrix).T + cls = type(self) + return cls(inv_mat * 2 * np.pi) @property - def reciprocal_lattice_crystallographic(self) -> Lattice: + def reciprocal_lattice_crystallographic(self) -> Self: """Returns the *crystallographic* reciprocal lattice, i.e. no factor of 2 * pi.""" - return Lattice(self.reciprocal_lattice.matrix / (2 * np.pi)) + cls = type(self) + return cls(self.reciprocal_lattice.matrix / (2 * np.pi)) @property def lll_matrix(self) -> np.ndarray: @@ -884,7 +891,7 @@ def find_all_mappings( None is returned if no matches are found. """ lengths = other_lattice.lengths - (alpha, beta, gamma) = other_lattice.angles + alpha, beta, gamma = other_lattice.angles frac, dist, _, _ = self.get_points_in_sphere( # type: ignore[misc] [[0, 0, 0]], [0, 0, 0], max(lengths) * (1 + ltol), zip_results=False @@ -940,8 +947,8 @@ def find_mapping( Defaults to False. Returns: - (aligned_lattice, rotation_matrix, scale_matrix) if a mapping is - found. aligned_lattice is a rotated version of other_lattice that + tuple[Lattice, np.ndarray, np.ndarray]: (aligned_lattice, rotation_matrix, scale_matrix) + if a mapping is found. aligned_lattice is a rotated version of other_lattice that has the same lattice parameters, but which is aligned in the coordinate system of this lattice so that translational points match up in 3D. rotation_matrix is the rotation that has to be @@ -957,15 +964,19 @@ def find_mapping( """ return next(self.find_all_mappings(other_lattice, ltol, atol, skip_rotation_matrix), None) - def get_lll_reduced_lattice(self, delta: float = 0.75) -> Lattice: - """:param delta: Delta parameter. + def get_lll_reduced_lattice(self, delta: float = 0.75) -> Self: + """Lenstra-Lenstra-Lovasz lattice basis reduction. + + Args: + delta: Delta parameter. Returns: - LLL reduced Lattice. + Lattice: LLL reduced """ if delta not in self._lll_matrix_mappings: self._lll_matrix_mappings[delta] = self._calculate_lll() - return Lattice(self._lll_matrix_mappings[delta][0]) + cls = type(self) + return cls(self._lll_matrix_mappings[delta][0]) def _calculate_lll(self, delta: float = 0.75) -> tuple[np.ndarray, np.ndarray]: """Performs a Lenstra-Lenstra-Lovasz lattice basis reduction to obtain a @@ -1090,14 +1101,14 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: if B + e < A or (abs(A - B) < e and abs(E) > abs(N) + e): # A1 - M = [[0, -1, 0], [-1, 0, 0], [0, 0, -1]] + M = np.array([[0, -1, 0], [-1, 0, 0], [0, 0, -1]]) G = np.dot(np.transpose(M), np.dot(G, M)) # update lattice parameters based on new G (gh-3657) A, B, C, E, N, Y = G[0, 0], G[1, 1], G[2, 2], 2 * G[1, 2], 2 * G[0, 2], 2 * G[0, 1] if (C + e < B) or (abs(B - C) < e and abs(N) > abs(Y) + e): # A2 - M = [[-1, 0, 0], [0, 0, -1], [0, -1, 0]] + M = np.array([[-1, 0, 0], [0, 0, -1], [0, -1, 0]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue @@ -1131,25 +1142,25 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: # A5 if abs(E) > B + e or (abs(E - B) < e and Y - e > 2 * N) or (abs(E + B) < e and -e > Y): - M = [[1, 0, 0], [0, 1, -E / abs(E)], [0, 0, 1]] + M = np.array([[1, 0, 0], [0, 1, -E / abs(E)], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A6 if abs(N) > A + e or (abs(A - N) < e and Y - e > 2 * E) or (abs(A + N) < e and -e > Y): - M = [[1, 0, -N / abs(N)], [0, 1, 0], [0, 0, 1]] + M = np.array([[1, 0, -N / abs(N)], [0, 1, 0], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A7 if abs(Y) > A + e or (abs(A - Y) < e and N - e > 2 * E) or (abs(A + Y) < e and -e > N): - M = [[1, -Y / abs(Y), 0], [0, 1, 0], [0, 0, 1]] + M = np.array([[1, -Y / abs(Y), 0], [0, 1, 0], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue # A8 if -e > E + N + Y + A + B or (abs(E + N + Y + A + B) < e < Y + (A + N) * 2): - M = [[1, 0, 1], [0, 1, 1], [0, 0, 1]] + M = np.array([[1, 0, 1], [0, 1, 1], [0, 0, 1]]) G = np.dot(np.transpose(M), np.dot(G, M)) continue @@ -1167,7 +1178,6 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: alpha = math.acos(E / 2 / b / c) / math.pi * 180 beta = math.acos(N / 2 / a / c) / math.pi * 180 gamma = math.acos(Y / 2 / a / b) / math.pi * 180 - lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma) mapped = self.find_mapping(lattice, e, skip_rotation_matrix=True) @@ -1178,7 +1188,7 @@ def get_niggli_reduced_lattice(self, tol: float = 1e-5) -> Lattice: raise ValueError("can't find niggli") - def scale(self, new_volume: float) -> Lattice: + def scale(self, new_volume: float) -> Self: """Return a new Lattice with volume new_volume by performing a scaling of the lattice vectors so that length proportions and angles are preserved. @@ -1198,7 +1208,7 @@ def scale(self, new_volume: float) -> Lattice: new_c = (new_volume / (geo_factor * np.prod(ratios))) ** (1 / 3.0) - return Lattice(versors * (new_c * ratios), pbc=self.pbc) + return type(self)(versors * (new_c * ratios), pbc=self.pbc) def get_wigner_seitz_cell(self) -> list[list[np.ndarray]]: """Returns the Wigner-Seitz cell for the given lattice. @@ -1322,14 +1332,13 @@ def get_points_in_sphere( return self.get_points_in_sphere_py(frac_points=frac_points, center=center, r=r, zip_results=zip_results) else: frac_points = np.ascontiguousarray(frac_points, dtype=float) - lattice_matrix = np.ascontiguousarray(self.matrix, dtype=float) + latt_matrix = np.ascontiguousarray(self.matrix, dtype=float) cart_coords = np.ascontiguousarray(self.get_cartesian_coords(frac_points), dtype=float) pbc = np.ascontiguousarray(self.pbc, dtype=int) - r = float(r) center_coords = np.ascontiguousarray([center], dtype=float) _, indices, images, distances = find_points_in_spheres( - all_coords=cart_coords, center_coords=center_coords, r=r, pbc=pbc, lattice=lattice_matrix, tol=1e-8 + all_coords=cart_coords, center_coords=center_coords, r=float(r), pbc=pbc, lattice=latt_matrix, tol=1e-8 ) if len(indices) < 1: return [] if zip_results else [()] * 4 @@ -1509,8 +1518,10 @@ def get_all_distances( return np.sqrt(d2) def is_hexagonal(self, hex_angle_tol: float = 5, hex_length_tol: float = 0.01) -> bool: - """:param hex_angle_tol: Angle tolerance - :param hex_length_tol: Length tolerance + """ + Args: + hex_angle_tol: Angle tolerance + hex_length_tol: Length tolerance. Returns: Whether lattice corresponds to hexagonal lattice. @@ -1551,10 +1562,10 @@ def get_distance_and_image( the image that is nearest to the site is found. Returns: - (distance, jimage): distance and periodic lattice translations - of the other site for which the distance applies. This means that - the distance between frac_coords1 and (jimage + frac_coords2) is - equal to distance. + tuple[float, np.ndarray]: distance and periodic lattice translations (jimage) + of the other site for which the distance applies. This means that + the distance between frac_coords1 and (jimage + frac_coords2) is + equal to distance. """ if jimage is None: v, d2 = pbc_shortest_vectors(self, frac_coords1, frac_coords2, return_d2=True) @@ -1642,7 +1653,7 @@ def get_integer_index(miller_index: Sequence[float], round_dp: int = 4, verbose: verbose (bool, optional): Whether to print warnings. Returns: - (tuple): The Miller index. + tuple: The Miller index. """ mi = np.asarray(miller_index) # deal with the case we have small irregular floats @@ -1687,7 +1698,7 @@ def get_points_in_spheres( all_coords: np.ndarray, center_coords: np.ndarray, r: float, - pbc: bool | list[bool] | tuple[bool, bool, bool] = True, + pbc: bool | list[bool] | PbcLike = True, numerical_tol: float = 1e-8, lattice: Lattice | None = None, return_fcoords: bool = False, @@ -1797,16 +1808,16 @@ def get_points_in_spheres( nn_coords = np.concatenate([cube_to_coords[k] for k in ks], axis=0) nn_images = itertools.chain(*(cube_to_images[k] for k in ks)) nn_indices = itertools.chain(*(cube_to_indices[k] for k in ks)) - dist = np.linalg.norm(nn_coords - ii[None, :], axis=1) + distances = np.linalg.norm(nn_coords - ii[None, :], axis=1) nns: list[tuple[np.ndarray, float, int, np.ndarray]] = [] - for coord, index, image, d in zip(nn_coords, nn_indices, nn_images, dist): + for coord, index, image, dist in zip(nn_coords, nn_indices, nn_images, distances): # filtering out all sites that are beyond the cutoff # Here there is no filtering of overlapping sites - if d < r + numerical_tol: + if dist < r + numerical_tol: if return_fcoords and (lattice is not None): coord = np.round(lattice.get_fractional_coords(coord), 10) - nn = (coord, float(d), int(index), image) - nns.append(nn) + nn = (coord, float(dist), int(index), image) + nns.append(nn) # type: ignore[arg-type] neighbors.append(nns) return neighbors diff --git a/pymatgen/core/libxcfunc.py b/pymatgen/core/libxcfunc.py index 993e140bdd0..b60b112c17b 100644 --- a/pymatgen/core/libxcfunc.py +++ b/pymatgen/core/libxcfunc.py @@ -9,10 +9,14 @@ import json import os -from enum import Enum +from enum import Enum, unique +from typing import TYPE_CHECKING from monty.json import MontyEncoder +if TYPE_CHECKING: + from typing_extensions import Self + # The libxc version used to generate this file! libxc_version = "3.0.0" @@ -25,11 +29,11 @@ __date__ = "May 16, 2016" # Loads libxc info from json file -with open(os.path.join(os.path.dirname(__file__), "libxc_docs.json")) as file: +with open(os.path.join(os.path.dirname(__file__), "libxc_docs.json"), encoding="utf-8") as file: _all_xcfuncs = {int(k): v for k, v in json.load(file).items()} -# @unique +@unique class LibxcFunc(Enum): """Enumerator with the identifiers. This object is used by Xcfunc declared in xcfunc.py to create an internal representation of the XC functional. @@ -486,7 +490,7 @@ def as_dict(self): return {"name": self.name, "@module": type(self).__module__, "@class": type(self).__name__} @classmethod - def from_dict(cls, dct: dict) -> LibxcFunc: + def from_dict(cls, dct: dict) -> Self: """Deserialize from MSONable dict representation.""" return cls[dct["name"]] diff --git a/pymatgen/core/operations.py b/pymatgen/core/operations.py index f1b9d26d988..7bc52f41f4b 100644 --- a/pymatgen/core/operations.py +++ b/pymatgen/core/operations.py @@ -20,6 +20,7 @@ from typing import Any from numpy.typing import ArrayLike + from typing_extensions import Self __author__ = "Shyue Ping Ong, Shyam Dwaraknath, Matthew Horton" @@ -54,12 +55,13 @@ def __init__(self, affine_transformation_matrix: ArrayLike, tol: float = 0.01) - self.affine_matrix = affine_transformation_matrix self.tol = tol - @staticmethod + @classmethod def from_rotation_and_translation( + cls, rotation_matrix: ArrayLike = ((1, 0, 0), (0, 1, 0), (0, 0, 1)), translation_vec: ArrayLike = (0, 0, 0), tol: float = 0.1, - ) -> SymmOp: + ) -> Self: """Creates a symmetry operation from a rotation matrix and a translation vector. @@ -80,7 +82,7 @@ def from_rotation_and_translation( affine_matrix = np.eye(4) affine_matrix[0:3][:, 0:3] = rotation_matrix affine_matrix[0:3][:, 3] = translation_vec - return SymmOp(affine_matrix, tol) + return cls(affine_matrix, tol) def __eq__(self, other: object) -> bool: if not isinstance(other, SymmOp): @@ -192,7 +194,9 @@ def are_symmetrically_related_vectors( tol (float): Absolute tolerance for checking distance. Returns: - (are_related, is_reversed) + tuple[bool, bool]: First bool indicates if the vectors are related, + the second if the vectors are related but the starting and end point + are exchanged. """ from_c = self.operate(from_a) to_c = self.operate(to_a) @@ -231,8 +235,8 @@ def __mul__(self, other): @property def inverse(self) -> SymmOp: """Returns inverse of transformation.""" - invr = np.linalg.inv(self.affine_matrix) - return SymmOp(invr) + inverse = np.linalg.inv(self.affine_matrix) + return SymmOp(inverse) @staticmethod def from_axis_angle_and_translation( @@ -256,22 +260,22 @@ def from_axis_angle_and_translation( vec = np.array(translation_vec) - a = angle if angle_in_radians else angle * pi / 180 - cosa = cos(a) - sina = sin(a) - u = axis / np.linalg.norm(axis) # type: ignore - r = np.zeros((3, 3)) - r[0, 0] = cosa + u[0] ** 2 * (1 - cosa) # type: ignore - r[0, 1] = u[0] * u[1] * (1 - cosa) - u[2] * sina # type: ignore - r[0, 2] = u[0] * u[2] * (1 - cosa) + u[1] * sina # type: ignore - r[1, 0] = u[0] * u[1] * (1 - cosa) + u[2] * sina # type: ignore - r[1, 1] = cosa + u[1] ** 2 * (1 - cosa) # type: ignore - r[1, 2] = u[1] * u[2] * (1 - cosa) - u[0] * sina # type: ignore - r[2, 0] = u[0] * u[2] * (1 - cosa) - u[1] * sina # type: ignore - r[2, 1] = u[1] * u[2] * (1 - cosa) + u[0] * sina # type: ignore - r[2, 2] = cosa + u[2] ** 2 * (1 - cosa) # type: ignore - - return SymmOp.from_rotation_and_translation(r, vec) + ang = angle if angle_in_radians else angle * pi / 180 + cos_a = cos(ang) + sin_a = sin(ang) + unit_vec = axis / np.linalg.norm(axis) # type: ignore + rot_mat = np.zeros((3, 3)) + rot_mat[0, 0] = cos_a + unit_vec[0] ** 2 * (1 - cos_a) # type: ignore + rot_mat[0, 1] = unit_vec[0] * unit_vec[1] * (1 - cos_a) - unit_vec[2] * sin_a # type: ignore + rot_mat[0, 2] = unit_vec[0] * unit_vec[2] * (1 - cos_a) + unit_vec[1] * sin_a # type: ignore + rot_mat[1, 0] = unit_vec[0] * unit_vec[1] * (1 - cos_a) + unit_vec[2] * sin_a # type: ignore + rot_mat[1, 1] = cos_a + unit_vec[1] ** 2 * (1 - cos_a) # type: ignore + rot_mat[1, 2] = unit_vec[1] * unit_vec[2] * (1 - cos_a) - unit_vec[0] * sin_a # type: ignore + rot_mat[2, 0] = unit_vec[0] * unit_vec[2] * (1 - cos_a) - unit_vec[1] * sin_a # type: ignore + rot_mat[2, 1] = unit_vec[1] * unit_vec[2] * (1 - cos_a) + unit_vec[0] * sin_a # type: ignore + rot_mat[2, 2] = cos_a + unit_vec[2] ** 2 * (1 - cos_a) # type: ignore + + return SymmOp.from_rotation_and_translation(rot_mat, vec) @typing.no_type_check @staticmethod @@ -394,7 +398,7 @@ def rotoreflection(axis: ArrayLike, angle: float, origin: ArrayLike = (0, 0, 0)) origin (3x1 array): Point left invariant by roto-reflection. Defaults to (0, 0, 0). - Return: + Returns: Roto-reflection operation """ rot = SymmOp.from_origin_axis_angle(origin, axis, angle) @@ -422,7 +426,7 @@ def as_xyz_str(self) -> str: return transformation_to_string(self.rotation_matrix, translation_vec=self.translation_vector, delim=", ") @classmethod - def from_xyz_str(cls, xyz_str: str) -> SymmOp: + def from_xyz_str(cls, xyz_str: str) -> Self: """ Args: xyz_str: string of the form 'x, y, z', '-x, -y, z', '-2y+1/2, 3x+1/2, z-y+1/2', etc. @@ -451,8 +455,10 @@ def from_xyz_str(cls, xyz_str: str) -> SymmOp: return cls.from_rotation_and_translation(rot_matrix, trans) @classmethod - def from_dict(cls, dct) -> SymmOp: - """:param dct: dict + def from_dict(cls, dct) -> Self: + """ + Args: + dct: dict. Returns: SymmOp from dict representation. @@ -480,9 +486,10 @@ def __init__(self, affine_transformation_matrix: ArrayLike, time_reversal: int, tol (float): Tolerance for determining if matrices are equal. """ SymmOp.__init__(self, affine_transformation_matrix, tol=tol) - if time_reversal not in (-1, 1): - raise Exception(f"Time reversal operator not well defined: {time_reversal}, {type(time_reversal)}") - self.time_reversal = time_reversal + if time_reversal in {-1, 1}: + self.time_reversal = time_reversal + else: + raise RuntimeError(f"Invalid {time_reversal=}, must be 1 or -1") def __eq__(self, other: object) -> bool: if not isinstance(other, SymmOp): @@ -538,7 +545,7 @@ class or as list or np array-like return Magmom.from_global_moment_and_saxis(transformed_moment, magmom.saxis) @classmethod - def from_symmop(cls, symmop: SymmOp, time_reversal) -> MagSymmOp: + def from_symmop(cls, symmop: SymmOp, time_reversal) -> Self: """Initialize a MagSymmOp from a SymmOp and time reversal operator. Args: @@ -575,7 +582,7 @@ def from_rotation_and_translation_and_time_reversal( return MagSymmOp.from_symmop(symm_op, time_reversal) @classmethod - def from_xyzt_str(cls, xyzt_str: str) -> MagSymmOp: + def from_xyzt_str(cls, xyzt_str: str) -> Self: """ Args: xyzt_str (str): of the form 'x, y, z, +1', '-x, -y, z, -1', @@ -588,7 +595,7 @@ def from_xyzt_str(cls, xyzt_str: str) -> MagSymmOp: try: time_reversal = int(xyzt_str.rsplit(",", 1)[1]) except Exception: - raise Exception("Time reversal operator could not be parsed.") + raise RuntimeError("Time reversal operator could not be parsed.") return cls.from_symmop(symm_op, time_reversal) def as_xyzt_str(self) -> str: @@ -609,8 +616,10 @@ def as_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, dct: dict) -> MagSymmOp: - """:param dct: dict + def from_dict(cls, dct: dict) -> Self: + """ + Args: + dct: dict. Returns: MagneticSymmOp from dict representation. diff --git a/pymatgen/core/periodic_table.py b/pymatgen/core/periodic_table.py index 20b1979092c..80ce3d83485 100644 --- a/pymatgen/core/periodic_table.py +++ b/pymatgen/core/periodic_table.py @@ -8,7 +8,7 @@ import re import warnings from collections import Counter -from enum import Enum +from enum import Enum, unique from itertools import combinations, product from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Literal @@ -16,20 +16,23 @@ import numpy as np from monty.json import MSONable -from pymatgen.core.units import SUPPORTED_UNIT_NAMES, FloatWithUnit, Length, Mass, Unit +from pymatgen.core.units import SUPPORTED_UNIT_NAMES, FloatWithUnit, Ha_to_eV, Length, Mass, Unit from pymatgen.util.string import Stringify, formula_double_format if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.util.typing import SpeciesLike # Loads element data from json file -with open(Path(__file__).absolute().parent / "periodic_table.json") as ptable_json: +with open(Path(__file__).absolute().parent / "periodic_table.json", encoding="utf-8") as ptable_json: _pt_data = json.load(ptable_json) _pt_row_sizes = (2, 8, 8, 18, 18, 32, 32) @functools.total_ordering +@unique class ElementBase(Enum): """Element class defined without any enum values so it can be subclassed. @@ -73,9 +76,10 @@ def __init__(self, symbol: SpeciesLike) -> None: electronic_structure (str): Electronic structure. E.g., The electronic structure for Fe is represented as [Ar].3d6.4s2. atomic_orbitals (dict): Atomic Orbitals. Energy of the atomic orbitals as a dict. E.g., The orbitals - energies in eV are represented as {'1s': -1.0, '2s': -0.1}. Data is obtained from + energies in Hartree are represented as {'1s': -1.0, '2s': -0.1}. Data is obtained from https://www.nist.gov/pml/data/atomic-reference-data-electronic-structure-calculations. The LDA values for neutral atoms are used. + atomic_orbitals_eV (dict): Atomic Orbitals. Same as `atomic_orbitals` but energies are in eV. thermal_conductivity (float): Thermal conductivity. boiling_point (float): Boiling point. melting_point (float): Melting point. @@ -233,26 +237,39 @@ def __getattr__(self, item: str) -> Any: unit = "K^-1" else: unit = tokens[1] - val = FloatWithUnit(tokens[0], unit) + val = FloatWithUnit(float(tokens[0]), unit) else: unit = tokens[1].replace("", "^").replace("", "").replace("Ω", "ohm") units = Unit(unit) if set(units).issubset(SUPPORTED_UNIT_NAMES): - val = FloatWithUnit(tokens[0], unit) + val = FloatWithUnit(float(tokens[0]), unit) except ValueError: # Ignore error. val will just remain a string. pass - if item in ("refractive_index", "melting_point") and isinstance(val, str): - # Final attempt to parse a float. - m = re.findall(r"[\.\d]+", val) - if m: - warnings.warn( - f"Ambiguous values ({val}) for {item} of {self.symbol}. Returning first float value." - ) - return float(m[0]) + if ( + item in ("refractive_index", "melting_point") + and isinstance(val, str) + and (match := re.findall(r"[\.\d]+", val)) + ): + warnings.warn( + f"Ambiguous values ({val}) for {item} of {self.symbol}. Returning first float value." + ) + return float(match[0]) return val raise AttributeError(f"Element has no attribute {item}!") + @property + def atomic_orbitals_eV(self) -> dict[str, float]: + """ + Get the LDA energies in eV for neutral atoms, by orbital. + + This property contains the same info as `self.atomic_orbitals`, + but uses eV for units, per matsci issue https://matsci.org/t/unit-of-atomic-orbitals-energy/54325 + In short, self.atomic_orbitals was meant to be in eV all along but is now kept + as Hartree for backwards compatibility. + """ + return {orb_idx: energy * Ha_to_eV for orb_idx, energy in self.atomic_orbitals.items()} + @property def data(self) -> dict[str, Any]: """Returns dict of data for element.""" @@ -369,9 +386,8 @@ def full_electronic_structure(self) -> list[tuple[int, str, int]]: e_str = self.electronic_structure def parse_orbital(orb_str): - m = re.match(r"(\d+)([spdfg]+)(\d+)", orb_str) - if m: - return int(m.group(1)), m.group(2), int(m.group(3)) + if match := re.match(r"(\d+)([spdfg]+)(\d+)", orb_str): + return int(match.group(1)), match.group(2), int(match.group(3)) return orb_str data = [parse_orbital(s) for s in e_str.split(".")] @@ -602,10 +618,10 @@ def row(self) -> int: return 6 if 89 <= z <= 103: return 7 - for i, size in enumerate(_pt_row_sizes): + for i, size in enumerate(_pt_row_sizes, start=1): total += size if total >= z: - return i + 1 + return i return 8 @property @@ -940,11 +956,11 @@ def __init__( ValueError: If oxidation state passed both in symbol string and via oxidation_state kwarg. """ - if oxidation_state is not None and isinstance(symbol, str) and symbol[-1] in {"+", "-"}: + if oxidation_state is not None and isinstance(symbol, str) and symbol.endswith(("+", "-")): raise ValueError( f"Oxidation state should be specified either in {symbol=} or as {oxidation_state=}, not both." ) - if isinstance(symbol, str) and symbol[-1] in {"+", "-"}: + if isinstance(symbol, str) and symbol.endswith(("+", "-")): # Extract oxidation state from symbol symbol, oxi = re.match(r"([A-Za-z]+)([0-9]*[\+\-])", symbol).groups() # type: ignore[union-attr] self._oxi_state: float | None = (1 if "+" in oxi else -1) * float(oxi[:-1] or 1) @@ -955,15 +971,15 @@ def __init__( self._spin = spin - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: """Allows Specie to inherit properties of underlying element.""" return getattr(self._el, attr) - def __getstate__(self): + def __getstate__(self) -> dict: return self.__dict__ - def __setstate__(self, d): - self.__dict__.update(d) + def __setstate__(self, dct: dict) -> None: + self.__dict__.update(dct) def __eq__(self, other: object) -> bool: """Species is equal to other only if element and oxidation states are exactly the same.""" @@ -1039,7 +1055,7 @@ def ionic_radius(self) -> float | None: return None @classmethod - def from_str(cls, species_string: str) -> Species: + def from_str(cls, species_string: str) -> Self: """Returns a Species from a string representation. Args: @@ -1093,6 +1109,7 @@ def __str__(self) -> str: if isinstance(abs_charge, float): abs_charge = f"{abs_charge:.2f}" output += f"{abs_charge}{'+' if self.oxi_state >= 0 else '-'}" + if self._spin is not None: spin = self._spin output += f",{spin=}" @@ -1185,16 +1202,20 @@ def get_crystal_field_spin( """ if coordination not in ("oct", "tet") or spin_config not in ("high", "low"): raise ValueError("Invalid coordination or spin config") + elec = self.full_electronic_structure if len(elec) < 4 or elec[-1][1] != "s" or elec[-2][1] != "d": raise AttributeError(f"Invalid element {self.symbol} for crystal field calculation") + n_electrons = elec[-1][2] + elec[-2][2] - self.oxi_state if n_electrons < 0 or n_electrons > 10: raise AttributeError(f"Invalid oxidation state {self.oxi_state} for element {self.symbol}") + if spin_config == "high": if n_electrons <= 5: return n_electrons return 10 - n_electrons + if spin_config == "low": if coordination == "oct": if n_electrons <= 3: @@ -1204,6 +1225,7 @@ def get_crystal_field_spin( if n_electrons <= 8: return n_electrons - 6 return 10 - n_electrons + if coordination == "tet": if n_electrons <= 2: return n_electrons @@ -1212,7 +1234,8 @@ def get_crystal_field_spin( if n_electrons <= 7: return n_electrons - 4 return 10 - n_electrons - raise RuntimeError(f"should not reach here, {spin_config=}, {coordination=}") + return None + return None def __deepcopy__(self, memo) -> Species: return Species(self.symbol, self.oxi_state, spin=self._spin) @@ -1228,13 +1251,15 @@ def as_dict(self) -> dict: } @classmethod - def from_dict(cls, d) -> Species: - """:param d: Dict representation. + def from_dict(cls, dct: dict) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: Species. """ - return cls(d["element"], d["oxidation_state"], spin=d.get("spin")) + return cls(dct["element"], dct["oxidation_state"], spin=dct.get("spin")) @functools.total_ordering @@ -1346,7 +1371,7 @@ def __deepcopy__(self, memo): return DummySpecies(self.symbol, self._oxi_state) @classmethod - def from_str(cls, species_string: str) -> DummySpecies: + def from_str(cls, species_string: str) -> Self: """Returns a Dummy from a string representation. Args: @@ -1359,17 +1384,16 @@ def from_str(cls, species_string: str) -> DummySpecies: Raises: ValueError if species_string cannot be interpreted. """ - m = re.search(r"([A-ZAa-z]*)([0-9.]*)([+\-]*)(.*)", species_string) - if m: - sym = m.group(1) - if m.group(2) == m.group(3) == "": + if match := re.search(r"([A-ZAa-z]*)([0-9.]*)([+\-]*)(.*)", species_string): + sym = match.group(1) + if match.group(2) == match.group(3) == "": oxi = 0.0 else: - oxi = 1.0 if m.group(2) == "" else float(m.group(2)) - oxi = -oxi if m.group(3) == "-" else oxi + oxi = 1.0 if match.group(2) == "" else float(match.group(2)) + oxi = -oxi if match.group(3) == "-" else oxi properties = {} - if m.group(4): # has Spin property - tokens = m.group(4).split("=") + if match.group(4): # has Spin property + tokens = match.group(4).split("=") properties = {tokens[0]: float(tokens[1])} return cls(sym, oxi, **properties) raise ValueError("Invalid DummySpecies String") @@ -1385,13 +1409,15 @@ def as_dict(self) -> dict: } @classmethod - def from_dict(cls, d) -> DummySpecies: - """:param d: Dict representation + def from_dict(cls, dct: dict) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: DummySpecies """ - return cls(d["element"], d["oxidation_state"], spin=d.get("spin")) + return cls(dct["element"], dct["oxidation_state"], spin=dct.get("spin")) def __repr__(self) -> str: return f"DummySpecies {self}" @@ -1443,28 +1469,56 @@ def get_el_sp(obj: int | SpeciesLike) -> Element | Species | DummySpecies: Species | Element: with a bias for the maximum number of properties that can be determined. """ + # if obj is already an Element or Species, return as is if isinstance(obj, (Element, Species, DummySpecies)): if getattr(obj, "_is_named_isotope", None): return Element(obj.name) if isinstance(obj, Element) else Species(str(obj)) return obj + # if obj is an integer, return the Element with atomic number obj try: flt = float(obj) - integer = int(flt) - integer = integer if integer == flt else None # type: ignore - return Element.from_Z(integer) - except (ValueError, TypeError): + assert flt == int(flt) + return Element.from_Z(int(flt)) + except (AssertionError, ValueError, TypeError, KeyError): pass + # if obj is a string, attempt to parse it as a Species try: return Species.from_str(obj) # type: ignore - except (ValueError, KeyError): + except (ValueError, TypeError, KeyError): pass + # if Species parsing failed, try Element try: return Element(obj) # type: ignore - except (ValueError, KeyError): + except (ValueError, TypeError, KeyError): pass + + # if Element parsing failed, try DummySpecies try: return DummySpecies.from_str(obj) # type: ignore except Exception: - raise ValueError(f"Can't parse Element or Species from {type(obj).__name__}: {obj}.") + raise ValueError(f"Can't parse Element or Species from {obj!r}") + + +@unique +class ElementType(Enum): + """Enum for element types.""" + + noble_gas = "noble_gas" # He, Ne, Ar, Kr, Xe, Rn + transition_metal = "transition_metal" # Sc-Zn, Y-Cd, La-Hg, Ac-Cn + post_transition_metal = "post_transition_metal" # Al, Ga, In, Tl, Sn, Pb, Bi, Po + rare_earth_metal = "rare_earth_metal" # Ce-Lu, Th-Lr + metal = "metal" + metalloid = "metalloid" # B, Si, Ge, As, Sb, Te, Po + alkali = "alkali" # Li, Na, K, Rb, Cs, Fr + alkaline = "alkaline" # Be, Mg, Ca, Sr, Ba, Ra + halogen = "halogen" # F, Cl, Br, I, At + chalcogen = "chalcogen" # O, S, Se, Te, Po + lanthanoid = "lanthanoid" # La-Lu + actinoid = "actinoid" # Ac-Lr + quadrupolar = "quadrupolar" + s_block = "s-block" + p_block = "p-block" + d_block = "d-block" + f_block = "f-block" diff --git a/pymatgen/core/sites.py b/pymatgen/core/sites.py index 2132238ed83..1cc9af08aef 100644 --- a/pymatgen/core/sites.py +++ b/pymatgen/core/sites.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.util.typing import CompositionLike, SpeciesLike @@ -40,21 +41,22 @@ def __init__( ) -> None: """Creates a non-periodic Site. - :param species: Species on the site. Can be: - i. A Composition-type object (preferred) - ii. An element / species specified either as a string - symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, - e.g., 3, 56, or actual Element or Species objects. - iii.Dict of elements/species and occupancies, e.g., - {"Fe" : 0.5, "Mn":0.5}. This allows the setup of - disordered structures. - :param coords: Cartesian coordinates of site. - :param properties: Properties associated with the site as a dict, e.g. - {"magmom": 5}. Defaults to None. - :param label: Label for the site. Defaults to None. - :param skip_checks: Whether to ignore all the usual checks and just - create the site. Use this if the Site is created in a controlled - manner and speed is desired. + Args: + species: Species on the site. Can be: + i. A Composition-type object (preferred) + ii. An element / species specified either as a string + symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, + e.g., 3, 56, or actual Element or Species objects. + iii.Dict of elements/species and occupancies, e.g., + {"Fe" : 0.5, "Mn":0.5}. This allows the setup of + disordered structures. + coords: Cartesian coordinates of site. + properties: Properties associated with the site as a dict, e.g. + {"magmom": 5}. Defaults to None. + label: Label for the site. Defaults to None. + skip_checks: Whether to ignore all the usual checks and just + create the site. Use this if the Site is created in a controlled + manner and speed is desired. """ if not skip_checks: if not isinstance(species, Composition): @@ -163,7 +165,7 @@ def species_string(self) -> str: @property def specie(self) -> Element | Species | DummySpecies: """The Species/Element at the site. Only works for ordered sites. Otherwise - an AttributeError is raised. Use this property sparingly. Robust + an AttributeError is raised. Use this property sparingly. Robust design should make use of the property species instead. Note that the singular of species is also species. So the choice of this variable name is governed by programmatic concerns as opposed to grammar. @@ -261,7 +263,7 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, dct: dict) -> Site: + def from_dict(cls, dct: dict) -> Self: """Create Site from dict representation.""" atoms_n_occu = {} for sp_occu in dct["species"]: @@ -298,28 +300,29 @@ def __init__( ) -> None: """Create a periodic site. - :param species: Species on the site. Can be: - i. A Composition-type object (preferred) - ii. An element / species specified either as a string - symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, - e.g., 3, 56, or actual Element or Species objects. - iii.Dict of elements/species and occupancies, e.g., - {"Fe" : 0.5, "Mn":0.5}. This allows the setup of - disordered structures. - :param coords: Coordinates of site, fractional coordinates - by default. See ``coords_are_cartesian`` for more details. - :param lattice: Lattice associated with the site. - :param to_unit_cell: Translates fractional coordinate to the - basic unit cell, i.e. all fractional coordinates satisfy 0 - <= a < 1. Defaults to False. - :param coords_are_cartesian: Set to True if you are providing - Cartesian coordinates. Defaults to False. - :param properties: Properties associated with the site as a dict, e.g. - {"magmom": 5}. Defaults to None. - :param label: Label for the site. Defaults to None. - :param skip_checks: Whether to ignore all the usual checks and just - create the site. Use this if the PeriodicSite is created in a - controlled manner and speed is desired. + Args: + species: Species on the site. Can be: + i. A Composition-type object (preferred) + ii. An element / species specified either as a string + symbols, e.g. "Li", "Fe2+", "P" or atomic numbers, + e.g., 3, 56, or actual Element or Species objects. + iii.Dict of elements/species and occupancies, e.g., + {"Fe" : 0.5, "Mn":0.5}. This allows the setup of + disordered structures. + coords: Coordinates of site, fractional coordinates + by default. See ``coords_are_cartesian`` for more details. + lattice: Lattice associated with the site. + to_unit_cell: Translates fractional coordinate to the + basic unit cell, i.e. all fractional coordinates satisfy 0 + <= a < 1. Defaults to False. + coords_are_cartesian: Set to True if you are providing + Cartesian coordinates. Defaults to False. + properties: Properties associated with the site as a dict, e.g. + {"magmom": 5}. Defaults to None. + label: Label for the site. Defaults to None. + skip_checks: Whether to ignore all the usual checks and just + create the site. Use this if the PeriodicSite is created in a + controlled manner and speed is desired. """ frac_coords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords @@ -497,15 +500,15 @@ def distance_and_image_from_frac_coords( jimage is also returned. Args: - fcoords (3x1 array): fcoords to get distance from. + fcoords (3x1 array): fractional coordinates to get distance from. jimage (3x1 array): Specific periodic image in terms of lattice translations, e.g., [1,0,0] implies to take periodic image that is one a-lattice vector away. If jimage is None, the image that is nearest to the site is found. Returns: - (distance, jimage): distance and periodic lattice translations - of the other site for which the distance applies. + tuple[float, np.ndarray]: distance and periodic lattice translations (jimage) + of the other site for which the distance applies. """ return self.lattice.get_distance_and_image(self.frac_coords, fcoords, jimage=jimage) @@ -525,8 +528,8 @@ def distance_and_image(self, other: PeriodicSite, jimage: ArrayLike | None = Non the image that is nearest to the site is found. Returns: - (distance, jimage): distance and periodic lattice translations - of the other site for which the distance applies. + tuple[float, np.ndarray]: distance and periodic lattice translations (jimage) + of the other site for which the distance applies. """ return self.distance_and_image_from_frac_coords(other.frac_coords, jimage) @@ -588,7 +591,7 @@ def as_dict(self, verbosity: int = 0) -> dict: return dct @classmethod - def from_dict(cls, dct, lattice=None) -> PeriodicSite: + def from_dict(cls, dct, lattice=None) -> Self: """Create PeriodicSite from dict representation. Args: diff --git a/pymatgen/core/spectrum.py b/pymatgen/core/spectrum.py index b50ad021864..7e819edc397 100644 --- a/pymatgen/core/spectrum.py +++ b/pymatgen/core/spectrum.py @@ -18,9 +18,12 @@ def lorentzian(x, x_0: float = 0, sigma: float = 1.0): - """:param x: x values - :param x_0: Center - :param sigma: FWHM + """ + + Args: + x: x values + x_0: Center + sigma: FWHM. Returns: Value of lorentzian at x. diff --git a/pymatgen/core/structure.py b/pymatgen/core/structure.py index 0372ba10014..353e36b0362 100644 --- a/pymatgen/core/structure.py +++ b/pymatgen/core/structure.py @@ -56,6 +56,7 @@ from ase.optimize.optimize import Optimizer from matgl.ext.ase import TrajectoryObserver from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.util.typing import CompositionLike, SpeciesLike @@ -80,12 +81,14 @@ def __init__( index: int = 0, label: str | None = None, ) -> None: - """:param species: Same as Site - :param coords: Same as Site, but must be fractional. - :param properties: Same as Site - :param nn_distance: Distance to some other Site. - :param index: Index within structure. - :param label: Label for the site. Defaults to None. + """ + Args: + species: Same as Site + coords: Same as Site, but must be fractional. + properties: Same as Site + nn_distance: Distance to some other Site. + index: Index within structure. + label: Label for the site. Defaults to None. """ self.coords = coords self._species = species @@ -102,12 +105,12 @@ def __getitem__(self, idx: int): """Make neighbor Tuple-like to retain backwards compatibility.""" return (self, self.nn_distance, self.index)[idx] - def as_dict(self) -> dict: # type: ignore + def as_dict(self) -> dict: """Note that method calls the super of Site, which is MSONable itself.""" return super(Site, self).as_dict() @classmethod - def from_dict(cls, dct: dict) -> Neighbor: # type: ignore + def from_dict(cls, dct: dict) -> Self: """Returns a Neighbor from a dict. Args: @@ -174,12 +177,12 @@ def __getitem__(self, idx: int | slice): """Make neighbor Tuple-like to retain backwards compatibility.""" return (self, self.nn_distance, self.index, self.image)[idx] - def as_dict(self) -> dict: # type: ignore + def as_dict(self) -> dict: # type: ignore[override] """Note that method calls the super of Site, which is MSONable itself.""" return super(Site, self).as_dict() @classmethod - def from_dict(cls, dct: dict) -> PeriodicNeighbor: # type: ignore + def from_dict(cls, dct: dict) -> Self: # type: ignore[override] """Returns a PeriodicNeighbor from a dict. Args: @@ -318,19 +321,18 @@ def atomic_numbers(self) -> tuple[int, ...]: @property def site_properties(self) -> dict[str, Sequence]: - """Returns the site properties as a dict of sequences. E.g. {"magmom": (5, -5), "charge": (-4, 4)}.""" - props: dict[str, Sequence] = {} + """The site properties as a dict of sequences. + E.g. {"magmom": (5, -5), "charge": (-4, 4)}. + """ prop_keys: set[str] = set() for site in self: prop_keys.update(site.properties) - for key in prop_keys: - props[key] = [site.properties.get(key) for site in self] - return props + return {key: [site.properties.get(key) for site in self] for key in prop_keys} @property def labels(self) -> list[str]: - """Return site labels as a list.""" + """Site labels as a list.""" return [site.label for site in self] def __contains__(self, site: object) -> bool: @@ -578,8 +580,7 @@ def add_oxidation_state_by_element(self, oxidation_states: dict[str, float]) -> Returns: SiteCollection: self with oxidation states. """ - missing = {el.symbol for el in self.composition} - {*oxidation_states} - if missing: + if missing := {el.symbol for el in self.composition} - {*oxidation_states}: raise ValueError(f"Oxidation states not specified for all elements, {missing=}") for site in self: new_sp = {} @@ -792,7 +793,7 @@ def _relax( # UIP=universal interatomic potential run_uip = isinstance(calculator, str) and calculator.lower() in ("m3gnet", "chgnet") - calc_params = dict(stress_weight=stress_weight) if not is_molecule else {} + calc_params = {} if is_molecule else dict(stress_weight=stress_weight) calculator = self._prep_calculator(calculator, **calc_params) # check str is valid optimizer key @@ -1030,9 +1031,9 @@ def from_sites( ValueError: If sites is empty or sites do not have the same lattice. Returns: - (Structure) Note that missing properties are set as None. + IStructure: Note that missing properties are set as None. """ - if len(sites) < 1: + if not sites: raise ValueError(f"You need at least 1 site to construct a {cls.__name__}") prop_keys: list[str] = [] props = {} @@ -1123,11 +1124,11 @@ def from_spacegroup( except ValueError: spg = SpaceGroup(sg) # type: ignore - latt = lattice if isinstance(lattice, Lattice) else Lattice(lattice) + lattice = lattice if isinstance(lattice, Lattice) else Lattice(lattice) - if not spg.is_compatible(latt): + if not spg.is_compatible(lattice): raise ValueError( - f"Supplied lattice with parameters {latt.parameters} is incompatible with supplied spacegroup " + f"Supplied lattice with parameters {lattice.parameters} is incompatible with supplied spacegroup " f"{spg.symbol}!" ) @@ -1135,7 +1136,7 @@ def from_spacegroup( raise ValueError(f"Supplied species and coords lengths ({len(species)} vs {len(coords)}) are different!") frac_coords = ( - np.array(coords, dtype=np.float64) if not coords_are_cartesian else latt.get_fractional_coords(coords) + lattice.get_fractional_coords(coords) if coords_are_cartesian else np.array(coords, dtype=np.float64) ) props = {} if site_properties is None else site_properties @@ -1153,7 +1154,7 @@ def from_spacegroup( for k, v in props.items(): all_site_properties[k].extend([v[idx]] * len(cc)) - return cls(latt, all_sp, all_coords, site_properties=all_site_properties, labels=all_labels) + return cls(lattice, all_sp, all_coords, site_properties=all_site_properties, labels=all_labels) @classmethod def from_magnetic_spacegroup( @@ -1225,11 +1226,11 @@ def from_magnetic_spacegroup( if not isinstance(msg, MagneticSpaceGroup): msg = MagneticSpaceGroup(msg) - latt = lattice if isinstance(lattice, Lattice) else Lattice(lattice) + lattice = lattice if isinstance(lattice, Lattice) else Lattice(lattice) - if not msg.is_compatible(latt): + if not msg.is_compatible(lattice): raise ValueError( - f"Supplied lattice with parameters {latt.parameters} is incompatible with supplied spacegroup " + f"Supplied lattice with parameters {lattice.parameters} is incompatible with supplied spacegroup " f"{msg.sg_symbol}!" ) @@ -1237,7 +1238,7 @@ def from_magnetic_spacegroup( if len(var) != len(species): raise ValueError(f"Length mismatch: len({name})={len(var)} != {len(species)=}") - frac_coords = coords if not coords_are_cartesian else latt.get_fractional_coords(coords) + frac_coords = lattice.get_fractional_coords(coords) if coords_are_cartesian else coords all_sp: list[str | Element | Species | DummySpecies | Composition] = [] all_coords: list[list[float]] = [] @@ -1257,7 +1258,7 @@ def from_magnetic_spacegroup( all_site_properties["magmom"] = all_magmoms - return cls(latt, all_sp, all_coords, site_properties=all_site_properties, labels=all_labels) + return cls(lattice, all_sp, all_coords, site_properties=all_site_properties, labels=all_labels) def unset_charge(self) -> None: """Reset the charge to None. E.g. to compute it dynamically based on oxidation states.""" @@ -1723,11 +1724,11 @@ def get_symmetric_neighbor_list( sgp = SpaceGroup(sg) ops = sgp.symmetry_ops - latt = self.lattice + lattice = self.lattice - if not sgp.is_compatible(latt): + if not sgp.is_compatible(lattice): raise ValueError( - f"Supplied lattice with parameters {latt.parameters} is incompatible with " + f"Supplied lattice with parameters {lattice.parameters} is incompatible with " f"supplied spacegroup {sgp.symbol}!" ) @@ -2027,8 +2028,8 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc nmax = np.ceil(np.max(self.frac_coords, axis=0)) + maxr all_ranges = list(itertools.starmap(np.arange, zip(nmin, nmax))) - latt = self._lattice - matrix = latt.matrix + lattice = self._lattice + matrix = lattice.matrix neighbors = [[] for _ in range(len(self))] all_fcoords = np.mod(self.frac_coords, 1) coords_in_cell = np.dot(all_fcoords, matrix) @@ -2046,7 +2047,7 @@ def get_all_neighbors_old(self, r, include_index=False, include_image=False, inc nnsite = PeriodicSite( self[j].species, coords[j], - latt, + lattice, properties=self[j].properties, coords_are_cartesian=True, skip_checks=True, @@ -2103,19 +2104,22 @@ def get_sorted_structure(self, key: Callable | None = None, reverse: bool = Fals sites = sorted(self, key=key, reverse=reverse) return type(self).from_sites(sites, charge=self._charge, properties=self.properties) - def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "niggli") -> IStructure | Structure: + def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "niggli") -> Self: """Get a reduced structure. Args: reduction_algo ("niggli" | "LLL"): The lattice reduction algorithm to use. Defaults to "niggli". + + Returns: + Structure: Niggli- or LLL-reduced structure. """ if reduction_algo == "niggli": reduced_latt = self._lattice.get_niggli_reduced_lattice() elif reduction_algo == "LLL": reduced_latt = self._lattice.get_lll_reduced_lattice() else: - raise ValueError(f"Invalid {reduction_algo=}") + raise ValueError(f"Invalid {reduction_algo=}, must be 'niggli' or 'LLL'.") if reduced_latt != self.lattice: return type(self)( @@ -2130,7 +2134,12 @@ def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "nigg ) return self.copy() - def copy(self, site_properties=None, sanitize=False, properties=None) -> Structure: + def copy( + self, + site_properties: dict[str, Any] | None = None, + sanitize: bool = False, + properties: dict[str, Any] | None = None, + ) -> Structure: """Convenience method to get a copy of the structure, with options to add site properties. @@ -2238,7 +2247,7 @@ def interpolate( if not (interpolate_lattices or self.lattice == end_structure.lattice): raise ValueError("Structures with different lattices!") - images = np.arange(nimages + 1) / nimages if not isinstance(nimages, collections.abc.Iterable) else nimages + images = nimages if isinstance(nimages, collections.abc.Iterable) else np.arange(nimages + 1) / nimages # Check that both structures have the same species for idx, site in enumerate(self): @@ -2614,7 +2623,7 @@ def get_orderings(self, mode: Literal["enum", "sqs"] = "enum", **kwargs) -> list for idx in range(1, len(dists)): if dists[idx] - dists[idx - 1] > 0.1: unique_dists.append(dists[idx]) - clusters = {(i + 2): d + 0.01 for i, d in enumerate(unique_dists) if i < 2} + clusters = {(idx + 2): dist + 0.01 for idx, dist in enumerate(unique_dists) if idx < 2} kwargs["clusters"] = clusters return [run_mcsqs(self, **kwargs).bestsqs] raise ValueError("Invalid mode!") @@ -2639,7 +2648,7 @@ def as_dict(self, verbosity=1, fmt=None, **kwargs) -> dict[str, Any]: JSON-serializable dict representation. """ if fmt == "abivars": - """Returns a dictionary with the ABINIT variables.""" + # Returns a dictionary with the ABINIT variables from pymatgen.io.abinit.abiobjects import structure_to_abivars return structure_to_abivars(self, **kwargs) @@ -2690,7 +2699,7 @@ def as_dataframe(self): return df @classmethod - def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) -> Structure: + def from_dict(cls, dct: dict[str, Any], fmt: Literal["abivars"] | None = None) -> Self: """Reconstitute a Structure object from a dict representation of Structure created using as_dict(). @@ -3503,7 +3512,7 @@ def get_boxed_structure( z_max, z_min = max(new_coords[:, 2]), min(new_coords[:, 2]) if x_max > a or x_min < 0 or y_max > b or y_min < 0 or z_max > c or z_min < 0: raise ValueError("Molecule crosses boundary of box") - if len(all_coords) == 0: + if not all_coords: break distances = lattice.get_all_distances( lattice.get_fractional_coords(new_coords), @@ -3588,7 +3597,7 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: writer: Any if fmt == "xyz" or fnmatch(filename.lower(), "*.xyz*"): writer = XYZ(self) - elif any(fmt == ext or fnmatch(filename.lower(), f"*.{ext}*") for ext in ["gjf", "g03", "g09", "com", "inp"]): + elif any(fmt == ext or fnmatch(filename.lower(), f"*.{ext}*") for ext in ("gjf", "g03", "g09", "com", "inp")): writer = GaussianInput(self) elif fmt == "json" or fnmatch(filename, "*.json*") or fnmatch(filename, "*.mson*"): json_str = json.dumps(self.as_dict()) @@ -3596,7 +3605,7 @@ def to(self, filename: str = "", fmt: str = "") -> str | None: with zopen(filename, mode="wt", encoding="utf8") as file: file.write(json_str) return json_str - elif fmt in ("yaml", "yml") or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"): + elif fmt in {"yaml", "yml"} or fnmatch(filename, "*.yaml*") or fnmatch(filename, "*.yml*"): yaml = YAML() str_io = StringIO() yaml.dump(self.as_dict(), str_io) @@ -3655,7 +3664,7 @@ def from_str( # type: ignore[override] return cls.from_sites(mol, properties=mol.properties) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self | None: # type: ignore[override] """Reads a molecule from a file. Supported formats include xyz, gaussian input (gjf|g03|g09|com|inp), Gaussian output (.out|and pymatgen's JSON-serialized molecules. Using openbabel, @@ -3686,8 +3695,7 @@ def from_file(cls, filename): return cls.from_str(contents, fmt="yaml") from pymatgen.io.babel import BabelMolAdaptor - match = re.search(r"\.(pdb|mol|mdl|sdf|sd|ml2|sy2|mol2|cml|mrv)", filename.lower()) - if match: + if match := re.search(r"\.(pdb|mol|mdl|sdf|sd|ml2|sy2|mol2|cml|mrv)", filename.lower()): new = BabelMolAdaptor.from_file(filename, match.group(1)).pymatgen_mol new.__class__ = cls return new @@ -3731,7 +3739,7 @@ def __init__( disordered structures. coords (Nx3 array): list of fractional/cartesian coordinates of each species. - charge (int): overall charge of the structure. Defaults to behavior + charge (float): overall charge of the structure. Defaults to behavior in SiteCollection where total charge is the sum of the oxidation states. validate_proximity (bool): Whether to check if there are sites @@ -3928,7 +3936,7 @@ def replace( coords_are_cartesian: bool = False, properties: dict | None = None, label: str | None = None, - ) -> None: + ) -> Self: """Replace a single site. Takes either a species or a dict of species and occupations. @@ -3941,6 +3949,9 @@ def replace( Defaults to False. properties (dict): Properties associated with the site. label (str): Label associated with the site. + + Returns: + Structure: self with replaced site. """ if coords is None: frac_coords = self[idx].frac_coords @@ -3952,7 +3963,9 @@ def replace( new_site = PeriodicSite(species, frac_coords, self._lattice, properties=properties, label=label) self.sites[idx] = new_site - def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> None: + return self + + def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_order: int = 1) -> Self: """Substitute atom at index with a functional group. Args: @@ -3973,6 +3986,9 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or bond_order (int): A specified bond order to calculate the bond length between the attached functional group and the nearest neighbor site. Defaults to 1. + + Returns: + Structure: self with functional group attached. """ # Find the nearest neighbor that is not a terminal atom. all_non_terminal_nn = [] @@ -3995,13 +4011,16 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or # Pass value of functional group--either from user-defined or from # functional.json - if not isinstance(func_group, Molecule): + if isinstance(func_group, Molecule): + fgroup = func_group + + else: # Check to see whether the functional group is in database. if func_group not in FunctionalGroups: - raise RuntimeError("Can't find functional group in list. Provide explicit coordinate instead") + raise ValueError( + f"Can't find functional group {func_group!r} in list. Provide explicit coordinates instead" + ) fgroup = FunctionalGroups[func_group] - else: - fgroup = func_group # If a bond length can be found, modify func_grp so that the X-group # bond length is equal to the bond length. @@ -4046,11 +4065,16 @@ def substitute(self, index: int, func_group: IMolecule | Molecule | str, bond_or s_new = PeriodicSite(site.species, site.coords, self.lattice, coords_are_cartesian=True, label=site.label) self._sites.append(s_new) - def remove_species(self, species: Sequence[SpeciesLike]) -> None: + return self + + def remove_species(self, species: Sequence[SpeciesLike]) -> Self: """Remove all occurrences of several species from a structure. Args: species: Sequence of species to remove, e.g., ["Li", "Na"]. + + Returns: + Structure: self with species removed. """ new_sites = [] species = [get_el_sp(s) for s in species] @@ -4069,15 +4093,22 @@ def remove_species(self, species: Sequence[SpeciesLike]) -> None: ) self.sites = new_sites - def remove_sites(self, indices: Sequence[int | None]) -> None: + return self + + def remove_sites(self, indices: Sequence[int | None]) -> Self: """Delete sites with at indices. Args: indices: Sequence of indices of sites to delete. + + Returns: + Structure: self with sites removed. """ self.sites = [site for idx, site in enumerate(self) if idx not in indices] - def apply_operation(self, symmop: SymmOp, fractional: bool = False) -> Structure: + return self + + def apply_operation(self, symmop: SymmOp, fractional: bool = False) -> Self: """Apply a symmetry operation to the structure in place and return the modified structure. The lattice is operated on by the rotation matrix only. Coords are operated in full and then transformed to the new lattice. @@ -4137,7 +4168,7 @@ def apply_strain(self, strain: ArrayLike, inplace: bool = True) -> Structure: Structure copy. Defaults to True. Returns: - Structure: Structure with strain applied. + Structure: self if inplace=True else new structure with strain applied. """ strain_matrix = (1 + np.array(strain)) * np.eye(3) new_lattice = Lattice(np.dot(self._lattice.matrix.T, strain_matrix).T) @@ -4145,7 +4176,7 @@ def apply_strain(self, strain: ArrayLike, inplace: bool = True) -> Structure: struct.lattice = new_lattice return struct - def sort(self, key: Callable | None = None, reverse: bool = False) -> Structure: + def sort(self, key: Callable | None = None, reverse: bool = False) -> Self: """Sort a structure in place. The parameters have the same meaning as in list.sort(). By default, sites are sorted by the electronegativity of the species. The difference between this method and @@ -4160,14 +4191,14 @@ def sort(self, key: Callable | None = None, reverse: bool = False) -> Structure: as if each comparison were reversed. Returns: - Structure: Sorted structure. + Structure: self sorted. """ self._sites.sort(key=key, reverse=reverse) return self def translate_sites( self, indices: int | Sequence[int], vector: ArrayLike, frac_coords: bool = True, to_unit_cell: bool = True - ) -> Structure: + ) -> Self: """Translate specific sites by some vector, keeping the sites within the unit cell. Modifies the structure in place. @@ -4205,7 +4236,7 @@ def rotate_sites( axis: ArrayLike | None = None, anchor: ArrayLike | None = None, to_unit_cell: bool = True, - ) -> Structure: + ) -> Self: """Rotate specific sites by some angle around vector at anchor. Modifies the structure in place. @@ -4252,7 +4283,7 @@ def rotate_sites( return self - def perturb(self, distance: float, min_distance: float | None = None) -> Structure: + def perturb(self, distance: float, min_distance: float | None = None) -> Self: """Performs a random perturbation of the sites in a structure to break symmetries. Modifies the structure in place. @@ -4308,8 +4339,8 @@ def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, i Structure: self if in_place is True else self.copy() after making supercell """ # TODO (janosh) maybe default in_place to False after a depreciation period - struct = self if in_place else self.copy() - supercell = struct * scaling_matrix + struct: Structure = self if in_place else self.copy() + supercell: Structure = struct * scaling_matrix if to_unit_cell: for site in supercell: site.to_unit_cell(in_place=True) @@ -4318,7 +4349,7 @@ def make_supercell(self, scaling_matrix: ArrayLike, to_unit_cell: bool = True, i return struct - def scale_lattice(self, volume: float) -> Structure: + def scale_lattice(self, volume: float) -> Self: """Performs a scaling of the lattice vectors so that length proportions and angles are preserved. @@ -4332,7 +4363,7 @@ def scale_lattice(self, volume: float) -> Structure: return self - def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Structure: + def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum") -> Self: """Merges sites (adding occupancies) within tol of each other. Removes site properties. @@ -4376,7 +4407,7 @@ def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average self._sites = sites return self - def set_charge(self, new_charge: float = 0.0) -> Structure: + def set_charge(self, new_charge: float = 0.0) -> Self: """Sets the overall structure charge. Args: @@ -4447,7 +4478,7 @@ def calculate(self, calculator: str | Calculator = "m3gnet", verbose: bool = Fal return self._calculate(calculator, verbose=verbose) @classmethod - def from_prototype(cls, prototype: str, species: Sequence, **kwargs) -> Structure: + def from_prototype(cls, prototype: str, species: Sequence, **kwargs) -> Self: """Method to rapidly construct common prototype structures. Args: @@ -4479,22 +4510,23 @@ def from_prototype(cls, prototype: str, species: Sequence, **kwargs) -> Structur return Structure.from_spacegroup( "Pm-3m", Lattice.cubic(kwargs["a"]), species, [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0]] ) - if prototype in ("cscl"): + if prototype == "cscl": return Structure.from_spacegroup( "Pm-3m", Lattice.cubic(kwargs["a"]), species, [[0, 0, 0], [0.5, 0.5, 0.5]] ) - if prototype in ("fluorite", "caf2"): + if prototype in {"fluorite", "caf2"}: return Structure.from_spacegroup( "Fm-3m", Lattice.cubic(kwargs["a"]), species, [[0, 0, 0], [1 / 4, 1 / 4, 1 / 4]] ) - if prototype in ("antifluorite"): + if prototype == "antifluorite": return Structure.from_spacegroup( "Fm-3m", Lattice.cubic(kwargs["a"]), species, [[1 / 4, 1 / 4, 1 / 4], [0, 0, 0]] ) - if prototype in ("zincblende"): + if prototype == "zincblende": return Structure.from_spacegroup( "F-43m", Lattice.cubic(kwargs["a"]), species, [[0, 0, 0], [1 / 4, 1 / 4, 3 / 4]] ) + except KeyError as exc: raise ValueError(f"Required parameter {exc} not specified as a kwargs!") from exc raise ValueError(f"Unsupported {prototype=}!") @@ -4783,8 +4815,8 @@ def rotate_sites( for idx in indices: site = self[idx] - s = ((np.dot(rm, (site.coords - anchor).T)).T + anchor).ravel() - new_site = Site(site.species, s, properties=site.properties, label=site.label) + coords = ((np.dot(rm, (site.coords - anchor).T)).T + anchor).ravel() + new_site = Site(site.species, coords, properties=site.properties, label=site.label) self[idx] = new_site return self @@ -4979,5 +5011,5 @@ class StructureError(Exception): """ -with open(os.path.join(os.path.dirname(__file__), "func_groups.json")) as file: +with open(os.path.join(os.path.dirname(__file__), "func_groups.json"), encoding="utf-8") as file: FunctionalGroups = {k: Molecule(v["species"], v["coords"]) for k, v in json.load(file).items()} diff --git a/pymatgen/core/surface.py b/pymatgen/core/surface.py index 087f1e26651..03af6ad463d 100644 --- a/pymatgen/core/surface.py +++ b/pymatgen/core/surface.py @@ -1,4 +1,6 @@ -"""This module implements representations of slabs and surfaces + algorithms for generating them. +"""This module implements representation of Slab, SlabGenerator +for generating Slabs, ReconstructionGenerator to generate +reconstructed Slabs, and some related utility functions. If you use this module, please consider citing the following work: @@ -20,8 +22,8 @@ import os import warnings from functools import reduce -from math import gcd -from typing import TYPE_CHECKING +from math import gcd, isclose +from typing import TYPE_CHECKING, cast import numpy as np from monty.fractions import lcm @@ -35,6 +37,13 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any + + from numpy.typing import ArrayLike + from typing_extensions import Self + + from pymatgen.core.composition import Element, Species from pymatgen.symmetry.groups import CrystalSystem __author__ = "Richard Tran, Wenhao Sun, Zihan Xu, Shyue Ping Ong" @@ -52,51 +61,42 @@ class Slab(Structure): - """Subclass of Structure representing a Slab. Implements additional + """Class to hold information for a Slab, with additional attributes pertaining to slabs, but the init method does not - actually implement any algorithm that creates a slab. This is a - DUMMY class who's init method only holds information about the - slab. Also has additional methods that returns other information - about a slab such as the surface area, normal, and atom adsorption. + actually create a slab. Also has additional methods that returns other information + about a Slab such as the surface area, normal, and atom adsorption. Note that all Slabs have the surface normal oriented perpendicular to the a and b lattice vectors. This means the lattice vectors a and b are in the surface plane and the c vector is out of the surface plane (though not necessarily perpendicular to the surface). - - Attributes: - miller_index (tuple): Miller index of plane parallel to surface. - scale_factor (float): Final computed scale factor that brings the parent cell to the surface cell. - shift (float): The shift value in Angstrom that indicates how much this slab has been shifted. """ def __init__( self, - lattice, - species, - coords, - miller_index, - oriented_unit_cell, - shift, - scale_factor, - reorient_lattice=True, - validate_proximity=False, - to_unit_cell=False, - reconstruction=None, - coords_are_cartesian=False, - site_properties=None, - energy=None, - properties=None, + lattice: Lattice | np.ndarray, + species: Sequence[Any], + coords: np.ndarray, + miller_index: tuple[int, int, int], + oriented_unit_cell: Structure, + shift: float, + scale_factor: np.ndarray, + reorient_lattice: bool = True, + validate_proximity: bool = False, + to_unit_cell: bool = False, + reconstruction: str | None = None, + coords_are_cartesian: bool = False, + site_properties: dict | None = None, + energy: float | None = None, ) -> None: - """Makes a Slab structure, a structure object with additional information - and methods pertaining to slabs. + """A Structure object with additional information + and methods pertaining to Slabs. Args: lattice (Lattice/3x3 array): The lattice, either as a - pymatgen.core.Lattice or - simply as any 2D array. Each row should correspond to a lattice - vector. E.g., [[10,0,0], [20,10,0], [0,0,30]] specifies a - lattice with lattice vectors [10,0,0], [20,10,0] and [0,0,30]. + pymatgen.core.Lattice or simply as any 2D array. + Each row should correspond to a lattice + vector. E.g., [[10,0,0], [20,10,0], [0,0,30]]. species ([Species]): Sequence of species on each site. Can take in flexible input, including: @@ -105,10 +105,10 @@ def __init__( e.g., (3, 56, ...) or actual Element or Species objects. ii. List of dict of elements/species and occupancies, e.g., - [{"Fe" : 0.5, "Mn":0.5}, ...]. This allows the setup of + [{"Fe": 0.5, "Mn": 0.5}, ...]. This allows the setup of disordered structures. coords (Nx3 array): list of fractional/cartesian coordinates of each species. - miller_index ([h, k, l]): Miller index of plane parallel to + miller_index (tuple[h, k, l]): Miller index of plane parallel to surface. Note that this is referenced to the input structure. If you need this to be based on the conventional cell, you should supply the conventional structure. @@ -132,16 +132,15 @@ def __init__( have to be the same length as the atomic species and fractional_coords. Defaults to None for no properties. energy (float): A value for the energy. - properties (dict): dictionary containing properties associated - with the whole slab. """ self.oriented_unit_cell = oriented_unit_cell - self.miller_index = tuple(miller_index) + self.miller_index = miller_index self.shift = shift self.reconstruction = reconstruction - self.scale_factor = np.array(scale_factor) + self.scale_factor = scale_factor self.energy = energy self.reorient_lattice = reorient_lattice + if self.reorient_lattice: if coords_are_cartesian: coords = lattice.get_fractional_coords(coords) @@ -165,20 +164,298 @@ def __init__( site_properties=site_properties, ) - def get_orthogonal_c_slab(self): - """This method returns a Slab where the normal (c lattice vector) is - "forced" to be exactly orthogonal to the surface a and b lattice - vectors. **Note that this breaks inherent symmetries in the slab.** + def __str__(self) -> str: + outs = [ + f"Slab Summary ({self.composition.formula})", + f"Reduced Formula: {self.composition.reduced_formula}", + f"Miller index: {self.miller_index}", + f"Shift: {self.shift:.4f}, Scale Factor: {self.scale_factor}", + f"abc : {' '.join(f'{i:0.6f}'.rjust(10) for i in self.lattice.abc)}", + f"angles: {' '.join(f'{i:0.6f}'.rjust(10) for i in self.lattice.angles)}", + f"Sites ({len(self)})", + ] + + for idx, site in enumerate(self): + outs.append(f"{idx + 1} {site.species_string} {' '.join(f'{j:0.6f}'.rjust(12) for j in site.frac_coords)}") + + return "\n".join(outs) + + @property + def center_of_mass(self) -> np.ndarray: + """The center of mass of the Slab in fractional coordinates.""" + weights = [site.species.weight for site in self] + return np.average(self.frac_coords, weights=weights, axis=0) + + @property + def dipole(self) -> np.ndarray: + """The dipole moment of the Slab in the direction of the surface normal. + + Note that the Slab must be oxidation state decorated for this to work properly. + Otherwise, the Slab will always have a dipole moment of 0. + """ + centroid = np.sum(self.cart_coords, axis=0) / len(self) + + dipole = np.zeros(3) + for site in self: + charge = sum(getattr(sp, "oxi_state", 0) * amt for sp, amt in site.species.items()) + dipole += charge * np.dot(site.coords - centroid, self.normal) * self.normal + return dipole + + @property + def normal(self) -> np.ndarray: + """The surface normal vector of the Slab, normalized to unit length.""" + normal = np.cross(self.lattice.matrix[0], self.lattice.matrix[1]) + normal /= np.linalg.norm(normal) + return normal + + @property + def surface_area(self) -> float: + """The surface area of the Slab.""" + matrix = self.lattice.matrix + return np.linalg.norm(np.cross(matrix[0], matrix[1])) + + @classmethod + def from_dict(cls, dct: dict[str, Any]) -> Self: # type: ignore[override] + """ + Args: + dct: dict. + + Returns: + Creates slab from dict. + """ + lattice = Lattice.from_dict(dct["lattice"]) + sites = [PeriodicSite.from_dict(sd, lattice) for sd in dct["sites"]] + struct = Structure.from_sites(sites) + + return Slab( + lattice=lattice, + species=struct.species_and_occu, + coords=struct.frac_coords, + miller_index=dct["miller_index"], + oriented_unit_cell=Structure.from_dict(dct["oriented_unit_cell"]), + shift=dct["shift"], + scale_factor=np.array(dct["scale_factor"]), + site_properties=struct.site_properties, + energy=dct["energy"], + ) + + def as_dict(self, **kwargs) -> dict: # type: ignore[override] + """MSONable dict.""" + dct = super().as_dict(**kwargs) + dct["@module"] = type(self).__module__ + dct["@class"] = type(self).__name__ + dct["oriented_unit_cell"] = self.oriented_unit_cell.as_dict() + dct["miller_index"] = self.miller_index + dct["shift"] = self.shift + dct["scale_factor"] = self.scale_factor.tolist() # np.ndarray is not JSON serializable + dct["reconstruction"] = self.reconstruction + dct["energy"] = self.energy + return dct + + def copy(self, site_properties: dict[str, Any] | None = None) -> Slab: # type: ignore[override] + """Get a copy of the structure, with options to update site properties. + + Args: + site_properties (dict): Properties to update. The + properties are specified in the same way as the constructor, + i.e., as a dict of the form {property: [values]}. + + Returns: + A copy of the Structure, with optionally new site_properties + """ + props = self.site_properties + if site_properties: + props.update(site_properties) + + return Slab( + self.lattice, + self.species_and_occu, + self.frac_coords, + self.miller_index, + self.oriented_unit_cell, + self.shift, + self.scale_factor, + site_properties=props, + reorient_lattice=self.reorient_lattice, + ) + + def is_symmetric(self, symprec: float = 0.1) -> bool: + """Check if Slab is symmetric, i.e., contains inversion, mirror on (hkl) plane, + or screw axis (rotation and translation) about [hkl]. + + Args: + symprec (float): Symmetry precision used for SpaceGroup analyzer. + + Returns: + bool: Whether surfaces are symmetric. + """ + spg_analyzer = SpacegroupAnalyzer(self, symprec=symprec) + symm_ops = spg_analyzer.get_point_group_operations() + + # Check for inversion symmetry. Or if sites from surface (a) can be translated + # to surface (b) along the [hkl]-axis, surfaces are symmetric. Or because the + # two surfaces of our slabs are always parallel to the (hkl) plane, + # any operation where there's an (hkl) mirror plane has surface symmetry + return ( + spg_analyzer.is_laue() + or any(op.translation_vector[2] != 0 for op in symm_ops) + or any(np.all(op.rotation_matrix[2] == np.array([0, 0, -1])) for op in symm_ops) + ) + + def is_polar(self, tol_dipole_per_unit_area: float = 1e-3) -> bool: + """Check if the Slab is polar by computing the normalized dipole per unit area. + Normalized dipole per unit area is used as it is more reliable than + using the absolute value, which varies with surface area. + + Note that the Slab must be oxidation state decorated for this to work properly. + Otherwise, the Slab will always have a dipole moment of 0. + + Args: + tol_dipole_per_unit_area (float): A tolerance above which the Slab is + considered polar. + """ + dip_per_unit_area = self.dipole / self.surface_area + return np.linalg.norm(dip_per_unit_area) > tol_dipole_per_unit_area + + def get_surface_sites(self, tag: bool = False) -> dict[str, list]: + """Returns the surface sites and their indices in a dictionary. + Useful for analysis involving broken bonds and for finding adsorption sites. + + The oriented unit cell of the slab will determine the + coordination number of a typical site. + We use VoronoiNN to determine the coordination number of sites. + Due to the pathological error resulting from some surface sites in the + VoronoiNN, we assume any site that has this error is a surface + site as well. This will only work for single-element systems for now. + + Args: + tag (bool): Option to adds site attribute "is_surfsite" (bool) + to all sites of slab. Defaults to False + + Returns: + A dictionary grouping sites on top and bottom of the slab together. + {"top": [sites with indices], "bottom": [sites with indices]} + + Todo: + Is there a way to determine site equivalence between sites in a slab + and bulk system? This would allow us get the coordination number of + a specific site for multi-elemental systems or systems with more + than one inequivalent site. This will allow us to use this for + compound systems. + """ + from pymatgen.analysis.local_env import VoronoiNN + + # Get a dictionary of coordination numbers for each distinct site in the structure + spg_analyzer = SpacegroupAnalyzer(self.oriented_unit_cell) + u_cell = spg_analyzer.get_symmetrized_structure() + cn_dict: dict = {} + voronoi_nn = VoronoiNN() + unique_indices = [equ[0] for equ in u_cell.equivalent_indices] + + for idx in unique_indices: + el = u_cell[idx].species_string + if el not in cn_dict: + cn_dict[el] = [] + # Since this will get the CN as a result of the weighted polyhedra, the + # slightest difference in CN will indicate a different environment for a + # species, eg. bond distance of each neighbor or neighbor species. The + # decimal place to get some CN to be equal. + cn = voronoi_nn.get_cn(u_cell, idx, use_weights=True) + cn = float(f"{round(cn, 5):.5f}") + if cn not in cn_dict[el]: + cn_dict[el].append(cn) + + voronoi_nn = VoronoiNN() + + surf_sites_dict: dict = {"top": [], "bottom": []} + properties: list = [] + for idx, site in enumerate(self): + # Determine if site is closer to the top or bottom of the slab + is_top: bool = site.frac_coords[2] > self.center_of_mass[2] + + try: + # A site is a surface site, if its environment does + # not fit the environment of other sites + cn = float(f"{round(voronoi_nn.get_cn(self, idx, use_weights=True), 5):.5f}") + if cn < min(cn_dict[site.species_string]): + properties.append(True) + key = "top" if is_top else "bottom" + surf_sites_dict[key].append([site, idx]) + else: + properties.append(False) + except RuntimeError: + # or if pathological error is returned, indicating a surface site + properties.append(True) + key = "top" if is_top else "bottom" + surf_sites_dict[key].append([site, idx]) + + if tag: + self.add_site_property("is_surf_site", properties) + return surf_sites_dict + + def get_symmetric_site( + self, + point: ArrayLike, + cartesian: bool = False, + ) -> ArrayLike: + """This method uses symmetry operations to find an equivalent site on + the other side of the slab. Works mainly for slabs with Laue symmetry. + + This is useful for retaining the non-polar and + symmetric properties of a slab when creating adsorbed + structures or symmetric reconstructions. + + TODO (@DanielYang59): use "site" over "point" as arg name for consistency + + Args: + point (ArrayLike): Fractional coordinate of the original site. + cartesian (bool): Use Cartesian coordinates. + + Returns: + ArrayLike: Fractional coordinate. A site equivalent to the + original site, but on the other side of the slab + """ + spg_analyzer = SpacegroupAnalyzer(self) + ops = spg_analyzer.get_symmetry_operations(cartesian=cartesian) + + # Each operation on a site will return an equivalent site. + # We want to find the site on the other side of the slab. + for op in ops: + slab = self.copy() + site_other = op.operate(point) + if f"{site_other[2]:.6f}" == f"{point[2]:.6f}": + continue + + # Add dummy sites to check if the overall structure is symmetric + slab.append("O", point, coords_are_cartesian=cartesian) + slab.append("O", site_other, coords_are_cartesian=cartesian) + if SpacegroupAnalyzer(slab).is_laue(): + break + + # If not symmetric, remove the two added + # sites and try another symmetry operator + slab.remove_sites([len(slab) - 1]) + slab.remove_sites([len(slab) - 1]) + + return site_other + + def get_orthogonal_c_slab(self) -> Slab: + """Generate a Slab where the normal (c lattice vector) is + forced to be orthogonal to the surface a and b lattice vectors. + + **Note that this breaks inherent symmetries in the slab.** + It should be pointed out that orthogonality is not required to get good surface energies, but it can be useful in cases where the slabs are subsequently used for postprocessing of some kind, e.g. generating - GBs or interfaces. + grain boundaries or interfaces. """ a, b, c = self.lattice.matrix - new_c = np.cross(a, b) - new_c /= np.linalg.norm(new_c) - new_c = np.dot(c, new_c) * new_c + _new_c = np.cross(a, b) + _new_c /= np.linalg.norm(_new_c) + new_c = np.dot(c, _new_c) * _new_c new_latt = Lattice([a, b, new_c]) + return Slab( lattice=new_latt, species=self.species_and_occu, @@ -193,21 +470,31 @@ def get_orthogonal_c_slab(self): site_properties=self.site_properties, ) - def get_tasker2_slabs(self, tol: float = 0.01, same_species_only=True): + def get_tasker2_slabs( + self, + tol: float = 0.01, + same_species_only: bool = True, + ) -> list[Slab]: """Get a list of slabs that have been Tasker 2 corrected. Args: - tol (float): Tolerance to determine if atoms are within same plane. - This is a fractional tolerance, not an absolute one. - same_species_only (bool): If True, only that are of the exact same - species as the atom at the outermost surface are considered for - moving. Otherwise, all atoms regardless of species that is - within tol are considered for moving. Default is True (usually - the desired behavior). + tol (float): Fractional tolerance to determine if atoms are within same plane. + same_species_only (bool): If True, only those are of the exact same + species as the atom at the outermost surface are considered for moving. + Otherwise, all atoms regardless of species within tol are considered for moving. + Default is True (usually the desired behavior). Returns: - list[Slab]: tasker 2 corrected slabs. + list[Slab]: Tasker 2 corrected slabs. """ + + def get_equi_index(site: PeriodicSite) -> int: + """Get the index of the equivalent site for a given site.""" + for idx, equi_sites in enumerate(symm_structure.equivalent_sites): + if site in equi_sites: + return idx + raise ValueError("Cannot determine equi index!") + sites = list(self.sites) slabs = [] @@ -219,14 +506,8 @@ def get_tasker2_slabs(self, tol: float = 0.01, same_species_only=True): n_layers_slab = int(round((sorted_csites[-1].c - sorted_csites[0].c) * n_layers_total)) slab_ratio = n_layers_slab / n_layers_total - spga = SpacegroupAnalyzer(self) - symm_structure = spga.get_symmetrized_structure() - - def equi_index(site) -> int: - for idx, equi_sites in enumerate(symm_structure.equivalent_sites): - if site in equi_sites: - return idx - raise ValueError("Cannot determine equi index!") + spg_analyzer = SpacegroupAnalyzer(self) + symm_structure = spg_analyzer.get_symmetrized_structure() for surface_site, shift in [(sorted_csites[0], slab_ratio), (sorted_csites[-1], -slab_ratio)]: to_move = [] @@ -240,9 +521,9 @@ def equi_index(site) -> int: fixed.append(site) # Sort and group the sites by the species and symmetry equivalence - to_move = sorted(to_move, key=equi_index) + to_move = sorted(to_move, key=get_equi_index) - grouped = [list(sites) for k, sites in itertools.groupby(to_move, key=equi_index)] + grouped = [list(sites) for k, sites in itertools.groupby(to_move, key=get_equi_index)] if len(to_move) == 0 or any(len(g) % 2 != 0 for g in grouped): warnings.warn( @@ -260,15 +541,15 @@ def equi_index(site) -> int: species = [site.species for site in fixed] frac_coords = [site.frac_coords for site in fixed] - for s in to_move: - species.append(s.species) + for struct_matcher in to_move: + species.append(struct_matcher.species) for group in selection: - if s in group: - frac_coords.append(s.frac_coords) + if struct_matcher in group: + frac_coords.append(struct_matcher.frac_coords) break else: # Move unselected atom to the opposite surface. - frac_coords.append(s.frac_coords + [0, 0, shift]) # noqa: RUF005 + frac_coords.append(struct_matcher.frac_coords + [0, 0, shift]) # noqa: RUF005 # sort by species to put all similar species together. sp_fcoord = sorted(zip(species, frac_coords), key=lambda x: x[0]) @@ -286,33 +567,10 @@ def equi_index(site) -> int: reorient_lattice=self.reorient_lattice, ) slabs.append(slab) - s = StructureMatcher() - return [ss[0] for ss in s.group_structures(slabs)] - - def is_symmetric(self, symprec: float = 0.1): - """Checks if surfaces are symmetric, i.e., contains inversion, mirror on (hkl) plane, - or screw axis (rotation and translation) about [hkl]. - - Args: - symprec (float): Symmetry precision used for SpaceGroup analyzer. - - Returns: - bool: Whether surfaces are symmetric. - """ - sg = SpacegroupAnalyzer(self, symprec=symprec) - symm_ops = sg.get_point_group_operations() - - # Check for inversion symmetry. Or if sites from surface (a) can be translated - # to surface (b) along the [hkl]-axis, surfaces are symmetric. Or because the - # two surfaces of our slabs are always parallel to the (hkl) plane, - # any operation where there's an (hkl) mirror plane has surface symmetry - return ( - sg.is_laue() - or any(op.translation_vector[2] != 0 for op in symm_ops) - or any(np.all(op.rotation_matrix[2] == np.array([0, 0, -1])) for op in symm_ops) - ) + struct_matcher = StructureMatcher() + return [ss[0] for ss in struct_matcher.group_structures(slabs)] - def get_sorted_structure(self, key=None, reverse=False): + def get_sorted_structure(self, key=None, reverse: bool = False) -> Slab: """Get a sorted copy of the structure. The parameters have the same meaning as in list.sort. By default, sites are sorted by the electronegativity of the species. Note that Slab has to override this @@ -339,521 +597,460 @@ def get_sorted_structure(self, key=None, reverse=False): reorient_lattice=self.reorient_lattice, ) - def copy(self, site_properties=None, sanitize=False): - """Convenience method to get a copy of the structure, with options to add - site properties. + def add_adsorbate_atom( + self, + indices: list[int], + species: str | Element | Species, + distance: float, + specie: Species | Element | str | None = None, + ) -> Self: + """Add adsorbate onto the Slab, along the c lattice vector. Args: - site_properties (dict): Properties to add or override. The - properties are specified in the same way as the constructor, - i.e., as a dict of the form {property: [values]}. The - properties should be in the order of the *original* structure - if you are performing sanitization. - sanitize (bool): If True, this method will return a sanitized - structure. Sanitization performs a few things: (i) The sites are - sorted by electronegativity, (ii) a LLL lattice reduction is - carried out to obtain a relatively orthogonalized cell, - (iii) all fractional coords for sites are mapped into the - unit cell. + indices (list[int]): Indices of sites on which to put the adsorbate. + Adsorbate will be placed relative to the center of these sites. + species (str | Element | Species): The species to add. + distance (float): between centers of the adsorbed atom and the + given site in Angstroms, along the c lattice vector. + specie: Deprecated argument in #3691. Use 'species' instead. Returns: - A copy of the Structure, with optionally new site_properties and - optionally sanitized. - """ - props = self.site_properties - if site_properties: - props.update(site_properties) - return Slab( - self.lattice, - self.species_and_occu, - self.frac_coords, - self.miller_index, - self.oriented_unit_cell, - self.shift, - self.scale_factor, - site_properties=props, - reorient_lattice=self.reorient_lattice, - ) - - @property - def dipole(self): - """Calculates the dipole of the Slab in the direction of the surface - normal. Note that the Slab must be oxidation state-decorated for this - to work properly. Otherwise, the Slab will always have a dipole of 0. + Slab: self with adsorbed atom. """ - dipole = np.zeros(3) - mid_pt = np.sum(self.cart_coords, axis=0) / len(self) - normal = self.normal - for site in self: - charge = sum(getattr(sp, "oxi_state", 0) * amt for sp, amt in site.species.items()) - dipole += charge * np.dot(site.coords - mid_pt, normal) * normal - return dipole + # Check if deprecated argument is used + if specie is not None: + warnings.warn("The argument 'specie' is deprecated. Use 'species' instead.", DeprecationWarning) + species = specie - def is_polar(self, tol_dipole_per_unit_area=1e-3) -> bool: - """Checks whether the surface is polar by computing the dipole per unit - area. Note that the Slab must be oxidation state-decorated for this - to work properly. Otherwise, the Slab will always be non-polar. + # Calculate target site as the center of sites + center = np.sum([self[idx].coords for idx in indices], axis=0) / len(indices) - Args: - tol_dipole_per_unit_area (float): A tolerance. If the dipole - magnitude per unit area is less than this value, the Slab is - considered non-polar. Defaults to 1e-3, which is usually - pretty good. Normalized dipole per unit area is used as it is - more reliable than using the total, which tends to be larger for - slabs with larger surface areas. - """ - dip_per_unit_area = self.dipole / self.surface_area - return np.linalg.norm(dip_per_unit_area) > tol_dipole_per_unit_area + coords = center + self.normal * distance - @property - def normal(self): - """Calculates the surface normal vector of the slab.""" - normal = np.cross(self.lattice.matrix[0], self.lattice.matrix[1]) - normal /= np.linalg.norm(normal) - return normal + self.append(species, coords, coords_are_cartesian=True) - @property - def surface_area(self): - """Calculates the surface area of the slab.""" - matrix = self.lattice.matrix - return np.linalg.norm(np.cross(matrix[0], matrix[1])) + return self - @property - def center_of_mass(self): - """Calculates the center of mass of the slab.""" - weights = [s.species.weight for s in self] - return np.average(self.frac_coords, weights=weights, axis=0) + def symmetrically_add_atom( + self, + species: str | Element | Species, + point: ArrayLike, + specie: str | Element | Species | None = None, + coords_are_cartesian: bool = False, + ) -> None: + """Add a species at a specified site in a slab. Will also add an + equivalent site on the other side of the slab to maintain symmetry. - def add_adsorbate_atom(self, indices, specie, distance) -> Slab: - """Gets the structure of single atom adsorption. - slab structure from the Slab class(in [0, 0, 1]). + TODO (@DanielYang59): use "site" over "point" as arg name for consistency Args: - indices ([int]): Indices of sites on which to put the adsorbate. - Absorbed atom will be displaced relative to the center of - these sites. - specie (Species/Element/str): adsorbed atom species - distance (float): between centers of the adsorbed atom and the - given site in Angstroms. - - Returns: - Slab: self with adsorbed atom. + species (str | Element | Species): The species to add. + point (ArrayLike): The coordinate of the target site. + specie: Deprecated argument name in #3691. Use 'species' instead. + coords_are_cartesian (bool): If the site is in Cartesian coordinates. """ - # Let's work in Cartesian coords - center = np.sum([self[idx].coords for idx in indices], axis=0) / len(indices) + # For now just use the species of the surface atom as the element to add - coords = center + self.normal * distance / np.linalg.norm(self.normal) + # Check if deprecated argument is used + if specie is not None: + warnings.warn("The argument 'specie' is deprecated. Use 'species' instead.", DeprecationWarning) + species = specie - self.append(specie, coords, coords_are_cartesian=True) + # Get the index of the equivalent site on the other side + equi_site = self.get_symmetric_site(point, cartesian=coords_are_cartesian) - return self + self.append(species, point, coords_are_cartesian=coords_are_cartesian) + self.append(species, equi_site, coords_are_cartesian=coords_are_cartesian) - def __str__(self) -> str: - def to_str(x) -> str: - return f"{x:0.6f}" + def symmetrically_remove_atoms(self, indices: list[int]) -> None: + """Remove sites from a list of indices. Will also remove the + equivalent site on the other side of the slab to maintain symmetry. - comp = self.composition - outs = [ - f"Slab Summary ({comp.formula})", - f"Reduced Formula: {comp.reduced_formula}", - f"Miller index: {self.miller_index}", - f"Shift: {self.shift:.4f}, Scale Factor: {self.scale_factor}", - f"abc : {' '.join(f'{i:0.6f}'.rjust(10) for i in self.lattice.abc)}", - f"angles: {' '.join(f'{i:0.6f}'.rjust(10) for i in self.lattice.angles)}", - f"Sites ({len(self)})", - ] + Args: + indices (list[int]): The indices of the sites to remove. - for idx, site in enumerate(self): - outs.append(f"{idx + 1} {site.species_string} {' '.join(f'{j:0.6f}'.rjust(12) for j in site.frac_coords)}") - return "\n".join(outs) + TODO(@DanielYang59): + 1. Reuse public method get_symmetric_site to get equi sites? + 2. If not 1, get_equi_sites has multiple nested loops + """ - def as_dict(self): - """MSONable dict.""" - dct = super().as_dict() - dct["@module"] = type(self).__module__ - dct["@class"] = type(self).__name__ - dct["oriented_unit_cell"] = self.oriented_unit_cell.as_dict() - dct["miller_index"] = self.miller_index - dct["shift"] = self.shift - dct["scale_factor"] = self.scale_factor.tolist() - dct["reconstruction"] = self.reconstruction - dct["energy"] = self.energy - return dct + def get_equi_sites(slab: Slab, sites: list[int]) -> list[int]: + """ + Get the indices of the equivalent sites of given sites. + + Parameters: + slab (Slab): The slab structure. + sites (list[int]): Original indices of sites. + + Returns: + list[int]: Indices of the equivalent sites. + """ + equi_sites = [] + + for pt in sites: + # Get the index of the original site + cart_point = slab.lattice.get_cartesian_coords(pt) + dist = [site.distance_from_point(cart_point) for site in slab] + site1 = dist.index(min(dist)) + + # Get the index of the equivalent site on the other side + for i, eq_sites in enumerate(slab.equivalent_sites): + if slab[site1] in eq_sites: + eq_indices = slab.equivalent_indices[i] + break + i1 = eq_indices[eq_sites.index(slab[site1])] + + for i2 in eq_indices: + if i2 == i1: + continue + if slab[i2].frac_coords[2] == slab[i1].frac_coords[2]: + continue + # Test site remove to see if it results in symmetric slab + slab = self.copy() + slab.remove_sites([i1, i2]) + if slab.is_symmetric(): + equi_sites.append(i2) + break + + return equi_sites + + # Generate the equivalent sites of the original sites + slab_copy = SpacegroupAnalyzer(self.copy()).get_symmetrized_structure() + sites = [slab_copy[i].frac_coords for i in indices] - @classmethod - def from_dict(cls, dct: dict) -> Slab: # type: ignore[override] - """:param dct: dict + equi_sites = get_equi_sites(slab_copy, sites) - Returns: - Creates slab from dict. - """ - lattice = Lattice.from_dict(dct["lattice"]) - sites = [PeriodicSite.from_dict(sd, lattice) for sd in dct["sites"]] - struct = Structure.from_sites(sites) + # Check if found any equivalent sites + if len(equi_sites) == len(indices): + self.remove_sites(indices) + self.remove_sites(equi_sites) - return Slab( - lattice=lattice, - species=struct.species_and_occu, - coords=struct.frac_coords, - miller_index=dct["miller_index"], - oriented_unit_cell=Structure.from_dict(dct["oriented_unit_cell"]), - shift=dct["shift"], - scale_factor=dct["scale_factor"], - site_properties=struct.site_properties, - energy=dct["energy"], - properties=dct.get("properties"), - ) + else: + warnings.warn("Equivalent sites could not be found for some indices. Surface unchanged.") - def get_surface_sites(self, tag=False): - """Returns the surface sites and their indices in a dictionary. The - oriented unit cell of the slab will determine the coordination number - of a typical site. We use VoronoiNN to determine the - coordination number of bulk sites and slab sites. Due to the - pathological error resulting from some surface sites in the - VoronoiNN, we assume any site that has this error is a surface - site as well. This will work for elemental systems only for now. Useful - for analysis involving broken bonds and for finding adsorption sites. - Args: - tag (bool): Option to adds site attribute "is_surfsite" (bool) - to all sites of slab. Defaults to False +def center_slab(slab: Slab) -> Slab: + """Relocate the Slab to the center such that its center + (the slab region) is close to z=0.5. - Returns: - A dictionary grouping sites on top and bottom of the slab together. - {"top": [sites with indices], "bottom": [sites with indices} + This makes it easier to find surface sites and apply + operations like doping. - Todo: - Is there a way to determine site equivalence between sites in a slab - and bulk system? This would allow us get the coordination number of - a specific site for multi-elemental systems or systems with more - than one inequivalent site. This will allow us to use this for - compound systems. - """ - from pymatgen.analysis.local_env import VoronoiNN + There are two possible cases: - # Get a dictionary of coordination numbers - # for each distinct site in the structure - spga = SpacegroupAnalyzer(self.oriented_unit_cell) - u_cell = spga.get_symmetrized_structure() - cn_dict = {} - voronoi_nn = VoronoiNN() - unique_indices = [equ[0] for equ in u_cell.equivalent_indices] + 1. When the slab region is completely positioned between + two vacuum layers in the cell but is not centered, we simply + shift the Slab to the center along z-axis. - for idx in unique_indices: - el = u_cell[idx].species_string - if el not in cn_dict: - cn_dict[el] = [] - # Since this will get the cn as a result of the weighted polyhedra, the - # slightest difference in cn will indicate a different environment for a - # species, eg. bond distance of each neighbor or neighbor species. The - # decimal place to get some cn to be equal. - cn = voronoi_nn.get_cn(u_cell, idx, use_weights=True) - cn = float(f"{round(cn, 5):.5f}") - if cn not in cn_dict[el]: - cn_dict[el].append(cn) + 2. If the Slab completely resides outside the cell either + from the bottom or the top, we iterate through all sites that + spill over and shift all sites such that it is now + on the other side. An edge case being, either the top + of the Slab is at z = 0 or the bottom is at z = 1. - voronoi_nn = VoronoiNN() + TODO (@DanielYang59): this should be a method for `Slab`? - surf_sites_dict, properties = {"top": [], "bottom": []}, [] - for idx, site in enumerate(self): - # Determine if site is closer to the top or bottom of the slab - top = site.frac_coords[2] > self.center_of_mass[2] + Args: + slab (Slab): The Slab to center. - try: - # A site is a surface site, if its environment does - # not fit the environment of other sites - cn = float(f"{round(voronoi_nn.get_cn(self, idx, use_weights=True), 5):.5f}") - if cn < min(cn_dict[site.species_string]): - properties.append(True) - key = "top" if top else "bottom" - surf_sites_dict[key].append([site, idx]) - else: - properties.append(False) - except RuntimeError: - # or if pathological error is returned, indicating a surface site - properties.append(True) - key = "top" if top else "bottom" - surf_sites_dict[key].append([site, idx]) + Returns: + Slab: The centered Slab. + """ + # Get all site indices + all_indices = list(range(len(slab))) + + # Get a reasonable cutoff radius to sample neighbors + bond_dists = sorted(nn[1] for nn in slab.get_neighbors(slab[0], 10) if nn[1] > 0) + # TODO (@DanielYang59): magic number for cutoff radius (would 3 be too large?) + cutoff_radius = bond_dists[0] * 3 + + # TODO (@DanielYang59): do we need the following complex method? + # Why don't we just calculate the center of the Slab and move it to z=0.5? + # Before moving we need to ensure there is only one Slab layer though + + # If structure is case 2, shift all the sites + # to the other side until it is case 1 + for site in slab: # DEBUG (@DanielYang59): Slab position changes during loop? + # DEBUG (@DanielYang59): sites below z=0 is not considered (only check coord > c) + if any(nn[1] >= slab.lattice.c for nn in slab.get_neighbors(site, cutoff_radius)): + # TODO (@DanielYang59): the magic offset "0.05" seems unnecessary, + # as the Slab would be centered later anyway + shift = 1 - site.frac_coords[2] + 0.05 + slab.translate_sites(all_indices, [0, 0, shift]) - if tag: - self.add_site_property("is_surf_site", properties) - return surf_sites_dict + # Now the slab is case 1, move it to the center + weights = [site.species.weight for site in slab] + center_of_mass = np.average(slab.frac_coords, weights=weights, axis=0) + shift = 0.5 - center_of_mass[2] - def get_symmetric_site(self, point, cartesian=False): - """This method uses symmetry operations to find equivalent sites on - both sides of the slab. Works mainly for slabs with Laue - symmetry. This is useful for retaining the non-polar and - symmetric properties of a slab when creating adsorbed - structures or symmetric reconstructions. + slab.translate_sites(all_indices, [0, 0, shift]) - Arg: - point: Fractional coordinate. + return slab - Returns: - point: Fractional coordinate. A point equivalent to the - parameter point, but on the other side of the slab - """ - sg = SpacegroupAnalyzer(self) - ops = sg.get_symmetry_operations(cartesian=cartesian) - # Each operation on a point will return an equivalent point. - # We want to find the point on the other side of the slab. - for op in ops: - slab = self.copy() - site2 = op.operate(point) - if f"{site2[2]:.6f}" == f"{point[2]:.6f}": - continue +def get_slab_regions( + slab: Slab, + blength: float = 3.5, +) -> list[tuple[float, float]]: + """Find the z-ranges for the slab region. - # Add dummy site to check the overall structure is symmetric - slab.append("O", point, coords_are_cartesian=cartesian) - slab.append("O", site2, coords_are_cartesian=cartesian) - sg = SpacegroupAnalyzer(slab) - if sg.is_laue(): - break + Useful for discerning where the slab ends and vacuum begins + if the slab is not fully within the cell. - # If not symmetric, remove the two added - # sites and try another symmetry operator - slab.remove_sites([len(slab) - 1]) - slab.remove_sites([len(slab) - 1]) + TODO (@DanielYang59): this should be a method for `Slab`? - return site2 + TODO (@DanielYang59): maybe project all z coordinates to 1D? - def symmetrically_add_atom(self, specie, point, coords_are_cartesian=False) -> None: - """Class method for adding a site at a specified point in a slab. - Will add the corresponding site on the other side of the - slab to maintain equivalent surfaces. + Args: + slab (Slab): The Slab to analyse. + blength (float): The bond length between atoms in Angstrom. + You generally want this value to be larger than the actual + bond length in order to find atoms that are part of the slab. + """ + frac_coords: list = [] # TODO (@DanielYang59): zip site and coords? + indices: list = [] - Arg: - specie (str): The specie to add - point (coords): The coordinate of the site in the slab to add. - coords_are_cartesian (bool): Is the point in Cartesian coordinates + all_indices: list = [] - Returns: - Slab: The modified slab - """ - # For now just use the species of the - # surface atom as the element to add + for site in slab: + neighbors = slab.get_neighbors(site, blength) + for nn in neighbors: + # TODO (@DanielYang59): use z coordinate (z<0) to check + # if a Slab is contiguous is suspicious (Slab could locate + # entirely below z=0) - # Get the index of the corresponding site at the bottom - point2 = self.get_symmetric_site(point, cartesian=coords_are_cartesian) + # Find sites with z < 0 (sites noncontiguous within cell) + if nn[0].frac_coords[2] < 0: + frac_coords.append(nn[0].frac_coords[2]) + indices.append(nn[-2]) - self.append(specie, point, coords_are_cartesian=coords_are_cartesian) - self.append(specie, point2, coords_are_cartesian=coords_are_cartesian) + if nn[-2] not in all_indices: + all_indices.append(nn[-2]) - def symmetrically_remove_atoms(self, indices) -> None: - """Class method for removing sites corresponding to a list of indices. - Will remove the corresponding site on the other side of the - slab to maintain equivalent surfaces. + # If slab is noncontiguous + if frac_coords: + # Locate the lowest site within the upper Slab + while frac_coords: + last_fcoords = copy.copy(frac_coords) + last_indices = copy.copy(indices) - Arg: - indices ([indices]): The indices of the sites - in the slab to remove. - """ - slab_copy = SpacegroupAnalyzer(self.copy()).get_symmetrized_structure() - points = [slab_copy[i].frac_coords for i in indices] - removal_list = [] - - for pt in points: - # Get the index of the original site on top - cart_point = slab_copy.lattice.get_cartesian_coords(pt) - dist = [site.distance_from_point(cart_point) for site in slab_copy] - site1 = dist.index(min(dist)) - - # Get the index of the corresponding site at the bottom - for i, eq_sites in enumerate(slab_copy.equivalent_sites): - if slab_copy[site1] in eq_sites: - eq_indices = slab_copy.equivalent_indices[i] - break - i1 = eq_indices[eq_sites.index(slab_copy[site1])] + site = slab[indices[frac_coords.index(min(frac_coords))]] + neighbors = slab.get_neighbors(site, blength, include_index=True, include_image=True) + frac_coords, indices = [], [] + for nn in neighbors: + if 1 > nn[0].frac_coords[2] > 0 and nn[0].frac_coords[2] < site.frac_coords[2]: + # Sites are noncontiguous within cell + frac_coords.append(nn[0].frac_coords[2]) + indices.append(nn[-2]) + if nn[-2] not in all_indices: + all_indices.append(nn[-2]) - for i2 in eq_indices: - if i2 == i1: - continue - if slab_copy[i2].frac_coords[2] == slab_copy[i1].frac_coords[2]: - continue - # Test site remove to see if it results in symmetric slab - slab = self.copy() - slab.remove_sites([i1, i2]) - if slab.is_symmetric(): - removal_list.extend([i1, i2]) - break + # Locate the highest site within the lower Slab + upper_fcoords: list = [] + for site in slab: + if all(nn.index not in all_indices for nn in slab.get_neighbors(site, blength)): + upper_fcoords.append(site.frac_coords[2]) + coords: list = copy.copy(frac_coords) if frac_coords else copy.copy(last_fcoords) + min_top = slab[last_indices[coords.index(min(coords))]].frac_coords[2] + return [(0, max(upper_fcoords)), (min_top, 1)] - # If expected, 2 atoms are removed per index - if len(removal_list) == 2 * len(indices): - self.remove_sites(removal_list) - else: - warnings.warn("Equivalent sites could not be found for removal for all indices. Surface unchanged.") + # If the entire slab region is within the cell, just + # set the range as the highest and lowest site in the Slab + sorted_sites = sorted(slab, key=lambda site: site.frac_coords[2]) + return [(sorted_sites[0].frac_coords[2], sorted_sites[-1].frac_coords[2])] class SlabGenerator: - """This class generates different slabs using shift values determined by where - a unique termination can be found along with other criteria such as where a + """Generate different slabs using shift values determined by where + a unique termination can be found, along with other criteria such as where a termination doesn't break a polyhedral bond. The shift value then indicates where the slab layer will begin and terminate in the slab-vacuum system. Attributes: - oriented_unit_cell (Structure): A unit cell of the parent structure with the miller - index of plane parallel to surface. + oriented_unit_cell (Structure): An oriented unit cell of the parent structure. parent (Structure): Parent structure from which Slab was derived. - lll_reduce (bool): Whether or not the slabs will be orthogonalized. - center_slab (bool): Whether or not the slabs will be centered between the vacuum layer. - slab_scale_factor (float): Final computed scale factor that brings the parent cell to the - surface cell. + lll_reduce (bool): Whether the slabs will be orthogonalized. + center_slab (bool): Whether the slabs will be centered in the slab-vacuum system. + slab_scale_factor (float): Scale factor that brings + the parent cell to the surface cell. miller_index (tuple): Miller index of plane parallel to surface. - min_slab_size (float): Minimum size in angstroms of layers containing atoms. - min_vac_size (float): Minimum size in angstroms of layers containing vacuum. + min_slab_size (float): Minimum size of layers containing atoms, in angstroms. + min_vac_size (float): Minimum vacuum layer size, in angstroms. """ def __init__( self, - initial_structure, - miller_index, - min_slab_size, - min_vacuum_size, - lll_reduce=False, - center_slab=False, - in_unit_planes=False, - primitive=True, - max_normal_search=None, - reorient_lattice=True, + initial_structure: Structure, + miller_index: tuple[int, int, int], + min_slab_size: float, + min_vacuum_size: float, + lll_reduce: bool = False, + center_slab: bool = False, + in_unit_planes: bool = False, + primitive: bool = True, + max_normal_search: int | None = None, + reorient_lattice: bool = True, ) -> None: - """Calculates the slab scale factor and uses it to generate a unit cell - of the initial structure that has been oriented by its miller index. + """Calculates the slab scale factor and uses it to generate an + oriented unit cell (OUC) of the initial structure. Also stores the initial information needed later on to generate a slab. Args: initial_structure (Structure): Initial input structure. Note that to - ensure that the miller indices correspond to usual + ensure that the Miller indices correspond to usual crystallographic definitions, you should supply a conventional unit cell structure. - miller_index ([h, k, l]): Miller index of plane parallel to - surface. Note that this is referenced to the input structure. If - you need this to be based on the conventional cell, + miller_index ([h, k, l]): Miller index of the plane parallel to + the surface. Note that this is referenced to the input structure. + If you need this to be based on the conventional cell, you should supply the conventional structure. min_slab_size (float): In Angstroms or number of hkl planes min_vacuum_size (float): In Angstroms or number of hkl planes lll_reduce (bool): Whether to perform an LLL reduction on the - eventual structure. + final structure. center_slab (bool): Whether to center the slab in the cell with equal vacuum spacing from the top and bottom. in_unit_planes (bool): Whether to set min_slab_size and min_vac_size - in units of hkl planes (True) or Angstrom (False/default). - Setting in units of planes is useful for ensuring some slabs - have a certain n_layer of atoms. e.g. for Cs (100), a 10 Ang - slab will result in a slab with only 2 layer of atoms, whereas - Fe (100) will have more layer of atoms. By using units of hkl - planes instead, we ensure both slabs - have the same number of atoms. The slab thickness will be in - min_slab_size/math.ceil(self._proj_height/dhkl) + in number of hkl planes or Angstrom (default). + Setting in units of planes is useful to ensure some slabs + to have a certain number of layers, e.g. for Cs(100), 10 Ang + will result in a slab with only 2 layers, whereas + Fe(100) will have more layers. The slab thickness + will be in min_slab_size/math.ceil(self._proj_height/dhkl) multiples of oriented unit cells. - primitive (bool): Whether to reduce any generated slabs to a - primitive cell (this does **not** mean the slab is generated - from a primitive cell, it simply means that after slab - generation, we attempt to find shorter lattice vectors, - which lead to less surface area and smaller cells). - max_normal_search (int): If set to a positive integer, the code will - conduct a search for a normal lattice vector that is as - perpendicular to the surface as possible by considering - multiples linear combinations of lattice vectors up to - max_normal_search. This has no bearing on surface energies, - but may be useful as a preliminary step to generating slabs - for absorption and other sizes. It is typical that this will - not be the smallest possible cell for simulation. Normality - is not guaranteed, but the oriented cell will have the c - vector as normal as possible (within the search range) to the - surface. A value of up to the max absolute Miller index is - usually sufficient. - reorient_lattice (bool): reorients the lattice parameters such that - the c direction is the third vector of the lattice matrix + primitive (bool): Whether to reduce generated slabs to + primitive cell. Note this does NOT generate a slab + from a primitive cell, it means that after slab + generation, we attempt to reduce the generated slab to + primitive cell. + max_normal_search (int): If set to a positive integer, the code + will search for a normal lattice vector that is as + perpendicular to the surface as possible, by considering + multiple linear combinations of lattice vectors up to + this value. This has no bearing on surface energies, + but may be useful as a preliminary step to generate slabs + for absorption or other sizes. It may not be the smallest possible + cell for simulation. Normality is not guaranteed, but the oriented + cell will have the c vector as normal as possible to the surface. + The max absolute Miller index is usually sufficient. + reorient_lattice (bool): reorient the lattice such that + the c direction is parallel to the third lattice vector """ - # Add Wyckoff symbols of the bulk, will help with - # identifying types of sites in the slab system - if ( - "bulk_wyckoff" not in initial_structure.site_properties - or "bulk_equivalent" not in initial_structure.site_properties - ): - sg = SpacegroupAnalyzer(initial_structure) - initial_structure.add_site_property("bulk_wyckoff", sg.get_symmetry_dataset()["wyckoffs"]) - initial_structure.add_site_property( - "bulk_equivalent", sg.get_symmetry_dataset()["equivalent_atoms"].tolist() - ) - latt = initial_structure.lattice - miller_index = _reduce_vector(miller_index) - # Calculate the surface normal using the reciprocal lattice vector. - recp = latt.reciprocal_lattice_crystallographic - normal = recp.get_cartesian_coords(miller_index) - normal /= np.linalg.norm(normal) - slab_scale_factor = [] - non_orth_ind = [] - eye = np.eye(3, dtype=int) - for ii, jj in enumerate(miller_index): - if jj == 0: - # Lattice vector is perpendicular to surface normal, i.e., - # in plane of surface. We will simply choose this lattice - # vector as one of the basis vectors. - slab_scale_factor.append(eye[ii]) - else: - # Calculate projection of lattice vector onto surface normal. - d = abs(np.dot(normal, latt.matrix[ii])) / latt.abc[ii] - non_orth_ind.append((ii, d)) - - # We want the vector that has maximum magnitude in the - # direction of the surface normal as the c-direction. - # Results in a more "orthogonal" unit cell. - c_index, _dist = max(non_orth_ind, key=lambda t: t[1]) - - if len(non_orth_ind) > 1: - lcm_miller = lcm(*(miller_index[i] for i, d in non_orth_ind)) - for (ii, _di), (jj, _dj) in itertools.combinations(non_orth_ind, 2): - scale_factor = [0, 0, 0] - scale_factor[ii] = -int(round(lcm_miller / miller_index[ii])) - scale_factor[jj] = int(round(lcm_miller / miller_index[jj])) - slab_scale_factor.append(scale_factor) - if len(slab_scale_factor) == 2: - break + def reduce_vector(vector: tuple[int, int, int]) -> tuple[int, int, int]: + """Helper function to reduce vectors.""" + divisor = abs(reduce(gcd, vector)) # type: ignore[arg-type] + return cast(tuple[int, int, int], tuple(int(idx / divisor) for idx in vector)) - if max_normal_search is None: - slab_scale_factor.append(eye[c_index]) - else: - index_range = sorted( - range(-max_normal_search, max_normal_search + 1), - key=lambda x: -abs(x), - ) - candidates = [] - for uvw in itertools.product(index_range, index_range, index_range): - if (not any(uvw)) or abs(np.linalg.det([*slab_scale_factor, uvw])) < 1e-8: - continue - vec = latt.get_cartesian_coords(uvw) - osdm = np.linalg.norm(vec) - cosine = abs(np.dot(vec, normal) / osdm) - candidates.append((uvw, cosine, osdm)) - if abs(abs(cosine) - 1) < 1e-8: - # If cosine of 1 is found, no need to search further. - break - # We want the indices with the maximum absolute cosine, - # but smallest possible length. - uvw, cosine, osdm = max(candidates, key=lambda x: (x[1], -x[2])) - slab_scale_factor.append(uvw) - - slab_scale_factor = np.array(slab_scale_factor) - - # Let's make sure we have a left-handed crystallographic system - if np.linalg.det(slab_scale_factor) < 0: - slab_scale_factor *= -1 + def add_site_types() -> None: + """Add Wyckoff symbols and equivalent sites to the initial structure.""" + if ( + "bulk_wyckoff" not in initial_structure.site_properties + or "bulk_equivalent" not in initial_structure.site_properties + ): + spg_analyzer = SpacegroupAnalyzer(initial_structure) + initial_structure.add_site_property("bulk_wyckoff", spg_analyzer.get_symmetry_dataset()["wyckoffs"]) + initial_structure.add_site_property( + "bulk_equivalent", spg_analyzer.get_symmetry_dataset()["equivalent_atoms"].tolist() + ) - # Make sure the slab_scale_factor is reduced to avoid - # unnecessarily large slabs + def calculate_surface_normal() -> np.ndarray: + """Calculate the unit surface normal vector using the reciprocal + lattice vector. + """ + recip_lattice = lattice.reciprocal_lattice_crystallographic + + normal = recip_lattice.get_cartesian_coords(miller_index) + normal /= np.linalg.norm(normal) + return normal + + def calculate_scaling_factor() -> np.ndarray: + """Calculate scaling factor. + + # TODO (@DanielYang59): revise docstring to add more details. + """ + slab_scale_factor = [] + non_orth_ind = [] + eye = np.eye(3, dtype=int) + for idx, miller_idx in enumerate(miller_index): + if miller_idx == 0: + # If lattice vector is perpendicular to surface normal, i.e., + # in plane of surface. We will simply choose this lattice + # vector as the basis vector + slab_scale_factor.append(eye[idx]) - reduced_scale_factor = [_reduce_vector(v) for v in slab_scale_factor] - slab_scale_factor = np.array(reduced_scale_factor) + else: + # Calculate projection of lattice vector onto surface normal. + d = abs(np.dot(normal, lattice.matrix[idx])) / lattice.abc[idx] + non_orth_ind.append((idx, d)) + + # We want the vector that has maximum magnitude in the + # direction of the surface normal as the c-direction. + # Results in a more "orthogonal" unit cell. + c_index, _dist = max(non_orth_ind, key=lambda t: t[1]) + + if len(non_orth_ind) > 1: + lcm_miller = lcm(*(miller_index[i] for i, _d in non_orth_ind)) + for (ii, _di), (jj, _dj) in itertools.combinations(non_orth_ind, 2): + scale_factor = [0, 0, 0] + scale_factor[ii] = -int(round(lcm_miller / miller_index[ii])) + scale_factor[jj] = int(round(lcm_miller / miller_index[jj])) + slab_scale_factor.append(scale_factor) + if len(slab_scale_factor) == 2: + break + + if max_normal_search is None: + slab_scale_factor.append(eye[c_index]) + else: + index_range = sorted( + range(-max_normal_search, max_normal_search + 1), + key=lambda x: -abs(x), + ) + candidates = [] + for uvw in itertools.product(index_range, index_range, index_range): + if (not any(uvw)) or abs(np.linalg.det([*slab_scale_factor, uvw])) < 1e-8: + continue + vec = lattice.get_cartesian_coords(uvw) + osdm = np.linalg.norm(vec) + cosine = abs(np.dot(vec, normal) / osdm) + candidates.append((uvw, cosine, osdm)) + # Stop searching if cosine equals 1 or -1 + if isclose(abs(cosine), 1, abs_tol=1e-8): + break + # We want the indices with the maximum absolute cosine, + # but smallest possible length. + uvw, cosine, osdm = max(candidates, key=lambda x: (x[1], -x[2])) + slab_scale_factor.append(uvw) + + slab_scale_factor = np.array(slab_scale_factor) + + # Let's make sure we have a left-handed crystallographic system + if np.linalg.det(slab_scale_factor) < 0: + slab_scale_factor *= -1 + + # Make sure the slab_scale_factor is reduced to avoid + # unnecessarily large slabs + reduced_scale_factor = [reduce_vector(v) for v in slab_scale_factor] + return np.array(reduced_scale_factor) + + # Add Wyckoff symbols and equivalent sites to the initial structure, + # to help identify types of sites in the generated slab + add_site_types() + + # Calculate the surface normal + lattice = initial_structure.lattice + miller_index = reduce_vector(miller_index) + normal = calculate_surface_normal() + + # Calculate scale factor + slab_scale_factor = calculate_scaling_factor() single = initial_structure.copy() single.make_supercell(slab_scale_factor) - # When getting the OUC, lets return the most reduced - # structure as possible to reduce calculations + # Calculate the most reduced structure as OUC to minimize calculations self.oriented_unit_cell = Structure.from_sites(single, to_unit_cell=True) + self.max_normal_search = max_normal_search self.parent = initial_structure self.lll_reduce = lll_reduce @@ -864,77 +1061,102 @@ def __init__( self.min_slab_size = min_slab_size self.in_unit_planes = in_unit_planes self.primitive = primitive - self._normal = normal + self._normal = normal # TODO (@DanielYang59): used only in unit test + self.reorient_lattice = reorient_lattice + _a, _b, c = self.oriented_unit_cell.lattice.matrix self._proj_height = abs(np.dot(normal, c)) - self.reorient_lattice = reorient_lattice - def get_slab(self, shift=0, tol: float = 0.1, energy=None): - """This method takes in shift value for the c lattice direction and - generates a slab based on the given shift. You should rarely use this - method. Instead, it is used by other generation algorithms to obtain - all slabs. + def get_slab(self, shift: float = 0, tol: float = 0.1, energy: float | None = None) -> Slab: + """Generate a slab based on a given shift value along the lattice c direction. - Arg: - shift (float): A shift value in Angstrom that determines how much a - slab should be shifted. + Note: + You should rarely use this (private) method directly, which is + intended for other generation methods. + + Args: + shift (float): The shift value along the lattice c direction in Angstrom. tol (float): Tolerance to determine primitive cell. - energy (float): An energy to assign to the slab. + energy (float): The energy to assign to the slab. Returns: - Slab: with a particular shifted oriented unit cell. + Slab: from a shifted oriented unit cell. """ - h = self._proj_height - p = round(h / self.parent.lattice.d_hkl(self.miller_index), 8) + scale_factor = self.slab_scale_factor + + # Calculate total number of layers + height = self._proj_height + height_per_layer = round(height / self.parent.lattice.d_hkl(self.miller_index), 8) + if self.in_unit_planes: - n_layers_slab = int(math.ceil(self.min_slab_size / p)) - n_layers_vac = int(math.ceil(self.min_vac_size / p)) + n_layers_slab = math.ceil(self.min_slab_size / height_per_layer) + n_layers_vac = math.ceil(self.min_vac_size / height_per_layer) else: - n_layers_slab = int(math.ceil(self.min_slab_size / h)) - n_layers_vac = int(math.ceil(self.min_vac_size / h)) + n_layers_slab = math.ceil(self.min_slab_size / height) + n_layers_vac = math.ceil(self.min_vac_size / height) + n_layers = n_layers_slab + n_layers_vac + # Prepare for Slab generation: lattice, species, coords and site_properties + a, b, c = self.oriented_unit_cell.lattice.matrix + new_lattice = [a, b, n_layers * c] + species = self.oriented_unit_cell.species_and_occu - props = self.oriented_unit_cell.site_properties - props = {k: v * n_layers_slab for k, v in props.items()} # type: ignore[operator, misc] + + # Shift all atoms + # DEBUG(@DanielYang59): shift value in Angstrom inconsistent with frac_coordis frac_coords = self.oriented_unit_cell.frac_coords + # DEBUG(@DanielYang59): suspicious shift direction (positive for downwards shift) frac_coords = np.array(frac_coords) + np.array([0, 0, -shift])[None, :] - frac_coords -= np.floor(frac_coords) - a, b, c = self.oriented_unit_cell.lattice.matrix - new_lattice = [a, b, n_layers * c] + frac_coords -= np.floor(frac_coords) # wrap frac_coords to the [0, 1) range + + # Scale down z-coordinate by the number of layers frac_coords[:, 2] = frac_coords[:, 2] / n_layers + + # Duplicate atom layers by stacking along the z-axis all_coords = [] for idx in range(n_layers_slab): - f_coords = frac_coords.copy() - f_coords[:, 2] += idx / n_layers - all_coords.extend(f_coords) + _frac_coords = frac_coords.copy() + _frac_coords[:, 2] += idx / n_layers + all_coords.extend(_frac_coords) + # Scale properties by number of atom layers (excluding vacuum) + props = self.oriented_unit_cell.site_properties + props = {k: v * n_layers_slab for k, v in props.items()} + + # Generate Slab slab = Structure(new_lattice, species * n_layers_slab, all_coords, site_properties=props) - scale_factor = self.slab_scale_factor - # Whether or not to orthogonalize the structure + # (Optionally) Post-process the Slab + # Orthogonalize the structure (through LLL lattice basis reduction) if self.lll_reduce: + # Sanitize Slab (LLL reduction + site sorting + map frac_coords) lll_slab = slab.copy(sanitize=True) - mapping = lll_slab.lattice.find_mapping(slab.lattice) - assert mapping is not None, "LLL reduction has failed" # mypy type narrowing - scale_factor = np.dot(mapping[2], scale_factor) # type: ignore[index] slab = lll_slab - # Whether or not to center the slab layer around the vacuum + # Apply reduction on the scaling factor + mapping = lll_slab.lattice.find_mapping(slab.lattice) + if mapping is None: + raise RuntimeError("LLL reduction has failed") + scale_factor = np.dot(mapping[2], scale_factor) + + # Center the slab layer around the vacuum if self.center_slab: - avg_c = np.average([c[2] for c in slab.frac_coords]) - slab.translate_sites(list(range(len(slab))), [0, 0, 0.5 - avg_c]) + c_center = np.average([coord[2] for coord in slab.frac_coords]) + slab.translate_sites(list(range(len(slab))), [0, 0, 0.5 - c_center]) + # Reduce to primitive cell if self.primitive: - prim = slab.get_primitive_structure(tolerance=tol) + prim_slab = slab.get_primitive_structure(tolerance=tol) + slab = prim_slab + if energy is not None: - energy = prim.volume / slab.volume * energy - slab = prim + energy *= prim_slab.volume / slab.volume - # Reorient the lattice to get the correct reduced cell + # Reorient the lattice to get the correctly reduced cell ouc = self.oriented_unit_cell.copy() if self.primitive: - # find a reduced ouc + # Find a reduced OUC slab_l = slab.lattice ouc = ouc.get_primitive_structure( constrain_latt={ @@ -945,8 +1167,9 @@ def get_slab(self, shift=0, tol: float = 0.1, energy=None): "gamma": slab_l.gamma, } ) - # Check this is the correct oriented unit cell - ouc = self.oriented_unit_cell if slab_l.a != ouc.lattice.a or slab_l.b != ouc.lattice.b else ouc + + # Ensure lattice a and b are consistent between the OUC and the Slab + ouc = ouc if (slab_l.a == ouc.lattice.a and slab_l.b == ouc.lattice.b) else self.oriented_unit_cell return Slab( slab.lattice, @@ -955,255 +1178,313 @@ def get_slab(self, shift=0, tol: float = 0.1, energy=None): self.miller_index, ouc, shift, - scale_factor, - energy=energy, - site_properties=slab.site_properties, + scale_factor, reorient_lattice=self.reorient_lattice, + site_properties=slab.site_properties, + energy=energy, ) - def _calculate_possible_shifts(self, tol: float = 0.1): - frac_coords = self.oriented_unit_cell.frac_coords - n = len(frac_coords) - - if n == 1: - # Clustering does not work when there is only one data point. - shift = frac_coords[0][2] + 0.5 - return [shift - math.floor(shift)] - - # We cluster the sites according to the c coordinates. But we need to - # take into account PBC. Let's compute a fractional c-coordinate - # distance matrix that accounts for PBC. - dist_matrix = np.zeros((n, n)) - h = self._proj_height - # Projection of c lattice vector in - # direction of surface normal. - for i, j in itertools.combinations(list(range(n)), 2): - if i != j: - cdist = frac_coords[i][2] - frac_coords[j][2] - cdist = abs(cdist - round(cdist)) * h - dist_matrix[i, j] = cdist - dist_matrix[j, i] = cdist - - condensed_m = squareform(dist_matrix) - z = linkage(condensed_m) - clusters = fcluster(z, tol, criterion="distance") - - # Generate dict of cluster# to c val - doesn't matter what the c is. - c_loc = {c: frac_coords[i][2] for i, c in enumerate(clusters)} - - # Put all c into the unit cell. - possible_c = [c - math.floor(c) for c in sorted(c_loc.values())] - - # Calculate the shifts - n_shifts = len(possible_c) - shifts = [] - for i in range(n_shifts): - if i == n_shifts - 1: - # There is an additional shift between the first and last c - # coordinate. But this needs special handling because of PBC. - shift = (possible_c[0] + 1 + possible_c[i]) * 0.5 - if shift > 1: - shift -= 1 - else: - shift = (possible_c[i] + possible_c[i + 1]) * 0.5 - shifts.append(shift - math.floor(shift)) - return sorted(shifts) - - def _get_c_ranges(self, bonds): - c_ranges = [] - bonds = {(get_el_sp(s1), get_el_sp(s2)): dist for (s1, s2), dist in bonds.items()} - for (sp1, sp2), bond_dist in bonds.items(): - for site in self.oriented_unit_cell: - if sp1 in site.species: - for nn in self.oriented_unit_cell.get_neighbors(site, bond_dist): - if sp2 in nn.species: - c_range = tuple(sorted([site.frac_coords[2], nn.frac_coords[2]])) - if c_range[1] > 1: - # Takes care of PBC when c coordinate of site - # goes beyond the upper boundary of the cell - c_ranges.extend(((c_range[0], 1), (0, c_range[1] - 1))) - elif c_range[0] < 0: - # Takes care of PBC when c coordinate of site - # is below the lower boundary of the unit cell - c_ranges.extend(((0, c_range[1]), (c_range[0] + 1, 1))) - elif c_range[0] != c_range[1]: - c_ranges.append((c_range[0], c_range[1])) - return c_ranges - def get_slabs( self, - bonds=None, - ftol=0.1, - tol=0.1, - max_broken_bonds=0, - symmetrize=False, - repair=False, - ): - """This method returns a list of slabs that are generated using the list of - shift values from the method, _calculate_possible_shifts(). Before the - shifts are used to create the slabs however, if the user decides to take - into account whether or not a termination will break any polyhedral - structure (bonds is not None), this method will filter out any shift - values that do so. + bonds: dict[tuple[Species | Element, Species | Element], float] | None = None, + ftol: float = 0.1, + tol: float = 0.1, + max_broken_bonds: int = 0, + symmetrize: bool = False, + repair: bool = False, + ) -> list[Slab]: + """Generate slabs with shift values calculated from the internal + calculate_possible_shifts method. If the user decide to avoid breaking + any polyhedral bond (by setting `bonds`), any shift value that do so + would be filtered out. Args: - bonds ({(specie1, specie2): max_bond_dist}: bonds are - specified as a dict of tuples: float of specie1, specie2 - and the max bonding distance. For example, PO4 groups may be - defined as {("P", "O"): 3}. - tol (float): General tolerance parameter for getting primitive - cells and matching structures - ftol (float): Threshold parameter in fcluster in order to check - if two atoms are lying on the same plane. Default thresh set - to 0.1 Angstrom in the direction of the surface normal. + bonds (dict): A {(species1, species2): max_bond_dist} dict. + For example, PO4 groups may be defined as {("P", "O"): 3}. + tol (float): Fractional tolerance for getting primitive cells + and matching structures. + ftol (float): Threshold for fcluster to check if two atoms are + on the same plane. Default to 0.1 Angstrom in the direction of + the surface normal. max_broken_bonds (int): Maximum number of allowable broken bonds - for the slab. Use this to limit # of slabs (some structures - may have a lot of slabs). Defaults to zero, which means no - defined bonds must be broken. - symmetrize (bool): Whether or not to ensure the surfaces of the - slabs are equivalent. - repair (bool): Whether to repair terminations with broken bonds - or just omit them. Set to False as repairing terminations can - lead to many possible slabs as oppose to just omitting them. + for the slab. Use this to limit number of slabs. Defaults to 0, + which means no bonds could be broken. + symmetrize (bool): Whether to enforce the equivalency of slab surfaces. + repair (bool): Whether to repair terminations with broken bonds (True) + or just omit them (False). Default to False as repairing terminations + can lead to many more possible slabs. Returns: - list[Slab]: all possible terminations of a particular surface. - Slabs are sorted by the # of bonds broken. + list[Slab]: All possible Slabs of a particular surface, + sorted by the number of bonds broken. """ - c_ranges = [] if bonds is None else self._get_c_ranges(bonds) + + def gen_possible_shifts(ftol: float) -> list[float]: + """Generate possible shifts by clustering z coordinates. + + Args: + ftol (float): Threshold for fcluster to check if + two atoms are on the same plane. + """ + frac_coords = self.oriented_unit_cell.frac_coords + n_atoms = len(frac_coords) + + # Clustering does not work when there is only one atom + if n_atoms == 1: + # TODO (@DanielYang59): why this magic number 0.5? + shift = frac_coords[0][2] + 0.5 + return [shift - math.floor(shift)] + + # Compute a Cartesian z-coordinate distance matrix + # TODO (@DanielYang59): account for periodic boundary condition + dist_matrix = np.zeros((n_atoms, n_atoms)) + for i, j in itertools.combinations(list(range(n_atoms)), 2): + if i != j: + z_dist = frac_coords[i][2] - frac_coords[j][2] + z_dist = abs(z_dist - round(z_dist)) * self._proj_height + dist_matrix[i, j] = z_dist + dist_matrix[j, i] = z_dist + + # Cluster the sites by z coordinates + z_matrix = linkage(squareform(dist_matrix)) + clusters = fcluster(z_matrix, ftol, criterion="distance") + + # Generate a cluster to z coordinate mapping + clst_loc = {c: frac_coords[i][2] for i, c in enumerate(clusters)} + + # Wrap all clusters into the unit cell ([0, 1) range) + possible_clst = [coord - math.floor(coord) for coord in sorted(clst_loc.values())] + + # Calculate shifts + n_shifts = len(possible_clst) + shifts = [] + for i in range(n_shifts): + # Handle the special case for the first-last + # z coordinate (because of periodic boundary condition) + if i == n_shifts - 1: + # TODO (@DanielYang59): Why calculate the "center" of the + # two clusters, which is not actually the shift? + shift = (possible_clst[0] + 1 + possible_clst[i]) * 0.5 + + else: + shift = (possible_clst[i] + possible_clst[i + 1]) * 0.5 + + shifts.append(shift - math.floor(shift)) + + return sorted(shifts) + + def get_z_ranges(bonds: dict[tuple[Species | Element, Species | Element], float]) -> list[tuple[float, float]]: + """Collect occupied z ranges where each z_range is a (lower_z, upper_z) tuple. + + This method examines all sites in the oriented unit cell (OUC) + and considers all neighboring sites within the specified bond distance + for each site. If a site and its neighbor meet bonding and species + requirements, their respective z-ranges will be collected. + + Args: + bonds (dict): A {(species1, species2): max_bond_dist} dict. + tol (float): Fractional tolerance for determine overlapping positions. + """ + # Sanitize species in dict keys + bonds = {(get_el_sp(s1), get_el_sp(s2)): dist for (s1, s2), dist in bonds.items()} + + z_ranges = [] + for (sp1, sp2), bond_dist in bonds.items(): + for site in self.oriented_unit_cell: + if sp1 in site.species: + for nn in self.oriented_unit_cell.get_neighbors(site, bond_dist): + if sp2 in nn.species: + z_range = tuple(sorted([site.frac_coords[2], nn.frac_coords[2]])) + + # Handle cases when z coordinate of site goes + # beyond the upper boundary + if z_range[1] > 1: + z_ranges.extend([(z_range[0], 1), (0, z_range[1] - 1)]) + + # When z coordinate is below the lower boundary + elif z_range[0] < 0: + z_ranges.extend([(0, z_range[1]), (z_range[0] + 1, 1)]) + + # Neglect overlapping positions + elif z_range[0] != z_range[1]: + # TODO (@DanielYang59): use the following for equality check + # elif not isclose(z_range[0], z_range[1], abs_tol=tol): + z_ranges.append(z_range) + + return z_ranges + + # Get occupied z_ranges + z_ranges = [] if bonds is None else get_z_ranges(bonds) slabs = [] - for shift in self._calculate_possible_shifts(tol=ftol): + for shift in gen_possible_shifts(ftol=ftol): + # Calculate total number of bonds broken (how often the shift + # position fall within the z_range occupied by a bond) bonds_broken = 0 - for r in c_ranges: - if r[0] <= shift <= r[1]: + for z_range in z_ranges: + if z_range[0] <= shift <= z_range[1]: bonds_broken += 1 - slab = self.get_slab(shift, tol=tol, energy=bonds_broken) + + # DEBUG(@DanielYang59): number of bonds broken passed to energy + # As per the docstring this is to sort final Slabs by number + # of bonds broken, but this may very likely lead to errors + # if the "energy" is used literally (Maybe reset energy to None?) + slab = self.get_slab(shift=shift, tol=tol, energy=bonds_broken) + if bonds_broken <= max_broken_bonds: slabs.append(slab) - elif repair: - # If the number of broken bonds is exceeded, - # we repair the broken bonds on the slab - slabs.append(self.repair_broken_bonds(slab, bonds)) - # Further filters out any surfaces made that might be the same + # If the number of broken bonds is exceeded, repair the broken bonds + elif repair and bonds is not None: + slabs.append(self.repair_broken_bonds(slab=slab, bonds=bonds)) + + # Filter out surfaces that might be the same matcher = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False) - new_slabs = [] - for g in matcher.group_structures(slabs): - # For each unique termination, symmetrize the - # surfaces by removing sites from the bottom. + final_slabs = [] + for group in matcher.group_structures(slabs): + # For each unique slab, symmetrize the + # surfaces by removing sites from the bottom if symmetrize: - slabs = self.nonstoichiometric_symmetrized_slab(g[0]) - new_slabs.extend(slabs) + sym_slabs = self.nonstoichiometric_symmetrized_slab(group[0]) + final_slabs.extend(sym_slabs) else: - new_slabs.append(g[0]) + final_slabs.append(group[0]) - match = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False) - new_slabs = [g[0] for g in match.group_structures(new_slabs)] + # Filter out similar surfaces generated by symmetrization + if symmetrize: + matcher_sym = StructureMatcher(ltol=tol, stol=tol, primitive_cell=False, scale=False) + final_slabs = [group[0] for group in matcher_sym.group_structures(final_slabs)] - return sorted(new_slabs, key=lambda s: s.energy) + return sorted(final_slabs, key=lambda slab: slab.energy) # type: ignore[return-value, arg-type] - def repair_broken_bonds(self, slab, bonds): - """This method will find undercoordinated atoms due to slab - cleaving specified by the bonds parameter and move them - to the other surface to make sure the bond is kept intact. - In a future release of surface.py, the ghost_sites will be - used to tell us how the repair bonds should look like. + def repair_broken_bonds( + self, + slab: Slab, + bonds: dict[tuple[Species | Element, Species | Element], float], + ) -> Slab: + """Repair broken bonds (specified by the bonds parameter) due to + slab cleaving, and repair them by moving undercoordinated atoms + to the other surface. + + How it works: + For example a P-O4 bond may have P and O(4-x) on one side + of the surface, and Ox on the other side, this method would + first move P (the reference atom) to the other side, + find its missing nearest neighbours (Ox), and move P + and Ox back together. - Arg: - slab (structure): A structure object representing a slab. - bonds ({(specie1, specie2): max_bond_dist}: bonds are - specified as a dict of tuples: float of specie1, specie2 - and the max bonding distance. For example, PO4 groups may be - defined as {("P", "O"): 3}. + Args: + slab (Slab): The Slab to repair. + bonds (dict): A {(species1, species2): max_bond_dist} dict. + For example, PO4 groups may be defined as {("P", "O"): 3}. Returns: - (Slab) A Slab object with a particular shifted oriented unit cell. + Slab: The repaired Slab. """ - for pair in bonds: - bond_len = bonds[pair] - - # First lets determine which element should be the - # reference (center element) to determine broken bonds. - # e.g. P for a PO4 bond. Find integer coordination - # numbers of the pair of elements w.r.t. to each other + for species_pair, bond_dist in bonds.items(): + # Determine which element should be the reference (center) + # element for determining broken bonds, e.g. P for PO4 bond. cn_dict = {} - for idx, el in enumerate(pair): + for idx, ele in enumerate(species_pair): cn_list = [] for site in self.oriented_unit_cell: - poly_coord = 0 - if site.species_string == el: - for nn in self.oriented_unit_cell.get_neighbors(site, bond_len): - if nn[0].species_string == pair[idx - 1]: - poly_coord += 1 - cn_list.append(poly_coord) - cn_dict[el] = cn_list - - # We make the element with the higher coordination our reference - if max(cn_dict[pair[0]]) > max(cn_dict[pair[1]]): - element1, element2 = pair + # Find integer coordination numbers for element pairs + ref_cn = 0 + if site.species_string == ele: + for nn in self.oriented_unit_cell.get_neighbors(site, bond_dist): + if nn[0].species_string == species_pair[idx - 1]: + ref_cn += 1 + + cn_list.append(ref_cn) + cn_dict[ele] = cn_list + + # Make the element with higher coordination the reference + if max(cn_dict[species_pair[0]]) > max(cn_dict[species_pair[1]]): + ele_ref, ele_other = species_pair else: - element2, element1 = pair + ele_other, ele_ref = species_pair for idx, site in enumerate(slab): - # Determine the coordination of our reference - if site.species_string == element1: - poly_coord = 0 - for neighbor in slab.get_neighbors(site, bond_len): - poly_coord += 1 if neighbor.species_string == element2 else 0 - - # suppose we find an undercoordinated reference atom - if poly_coord not in cn_dict[element1]: - # We get the reference atom of the broken bonds - # (undercoordinated), move it to the other surface + # Determine the coordination of the reference + if site.species_string == ele_ref: + ref_cn = sum( + 1 if neighbor.species_string == ele_other else 0 + for neighbor in slab.get_neighbors(site, bond_dist) + ) + + # Suppose we find an undercoordinated reference atom + # TODO (@DanielYang59): maybe use the following to + # check if the reference atom is "undercoordinated" + # if ref_cn < min(cn_dict[ele_ref]): + if ref_cn not in cn_dict[ele_ref]: + # Move this reference atom to the other side slab = self.move_to_other_side(slab, [idx]) - # find its NNs with the corresponding - # species it should be coordinated with - neighbors = slab.get_neighbors(slab[idx], bond_len, include_index=True) - to_move = [nn[2] for nn in neighbors if nn[0].species_string == element2] + # Find its NNs (with right species) it should bond to + neighbors = slab.get_neighbors(slab[idx], r=bond_dist) + to_move = [nn[2] for nn in neighbors if nn[0].species_string == ele_other] to_move.append(idx) - # and then move those NNs along with the central - # atom back to the other side of the slab again + + # Move those NNs along with the reference + # atom back to the other side of the slab slab = self.move_to_other_side(slab, to_move) return slab - def move_to_other_side(self, init_slab, index_of_sites): - """This method will Move a set of sites to the - other side of the slab (opposite surface). + def move_to_other_side( + self, + init_slab: Slab, + index_of_sites: list[int], + ) -> Slab: + """Move surface sites to the opposite surface of the Slab. + + If a selected site resides on the top half of the Slab, + it would be moved to the bottom side, and vice versa. + The distance moved is equal to the thickness of the Slab. - Arg: - init_slab (structure): A structure object representing a slab. - index_of_sites (list of ints): The list of indices representing - the sites we want to move to the other side. + Note: + You should only use this method on sites close to the + surface, otherwise it would end up deep inside the + vacuum layer. + + Args: + init_slab (Slab): The Slab whose sites would be moved. + index_of_sites (list[int]): Indices representing + the sites to move. Returns: - (Slab) A Slab object with a particular shifted oriented unit cell. + Slab: The Slab with selected sites moved. """ - slab = init_slab.copy() - - # Determine what fraction the slab is of the total cell size - # in the c direction. Round to nearest rational number. - h = self._proj_height - p = h / self.parent.lattice.d_hkl(self.miller_index) + # Calculate Slab height + height: float = self._proj_height + # Scale height if using number of hkl planes if self.in_unit_planes: - nlayers_slab = int(math.ceil(self.min_slab_size / p)) - nlayers_vac = int(math.ceil(self.min_vac_size / p)) - else: - nlayers_slab = int(math.ceil(self.min_slab_size / h)) - nlayers_vac = int(math.ceil(self.min_vac_size / h)) - nlayers = nlayers_slab + nlayers_vac - slab_ratio = nlayers_slab / nlayers - - # Sort the index of sites based on which side they are on - top_site_index = [i for i in index_of_sites if slab[i].frac_coords[2] > slab.center_of_mass[2]] - bottom_site_index = [i for i in index_of_sites if slab[i].frac_coords[2] < slab.center_of_mass[2]] + height /= self.parent.lattice.d_hkl(self.miller_index) + + # Calculate the moving distance as the fractional height + # of the Slab inside the cell + # DEBUG(@DanielYang59): the use actually sizes for slab/vac + # instead of the input arg (min_slab/vac_size) + n_layers_slab: int = math.ceil(self.min_slab_size / height) + n_layers_vac: int = math.ceil(self.min_vac_size / height) + n_layers: int = n_layers_slab + n_layers_vac + + frac_dist: float = n_layers_slab / n_layers + + # Separate selected sites into top and bottom + top_site_index: list[int] = [] + bottom_site_index: list[int] = [] + for idx in index_of_sites: + if init_slab[idx].frac_coords[2] >= init_slab.center_of_mass[2]: + top_site_index.append(idx) + else: + bottom_site_index.append(idx) - # Translate sites to the opposite surfaces - slab.translate_sites(top_site_index, [0, 0, slab_ratio]) - slab.translate_sites(bottom_site_index, [0, 0, -slab_ratio]) + # Move sites to the opposite surface + slab = init_slab.copy() + slab.translate_sites(top_site_index, vector=[0, 0, -frac_dist], frac_coords=True) + slab.translate_sites(bottom_site_index, vector=[0, 0, frac_dist], frac_coords=True) return Slab( init_slab.lattice, @@ -1216,254 +1497,433 @@ def move_to_other_side(self, init_slab, index_of_sites): energy=init_slab.energy, ) - def nonstoichiometric_symmetrized_slab(self, init_slab): - """This method checks whether or not the two surfaces of the slab are - equivalent. If the point group of the slab has an inversion symmetry ( - ie. belong to one of the Laue groups), then it is assumed that the - surfaces should be equivalent. Otherwise, sites at the bottom of the - slab will be removed until the slab is symmetric. Note the removal of sites - can destroy the stoichiometry of the slab. For non-elemental - structures, the chemical potential will be needed to calculate surface energy. + def nonstoichiometric_symmetrized_slab(self, init_slab: Slab) -> list[Slab]: + """Symmetrize the two surfaces of a Slab, but may break the stoichiometry. + + How it works: + 1. Check whether two surfaces of the slab are equivalent. + If the point group of the slab has an inversion symmetry ( + ie. belong to one of the Laue groups), then it's assumed that the + surfaces are equivalent. + + 2.If not symmetrical, sites at the bottom of the slab will be removed + until the slab is symmetric, which may break the stoichiometry. - Arg: - init_slab (Structure): A single slab structure + Args: + init_slab (Slab): The initial Slab. Returns: - Slab (structure): A symmetrized Slab object. + list[Slabs]: The symmetrized Slabs. """ if init_slab.is_symmetric(): return [init_slab] non_stoich_slabs = [] - # Build an equivalent surface slab for each of the different surfaces - for top in [True, False]: - asym = True + # Build a symmetrical surface slab for each of the different surfaces + for surface in ("top", "bottom"): + is_sym: bool = False slab = init_slab.copy() slab.energy = init_slab.energy - while asym: - # Keep removing sites from the bottom one by one until both - # surfaces are symmetric or the number of sites removed has + while not is_sym: + # Keep removing sites from the bottom until surfaces are + # symmetric or the number of sites removed has # exceeded 10 percent of the original slab + # TODO: (@DanielYang59) comment differs from implementation: + # no "exceeded 10 percent" check + z_coords: list[float] = [site[2] for site in slab.frac_coords] - c_dir = [site[2] for idx, site in enumerate(slab.frac_coords)] - - if top: - slab.remove_sites([c_dir.index(max(c_dir))]) + if surface == "top": + slab.remove_sites([z_coords.index(max(z_coords))]) else: - slab.remove_sites([c_dir.index(min(c_dir))]) + slab.remove_sites([z_coords.index(min(z_coords))]) + if len(slab) <= len(self.parent): + warnings.warn("Too many sites removed, please use a larger slab.") break - # Check if the altered surface is symmetric + # Check if the new Slab is symmetric + # TODO: (@DanielYang59): should have some feedback (warning) + # if cannot symmetrize the Slab if slab.is_symmetric(): - asym = False + is_sym = True non_stoich_slabs.append(slab) - if len(slab) <= len(self.parent): - warnings.warn("Too many sites removed, please use a larger slab size.") - return non_stoich_slabs +def generate_all_slabs( + structure: Structure, + max_index: int, + min_slab_size: float, + min_vacuum_size: float, + bonds: dict | None = None, + tol: float = 0.1, + ftol: float = 0.1, + max_broken_bonds: int = 0, + lll_reduce: bool = False, + center_slab: bool = False, + primitive: bool = True, + max_normal_search: int | None = None, + symmetrize: bool = False, + repair: bool = False, + include_reconstructions: bool = False, + in_unit_planes: bool = False, +) -> list[Slab]: + """Find all unique Slabs up to a given Miller index. + + Slabs oriented along certain Miller indices may be equivalent to + other Miller indices under symmetry operations. To avoid + duplication, such equivalent slabs would be filtered out. + For instance, CsCl has equivalent slabs in the (0,0,1), + (0,1,0), and (1,0,0) directions under symmetry operations. + + Args: + structure (Structure): Initial input structure. To + ensure that the Miller indices correspond to usual + crystallographic definitions, you should supply a + conventional unit cell. + max_index (int): The maximum Miller index to go up to. + min_slab_size (float): The minimum slab size in Angstrom. + min_vacuum_size (float): The minimum vacuum layer thickness in Angstrom. + bonds (dict): A {(species1, species2): max_bond_dist} dict. + For example, PO4 groups may be defined as {("P", "O"): 3}. + tol (float): Tolerance for getting primitive cells and + matching structures. + ftol (float): Tolerance in Angstrom for fcluster to check + if two atoms are on the same plane. Default to 0.1 Angstrom + in the direction of the surface normal. + max_broken_bonds (int): Maximum number of allowable broken bonds + for the slab. Use this to limit the number of slabs. + Defaults to zero, which means no bond can be broken. + lll_reduce (bool): Whether to perform an LLL reduction on the + final Slab. + center_slab (bool): Whether to center the slab in the cell with + equal vacuum spacing from the top and bottom. + primitive (bool): Whether to reduce generated slabs to + primitive cell. Note this does NOT generate a slab + from a primitive cell, it means that after slab + generation, we attempt to reduce the generated slab to + primitive cell. + max_normal_search (int): If set to a positive integer, the code + will search for a normal lattice vector that is as + perpendicular to the surface as possible, by considering + multiple linear combinations of lattice vectors up to + this value. This has no bearing on surface energies, + but may be useful as a preliminary step to generate slabs + for absorption or other sizes. It may not be the smallest possible + cell for simulation. Normality is not guaranteed, but the oriented + cell will have the c vector as normal as possible to the surface. + The max absolute Miller index is usually sufficient. + symmetrize (bool): Whether to ensure the surfaces of the + slabs are equivalent. + repair (bool): Whether to repair terminations with broken bonds + or just omit them. + include_reconstructions (bool): Whether to include reconstructed + slabs available in the reconstructions_archive.json file. Defaults to False. + in_unit_planes (bool): Whether to set min_slab_size and min_vac_size + in number of hkl planes or Angstrom (default). + Setting in units of planes is useful to ensure some slabs + to have a certain number of layers, e.g. for Cs(100), 10 Ang + will result in a slab with only 2 layers, whereas + Fe(100) will have more layers. The slab thickness + will be in min_slab_size/math.ceil(self._proj_height/dhkl) + multiples of oriented unit cells. + """ + all_slabs = [] + + for miller in get_symmetrically_distinct_miller_indices(structure, max_index): + gen = SlabGenerator( + structure, + miller, + min_slab_size, + min_vacuum_size, + lll_reduce=lll_reduce, + center_slab=center_slab, + primitive=primitive, + max_normal_search=max_normal_search, + in_unit_planes=in_unit_planes, + ) + slabs = gen.get_slabs( + bonds=bonds, + tol=tol, + ftol=ftol, + symmetrize=symmetrize, + max_broken_bonds=max_broken_bonds, + repair=repair, + ) + + if len(slabs) > 0: + logger.debug(f"{miller} has {len(slabs)} slabs... ") + all_slabs.extend(slabs) + + if include_reconstructions: + symbol = SpacegroupAnalyzer(structure).get_space_group_symbol() + # Enumerate through all reconstructions in the + # archive available for this particular spacegroup + for name, instructions in RECONSTRUCTIONS_ARCHIVE.items(): + if "base_reconstruction" in instructions: + instructions = RECONSTRUCTIONS_ARCHIVE[instructions["base_reconstruction"]] + + if instructions["spacegroup"]["symbol"] == symbol: + # Make sure this reconstruction has a max index + # equal or less than the given max index + if max(instructions["miller_index"]) > max_index: + continue + recon = ReconstructionGenerator(structure, min_slab_size, min_vacuum_size, name) + all_slabs.extend(recon.build_slabs()) + + return all_slabs + + +# Load the reconstructions_archive json file module_dir = os.path.dirname(os.path.abspath(__file__)) -with open(f"{module_dir}/reconstructions_archive.json") as data_file: - reconstructions_archive = json.load(data_file) +with open(f"{module_dir}/reconstructions_archive.json", encoding="utf-8") as data_file: + RECONSTRUCTIONS_ARCHIVE = json.load(data_file) + + +def get_d(slab: Slab) -> float: + """Determine the z-spacing between the bottom two layers for a Slab. + + TODO (@DanielYang59): this should be private/internal to ReconstructionGenerator? + """ + # Sort all sites by z-coordinates + sorted_sites = sorted(slab, key=lambda site: site.frac_coords[2]) + + for site, next_site in zip(sorted_sites, sorted_sites[1:]): + if not isclose(site.frac_coords[2], next_site.frac_coords[2], abs_tol=1e-6): + # DEBUG (@DanielYang59): code will break if no distinguishable layers found + distance = next_site.frac_coords[2] - site.frac_coords[2] + break + + return slab.lattice.get_cartesian_coords([0, 0, distance])[2] class ReconstructionGenerator: - """This class takes in a pre-defined dictionary specifying the parameters - need to build a reconstructed slab such as the SlabGenerator parameters, - transformation matrix, sites to remove/add and slab/vacuum size. It will - then use the formatted instructions provided by the dictionary to build - the desired reconstructed slab from the initial structure. + """Build a reconstructed Slab from a given initial Structure. + + This class needs a pre-defined dictionary specifying the parameters + needed such as the SlabGenerator parameters, transformation matrix, + sites to remove/add and slab/vacuum sizes. Attributes: slabgen_params (dict): Parameters for the SlabGenerator. - trans_matrix (np.ndarray): A 3x3 transformation matrix to generate the reconstructed - slab. Only the a and b lattice vectors are actually changed while the c vector remains - the same. This matrix is what the Wood's notation is based on. - reconstruction_json (dict): The full json or dictionary containing the instructions for - building the reconstructed slab. - termination (int): The index of the termination of the slab. + trans_matrix (np.ndarray): A 3x3 transformation matrix to generate + the reconstructed slab. Only the a and b lattice vectors are + actually changed while the c vector remains the same. + This matrix is what the Wood's notation is based on. + reconstruction_json (dict): The full json or dictionary containing + the instructions for building the slab. Todo: - - Right now there is no way to specify what atom is being added. In the future, use basis sets? + - Right now there is no way to specify what atom is being added. + Use basis sets in the future? """ - def __init__(self, initial_structure, min_slab_size, min_vacuum_size, reconstruction_name) -> None: - """Generates reconstructed slabs from a set of instructions - specified by a dictionary or json file. + def __init__( + self, + initial_structure: Structure, + min_slab_size: float, + min_vacuum_size: float, + reconstruction_name: str, + ) -> None: + """Generates reconstructed slabs from a set of instructions. Args: initial_structure (Structure): Initial input structure. Note - that to ensure that the miller indices correspond to usual + that to ensure that the Miller indices correspond to usual crystallographic definitions, you should supply a conventional unit cell structure. - min_slab_size (float): In Angstroms - min_vacuum_size (float): In Angstroms - reconstruction_name (str): Name of the dict containing the instructions - for building a reconstructed slab. The dictionary can contain - any item the creator deems relevant, however any instructions - archived in pymatgen for public use needs to contain the - following keys and items to ensure compatibility with the - ReconstructionGenerator: - - "name" (str): A descriptive name for the type of - reconstruction. Typically the name will have the type - of structure the reconstruction is for, the Miller - index, and Wood's notation along with anything to - describe the reconstruction: e.g.: - "fcc_110_missing_row_1x2" - "description" (str): A longer description of your - reconstruction. This is to help future contributors who - want to add other types of reconstructions to the - archive on pymatgen to check if the reconstruction - already exists. Please read the descriptions carefully - before adding a new type of reconstruction to ensure it - is not in the archive yet. - "reference" (str): Optional reference to where the - reconstruction was taken from or first observed. - "spacegroup" (dict): e.g. {"symbol": "Fm-3m", "number": 225} - Indicates what kind of structure is this reconstruction. - "miller_index" ([h,k,l]): Miller index of your reconstruction + min_slab_size (float): Minimum Slab size in Angstrom. + min_vacuum_size (float): Minimum vacuum layer size in Angstrom. + reconstruction_name (str): Name of the dict containing the build + instructions. The dictionary can contain any item, however + any instructions archived in pymatgen for public use need + to contain the following keys and items to ensure + compatibility with the ReconstructionGenerator: + + "name" (str): A descriptive name for the reconstruction, + typically including the type of structure, + the Miller index, the Wood's notation and additional + descriptors for the reconstruction. + Example: "fcc_110_missing_row_1x2" + "description" (str): A detailed description of the + reconstruction, intended to assist future contributors + in avoiding duplicate entries. Please read the description + carefully before adding to prevent duplications. + "reference" (str): Optional reference to the source of + the reconstruction. + "spacegroup" (dict): A dictionary indicating the space group + of the reconstruction. e.g. {"symbol": "Fm-3m", "number": 225}. + "miller_index" ([h, k, l]): Miller index of the reconstruction "Woods_notation" (str): For a reconstruction, the a and b - lattice may change to accommodate the symmetry of the - reconstruction. This notation indicates the change in + lattice may change to accommodate the symmetry. + This notation indicates the change in the vectors relative to the primitive (p) or - conventional (c) slab cell. E.g. p(2x1): + conventional (c) slab cell. E.g. p(2x1). - Wood, E. A. (1964). Vocabulary of surface + Reference: Wood, E. A. (1964). Vocabulary of surface crystallography. Journal of Applied Physics, 35(4), 1306-1312. - "transformation_matrix" (numpy array): A 3x3 matrix to transform the slab. Only the a and b lattice vectors should change while the c vector remains the same. "SlabGenerator_parameters" (dict): A dictionary containing - the parameters for the SlabGenerator class excluding the - miller_index, min_slab_size and min_vac_size as the + the parameters for the SlabGenerator, excluding the + miller_index, min_slab_size and min_vac_size. As the Miller index is already specified and the min_slab_size - and min_vac_size can be changed regardless of what type - of reconstruction is used. Having a consistent set of + and min_vac_size can be changed regardless of the + reconstruction type. Having a consistent set of SlabGenerator parameters allows for the instructions to - be reused to consistently build a reconstructed slab. - "points_to_remove" (list of coords): A list of sites to - remove where the first two indices are fraction (in a - and b) and the third index is in units of 1/d (in c). - "points_to_add" (list of frac_coords): A list of sites to add - where the first two indices are fraction (in a an b) and - the third index is in units of 1/d (in c). - - "base_reconstruction" (dict): Option to base a reconstruction on - an existing reconstruction model also exists to easily build - the instructions without repeating previous work. E.g. the + be reused. + "points_to_remove" (list[site]): A list of sites to + remove where the first two indices are fractional (in a + and b) and the third index is in units of 1/d (in c), + see the below "Notes" for details. + "points_to_add" (list[site]): A list of sites to add + where the first two indices are fractional (in a an b) and + the third index is in units of 1/d (in c), see the below + "Notes" for details. + "base_reconstruction" (dict, Optional): A dictionary specifying + an existing reconstruction model upon which the current + reconstruction is built to avoid repetition. E.g. the alpha reconstruction of halites is based on the octopolar reconstruction but with the topmost atom removed. The dictionary for the alpha reconstruction would therefore contain the item "reconstruction_base": "halite_111_octopolar_2x2", and - additional sites for "points_to_remove" and "points_to_add" - can be added to modify this reconstruction. - - For "points_to_remove" and "points_to_add", the third index for - the c vector is in units of 1/d where d is the spacing - between atoms along hkl (the c vector) and is relative to - the topmost site in the unreconstructed slab. e.g. a point - of [0.5, 0.25, 1] corresponds to the 0.5 frac_coord of a, - 0.25 frac_coord of b and a distance of 1 atomic layer above - the topmost site. [0.5, 0.25, -0.5] where the third index - corresponds to a point half a atomic layer below the topmost - site. [0.5, 0.25, 0] corresponds to a point in the same - position along c as the topmost site. This is done because - while the primitive units of a and b will remain constant, - the user can vary the length of the c direction by changing - the slab layer or the vacuum layer. - - NOTE: THE DICTIONARY SHOULD ONLY CONTAIN "points_to_remove" AND - "points_to_add" FOR THE TOP SURFACE. THE ReconstructionGenerator - WILL MODIFY THE BOTTOM SURFACE ACCORDINGLY TO RETURN A SLAB WITH - EQUIVALENT SURFACES. + additional sites can be added by "points_to_add". + + Notes: + 1. For "points_to_remove" and "points_to_add", the third index + for the c vector is specified in units of 1/d, where d represents + the spacing between atoms along the hkl (the c vector), relative + to the topmost site in the unreconstructed slab. For instance, + a point of [0.5, 0.25, 1] corresponds to the 0.5 fractional + coordinate of a, 0.25 fractional coordinate of b, and a + distance of 1 atomic layer above the topmost site. Similarly, + [0.5, 0.25, -0.5] corresponds to a point half an atomic layer + below the topmost site, and [0.5, 0.25, 0] corresponds to a + point at the same position along c as the topmost site. + This approach is employed because while the primitive units + of a and b remain constant, the user can vary the length + of the c direction by adjusting the slab layer or the vacuum layer. + + 2. The dictionary should only provide "points_to_remove" and + "points_to_add" for the top surface. The ReconstructionGenerator + will modify the bottom surface accordingly to return a symmetric Slab. """ - if reconstruction_name not in reconstructions_archive: - raise KeyError( - f"{reconstruction_name=} does not exist in the archive. Please select " - f"from one of the following reconstructions: {list(reconstructions_archive)} " - "or add the appropriate dictionary to the archive file " - "reconstructions_archive.json." - ) - # Get the instructions to build the reconstruction - # from the reconstruction_archive - recon_json = copy.deepcopy(reconstructions_archive[reconstruction_name]) - new_points_to_add, new_points_to_remove = [], [] - if "base_reconstruction" in recon_json: - if "points_to_add" in recon_json: - new_points_to_add = recon_json["points_to_add"] - if "points_to_remove" in recon_json: - new_points_to_remove = recon_json["points_to_remove"] + def build_recon_json() -> dict: + """Build reconstruction instructions, optionally upon a base instruction set.""" + # Check if reconstruction instruction exists + # TODO (@DanielYang59): can we avoid asking user to modify the source file? + if reconstruction_name not in RECONSTRUCTIONS_ARCHIVE: + raise KeyError( + f"{reconstruction_name=} does not exist in the archive. " + "Please select from one of the following: " + f"{list(RECONSTRUCTIONS_ARCHIVE)} or add it to the " + "archive file 'reconstructions_archive.json'." + ) + + # Get the reconstruction instructions from the archive file + recon_json: dict = copy.deepcopy(RECONSTRUCTIONS_ARCHIVE[reconstruction_name]) # Build new instructions from a base reconstruction - recon_json = copy.deepcopy(reconstructions_archive[recon_json["base_reconstruction"]]) - if "points_to_add" in recon_json: - del recon_json["points_to_add"] - if "points_to_remove" in recon_json: - del recon_json["points_to_remove"] - if new_points_to_add: - recon_json["points_to_add"] = new_points_to_add - if new_points_to_remove: - recon_json["points_to_remove"] = new_points_to_remove - - slabgen_params = copy.deepcopy(recon_json["SlabGenerator_parameters"]) - slabgen_params["initial_structure"] = initial_structure.copy() - slabgen_params["miller_index"] = recon_json["miller_index"] - slabgen_params["min_slab_size"] = min_slab_size - slabgen_params["min_vacuum_size"] = min_vacuum_size + if "base_reconstruction" in recon_json: + new_points_to_add: list = [] + new_points_to_remove: list = [] + + if "points_to_add" in recon_json: + new_points_to_add = recon_json["points_to_add"] + if "points_to_remove" in recon_json: + new_points_to_remove = recon_json["points_to_remove"] + + # DEBUG (@DanielYang59): the following overwrites previously + # loaded "recon_json", use condition to avoid this + recon_json = copy.deepcopy(RECONSTRUCTIONS_ARCHIVE[recon_json["base_reconstruction"]]) + + # TODO (@DanielYang59): use "site" over "point" for consistency? + if "points_to_add" in recon_json: + del recon_json["points_to_add"] + if new_points_to_add: + recon_json["points_to_add"] = new_points_to_add + + if "points_to_remove" in recon_json: + del recon_json["points_to_remove"] + if new_points_to_remove: + recon_json["points_to_remove"] = new_points_to_remove + + return recon_json + def build_slabgen_params() -> dict: + """Build SlabGenerator parameters.""" + slabgen_params: dict = copy.deepcopy(recon_json["SlabGenerator_parameters"]) + slabgen_params["initial_structure"] = initial_structure.copy() + slabgen_params["miller_index"] = recon_json["miller_index"] + slabgen_params["min_slab_size"] = min_slab_size + slabgen_params["min_vacuum_size"] = min_vacuum_size + + return slabgen_params + + # Build reconstruction instructions + recon_json = build_recon_json() + + # Build SlabGenerator parameters + slabgen_params = build_slabgen_params() + + self.name = reconstruction_name self.slabgen_params = slabgen_params - self.trans_matrix = recon_json["transformation_matrix"] self.reconstruction_json = recon_json - self.name = reconstruction_name + self.trans_matrix = recon_json["transformation_matrix"] - def build_slabs(self): - """Builds the reconstructed slab by: - (1) Obtaining the unreconstructed slab using the specified + def build_slabs(self) -> list[Slab]: + """Build reconstructed Slabs by: + (1) Obtaining the unreconstructed Slab using the specified parameters for the SlabGenerator. - (2) Applying the appropriate lattice transformation in the + (2) Applying the appropriate lattice transformation to the a and b lattice vectors. - (3) Remove any specified sites from both surfaces. - (4) Add any specified sites to both surfaces. + (3) Remove and then add specified sites from both surfaces. Returns: - Slab: The reconstructed slab. + list[Slab]: The reconstructed slabs. """ slabs = self.get_unreconstructed_slabs() + recon_slabs = [] for slab in slabs: - d = get_d(slab) + z_spacing = get_d(slab) top_site = sorted(slab, key=lambda site: site.frac_coords[2])[-1].coords - # Remove any specified sites + # Remove specified sites if "points_to_remove" in self.reconstruction_json: - pts_to_rm = copy.deepcopy(self.reconstruction_json["points_to_remove"]) - for p in pts_to_rm: - p[2] = slab.lattice.get_fractional_coords([top_site[0], top_site[1], top_site[2] + p[2] * d])[2] - cart_point = slab.lattice.get_cartesian_coords(p) - dist = [site.distance_from_point(cart_point) for site in slab] - site1 = dist.index(min(dist)) - slab.symmetrically_remove_atoms([site1]) - - # Add any specified sites + sites_to_rm: list = copy.deepcopy(self.reconstruction_json["points_to_remove"]) + for site in sites_to_rm: + site[2] = slab.lattice.get_fractional_coords( + [top_site[0], top_site[1], top_site[2] + site[2] * z_spacing] + )[2] + + # Find and remove nearest site + cart_point = slab.lattice.get_cartesian_coords(site) + distances: list[float] = [site.distance_from_point(cart_point) for site in slab] + nearest_site = distances.index(min(distances)) + slab.symmetrically_remove_atoms(indices=[nearest_site]) + + # Add specified sites if "points_to_add" in self.reconstruction_json: - pts_to_add = copy.deepcopy(self.reconstruction_json["points_to_add"]) - for p in pts_to_add: - p[2] = slab.lattice.get_fractional_coords([top_site[0], top_site[1], top_site[2] + p[2] * d])[2] - slab.symmetrically_add_atom(slab[0].specie, p) + sites_to_add: list = copy.deepcopy(self.reconstruction_json["points_to_add"]) + for site in sites_to_add: + site[2] = slab.lattice.get_fractional_coords( + [top_site[0], top_site[1], top_site[2] + site[2] * z_spacing] + )[2] + # TODO: see ReconstructionGenerator docstring: + # cannot specify species to add + slab.symmetrically_add_atom(species=slab[0].specie, point=site) slab.reconstruction = self.name slab.recon_trans_matrix = self.trans_matrix - # Get the oriented_unit_cell with the same axb area. + # Get the oriented unit cell with the same a*b area ouc = slab.oriented_unit_cell.copy() ouc.make_supercell(self.trans_matrix) slab.oriented_unit_cell = ouc @@ -1471,394 +1931,214 @@ def build_slabs(self): return recon_slabs - def get_unreconstructed_slabs(self): - """Generates the unreconstructed or pristine super slab.""" - slabs = [] - for slab in SlabGenerator(**self.slabgen_params).get_slabs(): - slab.make_supercell(self.trans_matrix) - slabs.append(slab) - return slabs - - -def get_d(slab): - """Determine the distance of space between - each layer of atoms along c. - """ - sorted_sites = sorted(slab, key=lambda site: site.frac_coords[2]) - for idx, site in enumerate(sorted_sites): - if f"{site.frac_coords[2]:.6f}" != f"{sorted_sites[idx + 1].frac_coords[2]:.6f}": - d = abs(site.frac_coords[2] - sorted_sites[idx + 1].frac_coords[2]) - break - return slab.lattice.get_cartesian_coords([0, 0, d])[2] - + def get_unreconstructed_slabs(self) -> list[Slab]: + """Generate the unreconstructed (super) Slabs. -def is_already_analyzed(miller_index: tuple, miller_list: list, symm_ops: list) -> bool: - """Helper function to check if a given Miller index is - part of the family of indices of any index in a list. - - Args: - miller_index (tuple): The Miller index to analyze - miller_list (list): List of Miller indices. If the given - Miller index belongs in the same family as any of the - indices in this list, return True, else return False - symm_ops (list): Symmetry operations of a - lattice, used to define family of indices - """ - return any(in_coord_list(miller_list, op.operate(miller_index)) for op in symm_ops) + TODO (@DanielYang59): this should be a private method. + """ + return [slab.make_supercell(self.trans_matrix) for slab in SlabGenerator(**self.slabgen_params).get_slabs()] def get_symmetrically_equivalent_miller_indices( - structure, - miller_index, - return_hkil=True, + structure: Structure, + miller_index: tuple[int, ...], + return_hkil: bool = True, system: CrystalSystem | None = None, -): - """Returns all symmetrically equivalent indices for a given structure. Analysis - is based on the symmetry of the reciprocal lattice of the structure. +) -> list: + """Get indices for all equivalent sites within a given structure. + Analysis is based on the symmetry of its reciprocal lattice. Args: - structure (Structure): Structure to analyze + structure (Structure): Structure to analyze. miller_index (tuple): Designates the family of Miller indices - to find. Can be hkl or hkil for hexagonal systems - return_hkil (bool): If true, return hkil form of Miller - index for hexagonal systems, otherwise return hkl - system: If known, specify the crystal system of the structure - so that it does not need to be re-calculated. + to find. Can be hkl or hkil for hexagonal systems. + return_hkil (bool): Whether to return hkil (True) form of Miller + index for hexagonal systems, or hkl (False). + system: The crystal system of the structure. """ - # Change to hkl if hkil because in_coord_list only handles tuples of 3 - miller_index = (miller_index[0], miller_index[1], miller_index[3]) if len(miller_index) == 4 else miller_index - mmi = max(np.abs(miller_index)) - rng = list(range(-mmi, mmi + 1)) - rng.reverse() - - sg = None - if not system: - sg = SpacegroupAnalyzer(structure) - system = sg.get_crystal_system() + # Convert to hkl if hkil, because in_coord_list only handles tuples of 3 + if len(miller_index) >= 3: + _miller_index: tuple[int, int, int] = (miller_index[0], miller_index[1], miller_index[-1]) + max_idx = max(np.abs(miller_index)) + idx_range = list(range(-max_idx, max_idx + 1)) + idx_range.reverse() + + # Skip crystal system analysis if already given + if system: + spg_analyzer = None + else: + spg_analyzer = SpacegroupAnalyzer(structure) + system = spg_analyzer.get_crystal_system() # Get distinct hkl planes from the rhombohedral setting if trigonal if system == "trigonal": - if not sg: - sg = SpacegroupAnalyzer(structure) - prim_structure = sg.get_primitive_standard_structure() + if not spg_analyzer: + spg_analyzer = SpacegroupAnalyzer(structure) + prim_structure = spg_analyzer.get_primitive_standard_structure() symm_ops = prim_structure.lattice.get_recp_symmetry_operation() + else: symm_ops = structure.lattice.get_recp_symmetry_operation() - equivalent_millers = [miller_index] - for miller in itertools.product(rng, rng, rng): - if miller == miller_index: + equivalent_millers: list[tuple[int, int, int]] = [_miller_index] + for miller in itertools.product(idx_range, idx_range, idx_range): + if miller == _miller_index: continue - if any(i != 0 for i in miller): - if is_already_analyzed(miller, equivalent_millers, symm_ops): - equivalent_millers.append(miller) - # include larger Miller indices in the family of planes + if any(idx != 0 for idx in miller): + if _is_in_miller_family(miller, equivalent_millers, symm_ops): + equivalent_millers += [miller] + + # Include larger Miller indices in the family of planes if ( - all(mmi > i for i in np.abs(miller)) + all(max_idx > i for i in np.abs(miller)) and not in_coord_list(equivalent_millers, miller) - and is_already_analyzed(mmi * np.array(miller), equivalent_millers, symm_ops) + and _is_in_miller_family(max_idx * np.array(miller), equivalent_millers, symm_ops) ): - equivalent_millers.append(miller) + equivalent_millers += [miller] - if return_hkil and system in ("trigonal", "hexagonal"): + # Convert hkl to hkil if necessary + if return_hkil and system in {"trigonal", "hexagonal"}: return [(hkl[0], hkl[1], -1 * hkl[0] - hkl[1], hkl[2]) for hkl in equivalent_millers] + return equivalent_millers -def get_symmetrically_distinct_miller_indices(structure, max_index, return_hkil=False): - """Returns all symmetrically distinct indices below a certain max-index for - a given structure. Analysis is based on the symmetry of the reciprocal - lattice of the structure. +def get_symmetrically_distinct_miller_indices( + structure: Structure, + max_index: int, + return_hkil: bool = False, +) -> list: + """Find all symmetrically distinct indices below a certain max-index + for a given structure. Analysis is based on the symmetry of the + reciprocal lattice of the structure. Args: - structure (Structure): input structure. - max_index (int): The maximum index. For example, a max_index of 1 - means that (100), (110), and (111) are returned for the cubic - structure. All other indices are equivalent to one of these. - return_hkil (bool): If true, return hkil form of Miller - index for hexagonal systems, otherwise return hkl + structure (Structure): The input structure. + max_index (int): The maximum index. For example, 1 means that + (100), (110), and (111) are returned for the cubic structure. + All other indices are equivalent to one of these. + return_hkil (bool): Whether to return hkil (True) form of Miller + index for hexagonal systems, or hkl (False). """ - r = list(range(-max_index, max_index + 1)) - r.reverse() + # Get a list of all hkls for conventional (including equivalent) + rng = list(range(-max_index, max_index + 1))[::-1] + conv_hkl_list = [miller for miller in itertools.product(rng, rng, rng) if any(i != 0 for i in miller)] - # First we get a list of all hkls for conventional (including equivalent) - conv_hkl_list = [miller for miller in itertools.product(r, r, r) if any(i != 0 for i in miller)] - - # Sort by the maximum of the absolute values of individual Miller indices so that - # low-index planes are first. This is important for trigonal systems. + # Sort by the maximum absolute values of Miller indices so that + # low-index planes come first. This is important for trigonal systems. conv_hkl_list = sorted(conv_hkl_list, key=lambda x: max(np.abs(x))) - sg = SpacegroupAnalyzer(structure) # Get distinct hkl planes from the rhombohedral setting if trigonal - if sg.get_crystal_system() == "trigonal": - transf = sg.get_conventional_to_primitive_transformation_matrix() - miller_list = [hkl_transformation(transf, hkl) for hkl in conv_hkl_list] + spg_analyzer = SpacegroupAnalyzer(structure) + if spg_analyzer.get_crystal_system() == "trigonal": + transf = spg_analyzer.get_conventional_to_primitive_transformation_matrix() + miller_list: list[tuple[int, int, int]] = [hkl_transformation(transf, hkl) for hkl in conv_hkl_list] prim_structure = SpacegroupAnalyzer(structure).get_primitive_standard_structure() symm_ops = prim_structure.lattice.get_recp_symmetry_operation() + else: miller_list = conv_hkl_list symm_ops = structure.lattice.get_recp_symmetry_operation() - unique_millers, unique_millers_conv = [], [] + unique_millers: list = [] + unique_millers_conv: list = [] - for i, miller in enumerate(miller_list): - d = abs(reduce(gcd, miller)) - miller = tuple(int(i / d) for i in miller) - if not is_already_analyzed(miller, unique_millers, symm_ops): - if sg.get_crystal_system() == "trigonal": + for idx, miller in enumerate(miller_list): + denom = abs(reduce(gcd, miller)) # type: ignore[arg-type] + miller = cast(tuple[int, int, int], tuple(int(idx / denom) for idx in miller)) + if not _is_in_miller_family(miller, unique_millers, symm_ops): + if spg_analyzer.get_crystal_system() == "trigonal": # Now we find the distinct primitive hkls using # the primitive symmetry operations and their # corresponding hkls in the conventional setting unique_millers.append(miller) - d = abs(reduce(gcd, conv_hkl_list[i])) - cmiller = tuple(int(i / d) for i in conv_hkl_list[i]) + denom = abs(reduce(gcd, conv_hkl_list[idx])) # type: ignore[arg-type] + cmiller = tuple(int(idx / denom) for idx in conv_hkl_list[idx]) unique_millers_conv.append(cmiller) else: unique_millers.append(miller) unique_millers_conv.append(miller) - if return_hkil and sg.get_crystal_system() in ["trigonal", "hexagonal"]: + if return_hkil and spg_analyzer.get_crystal_system() in {"trigonal", "hexagonal"}: return [(hkl[0], hkl[1], -1 * hkl[0] - hkl[1], hkl[2]) for hkl in unique_millers_conv] + return unique_millers_conv -def hkl_transformation(transf, miller_index): - """Returns the Miller index from setting - A to B using a transformation matrix +def _is_in_miller_family( + miller_index: tuple[int, int, int], + miller_list: list[tuple[int, int, int]], + symm_ops: list, +) -> bool: + """Helper function to check if the given Miller index belongs + to the same family of any index in the provided list. + Args: - transf (3x3 array): The transformation matrix - that transforms a lattice of A to B - miller_index ([h, k, l]): Miller index to transform to setting B. + miller_index (tuple): The Miller index to analyze. + miller_list (list): List of Miller indices. + symm_ops (list): Symmetry operations for a lattice, + used to define the indices family. """ - # Get a matrix of whole numbers (ints) - - def lcm(a, b): - return a * b // math.gcd(a, b) - - reduced_transf = reduce(lcm, [int(1 / i) for i in itertools.chain(*transf) if i != 0]) * transf - reduced_transf = reduced_transf.astype(int) - - # perform the transformation - t_hkl = np.dot(reduced_transf, miller_index) - d = abs(reduce(gcd, t_hkl)) - t_hkl = np.array([int(i / d) for i in t_hkl]) - - # get mostly positive oriented Miller index - if len([i for i in t_hkl if i < 0]) > 1: - t_hkl *= -1 - - return tuple(t_hkl) + return any(in_coord_list(miller_list, op.operate(miller_index)) for op in symm_ops) -def generate_all_slabs( - structure, - max_index, - min_slab_size, - min_vacuum_size, - bonds=None, - tol=0.1, - ftol=0.1, - max_broken_bonds=0, - lll_reduce=False, - center_slab=False, - primitive=True, - max_normal_search=None, - symmetrize=False, - repair=False, - include_reconstructions=False, - in_unit_planes=False, -): - """A function that finds all different slabs up to a certain miller index. - Slabs oriented under certain Miller indices that are equivalent to other - slabs in other Miller indices are filtered out using symmetry operations - to get rid of any repetitive slabs. For example, under symmetry operations, - CsCl has equivalent slabs in the (0,0,1), (0,1,0), and (1,0,0) direction. +def hkl_transformation( + transf: np.ndarray, + miller_index: tuple[int, int, int], +) -> tuple[int, int, int]: + """Transform the Miller index from setting A to B with a transformation matrix. Args: - structure (Structure): Initial input structure. Note that to - ensure that the miller indices correspond to usual - crystallographic definitions, you should supply a conventional - unit cell structure. - max_index (int): The maximum Miller index to go up to. - min_slab_size (float): In Angstroms - min_vacuum_size (float): In Angstroms - bonds ({(specie1, specie2): max_bond_dist}: bonds are - specified as a dict of tuples: float of specie1, specie2 - and the max bonding distance. For example, PO4 groups may be - defined as {("P", "O"): 3}. - tol (float): General tolerance parameter for getting primitive - cells and matching structures - ftol (float): Threshold parameter in fcluster in order to check - if two atoms are lying on the same plane. Default thresh set - to 0.1 Angstrom in the direction of the surface normal. - max_broken_bonds (int): Maximum number of allowable broken bonds - for the slab. Use this to limit # of slabs (some structures - may have a lot of slabs). Defaults to zero, which means no - defined bonds must be broken. - lll_reduce (bool): Whether to perform an LLL reduction on the - eventual structure. - center_slab (bool): Whether to center the slab in the cell with - equal vacuum spacing from the top and bottom. - primitive (bool): Whether to reduce any generated slabs to a - primitive cell (this does **not** mean the slab is generated - from a primitive cell, it simply means that after slab - generation, we attempt to find shorter lattice vectors, - which lead to less surface area and smaller cells). - max_normal_search (int): If set to a positive integer, the code will - conduct a search for a normal lattice vector that is as - perpendicular to the surface as possible by considering - multiples linear combinations of lattice vectors up to - max_normal_search. This has no bearing on surface energies, - but may be useful as a preliminary step to generating slabs - for absorption and other sizes. It is typical that this will - not be the smallest possible cell for simulation. Normality - is not guaranteed, but the oriented cell will have the c - vector as normal as possible (within the search range) to the - surface. A value of up to the max absolute Miller index is - usually sufficient. - symmetrize (bool): Whether or not to ensure the surfaces of the - slabs are equivalent. - repair (bool): Whether to repair terminations with broken bonds - or just omit them - include_reconstructions (bool): Whether to include reconstructed - slabs available in the reconstructions_archive.json file. Defaults to False. - in_unit_planes (bool): Whether to generate slabs in units of the primitive - cell's c lattice vector. This is useful for generating slabs with - a specific number of layers, as the number of layers will be - independent of the Miller index. Defaults to False. - in_unit_planes (bool): Whether to set min_slab_size and min_vac_size - in units of hkl planes (True) or Angstrom (False, the default). Setting in - units of planes is useful for ensuring some slabs have a certain n_layer of - atoms. e.g. for Cs (100), a 10 Ang slab will result in a slab with only 2 - layer of atoms, whereas Fe (100) will have more layer of atoms. By using units - of hkl planes instead, we ensure both slabs have the same number of atoms. The - slab thickness will be in min_slab_size/math.ceil(self._proj_height/dhkl) - multiples of oriented unit cells. + transf (3x3 array): The matrix that transforms a lattice from A to B. + miller_index (tuple[int, int, int]): The Miller index [h, k, l] to transform. """ - all_slabs = [] - - for miller in get_symmetrically_distinct_miller_indices(structure, max_index): - gen = SlabGenerator( - structure, - miller, - min_slab_size, - min_vacuum_size, - lll_reduce=lll_reduce, - center_slab=center_slab, - primitive=primitive, - max_normal_search=max_normal_search, - in_unit_planes=in_unit_planes, - ) - slabs = gen.get_slabs( - bonds=bonds, - tol=tol, - ftol=ftol, - symmetrize=symmetrize, - max_broken_bonds=max_broken_bonds, - repair=repair, - ) - - if len(slabs) > 0: - logger.debug(f"{miller} has {len(slabs)} slabs... ") - all_slabs.extend(slabs) - - if include_reconstructions: - sg = SpacegroupAnalyzer(structure) - symbol = sg.get_space_group_symbol() - # enumerate through all posisble reconstructions in the - # archive available for this particular structure (spacegroup) - for name, instructions in reconstructions_archive.items(): - if "base_reconstruction" in instructions: - instructions = reconstructions_archive[instructions["base_reconstruction"]] - if instructions["spacegroup"]["symbol"] == symbol: - # check if this reconstruction has a max index - # equal or less than the given max index - if max(instructions["miller_index"]) > max_index: - continue - recon = ReconstructionGenerator(structure, min_slab_size, min_vacuum_size, name) - all_slabs.extend(recon.build_slabs()) - - return all_slabs + def math_lcm(a: int, b: int) -> int: + """Calculate the least common multiple.""" + return a * b // math.gcd(a, b) -def get_slab_regions(slab, blength=3.5): - """Function to get the ranges of the slab regions. Useful for discerning where - the slab ends and vacuum begins if the slab is not fully within the cell - Args: - slab (Structure): Structure object modelling the surface - blength (float, Ang): The bondlength between atoms. You generally - want this value to be larger than the actual bondlengths in - order to find atoms that are part of the slab. - """ - fcoords, indices, all_indices = [], [], [] - for site in slab: - # find sites with c < 0 (noncontiguous) - neighbors = slab.get_neighbors(site, blength, include_index=True, include_image=True) - for nn in neighbors: - if nn[0].frac_coords[2] < 0: - # sites are noncontiguous within cell - fcoords.append(nn[0].frac_coords[2]) - indices.append(nn[-2]) - if nn[-2] not in all_indices: - all_indices.append(nn[-2]) + # Convert the elements of the transformation matrix to integers + reduced_transf = reduce(math_lcm, [int(1 / i) for i in itertools.chain(*transf) if i != 0]) * transf + reduced_transf = reduced_transf.astype(int) - if fcoords: - # If slab is noncontiguous, locate the lowest - # site within the upper region of the slab - while fcoords: - last_fcoords = copy.copy(fcoords) - last_indices = copy.copy(indices) - site = slab[indices[fcoords.index(min(fcoords))]] - neighbors = slab.get_neighbors(site, blength, include_index=True, include_image=True) - fcoords, indices = [], [] - for nn in neighbors: - if 1 > nn[0].frac_coords[2] > 0 and nn[0].frac_coords[2] < site.frac_coords[2]: - # sites are noncontiguous within cell - fcoords.append(nn[0].frac_coords[2]) - indices.append(nn[-2]) - if nn[-2] not in all_indices: - all_indices.append(nn[-2]) + # Perform the transformation + transf_hkl = np.dot(reduced_transf, miller_index) + divisor = abs(reduce(gcd, transf_hkl)) # type: ignore[arg-type] + transf_hkl = np.array([idx // divisor for idx in transf_hkl]) - # Now locate the highest site within the lower region of the slab - upper_fcoords = [] - for site in slab: - if all(nn.index not in all_indices for nn in slab.get_neighbors(site, blength)): - upper_fcoords.append(site.frac_coords[2]) - coords = copy.copy(last_fcoords) if not fcoords else copy.copy(fcoords) - min_top = slab[last_indices[coords.index(min(coords))]].frac_coords[2] - ranges = [[0, max(upper_fcoords)], [min_top, 1]] - else: - # If the entire slab region is within the slab cell, just - # set the range as the highest and lowest site in the slab - sorted_sites = sorted(slab, key=lambda site: site.frac_coords[2]) - ranges = [[sorted_sites[0].frac_coords[2], sorted_sites[-1].frac_coords[2]]] + # Get positive Miller index + if len([i for i in transf_hkl if i < 0]) > 1: + transf_hkl *= -1 - return ranges + return tuple(transf_hkl) # type: ignore[return-value] -def miller_index_from_sites(lattice, coords, coords_are_cartesian=True, round_dp=4, verbose=True): - """Get the Miller index of a plane from a list of site coordinates. +def miller_index_from_sites( + lattice: Lattice | ArrayLike, + coords: ArrayLike, + coords_are_cartesian: bool = True, + round_dp: int = 4, + verbose: bool = True, +) -> tuple[int, int, int]: + """Get the Miller index of a plane, determined by a given set of coordinates. - A minimum of 3 sets of coordinates are required. If more than 3 sets of - coordinates are given, the best plane that minimises the distance to all - points will be calculated. + A minimum of 3 sets of coordinates are required. If more than 3 + coordinates are given, the plane that minimises the distance to all + sites will be calculated. Args: - lattice (list or Lattice): A 3x3 lattice matrix or `Lattice` object (for - example obtained from Structure.lattice). - coords (iterable): A list or numpy array of coordinates. Can be - Cartesian or fractional coordinates. If more than three sets of - coordinates are provided, the best plane that minimises the - distance to all sites will be calculated. + lattice (matrix or Lattice): A 3x3 lattice matrix or `Lattice` object. + coords (ArrayLike): A list or numpy array of coordinates. Can be + Cartesian or fractional coordinates. coords_are_cartesian (bool, optional): Whether the coordinates are - in Cartesian space. If using fractional coordinates set to False. + in Cartesian coordinates, or fractional (False). round_dp (int, optional): The number of decimal places to round the - miller index to. + Miller index to. verbose (bool, optional): Whether to print warnings. Returns: - (tuple): The Miller index. + tuple[int]: The Miller index. """ if not isinstance(lattice, Lattice): lattice = Lattice(lattice) @@ -1869,59 +2149,3 @@ def miller_index_from_sites(lattice, coords, coords_are_cartesian=True, round_dp round_dp=round_dp, verbose=verbose, ) - - -def center_slab(slab): - """The goal here is to ensure the center of the slab region - is centered close to c=0.5. This makes it easier to - find the surface sites and apply operations like doping. - - There are three cases where the slab in not centered: - - 1. The slab region is completely between two vacuums in the - box but not necessarily centered. We simply shift the - slab by the difference in its center of mass and 0.5 - along the c direction. - - 2. The slab completely spills outside the box from the bottom - and into the top. This makes it incredibly difficult to - locate surface sites. We iterate through all sites that - spill over (z>c) and shift all sites such that this specific - site is now on the other side. Repeat for all sites with z>c. - - 3. This is a simpler case of scenario 2. Either the top or bottom - slab sites are at c=0 or c=1. Treat as scenario 2. - - Args: - slab (Slab): Slab structure to center - - Returns: - Returns a centered slab structure - """ - # get a reasonable r cutoff to sample neighbors - bdists = sorted(nn[1] for nn in slab.get_neighbors(slab[0], 10) if nn[1] > 0) - r = bdists[0] * 3 - - all_indices = [idx for idx, site in enumerate(slab)] - - # check if structure is case 2 or 3, shift all the - # sites up to the other side until it is case 1 - for site in slab: - if any(nn[1] > slab.lattice.c for nn in slab.get_neighbors(site, r)): - shift = 1 - site.frac_coords[2] + 0.05 - slab.translate_sites(all_indices, [0, 0, shift]) - - # now the slab is case 1, shift the center of mass of the slab to 0.5 - weights = [s.species.weight for s in slab] - center_of_mass = np.average(slab.frac_coords, weights=weights, axis=0) - shift = 0.5 - center_of_mass[2] - slab.translate_sites(all_indices, [0, 0, shift]) - - return slab - - -def _reduce_vector(vector): - # small function to reduce vectors - - d = abs(reduce(gcd, vector)) - return tuple(int(i / d) for i in vector) diff --git a/pymatgen/core/tensors.py b/pymatgen/core/tensors.py index 9aa04d755d6..684cc15520e 100644 --- a/pymatgen/core/tensors.py +++ b/pymatgen/core/tensors.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.core import Structure __author__ = "Joseph Montoya" @@ -44,14 +46,14 @@ class Tensor(np.ndarray, MSONable): symbol = "T" - def __new__(cls, input_array, vscale=None, check_rank=None): + def __new__(cls, input_array, vscale=None, check_rank=None) -> Self: """Create a Tensor object. Note that the constructor uses __new__ rather than __init__ according to the standard method of subclassing numpy ndarrays. Args: input_array: (array-like with shape 3^N): array-like representing - a tensor quantity in standard (i. e. non-voigt) notation + a tensor quantity in standard (i. e. non-Voigt) notation vscale: (N x M array-like): a matrix corresponding to the coefficients of the Voigt-notation tensor check_rank: (int): If not None, checks that input_array's rank == check_rank. @@ -123,8 +125,8 @@ def rotate(self, matrix, tol: float = 1e-3): matrix = SquareTensor(matrix) if not matrix.is_rotation(tol): raise ValueError("Rotation matrix is not valid.") - sop = SymmOp.from_rotation_and_translation(matrix, [0.0, 0.0, 0.0]) - return self.transform(sop) + symm_op = SymmOp.from_rotation_and_translation(matrix, [0.0, 0.0, 0.0]) + return self.transform(symm_op) def einsum_sequence(self, other_arrays, einsum_string=None): """Calculates the result of an einstein summation expression.""" @@ -152,9 +154,9 @@ def project(self, n): Args: n (3x1 array-like): direction to project onto - Returns (float): - scalar value corresponding to the projection of - the tensor into the vector + Returns: + float: scalar value corresponding to the projection of + the tensor into the vector """ n = get_uvec(n) return self.einsum_sequence([n] * self.rank) @@ -254,8 +256,8 @@ def round(self, decimals=0): If decimals is negative, it specifies the number of positions to the left of the decimal point. - Returns (Tensor): - rounded tensor of same type + Returns: + Tensor: rounded tensor of same type """ return type(self)(np.round(self, decimals=decimals)) @@ -361,7 +363,7 @@ def get_voigt_dict(rank): return vdict @classmethod - def from_voigt(cls, voigt_input): + def from_voigt(cls, voigt_input) -> Self: """Constructor based on the voigt notation vector or matrix. Args: @@ -510,7 +512,7 @@ def from_values_indices( voigt_rank=None, vsym=True, verbose=False, - ): + ) -> Self: """Creates a tensor from values and indices, with options for populating the remainder of the tensor. @@ -533,18 +535,21 @@ def from_values_indices( # TODO: refactor rank inheritance to make this easier indices = np.array(indices) if voigt_rank: - shape = [3] * (voigt_rank % 2) + [6] * (voigt_rank // 2) + shape = np.array([3] * (voigt_rank % 2) + [6] * (voigt_rank // 2)) else: shape = np.ceil(np.max(indices + 1, axis=0) / 3.0) * 3 + base = np.zeros(shape.astype(int)) for v, idx in zip(values, indices): base[tuple(idx)] = v obj = cls.from_voigt(base) if 6 in shape else cls(base) + if populate: assert structure, "Populate option must include structure input" obj = obj.populate(structure, vsym=vsym, verbose=verbose) elif structure: obj = obj.fit_to_structure(structure) + return obj def populate( @@ -634,12 +639,11 @@ def as_dict(self, voigt: bool = False) -> dict: """Serializes the tensor object. Args: - voigt (bool): flag for whether to store entries in - Voigt notation. Defaults to false, as information - may be lost in conversion. + voigt (bool): flag for whether to store entries in Voigt notation. + Defaults to false, as information may be lost in conversion. - Returns (dict): - serialized format tensor object + Returns: + dict: serialized format tensor object """ input_array = self.voigt if voigt else self dct = { @@ -652,15 +656,15 @@ def as_dict(self, voigt: bool = False) -> dict: return dct @classmethod - def from_dict(cls, d) -> Tensor: + def from_dict(cls, dct: dict) -> Self: """Instantiate Tensors from dicts (using MSONable API). Returns: Tensor: hydrated tensor object """ - if d.get("voigt"): - return cls.from_voigt(d["input_array"]) - return cls(d["input_array"]) + if dct.get("voigt"): + return cls.from_voigt(dct["input_array"]) + return cls(dct["input_array"]) class TensorCollection(collections.abc.Sequence, MSONable): @@ -669,8 +673,10 @@ class TensorCollection(collections.abc.Sequence, MSONable): """ def __init__(self, tensor_list: Sequence, base_class=Tensor) -> None: - """:param tensor_list: List of tensors. - :param base_class: Class to be used. + """ + Args: + tensor_list: List of tensors. + base_class: Class to be used. """ self.tensors = [tensor if isinstance(tensor, base_class) else base_class(tensor) for tensor in tensor_list] @@ -684,7 +690,9 @@ def __iter__(self): return iter(self.tensors) def zeroed(self, tol: float = 1e-3): - """:param tol: Tolerance + """ + Args: + tol: Tolerance. Returns: TensorCollection where small values are set to 0. @@ -694,7 +702,8 @@ def zeroed(self, tol: float = 1e-3): def transform(self, symm_op): """Transforms TensorCollection with a symmetry operation. - :param symm_op: SymmetryOperation. + Args: + symm_op: SymmetryOperation. Returns: TensorCollection. @@ -704,8 +713,9 @@ def transform(self, symm_op): def rotate(self, matrix, tol: float = 1e-3): """Rotates TensorCollection. - :param matrix: Rotation matrix. - :param tol: tolerance. + Args: + matrix: Rotation matrix. + tol: tolerance. Returns: TensorCollection. @@ -717,8 +727,10 @@ def symmetrized(self): """TensorCollection where all tensors are symmetrized.""" return type(self)([tensor.symmetrized for tensor in self]) - def is_symmetric(self, tol: float = 1e-5): - """:param tol: tolerance + def is_symmetric(self, tol: float = 1e-5) -> bool: + """ + Args: + tol: tolerance. Returns: Whether all tensors are symmetric. @@ -728,8 +740,9 @@ def is_symmetric(self, tol: float = 1e-5): def fit_to_structure(self, structure: Structure, symprec: float = 0.1): """Fits all tensors to a Structure. - :param structure: Structure - :param symprec: symmetry precision. + Args: + structure: Structure + symprec: symmetry precision. Returns: TensorCollection. @@ -737,8 +750,10 @@ def fit_to_structure(self, structure: Structure, symprec: float = 0.1): return type(self)([tensor.fit_to_structure(structure, symprec) for tensor in self]) def is_fit_to_structure(self, structure: Structure, tol: float = 1e-2): - """:param structure: Structure - :param tol: tolerance + """ + Args: + structure: Structure + tol: tolerance. Returns: Whether all tensors are fitted to Structure. @@ -755,8 +770,10 @@ def ranks(self): """Ranks for all tensors.""" return [tensor.rank for tensor in self] - def is_voigt_symmetric(self, tol: float = 1e-6): - """:param tol: tolerance + def is_voigt_symmetric(self, tol: float = 1e-6) -> bool: + """ + Args: + tol: tolerance. Returns: Whether all tensors are voigt symmetric. @@ -764,11 +781,12 @@ def is_voigt_symmetric(self, tol: float = 1e-6): return all(tensor.is_voigt_symmetric(tol) for tensor in self) @classmethod - def from_voigt(cls, voigt_input_list, base_class=Tensor): + def from_voigt(cls, voigt_input_list, base_class=Tensor) -> Self: """Creates TensorCollection from voigt form. - :param voigt_input_list: List of voigt tensors - :param base_class: Class for tensor. + Args: + voigt_input_list: List of voigt tensors + base_class: Class for tensor. Returns: TensorCollection. @@ -778,9 +796,10 @@ def from_voigt(cls, voigt_input_list, base_class=Tensor): def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotation=True): """Convert all tensors to IEEE. - :param structure: Structure - :param initial_fit: Whether to perform an initial fit. - :param refine_rotation: Whether to refine the rotation. + Args: + structure: Structure + initial_fit: Whether to perform an initial fit. + refine_rotation: Whether to refine the rotation. Returns: TensorCollection. @@ -790,8 +809,9 @@ def convert_to_ieee(self, structure: Structure, initial_fit=True, refine_rotatio def round(self, *args, **kwargs): """Round all tensors. - :param args: Passthrough to Tensor.round - :param kwargs: Passthrough to Tensor.round + Args: + args: Passthrough to Tensor.round + kwargs: Passthrough to Tensor.round Returns: TensorCollection. @@ -804,7 +824,9 @@ def voigt_symmetrized(self): return type(self)([tensor.voigt_symmetrized for tensor in self]) def as_dict(self, voigt=False): - """:param voigt: Whether to use Voigt form. + """ + Args: + voigt: Whether to use Voigt form. Returns: Dict representation of TensorCollection. @@ -820,10 +842,11 @@ def as_dict(self, voigt=False): return dct @classmethod - def from_dict(cls, dct: dict) -> TensorCollection: + def from_dict(cls, dct: dict) -> Self: """Creates TensorCollection from dict. - :param dct: dict + Args: + dct: dict Returns: TensorCollection @@ -839,7 +862,7 @@ class SquareTensor(Tensor): (stress, strain etc.). """ - def __new__(cls, input_array, vscale=None): + def __new__(cls, input_array, vscale=None) -> Self: """Create a SquareTensor object. Note that the constructor uses __new__ rather than __init__ according to the standard method of subclassing numpy ndarrays. Error is thrown when the class is initialized with non-square matrix. diff --git a/pymatgen/core/trajectory.py b/pymatgen/core/trajectory.py index e7c32dcb59c..966781c7654 100644 --- a/pymatgen/core/trajectory.py +++ b/pymatgen/core/trajectory.py @@ -9,7 +9,7 @@ from collections.abc import Iterator, Sequence from fnmatch import fnmatch from pathlib import Path -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union import numpy as np from monty.io import zopen @@ -19,6 +19,10 @@ from pymatgen.io.ase import AseAtomsAdaptor from pymatgen.io.vasp.outputs import Vasprun, Xdatcar +if TYPE_CHECKING: + from typing_extensions import Self + + __author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen, Evan Spotte-Smith" __version__ = "0.1" __date__ = "Jun 29, 2022" @@ -306,7 +310,7 @@ def __getitem__(self, frames: int | slice | list[int]) -> Molecule | Structure | Args: frames: Indices of the trajectory to return. - Return: + Returns: Subset of trajectory """ # Convert to position mode if not already @@ -467,7 +471,7 @@ def as_dict(self) -> dict: } @classmethod - def from_structures(cls, structures: list[Structure], constant_lattice: bool = True, **kwargs) -> Trajectory: + def from_structures(cls, structures: list[Structure], constant_lattice: bool = True, **kwargs) -> Self: """Create trajectory from a list of structures. Note: Assumes no atoms removed during simulation. @@ -500,7 +504,7 @@ def from_structures(cls, structures: list[Structure], constant_lattice: bool = T ) @classmethod - def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Trajectory: + def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Self: """Create trajectory from a list of molecules. Note: Assumes no atoms removed during simulation. @@ -526,7 +530,7 @@ def from_molecules(cls, molecules: list[Molecule], **kwargs) -> Trajectory: ) @classmethod - def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs) -> Trajectory: + def from_file(cls, filename: str | Path, constant_lattice: bool = True, **kwargs) -> Self: """Create trajectory from XDATCAR, vasprun.xml file, or ASE trajectory (.traj) file. Args: diff --git a/pymatgen/core/units.py b/pymatgen/core/units.py index 27844808af0..9d5b1bb87a7 100644 --- a/pymatgen/core/units.py +++ b/pymatgen/core/units.py @@ -9,14 +9,17 @@ from __future__ import annotations import collections -import numbers import re from functools import partial -from typing import Any +from numbers import Number +from typing import TYPE_CHECKING, Any import numpy as np import scipy.constants as const +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Shyue Ping Ong, Matteo Giantomassi" __copyright__ = "Copyright 2011, The Materials Project" __version__ = "1.0" @@ -161,8 +164,6 @@ class Unit(collections.abc.Mapping): Only integer powers are supported for units. """ - Error = UnitError - def __init__(self, unit_def) -> None: """Constructs a unit. @@ -224,17 +225,17 @@ def as_base_units(self): """Converts all units to base SI units, including derived units. Returns: - (base_units_dict, scaling factor). base_units_dict will not - contain any constants, which are gathered in the scaling factor. + tuple[dict, float]: (base_units_dict, scaling factor). base_units_dict will not + contain any constants, which are gathered in the scaling factor. """ b = collections.defaultdict(int) factor = 1 for k, v in self.items(): derived = False - for d in DERIVED_UNITS.values(): - if k in d: - for k2, v2 in d[k].items(): - if isinstance(k2, numbers.Number): + for dct in DERIVED_UNITS.values(): + if k in dct: + for k2, v2 in dct[k].items(): + if isinstance(k2, Number): factor *= k2 ** (v2 * v) else: b[k2] += v2 * v @@ -289,40 +290,7 @@ class FloatWithUnit(float): 32.932522246000005 eV """ - Error = UnitError - - @classmethod - def from_str(cls, s): - """Parse string to FloatWithUnit. - - Example: Memory.from_str("1. Mb") - """ - # Extract num and unit string. - s = s.strip() - for idx, char in enumerate(s): # noqa: B007 - if char.isalpha() or char.isspace(): - break - else: - raise Exception(f"Unit is missing in string {s}") - num, unit = float(s[:idx]), s[idx:] - - # Find unit type (set it to None if it cannot be detected) - for unit_type, d in BASE_UNITS.items(): # noqa: B007 - if unit in d: - break - else: - unit_type = None - - return cls(num, unit, unit_type=unit_type) - - def __new__(cls, val, unit, unit_type=None): - """Overrides __new__ since we are subclassing a Python primitive/.""" - new = float.__new__(cls, val) - new._unit = Unit(unit) - new._unit_type = unit_type - return new - - def __init__(self, val, unit, unit_type=None) -> None: + def __init__(self, val: float | Number, unit: str, unit_type: str | None = None) -> None: """Initializes a float with unit. Args: @@ -335,6 +303,13 @@ def __init__(self, val, unit, unit_type=None) -> None: self._unit = Unit(unit) self._unit_type = unit_type + def __new__(cls, val, unit, unit_type=None) -> Self: + """Overrides __new__ since we are subclassing a Python primitive.""" + new = float.__new__(cls, val) + new._unit = Unit(unit) + new._unit_type = unit_type + return new + def __str__(self) -> str: return f"{super().__str__()} {self._unit}" @@ -402,7 +377,7 @@ def __setstate__(self, state): self._unit = state["_unit"] @property - def unit_type(self) -> str: + def unit_type(self) -> str | None: """The type of unit. Energy, Charge, etc.""" return self._unit_type @@ -411,6 +386,26 @@ def unit(self) -> Unit: """The unit, e.g., "eV".""" return self._unit + @classmethod + def from_str(cls, s: str) -> Self: + """Parse string to FloatWithUnit. + Example: Memory.from_str("1. Mb"). + """ + # Extract num and unit string. + s = s.strip() + for _idx, char in enumerate(s): + if char.isalpha() or char.isspace(): + break + else: + raise ValueError(f"Unit is missing in string {s}") + num, unit = float(s[:_idx]), s[_idx:] + + # Find unit type (set it to None if it cannot be detected) + for unit_type, dct in BASE_UNITS.items(): + if unit in dct: + return cls(num, unit, unit_type=unit_type) + return cls(num, unit, unit_type=None) + def to(self, new_unit): """Conversion to a new_unit. Right now, only supports 1 to 1 mapping of units of each type. @@ -466,9 +461,7 @@ class ArrayWithUnit(np.ndarray): array([ 28.21138386, 56.42276772]) eV """ - Error = UnitError - - def __new__(cls, input_array, unit, unit_type=None): + def __new__(cls, input_array, unit, unit_type=None) -> Self: """Override __new__.""" # Input array is an already formed ndarray instance # We first cast to be our class type @@ -711,10 +704,12 @@ def obj_with_unit(obj: Any, unit: str) -> FloatWithUnit | ArrayWithUnit | dict[s """ unit_type = _UNAME2UTYPE[unit] - if isinstance(obj, numbers.Number): + if isinstance(obj, Number): return FloatWithUnit(obj, unit=unit, unit_type=unit_type) + if isinstance(obj, collections.abc.Mapping): - return {k: obj_with_unit(v, unit) for k, v in obj.items()} # type: ignore + return {k: obj_with_unit(v, unit) for k, v in obj.items()} # type: ignore[misc] + return ArrayWithUnit(obj, unit=unit, unit_type=unit_type) @@ -751,7 +746,7 @@ def wrapped_f(*args, **kwargs): if isinstance(val, collections.abc.Mapping): for k, v in val.items(): val[k] = FloatWithUnit(v, unit_type=unit_type, unit=unit) - elif isinstance(val, numbers.Number): + elif isinstance(val, Number): return FloatWithUnit(val, unit_type=unit_type, unit=unit) elif val is None: pass diff --git a/pymatgen/core/xcfunc.py b/pymatgen/core/xcfunc.py index 2354be2e3fe..e8b6b935791 100644 --- a/pymatgen/core/xcfunc.py +++ b/pymatgen/core/xcfunc.py @@ -1,14 +1,18 @@ -"""This module provides.""" +"""This module provides class for XC correlation functional.""" from __future__ import annotations from collections import namedtuple +from typing import TYPE_CHECKING from monty.functools import lazy_property from monty.json import MSONable from pymatgen.core.libxcfunc import LibxcFunc +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Matteo Giantomassi" __copyright__ = "Copyright 2016, The Materials Project" __version__ = "3.0.0" # The libxc version used to generate this file! @@ -111,13 +115,29 @@ class XcFunc(MSONable): del xcf + def __init__(self, xc: LibxcFunc | None = None, x: LibxcFunc | None = None, c: LibxcFunc | None = None) -> None: + """ + Args: + xc: LibxcFunc for XC functional. + x: LibxcFunc for exchange part. Mutually exclusive with xc. + c: LibxcFunc for correlation part. Mutually exclusive with xc. + """ + # Consistency check + if xc is None: + if x is None or c is None: + raise ValueError("x or c must be specified when xc is None") + elif x is not None or c is not None: + raise ValueError("x and c should be None when xc is specified") + + self.xc, self.x, self.c = xc, x, c + @classmethod - def aliases(cls): + def aliases(cls) -> list[str]: """List of registered names.""" return [nt.name for nt in cls.defined_aliases.values()] @classmethod - def asxc(cls, obj): + def asxc(cls, obj) -> Self: """Convert object into Xcfunc.""" if isinstance(obj, cls): return obj @@ -126,9 +146,8 @@ def asxc(cls, obj): raise TypeError(f"Don't know how to convert <{type(obj)}:{obj}> to Xcfunc") @classmethod - def from_abinit_ixc(cls, ixc): + def from_abinit_ixc(cls, ixc: int) -> Self | None: """Build the object from Abinit ixc (integer).""" - ixc = int(ixc) if ixc == 0: return None if ixc > 0: @@ -146,12 +165,12 @@ def from_abinit_ixc(cls, ixc): return cls(x=x, c=c) @classmethod - def from_name(cls, name): + def from_name(cls, name: str) -> Self: """Build the object from one of the registered names.""" return cls.from_type_name(None, name) @classmethod - def from_type_name(cls, typ, name): + def from_type_name(cls, typ: str | None, name: str) -> Self: """Build the object from (type, name).""" # Try aliases first. for k, nt in cls.defined_aliases.items(): @@ -168,12 +187,11 @@ def from_type_name(cls, typ, name): # name="GGA_X_PBE+GGA_C_PBE" or name=""LDA_XC_TETER93" if "+" in name: x, c = (s.strip() for s in name.split("+")) - x, c = LibxcFunc[x], LibxcFunc[c] - return cls(x=x, c=c) - xc = LibxcFunc[name] - return cls(xc=xc) + return cls(x=LibxcFunc[x], c=LibxcFunc[c]) + + return cls(xc=LibxcFunc[name]) - def as_dict(self): + def as_dict(self) -> dict: """Serialize to MSONable dict representation e.g. to write to disk as JSON.""" dct = {"@module": type(self).__module__, "@class": type(self).__name__} if self.x is not None: @@ -185,53 +203,48 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Deserialize from MSONable dict representation.""" return cls(xc=dct.get("xc"), x=dct.get("x"), c=dct.get("c")) - def __init__(self, xc=None, x=None, c=None) -> None: - """ - Args: - xc: LibxcFunc for XC functional. - x: LibxcFunc for exchange part. Mutually exclusive with xc. - c: LibxcFunc for correlation part. Mutually exclusive with xc. - """ - # Consistency check - if xc is None: - if x is None or c is None: - raise ValueError("x or c must be specified when xc is None") - elif x is not None or c is not None: - raise ValueError("x and c should be None when xc is specified") - - self.xc, self.x, self.c = xc, x, c - @lazy_property - def type(self): + def type(self) -> str | None: """The type of the functional.""" - if self.xc in self.defined_aliases: - return self.defined_aliases[self.xc].type + if self.xc in self.defined_aliases and self.xc is not None: + return self.defined_aliases[self.xc].type # type: ignore[index] + xc = self.x, self.c if xc in self.defined_aliases: - return self.defined_aliases[xc].type + return self.defined_aliases[xc].type # type: ignore[index] # If self is not in defined_aliases, use LibxcFunc family if self.xc is not None: return self.xc.family - return f"{self.x.family}+{self.c.family}" + + if self.x is not None and self.c is not None: + return f"{self.x.family}+{self.c.family}" + + return None @lazy_property - def name(self) -> str: + def name(self) -> str | None: """The name of the functional. If the functional is not found in the aliases, the string has the form X_NAME+C_NAME. """ if self.xc in self.defined_aliases: - return self.defined_aliases[self.xc].name + return self.defined_aliases[self.xc].name # type: ignore[index] + xc = (self.x, self.c) if xc in self.defined_aliases: - return self.defined_aliases[xc].name + return self.defined_aliases[xc].name # type: ignore[index] + if self.xc is not None: return self.xc.name - return f"{self.x.name}+{self.c.name}" + + if self.x is not None and self.c is not None: + return f"{self.x.name}+{self.c.name}" + + return None def __repr__(self) -> str: return str(self.name) diff --git a/pymatgen/electronic_structure/bandstructure.py b/pymatgen/electronic_structure/bandstructure.py index 9ef3d64377d..e5f7e4effa9 100644 --- a/pymatgen/electronic_structure/bandstructure.py +++ b/pymatgen/electronic_structure/bandstructure.py @@ -7,7 +7,7 @@ import math import re import warnings -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from monty.json import MSONable @@ -17,6 +17,9 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.coord import pbc_diff +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Geoffroy Hautier, Shyue Ping Ong, Michael Kocher" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.0" @@ -130,7 +133,7 @@ def as_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, dct) -> Kpoint: + def from_dict(cls, dct: dict) -> Self: """Create from dict. Args: @@ -187,7 +190,7 @@ def __init__( lattice: The reciprocal lattice as a pymatgen Lattice object. Pymatgen uses the physics convention of reciprocal lattice vectors WITH a 2*pi coefficient - efermi (float): fermi energy + efermi (float): Fermi energy labels_dict: (dict) of {} this links a kpoint (in frac coords or Cartesian coordinates depending on the coords) to a label. coords_are_cartesian: Whether coordinates are cartesian. @@ -211,7 +214,7 @@ def __init__( labels_dict = {} if len(self.projections) != 0 and self.structure is None: - raise Exception("if projections are provided a structure object needs also to be given") + raise RuntimeError("if projections are provided a structure object is also required") for k in kpoints: # let see if this kpoint has been assigned a label @@ -234,8 +237,8 @@ def get_projection_on_elements(self): """Method returning a dictionary of projections on elements. Returns: - a dictionary in the {Spin.up:[][{Element:values}], - Spin.down:[][{Element:values}]} format + a dictionary in the {Spin.up:[][{Element: [values]}], + Spin.down:[][{Element: [values]}]} format if there is no projections in the band structure returns an empty dict """ @@ -266,8 +269,7 @@ def get_projections_on_elements_and_orbitals(self, el_orb_spec): A dictionary of projections on elements in the {Spin.up:[][{Element:{orb:values}}], Spin.down:[][{Element:{orb:values}}]} format - if there is no projections in the band structure returns an empty - dict. + if there is no projections in the band structure returns an empty dict. """ result = {} structure = self.structure @@ -297,11 +299,9 @@ def is_metal(self, efermi_tol=1e-4) -> bool: Returns: bool: True if a metal. """ - for values in self.bands.values(): + for vals in self.bands.values(): for idx in range(self.nb_bands): - if np.any(values[idx, :] - self.efermi < -efermi_tol) and np.any( - values[idx, :] - self.efermi > efermi_tol - ): + if np.any(vals[idx, :] - self.efermi < -efermi_tol) and np.any(vals[idx, :] - self.efermi > efermi_tol): return True return False @@ -592,7 +592,7 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Create from dict. Args: @@ -640,7 +640,7 @@ def from_dict(cls, dct): return cls.from_old_dict(dct) @classmethod - def from_old_dict(cls, dct): + def from_old_dict(cls, dct) -> Self: """ Args: dct (dict): A dict with all data for a band structure symmetry line object. @@ -650,7 +650,7 @@ def from_old_dict(cls, dct): """ # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} - projections = {} + projections: dict = {} structure = None if dct.get("projections"): structure = Structure.from_dict(dct["structure"]) @@ -671,7 +671,7 @@ def from_old_dict(cls, dct): dd.append(np.array(ddd)) projections[Spin(int(spin))] = np.array(dd) - return BandStructure( + return cls( dct["kpoints"], {Spin(int(k)): dct["bands"][k] for k in dct["bands"]}, Lattice(dct["lattice_rec"]["matrix"]), @@ -934,7 +934,7 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): A dict with all data for a band structure symmetry line @@ -954,7 +954,7 @@ def from_dict(cls, dct): structure = Structure.from_dict(dct["structure"]) projections = {Spin(int(spin)): np.array(v) for spin, v in dct["projections"].items()} - return LobsterBandStructureSymmLine( + return cls( dct["kpoints"], {Spin(int(k)): dct["bands"][k] for k in dct["bands"]}, Lattice(dct["lattice_rec"]["matrix"]), @@ -970,10 +970,10 @@ def from_dict(cls, dct): "format. The old format will be retired in pymatgen " "5.0." ) - return LobsterBandStructureSymmLine.from_old_dict(dct) + return cls.from_old_dict(dct) @classmethod - def from_old_dict(cls, dct): + def from_old_dict(cls, dct) -> Self: """ Args: dct (dict): A dict with all data for a band structure symmetry line @@ -984,7 +984,7 @@ def from_old_dict(cls, dct): """ # Strip the label to recover initial string (see trick used in as_dict to handle $ chars) labels_dict = {k.strip(): v for k, v in dct["labels_dict"].items()} - projections = {} + projections: dict = {} structure = None if "projections" in dct and len(dct["projections"]) != 0: structure = Structure.from_dict(dct["structure"]) @@ -998,7 +998,7 @@ def from_old_dict(cls, dct): dd.append(np.array(ddd)) projections[Spin(int(spin))] = np.array(dd) - return LobsterBandStructureSymmLine( + return cls( dct["kpoints"], {Spin(int(k)): dct["bands"][k] for k in dct["bands"]}, Lattice(dct["lattice_rec"]["matrix"]), diff --git a/pymatgen/electronic_structure/boltztrap.py b/pymatgen/electronic_structure/boltztrap.py index cf9333300ce..d2582cb4689 100644 --- a/pymatgen/electronic_structure/boltztrap.py +++ b/pymatgen/electronic_structure/boltztrap.py @@ -9,7 +9,7 @@ You need version 1.2.3 or higher -References are:: +References are: Madsen, G. K. H., and Singh, D. J. (2006). BoltzTraP. A code for calculating band-structure dependent quantities. @@ -25,7 +25,7 @@ import tempfile import time from shutil import which -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np from monty.dev import requires @@ -45,6 +45,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core.sites import PeriodicSite from pymatgen.core.structure import Structure @@ -252,7 +253,8 @@ def nelec(self): def write_energy(self, output_file) -> None: """Writes the energy to an output file. - :param output_file: Filename + Args: + output_file: Filename """ with open(output_file, mode="w") as file: file.write("test\n") @@ -282,7 +284,7 @@ def write_energy(self, output_file) -> None: # use 90% of bottom bands since highest eigenvalues # are usually incorrect # ask Geoffroy Hautier for more details - nb_bands = int(math.floor(self._bs.nb_bands * (1 - self.cb_cut))) + nb_bands = math.floor(self._bs.nb_bands * (1 - self.cb_cut)) for j in range(nb_bands): eigs.append( Energy( @@ -304,7 +306,8 @@ def write_energy(self, output_file) -> None: def write_struct(self, output_file) -> None: """Writes the structure to an output file. - :param output_file: Filename + Args: + output_file: Filename """ if self._symprec is not None: sym = SpacegroupAnalyzer(self._bs.structure, symprec=self._symprec) @@ -338,7 +341,8 @@ def write_struct(self, output_file) -> None: def write_def(self, output_file) -> None: """Writes the def to an output file. - :param output_file: Filename + Args: + output_file: Filename """ # This function is useless in std version of BoltzTraP code # because x_trans script overwrite BoltzTraP.def @@ -363,10 +367,12 @@ def write_def(self, output_file) -> None: "'formatted',0\n" ) - def write_proj(self, output_file_proj, output_file_def) -> None: + def write_proj(self, output_file_proj: str, output_file_def: str) -> None: """Writes the projections to an output file. - :param output_file: Filename + Args: + output_file_proj: output file name + output_file_def: output file name """ # This function is useless in std version of BoltzTraP code # because x_trans script overwrite BoltzTraP.def @@ -378,7 +384,7 @@ def write_proj(self, output_file_proj, output_file_def) -> None: file.write(str(len(self._bs.kpoints)) + "\n") for i, kpt in enumerate(self._bs.kpoints): tmp_proj = [] - for j in range(int(math.floor(self._bs.nb_bands * (1 - self.cb_cut)))): + for j in range(math.floor(self._bs.nb_bands * (1 - self.cb_cut))): tmp_proj.append(self._bs.projections[Spin(self.spin)][j][i][oi][site_nb]) # TODO deal with the sorting going on at # the energy level!!! @@ -421,7 +427,8 @@ def write_proj(self, output_file_proj, output_file_def) -> None: def write_intrans(self, output_file) -> None: """Writes the intrans to an output file. - :param output_file: Filename + Args: + output_file: Filename """ setgap = 1 if self.scissor > 0.0001 else 0 @@ -497,7 +504,8 @@ def write_intrans(self, output_file) -> None: def write_input(self, output_dir) -> None: """Writes the input files. - :param output_dir: Directory to write the input files. + Args: + output_dir: Directory to write the input files. """ if self._bs.is_spin_polarized or self.soc: self.write_energy(f"{output_dir}/boltztrap.energyso") @@ -892,6 +900,7 @@ def check_acc_bzt_bands(sbs_bz, sbs_ref, warn_thr=(0.03, 0.03)): around the gap (semiconductors) or Fermi level (metals). warn_thr is a threshold to get a warning in the accuracy of Boltztap interpolated bands. + Return a dictionary with these keys: - "N": the index of the band compared; inside each there are: - "Corr": correlation coefficient for the 8 compared bands @@ -1472,9 +1481,9 @@ def is_isotropic(x, isotropy_tolerance) -> bool: d = self.get_zt(output="eigs", doping_levels=True) else: - raise ValueError(f"Target property: {target_prop} not recognized!") + raise ValueError(f"Unrecognized {target_prop=}") - absval = True # take the absolute value of properties + abs_val = True # take the absolute value of properties x_val = x_temp = x_doping = x_isotropic = None output = {} @@ -1491,7 +1500,7 @@ def is_isotropic(x, isotropy_tolerance) -> bool: doping_lvl = self.doping[pn][didx] if min_doping <= doping_lvl <= max_doping: isotropic = is_isotropic(evs, isotropy_tolerance) - if absval: + if abs_val: evs = [abs(x) for x in evs] val = float(sum(evs)) / len(evs) if use_average else max(evs) if x_val is None or (val > x_val and maximize) or (val < x_val and not maximize): @@ -1614,7 +1623,9 @@ def get_complete_dos(self, structure: Structure, analyzer_for_second_spin=None): return CompleteDos(structure, total_dos=total_dos, pdoss=pdoss) def get_mu_bounds(self, temp=300): - """:param temp: Temperature. + """ + Args: + temp: Temperature. Returns: The chemical potential bounds at that temperature. @@ -1783,7 +1794,7 @@ def parse_struct(path_dir): path_dir: (str) dir containing the boltztrap.struct file Returns: - (float) volume + float: volume of the structure in Angstrom^3 """ with open(f"{path_dir}/boltztrap.struct") as file: tokens = file.readlines() @@ -1925,7 +1936,7 @@ def parse_cond_and_hall(path_dir, doping_levels=None): ) @classmethod - def from_files(cls, path_dir, dos_spin=1): + def from_files(cls, path_dir: str, dos_spin: Literal[-1, 1] = 1) -> Self: """Get a BoltztrapAnalyzer object from a set of files. Args: @@ -1946,7 +1957,7 @@ def from_files(cls, path_dir, dos_spin=1): *cond_and_hall, carrier_conc = cls.parse_cond_and_hall(path_dir, doping_levels) - return cls(gap, *cond_and_hall, in_trans, dos, partial_dos, carrier_conc, vol, warning) + return cls(gap, *cond_and_hall, in_trans, dos, partial_dos, carrier_conc, vol, warning) # type: ignore[call-arg] if run_type == "DOS": trim = in_trans["dos_type"] == "HISTO" @@ -1994,9 +2005,11 @@ def as_dict(self): } return jsanitize(results) - @staticmethod - def from_dict(data): - """:param data: Dict representation. + @classmethod + def from_dict(cls, data: dict) -> Self: + """ + Args: + data: Dict representation. Returns: BoltztrapAnalyzer @@ -2112,7 +2125,7 @@ def _make_float_hall(a): vol = data.get("vol") warning = data.get("warning") - return BoltztrapAnalyzer( + return cls( gap=gap, mu_steps=mu_steps, cond=cond, @@ -2134,13 +2147,16 @@ def _make_float_hall(a): def read_cube_file(filename): - """:param filename: Cube filename + """ + + Args: + filename: Cube filename. Returns: Energy data. """ with open(filename) as file: - natoms = 0 + n_atoms = 0 for idx, line in enumerate(file): line = line.rstrip("\n") if idx == 0 and "CUBE" not in line: @@ -2148,7 +2164,7 @@ def read_cube_file(filename): if idx == 2: tokens = line.split() - natoms = int(tokens[0]) + n_atoms = int(tokens[0]) if idx == 3: tokens = line.split() n1 = int(tokens[0]) @@ -2162,12 +2178,12 @@ def read_cube_file(filename): break if "fort.30" in filename: - energy_data = np.genfromtxt(filename, skip_header=natoms + 6, skip_footer=1) + energy_data = np.genfromtxt(filename, skip_header=n_atoms + 6, skip_footer=1) n_lines_data = len(energy_data) - last_line = np.genfromtxt(filename, skip_header=n_lines_data + natoms + 6) + last_line = np.genfromtxt(filename, skip_header=n_lines_data + n_atoms + 6) energy_data = np.append(energy_data.flatten(), last_line).reshape(n1, n2, n3) elif "boltztrap_BZ.cube" in filename: - energy_data = np.loadtxt(filename, skiprows=natoms + 6).reshape(n1, n2, n3) + energy_data = np.loadtxt(filename, skiprows=n_atoms + 6).reshape(n1, n2, n3) energy_data /= Energy(1, "eV").to("Ry") diff --git a/pymatgen/electronic_structure/boltztrap2.py b/pymatgen/electronic_structure/boltztrap2.py index 63f4e963a31..1b78950cf74 100644 --- a/pymatgen/electronic_structure/boltztrap2.py +++ b/pymatgen/electronic_structure/boltztrap2.py @@ -29,6 +29,7 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING import matplotlib.pyplot as plt import numpy as np @@ -43,12 +44,18 @@ from pymatgen.io.vasp import Vasprun from pymatgen.symmetry.bandstructure import HighSymmKpath +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + try: from BoltzTraP2 import bandlib as BL from BoltzTraP2 import fite, sphere, units except ImportError: raise BoltztrapError("BoltzTraP2 has to be installed and working") + __author__ = "Francesco Ricci" __copyright__ = "Copyright 2018, The Materials Project" __version__ = "1.0" @@ -101,10 +108,7 @@ def __init__(self, obj, structure=None, nelect=None) -> None: self.is_spin_polarized = bs_obj.is_spin_polarized - if bs_obj.is_spin_polarized: - self.dosweight = 1.0 - else: - self.dosweight = 2.0 + self.dosweight = 1.0 if bs_obj.is_spin_polarized else 2.0 self.lattvec = self.atoms.get_cell().T * units.Angstrom self.mommat_all = None # not implemented yet @@ -131,7 +135,7 @@ def __init__(self, obj, structure=None, nelect=None) -> None: raise BoltztrapError("nelect must be given.") @classmethod - def from_file(cls, vasprun_file): + def from_file(cls, vasprun_file: str | Path) -> Self: """Get a vasprun.xml file and return a VasprunBSLoader.""" vrun_obj = Vasprun(vasprun_file, parse_projected_eigen=True) return cls(vrun_obj) @@ -165,7 +169,7 @@ def bandana(self, emin=-np.inf, emax=np.inf): self.proj = {} if self.proj_all: if len(self.proj_all) == 2: - h = int(len(accepted) / 2) + h = len(accepted) // 2 self.proj[Spin.up] = self.proj_all[Spin.up][:, accepted[:h], :, :] self.proj[Spin.down] = self.proj_all[Spin.down][:, accepted[h:], :, :] elif len(self.proj_all) == 1: @@ -203,10 +207,7 @@ def __init__(self, bs_obj, structure=None, nelect=None, mommat=None, magmom=None self.kpoints = np.array([kp.frac_coords for kp in bs_obj.kpoints]) - if structure is None: - self.structure = bs_obj.structure - else: - self.structure = structure + self.structure = bs_obj.structure if structure is None else structure self.atoms = AseAtomsAdaptor.get_atoms(self.structure) self.proj_all = None @@ -219,10 +220,7 @@ def __init__(self, bs_obj, structure=None, nelect=None, mommat=None, magmom=None self.is_spin_polarized = bs_obj.is_spin_polarized - if bs_obj.is_spin_polarized: - self.dosweight = 1.0 - else: - self.dosweight = 2.0 + self.dosweight = 1.0 if bs_obj.is_spin_polarized else 2.0 self.lattvec = self.atoms.get_cell().T * units.Angstrom self.mommat_all = mommat # not implemented yet @@ -263,7 +261,7 @@ def bandana(self, emin=-np.inf, emax=np.inf): self.proj = {} if self.proj_all: if len(self.proj_all) == 2: - h = int(len(accepted) / 2) + h = len(accepted) // 2 self.proj[Spin.up] = self.proj_all[Spin.up][:, accepted[:h], :, :] self.proj[Spin.down] = self.proj_all[Spin.down][:, accepted[h:], :, :] elif len(self.proj) == 1: @@ -347,10 +345,10 @@ def __init__(self, vrun_obj=None) -> None: self.cbm = self.fermi @classmethod - def from_file(cls, vasprun_file): + def from_file(cls, vasprun_file: str | Path) -> Self: """Get a vasprun.xml file and return a VasprunLoader.""" vrun_obj = Vasprun(vasprun_file, parse_projected_eigen=True) - return VasprunLoader(vrun_obj) + return cls(vrun_obj) def get_lattvec(self): """Lattice vectors.""" @@ -508,9 +506,9 @@ def get_band_structure(self, kpaths=None, kpoints_lbls_dict=None, density=20): if isinstance(kpaths, list) and isinstance(kpoints_lbls_dict, dict): kpoints = [] for kpath in kpaths: - for idx, k_pt in enumerate(kpath[:-1]): + for idx, k_pt in enumerate(kpath[:-1], start=1): sta = kpoints_lbls_dict[k_pt] - end = kpoints_lbls_dict[kpath[idx + 1]] + end = kpoints_lbls_dict[kpath[idx]] kpoints.append(np.linspace(sta, end, density)) kpoints = np.concatenate(kpoints) else: @@ -849,19 +847,23 @@ def compute_properties_doping(self, doping, temp_r=None) -> None: self.mu_doping_eV = {k: v / units.eV - self.efermi for k, v in mu_doping.items()} self.contain_props_doping = True - # def find_mu_doping(self, epsilon, dos, N0, T, dosweight=2.): + # def find_mu_doping(self, epsilon, dos, N0, T, dosweight=2.0): # """ - # Find the mu. + # Find the chemical potential (mu). + + # Args: + # epsilon (np.array): Array of energy values. + # dos (np.array): Array of density of states values. + # N0 (float): Background carrier concentration. + # T (float): Temperature in Kelvin. + # dosweight (float, optional): Weighting factor for the density of states. Default is 2.0. - # :param epsilon: - # :param dos: - # :param N0: - # :param T: - # :param dosweight: + # Returns: + # float: The chemical potential (mu) value. # """ # delta = np.empty_like(epsilon) - # for i, e in enumerate(epsilon): - # delta[i] = BL.calc_N(epsilon, dos, e, T, dosweight) + N0 + # for idx, eps in enumerate(epsilon): + # delta[idx] = BL.calc_N(epsilon, dos, eps, T, dosweight) + N0 # delta = np.abs(delta) # # Find the position optimizing this distance # pos = np.abs(delta).argmin() @@ -950,8 +952,13 @@ class BztPlotter: """ def __init__(self, bzt_transP=None, bzt_interp=None) -> None: - """:param bzt_transP: - :param bzt_interp: + """Placeholder. + + TODO: missing docstrings for __init__ + + Args: + bzt_transP (_type_, optional): _description_. Defaults to None. + bzt_interp (_type_, optional): _description_. Defaults to None. """ self.bzt_transP = bzt_transP self.bzt_interp = bzt_interp @@ -987,6 +994,9 @@ def plot_props( xlim: chemical potential range in eV, useful when prop_x='mu' ax: figure.axes where to plot. If None, a new figure is produced. + Returns: + plt.Axes: matplotlib Axes object + Example: bztPlotter.plot_props('S','mu','temp',temps=[600,900,1200]).show() more example are provided in the notebook @@ -1038,9 +1048,9 @@ def plot_props( mu = self.bzt_transP.mu_r_eV if prop_z == "doping" and prop_x == "temp": - p_array = eval(f"self.bzt_transP.{props[idx_prop]}_{prop_z}") + p_array = getattr(self.bzt_transP, f"{props[idx_prop]}_doping") else: - p_array = eval(f"self.bzt_transP.{props[idx_prop]}_{prop_x}") + p_array = getattr(self.bzt_transP, f"{props[idx_prop]}_{prop_x}") if ax is None: plt.figure(figsize=(10, 8)) @@ -1133,7 +1143,7 @@ def plot_props( plt.legend(title=leg_title if leg_title != "" else "", fontsize=15) plt.tight_layout() plt.grid() - return plt + return ax def plot_bands(self): """Plot a band structure on symmetry line using BSPlotter().""" @@ -1160,10 +1170,11 @@ def merge_up_down_doses(dos_up, dos_dn): """Merge the up and down DOSs. Args: - dos_up: Up DOS. - dos_dn: Down DOS - Return: - CompleteDos object + dos_up: Up DOS. + dos_dn: Down DOS + + Returns: + CompleteDos object """ warnings.warn("This function is not useful anymore. VasprunBSLoader deals with spin case.") cdos = Dos( diff --git a/pymatgen/electronic_structure/cohp.py b/pymatgen/electronic_structure/cohp.py index bbaf42c241e..4eddad85a6c 100644 --- a/pymatgen/electronic_structure/cohp.py +++ b/pymatgen/electronic_structure/cohp.py @@ -13,7 +13,7 @@ import re import sys import warnings -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from monty.json import MSONable @@ -28,6 +28,9 @@ from pymatgen.util.due import Doi, due from pymatgen.util.num import round_to_sigfigs +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Marco Esters, Janine George" __copyright__ = "Copyright 2017, The Materials Project" __version__ = "0.2" @@ -166,8 +169,8 @@ def has_antibnd_states_below_efermi(self, spin=None, limit=0.01): return None if spin is None: dict_to_return = {} - for sp, cohpvalues in populations.items(): - if (max(cohpvalues[0:n_energies_below_efermi])) > limit: + for sp, cohp_vals in populations.items(): + if (max(cohp_vals[0:n_energies_below_efermi])) > limit: dict_to_return[sp] = True else: dict_to_return[sp] = False @@ -176,8 +179,7 @@ def has_antibnd_states_below_efermi(self, spin=None, limit=0.01): if isinstance(spin, int): spin = Spin(spin) elif isinstance(spin, str): - s = {"up": 1, "down": -1}[spin.lower()] - spin = Spin(s) + spin = Spin({"up": 1, "down": -1}[spin.lower()]) if (max(populations[spin][0:n_energies_below_efermi])) > limit: dict_to_return[spin] = True else: @@ -186,12 +188,12 @@ def has_antibnd_states_below_efermi(self, spin=None, limit=0.01): return dict_to_return @classmethod - def from_dict(cls, dct: dict[str, Any]) -> Cohp: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Returns a COHP object from a dict representation of the COHP.""" icohp = {Spin(int(key)): np.array(val) for key, val in dct["ICOHP"].items()} if "ICOHP" in dct else None are_cobis = dct.get("are_cobis", False) are_multi_center_cobis = dct.get("are_multi_center_cobis", False) - return Cohp( + return cls( dct["efermi"], dct["energies"], {Spin(int(key)): np.array(val) for key, val in dct["COHP"].items()}, @@ -439,7 +441,7 @@ def get_summed_cohp_by_label_and_orbital_list( first_cohpobject = self.get_orbital_resolved_cohp(label_list[0], orbital_list[0]) summed_cohp = first_cohpobject.cohp.copy() summed_icohp = first_cohpobject.icohp.copy() - for ilabel, label in enumerate(label_list[1:], 1): + for ilabel, label in enumerate(label_list[1:], start=1): cohp_here = self.get_orbital_resolved_cohp(label, orbital_list[ilabel]) summed_cohp[Spin.up] = np.sum([summed_cohp[Spin.up], cohp_here.cohp.copy()[Spin.up]], axis=0) if Spin.down in summed_cohp: @@ -543,7 +545,7 @@ def get_orbital_resolved_cohp(self, label, orbitals, summed_spin_channels=False) ) @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """Returns CompleteCohp object from dict representation.""" # TODO: clean that mess up? cohp_dict = {} @@ -584,7 +586,7 @@ def from_dict(cls, dct): cohp_dict[label] = Cohp(efermi, energies, cohp, icohp=icohp) if "orb_res_cohp" in dct: - orb_cohp = {} + orb_cohp: dict[str, dict] = {} for label in dct["orb_res_cohp"]: orb_cohp[label] = {} for orb in dct["orb_res_cohp"][label]: @@ -607,7 +609,7 @@ def from_dict(cls, dct): } # If no total COHPs are present, calculate the total # COHPs from the single-orbital populations. Total COHPs - # may not be present when the cohpgenerator keyword is used + # may not be present when the COHP generator keyword is used # in LOBSTER versions 2.2.0 and earlier. if label not in dct["COHP"] or dct["COHP"][label] is None: cohp = { @@ -640,11 +642,11 @@ def from_dict(cls, dct): except KeyError: pass else: - orb_cohp = None + orb_cohp = {} are_cobis = dct.get("are_cobis", False) - return CompleteCohp( + return cls( structure, avg_cohp, cohp_dict, @@ -658,7 +660,7 @@ def from_dict(cls, dct): @classmethod def from_file( cls, fmt, filename=None, structure_file=None, are_coops=False, are_cobis=False, are_multi_center_cobis=False - ): + ) -> Self: """ Creates a CompleteCohp object from an output file of a COHP calculation. Valid formats are either LMTO (for the Stuttgart @@ -695,7 +697,7 @@ def from_file( structure_file = "CTRL" if filename is None: filename = "COPL" - cohp_file = LMTOCopl(filename=filename, to_eV=True) + cohp_file: LMTOCopl | Cohpcar = LMTOCopl(filename=filename, to_eV=True) elif fmt == "LOBSTER": if ( (are_coops and are_cobis) @@ -762,7 +764,7 @@ def from_file( if fmt == "LMTO": # Calculate the average COHP for the LMTO file to be # consistent with LOBSTER output. - avg_data = {"COHP": {}, "ICOHP": {}} + avg_data: dict[str, dict] = {"COHP": {}, "ICOHP": {}} for i in avg_data: for spin in spins: rows = np.array([v[i][spin] for v in cohp_data.values()]) @@ -770,57 +772,52 @@ def from_file( # LMTO COHPs have 5 significant figures avg_data[i].update({spin: np.array([round_to_sigfigs(a, 5) for a in avg], dtype=float)}) avg_cohp = Cohp(efermi, energies, avg_data["COHP"], icohp=avg_data["ICOHP"]) + elif not are_multi_center_cobis: + avg_cohp = Cohp( + efermi, + energies, + cohp_data["average"]["COHP"], + icohp=cohp_data["average"]["ICOHP"], + are_coops=are_coops, + are_cobis=are_cobis, + are_multi_center_cobis=are_multi_center_cobis, + ) + del cohp_data["average"] else: - if not are_multi_center_cobis: - avg_cohp = Cohp( - efermi, - energies, - cohp_data["average"]["COHP"], - icohp=cohp_data["average"]["ICOHP"], - are_coops=are_coops, - are_cobis=are_cobis, - are_multi_center_cobis=are_multi_center_cobis, - ) - del cohp_data["average"] - else: - # only include two-center cobis in average - # do this for both spin channels - cohp = {} - cohp[Spin.up] = np.array( - [np.array(c["COHP"][Spin.up]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] + # only include two-center cobis in average + # do this for both spin channels + cohp = {} + cohp[Spin.up] = np.array( + [np.array(c["COHP"][Spin.up]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] + ).mean(axis=0) + try: + cohp[Spin.down] = np.array( + [np.array(c["COHP"][Spin.down]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] + ).mean(axis=0) + except KeyError: + pass + try: + icohp = {} + icohp[Spin.up] = np.array( + [np.array(c["ICOHP"][Spin.up]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] ).mean(axis=0) try: - cohp[Spin.down] = np.array( - [np.array(c["COHP"][Spin.down]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] + icohp[Spin.down] = np.array( + [np.array(c["ICOHP"][Spin.down]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] ).mean(axis=0) except KeyError: pass - try: - icohp = {} - icohp[Spin.up] = np.array( - [np.array(c["ICOHP"][Spin.up]) for c in cohp_file.cohp_data.values() if len(c["sites"]) <= 2] - ).mean(axis=0) - try: - icohp[Spin.down] = np.array( - [ - np.array(c["ICOHP"][Spin.down]) - for c in cohp_file.cohp_data.values() - if len(c["sites"]) <= 2 - ] - ).mean(axis=0) - except KeyError: - pass - except KeyError: - icohp = None - avg_cohp = Cohp( - efermi, - energies, - cohp, - icohp=icohp, - are_coops=are_coops, - are_cobis=are_cobis, - are_multi_center_cobis=are_multi_center_cobis, - ) + except KeyError: + icohp = None + avg_cohp = Cohp( + efermi, + energies, + cohp, + icohp=icohp, + are_coops=are_coops, + are_cobis=are_cobis, + are_multi_center_cobis=are_multi_center_cobis, + ) cohp_dict = { key: Cohp( @@ -843,7 +840,7 @@ def from_file( for key, dct in cohp_data.items() } - return CompleteCohp( + return cls( structure, avg_cohp, cohp_dict, @@ -1004,7 +1001,7 @@ def icohpvalue_orbital(self, orbitals, spin=Spin.up): @property def icohp(self): """Dict with icohps for spinup and spindown - Return: + Returns: dict={Spin.up: icohpvalue for spin.up, Spin.down: icohpvalue for spin.down}. """ return self._icohp diff --git a/pymatgen/electronic_structure/core.py b/pymatgen/electronic_structure/core.py index 9a24948aa36..f8cef567609 100644 --- a/pymatgen/electronic_structure/core.py +++ b/pymatgen/electronic_structure/core.py @@ -13,6 +13,10 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + + from pymatgen.core import Lattice + @unique class Spin(Enum): @@ -125,9 +129,11 @@ class Magmom(MSONable): def __init__( self, moment: float | Sequence[float] | np.ndarray | Magmom, saxis: Sequence[float] = (0, 0, 1) ) -> None: - """:param moment: magnetic moment, supplied as float or list/np.ndarray - :param saxis: spin axis, supplied as list/np.ndarray, parameter will - be converted to unit vector (default is [0, 0, 1]) + """ + Args: + moment: magnetic moment, supplied as float or list/np.ndarray + saxis: spin axis, supplied as list/np.ndarray, parameter will + be converted to unit vector (default is [0, 0, 1]). Returns: Magmom object @@ -148,7 +154,7 @@ def __init__( self.saxis = saxis / np.linalg.norm(saxis) @classmethod - def from_global_moment_and_saxis(cls, global_moment, saxis): + def from_global_moment_and_saxis(cls, global_moment, saxis) -> Self: """Convenience method to initialize Magmom from a given global magnetic moment, i.e. magnetic moment with saxis=(0,0,1), and provided saxis. @@ -156,8 +162,9 @@ def from_global_moment_and_saxis(cls, global_moment, saxis): Method is useful if you do not know the components of your magnetic moment in frame of your desired saxis. - :param global_moment: - :param saxis: desired saxis + Args: + global_moment: global magnetic moment + saxis: desired saxis """ magmom = Magmom(global_moment) return cls(magmom.get_moment(saxis=saxis), saxis=saxis) @@ -204,7 +211,8 @@ def get_moment(self, saxis=(0, 0, 1)): Magmom's internal spin quantization axis, i.e. equivalent to Magmom.moment. - :param saxis: (list/numpy array) spin quantization axis + Args: + saxis: (list/numpy array) spin quantization axis Returns: np.ndarray of length 3 @@ -295,7 +303,8 @@ def have_consistent_saxis(magmoms) -> bool: If saxis are inconsistent, can create consistent set with: Magmom.get_consistent_set(magmoms). - :param magmoms: list of magmoms (Magmoms, scalars or vectors) + Args: + magmoms: list of magmoms (Magmoms, scalars or vectors) Returns: bool @@ -312,11 +321,12 @@ def get_consistent_set_and_saxis(magmoms, saxis=None): """Method to ensure a list of magmoms use the same spin axis. Returns a tuple of a list of Magmoms and their global spin axis. - :param magmoms: list of magmoms (Magmoms, scalars or vectors) - :param saxis: can provide a specific global spin axis + Args: + magmoms: list of magmoms (Magmoms, scalars or vectors) + saxis: can provide a specific global spin axis Returns: - (list of Magmoms, global spin axis) tuple + tuple[list[Magmom], np.ndarray]: (list of Magmoms, global spin axis) """ magmoms = [Magmom(magmom) for magmom in magmoms] saxis = Magmom.get_suggested_saxis(magmoms) if saxis is None else saxis / np.linalg.norm(saxis) @@ -330,7 +340,8 @@ def get_suggested_saxis(magmoms): with collinear spins, this would give a sensible saxis for a ncl calculation. - :param magmoms: list of magmoms (Magmoms, scalars or vectors) + Args: + magmoms: list of magmoms (Magmoms, scalars or vectors) Returns: np.ndarray of length 3 @@ -353,7 +364,9 @@ def get_suggested_saxis(magmoms): def are_collinear(magmoms) -> bool: """Method checks to see if a set of magnetic moments are collinear with each other. - :param magmoms: list of magmoms (Magmoms, scalars or vectors). + + Args: + magmoms: list of magmoms (Magmoms, scalars or vectors). Returns: bool. @@ -375,13 +388,15 @@ def are_collinear(magmoms) -> bool: return num_ncl == 0 @classmethod - def from_moment_relative_to_crystal_axes(cls, moment, lattice): + def from_moment_relative_to_crystal_axes(cls, moment: list[float], lattice: Lattice) -> Self: """Obtaining a Magmom object from a magnetic moment provided relative to crystal axes. Used for obtaining moments from magCIF file. - :param moment: list of floats specifying vector magmom - :param lattice: Lattice + + Args: + moment: list of floats specifying vector magmom + lattice: Lattice Returns: Magmom @@ -397,7 +412,8 @@ def get_moment_relative_to_crystal_axes(self, lattice): """If scalar magmoms, moments will be given arbitrarily along z. Used for writing moments to magCIF file. - :param lattice: Lattice + Args: + lattice: Lattice Returns: vector as list of floats diff --git a/pymatgen/electronic_structure/dos.py b/pymatgen/electronic_structure/dos.py index d2ef8678ce6..01da2e83e54 100644 --- a/pymatgen/electronic_structure/dos.py +++ b/pymatgen/electronic_structure/dos.py @@ -22,6 +22,7 @@ from collections.abc import Mapping from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core.sites import PeriodicSite from pymatgen.util.typing import SpeciesLike @@ -64,8 +65,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Down - finds the gap in the down spin channel. Returns: - (gap, cbm, vbm): - Tuple of floats in eV corresponding to the gap, cbm and vbm. + tuple[float, float, float]: Energies in eV corresponding to the band gap, cbm and vbm. """ if spin is None: tdos = self.y if len(self.ydim) == 1 else np.sum(self.y, axis=1) @@ -92,7 +92,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: end = get_linear_interpolated_value(terminal_dens, terminal_energies, tol) return end - start, end, start - def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin=None): + def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin=None) -> tuple[float, float]: """Expects a DOS object and finds the cbm and vbm. Args: @@ -103,7 +103,7 @@ def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin=None): Down - finds the gap in the down spin channel. Returns: - (cbm, vbm): float in eV corresponding to the gap + tuple[float, float]: Energies in eV corresponding to the cbm and vbm. """ # determine tolerance if spin is None: @@ -257,7 +257,9 @@ def get_interpolated_value(self, energy: float) -> dict[Spin, float]: energies[spin] = get_linear_interpolated_value(self.energies, self.densities[spin], energy) return energies - def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None): + def get_interpolated_gap( + self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None + ) -> tuple[float, float, float]: """Expects a DOS object and finds the gap. Args: @@ -269,7 +271,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Down - finds the gap in the down spin channel. Returns: - (gap, cbm, vbm): Tuple of floats in eV corresponding to the gap, cbm and vbm. + tuple[float, float, float]: Energies in eV corresponding to the band gap, cbm and vbm. """ tdos = self.get_densities(spin) if not abs_tol: @@ -291,7 +293,7 @@ def get_interpolated_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: end = get_linear_interpolated_value(terminal_dens, terminal_energies, tol) return end - start, end, start - def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None): + def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = None) -> tuple[float, float]: """Expects a DOS object and finds the cbm and vbm. Args: @@ -302,7 +304,7 @@ def get_cbm_vbm(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | No Down - finds the gap in the down spin channel. Returns: - (cbm, vbm): float in eV corresponding to the gap + tuple[float, float]: Energies in eV corresponding to the cbm and vbm. """ # determine tolerance tdos = self.get_densities(spin) @@ -339,7 +341,7 @@ def get_gap(self, tol: float = 0.001, abs_tol: bool = False, spin: Spin | None = Returns: gap in eV """ - (cbm, vbm) = self.get_cbm_vbm(tol, abs_tol, spin) + cbm, vbm = self.get_cbm_vbm(tol, abs_tol, spin) return max(cbm - vbm, 0.0) def __str__(self) -> str: @@ -355,12 +357,12 @@ def __str__(self) -> str: return "\n".join(str_arr) @classmethod - def from_dict(cls, d) -> Dos: + def from_dict(cls, dct: dict) -> Self: """Returns Dos object from dict representation of Dos.""" - return Dos( - d["efermi"], - d["energies"], - {Spin(int(k)): v for k, v in d["densities"].items()}, + return cls( + dct["efermi"], + dct["energies"], + {Spin(int(k)): v for k, v in dct["densities"].items()}, ) def as_dict(self) -> dict: @@ -557,10 +559,10 @@ def get_fermi( fermi = self.efermi # initialize target fermi relative_error = [float("inf")] for _ in range(precision): - f_range = np.arange(-nstep, nstep + 1) * step + fermi - calc_doping = np.array([self.get_doping(f, temperature) for f in f_range]) + fermi_range = np.arange(-nstep, nstep + 1) * step + fermi + calc_doping = np.array([self.get_doping(fermi_lvl, temperature) for fermi_lvl in fermi_range]) relative_error = np.abs(calc_doping / concentration - 1.0) # type: ignore - fermi = f_range[np.argmin(relative_error)] + fermi = fermi_range[np.argmin(relative_error)] step /= 10.0 if min(relative_error) > rtol: @@ -568,14 +570,14 @@ def get_fermi( return fermi @classmethod - def from_dict(cls, d) -> FermiDos: + def from_dict(cls, dct: dict) -> Self: """Returns Dos object from dict representation of Dos.""" dos = Dos( - d["efermi"], - d["energies"], - {Spin(int(k)): v for k, v in d["densities"].items()}, + dct["efermi"], + dct["energies"], + {Spin(int(k)): v for k, v in dct["densities"].items()}, ) - return FermiDos(dos, structure=Structure.from_dict(d["structure"]), nelecs=d["nelecs"]) + return cls(dos, structure=Structure.from_dict(dct["structure"]), nelecs=dct["nelecs"]) def as_dict(self) -> dict: """JSON-serializable dict representation of Dos.""" @@ -1248,19 +1250,19 @@ def get_dos_fp_similarity( ) @classmethod - def from_dict(cls, d) -> CompleteDos: + def from_dict(cls, dct: dict) -> Self: """Returns CompleteDos object from dict representation.""" - tdos = Dos.from_dict(d) - struct = Structure.from_dict(d["structure"]) + tdos = Dos.from_dict(dct) + struct = Structure.from_dict(dct["structure"]) pdoss = {} - for i in range(len(d["pdos"])): - at = struct[i] + for idx in range(len(dct["pdos"])): + at = struct[idx] orb_dos = {} - for orb_str, odos in d["pdos"][i].items(): + for orb_str, odos in dct["pdos"][idx].items(): orb = Orbital[orb_str] orb_dos[orb] = {Spin(int(k)): v for k, v in odos["densities"].items()} pdoss[at] = orb_dos - return CompleteDos(struct, tdos, pdoss) + return cls(struct, tdos, pdoss) def as_dict(self) -> dict: """JSON-serializable dict representation of CompleteDos.""" @@ -1394,19 +1396,19 @@ def get_element_spd_dos(self, el: SpeciesLike) -> dict[str, Dos]: # type: ignor return {orb: Dos(self.efermi, self.energies, densities) for orb, densities in el_dos.items()} # type: ignore @classmethod - def from_dict(cls, d) -> LobsterCompleteDos: + def from_dict(cls, dct: dict) -> Self: """Hydrate CompleteDos object from dict representation.""" - tdos = Dos.from_dict(d) - struct = Structure.from_dict(d["structure"]) + tdos = Dos.from_dict(dct) + struct = Structure.from_dict(dct["structure"]) pdoss = {} - for i in range(len(d["pdos"])): + for i in range(len(dct["pdos"])): at = struct[i] orb_dos = {} - for orb_str, odos in d["pdos"][i].items(): + for orb_str, odos in dct["pdos"][i].items(): orb = orb_str orb_dos[orb] = {Spin(int(k)): v for k, v in odos["densities"].items()} pdoss[at] = orb_dos - return LobsterCompleteDos(struct, tdos, pdoss) + return cls(struct, tdos, pdoss) def add_densities(density1: Mapping[Spin, ArrayLike], density2: Mapping[Spin, ArrayLike]) -> dict[Spin, np.ndarray]: diff --git a/pymatgen/electronic_structure/plotter.py b/pymatgen/electronic_structure/plotter.py index 18b4735d5d4..17416024aed 100644 --- a/pymatgen/electronic_structure/plotter.py +++ b/pymatgen/electronic_structure/plotter.py @@ -796,16 +796,16 @@ def get_ticks_old(self): tick_labels = [] previous_label = bs.kpoints[0].label previous_branch = bs.branches[0]["name"] - for i, c in enumerate(bs.kpoints): - if c.label is not None: - tick_distance.append(bs.distance[i]) + for idx, kpt in enumerate(bs.kpoints): + if kpt.label is not None: + tick_distance.append(bs.distance[idx]) this_branch = None for b in bs.branches: - if b["start_index"] <= i <= b["end_index"]: + if b["start_index"] <= idx <= b["end_index"]: this_branch = b["name"] break - if c.label != previous_label and previous_branch != this_branch: - label1 = c.label + if kpt.label != previous_label and previous_branch != this_branch: + label1 = kpt.label if label1.startswith("\\") or label1.find("_") != -1: label1 = f"${label1}$" label0 = previous_label @@ -814,11 +814,11 @@ def get_ticks_old(self): tick_labels.pop() tick_distance.pop() tick_labels.append(label0 + "$\\mid$" + label1) - elif c.label.startswith("\\") or c.label.find("_") != -1: - tick_labels.append(f"${c.label}$") + elif kpt.label.startswith("\\") or kpt.label.find("_") != -1: + tick_labels.append(f"${kpt.label}$") else: - tick_labels.append(c.label) - previous_label = c.label + tick_labels.append(kpt.label) + previous_label = kpt.label previous_branch = this_branch return {"distance": tick_distance, "label": tick_labels} @@ -960,7 +960,7 @@ def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_c vbm_cbm_marker: Add markers for the VBM and CBM. Defaults to False. Returns: - a pyplot object with different subfigures for each projection + list[plt.Axes]: A list with different subfigures for each projection The blue and red colors are for spin up and spin down. The bigger the red or blue dot in the band structure the higher character for the corresponding element and orbital. @@ -1048,14 +1048,14 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb e_min, e_max = -4, 4 if self._bs.is_metal(): e_min, e_max = -10, 10 - for idx, el in enumerate(self._bs.structure.elements, 1): + for idx, el in enumerate(self._bs.structure.elements, start=1): ax = plt.subplot(220 + idx) self._make_ticks(ax) for b in range(len(data["distances"])): - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): ax.plot( data["distances"][b], - data["energy"][str(Spin.up)][b][i], + data["energy"][str(Spin.up)][b][band_idx], "-", color=[192 / 255, 192 / 255, 192 / 255], linewidth=band_linewidth, @@ -1063,19 +1063,19 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb if self._bs.is_spin_polarized: ax.plot( data["distances"][b], - data["energy"][str(Spin.down)][b][i], + data["energy"][str(Spin.down)][b][band_idx], "--", color=[128 / 255, 128 / 255, 128 / 255], linewidth=band_linewidth, ) - for j in range(len(data["energy"][str(Spin.up)][b][i])): + for j in range(len(data["energy"][str(Spin.up)][b][band_idx])): markerscale = sum( - proj[b][str(Spin.down)][i][j][str(el)][o] - for o in proj[b][str(Spin.down)][i][j][str(el)] + proj[b][str(Spin.down)][band_idx][j][str(el)][o] + for o in proj[b][str(Spin.down)][band_idx][j][str(el)] ) ax.plot( data["distances"][b][j], - data["energy"][str(Spin.down)][b][i][j], + data["energy"][str(Spin.down)][b][band_idx][j], "bo", markersize=markerscale * 15.0, color=[ @@ -1084,13 +1084,14 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb 0.4 * markerscale, ], ) - for j in range(len(data["energy"][str(Spin.up)][b][i])): + for j in range(len(data["energy"][str(Spin.up)][b][band_idx])): markerscale = sum( - proj[b][str(Spin.up)][i][j][str(el)][o] for o in proj[b][str(Spin.up)][i][j][str(el)] + proj[b][str(Spin.up)][band_idx][j][str(el)][o] + for o in proj[b][str(Spin.up)][band_idx][j][str(el)] ) ax.plot( data["distances"][b][j], - data["energy"][str(Spin.up)][b][i][j], + data["energy"][str(Spin.up)][b][band_idx][j], "o", markersize=markerscale * 15.0, color=[markerscale, 0.3 * markerscale, 0.4 * markerscale], @@ -1150,20 +1151,25 @@ def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None): if self._bs.is_spin_polarized: spins = [Spin.up, Spin.down] self._make_ticks(ax) - for s in spins: + for spin in spins: for b in range(len(data["distances"])): - for i in range(self._nb_bands): - for j in range(len(data["energy"][str(s)][b][i]) - 1): + for band_idx in range(self._nb_bands): + for j in range(len(data["energy"][str(spin)][b][band_idx]) - 1): sum_e = 0.0 for el in elt_ordered: sum_e = sum_e + sum( - proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)] + proj[b][str(spin)][band_idx][j][str(el)][o] + for o in proj[b][str(spin)][band_idx][j][str(el)] ) if sum_e == 0.0: color = [0.0] * len(elt_ordered) else: color = [ - sum(proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]) / sum_e + sum( + proj[b][str(spin)][band_idx][j][str(el)][o] + for o in proj[b][str(spin)][band_idx][j][str(el)] + ) + / sum_e for el in elt_ordered ] if len(color) == 2: @@ -1171,11 +1177,11 @@ def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None): color[2] = color[1] color[1] = 0.0 sign = "-" - if s == Spin.down: + if spin == Spin.down: sign = "--" ax.plot( [data["distances"][b][j], data["distances"][b][j + 1]], - [data["energy"][str(s)][b][i][j], data["energy"][str(s)][b][i][j + 1]], + [data["energy"][str(spin)][b][band_idx][j], data["energy"][str(spin)][b][band_idx][j + 1]], sign, color=color, linewidth=band_linewidth, @@ -1250,26 +1256,26 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su else: proj_br.append({str(Spin.up): [[] for _ in range(self._nb_bands)]}) - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(b["start_index"], b["end_index"] + 1): edict = {} for elt in dictpa: for anum in dictpa[elt]: edict[f"{elt}{anum}"] = {} for morb in dictio[elt]: - edict[f"{elt}{anum}"][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1] - proj_br[-1][str(Spin.up)][i].append(edict) + edict[f"{elt}{anum}"][morb] = proj[Spin.up][band_idx][j][setos[morb]][anum - 1] + proj_br[-1][str(Spin.up)][band_idx].append(edict) if self._bs.is_spin_polarized: - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(b["start_index"], b["end_index"] + 1): edict = {} for elt in dictpa: for anum in dictpa[elt]: edict[f"{elt}{anum}"] = {} for morb in dictio[elt]: - edict[f"{elt}{anum}"][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1] - proj_br[-1][str(Spin.down)][i].append(edict) + edict[f"{elt}{anum}"][morb] = proj[Spin.up][band_idx][j][setos[morb]][anum - 1] + proj_br[-1][str(Spin.down)][band_idx].append(edict) # Adjusting projections for plot dictio_d, dictpa_d = self._summarize_keys_for_plot(dictio, dictpa, sum_atoms, sum_morbs) @@ -1293,9 +1299,9 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su proj_br_d.append({str(Spin.up): [[] for _ in range(self._nb_bands)]}) if (sum_atoms is not None) and (sum_morbs is None): - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][band_idx][j]) edict = {} for elt in dictpa: if elt in sum_atoms: @@ -1310,11 +1316,11 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su else: for anum in dictpa_d[elt]: edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) - proj_br_d[-1][str(Spin.up)][i].append(edict) + proj_br_d[-1][str(Spin.up)][band_idx].append(edict) if self._bs.is_spin_polarized: - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][band_idx][j]) edict = {} for elt in dictpa: if elt in sum_atoms: @@ -1329,12 +1335,12 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su else: for anum in dictpa_d[elt]: edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) - proj_br_d[-1][str(Spin.down)][i].append(edict) + proj_br_d[-1][str(Spin.down)][band_idx].append(edict) elif (sum_atoms is None) and (sum_morbs is not None): - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][band_idx][j]) edict = {} for elt in dictpa: if elt in sum_morbs: @@ -1349,11 +1355,11 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su else: for anum in dictpa_d[elt]: edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) - proj_br_d[-1][str(Spin.up)][i].append(edict) + proj_br_d[-1][str(Spin.up)][band_idx].append(edict) if self._bs.is_spin_polarized: - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][band_idx][j]) edict = {} for elt in dictpa: if elt in sum_morbs: @@ -1368,12 +1374,12 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su else: for anum in dictpa_d[elt]: edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum]) - proj_br_d[-1][str(Spin.down)][i].append(edict) + proj_br_d[-1][str(Spin.down)][band_idx].append(edict) else: - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][band_idx][j]) edict = {} for elt in dictpa: if (elt in sum_atoms) and (elt in sum_morbs): @@ -1424,12 +1430,12 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su edict[elt + anum] = {} for morb in dictio_d[elt]: edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] - proj_br_d[-1][str(Spin.up)][i].append(edict) + proj_br_d[-1][str(Spin.up)][band_idx].append(edict) if self._bs.is_spin_polarized: - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): for j in range(br["end_index"] - br["start_index"] + 1): - atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j]) + atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][band_idx][j]) edict = {} for elt in dictpa: if (elt in sum_atoms) and (elt in sum_morbs): @@ -1480,7 +1486,7 @@ def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, su edict[elt + anum] = {} for morb in dictio_d[elt]: edict[elt + anum][morb] = atoms_morbs[elt + anum][morb] - proj_br_d[-1][str(Spin.down)][i].append(edict) + proj_br_d[-1][str(Spin.down)][band_idx].append(edict) return proj_br_d, dictio_d, dictpa_d, indices @@ -1611,10 +1617,13 @@ def get_projected_plots_dots_patom_pmorb( br = -1 for b in branches: br += 1 - for i in range(self._nb_bands): + for band_idx in range(self._nb_bands): ax.plot( [x - shift[br] for x in data["distances"][b]], - [data["energy"][str(Spin.up)][b][i][j] for j in range(len(data["distances"][b]))], + [ + data["energy"][str(Spin.up)][b][band_idx][j] + for j in range(len(data["distances"][b])) + ], "b-", linewidth=band_linewidth, ) @@ -1622,24 +1631,27 @@ def get_projected_plots_dots_patom_pmorb( if self._bs.is_spin_polarized: ax.plot( [x - shift[br] for x in data["distances"][b]], - [data["energy"][str(Spin.down)][b][i][j] for j in range(len(data["distances"][b]))], + [ + data["energy"][str(Spin.down)][b][band_idx][j] + for j in range(len(data["distances"][b])) + ], "r--", linewidth=band_linewidth, ) - for j in range(len(data["energy"][str(Spin.up)][b][i])): + for j in range(len(data["energy"][str(Spin.up)][b][band_idx])): ax.plot( data["distances"][b][j] - shift[br], - data["energy"][str(Spin.down)][b][i][j], + data["energy"][str(Spin.down)][b][band_idx][j], "co", - markersize=proj_br_d[br][str(Spin.down)][i][j][elt + numa][o] * 15.0, + markersize=proj_br_d[br][str(Spin.down)][band_idx][j][elt + numa][o] * 15.0, ) - for j in range(len(data["energy"][str(Spin.up)][b][i])): + for j in range(len(data["energy"][str(Spin.up)][b][band_idx])): ax.plot( data["distances"][b][j] - shift[br], - data["energy"][str(Spin.up)][b][i][j], + data["energy"][str(Spin.up)][b][band_idx][j], "go", - markersize=proj_br_d[br][str(Spin.up)][i][j][elt + numa][o] * 15.0, + markersize=proj_br_d[br][str(Spin.up)][band_idx][j][elt + numa][o] * 15.0, ) if ylim is None: @@ -1829,9 +1841,9 @@ def _number_of_subfigures(self, dictio, dictpa, sum_atoms, sum_morbs): raise ValueError(f"The dictpa[{elt}] is empty. We cannot do anything") _sites = self._bs.structure.sites indices = [] - for i in range(len(_sites)): - if next(iter(_sites[i]._species)) == Element(elt): - indices.append(i + 1) + for site_idx in range(len(_sites)): + if next(iter(_sites[site_idx]._species)) == Element(elt): + indices.append(site_idx + 1) for number in dictpa[elt]: if isinstance(number, str): if number.lower() == "all": @@ -1876,9 +1888,9 @@ def _number_of_subfigures(self, dictio, dictpa, sum_atoms, sum_morbs): raise ValueError(f"The sum_atoms[{elt}] is empty. We cannot do anything") _sites = self._bs.structure.sites indices = [] - for i in range(len(_sites)): - if next(iter(_sites[i]._species)) == Element(elt): - indices.append(i + 1) + for site_idx in range(len(_sites)): + if next(iter(_sites[site_idx]._species)) == Element(elt): + indices.append(site_idx + 1) for number in sum_atoms[elt]: if isinstance(number, str): if number.lower() == "all": @@ -1998,9 +2010,9 @@ def orbital_label(list_orbitals): if elt in sum_atoms: _sites = self._bs.structure.sites indices = [] - for i in range(len(_sites)): - if next(iter(_sites[i]._species)) == Element(elt): - indices.append(i + 1) + for site_idx in range(len(_sites)): + if next(iter(_sites[site_idx]._species)) == Element(elt): + indices.append(site_idx + 1) flag_1 = len(set(dictpa[elt]).intersection(indices)) flag_2 = len(set(sum_atoms[elt]).intersection(indices)) if flag_1 == len(indices) and flag_2 == len(indices): @@ -2047,9 +2059,9 @@ def orbital_label(list_orbitals): if elt in sum_atoms: _sites = self._bs.structure.sites indices = [] - for i in range(len(_sites)): - if next(iter(_sites[i]._species)) == Element(elt): - indices.append(i + 1) + for site_idx in range(len(_sites)): + if next(iter(_sites[site_idx]._species)) == Element(elt): + indices.append(site_idx + 1) flag_1 = len(set(dictpa[elt]).intersection(indices)) flag_2 = len(set(sum_atoms[elt]).intersection(indices)) if flag_1 == len(indices) and flag_2 == len(indices): @@ -2076,13 +2088,13 @@ def _make_ticks_selected(self, ax: plt.Axes, branches: list[int]) -> tuple[plt.A distance = [] label = [] rm_elems = [] - for i in range(1, len(ticks["distance"])): - if ticks["label"][i] == ticks["label"][i - 1]: - rm_elems.append(i) - for i in range(len(ticks["distance"])): - if i not in rm_elems: - distance.append(ticks["distance"][i]) - label.append(ticks["label"][i]) + for idx in range(1, len(ticks["distance"])): + if ticks["label"][idx] == ticks["label"][idx - 1]: + rm_elems.append(idx) + for idx in range(len(ticks["distance"])): + if idx not in rm_elems: + distance.append(ticks["distance"][idx]) + label.append(ticks["label"][idx]) l_branches = [distance[i] - distance[i - 1] for i in range(1, len(distance))] n_distance = [] n_label = [] @@ -2104,48 +2116,48 @@ def _make_ticks_selected(self, ax: plt.Axes, branches: list[int]) -> tuple[plt.A f_distance.extend((0.0, n_distance[0])) rf_distance.extend((0.0, n_distance[0])) length = n_distance[0] - for i in range(1, len(n_distance)): - if n_label[i][0] == n_label[i - 1][1]: - f_distance.extend((length, length + n_distance[i])) - f_label.extend((n_label[i][0], n_label[i][1])) + for idx in range(1, len(n_distance)): + if n_label[idx][0] == n_label[idx - 1][1]: + f_distance.extend((length, length + n_distance[idx])) + f_label.extend((n_label[idx][0], n_label[idx][1])) else: - f_distance.append(length + n_distance[i]) - f_label[-1] = n_label[i - 1][1] + "$\\mid$" + n_label[i][0] - f_label.append(n_label[i][1]) - rf_distance.append(length + n_distance[i]) - length += n_distance[i] + f_distance.append(length + n_distance[idx]) + f_label[-1] = n_label[idx - 1][1] + "$\\mid$" + n_label[idx][0] + f_label.append(n_label[idx][1]) + rf_distance.append(length + n_distance[idx]) + length += n_distance[idx] uniq_d = [] uniq_l = [] temp_ticks = list(zip(f_distance, f_label)) - for i, t in enumerate(temp_ticks): - if i == 0: - uniq_d.append(t[0]) - uniq_l.append(t[1]) - logger.debug(f"Adding label {t[0]} at {t[1]}") - elif t[1] == temp_ticks[i - 1][1]: - logger.debug(f"Skipping label {t[1]}") + for idx, tick in enumerate(temp_ticks): + if idx == 0: + uniq_d.append(tick[0]) + uniq_l.append(tick[1]) + logger.debug(f"Adding label {tick[0]} at {tick[1]}") + elif tick[1] == temp_ticks[idx - 1][1]: + logger.debug(f"Skipping label {tick[1]}") else: - logger.debug(f"Adding label {t[0]} at {t[1]}") - uniq_d.append(t[0]) - uniq_l.append(t[1]) + logger.debug(f"Adding label {tick[0]} at {tick[1]}") + uniq_d.append(tick[0]) + uniq_l.append(tick[1]) logger.debug(f"Unique labels are {list(zip(uniq_d, uniq_l))}") ax.set_xticks(uniq_d) ax.set_xticklabels(uniq_l) - for i in range(len(f_label)): - if f_label[i] is not None: + for idx in range(len(f_label)): + if f_label[idx] is not None: # don't print the same label twice - if i != 0: - if f_label[i] == f_label[i - 1]: - logger.debug(f"already print label... skipping label {f_label[i]}") + if idx != 0: + if f_label[idx] == f_label[idx - 1]: + logger.debug(f"already print label... skipping label {f_label[idx]}") else: - logger.debug(f"Adding a line at {f_distance[i]} for label {f_label[i]}") - ax.axvline(f_distance[i], color="k") + logger.debug(f"Adding a line at {f_distance[idx]} for label {f_label[idx]}") + ax.axvline(f_distance[idx], color="k") else: - logger.debug(f"Adding a line at {f_distance[i]} for label {f_label[i]}") - ax.axvline(f_distance[i], color="k") + logger.debug(f"Adding a line at {f_distance[idx]} for label {f_label[idx]}") + ax.axvline(f_distance[idx], color="k") shift = [] br = -1 @@ -2502,11 +2514,11 @@ def _rgbline(ax, k, e, red, green, blue, alpha=1, linestyles="solid") -> None: seg = np.concatenate([pts[:-1], pts[1:]], axis=1) n_seg = len(k) - 1 - r = [0.5 * (red[i] + red[i + 1]) for i in range(n_seg)] - g = [0.5 * (green[i] + green[i + 1]) for i in range(n_seg)] - b = [0.5 * (blue[i] + blue[i + 1]) for i in range(n_seg)] - a = np.ones(n_seg, float) * alpha - lc = LineCollection(seg, colors=list(zip(r, g, b, a)), linewidth=2, linestyles=linestyles) + red = [0.5 * (red[i] + red[i + 1]) for i in range(n_seg)] + green = [0.5 * (green[i] + green[i + 1]) for i in range(n_seg)] + blue = [0.5 * (blue[i] + blue[i + 1]) for i in range(n_seg)] + alpha = np.ones(n_seg, float) * alpha + lc = LineCollection(seg, colors=list(zip(red, green, blue, alpha)), linewidth=2, linestyles=linestyles) ax.add_collection(lc) @staticmethod @@ -2682,16 +2694,13 @@ def _rb_line(ax, r_label, b_label, loc) -> None: inset_ax = inset_axes(ax, width=1.2, height=0.4, loc=loc) - x = [] - y = [] - color = [] - for i in range(1000): - x.append(i / 1800.0 + 0.55) + x, y, color = [], [], [] + for idx in range(1000): + x.append(idx / 1800.0 + 0.55) y.append(0) - color.append([math.sqrt(c) for c in [1 - (i / 1000) ** 2, 0, (i / 1000) ** 2]]) + color.append([math.sqrt(c) for c in [1 - (idx / 1000) ** 2, 0, (idx / 1000) ** 2]]) # plot the bar - inset_ax.scatter(x, y, s=250.0, marker="s", c=color) inset_ax.set_xlim([-0.1, 1.7]) inset_ax.text( @@ -2972,8 +2981,10 @@ def plot_power_factor_mu( a matplotlib object """ ax = pretty_plot(9, 7) - pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp] - ax.semilogy(self._bz.mu_steps, pf, linewidth=3.0) + pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[ + temp + ] + ax.semilogy(self._bz.mu_steps, pow_factor, linewidth=3.0) self._plot_bg_limits(ax) self._plot_doping(ax, temp) if output == "eig": @@ -3145,12 +3156,12 @@ def plot_power_factor_temp(self, doping="all", output="average", relaxation_time a matplotlib object """ if output == "average": - pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") + pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") elif output == "eigs": - pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") + pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") ax = pretty_plot(22, 14) - tlist = sorted(pf["n"]) + tlist = sorted(pow_factor["n"]) doping = self._bz.doping["n"] if doping == "all" else doping for idx, doping_type in enumerate(["n", "p"]): plt.subplot(121 + idx) @@ -3158,7 +3169,7 @@ def plot_power_factor_temp(self, doping="all", output="average", relaxation_time dop_idx = self._bz.doping[doping_type].index(dop) pf_temp = [] for temp in tlist: - pf_temp.append(pf[doping_type][temp][dop_idx]) + pf_temp.append(pow_factor[doping_type][temp][dop_idx]) if output == "average": ax.plot(tlist, pf_temp, marker="s", label=f"{dop} $cm^{-3}$") elif output == "eigs": @@ -3387,11 +3398,11 @@ def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1 a matplotlib object """ if output == "average": - pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") + pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average") elif output == "eigs": - pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") + pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs") - tlist = sorted(pf["n"]) if temps == "all" else temps + tlist = sorted(pow_factor["n"]) if temps == "all" else temps ax = pretty_plot(22, 14) for i, dt in enumerate(["n", "p"]): plt.subplot(121 + i) @@ -3399,10 +3410,13 @@ def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1 if output == "eigs": for xyz in range(3): ax.semilogx( - self._bz.doping[dt], list(zip(*pf[dt][temp]))[xyz], marker="s", label=f"{xyz} {temp} K" + self._bz.doping[dt], + list(zip(*pow_factor[dt][temp]))[xyz], + marker="s", + label=f"{xyz} {temp} K", ) elif output == "average": - ax.semilogx(self._bz.doping[dt], pf[dt][temp], marker="s", label=f"{temp} K") + ax.semilogx(self._bz.doping[dt], pow_factor[dt][temp], marker="s", label=f"{temp} K") ax.set_title(dt + "-type", fontsize=20) if i == 0: ax.set_ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0) @@ -3861,12 +3875,11 @@ def plot_fermi_surface( run mlab.show(). Returns: - ((mayavi.mlab.figure, mayavi.mlab)): The mlab plotter and an interactive + tuple[mlab.figure, mlab]: The mlab plotter and an interactive figure to control the plot. Note: Experimental. - Please, double check the surface shown by using some - other software and report issues. + Please, double check the surface shown by using some other software and report issues. """ bz = structure.lattice.reciprocal_lattice.get_wigner_seitz_cell() cell = structure.lattice.reciprocal_lattice.matrix diff --git a/pymatgen/entries/compatibility.py b/pymatgen/entries/compatibility.py index e6c70c2dd5f..93e9beb919e 100644 --- a/pymatgen/entries/compatibility.py +++ b/pymatgen/entries/compatibility.py @@ -19,7 +19,7 @@ from uncertainties import ufloat from pymatgen.analysis.structure_analyzer import oxide_type, sulfide_type -from pymatgen.core import SETTINGS, Element +from pymatgen.core import SETTINGS, Composition, Element from pymatgen.entries.computed_entries import ( CompositionEnergyAdjustment, ComputedEntry, @@ -34,6 +34,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from pymatgen.util.typing import CompositionLike + __author__ = "Amanda Wang, Ryan Kingsbury, Shyue Ping Ong, Anubhav Jain, Stephen Dacek, Sai Jayaraman" __copyright__ = "Copyright 2012-2020, The Materials Project" __version__ = "1.0" @@ -43,6 +45,13 @@ MODULE_DIR = os.path.dirname(os.path.abspath(__file__)) MU_H2O = -2.4583 # Free energy of formation of water, eV/H2O, used by MaterialsProjectAqueousCompatibility +MP2020_COMPAT_CONFIG = loadfn(f"{MODULE_DIR}/MP2020Compatibility.yaml") +MP_COMPAT_CONFIG = loadfn(f"{MODULE_DIR}/MPCompatibility.yaml") + +assert ( # ping @janosh @rkingsbury on GitHub if this fails + MP2020_COMPAT_CONFIG["Corrections"]["GGAUMixingCorrections"]["O"] + == MP2020_COMPAT_CONFIG["Corrections"]["GGAUMixingCorrections"]["F"] +), "MP2020Compatibility.yaml expected to have the same Hubbard U corrections for O and F" AnyComputedEntry = Union[ComputedEntry, ComputedStructureEntry] @@ -189,12 +198,14 @@ def __init__(self, config_file): Args: config_file: Path to the selected compatibility.yaml config file. """ - c = loadfn(config_file) - self.name = c["Name"] - self.cpd_energies = c["Advanced"]["CompoundEnergies"] + config = loadfn(config_file) + self.name = config["Name"] + self.cpd_energies = config["Advanced"]["CompoundEnergies"] def get_correction(self, entry) -> ufloat: - """:param entry: A ComputedEntry/ComputedStructureEntry + """ + Args: + entry: A ComputedEntry/ComputedStructureEntry. Returns: Correction. @@ -228,14 +239,16 @@ def __init__(self, config_file, correct_peroxide=True): correct_peroxide: Specify whether peroxide/superoxide/ozonide corrections are to be applied or not. """ - c = loadfn(config_file) - self.oxide_correction = c["OxideCorrections"] - self.sulfide_correction = c.get("SulfideCorrections", defaultdict(float)) - self.name = c["Name"] + config = loadfn(config_file) + self.oxide_correction = config["OxideCorrections"] + self.sulfide_correction = config.get("SulfideCorrections", defaultdict(float)) + self.name = config["Name"] self.correct_peroxide = correct_peroxide def get_correction(self, entry) -> ufloat: - """:param entry: A ComputedEntry/ComputedStructureEntry + """ + Args: + entry: A ComputedEntry/ComputedStructureEntry. Returns: Correction. @@ -317,14 +330,14 @@ def __init__(self, config_file, error_file=None): config_file: Path to the selected compatibility.yaml config file. error_file: Path to the selected compatibilityErrors.yaml config file. """ - c = loadfn(config_file) - self.cpd_energies = c["AqueousCompoundEnergies"] + config = loadfn(config_file) + self.cpd_energies = config["AqueousCompoundEnergies"] # there will either be a CompositionCorrections OR an OxideCorrections key, # but not both, depending on the compatibility scheme we are using. # MITCompatibility only uses OxideCorrections, and hence self.comp_correction is none. - self.comp_correction = c.get("CompositionCorrections", defaultdict(float)) - self.oxide_correction = c.get("OxideCorrections", defaultdict(float)) - self.name = c["Name"] + self.comp_correction = config.get("CompositionCorrections", defaultdict(float)) + self.oxide_correction = config.get("OxideCorrections", defaultdict(float)) + self.name = config["Name"] if error_file: e = loadfn(error_file) self.cpd_errors = e.get("AqueousCompoundEnergies", defaultdict(float)) @@ -332,7 +345,9 @@ def __init__(self, config_file, error_file=None): self.cpd_errors = defaultdict(float) def get_correction(self, entry) -> ufloat: - """:param entry: A ComputedEntry/ComputedStructureEntry + """ + Args: + entry: A ComputedEntry/ComputedStructureEntry. Returns: Correction, Uncertainty. @@ -426,27 +441,29 @@ def __init__(self, config_file, input_set, compat_type, error_file=None): if compat_type not in ["GGA", "Advanced"]: raise CompatibilityError(f"Invalid {compat_type=}") - c = loadfn(config_file) + config = loadfn(config_file) self.input_set = input_set if compat_type == "Advanced": self.u_settings = self.input_set.CONFIG["INCAR"]["LDAUU"] - self.u_corrections = c["Advanced"]["UCorrections"] + self.u_corrections = config["Advanced"]["UCorrections"] else: self.u_settings = {} self.u_corrections = {} - self.name = c["Name"] + self.name = config["Name"] self.compat_type = compat_type if error_file: - e = loadfn(error_file) - self.u_errors = e["Advanced"]["UCorrections"] + err = loadfn(error_file) + self.u_errors = err["Advanced"]["UCorrections"] else: self.u_errors = {} def get_correction(self, entry) -> ufloat: - """:param entry: A ComputedEntry/ComputedStructureEntry + """ + Args: + entry: A ComputedEntry/ComputedStructureEntry. Returns: Correction, Uncertainty. @@ -455,12 +472,12 @@ def get_correction(self, entry) -> ufloat: comp = entry.composition elements = sorted((el for el in comp.elements if comp[el] > 0), key=lambda el: el.X) - most_electroneg = elements[-1].symbol + most_electro_neg = elements[-1].symbol correction = ufloat(0.0, 0.0) - u_corr = self.u_corrections.get(most_electroneg, {}) - u_settings = self.u_settings.get(most_electroneg, {}) - u_errors = self.u_errors.get(most_electroneg, defaultdict(float)) + u_corr = self.u_corrections.get(most_electro_neg, {}) + u_settings = self.u_settings.get(most_electro_neg, {}) + u_errors = self.u_errors.get(most_electro_neg, defaultdict(float)) for el in comp.elements: sym = el.symbol @@ -690,13 +707,13 @@ def get_explanation_dict(self, entry): entry: A ComputedEntry. Returns: - (dict) of the form - {"Compatibility": "string", - "Uncorrected_energy": float, - "Corrected_energy": float, - "correction_uncertainty:" float, - "Corrections": [{"Name of Correction": { - "Value": float, "Explanation": "string", "Uncertainty": float}]} + dict[str, str | float | list[dict[str, Union[str, float]]]: of the form + {"Compatibility": "string", + "Uncorrected_energy": float, + "Corrected_energy": float, + "correction_uncertainty:" float, + "Corrections": [{"Name of Correction": { + "Value": float, "Explanation": "string", "Uncertainty": float}]} """ corr_entry = self.process_entry(entry) uncorrected_energy = (corr_entry or entry).uncorrected_energy @@ -782,13 +799,13 @@ def __init__( self.compat_type = compat_type self.correct_peroxide = correct_peroxide self.check_potcar_hash = check_potcar_hash - fp = f"{MODULE_DIR}/MPCompatibility.yaml" + file_path = f"{MODULE_DIR}/MPCompatibility.yaml" super().__init__( [ PotcarCorrection(MPRelaxSet, check_hash=check_potcar_hash), - GasCorrection(fp), - AnionCorrection(fp, correct_peroxide=correct_peroxide), - UCorrection(fp, MPRelaxSet, compat_type), + GasCorrection(file_path), + AnionCorrection(file_path, correct_peroxide=correct_peroxide), + UCorrection(file_path, MPRelaxSet, compat_type), ] ) @@ -883,21 +900,22 @@ def __init__( if config_file: if os.path.isfile(config_file): self.config_file: str | None = config_file - c = loadfn(self.config_file) + config = loadfn(self.config_file) else: raise ValueError(f"Custom MaterialsProject2020Compatibility {config_file=} does not exist.") else: self.config_file = None - c = loadfn(f"{MODULE_DIR}/MP2020Compatibility.yaml") - self.name = c["Name"] - self.comp_correction = c["Corrections"].get("CompositionCorrections", defaultdict(float)) - self.comp_errors = c["Uncertainties"].get("CompositionCorrections", defaultdict(float)) + config = MP2020_COMPAT_CONFIG + + self.name = config["Name"] + self.comp_correction = config["Corrections"].get("CompositionCorrections", defaultdict(float)) + self.comp_errors = config["Uncertainties"].get("CompositionCorrections", defaultdict(float)) if self.compat_type == "Advanced": self.u_settings = MPRelaxSet.CONFIG["INCAR"]["LDAUU"] - self.u_corrections = c["Corrections"].get("GGAUMixingCorrections", defaultdict(float)) - self.u_errors = c["Uncertainties"].get("GGAUMixingCorrections", defaultdict(float)) + self.u_corrections = config["Corrections"].get("GGAUMixingCorrections", defaultdict(float)) + self.u_errors = config["Uncertainties"].get("GGAUMixingCorrections", defaultdict(float)) else: self.u_settings = {} self.u_corrections = {} @@ -1064,10 +1082,10 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]: for el in comp.elements: symbol = el.symbol # Check for bad U values - expected_u = u_settings.get(symbol, 0) - actual_u = calc_u.get(symbol, 0) + expected_u = float(u_settings.get(symbol, 0)) + actual_u = float(calc_u.get(symbol, 0)) if actual_u != expected_u: - raise CompatibilityError(f"Invalid U value of {actual_u:.1f} on {symbol}, expected {expected_u:.1f}") + raise CompatibilityError(f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3}") if symbol in u_corrections: adjustments.append( CompositionEnergyAdjustment( @@ -1449,3 +1467,29 @@ def process_entries( self.h2_energy = h2_entries[0].energy_per_atom # type: ignore[assignment] return super().process_entries(entries, clean=clean, verbose=verbose, inplace=inplace, on_error=on_error) + + +def needs_u_correction( + comp: CompositionLike, + u_config: dict[str, dict[str, float]] = MP2020_COMPAT_CONFIG["Corrections"]["GGAUMixingCorrections"], +) -> set[str]: + """Check if a composition is Hubbard U-corrected in the Materials Project 2020 + GGA/GGA+U mixing scheme. + + Args: + comp (CompositionLike): The formula/composition to check. + u_config (dict): The U-correction configuration to use. Default is the + Materials Project 2020 configuration. + + Returns: + set[str]: The subset of elements whose combination requires a U-correction. Pass + return value to bool(ret_val) if you just want True/False. + """ + elements = set(map(str, Composition(comp).elements)) + has_u_anion = set(u_config) & elements + + u_corrected_cations = set(u_config["O"]) + has_u_cation = u_corrected_cations & elements + if has_u_cation and has_u_anion: + return has_u_cation | has_u_anion + return set() diff --git a/pymatgen/entries/computed_entries.py b/pymatgen/entries/computed_entries.py index 26f22a37d1d..a24ad687412 100644 --- a/pymatgen/entries/computed_entries.py +++ b/pymatgen/entries/computed_entries.py @@ -25,6 +25,8 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure __author__ = "Ryan Kingsbury, Matt McDermott, Shyue Ping Ong, Anubhav Jain" @@ -130,10 +132,12 @@ def explain(self): """Return an explanation of how the energy adjustment is calculated.""" return f"{self.description} ({self.value:.3f} eV)" - def normalize(self, factor): + def normalize(self, factor: float) -> None: """Normalize energy adjustment (in place), dividing value/uncertainty by a factor. - :param factor: factor to divide by. + + Args: + factor: factor to divide by. """ self._value /= factor self._uncertainty /= factor @@ -200,10 +204,12 @@ def explain(self): """Return an explanation of how the energy adjustment is calculated.""" return f"{self.description} ({self._adj_per_atom:.3f} eV/atom x {self.n_atoms} atoms)" - def normalize(self, factor): + def normalize(self, factor: float) -> None: """Normalize energy adjustment (in place), dividing value/uncertainty by a factor. - :param factor: factor to divide by. + + Args: + factor: factor to divide by. """ self.n_atoms /= factor @@ -259,10 +265,12 @@ def explain(self): """Return an explanation of how the energy adjustment is calculated.""" return f"{self.description} ({self._adj_per_deg:.4f} eV/K/atom x {self.temp} K x {self.n_atoms} atoms)" - def normalize(self, factor): + def normalize(self, factor: float) -> None: """Normalize energy adjustment (in place), dividing value/uncertainty by a factor. - :param factor: factor to divide by. + + Args: + factor: factor to divide by. """ self.n_atoms /= factor @@ -467,35 +475,36 @@ def __eq__(self, other: object) -> bool: return True @classmethod - def from_dict(cls, d) -> ComputedEntry: - """:param d: Dict representation. + def from_dict(cls, dct: dict) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: ComputedEntry """ - dec = MontyDecoder() # the first block here is for legacy ComputedEntry that were # serialized before we had the energy_adjustments attribute. - if d["correction"] != 0 and not d.get("energy_adjustments"): + if dct["correction"] != 0 and not dct.get("energy_adjustments"): return cls( - d["composition"], - d["energy"], - d["correction"], - parameters={k: dec.process_decoded(v) for k, v in d.get("parameters", {}).items()}, - data={k: dec.process_decoded(v) for k, v in d.get("data", {}).items()}, - entry_id=d.get("entry_id"), + dct["composition"], + dct["energy"], + dct["correction"], + parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()}, + data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()}, + entry_id=dct.get("entry_id"), ) # this is the preferred / modern way of instantiating ComputedEntry # we don't pass correction explicitly because it will be calculated # on the fly from energy_adjustments return cls( - d["composition"], - d["energy"], + dct["composition"], + dct["energy"], correction=0, - energy_adjustments=[dec.process_decoded(e) for e in d.get("energy_adjustments", {})], - parameters={k: dec.process_decoded(v) for k, v in d.get("parameters", {}).items()}, - data={k: dec.process_decoded(v) for k, v in d.get("data", {}).items()}, - entry_id=d.get("entry_id"), + energy_adjustments=[MontyDecoder().process_decoded(e) for e in dct.get("energy_adjustments", {})], + parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()}, + data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()}, + entry_id=dct.get("entry_id"), ) def as_dict(self) -> dict: @@ -602,37 +611,38 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, d) -> ComputedStructureEntry: - """:param d: Dict representation. + def from_dict(cls, dct) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: ComputedStructureEntry """ - dec = MontyDecoder() # the first block here is for legacy ComputedEntry that were # serialized before we had the energy_adjustments attribute. - if d["correction"] != 0 and not d.get("energy_adjustments"): - struct = dec.process_decoded(d["structure"]) + if dct["correction"] != 0 and not dct.get("energy_adjustments"): + struct = MontyDecoder().process_decoded(dct["structure"]) return cls( struct, - d["energy"], - correction=d["correction"], - parameters={k: dec.process_decoded(v) for k, v in d.get("parameters", {}).items()}, - data={k: dec.process_decoded(v) for k, v in d.get("data", {}).items()}, - entry_id=d.get("entry_id"), + dct["energy"], + correction=dct["correction"], + parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()}, + data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()}, + entry_id=dct.get("entry_id"), ) # this is the preferred / modern way of instantiating ComputedEntry # we don't pass correction explicitly because it will be calculated # on the fly from energy_adjustments return cls( - dec.process_decoded(d["structure"]), - d["energy"], - composition=d.get("composition"), + MontyDecoder().process_decoded(dct["structure"]), + dct["energy"], + composition=dct.get("composition"), correction=0, - energy_adjustments=[dec.process_decoded(e) for e in d.get("energy_adjustments", {})], - parameters={k: dec.process_decoded(v) for k, v in d.get("parameters", {}).items()}, - data={k: dec.process_decoded(v) for k, v in d.get("data", {}).items()}, - entry_id=d.get("entry_id"), + energy_adjustments=[MontyDecoder().process_decoded(e) for e in dct.get("energy_adjustments", {})], + parameters={k: MontyDecoder().process_decoded(v) for k, v in dct.get("parameters", {}).items()}, + data={k: MontyDecoder().process_decoded(v) for k, v in dct.get("data", {}).items()}, + entry_id=dct.get("entry_id"), ) def normalize(self, mode: Literal["formula_unit", "atom"] = "formula_unit") -> ComputedStructureEntry: @@ -869,7 +879,7 @@ def _g_delta_sisso(vol_per_atom, reduced_mass, temp) -> float: ) @classmethod - def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[GibbsComputedStructureEntry]: + def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[Self]: """Constructor method for initializing a list of GibbsComputedStructureEntry objects from an existing T = 0 K phase diagram composed of ComputedStructureEntry objects, as acquired from a thermochemical database; @@ -903,7 +913,7 @@ def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[GibbsComputedStructu return gibbs_entries @classmethod - def from_entries(cls, entries, temp=300, gibbs_model="SISSO") -> list[GibbsComputedStructureEntry]: + def from_entries(cls, entries, temp=300, gibbs_model="SISSO") -> list[Self]: """Constructor method for initializing GibbsComputedStructureEntry objects from T = 0 K ComputedStructureEntry objects, as acquired from a thermochemical database e.g. The Materials Project. @@ -915,7 +925,7 @@ def from_entries(cls, entries, temp=300, gibbs_model="SISSO") -> list[GibbsCompu gibbs_model (str): Gibbs model to use; currently the only option is "SISSO". Returns: - [GibbsComputedStructureEntry]: list of new entries which replace the orig. + list[GibbsComputedStructureEntry]: new entries which replace the orig. entries with inclusion of Gibbs free energy of formation at the specified temperature. """ @@ -934,29 +944,30 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, d) -> GibbsComputedStructureEntry: - """:param d: Dict representation. + def from_dict(cls, dct) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: GibbsComputedStructureEntry """ dec = MontyDecoder() return cls( - dec.process_decoded(d["structure"]), - d["formation_enthalpy_per_atom"], - d["temp"], - d["gibbs_model"], - composition=d.get("composition"), - correction=d["correction"], - energy_adjustments=[dec.process_decoded(e) for e in d.get("energy_adjustments", {})], - parameters={k: dec.process_decoded(v) for k, v in d.get("parameters", {}).items()}, - data={k: dec.process_decoded(v) for k, v in d.get("data", {}).items()}, - entry_id=d.get("entry_id"), + dec.process_decoded(dct["structure"]), + dct["formation_enthalpy_per_atom"], + dct["temp"], + dct["gibbs_model"], + composition=dct.get("composition"), + correction=dct["correction"], + energy_adjustments=[dec.process_decoded(e) for e in dct.get("energy_adjustments", {})], + parameters={k: dec.process_decoded(v) for k, v in dct.get("parameters", {}).items()}, + data={k: dec.process_decoded(v) for k, v in dct.get("data", {}).items()}, + entry_id=dct.get("entry_id"), ) def __repr__(self): - output = [ - f"GibbsComputedStructureEntry {self.entry_id} - {self.formula}", - f"Gibbs Free Energy (Formation) = {self.energy:.4f}", - ] - return "\n".join(output) + return ( + f"GibbsComputedStructureEntry {self.entry_id} - {self.formula}\n" + f"Gibbs Free Energy (Formation) = {self.energy:.4f}" + ) diff --git a/pymatgen/entries/entry_tools.py b/pymatgen/entries/entry_tools.py index 8e431050c58..4bf3978bf79 100644 --- a/pymatgen/entries/entry_tools.py +++ b/pymatgen/entries/entry_tools.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing_extensions import Self + from pymatgen.entries import Entry from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry @@ -211,14 +213,16 @@ def __len__(self): def add(self, element): """Add an entry. - :param element: Entry + Args: + element: Entry """ self.entries.add(element) def discard(self, element): """Discard an entry. - :param element: Entry + Args: + element: Entry """ self.entries.discard(element) @@ -239,10 +243,10 @@ def ground_states(self) -> set: per atom entry at each composition. """ entries = sorted(self.entries, key=lambda e: e.reduced_formula) - ground_states = set() - for _, g in itertools.groupby(entries, key=lambda e: e.reduced_formula): - ground_states.add(min(g, key=lambda e: e.energy_per_atom)) - return ground_states + return { + min(g, key=lambda e: e.energy_per_atom) + for _, g in itertools.groupby(entries, key=lambda e: e.reduced_formula) + } def remove_non_ground_states(self): """Removes all non-ground state entries, i.e., only keep the lowest energy @@ -312,7 +316,7 @@ def to_csv(self, filename: str, latexify_names: bool = False) -> None: writer.writerow(row) @classmethod - def from_csv(cls, filename: str): + def from_csv(cls, filename: str) -> Self: """Imports PDEntries from a csv. Args: diff --git a/pymatgen/entries/exp_entries.py b/pymatgen/entries/exp_entries.py index 8ac366c0d0a..e012ed7a9fb 100644 --- a/pymatgen/entries/exp_entries.py +++ b/pymatgen/entries/exp_entries.py @@ -2,12 +2,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from monty.json import MSONable from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.thermochemistry import ThermoData from pymatgen.core.composition import Composition +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -51,14 +56,16 @@ def __repr__(self): return f"ExpEntry {self.formula}, Energy = {self.energy:.4f}" @classmethod - def from_dict(cls, d): - """:param d: Dict representation. + def from_dict(cls, dct: dict) -> Self: + """ + Args: + dct (dict): Dict representation. Returns: ExpEntry """ - thermodata = [ThermoData.from_dict(td) for td in d["thermodata"]] - return cls(d["composition"], thermodata, d["temperature"]) + thermodata = [ThermoData.from_dict(td) for td in dct["thermodata"]] + return cls(dct["composition"], thermodata, dct["temperature"]) def as_dict(self): """MSONable dict.""" diff --git a/pymatgen/entries/mixing_scheme.py b/pymatgen/entries/mixing_scheme.py index 461887e7ff6..d7fe539b718 100644 --- a/pymatgen/entries/mixing_scheme.py +++ b/pymatgen/entries/mixing_scheme.py @@ -728,10 +728,10 @@ def display_entries(entries): f"{'entry_id':<12}{'formula':<12}{'spacegroup':<12}{'run_type':<10}{'eV/atom':<8}" f"{'corr/atom':<9} {'e_above_hull':<9}" ) - for e in entries: + for entry in entries: print( - f"{e.entry_id:<12}{e.reduced_formula:<12}{e.structure.get_space_group_info()[0]:<12}" - f"{e.parameters['run_type']:<10}{e.energy_per_atom:<8.3f}" - f"{e.correction / e.composition.num_atoms:<9.3f} {pd.get_e_above_hull(e):<9.3f}" + f"{entry.entry_id:<12}{entry.reduced_formula:<12}{entry.structure.get_space_group_info()[0]:<12}" + f"{entry.parameters['run_type']:<10}{entry.energy_per_atom:<8.3f}" + f"{entry.correction / entry.composition.num_atoms:<9.3f} {pd.get_e_above_hull(entry):<9.3f}" ) return diff --git a/pymatgen/ext/cod.py b/pymatgen/ext/cod.py index 9e4d55a31e3..9a720104b2a 100644 --- a/pymatgen/ext/cod.py +++ b/pymatgen/ext/cod.py @@ -47,13 +47,14 @@ class COD: def query(self, sql: str) -> str: """Perform a query. - :param sql: SQL string + Args: + sql: SQL string Returns: Response from SQL query. """ - resp = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"]) - return resp.decode("utf-8") + response = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"]) + return response.decode("utf-8") @requires(which("mysql"), "mysql must be installed to use this query.") def get_cod_ids(self, formula): @@ -89,8 +90,8 @@ def get_structure_by_id(self, cod_id, **kwargs): Returns: A Structure. """ - r = requests.get(f"http://{self.url}/cod/{cod_id}.cif") - return Structure.from_str(r.text, fmt="cif", **kwargs) + response = requests.get(f"http://{self.url}/cod/{cod_id}.cif") + return Structure.from_str(response.text, fmt="cif", **kwargs) @requires(which("mysql"), "mysql must be installed to use this query.") def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str | int | Structure]]: @@ -111,12 +112,12 @@ def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str for line in text: if line.strip(): cod_id, sg = line.split("\t") - r = requests.get(f"http://www.crystallography.net/cod/{cod_id.strip()}.cif") + response = requests.get(f"http://www.crystallography.net/cod/{cod_id.strip()}.cif") try: - struct = Structure.from_str(r.text, fmt="cif", **kwargs) + struct = Structure.from_str(response.text, fmt="cif", **kwargs) structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg}) except Exception: - warnings.warn(f"\nStructure.from_str failed while parsing CIF file:\n{r.text}") + warnings.warn(f"\nStructure.from_str failed while parsing CIF file:\n{response.text}") raise return structures diff --git a/pymatgen/ext/matproj.py b/pymatgen/ext/matproj.py index e86589a7b36..0766350f5c3 100644 --- a/pymatgen/ext/matproj.py +++ b/pymatgen/ext/matproj.py @@ -27,8 +27,12 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: + from mp_api.client import MPRester as _MPResterNew + from typing_extensions import Self + from pymatgen.core.structure import Structure from pymatgen.entries.computed_entries import ComputedStructureEntry + from pymatgen.ext.matproj_legacy import _MPResterLegacy logger = logging.getLogger(__name__) @@ -94,7 +98,7 @@ def __getattr__(self, item): "used by 80% of users. If you are looking for the full functionality MPRester, pls install the mp-api ." ) - def __enter__(self): + def __enter__(self) -> Self: """Support for "with" context.""" return self @@ -208,9 +212,9 @@ def get_structures(self, chemsys_formula: str, final=True) -> list[Structure]: """ query = f"chemsys={chemsys_formula}" if "-" in chemsys_formula else f"formula={chemsys_formula}" prop = "structure" if final else "initial_structure" - resp = self.request(f"materials/summary/?{query}&_all_fields=false&_fields={prop}") + response = self.request(f"materials/summary/?{query}&_all_fields=false&_fields={prop}") - return [d[prop] for d in resp] + return [dct[prop] for dct in response] def get_structure_by_material_id(self, material_id: str, conventional_unit_cell: bool = False) -> Structure: """ @@ -226,8 +230,8 @@ def get_structure_by_material_id(self, material_id: str, conventional_unit_cell: Structure object. """ prop = "structure" - resp = self.request(f"materials/summary/{material_id}/?_fields={prop}") - structure = resp[0][prop] + response = self.request(f"materials/summary/{material_id}/?_fields={prop}") + structure = response[0][prop] if conventional_unit_cell: return SpacegroupAnalyzer(structure).get_conventional_standard_structure() return structure @@ -248,8 +252,8 @@ def get_initial_structures_by_material_id( Structure object. """ prop = "initial_structures" - resp = self.request(f"materials/summary/{material_id}/?_fields={prop}") - structures = resp[0][prop] + response = self.request(f"materials/summary/{material_id}/?_fields={prop}") + structures = response[0][prop] if conventional_unit_cell: return [SpacegroupAnalyzer(s).get_conventional_standard_structure() for s in structures] # type: ignore return structures @@ -299,9 +303,9 @@ def get_entries( query = f"formula={criteria}" entries = [] - r = self.request(f"materials/thermo/?_fields=entries&{query}") - for d in r: - entries.extend(d["entries"].values()) + response = self.request(f"materials/thermo/?_fields=entries&{query}") + for dct in response: + entries.extend(dct["entries"].values()) if compatible_only: from pymatgen.entries.compatibility import MaterialsProject2020Compatibility @@ -372,8 +376,8 @@ class MPRester: for which API to use. """ - def __new__(cls, *args, **kwargs): - r""" + def __new__(cls, *args, **kwargs) -> _MPResterNew | _MPResterBasic | _MPResterLegacy: # type: ignore[misc] + """ Args: *args: Pass through to either legacy or new MPRester. **kwargs: Pass through to either legacy or new MPRester. diff --git a/pymatgen/ext/matproj_legacy.py b/pymatgen/ext/matproj_legacy.py index b45f6210efe..84ed591e560 100644 --- a/pymatgen/ext/matproj_legacy.py +++ b/pymatgen/ext/matproj_legacy.py @@ -35,8 +35,11 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos + logger = logging.getLogger(__name__) MP_LOG_FILE = os.path.join(os.path.expanduser("~"), ".mprester.log.yaml") @@ -233,7 +236,7 @@ def __init__( except Exception: pass - def __enter__(self): + def __enter__(self) -> Self: """Support for "with" context.""" return self @@ -450,10 +453,10 @@ def find_structure(self, filename_or_structure): payload = {"structure": json.dumps(struct.as_dict(), cls=MontyEncoder)} response = self.session.post(f"{self.preamble}/find_structure", data=payload) if response.status_code in [200, 400]: - resp = json.loads(response.text, cls=MontyDecoder) - if resp["valid_response"]: - return resp["response"] - raise MPRestError(resp["error"]) + response = json.loads(response.text, cls=MontyDecoder) + if response["valid_response"]: + return response["response"] + raise MPRestError(response["error"]) raise MPRestError(f"REST error with status code {response.status_code} and error {response.text}") def get_entries( @@ -1100,12 +1103,12 @@ def submit_snl(self, snl): payload = {"snl": json.dumps(json_data, cls=MontyEncoder)} response = self.session.post(f"{self.preamble}/snl/submit", data=payload) if response.status_code in [200, 400]: - resp = json.loads(response.text, cls=MontyDecoder) - if resp["valid_response"]: - if resp.get("warning"): - warnings.warn(resp["warning"]) - return resp["inserted_ids"] - raise MPRestError(resp["error"]) + response = json.loads(response.text, cls=MontyDecoder) + if response["valid_response"]: + if response.get("warning"): + warnings.warn(response["warning"]) + return response["inserted_ids"] + raise MPRestError(response["error"]) raise MPRestError(f"REST error with status code {response.status_code} and error {response.text}") @@ -1126,12 +1129,12 @@ def delete_snl(self, snl_ids): response = self.session.post(f"{self.preamble}/snl/delete", data=payload) if response.status_code in [200, 400]: - resp = json.loads(response.text, cls=MontyDecoder) - if resp["valid_response"]: - if resp.get("warning"): - warnings.warn(resp["warning"]) - return resp - raise MPRestError(resp["error"]) + response = json.loads(response.text, cls=MontyDecoder) + if response["valid_response"]: + if response.get("warning"): + warnings.warn(response["warning"]) + return response + raise MPRestError(response["error"]) raise MPRestError(f"REST error with status code {response.status_code} and error {response.text}") @@ -1154,12 +1157,12 @@ def query_snl(self, criteria): payload = {"criteria": json.dumps(criteria)} response = self.session.post(f"{self.preamble}/snl/query", data=payload) if response.status_code in [200, 400]: - resp = json.loads(response.text) - if resp["valid_response"]: - if resp.get("warning"): - warnings.warn(resp["warning"]) - return resp["response"] - raise MPRestError(resp["error"]) + response = json.loads(response.text) + if response["valid_response"]: + if response.get("warning"): + warnings.warn(response["warning"]) + return response["response"] + raise MPRestError(response["error"]) raise MPRestError(f"REST error with status code {response.status_code} and error {response.text}") @@ -1217,7 +1220,7 @@ def submit_vasp_directory( histories = [] for e in queen.get_data(): structures.append(e.structure) - m = { + meta_dict = { "_vasp": { "parameters": e.parameters, "final_energy": e.energy, @@ -1228,8 +1231,8 @@ def submit_vasp_directory( if "history" in e.parameters: histories.append(e.parameters["history"]) if master_data is not None: - m.update(master_data) - metadata.append(m) + meta_dict.update(master_data) + metadata.append(meta_dict) if master_history is not None: histories = master_history * len(structures) @@ -1252,12 +1255,12 @@ def get_stability(self, entries): data=payload, ) if response.status_code in [200, 400]: - resp = json.loads(response.text, cls=MontyDecoder) - if resp["valid_response"]: - if resp.get("warning"): - warnings.warn(resp["warning"]) - return resp["response"] - raise MPRestError(resp["error"]) + response = json.loads(response.text, cls=MontyDecoder) + if response["valid_response"]: + if response.get("warning"): + warnings.warn(response["warning"]) + return response["response"] + raise MPRestError(response["error"]) raise MPRestError(f"REST error with status code {response.status_code} and error {response.text}") def get_cohesive_energy(self, material_id, per_atom=False): @@ -1601,9 +1604,9 @@ def parse_criteria(criteria_string): def parse_sym(sym): if sym == "*": return [el.symbol for el in Element] - m = re.match(r"\{(.*)\}", sym) - if m: - return [s.strip() for s in m.group(1).split(",")] + + if match := re.match(r"\{(.*)\}", sym): + return [s.strip() for s in match.group(1).split(",")] return [sym] def parse_tok(t): @@ -1625,13 +1628,13 @@ def parse_tok(t): if ("*" in sym) or ("{" in sym): wild_card_els.append(sym) else: - m = re.match(r"([A-Z][a-z]*)[\.\d]*", sym) - explicit_els.append(m.group(1)) + match = re.match(r"([A-Z][a-z]*)[\.\d]*", sym) + explicit_els.append(match[1]) n_elements = len(wild_card_els) + len(set(explicit_els)) parts = re.split(r"(\*|\{.*\})", t) parts = [parse_sym(s) for s in parts if s != ""] - for f in itertools.product(*parts): - comp = Composition("".join(f)) + for formula in itertools.product(*parts): + comp = Composition("".join(formula)) if len(comp) == n_elements: # Check for valid Elements in keys. for elem in comp: @@ -1657,5 +1660,5 @@ def get_chunks(sequence: Sequence[Any], size=1): Returns: list[Sequence[Any]]: input sequence in chunks of length size. """ - chunks = int(math.ceil(len(sequence) / float(size))) + chunks = math.ceil(len(sequence) / float(size)) return [sequence[i * size : (i + 1) * size] for i in range(chunks)] diff --git a/pymatgen/ext/optimade.py b/pymatgen/ext/optimade.py index 5c4de8f6218..92eecf715e6 100644 --- a/pymatgen/ext/optimade.py +++ b/pymatgen/ext/optimade.py @@ -5,6 +5,7 @@ import logging import sys from collections import namedtuple +from typing import TYPE_CHECKING from urllib.parse import urljoin, urlparse import requests @@ -14,6 +15,9 @@ from pymatgen.util.due import Doi, due from pymatgen.util.provenance import StructureNL +if TYPE_CHECKING: + from typing_extensions import Self + # TODO: importing optimade-python-tool's data structures will make more sense Provider = namedtuple("Provider", ["name", "base_url", "description", "homepage", "prefix"]) @@ -48,7 +52,9 @@ class OptimadeRester: # these aliases are provided as a convenient shortcut for users of the OptimadeRester class aliases = { "aflow": "http://aflow.org/API/optimade/", - "alexandria": "https://alexandria.odbx.science", + "alexandria": "https://alexandria.icams.rub.de/pbe", + "alexandria.pbe": "https://alexandria.icams.rub.de/pbe", + "alexandria.pbesol": "https://alexandria.icams.rub.de/pbesol", "cod": "https://www.crystallography.net/cod/optimade", "cmr": "https://cmr-optimade.fysik.dtu.dk", "mcloud.mc3d": "https://aiida.materialscloud.org/mc3d/optimade", @@ -67,6 +73,7 @@ class OptimadeRester: "nmd": "https://nomad-lab.eu/prod/rae/optimade/", "odbx": "https://optimade.odbx.science", "odbx.odbx_misc": "https://optimade-misc.odbx.science", + "odbx.gnome": "https://optimade-gnome.odbx.science", "omdb.omdb_production": "http://optimade.openmaterialsdb.se", "oqmd": "http://oqmd.org/optimade/", "jarvis": "https://jarvis.nist.gov/optimade/jarvisdft", @@ -126,6 +133,11 @@ def __init__( # and values as the corresponding URL self.resources = {} + # preprocess aliases to ensure they have a trailing slash where appropriate + for alias, url in self.aliases.items(): + if urlparse(url).path is not None and not url.endswith("/"): + self.aliases[alias] += "/" + if not aliases_or_resource_urls: aliases_or_resource_urls = list(self.aliases) _logger.warning( @@ -435,6 +447,10 @@ def _validate_provider(self, provider_url) -> Provider | None: TODO: careful reading of OPTIMADE specification required TODO: add better exception handling, intentionally permissive currently """ + # Add trailing slash to all URLs if missing; prevents urljoin from scrubbing + # sections of the path + if urlparse(provider_url).path is not None and not provider_url.endswith("/"): + provider_url += "/" def is_url(url) -> bool: """Basic URL validation thanks to https://stackoverflow.com/a/52455972.""" @@ -448,6 +464,8 @@ def is_url(url) -> bool: _logger.warning(f"An invalid url was supplied: {provider_url}") return None + url = None + try: url = urljoin(provider_url, "v1/info") provider_info_json = self._get_json(url) @@ -485,6 +503,12 @@ def _parse_provider(self, provider: str, provider_url: str) -> dict[str, Provide A dictionary of keys (in format of "provider.database") to Provider objects. """ + # Add trailing slash to all URLs if missing; prevents urljoin from scrubbing + if urlparse(provider_url).path is not None and not provider_url.endswith("/"): + provider_url += "/" + + url = None + try: url = urljoin(provider_url, "v1/links") provider_link_json = self._get_json(url) @@ -545,8 +569,13 @@ def refresh_aliases(self, providers_url="https://providers.optimade.org/provider self.aliases = {alias: provider.base_url for alias, provider in structure_providers.items()} + # Add missing trailing slashes to any aliases with a path that need them + for alias, url in self.aliases.items(): + if urlparse(url).path is not None and not url.endswith("/"): + self.aliases[alias] += "/" + # TODO: revisit context manager logic here and in MPRester - def __enter__(self): + def __enter__(self) -> Self: """Support for "with" context.""" return self diff --git a/pymatgen/io/abinit/abiobjects.py b/pymatgen/io/abinit/abiobjects.py index 1ddfcb90fcb..0655233da21 100644 --- a/pymatgen/io/abinit/abiobjects.py +++ b/pymatgen/io/abinit/abiobjects.py @@ -9,14 +9,17 @@ from collections import namedtuple from collections.abc import Iterable -from enum import Enum +from enum import Enum, unique from pprint import pformat -from typing import Iterable, cast, Any +from typing import Iterable, cast, Any, TYPE_CHECKING from monty.collections import AttrDict from monty.design_patterns import singleton from monty.json import MontyDecoder, MontyEncoder, MSONable from pymatgen.core import ArrayWithUnit, Lattice, Species, Structure, units +if TYPE_CHECKING: + from typing_extensions import Self + def lattice_from_abivars(cls=None, *args, **kwargs): """ @@ -380,9 +383,9 @@ def as_dict(self) -> dict: return out @classmethod - def from_dict(cls, d: dict) -> SpinMode: + def from_dict(cls, dct: dict) -> Self: """Build object from dict.""" - return cls(**{k: d[k] for k in d if k in cls._fields}) + return cls(**{key: dct[key] for key in dct if key in cls._fields}) # An handy Multiton @@ -494,10 +497,10 @@ def as_dict(self) -> dict: "tsmear": self.tsmear, } - @staticmethod - def from_dict(d: dict) -> Smearing: + @classmethod + def from_dict(cls, dct: dict) -> Self: """Build object from dict.""" - return Smearing(d["occopt"], d["tsmear"]) + return cls(dct["occopt"], dct["tsmear"]) class ElectronsAlgorithm(dict, AbivarAble, MSONable): @@ -534,12 +537,12 @@ def as_dict(self) -> dict: return {"@module": type(self).__module__, "@class": type(self).__name__, **self.copy()} @classmethod - def from_dict(cls, d: dict) -> ElectronsAlgorithm: + def from_dict(cls, dct: dict) -> Self: """Build object from dict.""" - d = d.copy() - d.pop("@module", None) - d.pop("@class", None) - return cls(**d) + dct = dct.copy() + dct.pop("@module", None) + dct.pop("@class", None) + return cls(**dct) class Electrons(AbivarAble, MSONable): @@ -602,15 +605,14 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, d: dict) -> Electrons: + def from_dict(cls, dct: dict) -> Self: """Build object from dictionary.""" dct = dct.copy() dct.pop("@module", None) dct.pop("@class", None) - dec = MontyDecoder() - dct["spin_mode"] = dec.process_decoded(dct["spin_mode"]) - dct["smearing"] = dec.process_decoded(dct["smearing"]) - dct["algorithm"] = dec.process_decoded(dct["algorithm"]) if dct["algorithm"] else None + dct["spin_mode"] = MontyDecoder().process_decoded(dct["spin_mode"]) + dct["smearing"] = MontyDecoder().process_decoded(dct["smearing"]) + dct["algorithm"] = MontyDecoder().process_decoded(dct["algorithm"]) if dct["algorithm"] else None return cls(**dct) def to_abivars(self) -> dict: @@ -635,6 +637,7 @@ def to_abivars(self) -> dict: return abivars +@unique class KSamplingModes(Enum): """Enum if the different samplings of the BZ.""" @@ -718,11 +721,11 @@ def __init__( if use_symmetries and use_time_reversal: kptopt = 1 - if not use_symmetries and use_time_reversal: + elif not use_symmetries and use_time_reversal: kptopt = 2 - if not use_symmetries and not use_time_reversal: + elif not use_symmetries and not use_time_reversal: kptopt = 3 - if use_symmetries and not use_time_reversal: + else: # use_symmetries and not use_time_reversal kptopt = 4 abivars.update( @@ -994,14 +997,13 @@ def as_dict(self) -> dict: } @classmethod - def from_dict(cls, d: dict) -> KSampling: + def from_dict(cls, dct: dict) -> Self: """Build object from dict.""" - d = d.copy() - d.pop("@module", None) - d.pop("@class", None) - dec = MontyDecoder() - d["kpts"] = dec.process_decoded(d["kpts"]) - return cls(**d) + dct = dct.copy() + dct.pop("@module", None) + dct.pop("@class", None) + dct["kpts"] = MontyDecoder().process_decoded(dct["kpts"]) + return cls(**dct) class Constraints(AbivarAble): @@ -1124,15 +1126,16 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Build object from dictionary.""" - d = d.copy() - d.pop("@module", None) - d.pop("@class", None) + dct = dct.copy() + dct.pop("@module", None) + dct.pop("@class", None) - return cls(**d) + return cls(**dct) +@unique class PPModelModes(Enum): """Different kind of plasmon-pole models.""" @@ -1226,10 +1229,10 @@ def as_dict(self) -> dict: "@class": type(self).__name__, } - @staticmethod - def from_dict(d: dict) -> PPModel: + @classmethod + def from_dict(cls, dct: dict) -> Self: """Build object from dictionary.""" - return PPModel(mode=d["mode"], plasmon_freq=d["plasmon_freq"]) + return cls(mode=dct["mode"], plasmon_freq=dct["plasmon_freq"]) class HilbertTransform(AbivarAble): diff --git a/pymatgen/io/abinit/abitimer.py b/pymatgen/io/abinit/abitimer.py index ef20673af3c..b8f94f2e7ec 100644 --- a/pymatgen/io/abinit/abitimer.py +++ b/pymatgen/io/abinit/abitimer.py @@ -68,20 +68,21 @@ def walk(cls, top=".", ext=".abo"): Scan directory tree starting from top, look for files with extension `ext` and parse timing data. - Return: (parser, paths, okfiles) - where `parser` is the new object, `paths` is the list of files found and `okfiles` - is the list of files that have been parsed successfully. - (okfiles == paths) if all files have been parsed. + Returns: + parser: the new object + paths: the list of files found + ok_files: list of files that have been parsed successfully. + (ok_files == paths) if all files have been parsed. """ paths = [] for root, _dirs, files in os.walk(top): - for f in files: - if f.endswith(ext): - paths.append(os.path.join(root, f)) + for file in files: + if file.endswith(ext): + paths.append(os.path.join(root, file)) parser = cls() - okfiles = parser.parse(paths) - return parser, paths, okfiles + ok_files = parser.parse(paths) + return parser, paths, ok_files def __init__(self): """Initialize object.""" @@ -108,7 +109,8 @@ def parse(self, filenames) -> list[str]: Read and parse a filename or a list of filenames. Files that cannot be opened are ignored. A single filename may also be given. - Return: list of successfully read files. + Returns: + list of successfully read files. """ if isinstance(filenames, str): filenames = [filenames] @@ -255,7 +257,8 @@ def pefficiency(self) -> ParallelEfficiency: """ Analyze the parallel efficiency. - Return: ParallelEfficiency object. + Returns: + ParallelEfficiency object. """ timers = self.timers() @@ -741,12 +744,12 @@ def names_and_values(self, key, minval=None, minfract=None, sorted=True): if minval is not None: assert minfract is None - for n, v in zip(names, values): - if v >= minval: - new_names.append(n) - new_values.append(v) + for name, val in zip(names, values): + if val >= minval: + new_names.append(name) + new_values.append(val) else: - other_val += v + other_val += val new_names.append(f"below minval {minval}") new_values.append(other_val) @@ -756,12 +759,12 @@ def names_and_values(self, key, minval=None, minfract=None, sorted=True): total = self.sum_sections(key) - for n, v in zip(names, values): - if v / total >= minfract: - new_names.append(n) - new_values.append(v) + for name, val in zip(names, values): + if val / total >= minfract: + new_names.append(name) + new_values.append(val) else: - other_val += v + other_val += val new_names.append(f"below minfract {minfract}") new_values.append(other_val) @@ -802,8 +805,7 @@ def cpuwall_histogram(self, ax: plt.Axes = None, **kwargs): """ ax, fig = get_ax_fig(ax=ax) - nk = len(self.sections) - ind = np.arange(nk) # the x locations for the groups + ind = np.arange(len(self.sections)) # the x locations for the groups width = 0.35 # the width of the bars cpu_times = self.get_values("cpu_time") diff --git a/pymatgen/io/abinit/inputs.py b/pymatgen/io/abinit/inputs.py index be48629cb11..ba1150d40ac 100644 --- a/pymatgen/io/abinit/inputs.py +++ b/pymatgen/io/abinit/inputs.py @@ -13,7 +13,8 @@ import os from collections import namedtuple from collections.abc import Mapping, MutableMapping, Sequence -from enum import Enum +from enum import Enum, unique +from typing import TYPE_CHECKING import numpy as np from monty.collections import AttrDict @@ -25,6 +26,9 @@ from pymatgen.io.abinit.variable import InputVariable from pymatgen.symmetry.bandstructure import HighSymmKpath +if TYPE_CHECKING: + from typing_extensions import Self + logger = logging.getLogger(__file__) @@ -124,6 +128,7 @@ def as_structure(obj): raise TypeError(f"Don't know how to convert {type(obj)} into a structure") +@unique class ShiftMode(Enum): """ Class defining the mode to be used for the shifts. @@ -140,7 +145,7 @@ class ShiftMode(Enum): OneSymmetric = "O" @classmethod - def from_object(cls, obj): + def from_object(cls, obj) -> Self: """ Returns an instance of ShiftMode based on the type of object passed. Converts strings to ShiftMode depending on the initial letter of the string. G for GammaCentered, M for MonkhorstPack, @@ -174,6 +179,7 @@ def _stopping_criterion(run_level, accuracy): def _find_ecut_pawecutdg(ecut, pawecutdg, pseudos, accuracy): """Return a |AttrDict| with the value of ecut and pawecutdg.""" # Get ecut and pawecutdg from the pseudo hints. + has_hints = False if ecut is None or (pawecutdg is None and any(p.ispaw for p in pseudos)): has_hints = all(p.has_hints for p in pseudos) @@ -481,33 +487,33 @@ def calc_shiftk(structure, symprec: float = 0.01, angle_tolerance=5): This is often the preferred k point sampling. For a non-shifted Monkhorst-Pack grid, use `nshiftk=1` and `shiftk 0.0 0.0 0.0`, but there is little reason to do that. - When the primitive vectors of the lattice form a FCC lattice, with rprim:: + When the primitive vectors of the lattice form a FCC lattice, with rprim: 0.0 0.5 0.5 0.5 0.0 0.5 0.5 0.5 0.0 - the (very efficient) usual Monkhorst-Pack sampling will be generated by using nshiftk= 4 and shiftk:: + the (very efficient) usual Monkhorst-Pack sampling will be generated by using nshiftk= 4 and shiftk: 0.5 0.5 0.5 0.5 0.0 0.0 0.0 0.5 0.0 0.0 0.0 0.5 - When the primitive vectors of the lattice form a BCC lattice, with rprim:: + When the primitive vectors of the lattice form a BCC lattice, with rprim: -0.5 0.5 0.5 0.5 -0.5 0.5 0.5 0.5 -0.5 - the usual Monkhorst-Pack sampling will be generated by using nshiftk= 2 and shiftk:: + the usual Monkhorst-Pack sampling will be generated by using nshiftk= 2 and shiftk: 0.25 0.25 0.25 -0.25 -0.25 -0.25 However, the simple sampling nshiftk=1 and shiftk 0.5 0.5 0.5 is excellent. - For hexagonal lattices with hexagonal axes, e.g. rprim:: + For hexagonal lattices with hexagonal axes, e.g. rprim: 1.0 0.0 0.0 -0.5 sqrt(3)/2 0.0 @@ -545,10 +551,10 @@ def calc_shiftk(structure, symprec: float = 0.01, angle_tolerance=5): elif lattice_type == "hexagonal": # Find the hexagonal axis and set the shift along it. - for i, angle in enumerate(structure.lattice.angles): + for i, angle in enumerate(structure.lattice.angles, start=1): if abs(angle - 120) < 1.0: - j = (i + 1) % 3 - k = (i + 2) % 3 + j = i % 3 + k = (i + 1) % 3 hex_ax = next(ax for ax in range(3) if ax not in [j, k]) break else: @@ -786,10 +792,10 @@ def vars(self): return self._vars @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """JSON interface used in pymatgen for easier serialization.""" - pseudos = [Pseudo.from_file(p["filepath"]) for p in d["pseudos"]] - return cls(d["structure"], pseudos, comment=d["comment"], abi_args=d["abi_args"]) + pseudos = [Pseudo.from_file(p["filepath"]) for p in dct["pseudos"]] + return cls(dct["structure"], pseudos, comment=dct["comment"], abi_args=dct["abi_args"]) def add_abiobjects(self, *abi_objects): """ @@ -837,10 +843,9 @@ def to_str(self, post=None, with_structure=True, with_pseudos=True, exclude=None exclude: List of variable names that should be ignored. """ lines = [] - app = lines.append if self.comment: - app("# " + self.comment.replace("\n", "\n#")) + lines.append("# " + self.comment.replace("\n", "\n#")) post = post if post is not None else "" exclude = set(exclude) if exclude is not None else set() @@ -856,7 +861,7 @@ def to_str(self, post=None, with_structure=True, with_pseudos=True, exclude=None for name, value in items: # Build variable, convert to string and append it vname = name + post - app(str(InputVariable(vname, value))) + lines.append(str(InputVariable(vname, value))) out = "\n".join(lines) if not with_pseudos: @@ -1035,8 +1040,7 @@ class BasicMultiDataset: multi.get("paral_kgb", 0) - .. warning:: - + Warning: BasicMultiDataset does not support calculations done with different sets of pseudopotentials. The inputs can have different crystalline structures (as long as the atom types are equal) but each input in BasicMultiDataset must have the same set of pseudopotentials. @@ -1044,37 +1048,7 @@ class BasicMultiDataset: Error = BasicAbinitInputError - @classmethod - def from_inputs(cls, inputs): - """Build object from a list of BasicAbinitInput objects.""" - for inp in inputs: - if any(p1 != p2 for p1, p2 in zip(inputs[0].pseudos, inp.pseudos)): - raise ValueError("Pseudos must be consistent when from_inputs is invoked.") - - # Build BasicMultiDataset from input structures and pseudos and add inputs. - multi = cls( - structure=[inp.structure for inp in inputs], - pseudos=inputs[0].pseudos, - ndtset=len(inputs), - ) - - # Add variables - for inp, new_inp in zip(inputs, multi): - new_inp.set_vars(**inp) - - return multi - - @classmethod - def replicate_input(cls, input, ndtset): - """Construct a multidataset with ndtset from the BasicAbinitInput input.""" - multi = cls(input.structure, input.pseudos, ndtset=ndtset) - - for inp in multi: - inp.set_vars(**input) - - return multi - - def __init__(self, structure: Structure, pseudos, pseudo_dir="", ndtset=1): + def __init__(self, structure: Structure | Sequence[Structure], pseudos, pseudo_dir="", ndtset=1): """ Args: structure: file with the structure, |Structure| object or dictionary with ABINIT geo variable @@ -1114,6 +1088,36 @@ def __init__(self, structure: Structure, pseudos, pseudo_dir="", ndtset=1): assert len(structure) == ndtset self._inputs = [BasicAbinitInput(structure=s, pseudos=pseudos) for s in structure] + @classmethod + def from_inputs(cls, inputs: list[BasicAbinitInput]) -> Self: + """Build object from a list of BasicAbinitInput objects.""" + for inp in inputs: + if any(p1 != p2 for p1, p2 in zip(inputs[0].pseudos, inp.pseudos)): + raise ValueError("Pseudos must be consistent when from_inputs is invoked.") + + # Build BasicMultiDataset from input structures and pseudos and add inputs. + multi = cls( + structure=[inp.structure for inp in inputs], + pseudos=inputs[0].pseudos, + ndtset=len(inputs), + ) + + # Add variables + for inp, new_inp in zip(inputs, multi): + new_inp.set_vars(**inp) + + return multi + + @classmethod + def replicate_input(cls, input, ndtset): + """Construct a multidataset with ndtset from the BasicAbinitInput input.""" + multi = cls(input.structure, input.pseudos, ndtset=ndtset) + + for inp in multi: + inp.set_vars(**input) + + return multi + @property def ndtset(self): """Number of inputs in self.""" diff --git a/pymatgen/io/abinit/netcdf.py b/pymatgen/io/abinit/netcdf.py index 57fea064f7f..b2c7c6da52a 100644 --- a/pymatgen/io/abinit/netcdf.py +++ b/pymatgen/io/abinit/netcdf.py @@ -7,6 +7,7 @@ import logging import os.path import warnings +from typing import TYPE_CHECKING import numpy as np from monty.collections import AttrDict @@ -18,6 +19,9 @@ #from pymatgen.core.xcfunc import XcFunc #from pymatgen.core.structure import Structure +if TYPE_CHECKING: + from typing_extensions import Self + try: import netCDF4 except ImportError: @@ -93,7 +97,7 @@ def __init__(self, path: str): # See also https://github.com/Unidata/netcdf4-python/issues/785 self.rootgrp.set_auto_mask(False) - def __enter__(self): + def __enter__(self) -> Self: """Activated when used in the with statement.""" return self diff --git a/pymatgen/io/abinit/pseudos.py b/pymatgen/io/abinit/pseudos.py index 052f81d97a1..a5119a1ad7f 100644 --- a/pymatgen/io/abinit/pseudos.py +++ b/pymatgen/io/abinit/pseudos.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence import matplotlib.pyplot as plt + from typing_extensions import Self from pymatgen.core import Structure logger = logging.getLogger(__name__) @@ -61,11 +62,11 @@ def _read_nlines(filename: str, n_lines: int) -> list[str]: If n_lines is < 0, the entire file is read. """ if n_lines < 0: - with open(filename) as file: + with open(filename, encoding="utf-8") as file: return file.readlines() lines = [] - with open(filename) as file: + with open(filename, encoding="utf-8") as file: for lineno, line in enumerate(file): if lineno == n_lines: break @@ -107,8 +108,8 @@ def as_pseudo(cls, obj: Union[Pseudo, str]): """ return obj if isinstance(obj, cls) else cls.from_file(obj) - @staticmethod - def from_file(filename: str) -> Pseudo: + @classmethod + def from_file(cls, filename: str) -> Self: """ Build an instance of a concrete Pseudo subclass from filename. Note: the parser knows the concrete class that should be instantiated @@ -139,22 +140,21 @@ def to_str(self, verbose=0) -> str: """String representation.""" lines: list[str] = [] - app = lines.append - app(f"<{type(self).__name__}: {self.basename}>") - app(" summary: " + self.summary.strip()) - app(f" number of valence electrons: {self.Z_val}") - app(f" maximum angular momentum: {l2str(self.l_max)}") - app(f" angular momentum for local part: {l2str(self.l_local)}") - app(f" XC correlation: {self.xc}") - app(f" supports spin-orbit: {self.supports_soc}") + lines.append(f"<{type(self).__name__}: {self.basename}>") + lines.append(" summary: " + self.summary.strip()) + lines.append(f" number of valence electrons: {self.Z_val}") + lines.append(f" maximum angular momentum: {l2str(self.l_max)}") + lines.append(f" angular momentum for local part: {l2str(self.l_local)}") + lines.append(f" XC correlation: {self.xc}") + lines.append(f" supports spin-orbit: {self.supports_soc}") if self.isnc: - app(f" radius for non-linear core correction: {self.nlcc_radius}") + lines.append(f" radius for non-linear core correction: {self.nlcc_radius}") if self.has_hints: for accuracy in ("low", "normal", "high"): hint = self.hint_for_accuracy(accuracy=accuracy) - app(f" hint for {accuracy} accuracy: {hint}") + lines.append(f" hint for {accuracy} accuracy: {hint}") return "\n".join(lines) @@ -231,7 +231,7 @@ def md5(self): def compute_md5(self): """Compute and return MD5 hash value.""" - with open(self.path) as file: + with open(self.path, encoding="utf-8") as file: text = file.read() # usedforsecurity=False needed in FIPS mode (Federal Information Processing Standards) # https://github.com/materialsproject/pymatgen/issues/2804 @@ -264,7 +264,7 @@ def as_dict(self, **kwargs) -> dict: } @classmethod - def from_dict(cls, dct: dict): + def from_dict(cls, dct: dict) -> Self: """Build instance from dictionary (MSONable protocol).""" new = cls.from_file(dct["filepath"]) @@ -291,7 +291,7 @@ def as_tmpfile(self, tmpdir=None): # Copy dojo report file if present. root, _ext = os.path.splitext(self.filepath) - dj_report = root + ".djrepo" + dj_report = f"{root}.djrepo" if os.path.isfile(dj_report): shutil.copy(dj_report, os.path.join(tmpdir, os.path.basename(dj_report))) @@ -313,7 +313,7 @@ def djrepo_path(self) -> str: """The path of the djrepo file. None if file does not exist.""" root, _ext = os.path.splitext(self.filepath) - return root + ".djrepo" + return f"{root}.djrepo" # if os.path.isfile(path): return path # return None @@ -477,7 +477,7 @@ def __init__(self, path: str, header): value = header.get(attr_name) # Hide these attributes since one should always use the public interface. - setattr(self, "_" + attr_name, value) + setattr(self, f"_{attr_name}", value) @property def summary(self) -> str: @@ -587,9 +587,9 @@ def as_dict(self) -> dict: } @classmethod - def from_dict(cls, d: dict): + def from_dict(cls, dct: dict) -> Self: """Build instance from dictionary (MSONable protocol).""" - return cls(**{k: v for k, v in d.items() if not k.startswith("@")}) + return cls(**{k: v for k, v in dct.items() if not k.startswith("@")}) def _dict_from_lines(lines: list[str], key_nums: list[int], sep=None): @@ -828,11 +828,12 @@ def tm_header(filename, ppdesc) -> NcAbinitHeader: """ lines = _read_nlines(filename, -1) header = [] + lmax = None for lineno, line in enumerate(lines): header.append(line) if lineno == 2: - # Read lmax. + # Read lmax tokens = line.split() _pspcod, _pspxc, lmax, _lloc = map(int, tokens[:4]) _mmax, _r2well = map(float, tokens[4:6]) @@ -847,13 +848,15 @@ def tm_header(filename, ppdesc) -> NcAbinitHeader: # 0 4.085 6.246 0 2.8786493 l,e99.0,e99.9,nproj,rcpsp # .00000000 .0000000000 .0000000000 .00000000 rms,ekb1,ekb2,epsatm projectors = {} + proj_info = [] + idx = None for idx in range(2 * (lmax + 1)): line = lines[idx] if idx % 2 == 0: proj_info = [ line, ] - if idx % 2 == 1: + else: proj_info.append(line) d = _dict_from_lines(proj_info, [5, 4]) projectors[int(d["l"])] = d @@ -1002,8 +1005,7 @@ class PseudoParser: """ Responsible for parsing pseudopotential files and returning pseudopotential objects. - Usage:: - + Usage: pseudo = PseudoParser().parse("filename") """ @@ -1054,7 +1056,7 @@ def scan_directory(self, dirname: str, exclude_exts=(), exclude_fnames=()): """ for i, ext in enumerate(exclude_exts): if not ext.strip().startswith("."): - exclude_exts[i] = "." + ext.strip() + exclude_exts[i] = f".{ext.strip()}" # Exclude files depending on the extension. paths = [] @@ -1097,8 +1099,8 @@ def read_ppdesc(self, filename: str): # Assume file with the abinit header. lines = _read_nlines(filename, 80) - for lineno, line in enumerate(lines): - if lineno == 2: + for lineno, line in enumerate(lines, start=1): + if lineno == 3: try: tokens = line.split() pspcod, _pspxc = map(int, tokens[:2]) @@ -1114,7 +1116,7 @@ def read_ppdesc(self, filename: str): if pspcod == 7: # PAW -> need to know the format pspfmt - tokens = lines[lineno + 1].split() + tokens = lines[lineno].split() pspfmt, _creatorID = tokens[:2] ppdesc = ppdesc._replace(format=pspfmt) @@ -1450,10 +1452,10 @@ def plot_waves(self, ax: plt.Axes = None, fontsize=8, **kwargs): # ax.annotate("$r_c$", xy=(self.paw_radius + 0.1, 0.1)) for state, rfunc in self.pseudo_partial_waves.items(): - ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, lw=2, label="PS-WAVE: " + state) # noqa: PD011 + ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, lw=2, label=f"PS-WAVE: {state}") # noqa: PD011 for state, rfunc in self.ae_partial_waves.items(): - ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, lw=2, label="AE-WAVE: " + state) # noqa: PD011 + ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, lw=2, label=f"AE-WAVE: {state}") # noqa: PD011 ax.legend(loc="best", shadow=True, fontsize=fontsize) @@ -1480,7 +1482,7 @@ def plot_projectors(self, ax: plt.Axes = None, fontsize=8, **kwargs): # ax.annotate("$r_c$", xy=(self.paw_radius + 0.1, 0.1)) for state, rfunc in self.projector_functions.items(): - ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, label="TPROJ: " + state) # noqa: PD011 + ax.plot(rfunc.mesh, rfunc.mesh * rfunc.values, label=f"TPROJ: {state}") # noqa: PD011 ax.legend(loc="best", shadow=True, fontsize=fontsize) @@ -1874,12 +1876,10 @@ class PseudoTable(collections.abc.Sequence, MSONable): @classmethod def as_table(cls, items) -> PseudoTable: """Return an instance of PseudoTable from the iterable items.""" - if isinstance(items, cls): - return items - return cls(items) + return items if isinstance(items, cls) else cls(items) @classmethod - def from_dir(cls, top: str, exts=None, exclude_dirs="_*") -> PseudoTable: + def from_dir(cls, top, exts=None, exclude_dirs="_*") -> Self | None: """ Find all pseudos in the directory tree starting from top. @@ -1889,21 +1889,22 @@ def from_dir(cls, top: str, exts=None, exclude_dirs="_*") -> PseudoTable: we try to open all files in top exclude_dirs: Wildcard used to exclude directories. - return: PseudoTable sorted by atomic number Z. + Returns: + PseudoTable sorted by atomic number Z. """ pseudos = [] if exts == "all_files": - for f in [os.path.join(top, fn) for fn in os.listdir(top)]: - if os.path.isfile(f): + for filepath in [os.path.join(top, fn) for fn in os.listdir(top)]: + if os.path.isfile(filepath): try: - p = Pseudo.from_file(f) - if p: - pseudos.append(p) + pseudo = Pseudo.from_file(filepath) + if pseudo: + pseudos.append(pseudo) else: - logger.info(f"Skipping file {f}") + logger.info(f"Skipping file {filepath}") except Exception: - logger.info(f"Skipping file {f}") + logger.info(f"Skipping file {filepath}") if not pseudos: logger.warning(f"No pseudopotentials parsed from folder {top}") return None @@ -1913,11 +1914,11 @@ def from_dir(cls, top: str, exts=None, exclude_dirs="_*") -> PseudoTable: if exts is None: exts = ("psp8",) - for p in find_exts(top, exts, exclude_dirs=exclude_dirs): + for pseudo in find_exts(top, exts, exclude_dirs=exclude_dirs): try: - pseudos.append(Pseudo.from_file(p)) + pseudos.append(Pseudo.from_file(pseudo)) except Exception as exc: - logger.critical(f"Error in {p}:\n{exc}") + logger.critical(f"Error in {pseudo}:\n{exc}") return cls(pseudos).sort_by_z() @@ -2007,19 +2008,18 @@ def as_dict(self, **kwargs) -> dict: while k in dct: k += f"{k.split('#')[0]}#{count}" count += 1 - dct.update({k: p.as_dict()}) + dct[k] = p.as_dict() dct["@module"] = type(self).__module__ dct["@class"] = type(self).__name__ return dct @classmethod - def from_dict(cls, d: dict) -> PseudoTable: + def from_dict(cls, dct: dict) -> Self: """Build instance from dictionary (MSONable protocol).""" pseudos = [] - dec = MontyDecoder() - for k, v in d.items(): + for k, v in dct.items(): if not k.startswith("@"): - pseudos.append(dec.process_decoded(v)) + pseudos.append(MontyDecoder().process_decoded(v)) return cls(pseudos) def is_complete(self, zmax=118) -> bool: @@ -2032,8 +2032,7 @@ def all_combinations_for_elements(self, element_symbols): for the given list of element_symbols. Each item is a list of pseudopotential objects. - Example:: - + Example: table.all_combinations_for_elements(["Li", "F"]) """ dct = {} @@ -2060,9 +2059,7 @@ def pseudo_with_symbol(self, symbol: str, allow_multi=False) -> Pseudo: if not pseudos or (len(pseudos) > 1 and not allow_multi): raise ValueError(f"Found {len(pseudos)} occurrences of {symbol=}") - if not allow_multi: - return pseudos[0] - return pseudos + return pseudos if allow_multi else pseudos[0] def pseudos_with_symbols(self, symbols: list[str]) -> PseudoTable: """ @@ -2073,13 +2070,11 @@ def pseudos_with_symbols(self, symbols: list[str]) -> PseudoTable: """ pseudos = self.select_symbols(symbols, ret_list=True) found_symbols = [p.symbol for p in pseudos] - duplicated_elements = [s for s, o in collections.Counter(found_symbols).items() if o > 1] - if duplicated_elements: + if duplicated_elements := [s for s, o in collections.Counter(found_symbols).items() if o > 1]: raise ValueError(f"Found multiple occurrences of symbol(s) {', '.join(duplicated_elements)}") - missing_symbols = [s for s in symbols if s not in found_symbols] - if missing_symbols: + if missing_symbols := [s for s in symbols if s not in found_symbols]: raise ValueError(f"Missing data for symbol(s) {', '.join(missing_symbols)}") return pseudos @@ -2161,8 +2156,8 @@ def sorted(self, attrname, reverse=False) -> PseudoTable: """ Sort the table according to the value of attribute attrname. - Return: - New class:`PseudoTable` object + Returns: + New class: `PseudoTable` object """ attrs = [] for i, pseudo in self: @@ -2209,4 +2204,4 @@ def select_family(self, family: str) -> PseudoTable: Return PseudoTable with element belonging to the specified family, e.g. family="alkaline" """ # e.g element.is_alkaline - return type(self)([p for p in self if getattr(p.element, "is_" + family)]) + return type(self)([p for p in self if getattr(p.element, f"is_{family}")]) diff --git a/pymatgen/io/abinit/variable.py b/pymatgen/io/abinit/variable.py index 15bcdd3e229..57abf17082b 100644 --- a/pymatgen/io/abinit/variable.py +++ b/pymatgen/io/abinit/variable.py @@ -193,9 +193,9 @@ def format_list(self, values, float_decimal=0): line = "" # Format the line declaring the value - for i, val in enumerate(values): - line += " " + self.format_scalar(val, float_decimal) - if self.valperline is not None and (i + 1) % self.valperline == 0: + for i, val in enumerate(values, start=1): + line += f" {self.format_scalar(val, float_decimal)}" + if self.valperline is not None and i % self.valperline == 0: line += "\n" # Add a carriage return in case of several lines diff --git a/pymatgen/io/adf.py b/pymatgen/io/adf.py index b2e8f72a7a9..c16a912f442 100644 --- a/pymatgen/io/adf.py +++ b/pymatgen/io/adf.py @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Generator + from typing_extensions import Self + __author__ = "Xin Chen, chenxin13@mails.tsinghua.edu.cn" @@ -83,21 +85,15 @@ def __init__(self, name, options=None, subkeys=None): """ Initialization method. - Parameters - ---------- - name : str - The name of this key. - options : Sized - The options for this key. Each element can be a primitive object or - a tuple/list with two elements: the first is the name and the second - is a primitive object. - subkeys : Sized - The subkeys for this key. + Args: + name (str): The name of this key. + options : Sized + The options for this key. Each element can be a primitive object or + a tuple/list with two elements: the first is the name and the second is a primitive object. + subkeys (Sized): The subkeys for this key. Raises: - ------ - ValueError - If elements in ``subkeys`` are not ``AdfKey`` objects. + ValueError: If elements in ``subkeys`` are not ``AdfKey`` objects. """ self.name = name self.options = options if options is not None else [] @@ -141,9 +137,8 @@ def __str__(self): Return the string representation of this ``AdfKey``. Notes: - ----- - If this key is 'Atoms' and the coordinates are in Cartesian form, a - different string format will be used. + If this key is 'Atoms' and the coordinates are in Cartesian form, + a different string format will be used. """ adf_str = f"{self.key}" if len(self.options) > 0: @@ -172,19 +167,15 @@ def __eq__(self, other: object) -> bool: return False return str(self) == str(other) - def has_subkey(self, subkey): + def has_subkey(self, subkey: str | AdfKey) -> bool: """ Return True if this AdfKey contains the given subkey. - Parameters - ---------- - subkey : str or AdfKey - A key name or an AdfKey object. + Args: + subkey (str or AdfKey): A key name or an AdfKey object. Returns: - ------- - has : bool - True if this key contains the given key. Otherwise False. + bool: Whether this key contains the given key. """ if isinstance(subkey, str): key = subkey @@ -200,14 +191,11 @@ def add_subkey(self, subkey): """ Add a new subkey to this key. - Parameters - ---------- - subkey : AdfKey - A new subkey. + Args: + subkey (AdfKey): A new subkey. Notes: - ----- - Duplicate check will not be performed if this is an 'Atoms' block. + Duplicate check will not be performed if this is an 'Atoms' block. """ if self.key.lower() == "atoms" or not self.has_subkey(subkey): self.subkeys.append(subkey) @@ -216,10 +204,8 @@ def remove_subkey(self, subkey): """ Remove the given subkey, if existed, from this AdfKey. - Parameters - ---------- - subkey : str or AdfKey - The subkey to remove. + Args: + subkey (str or AdfKey): The subkey to remove. """ if len(self.subkeys) > 0: key = subkey if isinstance(subkey, str) else subkey.key @@ -232,16 +218,13 @@ def add_option(self, option): """ Add a new option to this key. - Parameters - ---------- - option : Sized or str or int or float - A new option to add. This must have the same format with existing - options. + Args: + option : Sized or str or int or float + A new option to add. This must have the same format + with existing options. Raises: - ------ - TypeError - If the format of the given ``option`` is different. + TypeError: If the format of the given ``option`` is different. """ if len(self.options) == 0: self.options.append(option) @@ -251,46 +234,38 @@ def add_option(self, option): raise TypeError("Option type is mismatched!") self.options.append(option) - def remove_option(self, option): + def remove_option(self, option: str | int) -> None: """ Remove an option. - Parameters - ---------- - option : str or int - The name (str) or index (int) of the option to remove. + Args: + option (str | int): The name or index of the option to remove. Raises: - ------ - TypeError - If the option has a wrong type. + TypeError: If the option has a wrong type. """ if len(self.options) > 0: if self._sized_op: if not isinstance(option, str): raise TypeError("``option`` should be a name string!") - for i, v in enumerate(self.options): - if v[0] == option: - self.options.pop(i) + for idx, val in enumerate(self.options): + if val[0] == option: + self.options.pop(idx) break else: if not isinstance(option, int): raise TypeError("``option`` should be an integer index!") self.options.pop(option) - def has_option(self, option): + def has_option(self, option: str) -> bool: """ Return True if the option is included in this key. - Parameters - ---------- - option : str - The option. + Args: + option (str): The option. Returns: - ------- - has : bool - True if the option can be found. Otherwise False will be returned. + bool: Whether the option can be found. """ if len(self.options) == 0: return False @@ -312,50 +287,39 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Construct a MSONable AdfKey object from the JSON dict. - Parameters - ---------- - d : dict - A dict of saved attributes. + Args: + dct (dict): A dict of saved attributes. Returns: - ------- - adfkey : AdfKey - An AdfKey object recovered from the JSON dict ``d``. + AdfKey: An AdfKey object recovered from the JSON dict. """ - key = d.get("name") - options = d.get("options") - subkey_list = d.get("subkeys", []) + key = dct.get("name") + options = dct.get("options") + subkey_list = dct.get("subkeys", []) subkeys = [AdfKey.from_dict(k) for k in subkey_list] or None return cls(key, options, subkeys) @classmethod - def from_str(cls, string: str) -> AdfKey: + def from_str(cls, string: str) -> Self: """ Construct an AdfKey object from the string. - Parameters - ---------- - string : str - A string. + Args: + string: str Returns: - ------- - adfkey : AdfKey An AdfKey object recovered from the string. Raises: - ------ - ValueError - Currently nested subkeys are not supported. If ``subend`` was found - a ValueError would be raised. + ValueError: Currently nested subkeys are not supported. + If ``subend`` was found a ValueError would be raised. Notes: - ----- - Only the first block key will be returned. + Only the first block key will be returned. """ def is_float(s) -> bool: @@ -377,7 +341,7 @@ def is_float(s) -> bool: if string.find("subend") != -1: raise ValueError("Nested subkeys are not supported!") - def iterlines(s: str) -> Generator[str, None, None]: + def iter_lines(s: str) -> Generator[str, None, None]: r"""A generator form of s.split('\n') for reducing memory overhead. Args: @@ -396,7 +360,7 @@ def iterlines(s: str) -> Generator[str, None, None]: prev_nl = next_nl key = None - for line in iterlines(string): + for line in iter_lines(string): if line == "": continue el = line.strip().split() @@ -412,7 +376,7 @@ def iterlines(s: str) -> Generator[str, None, None]: elif key is not None: key.add_subkey(cls.from_str(line)) - raise Exception("IncompleteKey: 'END' is missing!") + raise KeyError("Incomplete key: 'END' is missing!") class AdfTask(MSONable): @@ -420,9 +384,8 @@ class AdfTask(MSONable): Basic task for ADF. All settings in this class are independent of molecules. Notes: - ----- - Unlike other quantum chemistry packages (NWChem, Gaussian, ...), ADF does - not support calculating force/gradient. + Unlike other quantum chemistry packages (NWChem, Gaussian, ...), + ADF does not support calculating force/gradient. """ operations = dict( @@ -447,24 +410,15 @@ def __init__( """ Initialization method. - Parameters - ---------- - operation : str - The target operation. - basis_set : AdfKey - The basis set definitions for this task. Defaults to 'DZ/Large'. - xc : AdfKey - The exchange-correlation functionals. Defaults to PBE. - title : str - The title of this ADF task. - units : AdfKey - The units. Defaults to Angstroms/Degree. - geo_subkeys : Sized - The subkeys for the block key 'GEOMETRY'. - scf : AdfKey - The scf options. - other_directives : Sized - User-defined directives. + Args: + operation (str): The target operation. + basis_set (AdfKey): The basis set definitions for this task. Defaults to 'DZ/Large'. + xc (AdfKey): The exchange-correlation functionals. Defaults to PBE. + title (str): The title of this ADF task. + units (AdfKey): The units. Defaults to Angstroms/Degree. + geo_subkeys (Sized): The subkeys for the block key 'GEOMETRY'. + scf (AdfKey): The scf options. + other_directives (Sized): User-defined directives. """ if operation not in self.operations: raise AdfInputError(f"Invalid ADF task {operation}") @@ -506,15 +460,12 @@ def _setup_task(self, geo_subkeys): """ Setup the block 'Geometry' given subkeys and the task. - Parameters - ---------- - geo_subkeys : Sized - User-defined subkeys for the block 'Geometry'. + Args: + geo_subkeys (Sized): User-defined subkeys for the block 'Geometry'. Notes: - ----- - Most of the run types of ADF are specified in the Geometry block except - the 'AnalyticFreq'. + Most of the run types of ADF are specified in the Geometry + block except the 'AnalyticFreq'. """ self.geo = AdfKey("Geometry", subkeys=geo_subkeys) if self.operation.lower() == "energy": @@ -562,32 +513,28 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Construct a MSONable AdfTask object from the JSON dict. - Parameters - ---------- - d : dict - A dict of saved attributes. + Args: + dct: A dict of saved attributes. Returns: - ------- - task : AdfTask An AdfTask object recovered from the JSON dict ``d``. """ def _from_dict(_d): return AdfKey.from_dict(_d) if _d is not None else None - operation = d.get("operation") - title = d.get("title") - basis_set = _from_dict(d.get("basis_set")) - xc = _from_dict(d.get("xc")) - units = _from_dict(d.get("units")) - scf = _from_dict(d.get("scf")) - others = [AdfKey.from_dict(o) for o in d.get("others", [])] - geo = _from_dict(d.get("geo")) + operation = dct.get("operation") + title = dct.get("title") + basis_set = _from_dict(dct.get("basis_set")) + xc = _from_dict(dct.get("xc")) + units = _from_dict(dct.get("units")) + scf = _from_dict(dct.get("scf")) + others = [AdfKey.from_dict(o) for o in dct.get("others", [])] + geo = _from_dict(dct.get("geo")) return cls(operation, basis_set, xc, title, units, geo.subkeys, scf, others) @@ -599,10 +546,8 @@ def __init__(self, task): """ Initialization method. - Parameters - ---------- - task : AdfTask - An ADF task. + Args: + task (AdfTask): An ADF task. """ self.task = task @@ -610,12 +555,9 @@ def write_file(self, molecule, inp_file): """ Write an ADF input file. - Parameters - ---------- - molecule : Molecule - The molecule for this task. - inpfile : str - The name where the input file will be saved. + Args: + molecule (Molecule): The molecule for this task. + inpfile (str): The name where the input file will be saved. """ mol_blocks = [] atom_block = AdfKey("Atoms", options=["cartesian"]) @@ -644,41 +586,27 @@ class AdfOutput: A basic ADF output file parser. Attributes: - ---------- - is_failed : bool - True is the ADF job is terminated without success. Otherwise False. - is_internal_crash : bool - True if the job is terminated with internal crash. Please read 'TAPE13' - of the ADF manual for more detail. - error : str - The error description. - run_type : str - The RunType of this ADF job. Possible options are: 'SinglePoint', - 'GeometryOptimization', 'AnalyticalFreq' and 'NUmericalFreq'. - final_energy : float - The final molecule energy (a.u). - final_structure : GMolecule - The final structure of the molecule. - energies : Sized - The energy of each cycle. - structures : Sized - The structure of each cycle If geometry optimization is performed. - frequencies : array_like - The frequencies of the molecule. - normal_modes : array_like - The normal modes of the molecule. - freq_type : str - Either 'Analytical' or 'Numerical'. + is_failed (bool): Whether the ADF job is failed. + is_internal_crash (bool): Whether the job crashed. + Please read 'TAPE13' of the ADF manual for more detail. + error (str): The error description. + run_type (str): The RunType of this ADF job. Possible options are: + 'SinglePoint', 'GeometryOptimization', 'AnalyticalFreq' and 'NUmericalFreq'. + final_energy (float): The final molecule energy (a.u). + final_structure (GMolecule): The final structure of the molecule. + energies (Sized): The energy of each cycle. + structures (Sized): The structure of each cycle If geometry optimization is performed. + frequencies (array_like): The frequencies of the molecule. + normal_modes (array_like): The normal modes of the molecule. + freq_type (syr): Either 'Analytical' or 'Numerical'. """ def __init__(self, filename): """ Initialization method. - Parameters - ---------- - filename : str - The ADF output file to parse. + Args: + filename (str): The ADF output file to parse. """ self.filename = filename self._parse() @@ -716,15 +644,11 @@ def _sites_to_mol(sites): """ Return a ``Molecule`` object given a list of sites. - Parameters - ---------- - sites : list - A list of sites. + Args: + sites : A list of sites. Returns: - ------- - mol : Molecule - A ``Molecule`` object. + mol (Molecule): A ``Molecule`` object. """ return Molecule([site[0] for site in sites], [site[1] for site in sites]) @@ -905,7 +829,7 @@ def _parse_adf_output(self): v = list(chunks(map(float, m.group(3).split()), 3)) if len(v) != n_next: raise AdfOutputError("Odd Error!") - for i, k in enumerate(range(-n_next, 0, 1)): + for i, k in enumerate(range(-n_next, 0)): self.normal_modes[k].extend(v[i]) if int(m.group(1)) == n_atoms: parse_freq = True diff --git a/pymatgen/io/aims/inputs.py b/pymatgen/io/aims/inputs.py index 128f8d89b26..c5a278e66fa 100644 --- a/pymatgen/io/aims/inputs.py +++ b/pymatgen/io/aims/inputs.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +from monty.io import zopen from monty.json import MontyDecoder, MSONable from pymatgen.core import Lattice, Molecule, Structure @@ -21,6 +22,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + __author__ = "Thomas A. R. Purcell" __version__ = "1.0" __email__ = "purcellt@arizona.edu" @@ -33,7 +36,7 @@ class AimsGeometryIn(MSONable): Attributes: _content (str): The content of the input file - _structure (Structure or Molecule): The structure or molecule + _structure (Structure | Molecule): The structure or molecule representation of the file """ @@ -41,7 +44,7 @@ class AimsGeometryIn(MSONable): _structure: Structure | Molecule @classmethod - def from_str(cls, contents: str) -> AimsGeometryIn: + def from_str(cls, contents: str) -> Self: """Create an input from the content of an input file Args: @@ -54,12 +57,8 @@ def from_str(cls, contents: str) -> AimsGeometryIn: line.strip() for line in contents.split("\n") if len(line.strip()) > 0 and line.strip()[0] != "#" ] - species = [] - coords = [] - is_frac = [] - lattice_vectors = [] - charges_dct = {} - moments_dct = {} + species, coords, is_frac, lattice_vectors = [], [], [], [] + charges_dct, moments_dct = {}, {} for line in content_lines: inp = line.split() @@ -105,7 +104,7 @@ def from_str(cls, contents: str) -> AimsGeometryIn: return cls(_content="\n".join(content_lines), _structure=structure) @classmethod - def from_file(cls, filepath: str | Path) -> AimsGeometryIn: + def from_file(cls, filepath: str | Path) -> Self: """Create an AimsGeometryIn from an input file. Args: @@ -114,25 +113,21 @@ def from_file(cls, filepath: str | Path) -> AimsGeometryIn: Returns: AimsGeometryIn: The input object represented in the file """ - if str(filepath).endswith(".gz"): - with gzip.open(filepath, mode="rt") as infile: - content = infile.read() - else: - with open(filepath) as infile: - content = infile.read() + with zopen(filepath, mode="rt") as in_file: + content = in_file.read() return cls.from_str(content) @classmethod - def from_structure(cls, structure: Structure | Molecule) -> AimsGeometryIn: + def from_structure(cls, structure: Structure | Molecule) -> Self: """Construct an input file from an input structure. Args: - structure (Structure or Molecule): The structure for the file + structure (Structure | Molecule): The structure for the file Returns: AimsGeometryIn: The input object for the structure """ - content_lines = [] + content_lines: list[str] = [] if isinstance(structure, Structure): for lv in structure.lattice.matrix: @@ -190,7 +185,7 @@ def as_dict(self) -> dict[str, Any]: return dct @classmethod - def from_dict(cls, dct: dict[str, Any]) -> AimsGeometryIn: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Initialize from dictionary. Args: @@ -378,7 +373,7 @@ def as_dict(self) -> dict[str, Any]: return dct @classmethod - def from_dict(cls, dct: dict[str, Any]) -> AimsCube: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Initialize from dictionary. Args: @@ -406,8 +401,7 @@ class AimsControlIn(MSONable): def __post_init__(self) -> None: """Initialize the output list of _parameters""" - if "output" not in self._parameters: - self._parameters["output"] = [] + self._parameters.setdefault("output", []) def __getitem__(self, key: str) -> Any: """Get an input parameter @@ -464,8 +458,7 @@ def parameters(self, parameters: dict[str, Any]) -> None: parameters (dict[str, Any]): The new set of parameters to use """ self._parameters = parameters - if "output" not in self._parameters: - self._parameters["output"] = [] + self._parameters.setdefault("output", []) def get_aims_control_parameter_str(self, key: str, value: Any, fmt: str) -> str: """Get the string needed to add a parameter to the control.in file @@ -478,6 +471,8 @@ def get_aims_control_parameter_str(self, key: str, value: Any, fmt: str) -> str: Returns: str: The line to add to the control.in file """ + if value is None: + return "" return f"{key:35s}{fmt % value}\n" def get_content( @@ -486,7 +481,7 @@ def get_content( """Get the content of the file Args: - structure (Structure or Molecule): The structure to write the input + structure (Structure | Molecule): The structure to write the input file for verbose_header (bool): If True print the input option dictionary directory: str | Path | None = The directory for the calculation, @@ -510,8 +505,7 @@ def get_content( if verbose_header: content += "# \n# List of parameters used to initialize the calculator:" for param, val in parameters.items(): - s = f"# {param}:{val}\n" - content += s + content += f"# {param}:{val}\n" content += lim + "\n" assert ("smearing" in parameters and "occupation_type" in parameters) is False @@ -539,7 +533,7 @@ def get_content( elif isinstance(value, bool): content += self.get_aims_control_parameter_str(key, str(value).lower(), ".%s.") elif isinstance(value, (tuple, list)): - content += self.get_aims_control_parameter_str(key, " ".join([str(x) for x in value]), "%s") + content += self.get_aims_control_parameter_str(key, " ".join(map(str, value)), "%s") elif isinstance(value, str): content += self.get_aims_control_parameter_str(key, value, "%s") else: @@ -565,7 +559,7 @@ def write_file( """Writes the control.in file Args: - structure (Structure or Molecule): The structure to write the input + structure (Structure | Molecule): The structure to write the input file for directory (str or Path): The directory to write the control.in file. If None use cwd @@ -610,20 +604,20 @@ def get_species_block(self, structure: Structure | Molecule, species_dir: str | Raises: ValueError: If a file for the species is not found """ - sb = "" + block = "" species = np.unique(structure.species) for sp in species: filename = f"{species_dir}/{sp.Z:02d}_{sp.symbol}_default" if Path(filename).exists(): with open(filename) as sf: - sb += "".join(sf.readlines()) + block += "".join(sf.readlines()) elif Path(f"{filename}.gz").exists(): with gzip.open(f"{filename}.gz", mode="rt") as sf: - sb += "".join(sf.readlines()) + block += "".join(sf.readlines()) else: raise ValueError(f"Species file for {sp.symbol} not found.") - return sb + return block def as_dict(self) -> dict[str, Any]: """Get a dictionary representation of the geometry.in file.""" @@ -634,7 +628,7 @@ def as_dict(self) -> dict[str, Any]: return dct @classmethod - def from_dict(cls, dct: dict[str, Any]) -> AimsControlIn: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Initialize from dictionary. Args: diff --git a/pymatgen/io/aims/outputs.py b/pymatgen/io/aims/outputs.py index b6a189dece7..9c33c414f85 100644 --- a/pymatgen/io/aims/outputs.py +++ b/pymatgen/io/aims/outputs.py @@ -19,6 +19,7 @@ from pathlib import Path from emmet.core.math import Matrix3D, Vector3D + from typing_extensions import Self from pymatgen.core import Molecule, Structure @@ -61,7 +62,7 @@ def as_dict(self) -> dict[str, Any]: return dct @classmethod - def from_outfile(cls, outfile: str | Path) -> AimsOutput: + def from_outfile(cls, outfile: str | Path) -> Self: """Construct an AimsOutput from an output file. Args: @@ -76,7 +77,7 @@ def from_outfile(cls, outfile: str | Path) -> AimsOutput: return cls(results, metadata, structure_summary) @classmethod - def from_str(cls, content: str) -> AimsOutput: + def from_str(cls, content: str) -> Self: """Construct an AimsOutput from an output file. Args: @@ -91,7 +92,7 @@ def from_str(cls, content: str) -> AimsOutput: return cls(results, metadata, structure_summary) @classmethod - def from_dict(cls, dct: dict[str, Any]) -> AimsOutput: + def from_dict(cls, dct: dict[str, Any]) -> Self: """Construct an AimsOutput from a dictionary. Args: diff --git a/pymatgen/io/aims/parsers.py b/pymatgen/io/aims/parsers.py index 2ef1c179e09..b7f850b96f8 100644 --- a/pymatgen/io/aims/parsers.py +++ b/pymatgen/io/aims/parsers.py @@ -577,48 +577,36 @@ def _parse_lattice_atom_pos( def species(self) -> list[str]: """The list of atomic symbols for all atoms in the structure""" if "species" not in self._cache: - ( - self._cache["species"], - self._cache["coords"], - self._cache["velocities"], - self._cache["lattice"], - ) = self._parse_lattice_atom_pos() + self._cache["species"], self._cache["coords"], self._cache["velocities"], self._cache["lattice"] = ( + self._parse_lattice_atom_pos() + ) return self._cache["species"] @property def coords(self) -> list[Vector3D]: """The cartesian coordinates of the atoms""" if "coords" not in self._cache: - ( - self._cache["species"], - self._cache["coords"], - self._cache["velocities"], - self._cache["lattice"], - ) = self._parse_lattice_atom_pos() + self._cache["species"], self._cache["coords"], self._cache["velocities"], self._cache["lattice"] = ( + self._parse_lattice_atom_pos() + ) return self._cache["coords"] @property def velocities(self) -> list[Vector3D]: """The velocities of the atoms""" if "velocities" not in self._cache: - ( - self._cache["species"], - self._cache["coords"], - self._cache["velocities"], - self._cache["lattice"], - ) = self._parse_lattice_atom_pos() + self._cache["species"], self._cache["coords"], self._cache["velocities"], self._cache["lattice"] = ( + self._parse_lattice_atom_pos() + ) return self._cache["velocities"] @property def lattice(self) -> Lattice: """The Lattice object for the structure""" if "lattice" not in self._cache: - ( - self._cache["species"], - self._cache["coords"], - self._cache["velocities"], - self._cache["lattice"], - ) = self._parse_lattice_atom_pos() + self._cache["species"], self._cache["coords"], self._cache["velocities"], self._cache["lattice"] = ( + self._parse_lattice_atom_pos() + ) return self._cache["lattice"] @property diff --git a/pymatgen/io/aims/sets/base.py b/pymatgen/io/aims/sets/base.py index df8a8597159..5612cdd5016 100644 --- a/pymatgen/io/aims/sets/base.py +++ b/pymatgen/io/aims/sets/base.py @@ -208,9 +208,8 @@ def get_input_set( # type: ignore properties: list[str] System properties that are being calculated - Returns - ------- - The input set for the calculation of structure + Returns: + AimsInputSet: The input set for the calculation of structure """ prev_structure, prev_parameters, _ = self._read_previous(prev_dir) @@ -274,9 +273,8 @@ def _get_properties( parameters: dict[str, Any] The parameters for this calculation - Returns - ------- - The list of properties to calculate + Returns: + list[str]: The list of properties to calculate """ if properties is None: properties = ["energy", "free_energy"] @@ -310,9 +308,8 @@ def _get_input_parameters( prev_parameters: dict[str, Any] The previous calculation's calculation parameters - Returns - ------- - The input object + Returns: + dict: The input object """ # Get the default configuration # FHI-aims recommends using their defaults so bare-bones default parameters @@ -363,9 +360,8 @@ def get_parameter_updates(self, structure: Structure | Molecule, prev_parameters prev_parameters: dict[str, Any] Previous calculation parameters. - Returns - ------- - A dictionary of updates to apply. + Returns: + dict: A dictionary of updates to apply. """ return prev_parameters @@ -388,9 +384,8 @@ def d2k( even: bool Round up to even numbers. - Returns - ------- - Monkhorst-Pack grid size in all directions + Returns: + dict: Monkhorst-Pack grid size in all directions """ recipcell = structure.lattice.inv_matrix return self.d2k_recipcell(recipcell, structure.lattice.pbc, kptdensity, even) @@ -405,9 +400,8 @@ def k2d(self, structure: Structure, k_grid: np.ndarray[int]): k_grid: np.ndarray[int] k_grid that was used. - Returns - ------- - Density of kpoints in each direction. result.mean() computes average density + Returns: + dict: Density of kpoints in each direction. result.mean() computes average density """ recipcell = structure.lattice.inv_matrix densities = k_grid / (2 * np.pi * np.sqrt((recipcell**2).sum(axis=1))) @@ -433,9 +427,8 @@ def d2k_recipcell( even: bool Round up to even numbers. - Returns - ------- - Monkhorst-Pack grid size in all directions + Returns: + dict: Monkhorst-Pack grid size in all directions """ if not isinstance(kptdensity, Iterable): kptdensity = 3 * [float(kptdensity)] @@ -458,14 +451,11 @@ def recursive_update(dct: dict, up: dict) -> dict: Parameters ---------- - dct: Dict - Input dictionary to modify - up: Dict - Dictionary of updates to apply + dct (dict): Input dictionary to modify + up (dict): updates to apply - Returns - ------- - The updated dictionary. + Returns: + dict: The updated dictionary. Example ------- diff --git a/pymatgen/io/aims/sets/bs.py b/pymatgen/io/aims/sets/bs.py index 89cab4f0ada..d0833f16999 100644 --- a/pymatgen/io/aims/sets/bs.py +++ b/pymatgen/io/aims/sets/bs.py @@ -45,9 +45,8 @@ def prepare_band_input(structure: Structure, density: float = 20): current_segment["length"] += 1 lines_and_labels.append(current_segment) current_segment = None - else: - if current_segment is not None: - current_segment["length"] += 1 + elif current_segment is not None: + current_segment["length"] += 1 bands = [] for segment in lines_and_labels: @@ -88,9 +87,8 @@ def get_parameter_updates( prev_parameters: Dict[str, Any] The previous parameters - Returns - ------- - The updated for the parameters for the output section of FHI-aims + Returns: + dict: The updated for the parameters for the output section of FHI-aims """ if isinstance(structure, Molecule): raise ValueError("BandStructures can not be made for non-periodic systems") @@ -126,9 +124,8 @@ def get_parameter_updates(self, structure: Structure | Molecule, prev_parameters prev_parameters: Dict[str, Any] The previous parameters - Returns - ------- - The updated for the parameters for the output section of FHI-aims + Returns: + dict: The updated for the parameters for the output section of FHI-aims """ updates = {"anacon_type": "two-pole"} current_output = prev_parameters.get("output", []) diff --git a/pymatgen/io/aims/sets/core.py b/pymatgen/io/aims/sets/core.py index e1b4880367a..389970abd4e 100644 --- a/pymatgen/io/aims/sets/core.py +++ b/pymatgen/io/aims/sets/core.py @@ -34,9 +34,8 @@ def get_parameter_updates(self, structure: Structure | Molecule, prev_parameters prev_parameters: Dict[str, Any] The previous parameters - Returns - ------- - The updated for the parameters for the output section of FHI-aims + Returns: + dict: The updated for the parameters for the output section of FHI-aims """ return prev_parameters @@ -75,9 +74,8 @@ def get_parameter_updates(self, structure: Structure | Molecule, prev_parameters prev_parameters: Dict[str, Any] The previous parameters - Returns - ------- - The updated for the parameters for the output section of FHI-aims + Returns: + dict: The updated for the parameters for the output section of FHI-aims """ updates = {"relax_geometry": f"{self.method} {self.max_force:e}"} if isinstance(structure, Structure) and self.relax_cell: @@ -116,8 +114,7 @@ def get_parameter_updates(self, structure: Structure | Molecule, prev_parameters prev_parameters: Dict[str, Any] The previous parameters - Returns - ------- - The updated for the parameters for the output section of FHI-aims + Returns: + dict: The updated for the parameters for the output section of FHI-aims """ return {"use_pimd_wrapper": (self.host, self.port)} diff --git a/pymatgen/io/ase.py b/pymatgen/io/ase.py index 7d959179bf5..8cc59cb78ca 100644 --- a/pymatgen/io/ase.py +++ b/pymatgen/io/ase.py @@ -15,13 +15,6 @@ from pymatgen.core.structure import Molecule, Structure -if TYPE_CHECKING: - from typing import Any - - from numpy.typing import ArrayLike - - from pymatgen.core.structure import SiteCollection - try: from ase.atoms import Atoms from ase.calculators.singlepoint import SinglePointDFTCalculator @@ -29,15 +22,24 @@ from ase.io.jsonio import decode, encode from ase.spacegroup import Spacegroup - no_ase_err = None + NO_ASE_ERR = None except ImportError: - no_ase_err = PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`") + NO_ASE_ERR = PackageNotFoundError("AseAtomsAdaptor requires the ASE package. Use `pip install ase`") + encode = decode = FixAtoms = SinglePointDFTCalculator = Spacegroup = None class Atoms: # type: ignore[no-redef] def __init__(self, *args, **kwargs): - raise no_ase_err + raise NO_ASE_ERR +if TYPE_CHECKING: + from typing import Any + + from numpy.typing import ArrayLike + from typing_extensions import Self + + from pymatgen.core.structure import SiteCollection + __author__ = "Shyue Ping Ong, Andrew S. Rosen" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "1.0" @@ -63,12 +65,13 @@ def as_dict(atoms: Atoms) -> dict[str, Any]: "atoms_info": jsanitize(atoms.info, strict=True), } - def from_dict(dct: dict[str, Any]) -> MSONAtoms: + @classmethod + def from_dict(cls, dct: dict[str, Any]) -> Self: # Normally, we would want to this to be a wrapper around atoms.fromdict() with @module and # @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant # to be used in a round-trip fashion and does not work properly with constraints. # See ASE issue #1387. - mson_atoms = MSONAtoms(decode(dct["atoms_json"])) + mson_atoms = cls(decode(dct["atoms_json"])) atoms_info = MontyDecoder().process_decoded(dct["atoms_info"]) mson_atoms.info = atoms_info return mson_atoms @@ -92,8 +95,8 @@ def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSO Returns: Atoms: ASE Atoms object """ - if no_ase_err: - raise no_ase_err + if NO_ASE_ERR: + raise NO_ASE_ERR if not structure.is_ordered: raise ValueError("ASE Atoms only supports ordered structures") diff --git a/pymatgen/io/atat.py b/pymatgen/io/atat.py index 31860a4c9a3..7badcab828c 100644 --- a/pymatgen/io/atat.py +++ b/pymatgen/io/atat.py @@ -60,7 +60,8 @@ def structure_from_str(data): Parses a rndstr.in, lat.in or bestsqs.out file into pymatgen's Structure format. - :param data: contents of a rndstr.in, lat.in or bestsqs.out file + Args: + data: contents of a rndstr.in, lat.in or bestsqs.out file Returns: Structure object diff --git a/pymatgen/io/babel.py b/pymatgen/io/babel.py index 8704d3a3149..f6f862bc386 100644 --- a/pymatgen/io/babel.py +++ b/pymatgen/io/babel.py @@ -8,6 +8,7 @@ import copy import warnings +from typing import TYPE_CHECKING from monty.dev import requires @@ -16,7 +17,12 @@ try: from openbabel import openbabel, pybel except Exception: - openbabel = None + openbabel = pybel = None + +if TYPE_CHECKING: + from typing_extensions import Self + + from pymatgen.analysis.graphs import MoleculeGraph __author__ = "Shyue Ping Ong, Qi Wang" @@ -64,7 +70,7 @@ def __init__(self, mol: Molecule | openbabel.OBMol | pybel.Molecule) -> None: ob_atom = openbabel.OBAtom() ob_atom.thisown = 0 ob_atom.SetAtomicNum(atom_no) - ob_atom.SetVector(*coords) + ob_atom.SetVector(*map(float, coords)) ob_mol.AddAtom(ob_atom) del ob_atom ob_mol.ConnectTheDots() @@ -82,7 +88,7 @@ def __init__(self, mol: Molecule | openbabel.OBMol | pybel.Molecule) -> None: raise ValueError(f"Unsupported input type {type(mol)}, must be Molecule, openbabel.OBMol or pybel.Molecule") @property - def pymatgen_mol(self): + def pymatgen_mol(self) -> Molecule: """Returns pymatgen Molecule object.""" sp = [] coords = [] @@ -96,7 +102,7 @@ def openbabel_mol(self): """Returns OpenBabel's OBMol.""" return self._ob_mol - def localopt(self, forcefield="mmff94", steps=500): + def localopt(self, forcefield: str = "mmff94", steps: int = 500) -> None: """ A wrapper to pybel's localopt method to optimize a Molecule. @@ -109,7 +115,7 @@ def localopt(self, forcefield="mmff94", steps=500): pybelmol.localopt(forcefield=forcefield, steps=steps) self._ob_mol = pybelmol.OBMol - def make3d(self, forcefield="mmff94", steps=50): + def make3d(self, forcefield: str = "mmff94", steps: int = 50) -> None: """ A wrapper to pybel's make3D method generate a 3D structure from a 2D or 0D structure. @@ -132,11 +138,11 @@ def make3d(self, forcefield="mmff94", steps=50): pybelmol.make3D(forcefield=forcefield, steps=steps) self._ob_mol = pybelmol.OBMol - def add_hydrogen(self): + def add_hydrogen(self) -> None: """Add hydrogens (make all hydrogen explicit).""" self._ob_mol.AddHydrogens() - def remove_bond(self, idx1, idx2): + def remove_bond(self, idx1: int, idx2: int) -> None: """ Remove a bond from an openbabel molecule. @@ -150,7 +156,7 @@ def remove_bond(self, idx1, idx2): ): self._ob_mol.DeleteBond(obbond) - def rotor_conformer(self, *rotor_args, algo="WeightedRotorSearch", forcefield="mmff94"): + def rotor_conformer(self, *rotor_args, algo: str = "WeightedRotorSearch", forcefield: str = "mmff94") -> None: """ Conformer search based on several Rotor Search algorithms of openbabel. If the input molecule is not 3D, make3d will be called (generate 3D @@ -176,14 +182,14 @@ def rotor_conformer(self, *rotor_args, algo="WeightedRotorSearch", forcefield="m else: self.add_hydrogen() - ff = openbabel.OBForceField_FindType(forcefield) + ff = openbabel.OBForceField.FindType(forcefield) if ff == 0: warnings.warn( f"This input {forcefield=} is not supported " "in openbabel. The forcefield will be reset as " "default 'mmff94' for now." ) - ff = openbabel.OBForceField_FindType("mmff94") + ff = openbabel.OBForceField.FindType("mmff94") try: rotor_search = getattr(ff, algo) @@ -200,7 +206,7 @@ def rotor_conformer(self, *rotor_args, algo="WeightedRotorSearch", forcefield="m rotor_search(*rotor_args) ff.GetConformers(self._ob_mol) - def gen3d_conformer(self): + def gen3d_conformer(self) -> None: """ A combined method to first generate 3D structures from 0D or 2D structures and then find the minimum energy conformer: @@ -225,13 +231,13 @@ def gen3d_conformer(self): def confab_conformers( self, - forcefield="mmff94", - freeze_atoms=None, - rmsd_cutoff=0.5, - energy_cutoff=50.0, - conf_cutoff=100000, - verbose=False, - ): + forcefield: str = "mmff94", + freeze_atoms: list[int] | None = None, + rmsd_cutoff: float = 0.5, + energy_cutoff: float = 50.0, + conf_cutoff: int = 100000, + verbose: bool = False, + ) -> list[Molecule]: """ Conformer generation based on Confab to generate all diverse low-energy conformers for molecules. This is different from rotor_conformer or @@ -245,7 +251,7 @@ def confab_conformers( conformer search, default is None. rmsd_cutoff (float): rmsd_cufoff, default is 0.5 Angstrom. energy_cutoff (float): energy_cutoff, default is 50.0 kcal/mol. - conf_cutoff (float): max number of conformers to test, + conf_cutoff (int): max number of conformers to test, default is 1 million. verbose (bool): whether to display information on torsions found, default is False. @@ -258,10 +264,10 @@ def confab_conformers( else: self.add_hydrogen() - ff = openbabel.OBForceField_FindType(forcefield) + ff = openbabel.OBForceField.FindType(forcefield) if ff == 0: print(f"Could not find {forcefield=} in openbabel, the forcefield will be reset as default 'mmff94'") - ff = openbabel.OBForceField_FindType("mmff94") + ff = openbabel.OBForceField.FindType("mmff94") if freeze_atoms: print(f"{len(freeze_atoms)} atoms will be freezed") @@ -289,11 +295,11 @@ def confab_conformers( return conformers @property - def pybel_mol(self): + def pybel_mol(self) -> Molecule: """Returns Pybel's Molecule object.""" return pybel.Molecule(self._ob_mol) - def write_file(self, filename, file_format="xyz"): + def write_file(self, filename: str, file_format: str = "xyz") -> None: """ Uses OpenBabel to output all supported formats. @@ -302,10 +308,12 @@ def write_file(self, filename, file_format="xyz"): file_format: String specifying any OpenBabel supported formats. """ mol = pybel.Molecule(self._ob_mol) - return mol.write(file_format, filename, overwrite=True) + mol.write(file_format, filename, overwrite=True) @classmethod - def from_file(cls, filename, file_format="xyz", return_all_molecules=False): + def from_file( + cls, filename: str, file_format: str = "xyz", return_all_molecules: bool = False + ) -> Self | list[Self]: """ Uses OpenBabel to read a molecule from a file in all supported formats. @@ -325,8 +333,8 @@ def from_file(cls, filename, file_format="xyz", return_all_molecules=False): return cls(next(mols).OBMol) - @staticmethod - def from_molecule_graph(mol): + @classmethod + def from_molecule_graph(cls, mol: MoleculeGraph) -> Self: """ Read a molecule from a pymatgen MoleculeGraph object. @@ -336,11 +344,11 @@ def from_molecule_graph(mol): Returns: BabelMolAdaptor object """ - return BabelMolAdaptor(mol.molecule) + return cls(mol.molecule) - @needs_openbabel @classmethod - def from_str(cls, string_data, file_format="xyz"): + @needs_openbabel + def from_str(cls, string_data: str, file_format: str = "xyz") -> Self: """ Uses OpenBabel to read a molecule from a string in all supported formats. @@ -352,5 +360,5 @@ def from_str(cls, string_data, file_format="xyz"): Returns: BabelMolAdaptor object """ - mols = pybel.readstring(str(file_format), str(string_data)) + mols = pybel.readstring(file_format, string_data) return cls(mols.OBMol) diff --git a/pymatgen/io/cif.py b/pymatgen/io/cif.py index e155d4d5e52..4d8006702f4 100644 --- a/pymatgen/io/cif.py +++ b/pymatgen/io/cif.py @@ -8,9 +8,8 @@ import textwrap import warnings from collections import defaultdict, deque -from datetime import datetime from functools import partial -from inspect import getfullargspec as getargspec +from inspect import getfullargspec from io import StringIO from itertools import groupby from pathlib import Path @@ -31,6 +30,8 @@ from pymatgen.util.coord import find_in_coord_list_pbc, in_coord_list_pbc if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core.trajectory import Vector3D __author__ = "Shyue Ping Ong, Will Richards, Matthew Horton" @@ -107,7 +108,7 @@ def _loop_to_str(self, loop): out += line + "\n" + val line = "\n" elif len(line) + len(val) + 2 < self.max_len: - line += " " + val + line += f" {val}" else: out += line line = "\n " + val @@ -118,8 +119,9 @@ def _format_field(self, val) -> str: val = str(val).strip() if len(val) > self.max_len: return f";\n{textwrap.fill(val, self.max_len)}\n;" + # add quotes if necessary - if val == "": + if not val: return '""' if ( (" " in val or val[0] == "_") @@ -168,43 +170,44 @@ def _process_string(cls, string): return deq @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """ Reads CifBlock from string. - :param string: String representation. + Args: + string: String representation. Returns: CifBlock """ - q = cls._process_string(string) - header = q.popleft()[0][5:] - data = {} + deq = cls._process_string(string) + header = deq.popleft()[0][5:] + data: dict = {} loops = [] - while q: - s = q.popleft() + while deq: + s = deq.popleft() # cif keys aren't in quotes, so show up in s[0] if s[0] == "_eof": break if s[0].startswith("_"): try: - data[s[0]] = "".join(q.popleft()) + data[s[0]] = "".join(deq.popleft()) except IndexError: data[s[0]] = "" elif s[0].startswith("loop_"): columns = [] items = [] - while q: - s = q[0] + while deq: + s = deq[0] if s[0].startswith("loop_") or not s[0].startswith("_"): break - columns.append("".join(q.popleft())) + columns.append("".join(deq.popleft())) data[columns[-1]] = [] - while q: - s = q[0] + while deq: + s = deq[0] if s[0].startswith(("loop_", "_")): break - items.append("".join(q.popleft())) + items.append("".join(deq.popleft())) n = len(items) // len(columns) assert len(items) % n == 0 loops.append(columns) @@ -234,10 +237,11 @@ def __str__(self): return f"{self.comment}\n{out}\n" @classmethod - def from_str(cls, string) -> CifFile: + def from_str(cls, string: str) -> Self: """Reads CifFile from a string. - :param string: String representation. + Args: + string: String representation. Returns: CifFile @@ -260,11 +264,12 @@ def from_str(cls, string) -> CifFile: return cls(dct, string) @classmethod - def from_file(cls, filename: str | Path) -> CifFile: + def from_file(cls, filename: str | Path) -> Self: """ Reads CifFile from a filename. - :param filename: Filename + Args: + filename: Filename Returns: CifFile @@ -362,7 +367,7 @@ def is_magcif_incommensurate() -> bool: self._cif.data[key] = self._sanitize_data(self._cif.data[key]) @classmethod - def from_str(cls, cif_string: str, **kwargs) -> CifParser: + def from_str(cls, cif_string: str, **kwargs) -> Self: """ Creates a CifParser from a string. @@ -384,16 +389,16 @@ def _sanitize_data(self, data): This function is here so that CifParser can assume its input conforms to spec, simplifying its implementation. - :param data: CifBlock + + Args: + data: CifBlock Returns: data CifBlock """ - """ - This part of the code deals with handling formats of data as found in - CIF files extracted from the Springer Materials/Pauling File - databases, and that are different from standard ICSD formats. - """ + # This part of the code deals with handling formats of data as found in + # CIF files extracted from the Springer Materials/Pauling File + # databases, and that are different from standard ICSD formats. # check for implicit hydrogens, warn if any present if "_atom_site_attached_hydrogens" in data.data: attached_hydrogens = [str2float(x) for x in data.data["_atom_site_attached_hydrogens"] if str2float(x) != 0] @@ -587,6 +592,8 @@ def _unique_coords( # Up to this point, magmoms have been defined relative # to crystal axis. Now convert to Cartesian and into # a Magmom object. + if lattice is None: + raise ValueError("Lattice cannot be None.") magmom = Magmom.from_moment_relative_to_crystal_axes( op.operate_magmom(tmp_magmom), lattice=lattice ) @@ -632,7 +639,7 @@ def get_lattice( if data.data.get(lattice_label): lattice_type = data.data.get(lattice_label).lower() try: - required_args = getargspec(getattr(Lattice, lattice_type)).args + required_args = getfullargspec(getattr(Lattice, lattice_type)).args lengths = (length for length in length_strings if length in required_args) angles = (a for a in angle_strings if a in required_args) @@ -661,8 +668,8 @@ def get_lattice_no_exception( Returns: Lattice object """ - lengths = [str2float(data["_cell_length_" + i]) for i in length_strings] - angles = [str2float(data["_cell_angle_" + i]) for i in angle_strings] + lengths = [str2float(data[f"_cell_length_{i}"]) for i in length_strings] + angles = [str2float(data[f"_cell_angle_{i}"]) for i in angle_strings] if not lattice_type: return Lattice.from_parameters(*lengths, *angles) return getattr(Lattice, lattice_type)(*(lengths + angles)) @@ -806,7 +813,7 @@ def get_magsymops(self, data): # else check to see if it specifies a magnetic space group elif bns_name or bns_num: - label = bns_name if bns_name else list(map(int, (bns_num.split(".")))) + label = bns_name or list(map(int, (bns_num.split(".")))) if data.data.get("_space_group_magn.transform_BNS_Pp_abc") != "a,b,c;0,0,0": jonas_faithful = data.data.get("_space_group_magn.transform_BNS_Pp_abc") @@ -848,7 +855,8 @@ def parse_oxi_states(data): def parse_magmoms(data, lattice=None): """Parse atomic magnetic moments from data dictionary.""" if lattice is None: - raise Exception("Magmoms given in terms of crystal axes in magCIF spec.") + raise ValueError("Magmoms given in terms of crystal axes in magCIF spec.") + try: magmoms = { data["_atom_site_moment_label"][i]: np.array( @@ -897,10 +905,8 @@ def _parse_symbol(self, sym): parsed_sym = sym[:2].title() elif Element.is_valid_symbol(sym[0].upper()): parsed_sym = sym[0].upper() - else: - m = re.match(r"w?[A-Z][a-z]*", sym) - if m: - parsed_sym = m.group() + elif match := re.match(r"w?[A-Z][a-z]*", sym): + parsed_sym = match.group() if parsed_sym is not None and (m_sp or not re.match(rf"{parsed_sym}\d*", sym)): msg = f"{sym} parsed as {parsed_sym}" @@ -1016,6 +1022,7 @@ def get_matching_coord(coord): self.warnings.append(msg) all_species = [] + all_species_noedit = [] all_coords = [] all_magmoms = [] all_hydrogens = [] @@ -1091,7 +1098,7 @@ def get_matching_coord(coord): if self.feature_flags["magcif"]: site_properties["magmom"] = all_magmoms - if len(site_properties) == 0: + if not site_properties: site_properties = None # type: ignore[assignment] if any(all_labels): @@ -1185,12 +1192,6 @@ def parse_structures( Returns: list[Structure]: All structures in CIF file. """ - if ( - os.getenv("CI") - and os.getenv("GITHUB_REPOSITORY") == "materialsproject/pymatgen" - and datetime.now() > datetime(2024, 10, 1) - ): # pragma: no cover - raise RuntimeError("remove the warning about changing default primitive=True to False on 2023-10-24") if primitive is None: primitive = False warnings.warn( @@ -1210,8 +1211,7 @@ def parse_structures( structures = [] for idx, dct in enumerate(self._cif.data.values()): try: - struct = self._get_structure(dct, primitive, symmetrized, check_occu=check_occu) - if struct: + if struct := self._get_structure(dct, primitive, symmetrized, check_occu=check_occu): structures.append(struct) except (KeyError, ValueError) as exc: # A user reported a problem with cif files produced by Avogadro @@ -1232,10 +1232,12 @@ def parse_structures( raise ValueError("Invalid CIF file with no structures!") return structures - def get_bibtex_string(self): + def get_bibtex_string(self) -> str: """ Get BibTeX reference from CIF file. - :param data: + + args: + data: Returns: BibTeX string. @@ -1465,9 +1467,9 @@ def __init__( if symprec is None: block["_symmetry_equiv_pos_site_id"] = ["1"] block["_symmetry_equiv_pos_as_xyz"] = ["x, y, z"] + else: spg_analyzer = SpacegroupAnalyzer(struct, symprec) - symm_ops: list[SymmOp] = [] for op in spg_analyzer.get_symmetry_operations(): v = op.translation_vector @@ -1535,6 +1537,7 @@ def __init__( atom_site_properties[key].append(format_str.format(val)) count += 1 + else: # The following just presents a deterministic ordering. unique_sites = [ @@ -1542,7 +1545,7 @@ def __init__( sorted(sites, key=lambda s: tuple(abs(x) for x in s.frac_coords))[0], len(sites), ) - for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites + for sites in spg_analyzer.get_symmetrized_structure().equivalent_sites # type: ignore[reportPossiblyUnboundVariable] ] for site, mult in sorted( unique_sites, diff --git a/pymatgen/io/common.py b/pymatgen/io/common.py index 8cc909d32b1..b9821ba7f68 100644 --- a/pymatgen/io/common.py +++ b/pymatgen/io/common.py @@ -6,6 +6,7 @@ import json import warnings from copy import deepcopy +from typing import TYPE_CHECKING import numpy as np from monty.io import zopen @@ -16,6 +17,11 @@ from pymatgen.core.units import ang_to_bohr, bohr_to_angstrom from pymatgen.electronic_structure.core import Spin +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + class VolumetricData(MSONable): """ @@ -302,11 +308,12 @@ def to_hdf5(self, filename): file.attrs["structure_json"] = json.dumps(self.structure.as_dict()) @classmethod - def from_hdf5(cls, filename, **kwargs): + def from_hdf5(cls, filename: str, **kwargs) -> Self: """ Reads VolumetricData from HDF5 file. - :param filename: Filename + Args: + filename: Filename Returns: VolumetricData @@ -321,7 +328,7 @@ def from_hdf5(cls, filename, **kwargs): structure = Structure.from_dict(json.loads(file.attrs["structure_json"])) return cls(structure, data=data, data_aug=data_aug, **kwargs) - def to_cube(self, filename, comment=None): + def to_cube(self, filename, comment: str = ""): """ Write the total volumetric data to a cube file format, which consists of two comment lines, a header section defining the structure IN BOHR, and the data. @@ -332,7 +339,7 @@ def to_cube(self, filename, comment=None): """ with zopen(filename, mode="wt") as file: file.write(f"# Cube file for {self.structure.formula} generated by Pymatgen\n") - file.write(f"# {comment if comment else ''}\n") + file.write(f"# {comment}\n") file.write(f"\t {len(self.structure)} 0.000000 0.000000 0.000000\n") for idx in range(3): @@ -349,13 +356,13 @@ def to_cube(self, filename, comment=None): f"{ang_to_bohr * site.coords[2]} \n" ) - for idx, dat in enumerate(self.data["total"].flatten()): + for idx, dat in enumerate(self.data["total"].flatten(), start=1): file.write(f"{' ' if dat > 0 else ''}{dat:.6e} ") - if (idx + 1) % 6 == 0: + if idx % 6 == 0: file.write("\n") @classmethod - def from_cube(cls, filename): + def from_cube(cls, filename: str | Path) -> Self: """ Initialize the cube object and store the data as data. @@ -370,7 +377,7 @@ def from_cube(cls, filename): # number of atoms followed by the position of the origin of the volumetric data line = file.readline().split() - natoms = int(line[0]) + n_atoms = int(line[0]) # The number of voxels along each axis (x, y, z) followed by the axis vector. line = file.readline().split() @@ -389,7 +396,7 @@ def from_cube(cls, filename): # the first is the atom number, second is charge, # the last three are the x,y,z coordinates of the atom center. sites = [] - for _ in range(natoms): + for _ in range(n_atoms): line = file.readline().split() sites.append(Site(line[0], np.multiply(bohr_to_angstrom, list(map(float, line[2:]))))) diff --git a/pymatgen/io/core.py b/pymatgen/io/core.py index 0064c9a4af5..898ec149e70 100644 --- a/pymatgen/io/core.py +++ b/pymatgen/io/core.py @@ -39,6 +39,7 @@ if TYPE_CHECKING: from os import PathLike + __author__ = "Ryan Kingsbury" __email__ = "RKingsbury@lbl.gov" __status__ = "Development" @@ -75,7 +76,7 @@ def write_file(self, filename: str | PathLike) -> None: @classmethod @abc.abstractmethod - def from_str(cls, contents: str) -> InputFile: + def from_str(cls, contents: str) -> None: """ Create an InputFile object from a string. @@ -85,9 +86,10 @@ def from_str(cls, contents: str) -> InputFile: Returns: InputFile """ + raise NotImplementedError(f"from_str has not been implemented in {cls.__name__}") @classmethod - def from_file(cls, path: str | Path): + def from_file(cls, path: str | Path) -> None: """ Creates an InputFile object from a file. @@ -99,7 +101,7 @@ def from_file(cls, path: str | Path): """ filename = path if isinstance(path, Path) else Path(path) with zopen(filename, mode="rt") as file: - return cls.from_str(file.read()) + return cls.from_str(file.read()) # from_str not implemented def __str__(self) -> str: return self.get_str() @@ -226,7 +228,7 @@ def write_input( pass @classmethod - def from_directory(cls, directory: str | Path): + def from_directory(cls, directory: str | Path) -> None: """ Construct an InputSet from a directory of one or more files. @@ -253,11 +255,12 @@ class InputGenerator(MSONable): """ @abc.abstractmethod - def get_input_set(self) -> InputSet: + def get_input_set(self, *args, **kwargs): """ Generate an InputSet object. Typically the first argument to this method will be a Structure or other form of atomic coordinates. """ + raise NotImplementedError(f"get_input_set has not been implemented in {type(self).__name__}") class ParseError(SyntaxError): diff --git a/pymatgen/io/cp2k/inputs.py b/pymatgen/io/cp2k/inputs.py index 493df9cdbec..3f582a5bae1 100644 --- a/pymatgen/io/cp2k/inputs.py +++ b/pymatgen/io/cp2k/inputs.py @@ -22,6 +22,7 @@ from __future__ import annotations +import abc import copy import hashlib import itertools @@ -45,6 +46,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Molecule, Structure @@ -111,7 +114,7 @@ def __eq__(self, other: object) -> bool: if self.name.upper() == other.name.upper(): v1 = [val.upper() if isinstance(val, str) else val for val in self.values] v2 = [val.upper() if isinstance(val, str) else val for val in other.values] # noqa: PD011 - if v1 == v2 and self.units == self.units: + if v1 == v2 and self.units == other.units: return True return False @@ -123,35 +126,35 @@ def __getitem__(self, item): def as_dict(self): """Get a dictionary representation of the Keyword.""" - dct = {} - dct["@module"] = type(self).__module__ - dct["@class"] = type(self).__name__ - dct["name"] = self.name - dct["values"] = self.values - dct["description"] = self.description - dct["repeats"] = self.repeats - dct["units"] = self.units - dct["verbose"] = self.verbose - return dct + return { + "@module": type(self).__module__, + "@class": type(self).__name__, + "name": self.name, + "values": self.values, + "description": self.description, + "repeats": self.repeats, + "units": self.units, + "verbose": self.verbose, + } def get_str(self) -> str: """String representation of Keyword.""" return str(self) @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """Initialize from dictionary.""" return Keyword( - d["name"], - *d["values"], - description=d["description"], - repeats=d["repeats"], - units=d["units"], - verbose=d["verbose"], + dct["name"], + *dct["values"], + description=dct["description"], + repeats=dct["repeats"], + units=dct["units"], + verbose=dct["verbose"], ) @classmethod - def from_str(cls, s): + def from_str(cls, s: str) -> Self: """ Initialize from a string. @@ -170,8 +173,8 @@ def from_str(cls, s): description = None units = re.findall(r"\[(.*)\]", s) or [None] s = re.sub(r"\[(.*)\]", "", s) - args = s.split() - args = list(map(postprocessor if args[0].upper() != "ELEMENT" else str, args)) + args: list[Any] = s.split() + args = list(map(postprocessor if args[0].upper() != "ELEMENT" else str, args)) # type: ignore[call-overload] args[0] = str(args[0]) return cls(*args, units=units[0], description=description) @@ -287,12 +290,12 @@ def __init__( Keyword objects """ self.name = name - self.subsections = subsections if subsections else {} + self.subsections = subsections or {} self.repeats = repeats self.description = description - keywords = keywords if keywords else {} + keywords = keywords or {} self.keywords = keywords - self.section_parameters = section_parameters if section_parameters else [] + self.section_parameters = section_parameters or [] self.location = location self.verbose = verbose self.alias = alias @@ -317,10 +320,7 @@ def __deepcopy__(self, memodict=None): ) def __getitem__(self, d): - r = self.get_keyword(d) - if not r: - r = self.get_section(d) - if r: + if r := self.get_keyword(d) or self.get_section(d): return r raise KeyError @@ -333,7 +333,7 @@ def __add__(self, other) -> Section: elif isinstance(other, (Section, SectionList)): self.insert(other) else: - TypeError("Can only add sections or keywords.") + raise TypeError("Can only add sections or keywords.") return self @@ -355,8 +355,7 @@ def setitem(self, key, value, strict=False): else: if not isinstance(value, (Keyword, KeywordList)): value = Keyword(key, value) - match = [k for k in self.keywords if key.upper() == k.upper()] - if match: + if match := [k for k in self.keywords if key.upper() == k.upper()]: del self.keywords[match[0]] self.keywords[key] = value elif not strict: @@ -367,14 +366,14 @@ def __delitem__(self, key): Delete section with name matching key OR delete all keywords with names matching this key. """ - lst = [sub_sec for sub_sec in self.subsections if sub_sec.upper() == key.upper()] - if lst: + if lst := [sub_sec for sub_sec in self.subsections if sub_sec.upper() == key.upper()]: del self.subsections[lst[0]] return - lst = [kw for kw in self.keywords if kw.upper() == key.upper()] - if lst: + + if lst := [kw for kw in self.keywords if kw.upper() == key.upper()]: del self.keywords[lst[0]] return + raise KeyError("No section or keyword matching the given key.") def __sub__(self, other): @@ -386,7 +385,7 @@ def add(self, other): raise TypeError(f"Can only add keywords, not {type(other).__name__}") return self + other - def get(self, dct, default=None): + def get(self, d, default=None): """ Similar to get for dictionaries. This will attempt to retrieve the section or keyword matching d. Will not raise an error if d does not exist. @@ -395,10 +394,9 @@ def get(self, dct, default=None): d: the key to retrieve, if present default: what to return if d is not found """ - kw = self.get_keyword(dct) - if kw: + if kw := self.get_keyword(d): return kw - sec = self.get_section(dct) + sec = self.get_section(d) if sec: return sec return default @@ -439,7 +437,7 @@ def update(self, dct: dict, strict=False): of new Section child-classes. Args: - d (dict): A dictionary containing the update information. Should use nested dictionaries + dct (dict): A dictionary containing the update information. Should use nested dictionaries to specify the full path of the update. If a section or keyword does not exist, it will be created, but only with the values that are provided in "d", not using default values from a Section object. @@ -464,14 +462,14 @@ def _update(d1, d2, strict=False): elif isinstance(v, (Keyword, KeywordList)): d1.setitem(k, v, strict=strict) elif isinstance(v, dict): - tmp = [_ for _ in d1.subsections if k.upper() == _.upper()] - if not tmp: + if tmp := [_ for _ in d1.subsections if k.upper() == _.upper()]: + Section._update(d1.subsections[tmp[0]], v, strict=strict) + else: if strict: continue d1.insert(Section(k, subsections={})) Section._update(d1.subsections[k], v, strict=strict) - else: - Section._update(d1.subsections[tmp[0]], v, strict=strict) + elif isinstance(v, Section): if not strict: d1.insert(v) @@ -496,7 +494,7 @@ def unset(self, dct: dict): elif isinstance(v, dict): self[k].unset(v) else: - TypeError("Can only add sections or keywords.") + raise TypeError("Can only add sections or keywords.") def inc(self, dct: dict): """Mongo style dict modification. Include.""" @@ -508,7 +506,7 @@ def inc(self, dct: dict): elif isinstance(val, dict): self[key].inc(val) else: - TypeError("Can only add sections or keywords.") + raise TypeError("Can only add sections or keywords.") def insert(self, d): """Insert a new section as a subsection of the current one.""" @@ -526,8 +524,7 @@ def check(self, path: str): _path = path.split("/") s = self.subsections for p in _path: - tmp = [_ for _ in s if p.upper() == _.upper()] - if tmp: + if tmp := [_ for _ in s if p.upper() == _.upper()]: s = s[tmp[0]].subsections else: return False @@ -677,7 +674,7 @@ class Cp2kInput(Section): def __init__(self, name: str = "CP2K_INPUT", subsections: dict | None = None, **kwargs): """Initialize Cp2kInput by calling the super.""" self.name = name - self.subsections = subsections if subsections else {} + self.subsections = subsections or {} self.kwargs = kwargs description = "CP2K Input" @@ -698,27 +695,27 @@ def get_str(self): return string @classmethod - def _from_dict(cls, d): + def _from_dict(cls, dct): """Initialize from a dictionary.""" return Cp2kInput( "CP2K_INPUT", subsections=getattr( - __import__(d["@module"], globals(), locals(), d["@class"], 0), - d["@class"], + __import__(dct["@module"], globals(), locals(), dct["@class"], 0), + dct["@class"], ) - .from_dict(d) + .from_dict(dct) .subsections, ) @classmethod - def from_file(cls, filename: str): + def from_file(cls, filename: str | Path) -> Self: """Initialize from a file.""" with zopen(filename, mode="rt") as file: txt = preprocessor(file.read(), os.path.dirname(file.name)) return cls.from_str(txt) @classmethod - def from_str(cls, s: str): + def from_str(cls, s: str) -> Self: """Initialize from a string.""" lines = s.splitlines() lines = [line.replace("\t", "") for line in lines] @@ -727,7 +724,7 @@ def from_str(cls, s: str): return cls.from_lines(lines) @classmethod - def from_lines(cls, lines: list | tuple): + def from_lines(cls, lines: list | tuple) -> Self: """Helper method to read lines of file.""" cp2k_input = Cp2kInput("CP2K_INPUT", subsections={}) Cp2kInput._from_lines(cp2k_input, lines) @@ -754,8 +751,7 @@ def _from_lines(self, lines): name, section_parameters=subsection_params, alias=alias, subsections={}, description=description ) description = "" - tmp = self.by_path(current).get_section(sec.alias or sec.name) - if tmp: + if tmp := self.by_path(current).get_section(sec.alias or sec.name): if isinstance(tmp, SectionList): self.by_path(current)[sec.alias or sec.name].append(sec) else: @@ -765,8 +761,7 @@ def _from_lines(self, lines): current = f"{current}/{alias or name}" else: kwd = Keyword.from_str(line) - tmp = self.by_path(current).get(kwd.name) - if tmp: + if tmp := self.by_path(current).get(kwd.name): if isinstance(tmp, KeywordList): self.by_path(current).get(kwd.name).append(kwd) elif isinstance(self.by_path(current), SectionList): @@ -794,7 +789,7 @@ def write_file( if not os.path.isdir(output_dir) and make_dir_if_not_present: os.mkdir(output_dir) filepath = os.path.join(output_dir, input_filename) - with open(filepath, mode="w") as file: + with open(filepath, mode="w", encoding="utf-8") as file: file.write(self.get_str()) @@ -817,7 +812,7 @@ def __init__( """ self.project_name = project_name self.run_type = run_type - keywords = keywords if keywords else {} + keywords = keywords or {} description = ( "Section with general information regarding which kind of simulation to perform an general settings" @@ -843,8 +838,8 @@ class ForceEval(Section): def __init__(self, keywords: dict | None = None, subsections: dict | None = None, **kwargs): """Initialize the ForceEval section.""" - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Parameters needed to calculate energy and forces and describe the system you want to analyze." @@ -893,8 +888,8 @@ def __init__( self.potential_filename = potential_filename self.uks = uks self.wfn_restart_file_name = wfn_restart_file_name - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Parameter needed by dft programs" @@ -926,8 +921,8 @@ class Subsys(Section): def __init__(self, keywords: dict | None = None, subsections: dict | None = None, **kwargs): """Initialize the subsys section.""" - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "A subsystem: coordinates, topology, molecules and cell" super().__init__("SUBSYS", keywords=keywords, description=description, subsections=subsections, **kwargs) @@ -967,8 +962,8 @@ def __init__( self.eps_default = eps_default self.eps_pgf_orb = eps_pgf_orb self.extrapolation = extrapolation - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Parameters needed to set up the Quickstep framework" _keywords = { @@ -1026,8 +1021,8 @@ def __init__( self.eps_scf = eps_scf self.scf_guess = scf_guess - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Parameters needed to perform an SCF run." @@ -1084,8 +1079,8 @@ def __init__( self.ngrids = ngrids self.progression_factor = progression_factor - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "Multigrid information. Multigrid allows for sharp gaussians and diffuse " "gaussians to be treated on different grids, where the spacing of FFT integration " @@ -1130,8 +1125,8 @@ def __init__( self.eps_iter = eps_iter self.eps_jacobi = eps_jacobi self.jacobi_threshold = jacobi_threshold - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} location = "CP2K_INPUT/FORCE_EVAL/DFT/SCF/DIAGONALIZATION" description = "Settings for the SCF's diagonalization routines" @@ -1187,8 +1182,8 @@ def __init__( """ self.new_prec_each = new_prec_each self.preconditioner = preconditioner - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} _keywords = { "NEW_PREC_EACH": Keyword("NEW_PREC_EACH", new_prec_each), "PRECONDITIONER": Keyword("PRECONDITIONER", preconditioner), @@ -1266,8 +1261,8 @@ def __init__( self.occupation_preconditioner = occupation_preconditioner self.energy_gap = energy_gap self.linesearch = linesearch - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "Sets the various options for the orbital transformation (OT) method. " @@ -1312,7 +1307,7 @@ def __init__(self, lattice: Lattice, keywords: dict | None = None, **kwargs): keywords: additional keywords """ self.lattice = lattice - keywords = keywords if keywords else {} + keywords = keywords or {} description = "Lattice parameters and optional settings for creating a the CELL" _keywords = { @@ -1369,8 +1364,8 @@ def __init__( self.potential = potential self.ghost = ghost or False # if None, set False self.aux_basis = aux_basis - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "The description of this kind of atom including basis sets, element, etc." # Special case for closed-shell elements. Cannot impose magnetization in cp2k. @@ -1398,7 +1393,7 @@ def __init__( else aux_basis.get_keyword() ) - kind_name = alias if alias else specie.__str__() + kind_name = alias or specie.__str__() alias = kind_name section_parameters = [kind_name] @@ -1449,8 +1444,8 @@ def __init__( self.l = l self.u_minus_j = u_minus_j self.u_ramping = u_ramping - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Settings for on-site Hubbard +U correction for this atom kind." _keywords = { @@ -1485,8 +1480,8 @@ def __init__( """ self.structure = structure self.aliases = aliases - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "The coordinates for simple systems (like small QM cells) are specified " "here by default using explicit XYZ coordinates. More complex systems " @@ -1523,8 +1518,8 @@ def __init__(self, ndigits: int = 6, keywords: dict | None = None, subsections: subsections: additional subsections """ self.ndigits = ndigits - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Controls printing of the overall density of states" _keywords = {"NDIGITS": Keyword("NDIGITS", ndigits)} keywords.update(_keywords) @@ -1547,8 +1542,8 @@ def __init__(self, nlumo: int = -1, keywords: dict | None = None, subsections: d subsections: additional subsections """ self.nlumo = nlumo - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Controls printing of the projected density of states" _keywords = {"NLUMO": Keyword("NLUMO", nlumo), "COMPONENTS": Keyword("COMPONENTS")} @@ -1577,8 +1572,8 @@ def __init__( subsections: additional subsections """ self.index = index - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Controls printing of the projected density of states decomposed by atom type" _keywords = {"COMPONENTS": Keyword("COMPONENTS"), "LIST": Keyword("LIST", index)} keywords.update(_keywords) @@ -1592,12 +1587,12 @@ def __init__( ) -class V_Hartree_Cube(Section): +class VHartreeCube(Section): """Controls printing of the hartree potential as a cube file.""" def __init__(self, keywords: dict | None = None, subsections: dict | None = None, **kwargs): - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "Controls the printing of a cube file with eletrostatic potential generated by " "the total density (electrons+ions). It is valid only for QS with GPW formalism. " @@ -1612,7 +1607,12 @@ def __init__(self, keywords: dict | None = None, subsections: dict | None = None ) -class MO_Cubes(Section): +@deprecated(VHartreeCube, "Deprecated on 2024-03-29, to be removed on 2025-03-29.") +class V_Hartree_Cube(VHartreeCube): + pass + + +class MOCubes(Section): """Controls printing of the molecular orbital eigenvalues.""" def __init__( @@ -1628,8 +1628,8 @@ def __init__( self.write_cube = write_cube self.nhomo = nhomo self.nlumo = nlumo - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "Controls the printing of a cube file with eletrostatic potential generated by " "the total density (electrons+ions). It is valid only for QS with GPW formalism. " @@ -1651,12 +1651,17 @@ def __init__( ) -class E_Density_Cube(Section): +@deprecated(MOCubes, "Deprecated on 2024-03-29, to be removed on 2025-03-29.") +class MO_Cubes(MOCubes): + pass + + +class EDensityCube(Section): """Controls printing of the electron density cube file.""" def __init__(self, keywords: dict | None = None, subsections: dict | None = None, **kwargs): - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = ( "Controls the printing of cube files with the electronic density and, for LSD " "calculations, the spin density." @@ -1671,6 +1676,11 @@ def __init__(self, keywords: dict | None = None, subsections: dict | None = None ) +@deprecated(EDensityCube, "Deprecated on 2024-03-29, to be removed on 2025-03-29.") +class E_Density_Cube(EDensityCube): + pass + + class Smear(Section): """Control electron smearing.""" @@ -1686,8 +1696,8 @@ def __init__( self.elec_temp = elec_temp self.method = method self.fixed_magnetic_moment = fixed_magnetic_moment - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} description = "Activates smearing of electron occupations" _keywords = { @@ -1769,7 +1779,7 @@ def __init__( ) @classmethod - def from_el(cls, el, oxi_state=0, spin=0): + def from_el(cls, el: Element, oxi_state: int = 0, spin: int = 0) -> Self: """Create section from element, oxidation state, and spin.""" el = el if isinstance(el, Element) else Element(el) @@ -1793,7 +1803,7 @@ def f3(x): nel_beta = [] n_alpha = [] n_beta = [] - unpaired_orbital = None + unpaired_orbital: tuple[int, int, int] = (0, 0, 0) while tmp: tmp2 = -min((esv[0][2], tmp)) if tmp > 0 else min((f2(esv[0][1]) - esv[0][2], -tmp)) l_alpha.append(esv[0][1]) @@ -1806,6 +1816,9 @@ def f3(x): unpaired_orbital = esv[0][0], esv[0][1], esv[0][2] + tmp2 esv.pop(0) + if unpaired_orbital is None: + raise ValueError("unpaired_orbital cannot be None.") + if spin == "low-up": spin = unpaired_orbital[2] % 2 elif spin == "low-down": @@ -1834,7 +1847,7 @@ def f3(x): ) -class Xc_Functional(Section): +class XCFunctional(Section): """Defines the XC functional(s) to use.""" def __init__( @@ -1844,9 +1857,9 @@ def __init__( subsections: dict | None = None, **kwargs, ): - self.functionals = functionals if functionals else [] - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + self.functionals = functionals or [] + keywords = keywords or {} + subsections = subsections or {} location = "CP2K_INPUT/FORCE_EVAL/DFT/XC/XC_FUNCTIONAL" for functional in self.functionals: @@ -1862,6 +1875,11 @@ def __init__( ) +@deprecated(XCFunctional, "Deprecated on 2024-03-29, to be removed on 2025-03-29.") +class Xc_Functional(XCFunctional): + pass + + class PBE(Section): """Info about the PBE functional.""" @@ -1887,8 +1905,8 @@ def __init__( self.parameterization = parameterization self.scale_c = scale_c self.scale_x = scale_x - keywords = keywords if keywords else {} - subsections = subsections if subsections else {} + keywords = keywords or {} + subsections = subsections or {} location = "CP2K_INPUT/FORCE_EVAL/DFT/XC/XC_FUNCTIONAL/PBE" @@ -1954,7 +1972,7 @@ def __init__( keywords = {} self.kpts = kpts - self.weights = weights if weights else [1] * len(kpts) + self.weights = weights or [1] * len(kpts) assert len(self.kpts) == len(self.weights) self.eps_geo = eps_geo self.full_grid = full_grid @@ -1994,7 +2012,7 @@ def __init__( ) @classmethod - def from_kpoints(cls, kpoints: VaspKpoints, structure=None): + def from_kpoints(cls, kpoints: VaspKpoints, structure=None) -> Self: """ Initialize the section from a Kpoints object (pymatgen.io.vasp.inputs). CP2K does not have an automatic gamma-point constructor, so this is generally used @@ -2014,10 +2032,7 @@ def from_kpoints(cls, kpoints: VaspKpoints, structure=None): if kpoints.style == KpointsSupportedModes.Monkhorst: k = kpts[0] - if isinstance(k, (int, float)): - x, y, z = k, k, k - else: - x, y, z = k + x, y, z = (k, k, k) if isinstance(k, (int, float)) else k scheme = f"MONKHORST-PACK {x} {y} {z}" units = "B_VECTOR" elif kpoints.style == KpointsSupportedModes.Reciprocal: @@ -2089,7 +2104,7 @@ class Kpoint_Set(KpointSet): pass -class Band_Structure(Section): +class BandStructure(Section): """Specifies high symmetry paths for outputting the band structure in CP2K.""" def __init__( @@ -2111,7 +2126,7 @@ def __init__( self.kpoint_sets = SectionList(kpoint_sets) self.filename = filename self.added_mos = added_mos - keywords = keywords if keywords else {} + keywords = keywords or {} _keywords = { "FILE_NAME": Keyword("FILE_NAME", filename), "ADDED_MOS": Keyword("ADDED_MOS", added_mos), @@ -2127,8 +2142,8 @@ def __init__( # TODO kpoints objects are defined in the vasp module instead of a code agnostic module # if this changes in the future as other codes are added, then this will need to change - @staticmethod - def from_kpoints(kpoints: VaspKpoints, kpoints_line_density=20): + @classmethod + def from_kpoints(cls, kpoints: VaspKpoints, kpoints_line_density: int = 20) -> Self: """ Initialize band structure section from a line-mode Kpoint object. @@ -2165,7 +2180,12 @@ def pairwise(iterable): raise ValueError( "Unsupported k-point style. Must be line-mode or explicit k-points (reciprocal/cartesian)." ) - return Band_Structure(kpoint_sets=kpoint_sets, filename="BAND.bs") + return cls(kpoint_sets=kpoint_sets, filename="BAND.bs") + + +@deprecated(BandStructure, "Deprecated on 2024-03-29, to be removed on 2025-03-29.") +class Band_Structure(BandStructure): + pass @dataclass @@ -2217,7 +2237,7 @@ def softmatch(self, other): return all(not (v is not None and v != d2[k]) for k, v in d1.items()) @classmethod - def from_str(cls, string: str) -> BasisInfo: + def from_str(cls, string: str) -> Self: """Get summary info from a string.""" string = string.upper() data: dict[str, Any] = {} @@ -2251,17 +2271,17 @@ def from_str(cls, string: str) -> BasisInfo: data["polarization"] = string.count("P") data["diffuse"] = string.count("X") string = f"#{string}" - for i, s in enumerate(string): - if s == "Z": - z = int(tmp.get(string[i - 1], string[i - 1])) + for idx, char in enumerate(string): + if char == "Z": + z = int(tmp.get(string[idx - 1], string[idx - 1])) data["core"] = z if bool_core else None data["valence"] = z - elif s == "P" and string[i - 1].isnumeric(): - data["polarization"] = int(string[i - 1]) - elif s == "X" and string[i - 1].isnumeric(): - data["diffuse"] = int(string[i - 1]) - elif s == "Q" and string[i + 1].isnumeric(): - data["electrons"] = int("".join(_ for _ in string[i + 1 :] if _.isnumeric())) + elif char == "P" and string[idx - 1].isnumeric(): + data["polarization"] = int(string[idx - 1]) + elif char == "X" and string[idx - 1].isnumeric(): + data["diffuse"] = int(string[idx - 1]) + elif char == "Q" and string[idx + 1].isnumeric(): + data["electrons"] = int("".join(_ for _ in string[idx + 1 :] if _.isnumeric())) if not data["diffuse"]: data["diffuse"] = string.count("AUG") @@ -2385,7 +2405,7 @@ def get_keyword(self) -> Keyword: @property def nexp(self): """Number of exponents.""" - return [len(e) for e in self.exponents] + return [len(exp) for exp in self.exponents] @typing.no_type_check def get_str(self) -> str: @@ -2418,7 +2438,7 @@ def get_str(self) -> str: return out @classmethod - def from_str(cls, string: str) -> GaussianTypeOrbitalBasisSet: + def from_str(cls, string: str) -> Self: """Read from standard cp2k GTO formatted string.""" lines = [line for line in string.split("\n") if line] firstline = lines[0].split() @@ -2448,7 +2468,7 @@ def from_str(cls, string: str) -> GaussianTypeOrbitalBasisSet: line_index = 2 for set_index in range(nset): setinfo = lines[line_index].split() - _n, _lmin, _lmax, _nexp = map(int, setinfo[0:4]) + _n, _lmin, _lmax, _nexp = map(int, setinfo[:4]) n.append(_n) lmin.append(_lmin) lmax.append(_lmax) @@ -2518,17 +2538,17 @@ def softmatch(self, other): return all(not (v is not None and v != d2[k]) for k, v in d1.items()) @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """Get a cp2k formatted string representation.""" string = string.upper() - data = {} + data: dict[str, Any] = {} if "NLCC" in string: data["nlcc"] = True if "GTH" in string: data["potential_type"] = "GTH" - for i, s in enumerate(string): - if s == "Q" and string[i + 1].isnumeric(): - data["electrons"] = int("".join(_ for _ in string[i + 1 :] if _.isnumeric())) + for idx, char in enumerate(string, start=1): + if char == "Q" and string[idx].isnumeric(): + data["electrons"] = int("".join(_ for _ in string[idx:] if _.isnumeric())) for x in ("LDA", "PADA", "MGGA", "GGA", "HF", "PBE0", "PBE", "BP", "BLYP", "B3LYP", "SCAN"): if x in string: @@ -2612,7 +2632,7 @@ def get_section(self) -> Section: ) @classmethod - def from_section(cls, section: Section) -> GthPotential: + def from_section(cls, section: Section) -> Self: """Extract GTH-formatted string from a section and convert it to model.""" sec = copy.deepcopy(section) sec.verbosity(verbosity=False) @@ -2654,7 +2674,7 @@ def get_str(self) -> str: return out @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """Initialize model from a GTH formatted string.""" lines = [line for line in string.split("\n") if line] firstline = lines[0].split() @@ -2671,8 +2691,8 @@ def from_str(cls, string): info.electrons = Element(element).Z else: potential = "Pseudopotential" - nelecs = {i: int(n) for i, n in enumerate(lines[1].split())} - info.electrons = sum(nelecs.values()) # override, more reliable than name + n_elecs = {idx: int(n_elec) for idx, n_elec in enumerate(lines[1].split())} + info.electrons = sum(n_elecs.values()) # override, more reliable than name thirdline = lines[2].split() r_loc, nexp_ppl, c_exp_ppl = ( float(thirdline[0]), @@ -2681,9 +2701,9 @@ def from_str(cls, string): ) nprj = int(lines[3].split()[0]) if len(lines) > 3 else 0 - radii = {} - nprj_ppnl = {} - hprj_ppnl = {} + radii: dict[int, float] = {} + nprj_ppnl: dict[int, int] = {} + hprj_ppnl: dict[int, dict] = {} lines = lines[4:] i = 0 ll = 0 @@ -2694,8 +2714,8 @@ def from_str(cls, string): radii[ll] = float(line[0]) nprj_ppnl[ll] = int(line[1]) hprj_ppnl[ll] = {x: {} for x in range(nprj_ppnl[ll])} - line = list(map(float, line[2:])) - hprj_ppnl[ll][0] = {j: float(ln) for j, ln in enumerate(line)} + _line = [float(i) for i in line[2:]] + hprj_ppnl[ll][0] = {j: float(ln) for j, ln in enumerate(_line)} L = 1 i += 1 @@ -2711,7 +2731,7 @@ def from_str(cls, string): name=name, alias_names=aliases, potential=potential, - n_elecs=nelecs, + n_elecs=n_elecs, r_loc=r_loc, nexp_ppl=nexp_ppl, c_exp_ppl=c_exp_ppl, @@ -2730,22 +2750,23 @@ class DataFile(MSONable): objects: Sequence | None = None @classmethod - def from_file(cls, filename): - """Load from a file.""" - with open(filename) as file: - data = cls.from_str(file.read()) - for obj in data.objects: + def from_file(cls, filename) -> Self: + """Load from a file, reserved for child classes.""" + with open(filename, encoding="utf-8") as file: + data = cls.from_str(file.read()) # type: ignore[call-arg] + for obj in data.objects: # type: ignore[attr-defined] obj.filename = filename - return data + return data # type: ignore[return-value] @classmethod - def from_str(cls): + @abc.abstractmethod + def from_str(cls, string: str) -> None: """Initialize from a string.""" raise NotImplementedError def write_file(self, filename): """Write to a file.""" - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: file.write(self.get_str()) def get_str(self) -> str: @@ -2761,7 +2782,7 @@ class BasisFile(DataFile): """Data file for basis sets only.""" @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: # type: ignore[override] """Initialize from a string representation.""" basis_sets = [GaussianTypeOrbitalBasisSet.from_str(c) for c in chunk(string)] return cls(objects=basis_sets) @@ -2772,7 +2793,7 @@ class PotentialFile(DataFile): """Data file for potentials only.""" @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: # type: ignore[override] """Initialize from a string representation.""" basis_sets = [GthPotential.from_str(c) for c in chunk(string)] return cls(objects=basis_sets) diff --git a/pymatgen/io/cp2k/outputs.py b/pymatgen/io/cp2k/outputs.py index c3841f60944..d4309022a80 100644 --- a/pymatgen/io/cp2k/outputs.py +++ b/pymatgen/io/cp2k/outputs.py @@ -102,8 +102,7 @@ def cp2k_version(self): @property def completed(self): """Did the calculation complete.""" - c = self.data.get("completed", False) - if c: + if c := self.data.get("completed", False): return c[0][0] return c @@ -196,18 +195,12 @@ def is_molecule(self) -> bool: True if the cp2k output was generated for a molecule (i.e. no periodicity in the cell). """ - if self.data.get("poisson_periodicity", [[""]])[0][0].upper() == "NONE": - return True - return False + return self.data.get("poisson_periodicity", [[""]])[0][0].upper() == "NONE" @property def is_metal(self) -> bool: """Was a band gap found? i.e. is it a metal.""" - if self.band_gap is None: - return True - if self.band_gap <= 0: - return True - return False + return True if self.band_gap is None else self.band_gap <= 0 @property def is_hubbard(self) -> bool: @@ -262,9 +255,9 @@ def parse_files(self): self.filenames["wfn.bak"].append(w) else: self.filenames["wfn"] = w - for f in self.filenames.values(): - if hasattr(f, "sort"): - f.sort(key=natural_keys) + for filename in self.filenames.values(): + if hasattr(filename, "sort"): + filename.sort(key=natural_keys) def parse_structures(self, trajectory_file=None, lattice_file=None): """ @@ -276,7 +269,7 @@ def parse_structures(self, trajectory_file=None, lattice_file=None): default, so non static calculations have to reference the trajectory file. """ self.parse_initial_structure() - trajectory_file = trajectory_file if trajectory_file else self.filenames.get("trajectory") + trajectory_file = trajectory_file or self.filenames.get("trajectory") if isinstance(trajectory_file, list): if len(trajectory_file) == 1: trajectory_file = trajectory_file[0] @@ -300,8 +293,7 @@ def parse_structures(self, trajectory_file=None, lattice_file=None): lattices = [latt[2:].reshape(3, 3) for latt in latt_file] if not trajectory_file: - self.structures = [] - self.structures.append(self.initial_structure) + self.structures = [self.initial_structure] self.final_structure = self.structures[-1] else: mols = XYZ.from_file(trajectory_file).all_molecules @@ -609,8 +601,7 @@ def parse_dft_params(self): # Functional if self.input and self.input.check("FORCE_EVAL/DFT/XC/XC_FUNCTIONAL"): - xc_funcs = list(self.input["force_eval"]["dft"]["xc"]["xc_functional"].subsections) - if xc_funcs: + if xc_funcs := list(self.input["force_eval"]["dft"]["xc"]["xc_functional"].subsections): self.data["dft"]["functional"] = xc_funcs else: for v in self.input["force_eval"]["dft"]["xc"].subsections.values(): @@ -1251,10 +1242,7 @@ def parse_dos(self, dos_file=None, pdos_files=None, ldos_files=None): self.data["pdos"] = jsanitize(pdoss, strict=True) self.data["ldos"] = jsanitize(ldoss, strict=True) - if dos_file: - self.data["tdos"] = parse_dos(dos_file) - else: - self.data["tdos"] = tdos + self.data["tdos"] = parse_dos(dos_file) if dos_file else tdos if self.data.get("tdos"): self.band_gap = self.data["tdos"].get_gap() @@ -1295,7 +1283,7 @@ def parse_bandstructure(self, bandstructure_filename=None) -> None: else: return - with open(bandstructure_filename) as file: + with open(bandstructure_filename, encoding="utf-8") as file: lines = file.read().split("\n") data = np.loadtxt(bandstructure_filename) @@ -1317,7 +1305,7 @@ def parse_bandstructure(self, bandstructure_filename=None) -> None: nkpts += int(lines[0].split()[6]) elif line.split()[1] == "Point": kpts.append(list(map(float, line.split()[-4:-1]))) - elif line.split()[1] == "Special" in line: + elif line.split()[1] == "Special": splt = line.split() label = splt[7] if label.upper() == "GAMMA": @@ -1684,12 +1672,15 @@ def parse_dos(dos_file=None): data = np.loadtxt(dos_file) data[:, 0] *= Ha_to_eV energies = data[:, 0] - for i, o in enumerate(data[:, 1]): - if o == 0: + vbm_top = None + for idx, val in enumerate(data[:, 1]): + if val == 0: break - vbmtop = i - efermi = energies[vbmtop] + 1e-6 + vbm_top = idx + + efermi = energies[vbm_top] + 1e-6 densities = {Spin.up: data[:, 1]} + if data.shape[1] > 3: densities[Spin.down] = data[:, 3] return Dos(efermi=efermi, energies=energies, densities=densities) @@ -1727,38 +1718,38 @@ def parse_pdos(dos_file=None, spin_channel=None, total=False): header = re.split(r"\s{2,}", lines[1].replace("#", "").strip())[2:] dat = np.loadtxt(dos_file) - def cp2k_to_pmg_labels(x): - if x == "p": + def cp2k_to_pmg_labels(label: str) -> str: + if label == "p": return "px" - if x == "d": + if label == "d": return "dxy" - if x == "f": + if label == "f": return "f_3" - if x == "d-2": + if label == "d-2": return "dxy" - if x == "d-1": + if label == "d-1": return "dyz" - if x == "d0": + if label == "d0": return "dz2" - if x == "d+1": + if label == "d+1": return "dxz" - if x == "d+2": + if label == "d+2": return "dx2" - if x == "f-3": + if label == "f-3": return "f_3" - if x == "f-2": + if label == "f-2": return "f_2" - if x == "f-1": + if label == "f-1": return "f_1" - if x == "f0": + if label == "f0": return "f0" - if x == "f+1": + if label == "f+1": return "f1" - if x == "f+2": + if label == "f+2": return "f2" - if x == "f+3": + if label == "f+3": return "f3" - return x + return label header = [cp2k_to_pmg_labels(h) for h in header] @@ -1767,15 +1758,16 @@ def cp2k_to_pmg_labels(x): data = np.delete(data, 1, 1) data[:, 0] *= Ha_to_eV energies = data[:, 0] - for i, o in enumerate(occupations): - if o == 0: + vbm_top = None + for idx, occu in enumerate(occupations): + if occu == 0: break - vbmtop = i + vbm_top = idx # set Fermi level to be vbm plus tolerance for # PMG compatibility # *not* middle of the gap, which pdos might report - efermi = energies[vbmtop] + 1e-6 + efermi = energies[vbm_top] + 1e-6 # for pymatgen's dos class. VASP creates an evenly spaced grid of energy states, which # leads to 0 density states in the band gap. CP2K does not do this. PMG's Dos class was @@ -1783,10 +1775,10 @@ def cp2k_to_pmg_labels(x): # in between VBM and CBM, so here we introduce trivial ones energies = np.insert( energies, - vbmtop + 1, - np.linspace(energies[vbmtop] + 1e-6, energies[vbmtop + 1] - 1e-6, 2), + vbm_top + 1, + np.linspace(energies[vbm_top] + 1e-6, energies[vbm_top + 1] - 1e-6, 2), ) - data = np.insert(data, vbmtop + 1, np.zeros((2, data.shape[1])), axis=0) + data = np.insert(data, vbm_top + 1, np.zeros((2, data.shape[1])), axis=0) pdos = { kind: { diff --git a/pymatgen/io/cp2k/sets.py b/pymatgen/io/cp2k/sets.py index 0dee72c866b..60c1c4d21ac 100644 --- a/pymatgen/io/cp2k/sets.py +++ b/pymatgen/io/cp2k/sets.py @@ -35,7 +35,7 @@ PBE, PDOS, QS, - Band_Structure, + BandStructure, BasisFile, BasisInfo, BrokenSymmetry, @@ -43,7 +43,7 @@ Coord, Cp2kInput, Dft, - E_Density_Cube, + EDensityCube, ForceEval, GaussianTypeOrbitalBasisSet, Global, @@ -52,7 +52,7 @@ Kind, Kpoints, Mgrid, - MO_Cubes, + MOCubes, OrbitalTransformation, PotentialFile, PotentialInfo, @@ -61,8 +61,8 @@ SectionList, Smear, Subsys, - V_Hartree_Cube, - Xc_Functional, + VHartreeCube, + XCFunctional, ) from pymatgen.io.cp2k.utils import get_truncated_coulomb_cutoff, get_unique_site_indices from pymatgen.io.vasp.inputs import Kpoints as VaspKpoints @@ -177,7 +177,7 @@ def __init__( super().__init__(name="CP2K_INPUT", subsections={}) self.structure = structure - self.basis_and_potential = basis_and_potential if basis_and_potential else {} + self.basis_and_potential = basis_and_potential or {} self.project_name = project_name self.charge = int(structure.charge) if not multiplicity and isinstance(self.structure, Molecule): @@ -200,7 +200,7 @@ def __init__( self.rel_cutoff = rel_cutoff self.ngrids = ngrids self.progression_factor = progression_factor - self.override_default_params = override_default_params if override_default_params else {} + self.override_default_params = override_default_params or {} self.wfn_restart_file_name = wfn_restart_file_name self.kpoints = kpoints self.smearing = smearing @@ -234,7 +234,7 @@ def __init__( # Build the QS Section qs = QS(method=self.qs_method, eps_default=eps_default, eps_pgf_orb=kwargs.get("eps_pgf_orb")) - max_scf = max_scf if max_scf else 20 if ot else 400 # If ot, max_scf is for inner loop + max_scf = max_scf or 20 if ot else 400 # If ot, max_scf is for inner loop scf = Scf(eps_scf=eps_scf, max_scf=max_scf, subsections={}) if ot: @@ -304,7 +304,7 @@ def __init__( MULTIPLICITY=self.multiplicity, CHARGE=self.charge, uks=self.kwargs.get("spin_polarized", True), - basis_set_filenames=self.basis_set_file_names if self.basis_set_file_names else [], + basis_set_filenames=self.basis_set_file_names or [], potential_filename=self.potential_file_name, subsections={"QS": qs, "SCF": scf, "MGRID": mgrid}, wfn_restart_file_name=wfn_restart_file_name, @@ -317,7 +317,7 @@ def __init__( # Create subsections and insert into them self["FORCE_EVAL"].insert(dft) - xc_functional = Xc_Functional(functionals=self.xc_functionals) + xc_functional = XCFunctional(functionals=self.xc_functionals) xc = Section("XC", subsections={"XC_FUNCTIONAL": xc_functional}) self["FORCE_EVAL"]["DFT"].insert(xc) self["FORCE_EVAL"]["DFT"].insert(Section("PRINT", subsections={})) @@ -399,13 +399,13 @@ def get_basis_and_potential(structure, basis_and_potential): for el in structure.symbol_set: possible_basis_sets = [] possible_potentials = [] - basis, aux_basis, potential = None, None, None + basis, aux_basis, potential, DATA = None, None, None, None desired_basis, desired_aux_basis, desired_potential = None, None, None have_element_file = os.path.isfile(os.path.join(SETTINGS.get("PMG_CP2K_DATA_DIR", "."), el)) # Necessary if matching data to cp2k data files if have_element_file: - with open(os.path.join(SETTINGS.get("PMG_CP2K_DATA_DIR", "."), el)) as file: + with open(os.path.join(SETTINGS.get("PMG_CP2K_DATA_DIR", "."), el), encoding="utf-8") as file: yaml = YAML(typ="unsafe", pure=True) DATA = yaml.load(file) if not DATA.get("basis_sets"): @@ -567,7 +567,7 @@ def get_xc_functionals(xc_functionals: list | str | None = None) -> list: Get XC functionals. If simplified names are provided in kwargs, they will be expanded into their corresponding X and C names. """ - names = xc_functionals if xc_functionals else SETTINGS.get("PMG_DEFAULT_CP2K_FUNCTIONAL") + names = xc_functionals or SETTINGS.get("PMG_DEFAULT_CP2K_FUNCTIONAL") if not names: raise ValueError( "No XC functional provided. Specify kwarg xc_functional or configure PMG_DEFAULT_FUNCTIONAL " @@ -653,7 +653,7 @@ def print_mo_cubes(self, write_cube: bool = False, nlumo: int = -1, nhomo: int = nhomo (int): Controls the number of homos printed and dumped as a cube (-1=all) """ if not self.check("FORCE_EVAL/DFT/PRINT/MO_CUBES"): - self["FORCE_EVAL"]["DFT"]["PRINT"].insert(MO_Cubes(write_cube=write_cube, nlumo=nlumo, nhomo=nhomo)) + self["FORCE_EVAL"]["DFT"]["PRINT"].insert(MOCubes(write_cube=write_cube, nlumo=nlumo, nhomo=nhomo)) def print_mo(self) -> None: """Print molecular orbitals when running non-OT diagonalization.""" @@ -666,12 +666,12 @@ def print_v_hartree(self, stride=(2, 2, 2)) -> None: Note that by convention the potential has opposite sign than the expected physical one. """ if not self.check("FORCE_EVAL/DFT/PRINT/V_HARTREE_CUBE"): - self["FORCE_EVAL"]["DFT"]["PRINT"].insert(V_Hartree_Cube(keywords={"STRIDE": Keyword("STRIDE", *stride)})) + self["FORCE_EVAL"]["DFT"]["PRINT"].insert(VHartreeCube(keywords={"STRIDE": Keyword("STRIDE", *stride)})) def print_e_density(self, stride=(2, 2, 2)) -> None: """Controls the printing of cube files with electronic density and, for UKS, the spin density.""" if not self.check("FORCE_EVAL/DFT/PRINT/E_DENSITY_CUBE"): - self["FORCE_EVAL"]["DFT"]["PRINT"].insert(E_Density_Cube(keywords={"STRIDE": Keyword("STRIDE", *stride)})) + self["FORCE_EVAL"]["DFT"]["PRINT"].insert(EDensityCube(keywords={"STRIDE": Keyword("STRIDE", *stride)})) def print_bandstructure(self, kpoints_line_density: int = 20) -> None: """ @@ -685,7 +685,7 @@ def print_bandstructure(self, kpoints_line_density: int = 20) -> None: if not self.kpoints: raise ValueError("Kpoints must be provided to enable band structure printing") - bs = Band_Structure.from_kpoints( + bs = BandStructure.from_kpoints( self.kpoints, kpoints_line_density=kpoints_line_density, ) @@ -834,9 +834,9 @@ def activate_hybrid( ip_keywords = {} if hybrid_functional == "HSE06": pbe = PBE("ORIG", scale_c=1, scale_x=0) - xc_functional = Xc_Functional(functionals=[], subsections={"PBE": pbe}) + xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) - potential_type = potential_type if potential_type else "SHORTRANGE" + potential_type = potential_type or "SHORTRANGE" xc_functional.insert( Section( "XWPBE", @@ -857,7 +857,7 @@ def activate_hybrid( ) elif hybrid_functional == "PBE0": pbe = PBE("ORIG", scale_c=1, scale_x=1 - hf_fraction) - xc_functional = Xc_Functional(functionals=[], subsections={"PBE": pbe}) + xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) if isinstance(self.structure, Molecule): potential_type = "COULOMB" @@ -877,16 +877,15 @@ def activate_hybrid( ip_keywords["T_C_G_DATA"] = Keyword("T_C_G_DATA", "t_c_g.dat") ip_keywords["POTENTIAL_TYPE"] = Keyword("POTENTIAL_TYPE", potential_type) + elif hybrid_functional == "RSH": - """ - Activates range separated functional using mixing of the truncated - coulomb operator and the long range operator using scale_longrange, - scale_coulomb, cutoff_radius, and omega. - """ + # Activates range separated functional using mixing of the truncated + # coulomb operator and the long range operator using scale_longrange, + # scale_coulomb, cutoff_radius, and omega. pbe = PBE("ORIG", scale_c=1, scale_x=0) - xc_functional = Xc_Functional(functionals=[], subsections={"PBE": pbe}) + xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) - potential_type = potential_type if potential_type else "MIX_CL_TRUNC" + potential_type = potential_type or "MIX_CL_TRUNC" hf_fraction = 1 ip_keywords.update( { @@ -925,7 +924,7 @@ def activate_hybrid( "settings manually. Proceed with caution." ) pbe = PBE("ORIG", scale_c=gga_c_fraction, scale_x=gga_x_fraction) - xc_functional = Xc_Functional(functionals=[], subsections={"PBE": pbe}) + xc_functional = XCFunctional(functionals=[], subsections={"PBE": pbe}) ip_keywords.update( { diff --git a/pymatgen/io/cp2k/utils.py b/pymatgen/io/cp2k/utils.py index 2833d34829c..0c1ad20e157 100644 --- a/pymatgen/io/cp2k/utils.py +++ b/pymatgen/io/cp2k/utils.py @@ -29,11 +29,11 @@ def postprocessor(data: str) -> str | float | bool | None: """ data = data.strip().replace(" ", "_") # remove leading/trailing whitespace, replace spaces with _ - if data.lower() in ("false", "no", "f"): + if data.lower() in {"false", "no", "f"}: return False if data.lower() == "none": return None - if data.lower() in ("true", "yes", "t"): + if data.lower() in {"true", "yes", "t"}: return True if re.match(r"^-?\d+$", data): try: diff --git a/pymatgen/io/cssr.py b/pymatgen/io/cssr.py index 12d920fad65..b507cda1aed 100644 --- a/pymatgen/io/cssr.py +++ b/pymatgen/io/cssr.py @@ -3,12 +3,18 @@ from __future__ import annotations import re +from typing import TYPE_CHECKING from monty.io import zopen from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "Shyue Ping Ong" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -41,8 +47,8 @@ def __str__(self): f"{len(self.structure)} 0", f"0 {self.structure.formula}", ] - for idx, site in enumerate(self.structure): - output.append(f"{idx + 1} {site.specie} {site.a:.4f} {site.b:.4f} {site.c:.4f}") + for idx, site in enumerate(self.structure, start=1): + output.append(f"{idx} {site.specie} {site.a:.4f} {site.b:.4f} {site.c:.4f}") return "\n".join(output) def write_file(self, filename): @@ -56,7 +62,7 @@ def write_file(self, filename): file.write(str(self) + "\n") @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """ Reads a string representation to a Cssr object. @@ -71,18 +77,17 @@ def from_str(cls, string): lengths = [float(tok) for tok in tokens] tokens = lines[1].split() angles = [float(tok) for tok in tokens[0:3]] - latt = Lattice.from_parameters(*lengths, *angles) - sp = [] - coords = [] + lattice = Lattice.from_parameters(*lengths, *angles) + sp, coords = [], [] for line in lines[4:]: - m = re.match(r"\d+\s+(\w+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)", line.strip()) - if m: - sp.append(m.group(1)) - coords.append([float(m.group(i)) for i in range(2, 5)]) - return cls(Structure(latt, sp, coords)) + match = re.match(r"\d+\s+(\w+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)", line.strip()) + if match: + sp.append(match.group(1)) + coords.append([float(match.group(i)) for i in range(2, 5)]) + return cls(Structure(lattice, sp, coords)) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Reads a CSSR file to a Cssr object. diff --git a/pymatgen/io/exciting/inputs.py b/pymatgen/io/exciting/inputs.py index 794ec56f255..706a012d8c9 100644 --- a/pymatgen/io/exciting/inputs.py +++ b/pymatgen/io/exciting/inputs.py @@ -2,7 +2,9 @@ from __future__ import annotations +import itertools import xml.etree.ElementTree as ET +from typing import TYPE_CHECKING import numpy as np import scipy.constants as const @@ -13,6 +15,11 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.symmetry.bandstructure import HighSymmKpath +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "Christian Vorwerk" __copyright__ = "Copyright 2016" __version__ = "1.0" @@ -64,19 +71,27 @@ def lockxyz(self, lockxyz): self.structure.add_site_property("selective_dynamics", lockxyz) @classmethod - def from_str(cls, data): + def from_str(cls, data: str) -> Self: """Reads the exciting input from a string.""" - root = ET.fromstring(data) - species_node = root.find("structure").iter("species") + root: ET.Element = ET.fromstring(data) + struct = root.find("structure") + if struct is None: + raise ValueError("No structure found in input file!") + + species_node = struct.iter("species") elements = [] positions = [] vectors = [] lockxyz = [] # get title - title_in = str(root.find("title").text) + _title = root.find("title") + assert _title is not None, "title cannot be None." + title_in = str(_title.text) # Read elements and coordinates for nodes in species_node: - symbol = nodes.get("speciesfile").split(".")[0] + _speciesfile = nodes.get("speciesfile") + assert _speciesfile is not None, "speciesfile cannot be None." + symbol = _speciesfile.split(".")[0] if len(symbol.split("_")) == 2: symbol = symbol.split("_")[0] if Element.is_valid_symbol(symbol): @@ -84,40 +99,50 @@ def from_str(cls, data): element = symbol else: raise ValueError("Unknown element!") + for atom in nodes.iter("atom"): - x, y, z = atom.get("coord").split() + _coord = atom.get("coord") + assert _coord is not None, "coordinate cannot be None." + x, y, z = _coord.split() positions.append([float(x), float(y), float(z)]) elements.append(element) # Obtain lockxyz for each atom - if atom.get("lockxyz") is not None: + if atom.get("lockxyz") is None: + lockxyz.append([False, False, False]) + else: lxyz = [] - for line in atom.get("lockxyz").split(): + + _lockxyz = atom.get("lockxyz") + assert _lockxyz is not None, "lockxyz cannot be None." + for line in _lockxyz.split(): if line in ("True", "true"): lxyz.append(True) else: lxyz.append(False) lockxyz.append(lxyz) - else: - lockxyz.append([False, False, False]) + # check the atomic positions type - if "cartesian" in root.find("structure").attrib: - if root.find("structure").attrib["cartesian"]: - cartesian = True - for p in positions: - for j in range(3): - p[j] = p[j] * ExcitingInput.bohr2ang - print(positions) - else: - cartesian = False + cartesian = False + if struct.attrib.get("cartesian"): + cartesian = True + for p, j in itertools.product(positions, range(3)): + p[j] = p[j] * ExcitingInput.bohr2ang + + _crystal = struct.find("crystal") + assert _crystal is not None, "crystal cannot be None." + # get the scale attribute - scale_in = root.find("structure").find("crystal").get("scale") + scale_in = _crystal.get("scale") scale = float(scale_in) * ExcitingInput.bohr2ang if scale_in else ExcitingInput.bohr2ang + # get the stretch attribute - stretch_in = root.find("structure").find("crystal").get("stretch") + stretch_in = _crystal.get("stretch") stretch = np.array([float(a) for a in stretch_in]) if stretch_in else np.array([1.0, 1.0, 1.0]) + # get basis vectors and scale them accordingly - basisnode = root.find("structure").find("crystal").iter("basevect") + basisnode = _crystal.iter("basevect") for vect in basisnode: + assert vect.text is not None, "vectors cannot be None." x, y, z = vect.text.split() vectors.append( [ @@ -133,9 +158,10 @@ def from_str(cls, data): return cls(structure_in, title_in, lockxyz) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ - :param filename: Filename + Args: + filename: Filename Returns: ExcitingInput @@ -205,7 +231,7 @@ def write_etree(self, celltype, cartesian=False, bandstr=False, symprec: float = # write atomic positions for each species index = 0 for elem in sorted(new_struct.types_of_species, key=lambda el: el.X): - species = ET.SubElement(structure, "species", speciesfile=elem.symbol + ".xml") + species = ET.SubElement(structure, "species", speciesfile=f"{elem.symbol}.xml") sites = new_struct.indices_from_symbol(elem.symbol) for j in sites: @@ -226,8 +252,12 @@ def write_etree(self, celltype, cartesian=False, bandstr=False, symprec: float = # write atomic positions index = index + 1 _ = ET.SubElement(species, "atom", coord=coord) + # write bandstructure if needed - if bandstr and celltype == "primitive": + if bandstr: + if celltype != "primitive": + raise ValueError("Bandstructure is only implemented for the standard primitive unit cell!") + kpath = HighSymmKpath(new_struct, symprec=symprec, angle_tolerance=angle_tolerance) prop = ET.SubElement(root, "properties") band_struct = ET.SubElement(prop, "bandstructure") @@ -241,8 +271,6 @@ def write_etree(self, celltype, cartesian=False, bandstr=False, symprec: float = symbol_map = {"\\Gamma": "GAMMA", "\\Sigma": "SIGMA", "\\Delta": "DELTA", "\\Lambda": "LAMBDA"} symbol = symbol_map.get(symbol, symbol) _ = ET.SubElement(path, "point", coord=coord, label=symbol) - elif bandstr and celltype != "primitive": - raise ValueError("Bandstructure is only implemented for the standard primitive unit cell!") # write extra parameters from kwargs if provided self._dicttoxml(kwargs, root) @@ -325,13 +353,14 @@ def _indent(elem, level=0): """ Helper method to indent elements. - :param elem: - :param level: + Args: + elem: + level: """ i = "\n" + level * " " if len(elem): if not elem.text or not elem.text.strip(): - elem.text = i + " " + elem.text = f"{i} " if not elem.tail or not elem.tail.strip(): elem.tail = i for el in elem: diff --git a/pymatgen/io/feff/inputs.py b/pymatgen/io/feff/inputs.py index 293c096cede..4d62241142e 100644 --- a/pymatgen/io/feff/inputs.py +++ b/pymatgen/io/feff/inputs.py @@ -10,6 +10,7 @@ import re import warnings +from typing import TYPE_CHECKING import numpy as np from monty.io import zopen @@ -23,6 +24,9 @@ from pymatgen.util.io_utils import clean_lines from pymatgen.util.string import str_delimited +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Alan Dozier, Kiran Mathew" __credits__ = "Anubhav Jain, Shyue Ping Ong" __copyright__ = "Copyright 2011, The Materials Project" @@ -143,7 +147,7 @@ class Header(MSONable): """ Creates Header for the FEFF input file. - Has the following format:: + Has the following format: * This feff.inp file generated by pymatgen, materialsproject.org TITLE comment: @@ -193,10 +197,10 @@ def __init__( raise ValueError("'struct' argument must be a Structure or Molecule!") self.comment = comment or "None given" - @staticmethod - def from_cif_file(cif_file, source="", comment=""): + @classmethod + def from_cif_file(cls, cif_file: str, source: str = "", comment: str = "") -> Self: """ - Static method to create Header object from cif_file. + Create Header object from cif_file. Args: cif_file: cif_file path and name @@ -209,7 +213,7 @@ def from_cif_file(cif_file, source="", comment=""): """ parser = CifParser(cif_file) structure = parser.parse_structures(primitive=True)[0] - return Header(structure, source, comment) + return cls(structure, source, comment) @property def structure_symmetry(self): @@ -227,13 +231,13 @@ def formula(self): return self.struct.composition.formula @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str) -> Self: """Returns Header object from file.""" hs = cls.header_string_from_file(filename) return cls.from_str(hs) @staticmethod - def header_string_from_file(filename="feff.inp"): + def header_string_from_file(filename: str = "feff.inp"): """ Reads Header string from either a HEADER file or feff.inp file Will also read a header from a non-pymatgen generated feff.inp file. @@ -269,7 +273,7 @@ def header_string_from_file(filename="feff.inp"): # source end = 0 for line in f: - if (line[0] == "*" or line[0] == "T") and end == 0: + if line[0] in {"*", "T"} and end == 0: feff_header_str.append(line.replace("\r", "")) else: end = 1 @@ -277,7 +281,7 @@ def header_string_from_file(filename="feff.inp"): return "".join(feff_header_str) @classmethod - def from_str(cls, header_str): + def from_str(cls, header_str: str) -> Self: """ Reads Header string and returns Header object if header was generated by pymatgen. @@ -305,15 +309,13 @@ def from_str(cls, header_str): a = float(basis_vec[0]) b = float(basis_vec[1]) c = float(basis_vec[2]) - lengths = [a, b, c] # alpha, beta, gamma basis_ang = lines[7].split(":")[-1].split() alpha = float(basis_ang[0]) beta = float(basis_ang[1]) gamma = float(basis_ang[2]) - angles = [alpha, beta, gamma] - lattice = Lattice.from_parameters(*lengths, *angles) + lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma) n_atoms = int(lines[8].split(":")[-1].split()[0]) @@ -353,12 +355,14 @@ def __str__(self): output.append(f"TITLE sites: {len(self.struct)}") - for idx, site in enumerate(self.struct): + for idx, site in enumerate(self.struct, start=1): if isinstance(self.struct, Structure): coords = [f"{j:0.6f}".rjust(12) for j in site.frac_coords] elif isinstance(self.struct, Molecule): coords = [f"{j:0.6f}".rjust(12) for j in site.coords] - output.append(f"* {idx + 1} {site.species_string} {' '.join(coords)}") + else: + raise TypeError("Unsupported type. Expect Structure or Molecule.") + output.append(f"* {idx} {site.species_string} {' '.join(coords)}") return "\n".join(output) def write_file(self, filename="HEADER"): @@ -368,7 +372,7 @@ def write_file(self, filename="HEADER"): Args: filename: Filename and path for file to be written to disk """ - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: file.write(str(self) + "\n") @@ -491,11 +495,11 @@ def get_lines(self) -> list[list[str | int]]: 0, ] ] - for idx, site in enumerate(self._cluster[1:]): + for idx, site in enumerate(self._cluster[1:], start=1): site_symbol = site.specie.symbol ipot = self.pot_dict[site_symbol] - dist = self._cluster.get_distance(0, idx + 1) - lines += [[f"{site.x}", f"{site.y}", f"{site.z}", ipot, site_symbol, f"{dist}", idx + 1]] + dist = self._cluster.get_distance(0, idx) + lines += [[f"{site.x}", f"{site.y}", f"{site.z}", ipot, site_symbol, f"{dist}", idx]] # sort by distance from absorbing atom return sorted(lines, key=lambda line: float(line[5])) @@ -539,12 +543,12 @@ def __setitem__(self, key, val): Feff tags. Also cleans the parameter and val by stripping leading and trailing white spaces. - Arg: + Args: key: dict key value value: value associated with key in dictionary """ if key.strip().upper() not in VALID_FEFF_TAGS: - warnings.warn(key.strip() + " not in VALID_FEFF_TAGS list") + warnings.warn(f"{key.strip()} not in VALID_FEFF_TAGS list") super().__setitem__( key.strip(), Tags.proc_val(key.strip(), val.strip()) if isinstance(val, str) else val, @@ -562,22 +566,18 @@ def as_dict(self): tags_dict["@class"] = type(self).__name__ return tags_dict - @staticmethod - def from_dict(d): + @classmethod + def from_dict(cls, dct) -> Self: """ Creates Tags object from a dictionary. Args: - d: Dict of feff parameters and values. + dct (dict): Dict of feff parameters and values. Returns: Tags object """ - i = Tags() - for k, v in d.items(): - if k not in ("@module", "@class"): - i[k] = v - return i + return cls({k: v for k, v in dct.items() if k not in ("@module", "@class")}) def get_str(self, sort_keys: bool = False, pretty: bool = False) -> str: """ @@ -644,15 +644,15 @@ def write_file(self, filename="PARAMETERS"): file.write(f"{self}\n") @classmethod - def from_file(cls, filename="feff.inp"): + def from_file(cls, filename: str = "feff.inp") -> Self: """ - Creates a Feff_tag dictionary from a PARAMETER or feff.inp file. + Creates a Tags dictionary from a PARAMETER or feff.inp file. Args: filename: Filename for either PARAMETER or feff.inp file Returns: - Feff_tag object + Tags """ with zopen(filename, mode="rt") as file: lines = list(clean_lines(file.readlines())) @@ -661,10 +661,9 @@ def from_file(cls, filename="feff.inp"): ieels = -1 ieels_max = -1 for idx, line in enumerate(lines): - m = re.match(r"([A-Z]+\d*\d*)\s*(.*)", line) - if m: - key = m.group(1).strip() - val = m.group(2).strip() + if match := re.match(r"([A-Z]+\d*\d*)\s*(.*)", line): + key = match[1].strip() + val = match[2].strip() val = Tags.proc_val(key, val) if key not in ("ATOMS", "POTENTIALS", "END", "TITLE"): if key in ["ELNES", "EXELFS"]: @@ -711,32 +710,29 @@ def proc_val(key, val): boolean_type_keys = () float_type_keys = ("S02", "EXAFS", "RPATH") - def smart_int_or_float(numstr): - if numstr.find(".") != -1 or numstr.lower().find("e") != -1: - return float(numstr) - return int(numstr) + def smart_int_or_float(num_str): + if num_str.find(".") != -1 or num_str.lower().find("e") != -1: + return float(num_str) + return int(num_str) try: if key.lower() == "cif": - m = re.search(r"\w+.cif", val) - return m.group(0) + return re.search(r"\w+.cif", val)[0] if key in list_type_keys: output = [] tokens = re.split(r"\s+", val) for tok in tokens: - m = re.match(r"(\d+)\*([\d\.\-\+]+)", tok) - if m: - output.extend([smart_int_or_float(m.group(2))] * int(m.group(1))) + if match := re.match(r"(\d+)\*([\d\.\-\+]+)", tok): + output.extend([smart_int_or_float(match[2])] * int(match[1])) else: output.append(smart_int_or_float(tok)) return output if key in boolean_type_keys: - m = re.search(r"^\W+([TtFf])", val) - if m: - return m.group(1) in ["T", "t"] - raise ValueError(key + " should be a boolean type!") + if match := re.search(r"^\W+([TtFf])", val): + return match[1] in {"T", "t"} + raise ValueError(f"{key} should be a boolean type!") if key in float_type_keys: return float(val) @@ -796,13 +792,13 @@ def __init__(self, struct, absorbing_atom): struct (Structure): Structure object. absorbing_atom (str/int): Absorbing atom symbol or site index. """ - if struct.is_ordered: - self.struct = struct - atom_sym = get_absorbing_atom_symbol_index(absorbing_atom, struct)[0] - self.pot_dict = get_atom_map(struct, atom_sym) - else: + if not struct.is_ordered: raise ValueError("Structure with partial occupancies cannot be converted into atomic coordinates!") + self.struct = struct + atom_sym = get_absorbing_atom_symbol_index(absorbing_atom, struct)[0] + self.pot_dict = get_atom_map(struct, atom_sym) + self.absorbing_atom, _ = get_absorbing_atom_symbol_index(absorbing_atom, struct) @staticmethod @@ -857,7 +853,7 @@ def pot_dict_from_str(pot_data): Creates atomic symbol/potential number dictionary forward and reverse. - Arg: + Args: pot_data: potential data in string format Returns: @@ -1004,8 +1000,8 @@ def get_atom_map(structure, absorbing_atom=None): unique_pot_atoms.remove(absorbing_atom) atom_map = {} - for i, atom in enumerate(unique_pot_atoms): - atom_map[atom] = i + 1 + for i, atom in enumerate(unique_pot_atoms, start=1): + atom_map[atom] = i return atom_map diff --git a/pymatgen/io/feff/outputs.py b/pymatgen/io/feff/outputs.py index 7895f6c51fd..a5f2763d35d 100644 --- a/pymatgen/io/feff/outputs.py +++ b/pymatgen/io/feff/outputs.py @@ -8,6 +8,7 @@ import re from collections import defaultdict +from typing import TYPE_CHECKING import numpy as np from monty.io import zopen @@ -18,6 +19,9 @@ from pymatgen.electronic_structure.dos import CompleteDos, Dos from pymatgen.io.feff import Header, Potential, Tags +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Alan Dozier, Kiran Mathew, Chen Zheng" __credits__ = "Anubhav Jain, Shyue Ping Ong" __copyright__ = "Copyright 2011, The Materials Project" @@ -42,7 +46,7 @@ def __init__(self, complete_dos, charge_transfer): self.charge_transfer = charge_transfer @classmethod - def from_file(cls, feff_inp_file="feff.inp", ldos_file="ldos"): + def from_file(cls, feff_inp_file: str = "feff.inp", ldos_file: str = "ldos") -> Self: """ Creates LDos object from raw Feff ldos files by by assuming they are numbered consecutively, i.e. ldos01.dat @@ -107,7 +111,7 @@ def from_file(cls, feff_inp_file="feff.inp", ldos_file="ldos"): for idx in range(len(ldos[1])): dos_energies.append(ldos[1][idx][0]) - all_pdos = [] + all_pdos: list[dict] = [] vorb = {"s": Orbital.s, "p": Orbital.py, "d": Orbital.dxy, "f": Orbital.f0} forb = {"s": 0, "p": 1, "d": 2, "f": 3} @@ -134,13 +138,13 @@ def from_file(cls, feff_inp_file="feff.inp", ldos_file="ldos"): t_dos = [0] * d_length for idx in range(n_sites): pot_index = pot_dict[structure.species[idx].symbol] - for v in forb.values(): - density = [ldos[pot_index][j][v + 1] for j in range(d_length)] + for forb_val in forb.values(): + density = [ldos[pot_index][j][forb_val + 1] for j in range(d_length)] for j in range(d_length): t_dos[j] = t_dos[j] + density[j] - t_dos = {Spin.up: t_dos} + _t_dos: dict = {Spin.up: t_dos} - dos = Dos(efermi, dos_energies, t_dos) + dos = Dos(efermi, dos_energies, _t_dos) complete_dos = CompleteDos(structure, dos, pdoss) charge_transfer = LDos.charge_transfer_from_file(feff_inp_file, ldos_file) return cls(complete_dos, charge_transfer) @@ -287,7 +291,7 @@ def __init__(self, header, parameters, absorbing_atom, data): self.data = np.array(data) @classmethod - def from_file(cls, xmu_dat_file="xmu.dat", feff_inp_file="feff.inp"): + def from_file(cls, xmu_dat_file: str = "xmu.dat", feff_inp_file: str = "feff.inp") -> Self: """ Get Xmu from file. @@ -412,7 +416,7 @@ def fine_structure(self): return self.data[:, 3] @classmethod - def from_file(cls, eels_dat_file="eels.dat"): + def from_file(cls, eels_dat_file: str = "eels.dat") -> Self: """ Parse eels spectrum. @@ -425,7 +429,7 @@ def from_file(cls, eels_dat_file="eels.dat"): data = np.loadtxt(eels_dat_file) return cls(data) - def as_dict(self): + def as_dict(self) -> dict: """Returns dict representations of Xmu object.""" dct = MSONable.as_dict(self) dct["data"] = self.data.tolist() diff --git a/pymatgen/io/feff/sets.py b/pymatgen/io/feff/sets.py index 084ce17e0a4..3870560050d 100644 --- a/pymatgen/io/feff/sets.py +++ b/pymatgen/io/feff/sets.py @@ -14,6 +14,7 @@ import sys import warnings from copy import deepcopy +from typing import TYPE_CHECKING import numpy as np from monty.json import MSONable @@ -23,6 +24,9 @@ from pymatgen.core.structure import Molecule, Structure from pymatgen.io.feff.inputs import Atoms, Header, Potential, Tags +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Kiran Mathew" __credits__ = "Alan Dozier, Anubhav Jain, Shyue Ping Ong" __version__ = "1.1" @@ -282,15 +286,15 @@ def __str__(self): return "\n".join(output) @classmethod - def from_directory(cls, input_dir): + def from_directory(cls, input_dir: str) -> Self: """ Read in a set of FEFF input files from a directory, which is useful when existing FEFF input needs some adjustment. """ - sub_d = {} - for fname, ftype in [("HEADER", Header), ("PARAMETERS", Tags)]: - full_zpath = zpath(os.path.join(input_dir, fname)) - sub_d[fname.lower()] = ftype.from_file(full_zpath) + sub_d: dict = { + "header": Header.from_file(zpath(os.path.join(input_dir, "HEADER"))), + "parameters": Tags.from_file(zpath(os.path.join(input_dir, "PARAMETERS"))), + } # Generation of FEFFDict set requires absorbing atom, need to search # the index of absorption atom in the structure according to the diff --git a/pymatgen/io/fiesta.py b/pymatgen/io/fiesta.py index 05f8c2ac262..d016a57a5bd 100644 --- a/pymatgen/io/fiesta.py +++ b/pymatgen/io/fiesta.py @@ -15,12 +15,18 @@ import shutil import subprocess from string import Template +from typing import TYPE_CHECKING from monty.io import zopen from monty.json import MSONable from pymatgen.core.structure import Molecule +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "ndardenne" __copyright__ = "Copyright 2012, The Materials Project" __version__ = "0.1" @@ -50,9 +56,9 @@ def __init__(self, folder, filename="nwchem", log_file="log_n2f"): self.log_file = log_file self._NWCHEM2FIESTA_cmd = "NWCHEM2FIESTA" - self._nwcheminput_fn = filename + ".nw" - self._nwchemoutput_fn = filename + ".nwout" - self._nwchemmovecs_fn = filename + ".movecs" + self._nwcheminput_fn = f"{filename}.nw" + self._nwchemoutput_fn = f"{filename}.nwout" + self._nwchemmovecs_fn = f"{filename}.movecs" def run(self): """Performs actual NWCHEM2FIESTA run.""" @@ -82,14 +88,15 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation. + Args: + dct (dict): Dict representation. Returns: Nwchem2Fiesta """ - return cls(folder=d["folder"], filename=d["filename"]) + return cls(folder=dct["folder"], filename=dct["filename"]) class FiestaRun(MSONable): @@ -126,8 +133,8 @@ def run(self): def _gw_run(self): """Performs FIESTA (gw) run.""" - if self.folder != os.getcwd(): - init_folder = os.getcwd() + init_folder = os.getcwd() + if self.folder != init_folder: os.chdir(self.folder) with zopen(self.log_file, mode="w") as fout: @@ -144,13 +151,13 @@ def _gw_run(self): stdout=fout, ) - if self.folder != os.getcwd(): + if self.folder != init_folder: os.chdir(init_folder) def bse_run(self): """Performs BSE run.""" - if self.folder != os.getcwd(): - init_folder = os.getcwd() + init_folder = os.getcwd() + if self.folder != init_folder: os.chdir(self.folder) with zopen(self.log_file, mode="w") as fout: @@ -166,7 +173,7 @@ def bse_run(self): stdout=fout, ) - if self.folder != os.getcwd(): + if self.folder != init_folder: os.chdir(init_folder) def as_dict(self): @@ -180,14 +187,15 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation + Args: + dct (dict): Dict representation Returns: FiestaRun """ - return cls(folder=d["folder"], grid=d["grid"], log_file=d["log_file"]) + return cls(folder=dct["folder"], grid=dct["grid"], log_file=dct["log_file"]) class BasisSetReader: @@ -213,7 +221,7 @@ def __init__(self, filename): self.data.update(n_nlmo=self.set_n_nlmo()) @staticmethod - def _parse_file(input): + def _parse_file(lines): lmax_nnlo_patt = re.compile(r"\s* (\d+) \s+ (\d+) \s+ \# .* ", re.VERBOSE) nl_orbital_patt = re.compile(r"\s* (\d+) \s+ (\d+) \s+ (\d+) \s+ \# .* ", re.VERBOSE) @@ -227,25 +235,25 @@ def _parse_file(input): parse_nl_orbital = False nnlo = None lmax = None + l_angular = zeta = ng = None - for line in input.split("\n"): + for line in lines.split("\n"): if parse_nl_orbital: match_orb = nl_orbital_patt.search(line) match_alpha = coef_alpha_patt.search(line) if match_orb: - l_angular = match_orb.group(1) - zeta = match_orb.group(2) - ng = match_orb.group(3) + l_angular = match_orb[1] + zeta = match_orb[2] + ng = match_orb[3] basis_set[f"{l_angular}_{zeta}_{ng}"] = [] elif match_alpha: - alpha = match_alpha.group(1) - coef = match_alpha.group(2) + alpha = match_alpha[1] + coef = match_alpha[2] basis_set[f"{l_angular}_{zeta}_{ng}"].append((alpha, coef)) elif parse_lmax_nnlo: - match_orb = lmax_nnlo_patt.search(line) - if match_orb: - lmax = match_orb.group(1) - nnlo = match_orb.group(2) + if match_orb := lmax_nnlo_patt.search(line): + lmax = match_orb[1] + nnlo = match_orb[2] parse_lmax_nnlo = False parse_nl_orbital = True elif parse_preamble: @@ -301,12 +309,13 @@ def __init__( bse_tddft_options: dict[str, str] | None = None, ): """ - :param mol: pymatgen mol - :param correlation_grid: dict - :param Exc_DFT_option: dict - :param COHSEX_options: dict - :param GW_options: dict - :param BSE_TDDFT_options: dict + Args: + mol: pymatgen mol + correlation_grid: dict + Exc_DFT_option: dict + COHSEX_options: dict + GW_options: dict + BSE_TDDFT_options: dict """ self._mol = mol self.correlation_grid = correlation_grid or {"dE_grid": "0.500", "n_grid": "14"} @@ -333,24 +342,28 @@ def __init__( def set_auxiliary_basis_set(self, folder, auxiliary_folder, auxiliary_basis_set_type="aug_cc_pvtz"): """ copy in the desired folder the needed auxiliary basis set "X2.ion" where X is a specie. - :param auxiliary_folder: folder where the auxiliary basis sets are stored - :param auxiliary_basis_set_type: type of basis set (string to be found in the extension of the file name; must - be in lower case). ex: C2.ion_aug_cc_pvtz_RI_Weigend find "aug_cc_pvtz". + + Args: + auxiliary_folder: folder where the auxiliary basis sets are stored + auxiliary_basis_set_type: type of basis set (string to be found in the extension of the file name; must + be in lower case). ex: C2.ion_aug_cc_pvtz_RI_Weigend find "aug_cc_pvtz". """ list_files = os.listdir(auxiliary_folder) for specie in self._mol.symbol_set: for file in list_files: - if file.upper().find(specie.upper() + "2") != -1 and file.lower().find(auxiliary_basis_set_type) != -1: + if file.upper().find(f"{specie.upper()}2") != -1 and file.lower().find(auxiliary_basis_set_type) != -1: shutil.copyfile(f"{auxiliary_folder}/{file}", f"{folder}/{specie}2.ion") def set_gw_options(self, nv_band=10, nc_band=10, n_iteration=5, n_grid=6, dE_grid=0.5): """ Set parameters in cell.in for a GW computation - :param nv__band: number of valence bands to correct with GW - :param nc_band: number of conduction bands to correct with GW - :param n_iteration: number of iteration - :param n_grid and dE_grid:: number of points and spacing in eV for correlation grid. + + Args: + nv__band: number of valence bands to correct with GW + nc_band: number of conduction bands to correct with GW + n_iteration: number of iteration + n_grid and dE_grid: number of points and spacing in eV for correlation grid. """ self.GW_options.update(nv_corr=nv_band, nc_corr=nc_band, nit_gw=n_iteration) self.correlation_grid.update(dE_grid=dE_grid, n_grid=n_grid) @@ -367,16 +380,19 @@ def make_full_bse_densities_folder(folder): def set_bse_options(self, n_excitations=10, nit_bse=200): """ Set parameters in cell.in for a BSE computation - :param nv_bse: number of valence bands - :param nc_bse: number of conduction bands - :param n_excitations: number of excitations - :param nit_bse: number of iterations. + + Args: + nv_bse: number of valence bands + nc_bse: number of conduction bands + n_excitations: number of excitations + nit_bse: number of iterations. """ self.bse_tddft_options.update(npsi_bse=n_excitations, nit_bse=nit_bse) def dump_bse_data_in_gw_run(self, BSE_dump=True): """ - :param BSE_dump: boolean + Args: + BSE_dump: bool Returns: set the "do_bse" variable to one in cell.in @@ -386,14 +402,15 @@ def dump_bse_data_in_gw_run(self, BSE_dump=True): else: self.bse_tddft_options.update(do_bse=0, do_tddft=0) - def dump_tddft_data_in_gw_run(self, tddft_dump=True): + def dump_tddft_data_in_gw_run(self, tddft_dump: bool = True): """ - :param TDDFT_dump: boolean + Args: + TDDFT_dump: bool Returns: set the do_tddft variable to one in cell.in """ - self.bse_tddft_options.update(do_bse=0, do_tddft=1 if tddft_dump else 0) + self.bse_tddft_options.update(do_bse="0", do_tddft="1" if tddft_dump else "0") @property def infos_on_system(self): @@ -455,9 +472,7 @@ def molecule(self): def __str__(self): symbols = list(self._mol.symbol_set) - geometry = [] - for site in self._mol: - geometry.append(f" {site.x} {site.y} {site.z} {int(symbols.index(site.specie.symbol)) + 1}") + geometry = [f" {site.x} {site.y} {site.z} {symbols.index(site.specie.symbol) + 1}" for site in self._mol] t = Template( """# number of atoms and species @@ -518,10 +533,12 @@ def __str__(self): geometry="\n".join(geometry), ) - def write_file(self, filename): + def write_file(self, filename: str | Path) -> None: """ Write FiestaInput to a file - :param filename: Filename. + + Args: + filename: Filename. """ with zopen(filename, mode="w") as file: file.write(str(self)) @@ -538,24 +555,25 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation + Args: + dct (dict): Dict representation Returns: FiestaInput """ return cls( - Molecule.from_dict(d["mol"]), - correlation_grid=d["correlation_grid"], - Exc_DFT_option=d["Exc_DFT_option"], - COHSEX_options=d["geometry_options"], - GW_options=d["symmetry_options"], - BSE_TDDFT_options=d["memory_options"], + mol=Molecule.from_dict(dct["mol"]), + correlation_grid=dct["correlation_grid"], + exc_dft_option=dct["Exc_DFT_option"], + cohsex_options=dct["geometry_options"], + gw_options=dct["symmetry_options"], + bse_tddft_options=dct["memory_options"], ) @classmethod - def from_str(cls, string_input): + def from_str(cls, string_input: str) -> Self: """ Read an FiestaInput from a string. Currently tested to work with files generated from this class itself. @@ -690,7 +708,7 @@ def from_str(cls, string_input): ) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Read an Fiesta input from a file. Currently tested to work with files generated from this class itself. @@ -755,13 +773,11 @@ def _parse_job(output): for line in output.split("\n"): if parse_total_time: - m = end_patt.search(line) - if m: + if match := end_patt.search(line): GW_results.update(end_normally=True) - m = total_time_patt.search(line) - if m: - GW_results.update(total_time=m.group(1)) + if match := total_time_patt.search(line): + GW_results.update(total_time=match[1]) if parse_gw_results: if line.find("Dumping eigen energies") != -1: @@ -769,29 +785,27 @@ def _parse_job(output): parse_gw_results = False continue - m = GW_BANDS_results_patt.search(line) - if m: + if match := GW_BANDS_results_patt.search(line): dct = {} dct.update( - band=m.group(1).strip(), - eKS=m.group(2), - eXX=m.group(3), - eQP_old=m.group(4), - z=m.group(5), - sigma_c_Linear=m.group(6), - eQP_Linear=m.group(7), - sigma_c_SCF=m.group(8), - eQP_SCF=m.group(9), + band=match[1].strip(), + eKS=match[2], + eXX=match[3], + eQP_old=match[4], + z=match[5], + sigma_c_Linear=match[6], + eQP_Linear=match[7], + sigma_c_SCF=match[8], + eQP_SCF=match[9], ) - GW_results[m.group(1).strip()] = dct + GW_results[match[1].strip()] = dct - n = GW_GAPS_results_patt.search(line) - if n: + if n := GW_GAPS_results_patt.search(line): dct = {} dct.update( - Egap_KS=n.group(1), - Egap_QP_Linear=n.group(2), - Egap_QP_SCF=n.group(3), + Egap_KS=n[1], + Egap_QP_Linear=n[2], + Egap_QP_SCF=n[3], ) GW_results["Gaps"] = dct @@ -838,13 +852,11 @@ def _parse_job(output): for line in output.split("\n"): if parse_total_time: - m = end_patt.search(line) - if m: + if match := end_patt.search(line): BSE_results.update(end_normally=True) - m = total_time_patt.search(line) - if m: - BSE_results.update(total_time=m.group(1)) + if match := total_time_patt.search(line): + BSE_results.update(total_time=match[1]) if parse_BSE_results: if line.find("FULL BSE main valence -> conduction transitions weight:") != -1: @@ -852,11 +864,10 @@ def _parse_job(output): parse_BSE_results = False continue - m = BSE_exitons_patt.search(line) - if m: + if match := BSE_exitons_patt.search(line): dct = {} - dct.update(bse_eig=m.group(2), osc_strength=m.group(3)) - BSE_results[str(m.group(1).strip())] = dct + dct.update(bse_eig=match[2], osc_strength=match[3]) + BSE_results[str(match[1].strip())] = dct if line.find("FULL BSE eig.(eV), osc. strength and dipoles:") != -1: parse_BSE_results = True diff --git a/pymatgen/io/gaussian.py b/pymatgen/io/gaussian.py index fc0c29320cf..98593b1815c 100644 --- a/pymatgen/io/gaussian.py +++ b/pymatgen/io/gaussian.py @@ -4,6 +4,7 @@ import re import warnings +from typing import TYPE_CHECKING import numpy as np import scipy.constants as cst @@ -17,6 +18,11 @@ from pymatgen.util.coord import get_angle from pymatgen.util.plotting import pretty_plot +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "Shyue Ping Ong, Germain Salvato-Vallverdu, Xin Chen" __copyright__ = "Copyright 2013, The Materials Virtual Lab" __version__ = "0.1" @@ -36,7 +42,7 @@ def read_route_line(route): Args: route (str) : the route line - Return: + Returns: functional (str) : the method (HF, PBE ...) basis_set (str) : the basis set route (dict) : dictionary of parameters @@ -55,24 +61,21 @@ def read_route_line(route): route = route.replace(tok, "") for tok in route.split(): - if scrf_patt.match(tok): - m = scrf_patt.match(tok) - route_params[m.group(1)] = m.group(2) + if match := scrf_patt.match(tok): + route_params[match.group(1)] = match.group(2) elif tok.upper() in ["#", "#N", "#P", "#T"]: # does not store # in route to avoid error in input dieze_tag = "#N" if tok == "#" else tok continue + elif match := re.match(multi_params_patt, tok.strip("#")): + pars = {} + for par in match.group(2).split(","): + p = par.split("=") + pars[p[0]] = None if len(p) == 1 else p[1] + route_params[match.group(1)] = pars else: - m = re.match(multi_params_patt, tok.strip("#")) - if m: - pars = {} - for par in m.group(2).split(","): - p = par.split("=") - pars[p[0]] = None if len(p) == 1 else p[1] - route_params[m.group(1)] = pars - else: - d = tok.strip("#").split("=") - route_params[d[0]] = None if len(d) == 1 else d[1] + d = tok.strip("#").split("=") + route_params[d[0]] = None if len(d) == 1 else d[1] return functional, basis_set, route_params, dieze_tag @@ -158,7 +161,7 @@ def __init__( self.link0_parameters = link0_parameters or {} self.route_parameters = route_parameters or {} self.input_parameters = input_parameters or {} - self.dieze_tag = dieze_tag if dieze_tag[0] == "#" else "#" + dieze_tag + self.dieze_tag = dieze_tag if dieze_tag[0] == "#" else f"#{dieze_tag}" self.gen_basis = gen_basis if gen_basis is not None: self.basis_set = "Gen" @@ -275,7 +278,7 @@ def _parse_species(sp_str): return Molecule(species, coords) @classmethod - def from_str(cls, contents): + def from_str(cls, contents: str) -> Self: """ Creates GaussianInput from a string. @@ -292,6 +295,7 @@ def from_str(cls, contents): for line in lines: if link0_patt.match(line): m = link0_patt.match(line) + assert m is not None link0_dict[m.group(1).strip("=")] = m.group(2) route_patt = re.compile(r"^#[sSpPnN]*.*") @@ -299,7 +303,7 @@ def from_str(cls, contents): route_index = None for idx, line in enumerate(lines): if route_patt.match(line): - route += " " + line + route += f" {line}" route_index = idx # This condition allows for route cards spanning multiple lines elif (line == "" or line.isspace()) and route_index: @@ -310,10 +314,11 @@ def from_str(cls, contents): functional, basis_set, route_paras, dieze_tag = read_route_line(route) ind = 2 title = [] + assert route_index is not None, "route_index cannot be None" while lines[route_index + ind].strip(): title.append(lines[route_index + ind].strip()) ind += 1 - title = " ".join(title) + title_str = " ".join(title) ind += 1 tokens = re.split(r"[,\s]+", lines[route_index + ind]) charge = int(float(tokens[0])) @@ -340,7 +345,7 @@ def from_str(cls, contents): mol, charge=charge, spin_multiplicity=spin_mult, - title=title, + title=title_str, functional=functional, basis_set=basis_set, route_parameters=route_paras, @@ -350,7 +355,7 @@ def from_str(cls, contents): ) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Creates GaussianInput from a file. @@ -458,9 +463,10 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct: dict) -> GaussianInput: + def from_dict(cls, dct: dict) -> Self: """ - :param dct: dict + Args: + dct: dict Returns: GaussianInput @@ -551,22 +557,18 @@ class GaussianOutput: printed using `pop=NBOREAD` and `$nbo bndidx $end`. Methods: - .. method:: to_input() + to_input() + Return a GaussianInput object using the last geometry and the same + calculation parameters. - Return a GaussianInput object using the last geometry and the same - calculation parameters. + read_scan() + Read a potential energy surface from a gaussian scan calculation. - .. method:: read_scan() - - Read a potential energy surface from a gaussian scan calculation. + get_scan_plot() + Get a matplotlib plot of the potential energy surface - .. method:: get_scan_plot() - - Get a matplotlib plot of the potential energy surface - - .. method:: save_scan_plot() - - Save a matplotlib plot of the potential energy surface to a file + save_scan_plot() + Save a matplotlib plot of the potential energy surface to a file """ def __init__(self, filename): @@ -663,6 +665,7 @@ def _parse(self, filename): std_structures = [] geom_orientation = None opt_structures = [] + route_lower = {} with zopen(filename, mode="rt") as file: for line in file: @@ -670,8 +673,8 @@ def _parse(self, filename): if start_patt.search(line): parse_stage = 1 elif link0_patt.match(line): - m = link0_patt.match(line) - self.link0[m.group(1)] = m.group(2) + match = link0_patt.match(line) + self.link0[match.group(1)] = match.group(2) elif route_patt.search(line) or route_line != "": if set(line.strip()) == {"-"}: params = read_route_line(route_line) @@ -690,21 +693,21 @@ def _parse(self, filename): elif self.title == "": self.title = line.strip() elif charge_mul_patt.search(line): - m = charge_mul_patt.search(line) - self.charge = int(m.group(1)) - self.spin_multiplicity = int(m.group(2)) + match = charge_mul_patt.search(line) + self.charge = int(match.group(1)) + self.spin_multiplicity = int(match.group(2)) parse_stage = 2 elif parse_stage == 2: if self.is_pcm: self._check_pcm(line) if "freq" in route_lower and thermo_patt.search(line): - m = thermo_patt.search(line) - if m.group(1) == "Zero-point": - self.corrections["Zero-point"] = float(m.group(3)) + match = thermo_patt.search(line) + if match.group(1) == "Zero-point": + self.corrections["Zero-point"] = float(match.group(3)) else: - key = m.group(2).replace(" to ", "") - self.corrections[key] = float(m.group(3)) + key = match.group(2).replace(" to ", "") + self.corrections[key] = float(match.group(3)) if read_coord: [file.readline() for i in range(3)] @@ -724,9 +727,8 @@ def _parse(self, filename): std_structures.append(Molecule(sp, coords)) if parse_forces: - m = forces_patt.search(line) - if m: - forces.extend([float(_v) for _v in m.groups()[2:5]]) + if match := forces_patt.search(line): + forces.extend([float(_v) for _v in match.groups()[2:5]]) elif forces_off_patt.search(line): self.cart_forces.append(forces) forces = [] @@ -734,8 +736,7 @@ def _parse(self, filename): # read molecular orbital eigenvalues if read_eigen: - m = orbital_patt.search(line) - if m: + if match := orbital_patt.search(line): eigen_txt.append(line) else: read_eigen = False @@ -751,8 +752,8 @@ def _parse(self, filename): # read molecular orbital coefficients if (not num_basis_found) and num_basis_func_patt.search(line): - m = num_basis_func_patt.search(line) - self.num_basis_func = int(m.group(1)) + match = num_basis_func_patt.search(line) + self.num_basis_func = int(match.group(1)) num_basis_found = True elif read_mo: # build a matrix with all coefficients @@ -765,6 +766,8 @@ def _parse(self, filename): mat_mo[spin] = np.zeros((self.num_basis_func, self.num_basis_func)) nMO = 0 end_mo = False + atom_idx = None + coeffs = [] while nMO < self.num_basis_func and not end_mo: file.readline() file.readline() @@ -773,13 +776,13 @@ def _parse(self, filename): line = file.readline() # identify atom and OA labels - m = mo_coeff_name_patt.search(line) - if m.group(1).strip() != "": - atom_idx = int(m.group(2)) - 1 + match = mo_coeff_name_patt.search(line) + if match.group(1).strip() != "": + atom_idx = int(match.group(2)) - 1 # atname = m.group(3) - self.atom_basis_labels.append([m.group(4)]) + self.atom_basis_labels.append([match.group(4)]) else: - self.atom_basis_labels[atom_idx].append(m.group(4)) + self.atom_basis_labels[atom_idx].append(match.group(4)) # MO coefficients coeffs = [float(c) for c in float_patt.findall(line)] @@ -890,8 +893,8 @@ def _parse(self, filename): parse_bond_order = False elif termination_patt.search(line): - m = termination_patt.search(line) - if m.group(1) == "Normal": + match = termination_patt.search(line) + if match.group(1) == "Normal": self.properly_terminated = True terminated = True elif error_patt.search(line): @@ -899,25 +902,25 @@ def _parse(self, filename): "! Non-Optimized Parameters !": "Optimization error", "Convergence failure": "SCF convergence error", } - m = error_patt.search(line) - self.errors.append(error_defs[m.group(1)]) + match = error_patt.search(line) + self.errors.append(error_defs[match.group(1)]) elif num_elec_patt.search(line): - m = num_elec_patt.search(line) - self.electrons = (int(m.group(1)), int(m.group(2))) + match = num_elec_patt.search(line) + self.electrons = (int(match.group(1)), int(match.group(2))) elif (not self.is_pcm) and pcm_patt.search(line): self.is_pcm = True self.pcm = {} elif "freq" in route_lower and "opt" in route_lower and stat_type_patt.search(line): self.stationary_type = "Saddle" elif mp2_patt.search(line): - m = mp2_patt.search(line) - self.energies.append(float(m.group(1).replace("D", "E"))) + match = mp2_patt.search(line) + self.energies.append(float(match.group(1).replace("D", "E"))) elif oniom_patt.search(line): - m = oniom_patt.matcher(line) - self.energies.append(float(m.group(1))) + match = oniom_patt.matcher(line) + self.energies.append(float(match.group(1))) elif scf_patt.search(line): - m = scf_patt.search(line) - self.energies.append(float(m.group(1))) + match = scf_patt.search(line) + self.energies.append(float(match.group(1))) elif std_orientation_patt.search(line): standard_orientation = True geom_orientation = "standard" @@ -939,13 +942,12 @@ def _parse(self, filename): eigen_txt.append(line) read_eigen = True elif mulliken_patt.search(line): - mulliken_txt = [] read_mulliken = True elif not parse_forces and forces_on_patt.search(line): parse_forces = True elif freq_on_patt.search(line): parse_freq = True - [file.readline() for i in range(3)] + _ = [file.readline() for _i in range(3)] elif mo_coeff_patt.search(line): if "Alpha" in line: self.is_spin = True @@ -967,27 +969,28 @@ def _parse(self, filename): parse_bond_order = True if read_mulliken: + mulliken_txt = [] if not end_mulliken_patt.search(line): mulliken_txt.append(line) else: - m = end_mulliken_patt.search(line) + match = end_mulliken_patt.search(line) mulliken_charges = {} for line in mulliken_txt: if mulliken_charge_patt.search(line): - m = mulliken_charge_patt.search(line) - dic = {int(m.group(1)): [m.group(2), float(m.group(3))]} + match = mulliken_charge_patt.search(line) + dic = {int(match.group(1)): [match.group(2), float(match.group(3))]} mulliken_charges.update(dic) read_mulliken = False self.Mulliken_charges = mulliken_charges # store the structures. If symmetry is considered, the standard orientation # is used. Else the input orientation is used. + self.structures_input_orientation = input_structures if standard_orientation: self.structures = std_structures - self.structures_input_orientation = input_structures else: self.structures = input_structures - self.structures_input_orientation = input_structures + # store optimized structure in input orientation self.opt_structures = opt_structures @@ -1010,6 +1013,7 @@ def _parse_hessian(self, file, structure): self.hessian = np.zeros((ndf, ndf)) j_indices = range(5) ndf_idx = 0 + vals = None while ndf_idx < ndf: for i in range(ndf_idx, ndf): line = file.readline() @@ -1136,8 +1140,8 @@ def read_scan(self): while not re.search(r"^\s-+", line): values = list(map(float, line.split())) data["energies"].append(values[-1]) - for i, icname in enumerate(data["coords"]): - data["coords"][icname].append(values[i + 1]) + for i, icname in enumerate(data["coords"], start=1): + data["coords"][icname].append(values[i]) line = file.readline() else: line = file.readline() @@ -1203,7 +1207,7 @@ def read_excitation_energies(self): if td and re.search(r"^\sExcited State\s*\d", line): val = [float(v) for v in float_patt.findall(line)] - transitions.append(tuple(val[0:3])) + transitions.append(tuple(val[:3])) line = file.readline() return transitions diff --git a/pymatgen/io/icet.py b/pymatgen/io/icet.py index 697ee19d0a1..ea153a511cd 100644 --- a/pymatgen/io/icet.py +++ b/pymatgen/io/icet.py @@ -16,7 +16,9 @@ from icet.tools.structure_generation import _get_sqs_cluster_vector, _validate_concentrations, generate_sqs from mchammer.calculators import compare_cluster_vectors except ImportError: - ClusterSpace = None + ClusterSpace = _validate_concentrations = _get_sqs_cluster_vector = compare_cluster_vectors = generate_sqs = ( + enumerate_structures + ) = None if TYPE_CHECKING: diff --git a/pymatgen/io/lammps/data.py b/pymatgen/io/lammps/data.py index 554f94bf899..0a045e4f455 100644 --- a/pymatgen/io/lammps/data.py +++ b/pymatgen/io/lammps/data.py @@ -37,6 +37,8 @@ from collections.abc import Sequence from typing import Any + from typing_extensions import Self + from pymatgen.core.sites import Site from pymatgen.core.structure import SiteCollection @@ -255,7 +257,7 @@ def __init__( ["vx", "vy", "vz"] for Velocities section. Optional with default to None. If not None, its index should be consistent with atoms. - force_field (dict): Data for force field sections. Optional + force_fieldct (dict): Data for force field sections. Optional with default to None. Only keywords in force field and class 2 force field are valid keys, and each value is a DataFrame. @@ -294,7 +296,7 @@ def structure(self) -> Structure: Exports a periodic structure object representing the simulation box. - Return: + Returns: Structure """ masses = self.masses @@ -307,14 +309,14 @@ def structure(self) -> Structure: molecule = topologies[0].sites coords = molecule.cart_coords - np.array(self.box.bounds)[:, 0] species = molecule.species - latt = self.box.to_lattice() + lattice = self.box.to_lattice() site_properties = {} if "q" in atoms: site_properties["charge"] = atoms["q"].to_numpy() if self.velocities is not None: site_properties["velocities"] = self.velocities.to_numpy() return Structure( - latt, + lattice, species, coords, coords_are_cartesian=True, @@ -473,7 +475,7 @@ def write_file(self, filename: str, distance: int = 6, velocity: int = 8, charge charge (int): No. of significant figures to output for charges. Default to 4. """ - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: file.write(self.get_str(distance=distance, velocity=velocity, charge=charge)) def disassemble( @@ -629,7 +631,7 @@ def label_topo(t) -> tuple: return self.box, ff, topo_list @classmethod - def from_file(cls, filename: str, atom_style: str = "full", sort_id: bool = False) -> LammpsData: + def from_file(cls, filename: str, atom_style: str = "full", sort_id: bool = False) -> Self: """ Constructor that parses a file. @@ -656,16 +658,17 @@ def from_file(cls, filename: str, atom_style: str = "full", sort_id: bool = Fals bounds = {} for line in clean_lines(parts[0][1:]): # skip the 1st line match = None - for k, v in header_pattern.items(): # noqa: B007 - match = re.match(v, line) + key = None + for key, val in header_pattern.items(): # noqa: B007 + match = re.match(val, line) if match: break - if match and k in ["counts", "types"]: - header[k][match.group(2)] = int(match.group(1)) - elif match and k == "bounds": + if match and key in {"counts", "types"}: + header[key][match[2]] = int(match[1]) + elif match and key == "bounds": g = match.groups() bounds[g[2]] = [float(i) for i in g[:2]] - elif match and k == "tilt": + elif match and key == "tilt": header["tilt"] = [float(i) for i in match.groups()] header["bounds"] = [bounds.get(i, [-0.5, 0.5]) for i in "xyz"] box = LammpsBox(header["bounds"], header.get("tilt")) @@ -719,7 +722,7 @@ def parse_section(sec_lines) -> tuple[str, pd.DataFrame]: name in ["Velocities"] + SECTION_KEYWORDS["topology"] and not seen_atoms ): # Atoms must appear earlier than these raise RuntimeError(f"{err_msg}{name} section appears before Atoms section") - body.update({name: section}) + body[name] = section err_msg += "Nos. of {} do not match between header and {} section" assert len(body["Masses"]) == header["types"]["atom"], err_msg.format("atom types", "Masses") @@ -743,7 +746,7 @@ def parse_section(sec_lines) -> tuple[str, pd.DataFrame]: @classmethod def from_ff_and_topologies( cls, box: LammpsBox, ff: ForceField, topologies: Sequence[Topology], atom_style: str = "full" - ): + ) -> Self: """ Constructor building LammpsData from a ForceField object and a list of Topology objects. Do not support intermolecular @@ -770,7 +773,7 @@ def from_ff_and_topologies( v_collector: list | None = [] if topologies[0].velocities else None topo_collector: dict[str, list] = {"Bonds": [], "Angles": [], "Dihedrals": [], "Impropers": []} topo_labels: dict[str, list] = {"Bonds": [], "Angles": [], "Dihedrals": [], "Impropers": []} - for i, topo in enumerate(topologies): + for idx, topo in enumerate(topologies): if topo.topologies: shift = len(labels) for k, v in topo.topologies.items(): @@ -778,11 +781,10 @@ def from_ff_and_topologies( topo_labels[k].extend([tuple(topo.type_by_sites[j] for j in t) for t in v]) if isinstance(v_collector, list): v_collector.append(topo.velocities) - mol_ids.extend([i + 1] * len(topo.sites)) + mol_ids.extend([idx + 1] * len(topo.sites)) labels.extend(topo.type_by_sites) coords.append(topo.sites.cart_coords) - q = [0.0] * len(topo.sites) if not topo.charges else topo.charges - charges.extend(q) + charges.extend(topo.charges or [0.0] * len(topo.sites)) atoms = pd.DataFrame(np.concatenate(coords), columns=["x", "y", "z"]) atoms["molecule-ID"] = mol_ids @@ -818,7 +820,7 @@ def from_structure( ff_elements: Sequence[str] | None = None, atom_style: Literal["atomic", "charge"] = "charge", is_sort: bool = False, - ): + ) -> Self: """ Simple constructor building LammpsData from a structure without force field parameters and topologies. @@ -969,7 +971,7 @@ def from_bonding( dihedral: bool = True, tol: float = 0.1, **kwargs, - ): + ) -> Self: """ Another constructor that creates an instance from a molecule. Covalent bonds and other bond-based topologies (angles and @@ -996,8 +998,8 @@ def from_bonding( dests, freq = np.unique(bond_list, return_counts=True) hubs = dests[np.where(freq > 1)].tolist() bond_arr = np.array(bond_list) + hub_spokes = {} if len(hubs) > 0: - hub_spokes = {} for hub in hubs: ix = np.any(np.isin(bond_arr, hub), axis=1) bonds = np.unique(bond_arr[ix]).tolist() @@ -1029,7 +1031,7 @@ class ForceField(MSONable): Attributes: masses (pandas.DataFrame): DataFrame for Masses section. - force_field (dict): Force field section keywords (keys) and + force_fieldct (dict): Force field section keywords (keys) and data (values) as DataFrames. maps (dict): Dict for labeling atoms and topologies. """ @@ -1090,12 +1092,12 @@ def map_mass(v): ) index, masses, self.mass_info, atoms_map = [], [], [], {} - for i, m in enumerate(mass_info): - index.append(i + 1) + for idx, m in enumerate(mass_info, start=1): + index.append(idx) mass = map_mass(m[1]) masses.append(mass) self.mass_info.append((m[0], mass)) - atoms_map[m[0]] = i + 1 + atoms_map[m[0]] = idx self.masses = pd.DataFrame({"mass": masses}, index=index) self.maps = {"Atoms": atoms_map} @@ -1115,7 +1117,7 @@ def map_mass(v): ff_dfs.update(coeffs) self.maps.update(mapper) - self.force_field = None if len(ff_dfs) == 0 else ff_dfs + self.force_field = ff_dfs or None def _process_nonbond(self) -> dict: pair_df = pd.DataFrame(self.nonbond_coeffs) @@ -1160,9 +1162,9 @@ def find_eq_types(label, section) -> list: atoms = set(np.ravel(list(itertools.chain(*distinct_types)))) assert atoms.issubset(self.maps["Atoms"]), f"Undefined atom type found in {kw}" mapper = {} - for i, dt in enumerate(distinct_types): + for i, dt in enumerate(distinct_types, start=1): for t in dt: - mapper[t] = i + 1 + mapper[t] = i def process_data(data) -> pd.DataFrame: df = pd.DataFrame(data) @@ -1189,30 +1191,30 @@ def to_file(self, filename: str) -> None: "nonbond_coeffs": self.nonbond_coeffs, "topo_coeffs": self.topo_coeffs, } - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: yaml = YAML() yaml.dump(dct, file) @classmethod - def from_file(cls, filename: str) -> ForceField: + def from_file(cls, filename: str) -> Self: """ Constructor that reads in a file in YAML format. Args: filename (str): Filename. """ - with open(filename) as file: + with open(filename, encoding="utf-8") as file: yaml = YAML() d = yaml.load(file) return cls.from_dict(d) @classmethod - def from_dict(cls, dct: dict) -> ForceField: + def from_dict(cls, dct: dict) -> Self: """ Constructor that reads in a dictionary. Args: - d (dict): Dictionary to read. + dct (dict): Dictionary to read. """ dct["mass_info"] = [tuple(m) for m in dct["mass_info"]] if dct.get("topo_coeffs"): @@ -1330,7 +1332,7 @@ def structure(self) -> Structure: Exports a periodic structure object representing the simulation box. - Return: + Returns: Structure """ ld_cp = self.as_lammpsdata() @@ -1373,13 +1375,14 @@ def disassemble( for mol in self.mols ] + # NOTE (@janosh): The following two methods for override parent class LammpsData @classmethod - def from_ff_and_topologies(cls): + def from_ff_and_topologies(cls) -> None: # type: ignore[override] """Unsupported constructor for CombinedData objects.""" raise AttributeError("Unsupported constructor for CombinedData objects") @classmethod - def from_structure(cls): + def from_structure(cls) -> None: # type: ignore[override] """Unsupported constructor for CombinedData objects.""" raise AttributeError("Unsupported constructor for CombinedData objects") @@ -1406,33 +1409,38 @@ def parse_xyz(cls, filename: str | Path) -> pd.DataFrame: return df @classmethod - def from_files(cls, coordinate_file: str, list_of_numbers: list, *filenames) -> CombinedData: + def from_files(cls, coordinate_file: str, list_of_numbers: list[int], *filenames: str) -> Self: """ Constructor that parse a series of data file. Args: coordinate_file (str): The filename of xyz coordinates. - list_of_numbers (list): A list of numbers specifying counts for each + list_of_numbers (list[int]): A list of numbers specifying counts for each clusters parsed from files. filenames (str): A series of LAMMPS data filenames in string format. """ names = [] mols = [] styles = [] - coordinates = cls.parse_xyz(filename=coordinate_file) - for idx in range(1, len(filenames) + 1): - exec(f"cluster{idx} = LammpsData.from_file(filenames[{idx - 1}])") + clusters = [] + + for idx, filename in enumerate(filenames, start=1): + cluster = LammpsData.from_file(filename) + clusters.append(cluster) names.append(f"cluster{idx}") - mols.append(eval(f"cluster{idx}")) - styles.append(eval(f"cluster{idx}").atom_style) - style = set(styles) - assert len(style) == 1, "Files have different atom styles." - return cls.from_lammpsdata(mols, names, list_of_numbers, coordinates, style.pop()) + mols.append(cluster) + styles.append(cluster.atom_style) + + if len(set(styles)) != 1: + raise ValueError("Files have different atom styles.") + + coordinates = cls.parse_xyz(filename=coordinate_file) + return cls.from_lammpsdata(mols, names, list_of_numbers, coordinates, styles.pop()) @classmethod def from_lammpsdata( cls, mols: list, names: list, list_of_numbers: list, coordinates: pd.DataFrame, atom_style: str | None = None - ) -> CombinedData: + ) -> Self: """ Constructor that can infer atom_style. The input LammpsData objects are used non-destructively. @@ -1446,14 +1454,15 @@ def from_lammpsdata( columns of ["x", "y", "z"] for coordinates of atoms. atom_style (str): Output atom_style. Default to "full". """ - styles = [] - for mol in mols: - styles.append(mol.atom_style) - style = set(styles) - assert len(style) == 1, "Data have different atom_style." - style_return = style.pop() - if atom_style: - assert atom_style == style_return, "Data have different atom_style as specified." + styles = [mol.atom_style for mol in mols] + + if len(set(styles)) != 1: + raise ValueError("Data have different atom_style.") + style_return = styles.pop() + + if atom_style and atom_style != style_return: + raise ValueError("Data have different atom_style as specified.") + return cls(mols, names, list_of_numbers, coordinates, style_return) def get_str(self, distance: int = 6, velocity: int = 8, charge: int = 4, hybrid: bool = True) -> str: @@ -1494,7 +1503,7 @@ def as_lammpsdata(self): Convert a CombinedData object to a LammpsData object. attributes are deep-copied. box (LammpsBox): Simulation box. - force_field (dict): Data for force field sections. Optional + force_fieldct (dict): Data for force field sections. Optional with default to None. Only keywords in force field and class 2 force field are valid keys, and each value is a DataFrame. diff --git a/pymatgen/io/lammps/generators.py b/pymatgen/io/lammps/generators.py index 0559b41a5dd..baa74a24963 100644 --- a/pymatgen/io/lammps/generators.py +++ b/pymatgen/io/lammps/generators.py @@ -42,7 +42,7 @@ class BaseLammpsGenerator(InputGenerator): The parameters are then replaced based on the values found in the settings dictionary that you provide, e.g., `{"nsteps": 1000}`. - Parameters: + Attributes: template: Path (string) to the template file used to create the InputFile for LAMMPS. calc_type: Human-readable string used to briefly describe the type of computations performed by LAMMPS. settings: Dictionary containing the values of the parameters to replace in the template. @@ -57,17 +57,16 @@ class BaseLammpsGenerator(InputGenerator): (https://github.com/Matgenix/atomate2-lammps). """ + inputfile: LammpsInputFile | None = field(default=None) template: str = field(default_factory=str) + data: LammpsData | CombinedData | None = field(default=None) settings: dict = field(default_factory=dict) calc_type: str = field(default="lammps") keep_stages: bool = field(default=True) - def __post_init__(self): - self.settings = self.settings or {} - - def get_input_set(self, structure: Structure | LammpsData | CombinedData | None) -> LammpsInputSet: # type: ignore + def get_input_set(self, structure: Structure | LammpsData | CombinedData) -> LammpsInputSet: """Generate a LammpsInputSet from the structure/data, tailored to the template file.""" - data = LammpsData.from_structure(structure) if isinstance(structure, Structure) else structure + data: LammpsData = LammpsData.from_structure(structure) if isinstance(structure, Structure) else structure # Load the template with zopen(self.template, mode="r") as file: @@ -83,7 +82,7 @@ def get_input_set(self, structure: Structure | LammpsData | CombinedData | None) class LammpsMinimization(BaseLammpsGenerator): - r""" + """ Generator that yields a LammpsInputSet tailored for minimizing the energy of a system by iteratively adjusting atom coordinates. Example usage: @@ -94,7 +93,7 @@ class LammpsMinimization(BaseLammpsGenerator): Do not forget to specify the force field, otherwise LAMMPS will not be able to run! - /!\ This InputSet and InputGenerator implementation is based on templates and is not intended to be very flexible. + This InputSet and InputGenerator implementation is based on templates and is not intended to be very flexible. For instance, pymatgen will not detect whether a given variable should be adapted based on others (e.g., the number of steps from the temperature), it will not check for convergence nor will it actually run LAMMPS. For additional flexibility and automation, use the atomate2-lammps implementation @@ -112,21 +111,21 @@ def __init__( force_field: str = "Unspecified force field!", keep_stages: bool = False, ) -> None: - r""" + """ Args: template: Path (string) to the template file used to create the InputFile for LAMMPS. - units: units to be used for the LAMMPS calculation (see official LAMMPS documentation). - atom_style: atom_style to be used for the LAMMPS calculation (see official LAMMPS documentation). - dimension: dimension to be used for the LAMMPS calculation (see official LAMMPS documentation). - boundary: boundary to be used for the LAMMPS calculation (see official LAMMPS documentation). - read_data: read_data to be used for the LAMMPS calculation (see official LAMMPS documentation). - force_field: force field to be used for the LAMMPS calculation (see official LAMMPS documentation). - Note that you should provide all the required information as a single string. - In case of multiple lines expected in the input file, - separate them with '\n' in force_field. + units: units to be used for the LAMMPS calculation (see LAMMPS docs). + atom_style: atom_style to be used for the LAMMPS calculation (see LAMMPS docs). + dimension: dimension to be used for the LAMMPS calculation (see LAMMPS docs). + boundary: boundary to be used for the LAMMPS calculation (see LAMMPS docs). + read_data: read_data to be used for the LAMMPS calculation (see LAMMPS docs). + force_field: force field to be used for the LAMMPS calculation (see LAMMPS docs). + Note that you should provide all the required information as a single string. + In case of multiple lines expected in the input file, + separate them with '\n' in force_field. keep_stages: If True, the string is formatted in a block structure with stage names - and newlines that differentiate commands in the respective stages of the InputFile. - If False, stage names are not printed and all commands appear in a single block. + and newlines that differentiate commands in the respective stages of the InputFile. + If False, stage names are not printed and all commands appear in a single block. """ if template is None: template = f"{template_dir}/minimization.template" diff --git a/pymatgen/io/lammps/inputs.py b/pymatgen/io/lammps/inputs.py index b7031b726a0..029a00261e6 100644 --- a/pymatgen/io/lammps/inputs.py +++ b/pymatgen/io/lammps/inputs.py @@ -28,6 +28,8 @@ if TYPE_CHECKING: from os import PathLike + from typing_extensions import Self + from pymatgen.io.core import InputSet __author__ = "Kiran Mathew, Brandon Wood, Zhi Deng, Manas Likhit, Guillaume Brunin (Matgenix)" @@ -104,7 +106,7 @@ def ncomments(self) -> int: n_comments += 1 else: # Else, inline comment each count as one - n_comments += sum(1 for cmd, args in stage["commands"] if cmd.strip().startswith("#")) + n_comments += sum(1 for cmd, _args in stage["commands"] if cmd.strip().startswith("#")) return n_comments @@ -331,7 +333,7 @@ def merge_stages(self, stage_names: list[str]) -> None: Args: stage_names (list): list of strings giving the names of the stages to be merged. """ - if not all(stage in self.stages_names for stage in stage_names): + if any(stage not in self.stages_names for stage in stage_names): raise ValueError("At least one of the stages to be merged is not in the LammpsInputFile.") indices_stages_to_merge = [self.stages_names.index(stage) for stage in stage_names] @@ -476,9 +478,9 @@ def append(self, lmp_input_file: LammpsInputFile) -> None: # Making sure no stage_name of lmp_input_file clash with those from self. # If it is the case, we rename them. - for i_stage, stage in enumerate(lmp_input_file.stages): + for i_stage, stage in enumerate(lmp_input_file.stages, start=1): if stage["stage_name"] in self.stages_names: - stage["stage_name"] = f"Stage {self.nstages + i_stage + 1} (previously {stage['stage_name']})" + stage["stage_name"] = f"Stage {self.nstages + i_stage} (previously {stage['stage_name']})" # Append the two list of stages self.stages += new_list_to_add @@ -536,7 +538,7 @@ def write_file(self, filename: str | PathLike, ignore_comments: bool = False, ke file.write(self.get_str(ignore_comments=ignore_comments, keep_stages=keep_stages)) @classmethod - def from_str(cls, contents: str, ignore_comments: bool = False, keep_stages: bool = False) -> LammpsInputFile: + def from_str(cls, contents: str, ignore_comments: bool = False, keep_stages: bool = False) -> Self: # type: ignore[override] """ Helper method to parse string representation of LammpsInputFile. If you created the input file by hand, there is no guarantee that the representation @@ -555,25 +557,25 @@ def from_str(cls, contents: str, ignore_comments: bool = False, keep_stages: boo Returns: LammpsInputFile """ - LIF = cls() + lammps_in_file = cls() # Strip string from starting and/or ending white spaces - s = contents.strip() + contents = contents.strip() # Remove "&" symbols at the end of lines - while "&" in s: + while "&" in contents: sequence = "&" - index = s.index("&") + index = contents.index("&") next_symbol = "" idx = 0 while next_symbol != "\n": sequence += next_symbol idx += 1 - next_symbol = s[index + idx] - s = s.replace(sequence + "\n", "") + next_symbol = contents[index + idx] + contents = contents.replace(sequence + "\n", "") # Remove unwanted lines from the string - lines = cls._clean_lines(s.splitlines(), ignore_comments=ignore_comments) + lines = cls._clean_lines(contents.splitlines(), ignore_comments=ignore_comments) # Split the string into blocks based on the empty lines of the input file blocks = cls._get_blocks(lines, keep_stages=keep_stages) @@ -584,11 +586,11 @@ def from_str(cls, contents: str, ignore_comments: bool = False, keep_stages: boo if ignore_comments: keep_block = False else: - LIF._add_comment(comment=block[0][1:].strip(), inline=False) - stage_name = f"Comment {LIF.ncomments}" + lammps_in_file._add_comment(comment=block[0][1:].strip(), inline=False) + stage_name = f"Comment {lammps_in_file.ncomments}" if len(block) > 1: for line in block[1:]: - LIF._add_comment(comment=line[1:].strip(), inline=True, stage_name=stage_name) + lammps_in_file._add_comment(comment=line[1:].strip(), inline=True, stage_name=stage_name) # Header of a stage elif block[0][0] == "#" and keep_block: @@ -600,23 +602,21 @@ def from_str(cls, contents: str, ignore_comments: bool = False, keep_stages: boo n_comm_max = idx comments = block[:n_comm_max] - header = "" - for line in comments: - header += line[1:].strip() + " " + header = "".join(f"{line[1:].strip()} " for line in comments) header = header.strip() - stage_name = f"Stage {LIF.nstages + 1}" if (ignore_comments or not keep_stages) else header + stage_name = f"Stage {lammps_in_file.nstages + 1}" if (ignore_comments or not keep_stages) else header commands = block[n_comm_max:] - LIF.add_stage(commands=commands, stage_name=stage_name) + lammps_in_file.add_stage(commands=commands, stage_name=stage_name) # Stage with no header else: - stage_name = f"Stage {LIF.nstages + 1}" - LIF.add_stage(commands=block, stage_name=stage_name) - return LIF + stage_name = f"Stage {lammps_in_file.nstages + 1}" + lammps_in_file.add_stage(commands=block, stage_name=stage_name) + return lammps_in_file @classmethod - def from_file(cls, path: str | Path, ignore_comments: bool = False, keep_stages: bool = False) -> LammpsInputFile: + def from_file(cls, path: str | Path, ignore_comments: bool = False, keep_stages: bool = False) -> Self: # type: ignore[override] """ Creates an InputFile object from a file. @@ -756,7 +756,7 @@ def _add_comment( self.stages.append({"stage_name": stage_name, "commands": [("#", comment)]}) # Inline comment - elif inline and stage_name: + elif stage_name: command = "#" if index_comment: if "Comment" in comment and comment.strip()[9] == ":": @@ -921,7 +921,7 @@ def md( placeholders. """ template_path = os.path.join(template_dir, "md.template") - with open(template_path) as file: + with open(template_path, encoding="utf-8") as file: script_template = file.read() settings = other_settings.copy() if other_settings else {} settings.update({"force_field": force_field, "temperature": temperature, "nsteps": nsteps}) @@ -1068,9 +1068,9 @@ def write_lammps_inputs( os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, script_filename), mode="w") as file: file.write(input_script) - read_data = re.search(r"read_data\s+(.*)\n", input_script) - if read_data: - data_filename = read_data.group(1).split()[0] + + if read_data := re.search(r"read_data\s+(.*)\n", input_script): + data_filename = read_data[1].split()[0] if isinstance(data, LammpsData): data.write_file(os.path.join(output_dir, data_filename), **kwargs) elif isinstance(data, str) and os.path.isfile(data): diff --git a/pymatgen/io/lammps/outputs.py b/pymatgen/io/lammps/outputs.py index 0fc72f638cc..7a2999ffafd 100644 --- a/pymatgen/io/lammps/outputs.py +++ b/pymatgen/io/lammps/outputs.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from typing import Any + from typing_extensions import Self + __author__ = "Kiran Mathew, Zhi Deng" __copyright__ = "Copyright 2018, The Materials Virtual Lab" __version__ = "1.0" @@ -47,7 +49,7 @@ def __init__(self, timestep: int, natoms: int, box: LammpsBox, data: pd.DataFram self.data = data @classmethod - def from_str(cls, string: str) -> LammpsDump: + def from_str(cls, string: str) -> Self: """ Constructor from string parsing. @@ -71,10 +73,10 @@ def from_str(cls, string: str) -> LammpsDump: return cls(time_step, n_atoms, box, data) @classmethod - def from_dict(cls, dct: dict) -> LammpsDump: + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: LammpsDump @@ -111,7 +113,7 @@ def parse_lammps_dumps(file_pattern): files = glob(file_pattern) if len(files) > 1: pattern = file_pattern.replace("*", "([0-9]+)").replace("\\", "\\\\") - files = sorted(files, key=lambda f: int(re.match(pattern, f).group(1))) + files = sorted(files, key=lambda f: int(re.match(pattern, f)[1])) for filename in files: with zopen(filename, mode="rt") as file: @@ -169,7 +171,7 @@ def _parse_thermo(lines: list[str]) -> pd.DataFrame: data = {} step = re.match(multi_pattern, ts[0]) assert step is not None - data["Step"] = int(step.group(1)) + data["Step"] = int(step[1]) data.update({k: float(v) for k, v in re.findall(kv_pattern, "".join(ts[1:]))}) dicts.append(data) df = pd.DataFrame(dicts) diff --git a/pymatgen/io/lammps/sets.py b/pymatgen/io/lammps/sets.py index d303b5ad30a..cb1cdb08e89 100644 --- a/pymatgen/io/lammps/sets.py +++ b/pymatgen/io/lammps/sets.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: from pathlib import Path + from typing_extensions import Self + __author__ = "Ryan Kingsbury, Guillaume Brunin (Matgenix)" __copyright__ = "Copyright 2021, The Materials Project" __version__ = "0.2" @@ -73,7 +75,7 @@ def __init__( super().__init__(inputs={"in.lammps": self.inputfile, "system.data": self.data}) @classmethod - def from_directory(cls, directory: str | Path, keep_stages: bool = False) -> LammpsInputSet: + def from_directory(cls, directory: str | Path, keep_stages: bool = False) -> Self: # type: ignore[override] """ Construct a LammpsInputSet from a directory of two or more files. TODO: accept directories with only the input file, that should include the structure as well. @@ -88,7 +90,7 @@ def from_directory(cls, directory: str | Path, keep_stages: bool = False) -> Lam if isinstance(atom_style, list): raise ValueError("Variable atom_style is specified multiple times in the input file.") data_file = LammpsData.from_file(f"{directory}/system.data", atom_style=atom_style) - return LammpsInputSet(inputfile=input_file, data=data_file, calc_type="read_from_dir") + return cls(inputfile=input_file, data=data_file, calc_type="read_from_dir") def validate(self) -> bool: """ diff --git a/pymatgen/io/lammps/utils.py b/pymatgen/io/lammps/utils.py index b36539ce46f..64670889fb0 100644 --- a/pymatgen/io/lammps/utils.py +++ b/pymatgen/io/lammps/utils.py @@ -18,15 +18,16 @@ from pymatgen.io.packmol import PackmolBoxGen from pymatgen.util.coord import get_angle +try: + from openbabel import pybel +except ImportError: + pybel = None + if TYPE_CHECKING: from collections.abc import Sequence from numpy.typing import ArrayLike -try: - from openbabel import pybel -except ImportError: - pybel = None __author__ = "Kiran Mathew, Brandon Wood, Michael Humbert" __email__ = "kmathew@lbl.gov" @@ -493,4 +494,4 @@ def run(self) -> tuple[bytes, bytes]: auto_box=False, output_file="cocktail.xyz", ) - s = pmr.run() + mol = pmr.run() diff --git a/pymatgen/io/lmto.py b/pymatgen/io/lmto.py index 89740618824..d4cf2c35e96 100644 --- a/pymatgen/io/lmto.py +++ b/pymatgen/io/lmto.py @@ -7,7 +7,7 @@ from __future__ import annotations import re -from typing import no_type_check +from typing import TYPE_CHECKING, no_type_check import numpy as np from monty.io import zopen @@ -18,6 +18,11 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.num import round_to_sigfigs +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "Marco Esters" __copyright__ = "Copyright 2017, The Materials Project" __version__ = "0.1" @@ -136,7 +141,7 @@ def write_file(self, filename="CTRL", **kwargs): file.write(self.get_str(**kwargs)) @classmethod - def from_file(cls, filename="CTRL", **kwargs): + def from_file(cls, filename: str | Path = "CTRL", **kwargs) -> Self: """ Creates a CTRL file object from an existing file. @@ -148,11 +153,11 @@ def from_file(cls, filename="CTRL", **kwargs): """ with zopen(filename, mode="rt") as file: contents = file.read() - return LMTOCtrl.from_str(contents, **kwargs) + return cls.from_str(contents, **kwargs) @classmethod @no_type_check - def from_str(cls, data: str, sigfigs: int = 8) -> LMTOCtrl: + def from_str(cls, data: str, sigfigs: int = 8) -> Self: """ Creates a CTRL file object from a string. This will mostly be used to read an LMTOCtrl object from a CTRL file. Empty spheres @@ -181,13 +186,13 @@ def from_str(cls, data: str, sigfigs: int = 8) -> LMTOCtrl: for cat in ["STRUC", "CLASS", "SITE"]: fields = struct_lines[cat].split("=") - for f, field in enumerate(fields): + for idx, field in enumerate(fields, start=1): token = field.split()[-1] if token == "ALAT": - alat = round(float(fields[f + 1].split()[0]), sigfigs) - structure_tokens["ALAT"] = alat + a_lat = round(float(fields[idx].split()[0]), sigfigs) + structure_tokens["ALAT"] = a_lat elif token == "ATOM": - atom = fields[f + 1].split()[0] + atom = fields[idx].split()[0] if not bool(re.match("E[0-9]*$", atom)): if cat == "CLASS": structure_tokens["CLASS"].append(atom) @@ -197,9 +202,9 @@ def from_str(cls, data: str, sigfigs: int = 8) -> LMTOCtrl: pass elif token in ["PLAT", "POS"]: try: - arr = np.array([round(float(i), sigfigs) for i in fields[f + 1].split()]) + arr = np.array([round(float(i), sigfigs) for i in fields[idx].split()]) except ValueError: - arr = np.array([round(float(i), sigfigs) for i in fields[f + 1].split()[:-1]]) + arr = np.array([round(float(i), sigfigs) for i in fields[idx].split()[:-1]]) if token == "PLAT": structure_tokens["PLAT"] = arr.reshape([3, 3]) elif not bool(re.match("E[0-9]*$", atom)): @@ -209,9 +214,9 @@ def from_str(cls, data: str, sigfigs: int = 8) -> LMTOCtrl: else: pass try: - spcgrp_index = struct_lines["SYMGRP"].index("SPCGRP") - spcgrp = struct_lines["SYMGRP"][spcgrp_index : spcgrp_index + 12] - structure_tokens["SPCGRP"] = spcgrp.split("=")[1].split()[0] + spc_grp_index = struct_lines["SYMGRP"].index("SPCGRP") + spc_grp = struct_lines["SYMGRP"][spc_grp_index : spc_grp_index + 12] + structure_tokens["SPCGRP"] = spc_grp.split("=")[1].split()[0] except ValueError: pass @@ -224,7 +229,7 @@ def from_str(cls, data: str, sigfigs: int = 8) -> LMTOCtrl: return LMTOCtrl.from_dict(structure_tokens) @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Creates a CTRL file object from a dictionary. The dictionary must contain the items "ALAT", PLAT" and "SITE". diff --git a/pymatgen/io/lobster/inputs.py b/pymatgen/io/lobster/inputs.py index d2d619f16a7..f0ec0d41ae1 100644 --- a/pymatgen/io/lobster/inputs.py +++ b/pymatgen/io/lobster/inputs.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.core.composition import Composition @@ -129,7 +131,7 @@ def __init__(self, settingsdict: dict): # check for duplicates keys = [key.lower() for key in settingsdict] if len(keys) != len(set(keys)): - raise OSError("There are duplicates for the keywords! The program will stop here.") + raise KeyError("There are duplicates for the keywords!") self.update(settingsdict) def __setitem__(self, key, val): @@ -139,32 +141,26 @@ def __setitem__(self, key, val): leading and trailing white spaces. Similar to INCAR class. """ # due to the missing case sensitivity of lobster, the following code is necessary - found = False - for key_here in self: - if key.strip().lower() == key_here.lower(): - new_key = key_here - found = True - if not found: - new_key = key + new_key = next((key_here for key_here in self if key.strip().lower() == key_here.lower()), key) + if new_key.lower() not in [element.lower() for element in Lobsterin.AVAILABLE_KEYWORDS]: - raise ValueError("Key is currently not available") + raise KeyError("Key is currently not available") super().__setitem__(new_key, val.strip() if isinstance(val, str) else val) def __getitem__(self, item): """Implements getitem from dict to avoid problems with cases.""" - found = False - for key_here in self: - if item.strip().lower() == key_here.lower(): - new_key = key_here - found = True - if not found: - new_key = item + new_item = next((key_here for key_here in self if item.strip().lower() == key_here.lower()), item) + + if new_item.lower() not in [element.lower() for element in Lobsterin.AVAILABLE_KEYWORDS]: + raise KeyError("Key is currently not available") - return super().__getitem__(new_key) + return super().__getitem__(new_item) def __delitem__(self, key): - del self.data[key.lower()] + new_key = next((key_here for key_here in self if key.strip().lower() == key_here.lower()), key) + + del self.data[new_key] def diff(self, other): """ @@ -182,50 +178,47 @@ def diff(self, other): key_list_others = [element.lower() for element in other] for k1, v1 in self.items(): - k1lower = k1.lower() - if k1lower not in key_list_others: - different_param[k1.upper()] = {"lobsterin1": v1, "lobsterin2": None} - else: - for key_here in other: - if k1.lower() == key_here.lower(): - new_key = key_here - - if isinstance(v1, str): - if v1.strip().lower() != other[new_key].strip().lower(): - different_param[k1.upper()] = { - "lobsterin1": v1, - "lobsterin2": other[new_key], - } - else: - similar_param[k1.upper()] = v1 - elif isinstance(v1, list): - new_set1 = {element.strip().lower() for element in v1} - new_set2 = {element.strip().lower() for element in other[new_key]} - if new_set1 != new_set2: - different_param[k1.upper()] = { - "lobsterin1": v1, - "lobsterin2": other[new_key], - } - elif v1 != other[new_key]: - different_param[k1.upper()] = { + k1_lower = k1.lower() + k1_in_other = next((key_here for key_here in other if key_here.lower() == k1_lower), k1_lower) + if k1_lower not in key_list_others: + different_param[k1.lower()] = {"lobsterin1": v1, "lobsterin2": None} + elif isinstance(v1, str): + if v1.strip().lower() != other[k1_lower].strip().lower(): + different_param[k1.lower()] = { "lobsterin1": v1, - "lobsterin2": other[new_key], + "lobsterin2": other[k1_in_other], } else: - similar_param[k1.upper()] = v1 + similar_param[k1.lower()] = v1 + elif isinstance(v1, list): + new_set1 = {element.strip().lower() for element in v1} + new_set2 = {element.strip().lower() for element in other[k1_in_other]} + if new_set1 != new_set2: + different_param[k1.lower()] = { + "lobsterin1": v1, + "lobsterin2": other[k1_in_other], + } + elif v1 != other[k1_lower]: + different_param[k1.lower()] = { + "lobsterin1": v1, + "lobsterin2": other[k1_in_other], + } + else: + similar_param[k1.lower()] = v1 for k2, v2 in other.items(): - if k2.upper() not in similar_param and k2.upper() not in different_param: - for key_here in self: - new_key = key_here if k2.lower() == key_here.lower() else k2 - if new_key not in self: - different_param[k2.upper()] = {"lobsterin1": None, "lobsterin2": v2} + if ( + k2.lower() not in similar_param + and k2.lower() not in different_param + and k2.lower() not in [key.lower() for key in self] + ): + different_param[k2.lower()] = {"lobsterin1": None, "lobsterin2": v2} return {"Same": similar_param, "Different": different_param} def _get_nbands(self, structure: Structure): """Get number of bands.""" if self.get("basisfunctions") is None: - raise OSError("No basis functions are provided. The program cannot calculate nbands.") + raise ValueError("No basis functions are provided. The program cannot calculate nbands.") basis_functions: list[str] = [] for string_basis in self["basisfunctions"]: @@ -261,13 +254,10 @@ def write_lobsterin(self, path="lobsterin", overwritedict=None): # has to search first if entry is already in Lobsterindict (due to case insensitivity) if overwritedict is not None: for key, entry in overwritedict.items(): - found = False + self[key] = entry for key2 in self: if key.lower() == key2.lower(): self[key2] = entry - found = True - if not found: - self[key] = entry filename = path @@ -280,9 +270,7 @@ def write_lobsterin(self, path="lobsterin", overwritedict=None): # checks if entry is True or False for key_here in self: if key.lower() == key_here.lower(): - new_key = key_here - if self.get(new_key): - file.write(key + "\n") + file.write(key + "\n") elif key.lower() in [element.lower() for element in Lobsterin.STRING_KEYWORDS]: file.write(f"{key} {self.get(key)}\n") elif key.lower() in [element.lower() for element in Lobsterin.LISTKEYWORDS]: @@ -297,7 +285,7 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -305,7 +293,7 @@ def from_dict(cls, dct): Returns: Lobsterin """ - return Lobsterin({k: v for k, v in dct.items() if k not in ["@module", "@class"]}) + return cls({k: v for k, v in dct.items() if k not in ["@module", "@class"]}) def write_INCAR( self, @@ -365,18 +353,18 @@ def get_basis( atom_types_potcar = [name.split("_")[0] for name in potcar_names] if set(structure.symbol_set) != set(atom_types_potcar): - raise OSError("Your POSCAR does not correspond to your POTCAR!") - BASIS = loadfn(address_basis_file)["BASIS"] + raise ValueError("Your POSCAR does not correspond to your POTCAR!") + basis = loadfn(address_basis_file)["BASIS"] basis_functions = [] list_forin = [] - for idx, basis in enumerate(potcar_names): - if basis not in BASIS: + for idx, name in enumerate(potcar_names): + if name not in basis: raise ValueError( - f"You have to provide the basis for {basis} manually. We don't have any information on this POTCAR." + f"Missing basis information for POTCAR symbol: {name}. Please provide the basis manually." ) - basis_functions.append(BASIS[basis].split()) - list_forin.append(f"{atom_types_potcar[idx]} {BASIS[basis]}") + basis_functions.append(basis[name].split()) + list_forin.append(f"{atom_types_potcar[idx]} {basis[name]}") return list_forin @staticmethod @@ -472,7 +460,7 @@ def write_KPOINTS( # The following code is taken from: SpacegroupAnalyzer # we need to switch off symmetry here - latt = structure.lattice.matrix + matrix = structure.lattice.matrix positions = structure.frac_coords unique_species: list[Composition] = [] zs = [] @@ -495,7 +483,7 @@ def write_KPOINTS( magmoms.append(0) # For now, we are setting magmom to zero. (Taken from INCAR class) - cell = latt, positions, zs, magmoms + cell = matrix, positions, zs, magmoms # TODO: what about this shift? mapping, grid = spglib.get_ir_reciprocal_mesh(mesh, cell, is_shift=[0, 0, 0]) @@ -549,10 +537,9 @@ def write_KPOINTS( kpts.append(f) weights.append(0.0) all_labels.append(labels[k]) - ISYM = isym - comment = f"{ISYM=}, grid: {mesh} plus kpoint path" if line_mode else f"{ISYM=}, grid: {mesh}" + comment = f"{isym=}, grid: {mesh} plus kpoint path" if line_mode else f"{isym=}, grid: {mesh}" - KpointObject = Kpoints( + kpoint_object = Kpoints( comment=comment, style=Kpoints.supported_modes.Reciprocal, num_kpts=len(kpts), @@ -561,10 +548,10 @@ def write_KPOINTS( labels=all_labels, ) - KpointObject.write_file(filename=KPOINTS_output) + kpoint_object.write_file(filename=KPOINTS_output) @classmethod - def from_file(cls, lobsterin: str): + def from_file(cls, lobsterin: str) -> Self: """ Args: lobsterin (str): path to lobsterin. @@ -575,8 +562,8 @@ def from_file(cls, lobsterin: str): with zopen(lobsterin, mode="rt") as file: data = file.read().split("\n") if len(data) == 0: - raise OSError("lobsterin file contains no data.") - Lobsterindict: dict[str, Any] = {} + raise RuntimeError("lobsterin file contains no data.") + lobsterin_dict: dict[str, Any] = {} for datum in data: # Remove all comments @@ -589,22 +576,22 @@ def from_file(cls, lobsterin: str): # check which type of keyword this is, handle accordingly if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.LISTKEYWORDS]: if key_word[0].lower() not in [datum2.lower() for datum2 in Lobsterin.FLOAT_KEYWORDS]: - if key_word[0].lower() not in Lobsterindict: - Lobsterindict[key_word[0].lower()] = " ".join(key_word[1:]) + if key_word[0].lower() not in lobsterin_dict: + lobsterin_dict[key_word[0].lower()] = " ".join(key_word[1:]) else: raise ValueError(f"Same keyword {key_word[0].lower()} twice!") - elif key_word[0].lower() not in Lobsterindict: - Lobsterindict[key_word[0].lower()] = float(key_word[1]) + elif key_word[0].lower() not in lobsterin_dict: + lobsterin_dict[key_word[0].lower()] = float(key_word[1]) else: raise ValueError(f"Same keyword {key_word[0].lower()} twice!") - elif key_word[0].lower() not in Lobsterindict: - Lobsterindict[key_word[0].lower()] = [" ".join(key_word[1:])] + elif key_word[0].lower() not in lobsterin_dict: + lobsterin_dict[key_word[0].lower()] = [" ".join(key_word[1:])] else: - Lobsterindict[key_word[0].lower()].append(" ".join(key_word[1:])) + lobsterin_dict[key_word[0].lower()].append(" ".join(key_word[1:])) elif len(key_word) > 0: - Lobsterindict[key_word[0].lower()] = True + lobsterin_dict[key_word[0].lower()] = True - return cls(Lobsterindict) + return cls(lobsterin_dict) @staticmethod def _get_potcar_symbols(POTCAR_input: str) -> list: @@ -620,7 +607,7 @@ def _get_potcar_symbols(POTCAR_input: str) -> list: potcar = Potcar.from_file(POTCAR_input) for pot in potcar: if pot.potential_type != "PAW": - raise OSError("Lobster only works with PAW! Use different POTCARs") + raise ValueError("Lobster only works with PAW! Use different POTCARs") # Warning about a bug in lobster-4.1.0 with zopen(POTCAR_input, mode="r") as file: @@ -637,7 +624,7 @@ def _get_potcar_symbols(POTCAR_input: str) -> list: ) if potcar.functional != "PBE": - raise OSError("We only have BASIS options for PBE so far") + raise RuntimeError("We only have BASIS options for PBE so far") return [name["symbol"] for name in potcar.spec] @@ -696,7 +683,7 @@ def standard_calculations_from_vasp_files( ]: raise ValueError("The option is not valid!") - Lobsterindict: dict[str, Any] = { + lobsterin_dict: dict[str, Any] = { # this basis set covers most elements "basisSet": "pbeVaspFit2015", # energies around e-fermi @@ -715,95 +702,95 @@ def standard_calculations_from_vasp_files( "standard_with_fatband", }: # every interaction with a distance of 6.0 is checked - Lobsterindict["cohpGenerator"] = "from 0.1 to 6.0 orbitalwise" + lobsterin_dict["cohpGenerator"] = "from 0.1 to 6.0 orbitalwise" # the projection is saved - Lobsterindict["saveProjectionToFile"] = True + lobsterin_dict["saveProjectionToFile"] = True if option == "standard_from_projection": - Lobsterindict["cohpGenerator"] = "from 0.1 to 6.0 orbitalwise" - Lobsterindict["loadProjectionFromFile"] = True + lobsterin_dict["cohpGenerator"] = "from 0.1 to 6.0 orbitalwise" + lobsterin_dict["loadProjectionFromFile"] = True if option == "standard_with_energy_range_from_vasprun": Vr = Vasprun(Vasprun_output) - Lobsterindict["COHPstartEnergy"] = round(min(Vr.complete_dos.energies - Vr.complete_dos.efermi), 4) - Lobsterindict["COHPendEnergy"] = round(max(Vr.complete_dos.energies - Vr.complete_dos.efermi), 4) - Lobsterindict["COHPSteps"] = len(Vr.complete_dos.energies) + lobsterin_dict["COHPstartEnergy"] = round(min(Vr.complete_dos.energies - Vr.complete_dos.efermi), 4) + lobsterin_dict["COHPendEnergy"] = round(max(Vr.complete_dos.energies - Vr.complete_dos.efermi), 4) + lobsterin_dict["COHPSteps"] = len(Vr.complete_dos.energies) # TODO: add cobi here! might be relevant lobster version if option == "onlycohp": - Lobsterindict["skipdos"] = True - Lobsterindict["skipcoop"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipcoop"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlycoop": - Lobsterindict["skipdos"] = True - Lobsterindict["skipcohp"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipcohp"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlycohpcoop": - Lobsterindict["skipdos"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlycohpcoopcobi": - Lobsterindict["skipdos"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlydos": - Lobsterindict["skipcohp"] = True - Lobsterindict["skipcoop"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True + lobsterin_dict["skipcohp"] = True + lobsterin_dict["skipcoop"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlyprojection": - Lobsterindict["skipdos"] = True - Lobsterindict["skipcohp"] = True - Lobsterindict["skipcoop"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True - Lobsterindict["saveProjectionToFile"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipcohp"] = True + lobsterin_dict["skipcoop"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True + lobsterin_dict["saveProjectionToFile"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlycobi": - Lobsterindict["skipdos"] = True - Lobsterindict["skipcohp"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipcohp"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True - Lobsterindict["skipMadelungEnergy"] = True + lobsterin_dict["skipcobi"] = True + lobsterin_dict["skipMadelungEnergy"] = True if option == "onlymadelung": - Lobsterindict["skipdos"] = True - Lobsterindict["skipcohp"] = True - Lobsterindict["skipcoop"] = True - Lobsterindict["skipPopulationAnalysis"] = True - Lobsterindict["skipGrossPopulation"] = True - Lobsterindict["saveProjectionToFile"] = True + lobsterin_dict["skipdos"] = True + lobsterin_dict["skipcohp"] = True + lobsterin_dict["skipcoop"] = True + lobsterin_dict["skipPopulationAnalysis"] = True + lobsterin_dict["skipGrossPopulation"] = True + lobsterin_dict["saveProjectionToFile"] = True # lobster-4.1.0 - Lobsterindict["skipcobi"] = True + lobsterin_dict["skipcobi"] = True incar = Incar.from_file(INCAR_input) if incar["ISMEAR"] == 0: - Lobsterindict["gaussianSmearingWidth"] = incar["SIGMA"] + lobsterin_dict["gaussianSmearingWidth"] = incar["SIGMA"] if incar["ISMEAR"] != 0 and option == "standard_with_fatband": raise ValueError("ISMEAR has to be 0 for a fatband calculation with Lobster") if dict_for_basis is not None: @@ -817,11 +804,11 @@ def standard_calculations_from_vasp_files( basis = Lobsterin.get_basis(structure=Structure.from_file(POSCAR_input), potcar_symbols=potcar_names) else: raise ValueError("basis cannot be generated") - Lobsterindict["basisfunctions"] = basis + lobsterin_dict["basisfunctions"] = basis if option == "standard_with_fatband": - Lobsterindict["createFatband"] = basis + lobsterin_dict["createFatband"] = basis - return cls(Lobsterindict) + return cls(lobsterin_dict) def get_all_possible_basis_combinations(min_basis: list, max_basis: list) -> list: diff --git a/pymatgen/io/lobster/lobsterenv.py b/pymatgen/io/lobster/lobsterenv.py index a15ee86a9ab..342b7bc0917 100644 --- a/pymatgen/io/lobster/lobsterenv.py +++ b/pymatgen/io/lobster/lobsterenv.py @@ -32,6 +32,8 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core import Structure __author__ = "Janine George" @@ -79,7 +81,6 @@ def __init__( id_blist_sg2: str = "ICOBI", ) -> None: """ - Args: filename_icohp: (str) Path to ICOHPLIST.lobster or ICOOPLIST.lobster or ICOBILIST.lobster obj_icohp: Icohplist object @@ -161,6 +162,8 @@ def __init__( elif self.id_blist_sg2.lower() == "icobi": are_coops_id2 = False are_cobis_id2 = True + else: + raise ValueError("only icoops and icobis can be added") self.bonding_list_2 = Icohplist( filename=self.filename_blist_sg2, @@ -573,7 +576,7 @@ def get_info_cohps_to_neighbors( if present: new_labels.append(key) new_atoms.append(atompair) - if len(new_labels) > 0: + if new_labels: divisor = len(new_labels) if per_bond else 1 plot_label = self._get_plot_label(new_atoms, per_bond) @@ -595,7 +598,7 @@ def _get_plot_label(self, atoms, per_bond): for atoms_names in atoms: new = [self._split_string(atoms_names[0])[0], self._split_string(atoms_names[1])[0]] new.sort() - string_here = new[0] + "-" + new[1] + string_here = f"{new[0]}-{new[1]}" all_labels.append(string_here) count = collections.Counter(all_labels) plotlabels = [] @@ -614,7 +617,7 @@ def get_info_icohps_between_neighbors(self, isites=None, onlycation_isites=True) isites: list of site ids, if isite==None, all isites will be used onlycation_isites: will only use cations, if isite==None - Returns + Returns: ICOHPNeighborsInfo """ lowerlimit = self.lowerlimit @@ -733,7 +736,7 @@ def _evaluate_ce( additional_condition=additional_condition, ) - elif lowerlimit is None and (upperlimit is not None or lowerlimit is not None): + elif upperlimit is None or lowerlimit is None: raise ValueError("Please give two limits or leave them both at None") # find environments based on ICOHP values @@ -844,7 +847,7 @@ def _find_environments(self, additional_condition, lowerlimit, upperlimit, only_ list_icohps = [] list_lengths = [] list_keys = [] - for idx in range(len(self.structure)): + for idx, site in enumerate(self.structure): icohps = self._get_icohps( icohpcollection=self.Icohpcollection, isite=idx, @@ -857,7 +860,7 @@ def _find_environments(self, additional_condition, lowerlimit, upperlimit, only_ keys_from_ICOHPs, lengths_from_ICOHPs, neighbors_from_ICOHPs, selected_ICOHPs = additional_conds if len(neighbors_from_ICOHPs) > 0: - centralsite = self.structure[idx] + centralsite = site neighbors_by_distance_start = self.structure.get_sites_in_sphere( pt=centralsite.coords, @@ -959,6 +962,7 @@ def _find_relevant_atoms_additional_condition(self, isite, icohps, additional_co atomnr2 = self._get_atomnumber(icohp._atom2) # test additional conditions + val1 = val2 = None if additional_condition in (1, 3, 5, 6): val1 = self.valences[atomnr1] val2 = self.valences[atomnr2] @@ -1167,6 +1171,7 @@ def _get_limit_from_extremum( tuple[float, float]: [-inf, min(strongest_icohp*0.15,-noise_cutoff)] / [max(strongest_icohp*0.15, noise_cutoff), inf] """ + extremum_based = None if not adapt_extremum_to_add_cond or additional_condition == 0: extremum_based = icohpcollection.extremum_icohpvalue(summed_spin_channels=True) * percentage @@ -1268,7 +1273,7 @@ def from_Lobster( list_neighisite, structure: Structure, valences=None, - ): + ) -> Self: """ Will set up a LightStructureEnvironments from Lobster. diff --git a/pymatgen/io/lobster/outputs.py b/pymatgen/io/lobster/outputs.py index f670c5206e4..202854d64fa 100644 --- a/pymatgen/io/lobster/outputs.py +++ b/pymatgen/io/lobster/outputs.py @@ -12,6 +12,7 @@ import collections import fnmatch +import itertools import os import re import warnings @@ -116,7 +117,7 @@ def __init__( # contains all parameters that are needed to map the file. parameters = contents[1].split() # Subtract 1 to skip the average - num_bonds = int(parameters[0]) - 1 if not self.are_multi_center_cobis else int(parameters[0]) + num_bonds = int(parameters[0]) if self.are_multi_center_cobis else int(parameters[0]) - 1 self.efermi = float(parameters[-1]) self.is_spin_polarized = int(parameters[1]) == 2 spins = [Spin.up, Spin.down] if int(parameters[1]) == 2 else [Spin.up] @@ -124,7 +125,6 @@ def __init__( if not self.are_multi_center_cobis: # The COHP data start in row num_bonds + 3 data = np.array([np.array(row.split(), dtype=float) for row in contents[num_bonds + 3 :]]).transpose() - self.energies = data[0] cohp_data = { "average": { "COHP": {spin: data[1 + 2 * s * (num_bonds + 1)] for s, spin in enumerate(spins)}, @@ -134,7 +134,8 @@ def __init__( else: # The COBI data start in row num_bonds + 3 if multi-center cobis exist data = np.array([np.array(row.split(), dtype=float) for row in contents[num_bonds + 3 :]]).transpose() - self.energies = data[0] + + self.energies = data[0] orb_cohp: dict[str, Any] = {} # present for Lobster versions older than Lobster 2.2.0 @@ -385,7 +386,7 @@ def __init__( with zopen(filename, mode="rt") as file: data = file.read().split("\n")[1:-1] if len(data) == 0: - raise OSError("ICOHPLIST file contains no data.") + raise RuntimeError("ICOHPLIST file contains no data.") # Which Lobster version? if len(data[0].split()) == 8: @@ -422,7 +423,7 @@ def __init__( # TODO: adapt this for orbital-wise stuff n_bonds = len(data_without_orbitals) // 2 if n_bonds == 0: - raise OSError("ICOHPLIST file contains no data.") + raise RuntimeError("ICOHPLIST file contains no data.") else: n_bonds = len(data_without_orbitals) @@ -549,7 +550,7 @@ def __init__(self, filename: str | None = "NcICOBILIST.lobster"): # LOBSTER < 4 with zopen(filename, mode="rt") as file: # type:ignore data = file.read().split("\n")[1:-1] if len(data) == 0: - raise OSError("NcICOBILIST file contains no data.") + raise RuntimeError("NcICOBILIST file contains no data.") # If the calculation is spin-polarized, the line in the middle # of the file will be another header line. @@ -579,7 +580,7 @@ def __init__(self, filename: str | None = "NcICOBILIST.lobster"): # LOBSTER < 4 # TODO: adapt this for orbitalwise case n_bonds = len(data_without_orbitals) // 2 if n_bonds == 0: - raise OSError("NcICOBILIST file contains no data.") + raise RuntimeError("NcICOBILIST file contains no data.") else: n_bonds = len(data_without_orbitals) @@ -674,11 +675,11 @@ def _parse_doscar(self): tdensities = {} itdensities = {} with zopen(doscar, mode="rt") as file: - natoms = int(file.readline().split()[0]) + n_atoms = int(file.readline().split()[0]) efermi = float([file.readline() for nn in range(4)][3].split()[17]) dos = [] orbitals = [] - for _atom in range(natoms + 1): + for _atom in range(n_atoms + 1): line = file.readline() ndos = int(line.split()[2]) orbitals += [line.split(";")[-1].split()] @@ -702,7 +703,7 @@ def _parse_doscar(self): itdensities[Spin.up] = doshere[:, 2] pdoss = [] spin = Spin.up - for atom in range(natoms): + for atom in range(n_atoms): pdos = defaultdict(dict) data = dos[atom + 1] _, ncol = data.shape @@ -718,7 +719,7 @@ def _parse_doscar(self): itdensities[Spin.up] = doshere[:, 3] itdensities[Spin.down] = doshere[:, 4] pdoss = [] - for atom in range(natoms): + for atom in range(n_atoms): pdos = defaultdict(dict) data = dos[atom + 1] _, ncol = data.shape @@ -820,7 +821,7 @@ def __init__( with zopen(filename, mode="rt") as file: data = file.read().split("\n")[3:-3] if len(data) == 0: - raise OSError("CHARGES file contains no data.") + raise RuntimeError("CHARGES file contains no data.") self.num_atoms = len(data) for atom in range(self.num_atoms): @@ -943,7 +944,7 @@ def __init__(self, filename: str | None, **kwargs) -> None: with zopen(filename, mode="rt") as file: # read in file data = file.read().split("\n") if len(data) == 0: - raise OSError("lobsterout does not contain any data") + raise RuntimeError("lobsterout does not contain any data") # check if Lobster starts from a projection self.is_restart_from_projection = "loading projection from projectionData.lobster..." in data @@ -964,11 +965,11 @@ def __init__(self, filename: str | None, **kwargs) -> None: self.basis_functions = basisfunctions wall_time, user_time, sys_time = self._get_timing(data=data) - timing = {} - timing["wall_time"] = wall_time - timing["user_time"] = user_time - timing["sys_time"] = sys_time - self.timing = timing + self.timing = { + "wall_time": wall_time, + "user_time": user_time, + "sys_time": sys_time, + } warninglines = self._get_all_warning_lines(data=data) self.warning_lines = warninglines @@ -1084,15 +1085,13 @@ def _get_dft_program(data): @staticmethod def _get_number_of_spins(data): - if "spillings for spin channel 2" in data: - return 2 - return 1 + return 2 if "spillings for spin channel 2" in data else 1 @staticmethod def _get_threads(data): for row in data: splitrow = row.split() - if len(splitrow) > 11 and ((splitrow[11]) == "threads" or (splitrow[11] == "thread")): + if len(splitrow) > 11 and splitrow[11] in {"threads", "thread"}: return splitrow[10] raise ValueError("Threads not found.") @@ -1157,9 +1156,9 @@ def _get_timing(data): if "wall" in splitrow: wall_time = splitrow[2:10] if "user" in splitrow: - user_time = splitrow[0:8] + user_time = splitrow[:8] if "sys" in splitrow: - sys_time = splitrow[0:8] + sys_time = splitrow[:8] wall_time_dict = {"h": wall_time[0], "min": wall_time[2], "s": wall_time[4], "ms": wall_time[6]} user_time_dict = {"h": user_time[0], "min": user_time[2], "s": user_time[4], "ms": user_time[6]} @@ -1267,6 +1266,7 @@ def __init__( atom_type = [] atom_names = [] orbital_names = [] + parameters = [] if not isinstance(filenames, list) or filenames is None: filenames_new = [] @@ -1308,7 +1308,9 @@ def __init__( "present" ) - kpoints_array = [] + kpoints_array: list = [] + eigenvals: dict = {} + p_eigenvals: dict = {} for ifilename, filename in enumerate(filenames): with zopen(filename, mode="rt") as file: contents = file.read().split("\n") @@ -1331,7 +1333,7 @@ def __init__( self.is_spinpolarized = len(linenumbers) == 2 if ifilename == 0: - eigenvals = {} # type: dict + eigenvals = {} eigenvals[Spin.up] = [ [collections.defaultdict(float) for _ in range(self.number_kpts)] for _ in range(self.nbands) ] @@ -1340,12 +1342,12 @@ def __init__( [collections.defaultdict(float) for _ in range(self.number_kpts)] for _ in range(self.nbands) ] - p_eigenvals = {} # type: dict + p_eigenvals = {} p_eigenvals[Spin.up] = [ [ { - str(e): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[e]} - for e in atom_names + str(elem): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[elem]} + for elem in atom_names } for _ in range(self.number_kpts) ] @@ -1356,8 +1358,8 @@ def __init__( p_eigenvals[Spin.down] = [ [ { - str(e): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[e]} - for e in atom_names + str(elem): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[elem]} + for elem in atom_names } for _ in range(self.number_kpts) ] @@ -1366,6 +1368,7 @@ def __init__( idx_kpt = -1 linenumber = 0 + iband = 0 for line in contents[1:-1]: if line.split()[0] == "#": KPOINT = np.array( @@ -1486,7 +1489,7 @@ def _read(self, contents: list, spin_numbers: list): kpoint = line.split(" ") kpoint_array = [] for kpointel in kpoint: - if kpointel not in ["at", "k-point", ""]: + if kpointel not in {"at", "k-point", ""}: kpoint_array += [float(kpointel)] elif "maxDeviation" in line: @@ -1559,15 +1562,16 @@ def has_good_quality_check_occupied_bands( for matrix in self.band_overlaps_dict[Spin.down]["matrices"]: for iband1, band1 in enumerate(matrix): for iband2, band2 in enumerate(band1): - if number_occ_bands_spin_down is not None: - if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down: - if iband1 == iband2: - if abs(band2 - 1.0).all() > limit_deviation: - return False - elif band2.all() > limit_deviation: + if number_occ_bands_spin_down is None: + raise ValueError("number_occ_bands_spin_down has to be specified") + + if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down: + if iband1 == iband2: + if abs(band2 - 1.0).all() > limit_deviation: return False - else: - ValueError("number_occ_bands_spin_down has to be specified") + elif band2.all() > limit_deviation: + return False + return True @property @@ -1686,10 +1690,9 @@ def _parse_file(filename): real += [float(splitline[4])] imaginary += [float(splitline[5])] - if not len(real) == grid[0] * grid[1] * grid[2]: - raise ValueError("Something went wrong while reading the file") - if not len(imaginary) == grid[0] * grid[1] * grid[2]: + if len(real) != grid[0] * grid[1] * grid[2] or len(imaginary) != grid[0] * grid[1] * grid[2]: raise ValueError("Something went wrong while reading the file") + return grid, points, real, imaginary, distance def set_volumetric_data(self, grid, structure): @@ -1713,36 +1716,31 @@ def set_volumetric_data(self, grid, structure): new_imaginary = [] new_density = [] - runner = 0 - for x in range(Nx + 1): - for y in range(Ny + 1): - for z in range(Nz + 1): - x_here = x / float(Nx) * a[0] + y / float(Ny) * b[0] + z / float(Nz) * c[0] - y_here = x / float(Nx) * a[1] + y / float(Ny) * b[1] + z / float(Nz) * c[1] - z_here = x / float(Nx) * a[2] + y / float(Ny) * b[2] + z / float(Nz) * c[2] - - if x != Nx and y != Ny and z != Nz: - if ( - not np.isclose(self.points[runner][0], x_here, 1e-3) - and not np.isclose(self.points[runner][1], y_here, 1e-3) - and not np.isclose(self.points[runner][2], z_here, 1e-3) - ): - raise ValueError( - "The provided wavefunction from Lobster does not contain all relevant" - " points. " - "Please use a line similar to: printLCAORealSpaceWavefunction kpoint 1 " - "coordinates 0.0 0.0 0.0 coordinates 1.0 1.0 1.0 box bandlist 1 " - ) - - new_x += [x_here] - new_y += [y_here] - new_z += [z_here] - - new_real += [self.real[runner]] - new_imaginary += [self.imaginary[runner]] - new_density += [self.real[runner] ** 2 + self.imaginary[runner] ** 2] - - runner += 1 + for runner, (x, y, z) in enumerate(itertools.product(range(Nx + 1), range(Ny + 1), range(Nz + 1))): + x_here = x / float(Nx) * a[0] + y / float(Ny) * b[0] + z / float(Nz) * c[0] + y_here = x / float(Nx) * a[1] + y / float(Ny) * b[1] + z / float(Nz) * c[1] + z_here = x / float(Nx) * a[2] + y / float(Ny) * b[2] + z / float(Nz) * c[2] + + if x != Nx and y != Ny and z != Nz: + if ( + not np.isclose(self.points[runner][0], x_here, 1e-3) + and not np.isclose(self.points[runner][1], y_here, 1e-3) + and not np.isclose(self.points[runner][2], z_here, 1e-3) + ): + raise ValueError( + "The provided wavefunction from Lobster does not contain all relevant" + " points. " + "Please use a line similar to: printLCAORealSpaceWavefunction kpoint 1 " + "coordinates 0.0 0.0 0.0 coordinates 1.0 1.0 1.0 box bandlist 1 " + ) + + new_x += [x_here] + new_y += [y_here] + new_z += [z_here] + + new_real += [self.real[runner]] + new_imaginary += [self.imaginary[runner]] + new_density += [self.real[runner] ** 2 + self.imaginary[runner] ** 2] self.final_real = np.reshape(new_real, [Nx, Ny, Nz]) self.final_imaginary = np.reshape(new_imaginary, [Nx, Ny, Nz]) @@ -1844,7 +1842,7 @@ def __init__( with zopen(filename, mode="rt") as file: data = file.read().split("\n")[5] if len(data) == 0: - raise OSError("MadelungEnergies file contains no data.") + raise RuntimeError("MadelungEnergies file contains no data.") line = data.split() self._filename = filename self.ewald_splitting = float(line[0]) @@ -1924,7 +1922,7 @@ def __init__( with zopen(filename, mode="rt") as file: data = file.read().split("\n") if len(data) == 0: - raise OSError("SitePotentials file contains no data.") + raise RuntimeError("SitePotentials file contains no data.") self._filename = filename self.ewald_splitting = float(data[0].split()[9]) @@ -2088,7 +2086,7 @@ def __init__(self, e_fermi=None, filename: str = "hamiltonMatrices.lobster"): with zopen(self._filename, mode="rt") as file: file_data = file.readlines() if len(file_data) == 0: - raise OSError("Please check provided input file, it seems to be empty") + raise RuntimeError("Please check provided input file, it seems to be empty") pattern_coeff_hamil_trans = r"(\d+)\s+kpoint\s+(\d+)" # regex pattern to extract spin and k-point number pattern_overlap = r"kpoint\s+(\d+)" # regex pattern to extract k-point number diff --git a/pymatgen/io/nwchem.py b/pymatgen/io/nwchem.py index 43510556ebd..8d7921800d7 100644 --- a/pymatgen/io/nwchem.py +++ b/pymatgen/io/nwchem.py @@ -24,6 +24,7 @@ import re import warnings from string import Template +from typing import TYPE_CHECKING import numpy as np from monty.io import zopen @@ -33,6 +34,11 @@ from pymatgen.core.structure import Molecule, Structure from pymatgen.core.units import Energy, FloatWithUnit +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + NWCHEM_BASIS_LIBRARY = None if os.getenv("NWCHEM_BASIS_LIBRARY"): NWCHEM_BASIS_LIBRARY = set(os.listdir(os.environ["NWCHEM_BASIS_LIBRARY"])) @@ -197,24 +203,24 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: NwTask """ - return NwTask( - charge=d["charge"], - spin_multiplicity=d["spin_multiplicity"], - title=d["title"], - theory=d["theory"], - operation=d["operation"], - basis_set=d["basis_set"], - basis_set_option=d["basis_set_option"], - theory_directives=d["theory_directives"], - alternate_directives=d["alternate_directives"], + return cls( + charge=dct["charge"], + spin_multiplicity=dct["spin_multiplicity"], + title=dct["title"], + theory=dct["theory"], + operation=dct["operation"], + basis_set=dct["basis_set"], + basis_set_option=dct["basis_set_option"], + theory_directives=dct["theory_directives"], + alternate_directives=dct["alternate_directives"], ) @classmethod @@ -230,7 +236,7 @@ def from_molecule( operation="optimize", theory_directives=None, alternate_directives=None, - ): + ) -> Self: """ Very flexible arguments to support many types of potential setups. Users should use more friendly static methods unless they need the @@ -278,7 +284,7 @@ def from_molecule( if isinstance(basis_set, str): basis_set = dict.fromkeys(elements, basis_set) - return NwTask( + return cls( charge, spin_multiplicity, basis_set, @@ -399,25 +405,25 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: NwInput """ - return NwInput( - Molecule.from_dict(d["mol"]), - tasks=[NwTask.from_dict(dt) for dt in d["tasks"]], - directives=[tuple(li) for li in d["directives"]], - geometry_options=d["geometry_options"], - symmetry_options=d["symmetry_options"], - memory_options=d["memory_options"], + return cls( + Molecule.from_dict(dct["mol"]), + tasks=[NwTask.from_dict(dt) for dt in dct["tasks"]], + directives=[tuple(li) for li in dct["directives"]], + geometry_options=dct["geometry_options"], + symmetry_options=dct["symmetry_options"], + memory_options=dct["memory_options"], ) @classmethod - def from_str(cls, string_input): + def from_str(cls, string_input: str) -> Self: """ Read an NwInput from a string. Currently tested to work with files generated from this class itself. @@ -432,8 +438,10 @@ def from_str(cls, string_input): tasks = [] charge = spin_multiplicity = title = basis_set = None basis_set_option = None - theory_directives = {} + mol = None + theory_directives: dict[str, dict[str, str]] = {} geom_options = symmetry_options = memory_options = None + lines = string_input.strip().split("\n") while len(lines) > 0: line = lines.pop(0).strip() @@ -501,7 +509,7 @@ def from_str(cls, string_input): else: directives.append(line.strip().split()) - return NwInput( + return cls( mol, tasks=tasks, directives=directives, @@ -511,7 +519,7 @@ def from_str(cls, string_input): ) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Read an NwInput from a file. Currently tested to work with files generated from this class itself. @@ -614,8 +622,7 @@ def parse_tddft(self): def get_excitation_spectrum(self, width=0.1, npoints=2000): """ - Generate an excitation spectra from the singlet roots of TDDFT - calculations. + Generate an excitation spectra from the singlet roots of TDDFT calculations. Args: width (float): Width for Gaussian smearing. @@ -623,8 +630,7 @@ def get_excitation_spectrum(self, width=0.1, npoints=2000): curve. Returns: - (ExcitationSpectrum) which can be plotted using - pymatgen.vis.plotters.SpectrumPlotter. + ExcitationSpectrum: can be plotted using pymatgen.vis.plotters.SpectrumPlotter. """ roots = self.parse_tddft() data = roots["singlet"] @@ -916,15 +922,15 @@ def isfloatstring(in_str): for _freq, mode in normal_frequencies: mode[:] = zip(*[iter(mode)] * 3) if hessian: - n = len(hessian) - for i in range(n): - for j in range(i + 1, n): - hessian[i].append(hessian[j][i]) + len_hess = len(hessian) + for ii in range(len_hess): + for jj in range(ii + 1, len_hess): + hessian[ii].append(hessian[jj][ii]) if projected_hessian: - n = len(projected_hessian) - for i in range(n): - for j in range(i + 1, n): - projected_hessian[i].append(projected_hessian[j][i]) + len_hess = len(projected_hessian) + for ii in range(len_hess): + for jj in range(ii + 1, len_hess): + projected_hessian[ii].append(projected_hessian[jj][ii]) data.update( { diff --git a/pymatgen/io/openff.py b/pymatgen/io/openff.py new file mode 100644 index 00000000000..ecb892a9c0e --- /dev/null +++ b/pymatgen/io/openff.py @@ -0,0 +1,315 @@ +"""Utility functions for classical md subpackage.""" + +from __future__ import annotations + +import warnings +from pathlib import Path + +import numpy as np + +import pymatgen +from pymatgen.analysis.graphs import MoleculeGraph +from pymatgen.analysis.local_env import OpenBabelNN, metal_edge_extender +from pymatgen.core import Element, Molecule + +try: + import openff.toolkit as tk + from openff.units import Quantity, unit +except ImportError: + tk = None + Quantity = None + unit = None + warnings.warn( + "To use the pymatgen.io.openff module install openff-toolkit and openff-units" + "with `conda install -c conda-forge openff-toolkit openff-units`." + ) + + +def mol_graph_to_openff_mol(mol_graph: MoleculeGraph) -> tk.Molecule: + """ + Convert a Pymatgen MoleculeGraph to an OpenFF Molecule. + + Args: + mol_graph (MoleculeGraph): The Pymatgen MoleculeGraph to be converted. + + Returns: + tk.Molecule: The converted OpenFF Molecule. + """ + # create empty openff_mol and prepare a periodic table + p_table = {str(el): el.Z for el in Element} + openff_mol = tk.Molecule() + + # set atom properties + partial_charges = [] + # TODO: should assert that there is only one molecule + for i_node in range(len(mol_graph.graph.nodes)): + node = mol_graph.graph.nodes[i_node] + atomic_number = node.get("atomic_number") or p_table[mol_graph.molecule[i_node].species_string] + + # put formal charge on first atom if there is none present + formal_charge = node.get("formal_charge") + if formal_charge is None: + formal_charge = (i_node == 0) * int(round(mol_graph.molecule.charge, 0)) * unit.elementary_charge + + # assume not aromatic if no info present + is_aromatic = node.get("is_aromatic") or False + + openff_mol.add_atom(atomic_number, formal_charge, is_aromatic=is_aromatic) + + # add to partial charge array + partial_charge = node.get("partial_charge") + if isinstance(partial_charge, Quantity): + partial_charge = partial_charge.magnitude + partial_charges.append(partial_charge) + + charge_array = np.array(partial_charges) + if np.not_equal(charge_array, None).all(): + openff_mol.partial_charges = charge_array * unit.elementary_charge + + # set edge properties, default to single bond and assume not aromatic + for i_node, j, bond_data in mol_graph.graph.edges(data=True): + bond_order = bond_data.get("bond_order") or 1 + is_aromatic = bond_data.get("is_aromatic") or False + openff_mol.add_bond(i_node, j, bond_order, is_aromatic=is_aromatic) + + openff_mol.add_conformer(mol_graph.molecule.cart_coords * unit.angstrom) + return openff_mol + + +def mol_graph_from_openff_mol(molecule: tk.Molecule) -> MoleculeGraph: + """ + This is designed to closely mirror the graph structure generated by tk.Molecule.to_networkx + + Args: + molecule (tk.Molecule): The OpenFF Molecule to convert. + + Returns: + MoleculeGraph: The converted MoleculeGraph. + """ + mol_graph = MoleculeGraph.with_empty_graph( + Molecule([], []), + name="none", + ) + p_table = {el.Z: str(el) for el in Element} + total_charge = 0 + cum_atoms = 0 + + coords = molecule.conformers[0].magnitude if molecule.conformers is not None else np.zeros((molecule.n_atoms, 3)) + for idx, atom in enumerate(molecule.atoms): + mol_graph.insert_node( + cum_atoms + idx, + p_table[atom.atomic_number], + coords[idx, :], + ) + mol_graph.graph.nodes[cum_atoms + idx]["atomic_number"] = atom.atomic_number + mol_graph.graph.nodes[cum_atoms + idx]["is_aromatic"] = atom.is_aromatic + mol_graph.graph.nodes[cum_atoms + idx]["stereochemistry"] = atom.stereochemistry + # set partial charge as a pure float + partial_charge = None if atom.partial_charge is None else atom.partial_charge.magnitude + mol_graph.graph.nodes[cum_atoms + idx]["partial_charge"] = partial_charge + # set formal charge as a pure float + formal_charge = atom.formal_charge.magnitude + mol_graph.graph.nodes[cum_atoms + idx]["formal_charge"] = formal_charge + total_charge += formal_charge + for bond in molecule.bonds: + mol_graph.graph.add_edge( + cum_atoms + bond.atom1_index, + cum_atoms + bond.atom2_index, + bond_order=bond.bond_order, + is_aromatic=bond.is_aromatic, + stereochemistry=bond.stereochemistry, + ) + # formal_charge += molecule.total_charge + cum_atoms += molecule.n_atoms + mol_graph.molecule.set_charge_and_spin(charge=total_charge) + return mol_graph + + +def get_atom_map(inferred_mol: tk.Molecule, openff_mol: tk.Molecule) -> tuple[bool, dict[int, int]]: + """ + Compute an atom mapping between two OpenFF Molecules. + + Attempts to find an isomorphism between the molecules, considering various matching + criteria such as formal charges, stereochemistry, and bond orders. Returns the atom + mapping if an isomorphism is found, otherwise returns an empty mapping. + + Args: + inferred_mol (tk.Molecule): The first OpenFF Molecule. + openff_mol (tk.Molecule): The second OpenFF Molecule. + + Returns: + Tuple[bool, Dict[int, int]]: A tuple containing a boolean indicating if an + isomorphism was found and a dictionary representing the atom mapping. + """ + # do not apply formal charge restrictions + kwargs = dict( + return_atom_map=True, + formal_charge_matching=False, + ) + isomorphic, atom_map = tk.topology.Molecule.are_isomorphic(openff_mol, inferred_mol, **kwargs) + if isomorphic: + return True, atom_map + # relax stereochemistry restrictions + kwargs["atom_stereochemistry_matching"] = False + kwargs["bond_stereochemistry_matching"] = False + isomorphic, atom_map = tk.topology.Molecule.are_isomorphic(openff_mol, inferred_mol, **kwargs) + if isomorphic: + return True, atom_map + # relax bond order restrictions + kwargs["bond_order_matching"] = False + isomorphic, atom_map = tk.topology.Molecule.are_isomorphic(openff_mol, inferred_mol, **kwargs) + if isomorphic: + return True, atom_map + return False, {} + + +def infer_openff_mol( + mol_geometry: pymatgen.core.Molecule, +) -> tk.Molecule: + """Infer an OpenFF Molecule from a Pymatgen Molecule. + + Constructs a MoleculeGraph from the Pymatgen Molecule using the OpenBabelNN local + environment strategy and extends metal edges. Converts the resulting MoleculeGraph + to an OpenFF Molecule using mol_graph_to_openff_mol. + + Args: + mol_geometry (pymatgen.core.Molecule): The Pymatgen Molecule to infer from. + + Returns: + tk.Molecule: The inferred OpenFF Molecule. + """ + mol_graph = MoleculeGraph.with_local_env_strategy(mol_geometry, OpenBabelNN()) + mol_graph = metal_edge_extender(mol_graph) + return mol_graph_to_openff_mol(mol_graph) + + +def add_conformer( + openff_mol: tk.Molecule, geometry: pymatgen.core.Molecule | None +) -> tuple[tk.Molecule, dict[int, int]]: + """ + Add conformers to an OpenFF Molecule based on the provided geometry. + + If a geometry is provided, infers an OpenFF Molecule from it, + finds an atom mapping between the inferred molecule and the + input molecule, and adds the conformer coordinates to the input + molecule. If no geometry is provided, generates a single conformer. + + Args: + openff_mol (tk.Molecule): The OpenFF Molecule to add conformers to. + geometry (Union[pymatgen.core.Molecule, None]): The geometry to use for adding + conformers. + + Returns: + Tuple[tk.Molecule, Dict[int, int]]: A tuple containing the updated OpenFF + Molecule with added conformers and a dictionary representing the atom + mapping. + """ + # TODO: test this + if geometry: + # for geometry in geometries: + inferred_mol = infer_openff_mol(geometry) + is_isomorphic, atom_map = get_atom_map(inferred_mol, openff_mol) + if not is_isomorphic: + raise ValueError( + f"An isomorphism cannot be found between smile {openff_mol.to_smiles()}" + f"and the provided molecule {geometry}." + ) + new_mol = pymatgen.core.Molecule.from_sites([geometry.sites[i] for i in atom_map.values()]) + openff_mol.add_conformer(new_mol.cart_coords * unit.angstrom) + else: + atom_map = {i: i for i in range(openff_mol.n_atoms)} + openff_mol.generate_conformers(n_conformers=1) + return openff_mol, atom_map + + +def assign_partial_charges( + openff_mol: tk.Molecule, + atom_map: dict[int, int], + charge_method: str, + partial_charges: None | list[float], +) -> tk.Molecule: + """ + Assign partial charges to an OpenFF Molecule. + + If partial charges are provided, assigns them to the molecule + based on the atom mapping. If the molecule has only one atom, + assigns the total charge as the partial charge. Otherwise, + assigns partial charges using the specified charge method. + + Args: + openff_mol (tk.Molecule): The OpenFF Molecule to assign partial charges to. + atom_map (Dict[int, int]): A dictionary representing the atom mapping. + charge_method (str): The charge method to use if partial charges are + not provided. + partial_charges (Union[None, List[float]]): A list of partial charges to + assign or None to use the charge method. + + Returns: + tk.Molecule: The OpenFF Molecule with assigned partial charges. + """ + # TODO: test this + # assign partial charges + if partial_charges is not None: + partial_charges = np.array(partial_charges) + chargs = partial_charges[list(atom_map.values())] # type: ignore[index, call-overload] + openff_mol.partial_charges = chargs * unit.elementary_charge + elif openff_mol.n_atoms == 1: + openff_mol.partial_charges = np.array([openff_mol.total_charge.magnitude]) * unit.elementary_charge + else: + openff_mol.assign_partial_charges(charge_method) + return openff_mol + + +def create_openff_mol( + smile: str, + geometry: pymatgen.core.Molecule | str | Path | None = None, + charge_scaling: float = 1, + partial_charges: list[float] | None = None, + backup_charge_method: str = "am1bcc", +) -> tk.Molecule: + """ + Create an OpenFF Molecule from a SMILES string and optional geometry. + + Constructs an OpenFF Molecule from the provided SMILES + string, adds conformers based on the provided geometry (if + any), assigns partial charges using the specified method + or provided partial charges, and applies charge scaling. + + Args: + smile (str): The SMILES string of the molecule. + geometry (Union[pymatgen.core.Molecule, str, Path, None], optional): The + geometry to use for adding conformers. Can be a Pymatgen Molecule, + file path, or None. + charge_scaling (float, optional): The scaling factor for partial charges. + Default is 1. + partial_charges (Union[List[float], None], optional): A list of partial + charges to assign, or None to use the charge method. + backup_charge_method (str, optional): The backup charge method to use if + partial charges are not provided. Default is "am1bcc". + + Returns: + tk.Molecule: The created OpenFF Molecule. + """ + if isinstance(geometry, (str, Path)): + geometry = pymatgen.core.Molecule.from_file(str(geometry)) + + if partial_charges is not None: + if geometry is None: + raise ValueError("geometries must be set if partial_charges is set") + if len(partial_charges) != len(geometry): + raise ValueError("partial charges must have same length & order as geometry") + + openff_mol = tk.Molecule.from_smiles(smile, allow_undefined_stereo=True) + + # add conformer + openff_mol, atom_map = add_conformer(openff_mol, geometry) + # assign partial charges + openff_mol = assign_partial_charges( + openff_mol, + atom_map, + backup_charge_method, + partial_charges, + ) + openff_mol.partial_charges *= charge_scaling + + return openff_mol diff --git a/pymatgen/io/packmol.py b/pymatgen/io/packmol.py index a7506a2eacd..7bd24d85df1 100644 --- a/pymatgen/io/packmol.py +++ b/pymatgen/io/packmol.py @@ -88,7 +88,7 @@ def run(self, path: str | Path, timeout=30): os.chdir(wd) @classmethod - def from_directory(cls, directory: str | Path): + def from_directory(cls, directory: str | Path) -> None: """ Construct an InputSet from a directory of one or more files. @@ -184,6 +184,9 @@ def get_input_set( # type: ignore net_volume = 0.0 for d in molecules: mol = Molecule.from_file(d["coords"]) if not isinstance(d["coords"], Molecule) else d["coords"] + + if mol is None: + raise ValueError("Molecule cannot be None.") # pad the calculated length by an amount related to the tolerance parameter # the amount to add was determined arbitrarily length = ( @@ -196,14 +199,19 @@ def get_input_set( # type: ignore box_list = f"0.0 0.0 0.0 {box_length:.1f} {box_length:.1f} {box_length:.1f}" for d in molecules: + mol = None if isinstance(d["coords"], str): mol = Molecule.from_file(d["coords"]) elif isinstance(d["coords"], Path): mol = Molecule.from_file(str(d["coords"])) elif isinstance(d["coords"], Molecule): mol = d["coords"] + + if mol is None: + raise ValueError("Molecule cannot be None.") + fname = f"packmol_{d['name']}.xyz" - mapping.update({fname: mol.to(fmt="xyz")}) + mapping[fname] = mol.to(fmt="xyz") if " " in str(fname): # NOTE - double quotes are deliberately used inside the f-string here, do not change # fmt: off diff --git a/pymatgen/io/phonopy.py b/pymatgen/io/phonopy.py index 4954bd99071..79a6a073ed4 100644 --- a/pymatgen/io/phonopy.py +++ b/pymatgen/io/phonopy.py @@ -83,12 +83,12 @@ def get_structure_from_dict(dct): return Structure(dct["lattice"], species, frac_coords, site_properties={"phonopy_masses": masses}) -def eigvec_to_eigdispl(v, q, frac_coords, mass): - r""" +def eigvec_to_eigdispl(eig_vec, q, frac_coords, mass): + """ Converts a single eigenvector to an eigendisplacement in the primitive cell - according to the formula:: + according to the formula: - exp(2*pi*i*(frac_coords \\dot q) / sqrt(mass) * v + exp(2*pi*i*(frac_coords dot q) / sqrt(mass) * v Compared to the modulation option in phonopy, here all the additional multiplicative and phase factors are set to 1. @@ -101,7 +101,7 @@ def eigvec_to_eigdispl(v, q, frac_coords, mass): """ c = np.exp(2j * np.pi * np.dot(frac_coords, q)) / np.sqrt(mass) - return c * v + return c * eig_vec def get_ph_bs_symm_line_from_dict(bands_dict, has_nac=False, labels_dict=None): @@ -110,7 +110,7 @@ def get_ph_bs_symm_line_from_dict(bands_dict, has_nac=False, labels_dict=None): extracted by the band.yaml file produced by phonopy. The labels will be extracted from the dictionary, if present. If the 'eigenvector' key is found the eigendisplacements will be calculated according to the - formula:: + formula: exp(2*pi*i*(frac_coords \\dot q) / sqrt(mass) * v @@ -129,31 +129,31 @@ def get_ph_bs_symm_line_from_dict(bands_dict, has_nac=False, labels_dict=None): frequencies = [] eigen_displacements = [] phonopy_labels_dict = {} - for p in bands_dict["phonon"]: - q = p["q-position"] - q_pts.append(q) + for phonon in bands_dict["phonon"]: + q_pos = phonon["q-position"] + q_pts.append(q_pos) bands = [] eig_q = [] - for b in p["band"]: - bands.append(b["frequency"]) - if "eigenvector" in b: + for band in phonon["band"]: + bands.append(band["frequency"]) + if "eigenvector" in band: eig_b = [] - for i, eig_a in enumerate(b["eigenvector"]): - v = np.zeros(3, complex) + for idx, eig_a in enumerate(band["eigenvector"]): + eig_vec = np.zeros(3, complex) for x in range(3): - v[x] = eig_a[x][0] + eig_a[x][1] * 1j + eig_vec[x] = eig_a[x][0] + eig_a[x][1] * 1j eig_b.append( eigvec_to_eigdispl( - v, - q, - structure[i].frac_coords, - structure.site_properties["phonopy_masses"][i], + eig_vec, + q_pos, + structure[idx].frac_coords, + structure.site_properties["phonopy_masses"][idx], ) ) eig_q.append(eig_b) frequencies.append(bands) - if "label" in p: - phonopy_labels_dict[p["label"]] = p["q-position"] + if "label" in phonon: + phonopy_labels_dict[phonon["label"]] = phonon["q-position"] if eig_q: eigen_displacements.append(eig_q) @@ -246,7 +246,7 @@ def get_displaced_structures(pmg_structure, atom_disp=0.01, supercell_matrix=Non the outputting displacement yaml file, e.g. disp.yaml. **kwargs: Parameters used in Phonopy.generate_displacement method. - Return: + Returns: A list of symmetrically inequivalent structures with displacements, in which the first element is the perfect supercell structure. """ @@ -442,8 +442,8 @@ def get_gruneisenparameter(gruneisen_path, structure=None, structure_path=None) phonopy_labels_dict = {} for p in gruneisen_dict["phonon"]: - q = p["q-position"] - q_pts.append(q) + q_pos = p["q-position"] + q_pts.append(q_pos) m = p.get("multiplicity", 1) multiplicities.append(m) bands, gruneisenband = [], [] @@ -479,7 +479,7 @@ def get_gs_ph_bs_symm_line_from_dict( extracted by the gruneisen.yaml file produced by phonopy. The labels will be extracted from the dictionary, if present. If the 'eigenvector' key is found the eigendisplacements will be calculated according to the - formula:: + formula: exp(2*pi*i*(frac_coords \\dot q) / sqrt(mass) * v @@ -520,22 +520,22 @@ def get_gs_ph_bs_symm_line_from_dict( q_pts_temp, frequencies_temp = [], [] gruneisen_temp: list[list[float]] = [] distance: list[float] = [] - for i in range(pa["nqpoint"]): + for idx in range(pa["nqpoint"]): bands = [] gruneisen_band: list[float] = [] - for b in phonon[pa["nqpoint"] - i - 1]["band"]: + for b in phonon[pa["nqpoint"] - idx - 1]["band"]: bands.append(b["frequency"]) # Fraction of leftover points in current band - gruen = _extrapolate_grun(b, distance, gruneisen_temp, gruneisen_band, i, pa) + gruen = _extrapolate_grun(b, distance, gruneisen_temp, gruneisen_band, idx, pa) gruneisen_band.append(gruen) - q = phonon[pa["nqpoint"] - i - 1]["q-position"] - q_pts_temp.append(q) - d = phonon[pa["nqpoint"] - i - 1]["distance"] + q_pos = phonon[pa["nqpoint"] - idx - 1]["q-position"] + q_pts_temp.append(q_pos) + d = phonon[pa["nqpoint"] - idx - 1]["distance"] distance.append(d) frequencies_temp.append(bands) gruneisen_temp.append(gruneisen_band) - if "label" in phonon[pa["nqpoint"] - i - 1]: - phonopy_labels_dict[phonon[pa["nqpoint"] - i - 1]]["label"] = phonon[pa["nqpoint"] - i - 1][ + if "label" in phonon[pa["nqpoint"] - idx - 1]: + phonopy_labels_dict[phonon[pa["nqpoint"] - idx - 1]]["label"] = phonon[pa["nqpoint"] - idx - 1][ "q-position" ] @@ -545,42 +545,42 @@ def get_gs_ph_bs_symm_line_from_dict( elif end["q-position"] == [0, 0, 0]: # Gamma at end of band distance = [] - for i in range(pa["nqpoint"]): + for idx in range(pa["nqpoint"]): bands, gruneisen_band = [], [] - for b in phonon[i]["band"]: + for b in phonon[idx]["band"]: bands.append(b["frequency"]) - gruen = _extrapolate_grun(b, distance, gruneisen_params, gruneisen_band, i, pa) + gruen = _extrapolate_grun(b, distance, gruneisen_params, gruneisen_band, idx, pa) gruneisen_band.append(gruen) - q = phonon[i]["q-position"] - q_points.append(q) - d = phonon[i]["distance"] + q_pos = phonon[idx]["q-position"] + q_points.append(q_pos) + d = phonon[idx]["distance"] distance.append(d) frequencies.append(bands) gruneisen_params.append(gruneisen_band) - if "label" in phonon[i]: - phonopy_labels_dict[phonon[i]["label"]] = phonon[i]["q-position"] + if "label" in phonon[idx]: + phonopy_labels_dict[phonon[idx]["label"]] = phonon[idx]["q-position"] else: # No Gamma in band distance = [] - for i in range(pa["nqpoint"]): + for idx in range(pa["nqpoint"]): bands, gruneisen_band = [], [] - for b in phonon[i]["band"]: + for b in phonon[idx]["band"]: bands.append(b["frequency"]) gruneisen_band.append(b["gruneisen"]) - q = phonon[i]["q-position"] - q_points.append(q) - d = phonon[i]["distance"] + q_pos = phonon[idx]["q-position"] + q_points.append(q_pos) + d = phonon[idx]["distance"] distance.append(d) frequencies.append(bands) gruneisen_params.append(gruneisen_band) - if "label" in phonon[i]: - phonopy_labels_dict[phonon[i]["label"]] = phonon[i]["q-position"] + if "label" in phonon[idx]: + phonopy_labels_dict[phonon[idx]["label"]] = phonon[idx]["q-position"] else: for pa in gruneisen_dict["path"]: for p in pa["phonon"]: - q = p["q-position"] - q_points.append(q) + q_pos = p["q-position"] + q_points.append(q_pos) bands, gruneisen_bands = [], [] for b in p["band"]: bands.append(b["frequency"]) diff --git a/pymatgen/io/pwmat/inputs.py b/pymatgen/io/pwmat/inputs.py index 05c20f4357c..8598a3f5a2d 100644 --- a/pymatgen/io/pwmat/inputs.py +++ b/pymatgen/io/pwmat/inputs.py @@ -13,6 +13,8 @@ from pymatgen.symmetry.kpath import KPathSeek if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.util.typing import PathLike __author__ = "Hanyu Liu" @@ -375,7 +377,7 @@ def __str__(self): return self.get_str() @classmethod - def from_str(cls, data: str, mag: bool = False) -> AtomConfig: + def from_str(cls, data: str, mag: bool = False) -> Self: """Reads a atom.config from a string. Args: @@ -402,7 +404,7 @@ def from_str(cls, data: str, mag: bool = False) -> AtomConfig: return cls(structure) @classmethod - def from_file(cls, filename: PathLike, mag: bool = False) -> AtomConfig: + def from_file(cls, filename: PathLike, mag: bool = False) -> Self: """Returns a AtomConfig from a file Args: @@ -416,7 +418,7 @@ def from_file(cls, filename: PathLike, mag: bool = False) -> AtomConfig: return cls.from_str(data=file.read(), mag=mag) @classmethod - def from_dict(cls, dct: dict) -> AtomConfig: + def from_dict(cls, dct: dict) -> Self: """Returns a AtomConfig object from a dictionary. Args: @@ -434,17 +436,16 @@ def get_str(self) -> str: str: String representation of atom.config """ # This corrects for VASP really annoying bug of crashing on lattices - # which have triple product < 0. We will just invert the lattice - # vectors. - latt = self.structure.lattice - if np.linalg.det(latt.matrix) < 0: - latt = Lattice(-latt.matrix) + # which have triple product < 0. We will just invert the lattice vectors. + lattice = self.structure.lattice + if np.linalg.det(lattice.matrix) < 0: + lattice = Lattice(-lattice.matrix) lines: list[str] = [] lines.append(f"\t{self.structure.num_sites} atoms\n") lines.append("Lattice vector\n") for ii in range(3): - lines.append(f"{latt.matrix[ii][0]:>15f}{latt.matrix[ii][1]:>15f}{latt.matrix[ii][2]:>15f}\n") + lines.append(f"{lattice.matrix[ii][0]:>15f}{lattice.matrix[ii][1]:>15f}{lattice.matrix[ii][2]:>15f}\n") lines.append("Position, move_x, move_y, move_z\n") for ii in range(self.structure.num_sites): lines.append(f"{int(self.structure.species[ii].Z):>4d}") @@ -500,8 +501,8 @@ def __init__( self.kpath.update({"path": path}) self.density = density - @staticmethod - def from_structure(structure: Structure, dim: int, density: float = 0.01) -> GenKpt: + @classmethod + def from_structure(cls, structure: Structure, dim: int, density: float = 0.01) -> Self: """Obtain a AtomConfig object from Structure object. Args: @@ -531,7 +532,7 @@ def from_structure(structure: Structure, dim: int, density: float = 0.01) -> Gen kpts = kpath_set.kpath["kpoints"] path = kpath_set.kpath["path"] rec_lattice: np.ndarray = structure.lattice.reciprocal_lattice.matrix # with 2*pi - return GenKpt(rec_lattice, kpts, path, density * 2 * np.pi) + return cls(rec_lattice, kpts, path, density * 2 * np.pi) def get_str(self): """Returns a string to be written as a gen.kpt file.""" @@ -602,8 +603,8 @@ def __init__(self, reciprocal_lattice: np.ndarray, kpts: dict[str, list], path: self.kpath.update({"path": path}) self.density = density - @staticmethod - def from_structure(structure: Structure, dim: int, density: float = 0.01) -> HighSymmetryPoint: + @classmethod + def from_structure(cls, structure: Structure, dim: int, density: float = 0.01) -> Self: """Obtain HighSymmetry object from Structure object. Args: @@ -614,9 +615,7 @@ def from_structure(structure: Structure, dim: int, density: float = 0.01) -> Hig """ reciprocal_lattice: np.ndarray = structure.lattice.reciprocal_lattice.matrix gen_kpt = GenKpt.from_structure(structure=structure, dim=dim, density=density) - return HighSymmetryPoint( - reciprocal_lattice, gen_kpt.kpath["kpoints"], gen_kpt.kpath["path"], density * 2 * np.pi - ) + return cls(reciprocal_lattice, gen_kpt.kpath["kpoints"], gen_kpt.kpath["path"], density * 2 * np.pi) def get_str(self) -> str: """Returns a string describing high symmetry points in HIGH_SYMMETRY_POINTS format.""" diff --git a/pymatgen/io/pwmat/outputs.py b/pymatgen/io/pwmat/outputs.py index a2e8e54508a..456c7cdbacb 100644 --- a/pymatgen/io/pwmat/outputs.py +++ b/pymatgen/io/pwmat/outputs.py @@ -77,7 +77,7 @@ def atom_configs(self) -> list[Structure]: Returns: list[Structure]: List of Structure objects for the structure at each ionic step. """ - return [step["atom_config"] for _, step in enumerate(self.ionic_steps)] + return [step["atom_config"] for step in self.ionic_steps] @property def e_tots(self) -> np.ndarray: @@ -87,7 +87,7 @@ def e_tots(self) -> np.ndarray: np.ndarray: Total energy of of each ionic step structure, with shape of (n_ionic_steps,). """ - return np.array([step["e_tot"] for _, step in enumerate(self.ionic_steps)]) + return np.array([step["e_tot"] for step in self.ionic_steps]) @property def atom_forces(self) -> np.ndarray: @@ -97,7 +97,7 @@ def atom_forces(self) -> np.ndarray: np.ndarray: The forces on atoms of each ionic step structure, with shape of (n_ionic_steps, n_atoms, 3). """ - return np.array([step["atom_forces"] for _, step in enumerate(self.ionic_steps)]) + return np.array([step["atom_forces"] for step in self.ionic_steps]) @property def e_atoms(self) -> np.ndarray: @@ -109,7 +109,7 @@ def e_atoms(self) -> np.ndarray: np.ndarray: The individual energy of atoms in each ionic step structure, with shape of (n_ionic_steps, n_atoms). """ - return np.array([step["eatoms"] for _, step in enumerate(self.ionic_steps) if ("eatoms" in step)]) + return np.array([step["eatoms"] for step in self.ionic_steps if ("eatoms" in step)]) @property def virials(self) -> np.ndarray: @@ -119,7 +119,7 @@ def virials(self) -> np.ndarray: np.ndarray: The virial tensor of each ionic step structure, with shape of (n_ionic_steps, 3, 3) """ - return np.array([step["virial"] for _, step in enumerate(self.ionic_steps) if ("virial" in step)]) + return np.array([step["virial"] for step in self.ionic_steps if ("virial" in step)]) def _parse_sefv(self) -> list[dict]: """ diff --git a/pymatgen/io/pwscf.py b/pymatgen/io/pwscf.py index d7b52c1dc6c..b68d7df1648 100644 --- a/pymatgen/io/pwscf.py +++ b/pymatgen/io/pwscf.py @@ -4,6 +4,7 @@ import re from collections import defaultdict +from typing import TYPE_CHECKING from monty.io import zopen from monty.re import regrep @@ -11,6 +12,11 @@ from pymatgen.core import Element, Lattice, Structure from pymatgen.util.io_utils import clean_lines +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + class PWInput: """ @@ -133,7 +139,7 @@ def to_str(v): out.append("ATOMIC_SPECIES") for k, v in sorted(site_descriptions.items(), key=lambda i: i[0]): - e = re.match(r"[A-Z][a-z]?", k).group(0) + e = re.match(r"[A-Z][a-z]?", k)[0] p = v if self.pseudo is not None else v["pseudo"] out.append(f" {k} {Element(e).atomic_mass:.4f} {p}") @@ -184,27 +190,27 @@ def as_dict(self): } @classmethod - def from_dict(cls, pwinput_dict): + def from_dict(cls, dct: dict) -> Self: """ Load a PWInput object from a dictionary. Args: - pwinput_dict (dict): dictionary with PWInput data + dct (dict): dictionary with PWInput data Returns: PWInput object """ return cls( - structure=Structure.from_dict(pwinput_dict["structure"]), - pseudo=pwinput_dict["pseudo"], - control=pwinput_dict["sections"]["control"], - system=pwinput_dict["sections"]["system"], - electrons=pwinput_dict["sections"]["electrons"], - ions=pwinput_dict["sections"]["ions"], - cell=pwinput_dict["sections"]["cell"], - kpoints_mode=pwinput_dict["kpoints_mode"], - kpoints_grid=pwinput_dict["kpoints_grid"], - kpoints_shift=pwinput_dict["kpoints_shift"], + structure=Structure.from_dict(dct["structure"]), + pseudo=dct["pseudo"], + control=dct["sections"]["control"], + system=dct["sections"]["system"], + electrons=dct["sections"]["electrons"], + ions=dct["sections"]["ions"], + cell=dct["sections"]["cell"], + kpoints_mode=dct["kpoints_mode"], + kpoints_grid=dct["kpoints_grid"], + kpoints_shift=dct["kpoints_shift"], ) def write_file(self, filename): @@ -214,16 +220,16 @@ def write_file(self, filename): Args: filename (str): The string filename to output to. """ - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: file.write(str(self)) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Reads an PWInput object from a file. Args: - filename (str): Filename for file + filename (str | Path): Filename for file Returns: PWInput object @@ -232,7 +238,7 @@ def from_file(cls, filename): return cls.from_str(file.read()) @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """ Reads an PWInput object from a string. @@ -259,7 +265,7 @@ def input_mode(line): return None return mode - sections = { + sections: dict[str, dict] = { "control": {}, "system": {}, "electrons": {}, @@ -271,19 +277,23 @@ def input_mode(line): species = [] coords = [] structure = None - site_properties = {"pseudo": []} + site_properties: dict[str, list] = {"pseudo": []} mode = None + kpoints_mode = None + kpoints_grid = (1, 1, 1) + kpoints_shift = (0, 0, 0) + coords_are_cartesian = False + for line in lines: mode = input_mode(line) if mode is None: pass elif mode[0] == "sections": section = mode[1] - m = re.match(r"(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line) - if m: - key = m.group(1).strip() - key_ = m.group(2).strip() - val = m.group(3).strip() + if match := re.match(r"(\w+)\(?(\d*?)\)?\s*=\s*(.*)", line): + key = match[1].strip() + key_ = match[2].strip() + val = match[3].strip() if key_ != "": if sections[section].get(key) is None: val_ = [0.0] * 20 # MAX NTYP DEFINITION @@ -297,37 +307,34 @@ def input_mode(line): sections[section][key] = PWInput.proc_val(key, val) elif mode[0] == "pseudo": - m = re.match(r"(\w+)\s+(\d*.\d*)\s+(.*)", line) - if m: - pseudo[m.group(1).strip()] = m.group(3).strip() + if match := re.match(r"(\w+)\s+(\d*.\d*)\s+(.*)", line): + pseudo[match[1].strip()] = match[3].strip() + elif mode[0] == "kpoints": - m = re.match(r"(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)", line) - if m: - kpoints_grid = (int(m.group(1)), int(m.group(2)), int(m.group(3))) - kpoints_shift = (int(m.group(4)), int(m.group(5)), int(m.group(6))) + if match := re.match(r"(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)\s+(\d+)", line): + kpoints_grid = (int(match[1]), int(match[2]), int(match[3])) + kpoints_shift = (int(match[4]), int(match[5]), int(match[6])) else: kpoints_mode = mode[1] - kpoints_grid = (1, 1, 1) - kpoints_shift = (0, 0, 0) elif mode[0] == "structure": m_l = re.match(r"(-?\d+\.?\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) m_p = re.match(r"(\w+)\s+(-?\d+\.\d*)\s+(-?\d+\.?\d*)\s+(-?\d+\.?\d*)", line) if m_l: lattice += [ - float(m_l.group(1)), - float(m_l.group(2)), - float(m_l.group(3)), + float(m_l[1]), + float(m_l[2]), + float(m_l[3]), ] + elif m_p: - site_properties["pseudo"].append(pseudo[m_p.group(1)]) - species.append(m_p.group(1)) - coords += [[float(m_p.group(2)), float(m_p.group(3)), float(m_p.group(4))]] + site_properties["pseudo"].append(pseudo[m_p[1]]) + species.append(m_p[1]) + coords += [[float(m_p[2]), float(m_p[3]), float(m_p[4])]] if mode[1] == "angstrom": coords_are_cartesian = True - elif mode[1] == "crystal": - coords_are_cartesian = False + structure = Structure( Lattice(lattice), species, @@ -466,10 +473,10 @@ def smart_int_or_float(numstr): raise ValueError(key + " should be a boolean type!") if key in float_keys: - return float(re.search(r"^-?\d*\.?\d*d?-?\d*", val.lower()).group(0).replace("d", "e")) + return float(re.search(r"^-?\d*\.?\d*d?-?\d*", val.lower())[0].replace("d", "e")) if key in int_keys: - return int(re.match(r"^-?[0-9]+", val).group(0)) + return int(re.match(r"^-?[0-9]+", val)[0]) except ValueError: pass @@ -484,9 +491,8 @@ def smart_int_or_float(numstr): if "false" in val.lower(): return False - m = re.match(r"^[\"|'](.+)[\"|']$", val) - if m: - return m.group(1) + if match := re.match(r"^[\"|'](.+)[\"|']$", val): + return match[1] return None diff --git a/pymatgen/io/qchem/inputs.py b/pymatgen/io/qchem/inputs.py index 79fc744348a..99678728278 100644 --- a/pymatgen/io/qchem/inputs.py +++ b/pymatgen/io/qchem/inputs.py @@ -16,6 +16,8 @@ if TYPE_CHECKING: from pathlib import Path + from typing_extensions import Self + __author__ = "Brandon Wood, Samuel Blau, Shyam Dwaraknath, Julian Self, Evan Spotte-Smith, Ryan Kingsbury" __copyright__ = "Copyright 2018-2022, The Materials Project" __version__ = "0.1" @@ -61,7 +63,7 @@ def __init__( Molecule objects, or as the string "read". "read" can be used in multi_job QChem input files where the molecule is read in from the previous calculation. rem (dict): - A dictionary of all the input parameters for the rem section of QChem input file. + A dictionary of all the input parameters for the REM section of QChem input file. Ex. rem = {'method': 'rimp2', 'basis': '6-31*G++' ... } opt (dict of lists): A dictionary of opt sections, where each opt section is a key and the corresponding @@ -292,18 +294,18 @@ def multi_job_string(job_list: list[QCInput]) -> str: job_list (): List of jobs. Returns: - (str) String representation of multi job input file. + str: String representation of a multi-job input file. """ multi_job_string = "" - for i, job_i in enumerate(job_list): - if i < len(job_list) - 1: + for i, job_i in enumerate(job_list, start=1): + if i < len(job_list): multi_job_string += str(job_i) + "\n@@@\n\n" else: multi_job_string += str(job_i) return multi_job_string @classmethod - def from_str(cls, string: str) -> QCInput: + def from_str(cls, string: str) -> Self: # type: ignore[override] """ Read QcInput from string. @@ -378,7 +380,7 @@ def write_multi_job_file(job_list: list[QCInput], filename: str): file.write(QCInput.multi_job_string(job_list)) @classmethod - def from_file(cls, filename: str | Path) -> QCInput: + def from_file(cls, filename: str | Path) -> Self: # type: ignore[override] """ Create QcInput from file. @@ -392,7 +394,7 @@ def from_file(cls, filename: str | Path) -> QCInput: return cls.from_str(file.read()) @classmethod - def from_multi_jobs_file(cls, filename: str) -> list[QCInput]: + def from_multi_jobs_file(cls, filename: str) -> list[Self]: """ Create list of QcInput from a file. @@ -415,7 +417,7 @@ def molecule_template(molecule: Molecule | list[Molecule] | Literal["read"]) -> molecule (Molecule, list of Molecules, or "read"). Returns: - (str) Molecule template. + str: Molecule template. """ # TODO: add ghost atoms mol_list = [] @@ -456,7 +458,7 @@ def rem_template(rem: dict) -> str: rem (): Returns: - (str) + str: REM template. """ rem_list = [] rem_list.append("$rem") @@ -474,7 +476,7 @@ def opt_template(opt: dict[str, list]) -> str: opt (): Returns: - (str) + str: Optimization template. """ opt_list = [] opt_list.append("$opt") @@ -493,13 +495,13 @@ def opt_template(opt: dict[str, list]) -> str: @staticmethod def pcm_template(pcm: dict) -> str: """ - Pcm run template. + PCM run template. Args: pcm (): Returns: - (str) + str: PCM template. """ pcm_list = [] pcm_list.append("$pcm") @@ -517,7 +519,7 @@ def solvent_template(solvent: dict) -> str: solvent (): Returns: - (str) + str: Solvent section. """ solvent_list = [] solvent_list.append("$solvent") @@ -533,7 +535,7 @@ def smx_template(smx: dict) -> str: smx (): Returns: - (str) + str: Solvation model with short-range corrections. """ smx_list = [] smx_list.append("$smx") @@ -608,7 +610,7 @@ def plots_template(plots: dict) -> str: plots (): Returns: - (str) + str: Plots section. """ plots_list = [] plots_list.append("$plots") @@ -624,7 +626,7 @@ def nbo_template(nbo: dict) -> str: nbo (): Returns: - (str) + str: NBO section. """ nbo_list = [] nbo_list.append("$nbo") @@ -660,7 +662,7 @@ def geom_opt_template(geom_opt: dict) -> str: geom_opt (): Returns: - (str) geom_opt parameters. + str: Geometry optimization section. """ geom_opt_list = [] geom_opt_list.append("$geom_opt") @@ -676,11 +678,11 @@ def cdft_template(cdft: list[list[dict]]) -> str: cdft: list of lists of dicts. Returns: - (str) + str: CDFT section. """ cdft_list = [] cdft_list.append("$cdft") - for ii, state in enumerate(cdft): + for ii, state in enumerate(cdft, start=1): for constraint in state: types = constraint["types"] cdft_list.append(f" {constraint['value']}") @@ -701,7 +703,7 @@ def cdft_template(cdft: list[list[dict]]) -> str: cdft_list.append(f" {coef} {first} {last} {type_string}") else: cdft_list.append(f" {coef} {first} {last}") - if len(cdft) != 1 and ii + 1 < len(state): + if len(cdft) != 1 and ii < len(state): cdft_list.append("--------------") # Ensure that you don't have a line indicating a state that doesn't exist @@ -718,7 +720,7 @@ def almo_template(almo_coupling: list[list[tuple[int, int]]]) -> str: almo: list of lists of int 2-tuples. Returns: - (str) + str: ALMO coupling section. """ almo_list = [] almo_list.append("$almo_coupling") @@ -760,7 +762,10 @@ def pcm_nonels_template(pcm_nonels: dict) -> str: } Returns: - (str) + str: the $pcm_nonels section. Note that all parameters will be concatenated onto + a single line formatted as a FORTRAN namelist. This is necessary + because the non-electrostatic part of the CMIRS solvation model in Q-Chem + calls a secondary code. """ pcm_nonels_list = [] pcm_nonels_list.append("$pcm_nonels") @@ -794,7 +799,7 @@ def find_sections(string: str) -> list: if "molecule" not in sections: raise ValueError("Output file does not contain a molecule section") if "rem" not in sections: - raise ValueError("Output file does not contain a rem section") + raise ValueError("Output file does not contain a REM section") return sections @staticmethod @@ -866,7 +871,7 @@ def read_rem(string: str) -> dict: string (str): String Returns: - (dict) rem + dict[str, str]: REM section """ header = r"^\s*\$rem" row = r"\s*([a-zA-Z\_\d]+)\s*=?\s*(\S+)" @@ -883,7 +888,7 @@ def read_opt(string: str) -> dict[str, list]: string (str): String Returns: - (dict) Opt section + dict[str, list]: Opt section """ patterns = { "CONSTRAINT": r"^\s*CONSTRAINT", @@ -944,7 +949,7 @@ def read_pcm(string: str) -> dict: string (str): String Returns: - (dict) PCM parameters + dict[str, str]: PCM parameters """ header = r"^\s*\$pcm" row = r"\s*([a-zA-Z\_]+)\s+(\S+)" @@ -965,7 +970,7 @@ def read_vdw(string: str) -> tuple[str, dict]: string (str): String Returns: - (str, dict) vdW mode ('atomic' or 'sequential') and dict of van der Waals radii. + tuple[str, dict]: (vdW mode ('atomic' or 'sequential'), dict of van der Waals radii) """ header = r"^\s*\$van_der_waals" row = r"[^\d]*(\d+).?(\d+.\d+)?.*" @@ -988,7 +993,7 @@ def read_solvent(string: str) -> dict: string (str): String Returns: - (dict) Solvent parameters + dict[str, str]: Solvent parameters """ header = r"^\s*\$solvent" row = r"\s*([a-zA-Z\_]+)\s+(\S+)" @@ -1009,7 +1014,7 @@ def read_smx(string: str) -> dict: string (str): String Returns: - (dict) SMX parameters. + dict[str, str] SMX parameters. """ header = r"^\s*\$smx" row = r"\s*([a-zA-Z\_]+)\s+(\S+)" @@ -1070,7 +1075,7 @@ def read_plots(string: str) -> dict: string (str): String Returns: - (dict) plots parameters. + dict[str, str]: plots parameters. """ header = r"^\s*\$plots" row = r"\s*([a-zA-Z\_]+)\s+(\S+)" @@ -1090,7 +1095,7 @@ def read_nbo(string: str) -> dict: string (str): String Returns: - (dict) nbo parameters. + dict[str, str]: nbo parameters. """ header = r"^\s*\$nbo" row = r"\s*([a-zA-Z\_\d]+)\s*=?\s*(\S+)" @@ -1110,7 +1115,7 @@ def read_geom_opt(string: str) -> dict: string (str): String Returns: - (dict) geom_opt parameters. + dict[str, str]: geom_opt parameters. """ header = r"^\s*\$geom_opt" row = r"\s*([a-zA-Z\_]+)\s*=?\s*(\S+)" @@ -1188,7 +1193,7 @@ def read_almo(string: str) -> list[list[tuple[int, int]]]: string (str): String Returns: - (list of lists of int 2-tuples) almo_coupling parameters + list[list[tuple[int, int]]]: ALMO coupling parameters """ pattern = { "key": r"\$almo_coupling\s*\n((?:\s*[\-0-9]+\s+[\-0-9]+\s*\n)+)\s*\-\-" @@ -1242,7 +1247,7 @@ def read_pcm_nonels(string: str) -> dict: string (str): String Returns: - (dict) PCM parameters + dict[str, str]: PCM parameters """ header = r"^\s*\$pcm_nonels" row = r"\s*([a-zA-Z\_]+)\s+(.+)" diff --git a/pymatgen/io/qchem/outputs.py b/pymatgen/io/qchem/outputs.py index 1734e04efab..9d61aa1039e 100644 --- a/pymatgen/io/qchem/outputs.py +++ b/pymatgen/io/qchem/outputs.py @@ -677,15 +677,12 @@ def _read_eigenvalues(self): if spin_unrestricted: header_pattern = r"Final Beta MO Eigenvalues" footer_pattern = r"Final Alpha MO Coefficients+\s*" - beta_eigenvalues = read_matrix_pattern( + self.data["beta_eigenvalues"] = read_matrix_pattern( header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float ) self.data["alpha_eigenvalues"] = alpha_eigenvalues - if spin_unrestricted: - self.data["beta_eigenvalues"] = beta_eigenvalues - def _read_fock_matrix(self): """Parses the Fock matrix. The matrix is read in whole from the output file and then transformed into the right dimensions. @@ -705,13 +702,6 @@ def _read_fock_matrix(self): alpha_fock_matrix = read_matrix_pattern( header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float ) - # The beta Fock matrix is only present if this is a spin-unrestricted calculation. - if spin_unrestricted: - header_pattern = r"Final Beta Fock Matrix" - footer_pattern = "SCF time:" - beta_fock_matrix = read_matrix_pattern( - header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float - ) # Convert the matrices to the right dimension. Right now they are simply # one massive list of numbers, but we need to split them into a matrix. The @@ -720,7 +710,14 @@ def _read_fock_matrix(self): alpha_fock_matrix = process_parsed_fock_matrix(alpha_fock_matrix) self.data["alpha_fock_matrix"] = alpha_fock_matrix + # The beta Fock matrix is only present if this is a spin-unrestricted calculation. if spin_unrestricted: + header_pattern = r"Final Beta Fock Matrix" + footer_pattern = "SCF time:" + beta_fock_matrix = read_matrix_pattern( + header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float + ) + # Perform the same transformation for the beta Fock matrix. beta_fock_matrix = process_parsed_fock_matrix(beta_fock_matrix) self.data["beta_fock_matrix"] = beta_fock_matrix @@ -744,12 +741,6 @@ def _read_coefficient_matrix(self): alpha_coeff_matrix = read_matrix_pattern( header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float ) - if spin_unrestricted: - header_pattern = r"Final Beta MO Coefficients" - footer_pattern = "Final Alpha Density Matrix" - beta_coeff_matrix = read_matrix_pattern( - header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float - ) # Convert the matrices to the right dimension. Right now they are simply # one massive list of numbers, but we need to split them into a matrix. The @@ -759,6 +750,12 @@ def _read_coefficient_matrix(self): self.data["alpha_coeff_matrix"] = alpha_coeff_matrix if spin_unrestricted: + header_pattern = r"Final Beta MO Coefficients" + footer_pattern = "Final Alpha Density Matrix" + beta_coeff_matrix = read_matrix_pattern( + header_pattern, footer_pattern, elements_pattern, self.text, postprocess=float + ) + # Perform the same transformation for the beta Fock matrix. beta_coeff_matrix = process_parsed_fock_matrix(beta_coeff_matrix) self.data["beta_coeff_matrix"] = beta_coeff_matrix @@ -2228,9 +2225,9 @@ def check_for_structure_changes(mol1: Molecule, mol2: Molecule) -> str: # Can add logic to check the distances in the future if desired - initial_mol_graph = MoleculeGraph.with_local_env_strategy(mol_list[0], OpenBabelNN()) + initial_mol_graph = MoleculeGraph.from_local_env_strategy(mol_list[0], OpenBabelNN()) initial_graph = initial_mol_graph.graph - last_mol_graph = MoleculeGraph.with_local_env_strategy(mol_list[1], OpenBabelNN()) + last_mol_graph = MoleculeGraph.from_local_env_strategy(mol_list[1], OpenBabelNN()) last_graph = last_mol_graph.graph if initial_mol_graph.isomorphic_to(last_mol_graph): return "no_change" diff --git a/pymatgen/io/qchem/utils.py b/pymatgen/io/qchem/utils.py index f38a541fa5c..93817ce3404 100644 --- a/pymatgen/io/qchem/utils.py +++ b/pymatgen/io/qchem/utils.py @@ -53,7 +53,7 @@ def read_matrix_pattern(header_pattern, footer_pattern, elements_pattern, text, elements = re.findall(elements_pattern, text_between_header_and_footer) # Apply postprocessing to all the elements - return [postprocess(e) for e in elements] + return [postprocess(elem) for elem in elements] def read_table_pattern( diff --git a/pymatgen/io/res.py b/pymatgen/io/res.py index 95a25763d7f..c394689128f 100644 --- a/pymatgen/io/res.py +++ b/pymatgen/io/res.py @@ -6,16 +6,15 @@ from and back to a string/file is not guaranteed to be reversible, i.e. a diff on the output would not be empty. The difference should be limited to whitespace, float precision, and the REM entries. - """ from __future__ import annotations +import datetime import re from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Literal -import dateutil.parser # type: ignore[import] from monty.io import zopen from monty.json import MSONable @@ -26,6 +25,9 @@ if TYPE_CHECKING: from collections.abc import Iterator from datetime import date + from pathlib import Path + + from typing_extensions import Self from pymatgen.core.trajectory import Vector3D @@ -237,10 +239,9 @@ def _parse_str(cls, source: str) -> Res: return self._parse_txt() @classmethod - def _parse_file(cls, filename: str) -> Res: + def _parse_file(cls, filename: str | Path) -> Res: """Parses the res file as a file.""" self = cls() - self.filename = filename with zopen(filename, mode="r") as file: self.source = file.read() return self._parse_txt() @@ -333,12 +334,12 @@ def _site_spin(cls, spin: float | None) -> dict[str, float] | None: return {"magmom": spin} @classmethod - def from_str(cls, string: str) -> ResProvider: + def from_str(cls, string: str) -> Self: """Construct a Provider from a string.""" return cls(ResParser._parse_str(string)) @classmethod - def from_file(cls, filename: str) -> ResProvider: + def from_file(cls, filename: str | Path) -> Self: """Construct a Provider from a file.""" return cls(ResParser._parse_file(filename)) @@ -402,12 +403,12 @@ def __init__(self, res: Res, parse_rems: Literal["gentle", "strict"] = "gentle") self.parse_rems = parse_rems @classmethod - def from_str(cls, string: str, parse_rems: Literal["gentle", "strict"] = "gentle") -> AirssProvider: + def from_str(cls, string: str, parse_rems: Literal["gentle", "strict"] = "gentle") -> Self: """Construct a Provider from a string.""" return cls(ResParser._parse_str(string), parse_rems) @classmethod - def from_file(cls, filename: str, parse_rems: Literal["gentle", "strict"] = "gentle") -> AirssProvider: + def from_file(cls, filename: str | Path, parse_rems: Literal["gentle", "strict"] = "gentle") -> Self: """Construct a Provider from a file.""" return cls(ResParser._parse_file(filename), parse_rems) @@ -417,8 +418,11 @@ def _parse_date(cls, string: str) -> date: match = cls._date_fmt.search(string) if match is None: raise ResParseError(f"Could not parse the date from {string=}.") - date_string = match.group(0) - return dateutil.parser.parse(date_string) + + day, month, year, *_ = match.groups() + month_num = datetime.datetime.strptime(month, "%b").month + + return datetime.date(int(year), month_num, int(day)) def _raise_or_none(self, err: ResParseError) -> None: if self.parse_rems != "strict": @@ -430,7 +434,7 @@ def get_run_start_info(self) -> tuple[date, str] | None: Retrieves the run start date and the path it was started in from the REM entries. Returns: - (date, path) + tuple[date, str]: (date, path) """ for rem in self._res.REMS: if rem.strip().startswith("Run started:"): @@ -459,7 +463,7 @@ def get_func_rel_disp(self) -> tuple[str, str, str] | None: Retrieves the functional, relativity scheme, and dispersion correction from the REM entries. Returns: - (functional, relativity, dispersion) + tuple[str, str, str]: (functional, relativity, dispersion) """ for rem in self._res.REMS: if rem.strip().startswith("Functional"): @@ -474,7 +478,7 @@ def get_cut_grid_gmax_fsbc(self) -> tuple[float, float, float, str] | None: from the REM entries. Returns: - (cut-off, grid scale, Gmax, fsbc) + tuple[float, float, float, str]: (cut-off, grid scale, Gmax, fsbc) """ for rem in self._res.REMS: if rem.strip().startswith("Cut-off"): @@ -488,7 +492,7 @@ def get_mpgrid_offset_nkpts_spacing(self) -> tuple[tuple[int, int, int], Vector3 Retrieves the MP grid, the grid offsets, number of kpoints, and maximum kpoint spacing. Returns: - (MP grid), (offsets), No. kpts, max spacing) + tuple[tuple[int, int, int], Vector3D, int, float]: (MP grid), (offsets), No. kpts, max spacing) """ for rem in self._res.REMS: if rem.strip().startswith("MP grid"): @@ -503,8 +507,8 @@ def get_airss_version(self) -> tuple[str, date] | None: """ Retrieves the version of AIRSS that was used along with the build date (not compile date). - Return: - (version string, date) + Returns: + tuple[str, date] (version string, date) """ for rem in self._res.REMS: if rem.strip().startswith("AIRSS Version"): @@ -531,13 +535,13 @@ def get_pspots(self) -> dict[str, str]: Returns: dict[specie, potential] """ - pspots: dict[str, str] = {} + pseudo_pots: dict[str, str] = {} for rem in self._res.REMS: srem = rem.split() if len(srem) == 2 and Element.is_valid_symbol(srem[0]): k, v = srem - pspots[k] = v - return pspots + pseudo_pots[k] = v + return pseudo_pots @property def seed(self) -> str: diff --git a/pymatgen/io/shengbte.py b/pymatgen/io/shengbte.py index c5fc381b732..666d367a963 100644 --- a/pymatgen/io/shengbte.py +++ b/pymatgen/io/shengbte.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from monty.dev import requires @@ -12,6 +12,9 @@ from pymatgen.core.structure import Structure from pymatgen.io.vasp import Kpoints +if TYPE_CHECKING: + from typing_extensions import Self + try: import f90nml except ImportError: @@ -124,7 +127,7 @@ def __init__(self, ngrid: list[int] | None = None, temperature: float | dict[str f90nml, "ShengBTE Control object requires f90nml to be installed. Please get it at https://pypi.org/project/f90nml.", ) - def from_file(cls, filepath: str): + def from_file(cls, filepath: str) -> Self: """ Read a CONTROL namelist file and output a 'Control' object. @@ -147,7 +150,7 @@ def from_file(cls, filepath: str): return cls.from_dict(all_dict) @classmethod - def from_dict(cls, control_dict: dict): + def from_dict(cls, control_dict: dict) -> Self: """ Write a CONTROL file from a Python dictionary. Description and default parameters can be found at @@ -164,7 +167,7 @@ def from_dict(cls, control_dict: dict): f90nml, "ShengBTE Control object requires f90nml to be installed. Please get it at https://pypi.org/project/f90nml.", ) - def to_file(self, filename: str = "CONTROL"): + def to_file(self, filename: str = "CONTROL") -> None: """ Writes ShengBTE CONTROL file from 'Control' object. @@ -191,11 +194,11 @@ def to_file(self, filename: str = "CONTROL"): flags_nml = f90nml.Namelist({"flags": flags_dict}) control_str += str(flags_nml) + "\n" - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: file.write(control_str) @classmethod - def from_structure(cls, structure: Structure, reciprocal_density: int | None = 50000, **kwargs): + def from_structure(cls, structure: Structure, reciprocal_density: int | None = 50000, **kwargs) -> Self: """ Get a ShengBTE control object from a structure. diff --git a/pymatgen/io/vasp/help.py b/pymatgen/io/vasp/help.py index 58f09658ea6..6a9059e71b9 100644 --- a/pymatgen/io/vasp/help.py +++ b/pymatgen/io/vasp/help.py @@ -48,8 +48,8 @@ def get_help(cls, tag, fmt="text"): Help text. """ tag = tag.upper() - r = requests.get(f"https://www.vasp.at/wiki/index.php/{tag}", verify=False) - soup = BeautifulSoup(r.text) + response = requests.get(f"https://www.vasp.at/wiki/index.php/{tag}", verify=False) + soup = BeautifulSoup(response.text) main_doc = soup.find(id="mw-content-text") if fmt == "text": output = main_doc.text @@ -67,8 +67,8 @@ def get_incar_tags(cls): "https://www.vasp.at/wiki/index.php/Category:INCAR", "https://www.vasp.at/wiki/index.php?title=Category:INCAR&pagefrom=ML+FF+LCONF+DISCARD#mw-pages", ]: - r = requests.get(page, verify=False) - soup = BeautifulSoup(r.text) + response = requests.get(page, verify=False) + soup = BeautifulSoup(response.text) for div in soup.findAll("div", {"class": "mw-category-group"}): children = div.findChildren("li") for child in children: diff --git a/pymatgen/io/vasp/incar_parameters.json b/pymatgen/io/vasp/incar_parameters.json index c6697647007..c662e725645 100644 --- a/pymatgen/io/vasp/incar_parameters.json +++ b/pymatgen/io/vasp/incar_parameters.json @@ -156,7 +156,7 @@ "type": "float" }, "ENCUT": { - "type": "int" + "type": "float" }, "ENCUTFOCK": { "type": "float" diff --git a/pymatgen/io/vasp/inputs.py b/pymatgen/io/vasp/inputs.py index 12e927ca151..2efd0704c40 100644 --- a/pymatgen/io/vasp/inputs.py +++ b/pymatgen/io/vasp/inputs.py @@ -16,10 +16,10 @@ import subprocess import warnings from collections import namedtuple -from enum import Enum +from enum import Enum, unique from glob import glob from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np import scipy.constants as const @@ -37,8 +37,10 @@ if TYPE_CHECKING: from collections.abc import Iterator, Sequence + from pathlib import Path from numpy.typing import ArrayLike + from typing_extensions import Self from pymatgen.core.trajectory import Vector3D from pymatgen.util.typing import PathLike @@ -218,7 +220,7 @@ def __setattr__(self, name, value): super().__setattr__(name, value) @classmethod - def from_file(cls, filename, check_for_potcar=True, read_velocities=True, **kwargs) -> Poscar: + def from_file(cls, filename, check_for_potcar=True, read_velocities=True, **kwargs) -> Self: """ Reads a Poscar from a file. @@ -269,7 +271,7 @@ def from_file(cls, filename, check_for_potcar=True, read_velocities=True, **kwar return cls.from_str(file.read(), names, read_velocities=read_velocities) @classmethod - def from_str(cls, data, default_names=None, read_velocities=True) -> Poscar: + def from_str(cls, data, default_names=None, read_velocities=True) -> Self: """ Reads a Poscar from a string. @@ -325,9 +327,12 @@ def from_str(cls, data, default_names=None, read_velocities=True) -> Poscar: lattice *= scale vasp5_symbols = False + atomic_symbols = [] + try: n_atoms = [int(i) for i in lines[5].split()] ipos = 6 + except ValueError: vasp5_symbols = True symbols = [symbol.split("/")[0] for symbol in lines[5].split()] @@ -354,13 +359,14 @@ def from_str(cls, data, default_names=None, read_velocities=True) -> Poscar: break except ValueError: pass + for i_line_symbols in range(6, 5 + n_lines_symbols): symbols.extend(lines[i_line_symbols].split()) n_atoms = [] iline_natoms_start = 5 + n_lines_symbols for iline_natoms in range(iline_natoms_start, iline_natoms_start + n_lines_symbols): n_atoms.extend([int(i) for i in lines[iline_natoms].split()]) - atomic_symbols = [] + for i, nat in enumerate(n_atoms): atomic_symbols.extend([symbols[i]] * nat) ipos = 5 + 2 * n_lines_symbols @@ -398,11 +404,12 @@ def from_str(cls, data, default_names=None, read_velocities=True) -> Poscar: if not all(Element.is_valid_symbol(sym) for sym in atomic_symbols): raise ValueError("Non-valid symbols detected.") vasp5_symbols = True + except (ValueError, IndexError): # Defaulting to false names. atomic_symbols = [] - for i, nat in enumerate(n_atoms): - sym = Element.from_Z(i + 1).symbol + for i, nat in enumerate(n_atoms, start=1): + sym = Element.from_Z(i).symbol atomic_symbols.extend([sym] * nat) warnings.warn( f"Elements in POSCAR cannot be determined. Defaulting to false names {atomic_symbols}.", @@ -606,9 +613,10 @@ def as_dict(self) -> dict: } @classmethod - def from_dict(cls, dct: dict) -> Poscar: + def from_dict(cls, dct: dict) -> Self: """ - :param dct: Dict representation. + Args: + dct (dict): Dict representation. Returns: Poscar @@ -716,7 +724,7 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, dct: dict[str, Any]) -> Incar: + def from_dict(cls, dct: dict[str, Any]) -> Self: """ Args: dct (dict): Serialized Incar @@ -783,7 +791,7 @@ def write_file(self, filename: PathLike): file.write(str(self)) @classmethod - def from_file(cls, filename: PathLike) -> Incar: + def from_file(cls, filename: PathLike) -> Self: """Reads an Incar object from a file. Args: @@ -796,7 +804,7 @@ def from_file(cls, filename: PathLike) -> Incar: return cls.from_str(file.read()) @classmethod - def from_str(cls, string: str) -> Incar: + def from_str(cls, string: str) -> Self: """Reads an Incar object from a string. Args: @@ -809,9 +817,9 @@ def from_str(cls, string: str) -> Incar: params = {} for line in lines: for sline in line.split(";"): - if m := re.match(r"(\w+)\s*=\s*(.*)", sline.strip()): - key = m.group(1).strip() - val = m.group(2).strip() + if match := re.match(r"(\w+)\s*=\s*(.*)", sline.strip()): + key = match.group(1).strip() + val = match.group(2).strip() val = Incar.proc_val(key, val) params[key] = val return cls(params) @@ -880,11 +888,10 @@ def smart_int_or_float(num_str): output.append(smart_int_or_float(tok[0])) return output if key in bool_keys: - m = re.match(r"^\.?([T|F|t|f])[A-Za-z]*\.?", val) - if m: - return m.group(1).lower() == "t" + if match := re.match(r"^\.?([T|F|t|f])[A-Za-z]*\.?", val): + return match.group(1).lower() == "t" - raise ValueError(key + " should be a boolean type!") + raise ValueError(f"{key} should be a boolean type!") if key in float_keys: return float(re.search(r"^-?\d*\.?\d*[e|E]?-?\d*", val).group(0)) # type: ignore @@ -976,7 +983,7 @@ def check_params(self) -> None: param_type = incar_params[tag].get("type") allowed_values = incar_params[tag].get("values") - if param_type is not None and not isinstance(val, eval(param_type)): + if param_type is not None and type(val).__name__ != param_type: warnings.warn(f"{tag}: {val} is not a {param_type}", BadIncarWarning, stacklevel=2) # Only check value when it's not None, @@ -989,6 +996,7 @@ class BadIncarWarning(UserWarning): """Warning class for bad INCAR parameters.""" +@unique class KpointsSupportedModes(Enum): """Enum type of all supported modes for Kpoint generation.""" @@ -1003,15 +1011,16 @@ def __str__(self): return str(self.name) @classmethod - def from_str(cls, mode: str) -> KpointsSupportedModes: + def from_str(cls, mode: str) -> Self: """ - :param s: String + Args: + mode: String Returns: Kpoints_supported_modes """ initial = mode.lower()[0] - for key in KpointsSupportedModes: + for key in cls: if key.name.lower()[0] == initial: return key raise ValueError(f"Invalid Kpoint {mode=}") @@ -1027,7 +1036,7 @@ def __init__( comment: str = "Default gamma", num_kpts: int = 0, style: KpointsSupportedModes = supported_modes.Gamma, - kpts: Sequence[float | Sequence] = ((1, 1, 1),), + kpts: Sequence[float | Sequence[float]] = ((1, 1, 1),), kpts_shift: Vector3D = (0, 0, 0), kpts_weights=None, coord_type=None, @@ -1085,7 +1094,7 @@ def __init__( self.style = style self.coord_type = coord_type self.kpts_weights = kpts_weights - self.kpts_shift = kpts_shift + self.kpts_shift = tuple(kpts_shift) self.labels = labels self.tet_number = tet_number self.tet_weight = tet_weight @@ -1099,13 +1108,12 @@ def style(self) -> KpointsSupportedModes: return self._style @style.setter - def style(self, style): + def style(self, style) -> None: """ - :param style: Style + Sets the style for the Kpoints. One of Kpoints_supported_modes enum. - Returns: - Sets the style for the Kpoints. One of Kpoints_supported_modes - enum. + Args: + style: Style """ if isinstance(style, str): style = Kpoints.supported_modes.from_str(style) @@ -1123,10 +1131,10 @@ def style(self, style): self._style = style - @staticmethod - def automatic(subdivisions): + @classmethod + def automatic(cls, subdivisions) -> Self: """ - Convenient static constructor for a fully automatic Kpoint grid, with + Constructor for a fully automatic Kpoint grid, with gamma centered Monkhorst-Pack grids and the number of subdivisions along each reciprocal lattice vector determined by the scheme in the VASP manual. @@ -1138,15 +1146,12 @@ def automatic(subdivisions): Returns: Kpoints object """ - return Kpoints( - "Fully automatic kpoint scheme", 0, style=Kpoints.supported_modes.Automatic, kpts=[[subdivisions]] - ) + return cls("Fully automatic kpoint scheme", 0, style=Kpoints.supported_modes.Automatic, kpts=[[subdivisions]]) - @staticmethod - def gamma_automatic(kpts: tuple[int, int, int] = (1, 1, 1), shift: Vector3D = (0, 0, 0)): + @classmethod + def gamma_automatic(cls, kpts: tuple[int, int, int] = (1, 1, 1), shift: Vector3D = (0, 0, 0)) -> Self: """ - Convenient static constructor for an automatic Gamma centered Kpoint - grid. + Constructor for an automatic Gamma centered Kpoint grid. Args: kpts: Subdivisions N_1, N_2 and N_3 along reciprocal lattice @@ -1156,10 +1161,10 @@ def gamma_automatic(kpts: tuple[int, int, int] = (1, 1, 1), shift: Vector3D = (0 Returns: Kpoints object """ - return Kpoints("Automatic kpoint scheme", 0, Kpoints.supported_modes.Gamma, kpts=[kpts], kpts_shift=shift) + return cls("Automatic kpoint scheme", 0, Kpoints.supported_modes.Gamma, kpts=[kpts], kpts_shift=shift) - @staticmethod - def monkhorst_automatic(kpts: tuple[int, int, int] = (2, 2, 2), shift: Vector3D = (0, 0, 0)): + @classmethod + def monkhorst_automatic(cls, kpts: tuple[int, int, int] = (2, 2, 2), shift: Vector3D = (0, 0, 0)) -> Self: """ Convenient static constructor for an automatic Monkhorst pack Kpoint grid. @@ -1172,10 +1177,10 @@ def monkhorst_automatic(kpts: tuple[int, int, int] = (2, 2, 2), shift: Vector3D Returns: Kpoints object """ - return Kpoints("Automatic kpoint scheme", 0, Kpoints.supported_modes.Monkhorst, kpts=[kpts], kpts_shift=shift) + return cls("Automatic kpoint scheme", 0, Kpoints.supported_modes.Monkhorst, kpts=[kpts], kpts_shift=shift) - @staticmethod - def automatic_density(structure: Structure, kppa: float, force_gamma: bool = False): + @classmethod + def automatic_density(cls, structure: Structure, kppa: float, force_gamma: bool = False) -> Self: """ Returns an automatic Kpoint object based on a structure and a kpoint density. Uses Gamma centered meshes for hexagonal cells and face-centered cells, @@ -1202,7 +1207,7 @@ def automatic_density(structure: Structure, kppa: float, force_gamma: bool = Fal ngrid = kppa / len(structure) mult = (ngrid * lengths[0] * lengths[1] * lengths[2]) ** (1 / 3) - num_div = [int(math.floor(max(mult / length, 1))) for length in lengths] + num_div = [math.floor(max(mult / length, 1)) for length in lengths] is_hexagonal = lattice.is_hexagonal() is_face_centered = structure.get_space_group_info()[0][0] == "F" @@ -1212,10 +1217,10 @@ def automatic_density(structure: Structure, kppa: float, force_gamma: bool = Fal else: style = Kpoints.supported_modes.Monkhorst - return Kpoints(comment, 0, style, [num_div], (0, 0, 0)) + return cls(comment, 0, style, [num_div], (0, 0, 0)) - @staticmethod - def automatic_gamma_density(structure: Structure, kppa: float): + @classmethod + def automatic_gamma_density(cls, structure: Structure, kppa: float) -> Self: """ Returns an automatic Kpoint object based on a structure and a kpoint density. Uses Gamma centered meshes always. For GW. @@ -1228,28 +1233,28 @@ def automatic_gamma_density(structure: Structure, kppa: float): structure: Input structure kppa: Grid density """ - latt = structure.lattice - a, b, c = latt.abc - ngrid = kppa / len(structure) + lattice = structure.lattice + a, b, c = lattice.abc + n_grid = kppa / len(structure) - mult = (ngrid * a * b * c) ** (1 / 3) - num_div = [int(round(mult / length)) for length in latt.abc] + multip = (n_grid * a * b * c) ** (1 / 3) + n_div = [int(round(multip / length)) for length in lattice.abc] # ensure that all num_div[i] > 0 - num_div = [idx if idx > 0 else 1 for idx in num_div] + n_div = [idx if idx > 0 else 1 for idx in n_div] # VASP documentation recommends to use even grids for n <= 8 and odd grids for n > 8. - num_div = [idx + idx % 2 if idx <= 8 else idx - idx % 2 + 1 for idx in num_div] + n_div = [idx + idx % 2 if idx <= 8 else idx - idx % 2 + 1 for idx in n_div] style = Kpoints.supported_modes.Gamma comment = f"pymatgen with grid density = {kppa:.0f} / number of atoms" - num_kpts = 0 - return Kpoints(comment, num_kpts, style, [num_div], (0, 0, 0)) + n_kpts = 0 + return cls(comment, n_kpts, style, [n_div], (0, 0, 0)) - @staticmethod - def automatic_density_by_vol(structure: Structure, kppvol: int, force_gamma: bool = False) -> Kpoints: + @classmethod + def automatic_density_by_vol(cls, structure: Structure, kppvol: int, force_gamma: bool = False) -> Self: """ Returns an automatic Kpoint object based on a structure and a kpoint density per inverse Angstrom^3 of reciprocal cell. @@ -1267,12 +1272,12 @@ def automatic_density_by_vol(structure: Structure, kppvol: int, force_gamma: boo """ vol = structure.lattice.reciprocal_lattice.volume kppa = kppvol * vol * len(structure) - return Kpoints.automatic_density(structure, kppa, force_gamma=force_gamma) + return cls.automatic_density(structure, kppa, force_gamma=force_gamma) - @staticmethod + @classmethod def automatic_density_by_lengths( - structure: Structure, length_densities: Sequence[float], force_gamma: bool = False - ): + cls, structure: Structure, length_densities: Sequence[float], force_gamma: bool = False + ) -> Self: """ Returns an automatic Kpoint object based on a structure and a k-point density normalized by lattice constants. @@ -1306,10 +1311,10 @@ def automatic_density_by_lengths( else: style = Kpoints.supported_modes.Monkhorst - return Kpoints(comment, 0, style, [num_div], (0, 0, 0)) + return cls(comment, 0, style, [num_div], (0, 0, 0)) - @staticmethod - def automatic_linemode(divisions, ibz): + @classmethod + def automatic_linemode(cls, divisions, ibz) -> Self: """ Convenient static constructor for a KPOINTS in mode line_mode. gamma centered Monkhorst-Pack grids and the number of subdivisions @@ -1337,7 +1342,7 @@ def automatic_linemode(divisions, ibz): kpoints.append(ibz.kpath["kpoints"][path[-1]]) labels.append(path[-1]) - return Kpoints( + return cls( "Line_mode KPOINTS file", style=Kpoints.supported_modes.Line_mode, coord_type="Reciprocal", @@ -1355,7 +1360,7 @@ def __eq__(self, other: object) -> bool: return self.as_dict() == other.as_dict() @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Reads a Kpoints object from a KPOINTS file. @@ -1369,7 +1374,7 @@ def from_file(cls, filename): return cls.from_str(file.read()) @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """ Reads a Kpoints object from a KPOINTS string. @@ -1392,63 +1397,69 @@ def from_str(cls, string): coord_pattern = re.compile(r"^\s*([\d+.\-Ee]+)\s+([\d+.\-Ee]+)\s+([\d+.\-Ee]+)") # Automatic gamma and Monk KPOINTS, with optional shift - if style in ["g", "m"]: - kpts = [int(i) for i in lines[3].split()] - kpts_shift = (0, 0, 0) + if style in {"g", "m"}: + kpts = tuple(int(i) for i in lines[3].split()) + assert len(kpts) == 3 + + kpts_shift: tuple[float, float, float] = (0, 0, 0) if len(lines) > 4 and coord_pattern.match(lines[4]): try: - kpts_shift = [float(i) for i in lines[4].split()] + _kpts_shift = tuple(float(i) for i in lines[4].split()) except ValueError: - pass + _kpts_shift = (0, 0, 0) + + if len(_kpts_shift) == 3: + kpts_shift = _kpts_shift + return cls.gamma_automatic(kpts, kpts_shift) if style == "g" else cls.monkhorst_automatic(kpts, kpts_shift) # Automatic kpoints with basis if num_kpts <= 0: - style = cls.supported_modes.Cartesian if style in "ck" else cls.supported_modes.Reciprocal - kpts = [[float(j) for j in lines[i].split()] for i in range(3, 6)] - kpts_shift = [float(i) for i in lines[6].split()] - return Kpoints( + _style = cls.supported_modes.Cartesian if style in "ck" else cls.supported_modes.Reciprocal + _kpts_shift = tuple(float(i) for i in lines[6].split()) + kpts_shift = _kpts_shift if len(_kpts_shift) == 3 else (0, 0, 0) + + return cls( comment=comment, num_kpts=num_kpts, - style=style, - kpts=kpts, + style=_style, + kpts=[[float(j) for j in lines[i].split()] for i in range(3, 6)], kpts_shift=kpts_shift, ) # Line-mode KPOINTS, usually used with band structures if style == "l": coord_type = "Cartesian" if lines[3].lower()[0] in "ck" else "Reciprocal" - style = cls.supported_modes.Line_mode - kpts = [] + _style = cls.supported_modes.Line_mode + _kpts: list[list[float]] = [] labels = [] patt = re.compile(r"([e0-9.\-]+)\s+([e0-9.\-]+)\s+([e0-9.\-]+)\s*!*\s*(.*)") for idx in range(4, len(lines)): line = lines[idx] - m = patt.match(line) - if m: - kpts.append([float(m.group(1)), float(m.group(2)), float(m.group(3))]) - labels.append(m.group(4).strip()) - return Kpoints( + if match := patt.match(line): + _kpts.append([float(match.group(1)), float(match.group(2)), float(match.group(3))]) + labels.append(match.group(4).strip()) + return cls( comment=comment, num_kpts=num_kpts, - style=style, - kpts=kpts, + style=_style, + kpts=_kpts, coord_type=coord_type, labels=labels, ) # Assume explicit KPOINTS if all else fails. - style = cls.supported_modes.Cartesian if style in "ck" else cls.supported_modes.Reciprocal - kpts = [] + _style = cls.supported_modes.Cartesian if style in "ck" else cls.supported_modes.Reciprocal + _kpts = [] kpts_weights = [] labels = [] tet_number = 0 - tet_weight = 0 + tet_weight: float = 0 tet_connections = None for idx in range(3, 3 + num_kpts): tokens = lines[idx].split() - kpts.append([float(j) for j in tokens[0:3]]) + _kpts.append([float(j) for j in tokens[:3]]) kpts_weights.append(float(tokens[3])) if len(tokens) > 4: labels.append(tokens[4]) @@ -1470,8 +1481,8 @@ def from_str(cls, string): return cls( comment=comment, num_kpts=num_kpts, - style=cls.supported_modes[str(style)], - kpts=kpts, + style=cls.supported_modes[str(_style)], + kpts=_kpts, kpts_weights=kpts_weights, tet_number=tet_number, tet_weight=tet_weight, @@ -1479,7 +1490,7 @@ def from_str(cls, string): labels=labels, ) - def write_file(self, filename): + def write_file(self, filename: str) -> None: """ Write Kpoints to a file. @@ -1518,7 +1529,7 @@ def __repr__(self): lines.append(" ".join(map(str, self.kpts_shift))) return "\n".join(lines) + "\n" - def as_dict(self): + def as_dict(self) -> dict: """MSONable dict.""" dct = { "comment": self.comment, @@ -1543,49 +1554,50 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation. + Args: + dct (dict): Dict representation. Returns: Kpoints """ - comment = d.get("comment", "") - generation_style = d.get("generation_style") - kpts = d.get("kpoints", [[1, 1, 1]]) - kpts_shift = d.get("usershift", [0, 0, 0]) - num_kpts = d.get("nkpoints", 0) + comment = dct.get("comment", "") + generation_style = cast(KpointsSupportedModes, dct.get("generation_style")) + kpts = dct.get("kpoints", [[1, 1, 1]]) + kpts_shift = dct.get("usershift", [0, 0, 0]) + num_kpts = dct.get("nkpoints", 0) return cls( comment=comment, kpts=kpts, style=generation_style, kpts_shift=kpts_shift, num_kpts=num_kpts, - kpts_weights=d.get("kpts_weights"), - coord_type=d.get("coord_type"), - labels=d.get("labels"), - tet_number=d.get("tet_number", 0), - tet_weight=d.get("tet_weight", 0), - tet_connections=d.get("tet_connections"), + kpts_weights=dct.get("kpts_weights"), + coord_type=dct.get("coord_type"), + labels=dct.get("labels"), + tet_number=dct.get("tet_number", 0), + tet_weight=dct.get("tet_weight", 0), + tet_connections=dct.get("tet_connections"), ) -def _parse_bool(s): - if m := re.match(r"^\.?([TFtf])[A-Za-z]*\.?", s): - return m[1] in ["T", "t"] - raise ValueError(f"{s} should be a boolean type!") +def _parse_bool(string): + if match := re.match(r"^\.?([TFtf])[A-Za-z]*\.?", string): + return match[1] in {"T", "t"} + raise ValueError(f"{string} should be a boolean type!") -def _parse_float(s): - return float(re.search(r"^-?\d*\.?\d*[eE]?-?\d*", s).group(0)) +def _parse_float(string): + return float(re.search(r"^-?\d*\.?\d*[eE]?-?\d*", string).group(0)) -def _parse_int(s): - return int(re.match(r"^-?[0-9]+", s).group(0)) +def _parse_int(string): + return int(re.match(r"^-?[0-9]+", string).group(0)) -def _parse_list(s): - return [float(y) for y in re.split(r"\s+", s.strip()) if not y.isalpha()] +def _parse_list(string): + return [float(y) for y in re.split(r"\s+", string.strip()) if not y.isalpha()] Orbital = namedtuple("Orbital", ["n", "l", "j", "E", "occ"]) @@ -1829,16 +1841,17 @@ def copy(self) -> PotcarSingle: return PotcarSingle(self.data, symbol=self.symbol) @classmethod - def from_file(cls, filename: str) -> PotcarSingle: + def from_file(cls, filename: str) -> Self: """Reads PotcarSingle from file. - :param filename: Filename. + Args: + filename: Filename. Returns: PotcarSingle """ match = re.search(r"(?<=POTCAR\.)(.*)(?=.gz)", str(filename)) - symbol = match.group(0) if match else "" + symbol = match[0] if match else "" try: with zopen(filename, mode="rt") as file: @@ -1850,7 +1863,7 @@ def from_file(cls, filename: str) -> PotcarSingle: return cls(file.read(), symbol=symbol or None) @classmethod - def from_symbol_and_functional(cls, symbol: str, functional: str | None = None): + def from_symbol_and_functional(cls, symbol: str, functional: str | None = None) -> Self: """Makes a PotcarSingle from a symbol and functional. Args: @@ -2169,12 +2182,10 @@ def md5_header_hash(self) -> str: if k in ("nentries", "Orbitals", "SHA256", "COPYR"): continue hash_str += f"{k}" - if isinstance(v, bool): + if isinstance(v, (bool, int)): hash_str += f"{v}" elif isinstance(v, float): hash_str += f"{v:.3f}" - elif isinstance(v, int): - hash_str += f"{v}" elif isinstance(v, (tuple, list)): for item in v: if isinstance(item, float): @@ -2467,8 +2478,10 @@ def __init__( if symbols is not None: self.set_symbols(symbols, functional, sym_potcar_map) - def __iter__(self) -> Iterator[PotcarSingle]: # boilerplate code. only here to supply - # type hint so `for psingle in Potcar()` is correctly inferred as PotcarSingle + def __iter__(self) -> Iterator[PotcarSingle]: + """Boilerplate code. Only here to supply type hint so + `for psingle in Potcar()` is correctly inferred as PotcarSingle + """ return super().__iter__() def as_dict(self): @@ -2481,21 +2494,23 @@ def as_dict(self): } @classmethod - def from_dict(cls, d): + def from_dict(cls, dct) -> Self: """ - :param d: Dict representation + Args: + dct (dict): Dict representation. Returns: Potcar """ - return Potcar(symbols=d["symbols"], functional=d["functional"]) + return Potcar(symbols=dct["symbols"], functional=dct["functional"]) @classmethod - def from_file(cls, filename: str): + def from_file(cls, filename: str) -> Self: """ Reads Potcar from file. - :param filename: Filename + Args: + filename: Filename Returns: Potcar @@ -2575,7 +2590,7 @@ class VaspInput(dict, MSONable): def __init__( self, - incar: Incar, + incar: dict | Incar, kpoints: Kpoints | None, poscar: Poscar, potcar: Potcar | None, @@ -2614,21 +2629,21 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: """ - :param d: Dict representation. + Args: + dct (dict): Dict representation. Returns: VaspInput """ - dec = MontyDecoder() - sub_dct = {"optional_files": {}} + sub_dct: dict[str, dict] = {"optional_files": {}} for key, val in dct.items(): if key in ["INCAR", "POSCAR", "POTCAR", "KPOINTS"]: - sub_dct[key.lower()] = dec.process_decoded(val) + sub_dct[key.lower()] = MontyDecoder().process_decoded(val) elif key not in ["@module", "@class"]: - sub_dct["optional_files"][key] = dec.process_decoded(val) - return cls(**sub_dct) + sub_dct["optional_files"][key] = MontyDecoder().process_decoded(val) + return cls(**sub_dct) # type: ignore[arg-type] def write_input(self, output_dir=".", make_dir_if_not_present=True): """ @@ -2648,7 +2663,7 @@ def write_input(self, output_dir=".", make_dir_if_not_present=True): file.write(str(v)) @classmethod - def from_directory(cls, input_dir, optional_files=None): + def from_directory(cls, input_dir: str, optional_files: dict | None = None) -> Self: """ Read in a set of VASP input from a directory. Note that only the standard INCAR, POSCAR, POTCAR and KPOINTS files are read unless @@ -2669,14 +2684,14 @@ def from_directory(cls, input_dir, optional_files=None): ]: try: full_zpath = zpath(os.path.join(input_dir, fname)) - sub_dct[fname.lower()] = ftype.from_file(full_zpath) + sub_dct[fname.lower()] = ftype.from_file(full_zpath) # type: ignore[attr-defined] except FileNotFoundError: # handle the case where there is no KPOINTS file sub_dct[fname.lower()] = None - sub_dct["optional_files"] = {} - if optional_files is not None: - for fname, ftype in optional_files.items(): - sub_dct["optional_files"][fname] = ftype.from_file(os.path.join(input_dir, fname)) + sub_dct["optional_files"] = { + fname: ftype.from_file(os.path.join(input_dir, fname)) for fname, ftype in (optional_files or {}).items() + } + return cls(**sub_dct) def copy(self, deep: bool = True): @@ -2691,15 +2706,16 @@ def run_vasp( vasp_cmd: list | None = None, output_file: PathLike = "vasp.out", err_file: PathLike = "vasp.err", - ): + ) -> None: """ Write input files and run VASP. - :param run_dir: Where to write input files and do the run. - :param vasp_cmd: Args to be supplied to run VASP. Otherwise, the - PMG_VASP_EXE in .pmgrc.yaml is used. - :param output_file: File to write output. - :param err_file: File to write err. + Args: + run_dir: Where to write input files and do the run. + vasp_cmd: Args to be supplied to run VASP. Otherwise, the + PMG_VASP_EXE in .pmgrc.yaml is used. + output_file: File to write output. + err_file: File to write err. """ self.write_input(output_dir=run_dir) vasp_cmd = vasp_cmd or SETTINGS.get("PMG_VASP_EXE") # type: ignore[assignment] diff --git a/pymatgen/io/vasp/optics.py b/pymatgen/io/vasp/optics.py index efc17ad6ada..35ccd1a3cdf 100644 --- a/pymatgen/io/vasp/optics.py +++ b/pymatgen/io/vasp/optics.py @@ -19,6 +19,7 @@ from pathlib import Path from numpy.typing import ArrayLike, NDArray + from typing_extensions import Self __author__ = "Jimmy-Xuan Shen" __copyright__ = "Copyright 2022, The Materials Project" @@ -70,7 +71,7 @@ class DielectricFunctionCalculator(MSONable): volume: float @classmethod - def from_vasp_objects(cls, vrun: Vasprun, waveder: Waveder): + def from_vasp_objects(cls, vrun: Vasprun, waveder: Waveder) -> Self: """Construct a DielectricFunction from Vasprun, Kpoint, and Waveder objects. Args: @@ -94,7 +95,7 @@ def from_vasp_objects(cls, vrun: Vasprun, waveder: Waveder): if vrun.parameters["ISYM"] != 0: raise NotImplementedError("ISYM != 0 is not implemented yet") - return DielectricFunctionCalculator( + return cls( cder_real=waveder.cder_real, cder_imag=waveder.cder_imag, eigs=eigs, @@ -110,7 +111,7 @@ def from_vasp_objects(cls, vrun: Vasprun, waveder: Waveder): ) @classmethod - def from_directory(cls, directory: Path | str): + def from_directory(cls, directory: Path | str) -> Self: """Construct a DielectricFunction from a directory containing vasprun.xml and WAVEDER files.""" def _try_reading(dtypes): @@ -306,7 +307,7 @@ def get_delta(x0: float, sigma: float, nx: int, dx: float, ismear: int = 3): dx: The gridspacing of the output grid. ismear: The smearing parameter used by the ``step_func``. - Return: + Returns: np.array: Array of size `nx` with delta function on the desired outputgrid. """ xgrid = np.linspace(0, nx * dx, nx, endpoint=False) @@ -330,7 +331,7 @@ def get_step(x0, sigma, nx, dx, ismear): dx: The gridspacing of the output grid. ismear: The smearing parameter used by the ``step_func``. - Return: + Returns: np.array: Array of size `nx` with step function on the desired outputgrid. """ xgrid = np.linspace(0, nx * dx, nx, endpoint=False) @@ -367,7 +368,7 @@ def epsilon_imag( jdir: The second direction of the dielectric tensor mask: Mask for the bands/kpoint/spin index to include in the calculation - Return: + Returns: np.array: Array of size `nedos` with the imaginary part of the dielectric function. """ norm_kweights = np.array(kweights) / np.sum(kweights) @@ -433,7 +434,7 @@ def kramers_kronig( deltae: The energy grid spacing cshift: The shift of the imaginary part of the dielectric function. - Return: + Returns: np.array: Array of size `nedos` with the complex dielectric function. """ egrid = np.linspace(0, deltae * nedos, nedos) diff --git a/pymatgen/io/vasp/outputs.py b/pymatgen/io/vasp/outputs.py index 16256904e0c..deacb1231c2 100644 --- a/pymatgen/io/vasp/outputs.py +++ b/pymatgen/io/vasp/outputs.py @@ -16,7 +16,7 @@ from glob import glob from io import StringIO from pathlib import Path -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np from monty.io import reverse_readfile, zopen @@ -42,6 +42,9 @@ from pymatgen.util.io_utils import clean_lines, micro_pyawk from pymatgen.util.num import make_symmetric_matrix_from_upper_tri +if TYPE_CHECKING: + from typing_extensions import Self + logger = logging.getLogger(__name__) @@ -86,23 +89,23 @@ def _parse_v_parameters(val_type, val, filename, param_name): elif val_type == "int": try: val = [int(i) for i in val.split()] - except ValueError: + except ValueError as exc: # Fix for stupid error in vasprun sometimes which displays # LDAUL/J as 2**** val = _parse_from_incar(filename, param_name) if val is None: - raise err + raise err from exc elif val_type == "string": val = val.split() else: try: val = [float(i) for i in val.split()] - except ValueError: + except ValueError as exc: # Fix for stupid error in vasprun sometimes which displays # MAGMOM as 2**** val = _parse_from_incar(filename, param_name) if val is None: - raise err + raise err from exc return val @@ -115,10 +118,10 @@ def _parse_vasp_array(elem) -> list[list[float]]: def _parse_from_incar(filename: str, key: str) -> str | None: """Helper function to parse a parameter from the INCAR.""" dirname = os.path.dirname(filename) - for f in os.listdir(dirname): - if re.search(r"INCAR", f): - warnings.warn("INCAR found. Using " + key + " from INCAR.") - incar = Incar.from_file(os.path.join(dirname, f)) + for filename in os.listdir(dirname): + if re.search(r"INCAR", filename): + warnings.warn(f"INCAR found. Using {key} from INCAR.") + incar = Incar.from_file(os.path.join(dirname, filename)) if key in incar: return incar[key] return None @@ -352,7 +355,7 @@ def _parse(self, stream, parse_dos, parse_eigen, parse_projected_eigen): # The start event tells us when we have entered blocks if tag == "calculation": parsed_header = True - elif tag == "eigenvalues_kpoints_opt" or tag == "projected_kpoints_opt": + elif tag in ("eigenvalues_kpoints_opt", "projected_kpoints_opt"): in_kpoints_opt = True else: # event == "end": # The end event happens when we have read a block, so have @@ -408,7 +411,7 @@ def _parse(self, stream, parse_dos, parse_eigen, parse_projected_eigen): self.eigenvalues = self._parse_eigen(elem) elif parse_projected_eigen and tag == "projected" and not in_kpoints_opt: self.projected_eigenvalues, self.projected_magnetisation = self._parse_projected_eigen(elem) - elif tag == "eigenvalues_kpoints_opt" or tag == "projected_kpoints_opt": + elif tag in ("eigenvalues_kpoints_opt", "projected_kpoints_opt"): in_kpoints_opt = False if self.kpoints_opt_props is None: self.kpoints_opt_props = KpointOptProps() @@ -552,7 +555,7 @@ def dielectric(self): return self.dielectric_data["density"] @property - def optical_absorption_coeff(self) -> list[float]: + def optical_absorption_coeff(self) -> list[float] | None: """ Calculate the optical absorption coefficient from the dielectric constants. Note that this method is only @@ -563,11 +566,11 @@ def optical_absorption_coeff(self) -> list[float]: """ if self.dielectric_data["density"]: real_avg = [ - sum(self.dielectric_data["density"][1][i][0:3]) / 3 + sum(self.dielectric_data["density"][1][i][:3]) / 3 for i in range(len(self.dielectric_data["density"][0])) ] imag_avg = [ - sum(self.dielectric_data["density"][2][i][0:3]) / 3 + sum(self.dielectric_data["density"][2][i][:3]) / 3 for i in range(len(self.dielectric_data["density"][0])) ] @@ -579,10 +582,10 @@ def optical_absorb_coeff(freq, real, imag): hc = 1.23984 * 1e-4 # plank constant times speed of light, in the unit of eV*cm return 2 * 3.14159 * np.sqrt(np.sqrt(real**2 + imag**2) - real) * np.sqrt(2) / hc * freq - absorption_coeff = list( + return list( itertools.starmap(optical_absorb_coeff, zip(self.dielectric_data["density"][0], real_avg, imag_avg)) ) - return absorption_coeff + return None @property def converged_electronic(self) -> bool: @@ -883,7 +886,7 @@ def get_band_structure( if not kpoints_filename: kpts_path = os.path.join(os.path.dirname(self.filename), "KPOINTS_OPT" if use_kpoints_opt else "KPOINTS") kpoints_filename = zpath(kpts_path) - if kpoints_filename and not os.path.isfile(kpoints_filename) and line_mode is True: + if kpoints_filename and not os.path.isfile(kpoints_filename) and line_mode: name = "KPOINTS_OPT" if use_kpoints_opt else "KPOINTS" raise VaspParseError(f"{name} not found but needed to obtain band structure along symmetry lines.") @@ -971,10 +974,9 @@ def get_band_structure( eigenvals[Spin.up] = up_eigen else: if "" in kpoint_file.labels: - raise Exception( - "A band structure along symmetry lines " - "requires a label for each kpoint. " - "Check your KPOINTS file" + raise ValueError( + "A band structure along symmetry lines requires a label " + "for each kpoint. Check your KPOINTS file" ) labels_dict = dict(zip(kpoint_file.labels, kpoint_file.kpts)) labels_dict.pop(None, None) @@ -1162,7 +1164,8 @@ def update_charge_from_potcar(self, path): """ Sets the charge of a structure based on the POTCARs found. - :param path: Path to search for POTCARs + Args: + path: Path to search for POTCARs """ potcar = self.get_potcars(path) @@ -1260,7 +1263,7 @@ def as_dict(self): if self.eigenvalues: eigen = {str(spin): v.tolist() for spin, v in self.eigenvalues.items()} vout["eigenvalues"] = eigen - (gap, cbm, vbm, is_direct) = self.eigenvalue_band_properties + gap, cbm, vbm, is_direct = self.eigenvalue_band_properties vout.update({"bandgap": gap, "cbm": cbm, "vbm": vbm, "is_gap_direct": is_direct}) if self.projected_eigenvalues: @@ -1275,7 +1278,7 @@ def as_dict(self): eigen = {str(spin): v.tolist() for spin, v in kpt_opt_props.eigenvalues.items()} vout["eigenvalues_kpoints_opt"] = eigen # TODO implement kpoints_opt eigenvalue_band_proprties. - # (gap, cbm, vbm, is_direct) = self.eigenvalue_band_properties + # gap, cbm, vbm, is_direct = self.eigenvalue_band_properties # vout.update({"bandgap": gap, "cbm": cbm, "vbm": vbm, "is_gap_direct": is_direct}) if kpt_opt_props.projected_eigenvalues: @@ -1322,6 +1325,9 @@ def _parse_params(self, elem): @staticmethod def _parse_atominfo(elem): + atomic_symbols = [] + potcar_symbols = [] + for a in elem.findall("array"): if a.attrib["name"] == "atoms": atomic_symbols = [rc.find("c").text.strip() for rc in a.find("set")] @@ -1350,6 +1356,7 @@ def _parse_kpoints(elem): e = elem.find("generation") k = Kpoints("Kpoints from vasprun.xml") k.style = Kpoints.supported_modes.from_str(e.attrib.get("param", "Reciprocal")) + for v in e.findall("v"): name = v.attrib.get("name") tokens = v.text.split() @@ -1359,6 +1366,9 @@ def _parse_kpoints(elem): k.kpts_shift = [float(i) for i in tokens] elif name in {"genvec1", "genvec2", "genvec3", "shift"}: setattr(k, name, [float(i) for i in tokens]) + + actual_kpoints = [] + weights = [] for va in elem.findall("varray"): name = va.attrib["name"] if name == "kpointlist": @@ -1366,6 +1376,7 @@ def _parse_kpoints(elem): elif name == "weights": weights = [i[0] for i in _parse_vasp_array(va)] elem.clear() + if k.style == Kpoints.supported_modes.Reciprocal: k = Kpoints( comment="Kpoints from vasprun.xml", @@ -1377,12 +1388,12 @@ def _parse_kpoints(elem): return k, actual_kpoints, weights def _parse_structure(self, elem): - latt = _parse_vasp_array(elem.find("crystal").find("varray")) + lattice = _parse_vasp_array(elem.find("crystal").find("varray")) pos = _parse_vasp_array(elem.find("varray")) - struct = Structure(latt, self.atomic_symbols, pos) - sdyn = elem.find("varray/[@name='selective']") - if sdyn: - struct.add_site_property("selective_dynamics", _parse_vasp_array(sdyn)) + struct = Structure(lattice, self.atomic_symbols, pos) + selective_dyn = elem.find("varray/[@name='selective']") + if selective_dyn: + struct.add_site_property("selective_dynamics", _parse_vasp_array(selective_dyn)) return struct @staticmethod @@ -1403,9 +1414,11 @@ def _parse_optical_transition(elem): for va in elem.findall("varray"): if va.attrib.get("name") == "opticaltransitions": # optical transitions array contains oscillator strength and probability of transition - oscillator_strength = np.array(_parse_vasp_array(va))[0:] - probability_transition = np.array(_parse_vasp_array(va))[0:, 1] - return oscillator_strength, probability_transition + oscillator_strength = np.array(_parse_vasp_array(va))[:] + probability_transition = np.array(_parse_vasp_array(va))[:, 1] + + return oscillator_strength, probability_transition + return None def _parse_chemical_shielding_calculation(self, elem): calculation = [] @@ -1612,58 +1625,53 @@ def __init__( or (tag == "dos" and elem.attrib.get("comment") == "kpoints_opt") ): in_kpoints_opt = True - else: # if event == "end": - if not parsed_header: - if tag == "generator": - self.generator = self._parse_params(elem) - elif tag == "incar": - self.incar = self._parse_params(elem) - elif tag == "kpoints": - ( - self.kpoints, - self.actual_kpoints, - self.actual_kpoints_weights, - ) = self._parse_kpoints(elem) - elif tag == "parameters": - self.parameters = self._parse_params(elem) - elif tag == "atominfo": - self.atomic_symbols, self.potcar_symbols = self._parse_atominfo(elem) - self.potcar_spec = [ - {"titel": p, "hash": None, "summary_stats": {}} for p in self.potcar_symbols - ] - parsed_header = True - elif tag == "i" and elem.attrib.get("name") == "efermi": - if in_kpoints_opt: - if self.kpoints_opt_props is None: - self.kpoints_opt_props = KpointOptProps() - self.kpoints_opt_props.efermi = float(elem.text) - in_kpoints_opt = False - else: - self.efermi = float(elem.text) - elif tag == "eigenvalues" and not in_kpoints_opt: - self.eigenvalues = self._parse_eigen(elem) - elif parse_projected_eigen and tag == "projected" and not in_kpoints_opt: - self.projected_eigenvalues, self.projected_magnetisation = self._parse_projected_eigen(elem) - elif tag == "eigenvalues_kpoints_opt" or tag == "projected_kpoints_opt": + elif not parsed_header: + if tag == "generator": + self.generator = self._parse_params(elem) + elif tag == "incar": + self.incar = self._parse_params(elem) + elif tag == "kpoints": + self.kpoints, self.actual_kpoints, self.actual_kpoints_weights = self._parse_kpoints(elem) + elif tag == "parameters": + self.parameters = self._parse_params(elem) + elif tag == "atominfo": + self.atomic_symbols, self.potcar_symbols = self._parse_atominfo(elem) + self.potcar_spec = [ + {"titel": p, "hash": None, "summary_stats": {}} for p in self.potcar_symbols + ] + parsed_header = True + elif tag == "i" and elem.attrib.get("name") == "efermi": + if in_kpoints_opt: if self.kpoints_opt_props is None: self.kpoints_opt_props = KpointOptProps() + self.kpoints_opt_props.efermi = float(elem.text) in_kpoints_opt = False - # projected_kpoints_opt includes occupation information whereas - # eigenvalues_kpoints_opt doesn't. - self.kpoints_opt_props.eigenvalues = self._parse_eigen(elem.find("eigenvalues")) - if tag == "eigenvalues_kpoints_opt": - ( - self.kpoints_opt_props.kpoints, - self.kpoints_opt_props.actual_kpoints, - self.kpoints_opt_props.actual_kpoints_weights, - ) = self._parse_kpoints(elem.find("kpoints")) - elif parse_projected_eigen: # and tag == "projected_kpoints_opt": (implied) - ( - self.kpoints_opt_props.projected_eigenvalues, - self.kpoints_opt_props.projected_magnetisation, - ) = self._parse_projected_eigen(elem) - elif tag == "structure" and elem.attrib.get("name") == "finalpos": - self.final_structure = self._parse_structure(elem) + else: + self.efermi = float(elem.text) + elif tag == "eigenvalues" and not in_kpoints_opt: + self.eigenvalues = self._parse_eigen(elem) + elif parse_projected_eigen and tag == "projected" and not in_kpoints_opt: + self.projected_eigenvalues, self.projected_magnetisation = self._parse_projected_eigen(elem) + elif tag in ("eigenvalues_kpoints_opt", "projected_kpoints_opt"): + if self.kpoints_opt_props is None: + self.kpoints_opt_props = KpointOptProps() + in_kpoints_opt = False + # projected_kpoints_opt includes occupation information whereas + # eigenvalues_kpoints_opt doesn't. + self.kpoints_opt_props.eigenvalues = self._parse_eigen(elem.find("eigenvalues")) + if tag == "eigenvalues_kpoints_opt": + ( + self.kpoints_opt_props.kpoints, + self.kpoints_opt_props.actual_kpoints, + self.kpoints_opt_props.actual_kpoints_weights, + ) = self._parse_kpoints(elem.find("kpoints")) + elif parse_projected_eigen: # and tag == "projected_kpoints_opt": (implied) + ( + self.kpoints_opt_props.projected_eigenvalues, + self.kpoints_opt_props.projected_magnetisation, + ) = self._parse_projected_eigen(elem) + elif tag == "structure" and elem.attrib.get("name") == "finalpos": + self.final_structure = self._parse_structure(elem) self.vasp_version = self.generator["version"] if parse_potcar_file: self.update_potcar_spec(parse_potcar_file) @@ -1727,7 +1735,7 @@ def as_dict(self): for idx, val in enumerate(values): eigen[idx][str(spin)] = val vout["eigenvalues"] = eigen - (gap, cbm, vbm, is_direct) = self.eigenvalue_band_properties + gap, cbm, vbm, is_direct = self.eigenvalue_band_properties vout.update({"bandgap": gap, "cbm": cbm, "vbm": vbm, "is_gap_direct": is_direct}) if self.projected_eigenvalues: @@ -1742,7 +1750,7 @@ def as_dict(self): eigen = {str(spin): v.tolist() for spin, v in kpt_opt_props.eigenvalues.items()} vout["eigenvalues_kpoints_opt"] = eigen # TODO implement kpoints_opt eigenvalue_band_proprties. - # (gap, cbm, vbm, is_direct) = self.eigenvalue_band_properties + # gap, cbm, vbm, is_direct = self.eigenvalue_band_properties # vout.update({"bandgap": gap, "cbm": cbm, "vbm": vbm, "is_gap_direct": is_direct}) if kpt_opt_props.projected_eigenvalues: @@ -1862,36 +1870,29 @@ def __init__(self, filename): except ValueError: run_stats[tok[0].strip()] = None continue - m = efermi_patt.search(clean) - if m: + + if match := efermi_patt.search(clean): try: # try-catch because VASP sometimes prints # 'E-fermi: ******** XC(G=0): -6.1327 # alpha+bet : -1.8238' - efermi = float(m.group(1)) + efermi = float(match.group(1)) continue except ValueError: efermi = None continue - m = nelect_patt.search(clean) - if m: - nelect = float(m.group(1)) - m = mag_patt.search(clean) - if m: - total_mag = float(m.group(1)) - - if e_fr_energy is None: - m = e_fr_energy_pattern.search(clean) - if m: - e_fr_energy = float(m.group(1)) - if e_wo_entrp is None: - m = e_wo_entrp_pattern.search(clean) - if m: - e_wo_entrp = float(m.group(1)) - if e0 is None: - m = e0_pattern.search(clean) - if m: - e0 = float(m.group(1)) + if match := nelect_patt.search(clean): + nelect = float(match.group(1)) + + if match := mag_patt.search(clean): + total_mag = float(match.group(1)) + + if e_fr_energy is None and (match := e_fr_energy_pattern.search(clean)): + e_fr_energy = float(match.group(1)) + if e_wo_entrp is None and (match := e_wo_entrp_pattern.search(clean)): + e_wo_entrp = float(match.group(1)) + if e0 is None and (match := e0_pattern.search(clean)): + e0 = float(match.group(1)) if all([nelect, total_mag is not None, efermi is not None, run_stats]): break @@ -1907,24 +1908,22 @@ def __init__(self, filename): if clean.startswith("# of ion"): header = re.split(r"\s{2,}", clean.strip()) header.pop(0) - else: - m = re.match(r"\s*(\d+)\s+(([\d\.\-]+)\s+)+", clean) - if m: - tokens = [float(i) for i in re.findall(r"[\d\.\-]+", clean)] - tokens.pop(0) - if read_charge: - charge.append(dict(zip(header, tokens))) - elif read_mag_x: - mag_x.append(dict(zip(header, tokens))) - elif read_mag_y: - mag_y.append(dict(zip(header, tokens))) - elif read_mag_z: - mag_z.append(dict(zip(header, tokens))) - elif clean.startswith("tot"): - read_charge = False - read_mag_x = False - read_mag_y = False - read_mag_z = False + elif match := re.match(r"\s*(\d+)\s+(([\d\.\-]+)\s+)+", clean): + tokens = [float(i) for i in re.findall(r"[\d\.\-]+", clean)] + tokens.pop(0) + if read_charge: + charge.append(dict(zip(header, tokens))) + elif read_mag_x: + mag_x.append(dict(zip(header, tokens))) + elif read_mag_y: + mag_y.append(dict(zip(header, tokens))) + elif read_mag_z: + mag_z.append(dict(zip(header, tokens))) + elif clean.startswith("tot"): + read_charge = False + read_mag_x = False + read_mag_y = False + read_mag_z = False if clean == "total charge": charge = [] read_charge = True @@ -2262,9 +2261,8 @@ def _parse_sci_notation(line): Returns: list[float]: numbers if found, empty ist if not """ - m = re.findall(r"[\.\-\d]+E[\+\-]\d{2}", line) - if m: - return [float(t) for t in m] + if match := re.findall(r"[\.\-\d]+E[\+\-]\d{2}", line): + return [float(t) for t in match] return [] def read_freq_dielectric(self): @@ -2750,8 +2748,7 @@ def p_ion(results, match): self.er_bp_tot = self.er_bp[Spin.up] + self.er_bp[Spin.down] except Exception: - self.er_ev_tot = self.er_bp_tot = None - raise Exception("IGPAR OUTCAR could not be parsed.") + raise RuntimeError("IGPAR OUTCAR could not be parsed.") def read_internal_strain_tensor(self): """ @@ -2780,7 +2777,7 @@ def internal_strain_data(results, match): elif match.group(1).lower() == "z": index = 2 else: - raise Exception(f"Couldn't parse row index from symbol for internal strain tensor: {match.group(1)}") + raise IndexError(f"Couldn't parse row index from symbol for internal strain tensor: {match.group(1)}") results.internal_strain_tensor[results.internal_strain_ion][index] = np.array( [float(match.group(i)) for i in range(2, 8)] ) @@ -2952,7 +2949,7 @@ def born_section_stop(results, _match): self.piezo_tensor = self.piezo_tensor.tolist() except Exception: - raise Exception("LEPSILON OUTCAR could not be parsed.") + raise RuntimeError("LEPSILON OUTCAR could not be parsed.") def read_lepsilon_ionic(self): """ @@ -3069,7 +3066,7 @@ def piezo_section_stop(results, _match): self.piezo_ionic_tensor = self.piezo_ionic_tensor.tolist() except Exception: - raise Exception("ionic part of LEPSILON OUTCAR could not be parsed.") + raise RuntimeError("ionic part of LEPSILON OUTCAR could not be parsed.") def read_lcalcpol(self): """ @@ -3161,9 +3158,9 @@ def p_ion(results, match): # fix polarization units in new versions of vasp regex = r"^.*Ionic dipole moment: .*" search = [[regex, None, lambda x, y: x.append(y.group(0))]] - r = micro_pyawk(self.filename, search, []) + results = micro_pyawk(self.filename, search, []) - if "|e|" in r[0]: + if "|e|" in results[0]: self.p_elec *= -1 self.p_ion *= -1 if self.spin and not self.noncollinear: @@ -3172,7 +3169,7 @@ def p_ion(results, match): except Exception as exc: print(exc.args) - raise Exception("LCALCPOL OUTCAR could not be parsed.") from exc + raise RuntimeError("LCALCPOL OUTCAR could not be parsed.") from exc def read_pseudo_zval(self): """Create pseudopotential ZVAL dictionary.""" @@ -3196,14 +3193,14 @@ def zvals(results, match): zval_dict = {} for x, y in zip(self.atom_symbols, self.zvals): - zval_dict.update({x: y}) + zval_dict[x] = y self.zval_dict = zval_dict # Clean-up del self.atom_symbols del self.zvals except Exception: - raise Exception("ZVAL dict could not be parsed.") + raise RuntimeError("ZVAL dict could not be parsed.") def read_core_state_eigen(self): """ @@ -3220,11 +3217,15 @@ def read_core_state_eigen(self): """ with zopen(self.filename, mode="rt") as foutcar: line = foutcar.readline() + cl = [] + while line != "": line = foutcar.readline() + if "NIONS =" in line: natom = int(line.split("NIONS =")[1]) - cl = [defaultdict(list) for i in range(natom)] + cl = [defaultdict(list) for _ in range(natom)] + if "the core state eigen" in line: iat = -1 while line != "": @@ -3557,10 +3558,10 @@ def _print_fortran_float(flt): Returns: str: String representation of float in Fortran format. """ - s = f"{flt:.10E}" + flt_str = f"{flt:.10E}" if flt >= 0: - return f"0.{s[0]}{s[2:12]}E{int(s[13:]) + 1:+03}" - return f"-.{s[1]}{s[3:13]}E{int(s[14:]) + 1:+03}" + return f"0.{flt_str[0]}{flt_str[2:12]}E{int(flt_str[13:]) + 1:+03}" + return f"-.{flt_str[1]}{flt_str[3:13]}E{int(flt_str[14:]) + 1:+03}" with zopen(file_name, mode="wt") as file: poscar = Poscar(self.structure) @@ -3603,12 +3604,13 @@ def write_spin(data_type): file.write("".join(data)) write_spin("total") - if self.is_spin_polarized and self.is_soc: - write_spin("diff_x") - write_spin("diff_y") - write_spin("diff_z") - elif self.is_spin_polarized: - write_spin("diff") + if self.is_spin_polarized: + if self.is_soc: + write_spin("diff_x") + write_spin("diff_y") + write_spin("diff_z") + else: + write_spin("diff") class Locpot(VolumetricData): @@ -3624,7 +3626,7 @@ def __init__(self, poscar: Poscar, data: np.ndarray, **kwargs) -> None: self.name = poscar.comment @classmethod - def from_file(cls, filename, **kwargs): + def from_file(cls, filename: str, **kwargs) -> Self: """Read a LOCPOT file. Args: @@ -3656,12 +3658,14 @@ def __init__(self, poscar, data, data_aug=None): struct = poscar self.poscar = Poscar(poscar) self.name = None + else: + raise TypeError("Unsupported POSCAR type.") super().__init__(struct, data, data_aug=data_aug) self._distance_matrix = {} @classmethod - def from_file(cls, filename: str): + def from_file(cls, filename: str) -> Self: """Read a CHGCAR file. Args: @@ -3676,9 +3680,7 @@ def from_file(cls, filename: str): @property def net_magnetization(self): """Net magnetization from Chgcar""" - if self.is_spin_polarized: - return np.sum(self.data["diff"]) - return None + return np.sum(self.data["diff"]) if self.is_spin_polarized else None class Elfcar(VolumetricData): @@ -3704,6 +3706,8 @@ def __init__(self, poscar, data): elif isinstance(poscar, Structure): tmp_struct = poscar self.poscar = Poscar(poscar) + else: + raise TypeError("Unsupported POSCAR type.") super().__init__(tmp_struct, data) # TODO: modify VolumetricData so that the correct keys can be used. @@ -3713,26 +3717,27 @@ def __init__(self, poscar, data): self.data = data @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str) -> Self: """ Reads a ELFCAR file. - :param filename: Filename + Args: + filename: Filename Returns: Elfcar """ - (poscar, data, _data_aug) = VolumetricData.parse_file(filename) + poscar, data, _data_aug = VolumetricData.parse_file(filename) return cls(poscar, data) def get_alpha(self): """Get the parameter alpha where ELF = 1/(1+alpha^2).""" alpha_data = {} - for k, v in self.data.items(): - alpha = 1 / v + for key, val in self.data.items(): + alpha = 1 / val alpha = alpha - 1 alpha = np.sqrt(alpha) - alpha_data[k] = alpha + alpha_data[key] = alpha return VolumetricData(self.structure, alpha_data) @@ -3760,31 +3765,39 @@ def __init__(self, filename): headers = None with zopen(filename, mode="rt") as file_handle: - preambleexpr = re.compile(r"# of k-points:\s*(\d+)\s+# of bands:\s*(\d+)\s+# of ions:\s*(\d+)") - kpointexpr = re.compile(r"^k-point\s+(\d+).*weight = ([0-9\.]+)") - bandexpr = re.compile(r"^band\s+(\d+)") - ionexpr = re.compile(r"^ion.*") + preamble_expr = re.compile(r"# of k-points:\s*(\d+)\s+# of bands:\s*(\d+)\s+# of ions:\s*(\d+)") + kpoint_expr = re.compile(r"^k-point\s+(\d+).*weight = ([0-9\.]+)") + band_expr = re.compile(r"^band\s+(\d+)") + ion_expr = re.compile(r"^ion.*") expr = re.compile(r"^([0-9]+)\s+") current_kpoint = 0 current_band = 0 done = False spin = Spin.down + weights = None + n_kpoints = None + n_bands = None + n_ions = None + weights = [] + headers = None + data = None + phase_factors = None for line in file_handle: line = line.strip() - if bandexpr.match(line): - m = bandexpr.match(line) - current_band = int(m.group(1)) - 1 + if band_expr.match(line): + match = band_expr.match(line) + current_band = int(match.group(1)) - 1 done = False - elif kpointexpr.match(line): - m = kpointexpr.match(line) - current_kpoint = int(m.group(1)) - 1 - weights[current_kpoint] = float(m.group(2)) + elif kpoint_expr.match(line): + match = kpoint_expr.match(line) + current_kpoint = int(match.group(1)) - 1 + weights[current_kpoint] = float(match.group(2)) if current_kpoint == 0: spin = Spin.up if spin == Spin.down else Spin.down done = False - elif headers is None and ionexpr.match(line): + elif headers is None and ion_expr.match(line): headers = line.split() headers.pop(0) headers.pop(-1) @@ -3792,11 +3805,7 @@ def __init__(self, filename): data = defaultdict(lambda: np.zeros((n_kpoints, n_bands, n_ions, len(headers)))) phase_factors = defaultdict( - lambda: np.full( - (n_kpoints, n_bands, n_ions, len(headers)), - np.nan, - dtype=np.complex128, - ) + lambda: np.full((n_kpoints, n_bands, n_ions, len(headers)), np.nan, dtype=np.complex128) ) elif expr.match(line): tokens = line.split() @@ -3818,11 +3827,11 @@ def __init__(self, filename): phase_factors[spin][current_kpoint, current_band, index, :] += 1j * num_data elif line.startswith("tot"): done = True - elif preambleexpr.match(line): - m = preambleexpr.match(line) - n_kpoints = int(m.group(1)) - n_bands = int(m.group(2)) - n_ions = int(m.group(3)) + elif preamble_expr.match(line): + match = preamble_expr.match(line) + n_kpoints = int(match.group(1)) + n_bands = int(match.group(2)) + n_ions = int(match.group(3)) weights = np.zeros(n_kpoints) self.nkpoints = n_kpoints @@ -3845,7 +3854,7 @@ def get_projection_on_elements(self, structure: Structure): """ dico: dict[Spin, list] = {} for spin in self.data: - dico[spin] = [[defaultdict(float) for i in range(self.nkpoints)] for j in range(self.nbands)] + dico[spin] = [[defaultdict(float) for _ in range(self.nkpoints)] for _ in range(self.nbands)] for iat in range(self.nions): name = structure.species[iat].symbol @@ -4045,12 +4054,13 @@ def __init__(self, filename, ionicstep_start=1, ionicstep_end=None, comment=None structures = [] preamble_done = False if ionicstep_start < 1: - raise Exception("Start ionic step cannot be less than 1") - if ionicstep_end is not None and ionicstep_start < 1: - raise Exception("End ionic step cannot be less than 1") + raise ValueError("Start ionic step cannot be less than 1") + if ionicstep_end is not None and ionicstep_end < 1: + raise ValueError("End ionic step cannot be less than 1") ionicstep_cnt = 1 with zopen(filename, mode="rt") as file: + title = None for line in file: line = line.strip() if preamble is None: @@ -4139,9 +4149,9 @@ def concatenate(self, filename, ionicstep_start=1, ionicstep_end=None): structures = self.structures preamble_done = False if ionicstep_start < 1: - raise Exception("Start ionic step cannot be less than 1") - if ionicstep_end is not None and ionicstep_start < 1: - raise Exception("End ionic step cannot be less than 1") + raise ValueError("Start ionic step cannot be less than 1") + if ionicstep_end is not None and ionicstep_end < 1: + raise ValueError("End ionic step cannot be less than 1") ionicstep_cnt = 1 with zopen(filename, mode="rt") as file: @@ -4190,19 +4200,19 @@ def get_str(self, ionicstep_start: int = 1, ionicstep_end: int | None = None, si significant_figures (int): Number of significant figures. """ if ionicstep_start < 1: - raise Exception("Start ionic step cannot be less than 1") + raise ValueError("Start ionic step cannot be less than 1") if ionicstep_end is not None and ionicstep_end < 1: - raise Exception("End ionic step cannot be less than 1") - latt = self.structures[0].lattice - if np.linalg.det(latt.matrix) < 0: - latt = Lattice(-latt.matrix) - lines = [self.comment, "1.0", str(latt)] + raise ValueError("End ionic step cannot be less than 1") + lattice = self.structures[0].lattice + if np.linalg.det(lattice.matrix) < 0: + lattice = Lattice(-lattice.matrix) + lines = [self.comment, "1.0", str(lattice)] lines.extend((" ".join(self.site_symbols), " ".join(str(x) for x in self.natoms))) format_str = f"{{:.{significant_figures}f}}" ionicstep_cnt = 1 output_cnt = 1 - for cnt, structure in enumerate(self.structures): - ionicstep_cnt = cnt + 1 + for cnt, structure in enumerate(self.structures, start=1): + ionicstep_cnt = cnt if ionicstep_end is None: if ionicstep_cnt >= ionicstep_start: lines.append(f"Direct configuration={' ' * (7 - len(str(output_cnt)))}{output_cnt}") @@ -4416,9 +4426,9 @@ def __init__(self, filename="WAVECAR", verbose=False, precision="normal", vasp_t values are ['std', 'gam', 'ncl'] (only first letter is required) """ self.filename = filename - valid_types = ["std", "gam", "ncl"] + valid_types = {"std", "gam", "ncl"} initials = {x[0] for x in valid_types} - if not (vasp_type is None or vasp_type.lower()[0] in initials): + if vasp_type is not None and vasp_type.lower()[0] not in initials: raise ValueError( f"invalid {vasp_type=}, must be one of {valid_types} (we only check the first letter {initials})" ) @@ -4491,7 +4501,7 @@ def __init__(self, filename="WAVECAR", verbose=False, precision="normal", vasp_t self.Gpoints = [None for _ in range(self.nk)] self.kpoints = [] if spin == 2: - self.coeffs = [[[None for i in range(self.nb)] for j in range(self.nk)] for _ in range(spin)] + self.coeffs = [[[None for _ in range(self.nb)] for _ in range(self.nk)] for _ in range(spin)] self.band_energy = [[] for _ in range(spin)] else: self.coeffs = [[None for i in range(self.nb)] for j in range(self.nk)] @@ -4562,6 +4572,8 @@ def __init__(self, filename="WAVECAR", verbose=False, precision="normal", vasp_t # but I don't have a WAVECAR to test it with data = np.fromfile(file, dtype=np.complex128, count=nplane) np.fromfile(file, dtype=np.float64, count=recl8 - 2 * nplane) + else: + raise RuntimeError("Invalid rtag value.") extra_coeffs = [] if len(extra_coeff_inds) > 0: @@ -4738,9 +4750,7 @@ def fft_mesh(self, kpoint: int, band: int, spin: int = 0, spinor: int = 0, shift t = tuple(gp.astype(int) + (self.ng / 2).astype(int)) mesh[t] = coeff - if shift: - return np.fft.ifftshift(mesh) - return mesh + return np.fft.ifftshift(mesh) if shift else mesh def get_parchg( self, @@ -4822,7 +4832,7 @@ def get_parchg( den = np.abs(np.conj(wfr) * wfr) den += np.abs(np.conj(wfr_t) * wfr_t) - if phase and not (self.vasp_type.lower()[0] == "n" and spinor is None): + if phase and (self.vasp_type.lower()[0] != "n" or spinor is not None): den = np.sign(np.real(wfr)) * den data["total"] = den @@ -4859,7 +4869,7 @@ def write_unks(self, directory: str) -> None: for ib in range(self.nb): data[ib, 0, :, :, :] = np.fft.ifftn(self.fft_mesh(ik, ib, spinor=0)) * N data[ib, 1, :, :, :] = np.fft.ifftn(self.fft_mesh(ik, ib, spinor=1)) * N - Unk(ik + 1, data).write_file(str(out_dir / (fname + "NC"))) + Unk(ik + 1, data).write_file(str(out_dir / f"{fname}NC")) else: data = np.empty((self.nb, *self.ng), dtype=np.complex128) for ispin in range(self.spin): @@ -4897,9 +4907,6 @@ def __init__(self, filename, occu_tol=1e-8, separate_spins=False): reported for each individual spin channel. Defaults to False, which computes the eigenvalue band properties independent of the spin orientation. If True, the calculation must be spin-polarized. - - Returns: - a pymatgen.io.vasp.outputs.Eigenval object """ self.filename = filename self.occu_tol = occu_tol @@ -5023,7 +5030,7 @@ class Waveder(MSONable): cder_imag: np.ndarray @classmethod - def from_formatted(cls, filename): + def from_formatted(cls, filename: str) -> Self: """Reads the WAVEDERF file and returns a Waveder object. Note: This file is only produced when LOPTICS is true AND vasp has been @@ -5053,7 +5060,7 @@ def from_formatted(cls, filename): return cls(cder_real, cder_imag) @classmethod - def from_binary(cls, filename, data_type="complex64"): + def from_binary(cls, filename: str, data_type: str = "complex64") -> Self: """Read the WAVEDER file and returns a Waveder object. Args: @@ -5173,7 +5180,7 @@ def data(self): return self.me_real + 1j * self.me_imag @classmethod - def from_file(cls, filename: str) -> WSWQ: + def from_file(cls, filename: str) -> Self: """Constructs a WSWQ object from a file. Args: diff --git a/pymatgen/io/vasp/sets.py b/pymatgen/io/vasp/sets.py index 0ce1db9732c..3f1c4dba468 100644 --- a/pymatgen/io/vasp/sets.py +++ b/pymatgen/io/vasp/sets.py @@ -34,12 +34,13 @@ import re import shutil import warnings +from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass, field from glob import glob from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, Union, cast from zipfile import ZipFile import numpy as np @@ -58,7 +59,7 @@ from pymatgen.util.due import Doi, due if TYPE_CHECKING: - from collections.abc import Sequence + from typing_extensions import Self from pymatgen.core.trajectory import Vector3D @@ -175,6 +176,7 @@ def write_input( same name as the InputSet (e.g., MPStaticSet.zip) """ if potcar_spec: + vasp_input = None if make_dir_if_not_present: os.makedirs(output_dir, exist_ok=True) @@ -190,19 +192,20 @@ def write_input( vasp_input.write_input(output_dir, make_dir_if_not_present=make_dir_if_not_present) cif_name = "" - if include_cif: + if include_cif and vasp_input is not None: struct = vasp_input["POSCAR"].structure cif_name = f"{output_dir}/{struct.formula.replace(' ', '')}.cif" struct.to(filename=cif_name) if zip_output: - filename = type(self).__name__ + ".zip" + filename = f"{type(self).__name__}.zip" with ZipFile(os.path.join(output_dir, filename), mode="w") as zip_file: for file in ["INCAR", "POSCAR", "KPOINTS", "POTCAR", "POTCAR.spec", cif_name]: try: zip_file.write(os.path.join(output_dir, file), arcname=file) except FileNotFoundError: pass + try: os.remove(os.path.join(output_dir, file)) except (FileNotFoundError, PermissionError, IsADirectoryError): @@ -1013,10 +1016,10 @@ def override_from_prev_calc(self, prev_calc_dir="."): wavecar_files = sorted(glob(str(Path(prev_calc_dir) / (fname + "*")))) if wavecar_files: if fname == "WFULL": - for f in wavecar_files: - fname = Path(f).name + for wavecar_file in wavecar_files: + fname = Path(wavecar_file).name fname = fname.split(".")[0] - files_to_transfer[fname] = f + files_to_transfer[fname] = wavecar_file else: files_to_transfer[fname] = str(wavecar_files[-1]) @@ -1024,7 +1027,7 @@ def override_from_prev_calc(self, prev_calc_dir="."): return self @classmethod - def from_prev_calc(cls, prev_calc_dir, **kwargs): + def from_prev_calc(cls, prev_calc_dir: str, **kwargs) -> Self: """ Generate a set of VASP input files for static calculations from a directory of previous VASP run. @@ -1113,11 +1116,10 @@ def calculate_ng( if custom_encut is not None: encut = custom_encut + elif self.incar.get("ENCUT", 0) > 0: + encut = self.incar["ENCUT"] # get the ENCUT val else: - if self.incar.get("ENCUT", 0) > 0: - encut = self.incar["ENCUT"] # get the ENCUT val - else: - encut = max(i_species.enmax for i_species in self.get_vasp_input()["POTCAR"]) + encut = max(i_species.enmax for i_species in self.get_vasp_input()["POTCAR"]) # PREC=Normal is VASP default PREC = self.incar.get("PREC", "Normal") if custom_prec is None else custom_prec @@ -2076,7 +2078,7 @@ def incar_updates(self) -> dict: return updates @classmethod - def from_prev_calc(cls, prev_calc_dir, mode="DIAG", **kwargs): + def from_prev_calc(cls, prev_calc_dir: str, mode: str = "DIAG", **kwargs) -> Self: """ Generate a set of VASP input files for GW or BSE calculations from a directory of previous Exact Diag VASP run. @@ -2347,13 +2349,13 @@ def write_input( self.kpoints.write_file(str(output_dir / "KPOINTS")) self.potcar.write_file(str(output_dir / "POTCAR")) - for i, p in enumerate(self.poscars): - d = output_dir / str(i).zfill(2) + for idx, poscar in enumerate(self.poscars): + d = output_dir / str(idx).zfill(2) if not d.exists(): d.mkdir(parents=True) - p.write_file(str(d / "POSCAR")) + poscar.write_file(str(d / "POSCAR")) if write_cif: - p.structure.to(filename=str(d / f"{i}.cif")) + poscar.structure.to(filename=str(d / f"{idx}.cif")) if write_endpoint_inputs: end_point_param = MITRelaxSet(self.structures[0], user_incar_settings=self.user_incar_settings) @@ -2362,12 +2364,12 @@ def write_input( end_point_param.kpoints.write_file(str(output_dir / image / "KPOINTS")) end_point_param.potcar.write_file(str(output_dir / image / "POTCAR")) if write_path_cif: - sites = set() - lat = self.structures[0].lattice - for site in chain(*(struct for struct in self.structures)): - sites.add(PeriodicSite(site.species, site.frac_coords, lat)) - nebpath = Structure.from_sites(sorted(sites)) - nebpath.to(filename=str(output_dir / "path.cif")) + sites = { + PeriodicSite(site.species, site.frac_coords, self.structures[0].lattice) + for site in chain(*(struct for struct in self.structures)) + } + neb_path = Structure.from_sites(sorted(sites)) + neb_path.to(filename=f"{output_dir}/path.cif") @dataclass @@ -2685,7 +2687,7 @@ def __post_init__(self): def kpoints_updates(self) -> dict | Kpoints: """Get updates to the kpoints configuration for this calculation type.""" # test, if this is okay - return {"reciprocal_density": self.reciprocal_density if self.reciprocal_density else 310} + return {"reciprocal_density": self.reciprocal_density or 310} @property def incar_updates(self) -> dict: @@ -2709,6 +2711,8 @@ def incar_updates(self) -> dict: if atom_type not in self.user_supplied_basis: raise ValueError(f"There are no basis functions for the atom type {atom_type}") basis = [f"{key} {value}" for key, value in self.user_supplied_basis.items()] + else: + basis = None lobsterin = Lobsterin(settingsdict={"basisfunctions": basis}) nbands = lobsterin._get_nbands(structure=self.structure) # type: ignore @@ -2778,8 +2782,8 @@ def get_structure_from_prev_run(vasprun, outcar=None) -> Structure: site_properties["magmom"] = vasprun.parameters["MAGMOM"] # LDAU if vasprun.parameters.get("LDAU", False): - for k in ("LDAUU", "LDAUJ", "LDAUL"): - vals = vasprun.incar[k] + for key in ("LDAUU", "LDAUJ", "LDAUL"): + vals = vasprun.incar[key] m = {} l_val = [] s = 0 @@ -2789,7 +2793,7 @@ def get_structure_from_prev_run(vasprun, outcar=None) -> Structure: s += 1 l_val.append(m[site.specie.symbol]) if len(l_val) == len(structure): - site_properties.update({k.lower(): l_val}) + site_properties.update({key.lower(): l_val}) else: raise ValueError(f"length of list {l_val} not the same as structure") @@ -3070,9 +3074,7 @@ def _get_ispin(vasprun: Vasprun | None, outcar: Outcar | None) -> int: def _combine_kpoints(*kpoints_objects: Kpoints) -> Kpoints: """Combine k-points files together.""" - labels = [] - kpoints = [] - weights = [] + labels, kpoints, weights = [], [], [] for kpoints_object in filter(None, kpoints_objects): if kpoints_object.style != Kpoints.supported_modes.Reciprocal: @@ -3092,7 +3094,7 @@ def _combine_kpoints(*kpoints_objects: Kpoints) -> Kpoints: comment="Combined k-points", style=Kpoints.supported_modes.Reciprocal, num_kpts=len(kpoints), - kpts=kpoints, + kpts=cast(Sequence[Sequence[float]], kpoints), labels=labels, kpts_weights=weights, ) diff --git a/pymatgen/io/wannier90.py b/pymatgen/io/wannier90.py index a63a88e7cf3..295b8149036 100644 --- a/pymatgen/io/wannier90.py +++ b/pymatgen/io/wannier90.py @@ -10,6 +10,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + __author__ = "Mark Turiansky" __copyright__ = "Copyright 2011, The Materials Project" __version__ = "0.1" @@ -89,7 +91,7 @@ def data(self, value: np.ndarray) -> None: self.ng = self.data.shape[-3:] @classmethod - def from_file(cls, filename: str) -> object: + def from_file(cls, filename: str) -> Self: """ Reads the UNK data from file. diff --git a/pymatgen/io/xcrysden.py b/pymatgen/io/xcrysden.py index 8e03228cbb8..fc6866f3a44 100644 --- a/pymatgen/io/xcrysden.py +++ b/pymatgen/io/xcrysden.py @@ -2,8 +2,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from pymatgen.core import Element, Structure +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Matteo Giantomassi" __copyright__ = "Copyright 2013, The Materials Project" __version__ = "0.1" @@ -20,7 +25,7 @@ def __init__(self, structure: Structure): """ self.structure = structure - def to_str(self, atom_symbol=True): + def to_str(self, atom_symbol: bool = True) -> str: """ Returns a string with the structure in XSF format See http://www.xcrysden.org/doc/XSF.html. @@ -28,30 +33,32 @@ def to_str(self, atom_symbol=True): Args: atom_symbol (bool): Uses atom symbol instead of atomic number. Defaults to True. """ - lines = [] - app = lines.append + lines: list[str] = [] - app("CRYSTAL") - app("# Primitive lattice vectors in Angstrom") - app("PRIMVEC") + lines.append("CRYSTAL") + lines.append("# Primitive lattice vectors in Angstrom") + lines.append("PRIMVEC") cell = self.structure.lattice.matrix for i in range(3): - app(f" {cell[i][0]:.14f} {cell[i][1]:.14f} {cell[i][2]:.14f}") + lines.append(f" {cell[i][0]:.14f} {cell[i][1]:.14f} {cell[i][2]:.14f}") cart_coords = self.structure.cart_coords - app("# Cartesian coordinates in Angstrom.") - app("PRIMCOORD") - app(f" {len(cart_coords)} 1") + lines.append("# Cartesian coordinates in Angstrom.") + lines.append("PRIMCOORD") + lines.append(f" {len(cart_coords)} 1") for site, coord in zip(self.structure, cart_coords): sp = site.specie.symbol if atom_symbol else f"{site.specie.Z}" x, y, z = coord - app(f"{sp} {x:20.14f} {y:20.14f} {z:20.14f}") + lines.append(f"{sp} {x:20.14f} {y:20.14f} {z:20.14f}") + if "vect" in site.properties: + vx, vy, vz = site.properties["vect"] + lines[-1] += f" {vx:20.14f} {vy:20.14f} {vz:20.14f}" return "\n".join(lines) @classmethod - def from_str(cls, input_string, cls_=None): + def from_str(cls, input_string: str, cls_=None) -> Self: """ Initialize a `Structure` object from a string with data in XSF format. @@ -59,38 +66,39 @@ def from_str(cls, input_string, cls_=None): input_string: String with the structure in XSF format. See http://www.xcrysden.org/doc/XSF.html cls_: Structure class to be created. default: pymatgen structure - """ - # CRYSTAL see (1) - # these are primitive lattice vectors (in Angstroms) - # PRIMVEC - # 0.0000000 2.7100000 2.7100000 see (2) - # 2.7100000 0.0000000 2.7100000 - # 2.7100000 2.7100000 0.0000000 - - # these are conventional lattice vectors (in Angstroms) - # CONVVEC - # 5.4200000 0.0000000 0.0000000 see (3) - # 0.0000000 5.4200000 0.0000000 - # 0.0000000 0.0000000 5.4200000 - - # these are atomic coordinates in a primitive unit cell (in Angstroms) - # PRIMCOORD - # 2 1 see (4) - # 16 0.0000000 0.0000000 0.0000000 see (5) - # 30 1.3550000 -1.3550000 -1.3550000 + Example file: + CRYSTAL see (1) + these are primitive lattice vectors (in Angstroms) + PRIMVEC + 0.0000000 2.7100000 2.7100000 see (2) + 2.7100000 0.0000000 2.7100000 + 2.7100000 2.7100000 0.0000000 + + these are conventional lattice vectors (in Angstroms) + CONVVEC + 5.4200000 0.0000000 0.0000000 see (3) + 0.0000000 5.4200000 0.0000000 + 0.0000000 0.0000000 5.4200000 + + these are atomic coordinates in a primitive unit cell (in Angstroms) + PRIMCOORD + 2 1 see (4) + 16 0.0000000 0.0000000 0.0000000 see (5) + 30 1.3550000 -1.3550000 -1.3550000 + """ lattice, coords, species = [], [], [] lines = input_string.splitlines() - for idx, line in enumerate(lines): + for idx, line in enumerate(lines, start=1): if "PRIMVEC" in line: - for j in range(idx + 1, idx + 4): + for j in range(idx, idx + 3): lattice.append([float(c) for c in lines[j].split()]) if "PRIMCOORD" in line: - num_sites = int(lines[idx + 1].split()[0]) + num_sites = int(lines[idx].split()[0]) - for j in range(idx + 2, idx + 2 + num_sites): + for j in range(idx + 1, idx + 1 + num_sites): tokens = lines[j].split() Z = Element(tokens[0]).Z if tokens[0].isalpha() else int(tokens[0]) species.append(Z) @@ -102,5 +110,4 @@ def from_str(cls, input_string, cls_=None): if cls_ is None: cls_ = Structure - s = cls_(lattice, species, coords, coords_are_cartesian=True) - return XSF(s) + return cls(cls_(lattice, species, coords, coords_are_cartesian=True)) diff --git a/pymatgen/io/xr.py b/pymatgen/io/xr.py index e576f1f47e4..789651e66e4 100644 --- a/pymatgen/io/xr.py +++ b/pymatgen/io/xr.py @@ -11,6 +11,7 @@ import re from math import fabs +from typing import TYPE_CHECKING import numpy as np from monty.io import zopen @@ -18,6 +19,11 @@ from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self + __author__ = "Nils Edvin Richard Zimmermann" __copyright__ = "Copyright 2016, The Materials Project" __version__ = "0.1" @@ -50,15 +56,15 @@ def __str__(self): ] # There are actually 10 more fields per site # in a typical xr file from GULP, for example. - for idx, site in enumerate(self.structure): - output.append(f"{idx + 1} {site.specie} {site.x:.4f} {site.y:.4f} {site.z:.4f}") + for idx, site in enumerate(self.structure, start=1): + output.append(f"{idx } {site.specie} {site.x:.4f} {site.y:.4f} {site.z:.4f}") mat = self.structure.lattice.matrix for _ in range(2): for j in range(3): output.append(f"{mat[j][0]:.4f} {mat[j][1]:.4f} {mat[j][2]:.4f}") return "\n".join(output) - def write_file(self, filename): + def write_file(self, filename: str | Path) -> None: """ Write out an xr file. @@ -69,7 +75,7 @@ def write_file(self, filename): file.write(str(self) + "\n") @classmethod - def from_str(cls, string, use_cores=True, thresh=1.0e-4): + def from_str(cls, string: str, use_cores: bool = True, thresh: float = 1.0e-4) -> Self: """ Creates an Xr object from a string representation. @@ -118,12 +124,11 @@ def from_str(cls, string, use_cores=True, thresh=1.0e-4): sp = [] coords = [] for j in range(n_sites): - m = re.match( + if match := re.match( r"\d+\s+(\w+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)", lines[4 + j].strip(), - ) - if m: - tmp_sp = m.group(1) + ): + tmp_sp = match.group(1) if use_cores and tmp_sp[len(tmp_sp) - 2 :] == "_s": continue if not use_cores and tmp_sp[len(tmp_sp) - 2 :] == "_c": @@ -132,11 +137,11 @@ def from_str(cls, string, use_cores=True, thresh=1.0e-4): sp.append(tmp_sp[0 : len(tmp_sp) - 2]) else: sp.append(tmp_sp) - coords.append([float(m.group(i)) for i in range(2, 5)]) + coords.append([float(match.group(i)) for i in range(2, 5)]) return cls(Structure(lattice, sp, coords, coords_are_cartesian=True)) @classmethod - def from_file(cls, filename, use_cores=True, thresh=1.0e-4): + def from_file(cls, filename: str | Path, use_cores: bool = True, thresh: float = 1.0e-4) -> Self: """ Reads an xr-formatted file to create an Xr object. diff --git a/pymatgen/io/xtb/inputs.py b/pymatgen/io/xtb/inputs.py index 6557d7a746c..d753885e08f 100644 --- a/pymatgen/io/xtb/inputs.py +++ b/pymatgen/io/xtb/inputs.py @@ -80,11 +80,11 @@ def constrains_template(molecule, reference_fnm, constraints) -> str: atoms_for_mtd = [idx for idx in range(1, len(mol) + 1) if idx not in atoms_to_constrain] # Write as 1-3,5 instead of 1,2,3,5 interval_list = [atoms_for_mtd[0]] - for i, v in enumerate(atoms_for_mtd): - if v + 1 not in atoms_for_mtd: - interval_list.append(v) - if i != len(atoms_for_mtd) - 1: - interval_list.append(atoms_for_mtd[i + 1]) + for idx, val in enumerate(atoms_for_mtd, start=1): + if val + 1 not in atoms_for_mtd: + interval_list.append(val) + if idx != len(atoms_for_mtd): + interval_list.append(atoms_for_mtd[idx]) allowed_mtd_string = ",".join( [f"{interval_list[i]}-{interval_list[i + 1]}" for i in range(len(interval_list)) if i % 2 == 0] ) diff --git a/pymatgen/io/xtb/outputs.py b/pymatgen/io/xtb/outputs.py index c4a5d98cb9c..1a6132f496b 100644 --- a/pymatgen/io/xtb/outputs.py +++ b/pymatgen/io/xtb/outputs.py @@ -56,7 +56,7 @@ def _parse_crest_output(self): # Get CREST command crest_cmd = None - with open(output_filepath) as xtbout_file: + with open(output_filepath, encoding="utf-8") as xtbout_file: for line in xtbout_file: if "> crest" in line: crest_cmd = line.strip()[8:] @@ -72,12 +72,12 @@ def _parse_crest_output(self): print(f"Input file {split_cmd[0]} not found") # Get CREST input flags - for i, entry in enumerate(split_cmd): + for i, entry in enumerate(split_cmd, start=1): value = None if entry and "-" in entry: option = entry[1:] - if i + 1 < len(split_cmd) and "-" not in split_cmd[i + 1]: - value = split_cmd[i + 1] + if i < len(split_cmd) and "-" not in split_cmd[i]: + value = split_cmd[i] self.cmd_options[option] = value # Get input charge for decorating parsed molecules chg = 0 @@ -112,7 +112,7 @@ def _parse_crest_output(self): ) conformer_degeneracies = [] energies = [] - with open(output_filepath) as xtbout_file: + with open(output_filepath, encoding="utf-8") as xtbout_file: for line in xtbout_file: conformer_match = conformer_pattern.match(line) rotamer_match = rotamer_pattern.match(line) @@ -127,12 +127,12 @@ def _parse_crest_output(self): final_rotamer_filename = "crest_rotamers.xyz" else: n_rot_files = [] - for f in os.listdir(self.path): - if "crest_rotamers" in f: - n_rot_file = int(os.path.splitext(f)[0].split("_")[2]) + for filename in os.listdir(self.path): + if "crest_rotamers" in filename: + n_rot_file = int(os.path.splitext(filename)[0].split("_")[2]) n_rot_files.append(n_rot_file) - if len(n_rot_files) > 0: - final_rotamer_filename = f"crest_rotamers_{max(n_rot_files)}.xyz" + final_rotamer_filename = f"crest_rotamers_{max(n_rot_files)}.xyz" if len(n_rot_files) > 0 else "" + try: rotamers_path = os.path.join(self.path, final_rotamer_filename) rotamer_structures = XYZ.from_file(rotamers_path).all_molecules diff --git a/pymatgen/io/xyz.py b/pymatgen/io/xyz.py index 5b4d66f82b5..74255f7d747 100644 --- a/pymatgen/io/xyz.py +++ b/pymatgen/io/xyz.py @@ -4,7 +4,7 @@ import re from io import StringIO -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pandas as pd from monty.io import zopen @@ -14,6 +14,9 @@ if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path + + from typing_extensions import Self class XYZ: @@ -34,7 +37,7 @@ def __init__(self, mol: Molecule | Structure | Sequence[Molecule | Structure], c mol (Molecule | Structure): Input molecule or structure or list thereof. coord_precision: Precision to be used for coordinates. """ - self._mols = [mol] if isinstance(mol, SiteCollection) else mol + self._mols = cast(list[SiteCollection], [mol] if isinstance(mol, SiteCollection) else mol) self.precision = coord_precision @property @@ -71,7 +74,7 @@ def _from_frame_str(contents) -> Molecule: return Molecule(sp, coords) @classmethod - def from_str(cls, contents) -> XYZ: + def from_str(cls, contents: str) -> Self: """ Creates XYZ object from a string. @@ -96,7 +99,7 @@ def from_str(cls, contents) -> XYZ: return cls(mols) @classmethod - def from_file(cls, filename) -> XYZ: + def from_file(cls, filename: str | Path) -> Self: """ Creates XYZ object from a file. diff --git a/pymatgen/io/zeopp.py b/pymatgen/io/zeopp.py index 0341cf3379f..364eb8c71f7 100644 --- a/pymatgen/io/zeopp.py +++ b/pymatgen/io/zeopp.py @@ -26,6 +26,7 @@ import os import re +from typing import TYPE_CHECKING from monty.dev import requires from monty.io import zopen @@ -43,6 +44,12 @@ zeo_found = True except ImportError: zeo_found = False + AtomNetwork = prune_voronoi_network_close_node = None + +if TYPE_CHECKING: + from pathlib import Path + + from typing_extensions import Self __author__ = "Bharat Medasani" __copyright__ = "Copyright 2013, The Materials Project" @@ -91,7 +98,7 @@ def __str__(self): return "\n".join(output) @classmethod - def from_str(cls, string): + def from_str(cls, string: str) -> Self: """ Reads a string representation to a ZeoCssr object. @@ -112,24 +119,25 @@ def from_str(cls, string): alpha = angles.pop(-1) angles.insert(0, alpha) lattice = Lattice.from_parameters(*lengths, *angles) + sp = [] coords = [] charge = [] for line in lines[4:]: - m = re.match( + match = re.match( r"\d+\s+(\w+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)\s+([0-9\-\.]+)\s+(?:0\s+){8}([0-9\-\.]+)", line.strip(), ) - if m: - sp.append(m.group(1)) + if match: + sp.append(match.group(1)) # coords.append([float(m.group(i)) for i in xrange(2, 5)]) # Zeo++ takes x-axis along a and pymatgen takes z-axis along c - coords.append([float(m.group(i)) for i in [3, 4, 2]]) - charge.append(m.group(5)) + coords.append([float(match.group(i)) for i in [3, 4, 2]]) + charge.append(match.group(5)) return cls(Structure(lattice, sp, coords, site_properties={"charge": charge})) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Reads a CSSR file to a ZeoCssr object. @@ -158,7 +166,7 @@ def __init__(self, mol): super().__init__(mol) @classmethod - def from_str(cls, contents): + def from_str(cls, contents: str) -> Self: """ Creates Zeo++ Voronoi XYZ object from a string. from_string method of XYZ class is being redefined. @@ -185,7 +193,7 @@ def from_str(cls, contents): return cls(Molecule(sp, coords, site_properties={"voronoi_radius": prop})) @classmethod - def from_file(cls, filename): + def from_file(cls, filename: str | Path) -> Self: """ Creates XYZ object from a file. @@ -234,22 +242,22 @@ def get_voronoi_nodes(structure, rad_dict=None, probe_rad=0.1): """ with ScratchDir("."): name = "temp_zeo1" - zeo_inp_filename = name + ".cssr" + zeo_inp_filename = f"{name}.cssr" ZeoCssr(structure).write_file(zeo_inp_filename) rad_file = None rad_flag = False if rad_dict: - rad_file = name + ".rad" + rad_file = f"{name}.rad" rad_flag = True - with open(rad_file, "w+") as file: + with open(rad_file, "w+", encoding="utf-8") as file: for el in rad_dict: file.write(f"{el} {rad_dict[el].real}\n") atom_net = AtomNetwork.read_from_CSSR(zeo_inp_filename, rad_flag=rad_flag, rad_file=rad_file) vor_net, vor_edge_centers, vor_face_centers = atom_net.perform_voronoi_decomposition() vor_net.analyze_writeto_XYZ(name, probe_rad, atom_net) - voro_out_filename = name + "_voro.xyz" + voro_out_filename = f"{name}_voro.xyz" voro_node_mol = ZeoVoronoiXYZ.from_file(voro_out_filename).molecule species = ["X"] * len(voro_node_mol) @@ -324,8 +332,8 @@ def get_high_accuracy_voronoi_nodes(structure, rad_dict, probe_rad=0.1): zeo_inp_filename = f"{name}.cssr" ZeoCssr(structure).write_file(zeo_inp_filename) rad_flag = True - rad_file = name + ".rad" - with open(rad_file, "w+") as file: + rad_file = f"{name}.rad" + with open(rad_file, "w+", encoding="utf-8") as file: for el in rad_dict: print(f"{el} {rad_dict[el].real}", file=file) @@ -383,15 +391,15 @@ def get_free_sphere_params(structure, rad_dict=None, probe_rad=0.1): """ with ScratchDir("."): name = "temp_zeo1" - zeo_inp_filename = name + ".cssr" + zeo_inp_filename = f"{name}.cssr" ZeoCssr(structure).write_file(zeo_inp_filename) rad_file = None rad_flag = False if rad_dict: - rad_file = name + ".rad" + rad_file = f"{name}.rad" rad_flag = True - with open(rad_file, "w+") as file: + with open(rad_file, "w+", encoding="utf-8") as file: for el in rad_dict: file.write(f"{el} {rad_dict[el].real}\n") @@ -399,16 +407,16 @@ def get_free_sphere_params(structure, rad_dict=None, probe_rad=0.1): out_file = "temp.res" atom_net.calculate_free_sphere_parameters(out_file) if os.path.isfile(out_file) and os.path.getsize(out_file) > 0: - with open(out_file) as file: + with open(out_file, encoding="utf-8") as file: output = file.readline() else: output = "" fields = [val.strip() for val in output.split()][1:4] if len(fields) == 3: fields = [float(field) for field in fields] - free_sphere_params = { + return { "inc_sph_max_dia": fields[0], "free_sph_max_dia": fields[1], "inc_sph_along_free_sph_path_max_dia": fields[2], } - return free_sphere_params + return None diff --git a/pymatgen/optimization/neighbors.pyx b/pymatgen/optimization/neighbors.pyx index 6a2d3bfc26e..f320b4a362e 100644 --- a/pymatgen/optimization/neighbors.pyx +++ b/pymatgen/optimization/neighbors.pyx @@ -110,12 +110,12 @@ def find_points_in_spheres( double[:, ::1] reciprocal_lattice = reciprocal_lattice_arr int count = 0 - int natoms = n_total - double *offsets_p_temp = safe_malloc(natoms * 3 * sizeof(double)) + int n_atoms = n_total + double *offsets_p_temp = safe_malloc(n_atoms * 3 * sizeof(double)) double *expanded_coords_p_temp = safe_malloc( - natoms * 3 * sizeof(double) + n_atoms * 3 * sizeof(double) ) - long *indices_p_temp = safe_malloc(natoms * sizeof(long)) + long *indices_p_temp = safe_malloc(n_atoms * sizeof(long)) double coord_temp[3] long ncube[3] @@ -194,16 +194,16 @@ def find_points_in_spheres( expanded_coords_p_temp[3*count+1] = coord_temp[1] expanded_coords_p_temp[3*count+2] = coord_temp[2] count += 1 - if count >= natoms: # exceeding current memory - natoms += natoms + if count >= n_atoms: # exceeding current memory + n_atoms += n_atoms offsets_p_temp = realloc( - offsets_p_temp, natoms * 3 * sizeof(double) + offsets_p_temp, n_atoms * 3 * sizeof(double) ) expanded_coords_p_temp = realloc( - expanded_coords_p_temp, natoms * 3 * sizeof(double) + expanded_coords_p_temp, n_atoms * 3 * sizeof(double) ) indices_p_temp = realloc( - indices_p_temp, natoms * sizeof(long) + indices_p_temp, n_atoms * sizeof(long) ) if ( offset_final is NULL or @@ -248,7 +248,7 @@ def find_points_in_spheres( return (np.array([], dtype=int), np.array([], dtype=int), np.array([[], [], []], dtype=float).T, np.array([], dtype=float)) - natoms = count + n_atoms = count cdef: # Delete those beyond (min_center_coords - r, max_center_coords + r) double *offsets_p = safe_realloc( @@ -266,11 +266,11 @@ def find_points_in_spheres( long[::1] indices = indices_p # Construct linked cell list - long[:, ::1] all_indices3 = safe_malloc( - natoms * 3 * sizeof(long) + long[:, ::1] all_indices3 = safe_malloc( + n_atoms * 3 * sizeof(long) ) - long[::1] all_indices1 = safe_malloc( - natoms * sizeof(long) + long[::1] all_indices1 = safe_malloc( + n_atoms * sizeof(long) ) for i in range(3): @@ -282,16 +282,16 @@ def find_points_in_spheres( cdef: long nb_cubes = ncube[0] * ncube[1] * ncube[2] long *head = safe_malloc(nb_cubes*sizeof(long)) - long *atom_indices = safe_malloc(natoms*sizeof(long)) + long *atom_indices = safe_malloc(n_atoms*sizeof(long)) long[:, ::1] neighbor_map = safe_malloc( nb_cubes * 27 * sizeof(long) ) memset(head, -1, nb_cubes*sizeof(long)) - memset(atom_indices, -1, natoms*sizeof(long)) + memset(atom_indices, -1, n_atoms*sizeof(long)) get_cube_neighbors(ncube, neighbor_map) - for i in range(natoms): + for i in range(n_atoms): atom_indices[i] = head[all_indices1[i]] head[all_indices1[i]] = i diff --git a/pymatgen/phonon/bandstructure.py b/pymatgen/phonon/bandstructure.py index 90dac271be3..2b51d497ced 100644 --- a/pymatgen/phonon/bandstructure.py +++ b/pymatgen/phonon/bandstructure.py @@ -17,6 +17,7 @@ from os import PathLike from numpy.typing import ArrayLike + from typing_extensions import Self def get_reasonable_repetitions(n_atoms: int) -> tuple[int, int, int]: @@ -44,6 +45,7 @@ def estimate_band_connection(prev_eigvecs, eigvecs, prev_band_order) -> list[int connection_order = [] for overlaps in metric: max_val = 0 + max_idx = 0 for idx in reversed(range(len(metric))): val = overlaps[idx] if idx in connection_order: @@ -304,7 +306,7 @@ def as_dict(self) -> dict[str, Any]: return dct @classmethod - def from_dict(cls, dct: dict[str, Any]) -> PhononBandStructure: + def from_dict(cls, dct: dict[str, Any]) -> Self: """ Args: dct (dict): Dict representation of PhononBandStructure. @@ -630,10 +632,10 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, dct: dict) -> PhononBandStructureSymmLine: + def from_dict(cls, dct: dict) -> Self: """ Args: - dct: Dict representation. + dct (dict): Dict representation. Returns: PhononBandStructureSymmLine diff --git a/pymatgen/phonon/dos.py b/pymatgen/phonon/dos.py index 698f71f47f2..29a824cfdeb 100644 --- a/pymatgen/phonon/dos.py +++ b/pymatgen/phonon/dos.py @@ -8,7 +8,7 @@ import scipy.constants as const from monty.functools import lazy_property from monty.json import MSONable -from scipy.ndimage.filters import gaussian_filter1d +from scipy.ndimage import gaussian_filter1d from pymatgen.core.structure import Structure from pymatgen.util.coord import get_linear_interpolated_value @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + BOLTZ_THZ_PER_K = const.value("Boltzmann constant in Hz/K") / const.tera # Boltzmann constant in THz/K THZ_TO_J = const.value("hertz-joule relationship") * const.tera @@ -133,7 +135,7 @@ def __str__(self) -> str: return "\n".join(str_arr) @classmethod - def from_dict(cls, dct: dict[str, Sequence]) -> PhononDos: + def from_dict(cls, dct: dict[str, Sequence]) -> Self: """Returns PhononDos object from dict representation of PhononDos.""" return cls(dct["frequencies"], dct["densities"]) @@ -458,7 +460,7 @@ def get_element_dos(self) -> dict: return {el: PhononDos(self.frequencies, densities) for el, densities in el_dos.items()} @classmethod - def from_dict(cls, dct: dict) -> CompletePhononDos: + def from_dict(cls, dct: dict) -> Self: """Returns CompleteDos object from dict representation.""" total_dos = PhononDos.from_dict(dct) struct = Structure.from_dict(dct["structure"]) diff --git a/pymatgen/phonon/gruneisen.py b/pymatgen/phonon/gruneisen.py index 69bc09fb269..baabee17cd6 100644 --- a/pymatgen/phonon/gruneisen.py +++ b/pymatgen/phonon/gruneisen.py @@ -19,15 +19,16 @@ try: import phonopy from phonopy.phonon.dos import TotalDos -except ImportError as exc: - print(exc) +except ImportError: phonopy = None + TotalDos = None if TYPE_CHECKING: from collections.abc import Sequence from typing import Literal from numpy.typing import ArrayLike + from typing_extensions import Self __author__ = "A. Bonkowski, J. George, G. Petretto" __copyright__ = "Copyright 2021, The Materials Project" @@ -312,7 +313,7 @@ def as_dict(self) -> dict: return dct @classmethod - def from_dict(cls, dct: dict) -> GruneisenPhononBandStructure: + def from_dict(cls, dct: dict) -> Self: """ Args: dct (dict): Dict representation. @@ -393,10 +394,10 @@ def __init__( ) @classmethod - def from_dict(cls, dct: dict) -> GruneisenPhononBandStructureSymmLine: + def from_dict(cls, dct: dict) -> Self: """ Args: - dct: Dict representation. + dct (dict): Dict representation. Returns: GruneisenPhononBandStructureSymmLine diff --git a/pymatgen/phonon/ir_spectra.py b/pymatgen/phonon/ir_spectra.py index 86706fd08ad..f9c1caacfa7 100644 --- a/pymatgen/phonon/ir_spectra.py +++ b/pymatgen/phonon/ir_spectra.py @@ -23,6 +23,7 @@ from matplotlib.axes import Axes from numpy.typing import ArrayLike + from typing_extensions import Self __author__ = "Henrique Miranda, Guido Petretto, Matteo Giantomassi" __copyright__ = "Copyright 2018, The Materials Project" @@ -59,7 +60,7 @@ def __init__( self.epsilon_infinity = np.array(epsilon_infinity) @classmethod - def from_dict(cls, dct: dict) -> IRDielectricTensor: + def from_dict(cls, dct: dict) -> Self: """Returns IRDielectricTensor from dict representation.""" structure = Structure.from_dict(dct["structure"]) oscillator_strength = dct["oscillator_strength"] diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index e756a91a671..5989cf4dd17 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -133,6 +133,7 @@ def get_plot( self, xlim: float | None = None, ylim: float | None = None, + invert_axes: bool = False, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", legend: dict | None = None, ax: Axes | None = None, @@ -142,6 +143,8 @@ def get_plot( Args: xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. + invert_axes (bool): Whether to invert the x and y axes. Enables chemist style DOS plotting. + Defaults to False. units (thz | ev | mev | ha | cm-1 | cm^-1): units for the frequencies. Defaults to "thz". legend: dict with legend options. For example, {"loc": "upper right"} will place the legend in the upper right corner. Defaults to {"fontsize": 30}. @@ -155,21 +158,21 @@ def get_plot( n_colors = max(3, len(self._doses)) n_colors = min(9, n_colors) - y = None + ys = None all_densities = [] all_frequencies = [] - ax = pretty_plot(12, 8, ax=ax) + ax = pretty_plot(*(8, 12) if invert_axes else (12, 8), ax=ax) # Note that this complicated processing of frequencies is to allow for # stacked plots in matplotlib. for dos in self._doses.values(): frequencies = dos["frequencies"] * unit.factor densities = dos["densities"] - if y is None: - y = np.zeros(frequencies.shape) + if ys is None: + ys = np.zeros(frequencies.shape) if self.stack: - y += densities - new_dens = y.copy() + ys += densities + new_dens = ys.copy() else: new_dens = densities all_frequencies.append(frequencies) @@ -182,25 +185,46 @@ def get_plot( colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive") for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)): color = self._doses[key].get("color", colors[idx % n_colors]) + linewidth = self._doses[key].get("linewidth", 3) + kwargs = { + key: val + for key, val in self._doses[key].items() + if key not in ["frequencies", "densities", "color", "linewidth"] + } all_pts.extend(list(zip(frequencies, densities))) + if invert_axes: + xs, ys = densities, frequencies + else: + xs, ys = frequencies, densities if self.stack: - ax.fill(frequencies, densities, color=color, label=str(key)) + ax.fill(xs, ys, color=color, label=str(key), **kwargs) else: - ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3) + ax.plot(xs, ys, color=color, label=str(key), linewidth=linewidth, **kwargs) if xlim: ax.set_xlim(xlim) if ylim: ax.set_ylim(ylim) + elif invert_axes: + _ylim = ax.get_ylim() + relevant_x = [p[1] for p in all_pts if _ylim[0] < p[0] < _ylim[1]] or ax.get_xlim() + ax.set_xlim((min(relevant_x), max(relevant_x))) else: _xlim = ax.get_xlim() relevant_y = [p[1] for p in all_pts if _xlim[0] < p[0] < _xlim[1]] or ax.get_ylim() ax.set_ylim((min(relevant_y), max(relevant_y))) - ax.axvline(0, linewidth=2, color="black", linestyle="--") + if invert_axes: + ax.axhline(0, linewidth=2, color="black", linestyle="--") + + ax.set_xlabel(r"$\mathrm{Density\ of\ states}$", fontsize=legend.get("fontsize", 30)) + ax.set_ylabel(rf"$\mathrm{{Frequencies\ ({unit.label})}}$", fontsize=legend.get("fontsize", 30)) - ax.set_xlabel(rf"$\mathrm{{Frequencies\ ({unit.label})}}$", fontsize=legend.get("fontsize", 30)) - ax.set_ylabel(r"$\mathrm{Density\ of\ states}$", fontsize=legend.get("fontsize", 30)) + else: + ax.axvline(0, linewidth=2, color="black", linestyle="--") + + ax.set_xlabel(rf"$\mathrm{{Frequencies\ ({unit.label})}}$", fontsize=legend.get("fontsize", 30)) + ax.set_ylabel(r"$\mathrm{Density\ of\ states}$", fontsize=legend.get("fontsize", 30)) # only show legend if there are labels if sum(map(len, keys)) > 0: @@ -214,6 +238,7 @@ def save_plot( img_format: str = "eps", xlim: float | None = None, ylim: float | None = None, + invert_axes: bool = False, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", ) -> None: """Save matplotlib plot to a file. @@ -224,9 +249,11 @@ def save_plot( xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. + invert_axes: Whether to invert the x and y axes. Enables chemist style DOS plotting. + Defaults to False. units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1 """ - self.get_plot(xlim, ylim, units=units) + self.get_plot(xlim, ylim, invert_axes=invert_axes, units=units) plt.savefig(filename, format=img_format) plt.close() @@ -234,6 +261,7 @@ def show( self, xlim: float | None = None, ylim: None = None, + invert_axes: bool = False, units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", ) -> None: """Show the plot using matplotlib. @@ -242,9 +270,11 @@ def show( xlim: Specifies the x-axis limits. Set to None for automatic determination. ylim: Specifies the y-axis limits. + invert_axes: Whether to invert the x and y axes. Enables chemist style DOS plotting. + Defaults to False. units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. """ - self.get_plot(xlim, ylim, units=units) + self.get_plot(xlim, ylim, invert_axes=invert_axes, units=units) plt.show() @@ -256,7 +286,7 @@ def __init__(self, bs: PhononBandStructureSymmLine, label: str | None = None) -> Args: bs: A PhononBandStructureSymmLine object. label: A label for the plot. Defaults to None for no label. Esp. useful with - the plot_compare method to distinguish the two band structures. + the plot_compare method to distinguish the band structures. """ if not isinstance(bs, PhononBandStructureSymmLine): raise ValueError( @@ -395,10 +425,10 @@ def _make_color(colors: Sequence[int]) -> Sequence[int]: return colors # if there are four groups, use cyan, magenta, yellow and black if len(colors) == 4: - r = (1 - colors[0]) * (1 - colors[3]) - g = (1 - colors[1]) * (1 - colors[3]) - b = (1 - colors[2]) * (1 - colors[3]) - return [r, g, b] + red = (1 - colors[0]) * (1 - colors[3]) + green = (1 - colors[1]) * (1 - colors[3]) + blue = (1 - colors[2]) * (1 - colors[3]) + return [red, green, blue] raise ValueError(f"Expected 2, 3 or 4 colors, got {len(colors)}") def get_proj_plot( @@ -424,7 +454,7 @@ def get_proj_plot( the colors will be automatically generated. """ assert self._bs.structure is not None, "Structure is required for get_proj_plot" - elements = [e.symbol for e in self._bs.structure.elements] + elements = [elem.symbol for elem in self._bs.structure.elements] if site_comb == "element": assert 2 <= len(elements) <= 4, "the compound must have 2, 3 or 4 unique elements" indices: list[list[int]] = [[] for _ in range(len(elements))] @@ -484,7 +514,7 @@ def get_proj_plot( if rgb_labels is not None: labels = rgb_labels # type: ignore[assignment] elif site_comb == "element": - labels = [e.symbol for e in self._bs.structure.elements] + labels = [elem.symbol for elem in self._bs.structure.elements] else: labels = [f"{idx}" for idx in range(len(site_comb))] if len(indices) == 2: @@ -590,30 +620,30 @@ def get_ticks(self) -> dict[str, list]: def plot_compare( self, - other_plotter: PhononBSPlotter, + other_plotter: PhononBSPlotter | dict[str, PhononBSPlotter], units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz", - labels: tuple[str, str] | None = None, + self_label: str = "self", + colors: Sequence[str] | None = None, legend_kwargs: dict | None = None, on_incompatible: Literal["raise", "warn", "ignore"] = "raise", other_kwargs: dict | None = None, **kwargs, ) -> Axes: - """Plot two band structure for comparison. self in blue, other in red. - The two band structures need to be defined on the same symmetry lines! + """Plot two band structure for comparison. self in blue, others in red, green, ... + The band structures need to be defined on the same symmetry lines! The distance between symmetry lines is determined by the band structure used to initialize PhononBSPlotter (self). Args: - other_plotter (PhononBSPlotter): another PhononBSPlotter object defined along the - same symmetry lines + other_plotter (PhononBSPlotter | dict[str, PhononBSPlotter]): Other PhononBSPlotter object(s) defined along + the same symmetry lines units (str): units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1. Defaults to 'thz'. - labels (tuple[str, str] | None): labels for the two band structures. Defaults to None, - which will use the label of the two PhononBSPlotter objects if present. - Label order is (self_label, other_label), i.e. the label of the PhononBSPlotter - on which plot_compare() is called must come first. + self_label (str): label for the self band structure. Defaults to to the label passed to PhononBSPlotter.init + or, if None, 'self'. + colors (list[str]): list of colors for the other band structures. Defaults to None for automatic colors. legend_kwargs: dict[str, Any]: kwargs passed to ax.legend(). - on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures + on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the band structures are not compatible. Defaults to 'raise'. other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot(). **kwargs: passed to ax.plot(). @@ -625,35 +655,46 @@ def plot_compare( legend_kwargs = legend_kwargs or {} other_kwargs = other_kwargs or {} legend_kwargs.setdefault("fontsize", 20) + _colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive") + if isinstance(other_plotter, PhononBSPlotter): + other_plotter = {other_plotter._label or "other": other_plotter} + if colors: + assert len(colors) == len(other_plotter) + 1, "Wrong number of colors" self_data = self.bs_plot_data() - other_data = other_plotter.bs_plot_data() - - if len(self_data["distances"]) != len(other_data["distances"]): - if on_incompatible == "raise": - raise ValueError("The two band structures are not compatible.") - if on_incompatible == "warn": - logger.warning("The two band structures are not compatible.") - return None # ignore/warn line_width = kwargs.setdefault("linewidth", 1) + ax = self.get_plot(units=units, color=colors[0] if colors else _colors[0], **kwargs) - ax = self.get_plot(units=units, **kwargs) + colors_other = [] - kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color + for idx, plotter in enumerate(other_plotter.values()): + other_data = plotter.bs_plot_data() - for band_idx in range(other_plotter.n_bands): - for dist_idx, dists in enumerate(self_data["distances"]): - xs = dists - ys = [other_data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))] - ax.plot(xs, ys, **(kwargs | other_kwargs)) + if np.asarray(self_data["distances"]).shape != np.asarray(other_data["distances"]).shape: + if on_incompatible == "raise": + raise ValueError("The two band structures are not compatible.") + if on_incompatible == "warn": + logger.warning("The two band structures are not compatible.") + return None # ignore/warn + + color = colors[idx + 1] if colors else _colors[1 + idx % len(_colors)] + _kwargs = kwargs.copy() # Don't set the color in kwargs, or every band will be red + colors_other.append( + _kwargs.setdefault("color", color) + ) # don't move this line up! it would mess up self.get_plot color + + for band_idx in range(plotter.n_bands): + for dist_idx, dists in enumerate(self_data["distances"]): + xs = dists + ys = [other_data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))] + ax.plot(xs, ys, **(_kwargs | other_kwargs)) # add legend showing which color corresponds to which band structure - if labels or (self._label and other_plotter._label): - color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color() - label_self, label_other = labels or (self._label, other_plotter._label) - ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self) - linestyle = other_kwargs.get("linestyle", "-") + color_self = ax.lines[0].get_color() + ax.plot([], [], label=self._label or self_label, linewidth=2 * line_width, color=color_self) + linestyle = other_kwargs.get("linestyle", "-") + for color_other, label_other in zip(colors_other, other_plotter): ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle) ax.legend(**legend_kwargs) diff --git a/pymatgen/phonon/thermal_displacements.py b/pymatgen/phonon/thermal_displacements.py index 2bb07ceca24..566e1945db7 100644 --- a/pymatgen/phonon/thermal_displacements.py +++ b/pymatgen/phonon/thermal_displacements.py @@ -4,7 +4,7 @@ import re from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np from monty.json import MSONable @@ -15,16 +15,16 @@ from pymatgen.symmetry.groups import SYMM_DATA from pymatgen.util.due import Doi, due +try: + import phonopy +except ImportError: + phonopy = None + if TYPE_CHECKING: from os import PathLike from numpy.typing import ArrayLike - -try: - import phonopy -except ImportError as exc: - print(exc) - phonopy = None + from typing_extensions import Self __author__ = "J. George" __copyright__ = "Copyright 2022, The Materials Project" @@ -306,7 +306,7 @@ def visualize_directionality_quality_criterion( self, other: ThermalDisplacementMatrices, filename: str | PathLike = "visualization.vesta", - which_structure: int = 0, + which_structure: Literal[0, 1] = 0, ) -> None: """Will create a VESTA file for visualization of the directionality criterion. @@ -328,8 +328,10 @@ def visualize_directionality_quality_criterion( structure = self.structure elif which_structure == 1: structure = other.structure + else: + raise ValueError("Illegal which_structure value.") - with open(filename, mode="w") as file: + with open(filename, mode="w", encoding="utf-8") as file: # file.write("#VESTA_FORMAT_VERSION 3.5.4\n \n \n") file.write("CRYSTAL\n\n") @@ -345,9 +347,9 @@ def visualize_directionality_quality_criterion( file.write(" 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000\n") # error on parameters file.write("STRUC\n") - for isite, site in enumerate(structure): + for isite, site in enumerate(structure, start=1): file.write( - f"{isite + 1} {site.species_string} {site.species_string}{isite + 1} 1.0000 {site.frac_coords[0]} " + f"{isite} {site.species_string} {site.species_string}{isite} 1.0000 {site.frac_coords[0]} " f"{site.frac_coords[1]} {site.frac_coords[2]} 1a 1\n" ) file.write(" 0.000000 0.000000 0.000000 0.00\n") # error on positions - zero here @@ -410,10 +412,13 @@ def ratio_prolate(self) -> np.ndarray: return np.array(ratios) - @staticmethod + @classmethod def from_Ucif( - thermal_displacement_matrix_cif: ArrayLike[ArrayLike], structure: Structure, temperature: float | None = None - ) -> ThermalDisplacementMatrices: + cls, + thermal_displacement_matrix_cif: ArrayLike[ArrayLike], + structure: Structure, + temperature: float | None = None, + ) -> Self: """Starting from a numpy array, it will convert Ucif values into Ucart values and initialize the class. Args: @@ -446,7 +451,7 @@ def from_Ucif( # get ThermalDisplacementMatrices Object - return ThermalDisplacementMatrices( + return cls( thermal_displacement_matrix_cart=thermal_displacement_matrix_cart, thermal_displacement_matrix_cif=thermal_displacement_matrix_cif, structure=structure, @@ -488,9 +493,7 @@ def sort_order(site): return self.structure.copy(site_properties=site_properties) @classmethod - def from_structure_with_site_properties_Ucif( - cls, structure: Structure, temperature: float | None = None - ) -> ThermalDisplacementMatrices: + def from_structure_with_site_properties_Ucif(cls, structure: Structure, temperature: float | None = None) -> Self: """Will create this object with the help of a structure with site properties. Args: diff --git a/pymatgen/symmetry/analyzer.py b/pymatgen/symmetry/analyzer.py index a038eab4d9a..7b1a082d001 100644 --- a/pymatgen/symmetry/analyzer.py +++ b/pymatgen/symmetry/analyzer.py @@ -39,6 +39,7 @@ from pymatgen.symmetry.groups import CrystalSystem logger = logging.getLogger(__name__) + LatticeType = Literal["cubic", "hexagonal", "monoclinic", "orthorhombic", "rhombohedral", "tetragonal", "triclinic"] cite_conventional_cell_algo = due.dcite( @@ -61,7 +62,7 @@ class SpacegroupAnalyzer: Uses spglib to perform various symmetry finding operations. """ - def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tolerance: float = 5.0) -> None: + def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tolerance: float = 5) -> None: """ Args: structure (Structure/IStructure): Structure to find symmetry @@ -72,7 +73,7 @@ def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tol positions (e.g., structures relaxed with electronic structure codes), a looser tolerance of 0.1 (the value used in Materials Project) is often needed. - angle_tolerance (float): Angle tolerance for symmetry finding. + angle_tolerance (float): Angle tolerance for symmetry finding. Defaults to 5 degrees. """ self._symprec = symprec self._angle_tol = angle_tolerance @@ -258,7 +259,7 @@ def _get_symmetry(self): # fractions) translations = [] for t in dct["translations"]: - translations.append([float(Fraction.from_float(c).limit_denominator(1000)) for c in t]) + translations.append([float(Fraction(c).limit_denominator(1000)) for c in t]) translations = np.array(translations) # fractional translations of 1 are more simply 0 @@ -510,15 +511,15 @@ def get_primitive_standard_structure(self, international_monoclinic=True, keep_s ) new_sites = [] - latt = Lattice(np.dot(transf, conv.lattice.matrix)) - for s in conv: + lattice = Lattice(np.dot(transf, conv.lattice.matrix)) + for site in conv: new_s = PeriodicSite( - s.specie, - s.coords, - latt, + site.specie, + site.coords, + lattice, to_unit_cell=True, coords_are_cartesian=True, - properties=s.properties, + properties=site.properties, ) if not any(map(new_s.is_periodic_image, new_sites)): new_sites.append(new_s) @@ -539,14 +540,14 @@ def get_primitive_standard_structure(self, international_monoclinic=True, keep_s ], ] new_sites = [] - latt = Lattice(new_matrix) - for s in prim: + lattice = Lattice(new_matrix) + for site in prim: new_s = PeriodicSite( - s.specie, - s.frac_coords, - latt, + site.specie, + site.frac_coords, + lattice, to_unit_cell=True, - properties=s.properties, + properties=site.properties, ) if not any(map(new_s.is_periodic_image, new_sites)): new_sites.append(new_s) @@ -580,12 +581,13 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee The structure in a conventional standardized cell """ tol = 1e-5 + transf = None struct = self.get_refined_structure(keep_site_properties=keep_site_properties) - latt = struct.lattice + lattice = struct.lattice latt_type = self.get_lattice_type() - sorted_lengths = sorted(latt.abc) + sorted_lengths = sorted(lattice.abc) sorted_dic = sorted( - ({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in range(3)), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in range(3)), key=lambda k: k["length"], ) @@ -595,31 +597,31 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf = np.zeros(shape=(3, 3)) if self.get_space_group_symbol().startswith("C"): transf[2] = [0, 0, 1] - a, b = sorted(latt.abc[:2]) + a, b = sorted(lattice.abc[:2]) sorted_dic = sorted( - ({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [0, 1]), key=lambda k: k["length"], ) for idx in range(2): transf[idx][sorted_dic[idx]["orig_index"]] = 1 - c = latt.abc[2] + c = lattice.abc[2] elif self.get_space_group_symbol().startswith( "A" ): # change to C-centering to match Setyawan/Curtarolo convention transf[2] = [1, 0, 0] - a, b = sorted(latt.abc[1:]) + a, b = sorted(lattice.abc[1:]) sorted_dic = sorted( - ({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [1, 2]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [1, 2]), key=lambda k: k["length"], ) for idx in range(2): transf[idx][sorted_dic[idx]["orig_index"]] = 1 - c = latt.abc[0] + c = lattice.abc[0] else: for idx, dct in enumerate(sorted_dic): transf[idx][dct["orig_index"]] = 1 a, b, c = sorted_lengths - latt = Lattice.orthorhombic(a, b, c) + lattice = Lattice.orthorhombic(a, b, c) elif latt_type == "tetragonal": # find the "a" vectors @@ -632,7 +634,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee if abs(b - c) < tol < abs(a - c): a, c = c, a transf = np.dot([[0, 0, 1], [0, 1, 0], [1, 0, 0]], transf) - latt = Lattice.tetragonal(a, c) + lattice = Lattice.tetragonal(a, c) elif latt_type in ("hexagonal", "rhombohedral"): # for the conventional cell representation, # we always show the rhombohedral lattices as hexagonal @@ -640,7 +642,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee # check first if we have the refined structure shows a rhombohedral # cell # if so, make a supercell - a, b, c = latt.abc + a, b, c = lattice.abc if np.all(np.abs([a - b, c - b, a - c]) < 0.001): struct.make_supercell(((1, -1, 0), (0, 1, -1), (1, 1, 1))) a, b, c = sorted(struct.lattice.abc) @@ -652,7 +654,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee [a / 2, a * math.sqrt(3) / 2, 0], [0, 0, c], ] - latt = Lattice(new_matrix) + lattice = Lattice(new_matrix) transf = np.eye(3, 3) elif latt_type == "monoclinic": @@ -662,15 +664,15 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf = np.zeros(shape=(3, 3)) transf[2] = [0, 0, 1] sorted_dic = sorted( - ({"vec": latt.matrix[i], "length": latt.abc[i], "orig_index": i} for i in [0, 1]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [0, 1]), key=lambda k: k["length"], ) a = sorted_dic[0]["length"] b = sorted_dic[1]["length"] - c = latt.abc[2] + c = lattice.abc[2] new_matrix = None for t in itertools.permutations(list(range(2)), 2): - m = latt.matrix + m = lattice.matrix latt2 = Lattice([m[t[0]], m[t[1]], m[2]]) lengths = latt2.lengths angles = latt2.angles @@ -717,8 +719,9 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee # keep the ones with the non-90 angle=alpha # and b 90 and b < c: a, b, c, alpha, beta, gamma = Lattice([-m[t[0]], -m[t[1]], m[t[2]]]).parameters @@ -733,6 +736,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee [0, c * cos(alpha), c * sin(alpha)], ] continue + if alpha < 90 and b < c: transf = np.zeros(shape=(3, 3)) transf[0][t[0]] = 1 @@ -744,6 +748,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee [0, b, 0], [0, c * cos(alpha), c * sin(alpha)], ] + if new_matrix is None: # this if is to treat the case # where alpha==90 (but we still have a monoclinic sg @@ -769,15 +774,15 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf = np.dot(op, transf) new_matrix = np.dot(op, new_matrix) - latt = Lattice(new_matrix) + lattice = Lattice(new_matrix) elif latt_type == "triclinic": # we use a LLL Minkowski-like reduction for the triclinic cells struct = struct.get_reduced_structure("LLL") - latt = struct.lattice + lattice = struct.lattice - a, b, c = latt.lengths - alpha, beta, gamma = (math.pi * i / 180 for i in latt.angles) + a, b, c = lattice.lengths + alpha, beta, gamma = (math.pi * i / 180 for i in lattice.angles) new_matrix = None test_matrix = [ [a, 0, 0], @@ -854,11 +859,11 @@ def is_all_acute_or_obtuse(matrix) -> bool: transf = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] new_matrix = test_matrix - latt = Lattice(new_matrix) + lattice = Lattice(new_matrix) new_coords = np.dot(transf, np.transpose(struct.frac_coords)).T new_struct = Structure( - latt, + lattice, struct.species_and_occu, new_coords, site_properties=struct.site_properties, @@ -881,8 +886,8 @@ def get_kpoint_weights(self, kpoints, atol=1e-5): kpts = np.array(kpoints) shift = [] mesh = [] - for i in range(3): - nonzero = [i for i in kpts[:, i] if abs(i) > 1e-5] + for idx in range(3): + nonzero = [i for i in kpts[:, idx] if abs(i) > 1e-5] if len(nonzero) != len(kpts): # gamma centered if not nonzero: @@ -902,11 +907,11 @@ def get_kpoint_weights(self, kpoints, atol=1e-5): grid = (np.array(grid) + np.array(shift) * (0.5, 0.5, 0.5)) / mesh weights = [] mapped = defaultdict(int) - for k in kpoints: - for i, g in enumerate(grid): - if np.allclose(pbc_diff(k, g), (0, 0, 0), atol=atol): + for kpt in kpoints: + for idx, g in enumerate(grid): + if np.allclose(pbc_diff(kpt, g), (0, 0, 0), atol=atol): mapped[tuple(g)] += 1 - weights.append(mapping.count(mapping[i])) + weights.append(mapping.count(mapping[idx])) break if (len(mapped) != len(set(mapping))) or (not all(v == 1 for v in mapped.values())): raise ValueError("Unable to find 1:1 corresponding between input kpoints and irreducible grid!") @@ -1505,9 +1510,9 @@ def cluster_sites(mol: Molecule, tol: float, give_only_index: bool = False) -> t origin site, instead of the site itself. Defaults to False. Returns: - (origin_site, clustered_sites): origin_site is a site at the center - of mass (None if there are no origin atoms). clustered_sites is a - dict of {(avg_dist, species_and_occu): [list of sites]} + tuple[Site | None, dict]: origin_site is a site at the center + of mass (None if there are no origin atoms). clustered_sites is a + dict of {(avg_dist, species_and_occu): [list of sites]} """ # Cluster works for dim > 2 data. We just add a dummy 0 for second # coordinate. diff --git a/pymatgen/symmetry/bandstructure.py b/pymatgen/symmetry/bandstructure.py index e0f7fd3fe4f..69cc639bb66 100644 --- a/pymatgen/symmetry/bandstructure.py +++ b/pymatgen/symmetry/bandstructure.py @@ -270,6 +270,8 @@ def _get_klabels(self, lm_bs, sc_bs, hin_bs, rpg): unlabeled[label_a] = coord_a for label_a, coord_a in unlabeled.items(): + key = None + for op in rpg: coord_a_t = np.dot(op, coord_a) key = [ diff --git a/pymatgen/symmetry/groups.py b/pymatgen/symmetry/groups.py index dd0b863995c..46df57c9439 100644 --- a/pymatgen/symmetry/groups.py +++ b/pymatgen/symmetry/groups.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike + from typing_extensions import Self # don't import at runtime to avoid circular import from pymatgen.core.lattice import Lattice @@ -239,17 +240,17 @@ def __init__(self, int_symbol: str) -> None: # TODO: Support different origin choices. enc = list(data["enc"]) inversion = int(enc.pop(0)) - ngen = int(enc.pop(0)) + n_gen = int(enc.pop(0)) symm_ops = [np.eye(4)] if inversion: symm_ops.append(np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])) - for _ in range(ngen): - m = np.eye(4) - m[:3, :3] = SpaceGroup.gen_matrices[enc.pop(0)] - m[0, 3] = SpaceGroup.translations[enc.pop(0)] - m[1, 3] = SpaceGroup.translations[enc.pop(0)] - m[2, 3] = SpaceGroup.translations[enc.pop(0)] - symm_ops.append(m) + for _ in range(n_gen): + matrix = np.eye(4) + matrix[:3, :3] = SpaceGroup.gen_matrices[enc.pop(0)] + matrix[0, 3] = SpaceGroup.translations[enc.pop(0)] + matrix[1, 3] = SpaceGroup.translations[enc.pop(0)] + matrix[2, 3] = SpaceGroup.translations[enc.pop(0)] + symm_ops.append(matrix) self.generators = symm_ops self.full_symbol = data["full_symbol"] self.point_group = data["point_group"] @@ -292,12 +293,15 @@ def get_settings(cls, int_symbol: str) -> set[str]: set[str]: All possible settings for the given international symbol. """ symbols = [] + int_number = None if int_symbol in SpaceGroup.abbrev_sg_mapping: symbols.append(SpaceGroup.abbrev_sg_mapping[int_symbol]) int_number = SpaceGroup.sg_encoding[int_symbol]["int_number"] + elif int_symbol in SpaceGroup.full_sg_mapping: symbols.append(SpaceGroup.full_sg_mapping[int_symbol]) int_number = SpaceGroup.sg_encoding[int_symbol]["int_number"] + else: for spg in SpaceGroup.SYMM_OPS: if int_symbol in [ @@ -476,7 +480,7 @@ def is_supergroup(self, subgroup: SymmetryGroup) -> bool: return subgroup.is_subgroup(self) @classmethod - def from_int_number(cls, int_number: int, hexagonal: bool = True) -> SpaceGroup: + def from_int_number(cls, int_number: int, hexagonal: bool = True) -> Self: """Obtains a SpaceGroup from its international number. Args: @@ -495,7 +499,7 @@ def from_int_number(cls, int_number: int, hexagonal: bool = True) -> SpaceGroup: symbol = sg_symbol_from_int_number(int_number, hexagonal=hexagonal) if not hexagonal and int_number in (146, 148, 155, 160, 161, 166, 167): symbol += ":R" - return SpaceGroup(symbol) + return cls(symbol) def __repr__(self) -> str: symbol = self.symbol @@ -553,7 +557,7 @@ def in_array_list(array_list: list[np.ndarray] | np.ndarray, arr: np.ndarray, to tol (float): The tolerance. Defaults to 1e-5. If 0, an exact match is done. Returns: - (bool) + bool: True if arr is in array_list. """ if len(array_list) == 0: return False diff --git a/pymatgen/symmetry/kpath.py b/pymatgen/symmetry/kpath.py index 98c02e050a4..fc26429fe43 100644 --- a/pymatgen/symmetry/kpath.py +++ b/pymatgen/symmetry/kpath.py @@ -12,7 +12,6 @@ import numpy as np import spglib from monty.dev import requires -from scipy.linalg import sqrtm from pymatgen.core.lattice import Lattice from pymatgen.core.operations import MagSymmOp, SymmOp @@ -1507,18 +1506,18 @@ def _get_key_lines(key_points, bz_as_key_point_inds): # not the face center point (don't need to check it since it's not # shared with other facets) face_center_ind = facet_as_key_point_inds[-1] - for j, ind in enumerate(facet_as_key_point_inds_bndy): + for j, ind in enumerate(facet_as_key_point_inds_bndy, start=-1): if ( - min(ind, facet_as_key_point_inds_bndy[j - 1]), - max(ind, facet_as_key_point_inds_bndy[j - 1]), + min(ind, facet_as_key_point_inds_bndy[j]), + max(ind, facet_as_key_point_inds_bndy[j]), ) not in key_lines: key_lines.append( ( - min(ind, facet_as_key_point_inds_bndy[j - 1]), - max(ind, facet_as_key_point_inds_bndy[j - 1]), + min(ind, facet_as_key_point_inds_bndy[j]), + max(ind, facet_as_key_point_inds_bndy[j]), ) ) - k = j + 1 if j != len(facet_as_key_point_inds_bndy) - 1 else 0 + k = j + 2 if j != len(facet_as_key_point_inds_bndy) - 2 else 0 if ( min(ind, facet_as_key_point_inds_bndy[k]), max(ind, facet_as_key_point_inds_bndy[k]), @@ -1683,10 +1682,10 @@ def _get_magnetic_symmetry_operations(self, struct, grey_ops, atol): sites = [site for idx, site in enumerate(struct) if idx in nonzero_magmom_inds] init_site_coords = [site.frac_coords for site in sites] for op in grey_ops: - r = op.rotation_matrix + rot_mat = op.rotation_matrix t = op.translation_vector - xformed_magmoms = [self._apply_op_to_magmom(r, magmom) for magmom in init_magmoms] - xformed_site_coords = [np.dot(r, site.frac_coords) + t for site in sites] + xformed_magmoms = [self._apply_op_to_magmom(rot_mat, magmom) for magmom in init_magmoms] + xformed_site_coords = [np.dot(rot_mat, site.frac_coords) + t for site in sites] permutation = ["a" for i in range(len(sites))] not_found = list(range(len(sites))) for i in range(len(sites)): @@ -1849,160 +1848,54 @@ def _get_IRBZ(self, recip_point_group, W, key_points, face_center_inds, atol): used_axes = [] + def find_face_center(name: str, IRBZ_points): + for rotn in rpgdict["rotations"][name]: + ax = rotn["axis"] + op = rotn["op"] + rot_boundaries = None + + if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( + op, IRBZ_points, atol + ): + face_center_found = False + for point in IRBZ_points: + if point[0] in face_center_inds: + cross = D * np.dot(g_inv, np.cross(ax, point[1])) + if not np.allclose(cross, 0, atol=atol): + rot_boundaries = [cross, -1 * np.dot(op, cross)] + face_center_found = True + used_axes.append(ax) + break + + if not face_center_found: + print("face center not found") + for point in IRBZ_points: + cross = D * np.dot(g_inv, np.cross(ax, point[1])) + if not np.allclose(cross, 0, atol=atol): + rot_boundaries = [cross, -1 * np.dot(op, cross)] + used_axes.append(ax) + break + + if rot_boundaries is None: + raise RuntimeError("Failed to find rotation boundaries.") + + return self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) + return IRBZ_points + return IRBZ_points + # six-fold rotoinversion always comes with horizontal mirror so don't # need to check - for rotn in rpgdict["rotations"]["six-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) - - for rotn in rpgdict["rotations"]["rotoinv-four-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, np.dot(op, cross)] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) - - for rotn in rpgdict["rotations"]["four-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) - - for rotn in rpgdict["rotations"]["rotoinv-three-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [ - cross, - -1 * np.dot(sqrtm(-1 * op), cross), - ] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) - - for rotn in rpgdict["rotations"]["three-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) - - for rotn in rpgdict["rotations"]["two-fold"]: - ax = rotn["axis"] - op = rotn["op"] - if not np.any([np.allclose(ax, usedax, atol) for usedax in used_axes]) and self._op_maps_IRBZ_to_self( - op, IRBZ_points, atol - ): - face_center_found = False - for point in IRBZ_points: - if point[0] in face_center_inds: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - face_center_found = True - used_axes.append(ax) - break - if not face_center_found: - print("face center not found") - for point in IRBZ_points: - cross = D * np.dot(g_inv, np.cross(ax, point[1])) - if not np.allclose(cross, 0, atol=atol): - rot_boundaries = [cross, -1 * np.dot(op, cross)] - used_axes.append(ax) - break - IRBZ_points = self._reduce_IRBZ(IRBZ_points, rot_boundaries, g, atol) + IRBZ_points = find_face_center("six-fold", IRBZ_points) + + IRBZ_points = find_face_center("rotoinv-four-fold", IRBZ_points) + + IRBZ_points = find_face_center("four-fold", IRBZ_points) + + IRBZ_points = find_face_center("rotoinv-three-fold", IRBZ_points) + + IRBZ_points = find_face_center("three-fold", IRBZ_points) + + IRBZ_points = find_face_center("two-fold", IRBZ_points) return [point[0] for point in IRBZ_points] @@ -2035,22 +1928,29 @@ def _get_reciprocal_point_group_dict(recip_point_group, atol): if np.isclose(tr, 3, atol=atol): continue if np.isclose(tr, -1, atol=atol): # two-fold rotation + ax = None for j in range(3): if np.isclose(evals[j], 1, atol=atol): ax = evects[:, j] dct["rotations"]["two-fold"].append({"ind": idx, "axis": ax, "op": op}) + elif np.isclose(tr, 0, atol=atol): # three-fold rotation + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] dct["rotations"]["three-fold"].append({"ind": idx, "axis": ax, "op": op}) + # four-fold rotation elif np.isclose(tr, 1, atol=atol): + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] dct["rotations"]["four-fold"].append({"ind": idx, "axis": ax, "op": op}) + elif np.isclose(tr, 2, atol=atol): # six-fold rotation + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] @@ -2060,24 +1960,32 @@ def _get_reciprocal_point_group_dict(recip_point_group, atol): if np.isclose(det, -1, atol=atol): if np.isclose(tr, -3, atol=atol): dct["inversion"].append({"ind": idx, "op": PAR}) + elif np.isclose(tr, 1, atol=atol): # two-fold rotation + norm = None for j in range(3): if np.isclose(evals[j], -1, atol=atol): norm = evects[:, j] dct["reflections"].append({"ind": idx, "normal": norm, "op": op}) + elif np.isclose(tr, 0, atol=atol): # three-fold rotoinversion + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] dct["rotations"]["rotoinv-three-fold"].append({"ind": idx, "axis": ax, "op": op}) + # four-fold rotoinversion elif np.isclose(tr, -1, atol=atol): + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] dct["rotations"]["rotoinv-four-fold"].append({"ind": idx, "axis": ax, "op": op}) + # six-fold rotoinversion elif np.isclose(tr, -2, atol=atol): + ax = None for j in range(3): if np.isreal(evals[j]) and np.isclose(np.absolute(evals[j]), 1, atol=atol): ax = evects[:, j] @@ -2212,7 +2120,7 @@ def _get_max_cosine_labels(self, max_cosine_orbits_orig, key_points_inds_orbits, pop_orbits.append(grouped_inds[idx][worst_next_choice]) pop_labels.append(initial_max_cosine_label_inds[grouped_inds[idx][worst_next_choice]]) - if len(unassigned_orbits) != 0: + if unassigned_orbits: max_cosine_orbits_copy = self._reduce_cosines_array(max_cosine_orbits_copy, pop_orbits, pop_labels) unassigned_orbits_labels = self._get_orbit_labels(max_cosine_orbits_copy, key_points_inds_orbits, atol) for idx, unassigned_orbit in enumerate(unassigned_orbits): diff --git a/pymatgen/symmetry/maggroups.py b/pymatgen/symmetry/maggroups.py index 004171e2c1c..e0c2e0afadf 100644 --- a/pymatgen/symmetry/maggroups.py +++ b/pymatgen/symmetry/maggroups.py @@ -21,6 +21,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.core.lattice import Lattice __author__ = "Matthew Horton, Shyue Ping Ong" @@ -191,10 +193,10 @@ def get_label(idx): n = 1 # nth Wyckoff site num_wyckoff = b[0] while len(wyckoff_sites) < num_wyckoff: - m = b[1 + o] # multiplicity - label = str(b[2 + o] * m) + get_label(num_wyckoff - n) + multiplicity = b[1 + o] + label = str(b[2 + o] * multiplicity) + get_label(num_wyckoff - n) sites = [] - for j in range(m): + for j in range(multiplicity): s = b[3 + o + (j * 22) : 3 + o + (j * 22) + 22] # data corresponding to specific Wyckoff position translation_vec = [s[0] / s[3], s[1] / s[3], s[2] / s[3]] matrix = [ @@ -225,7 +227,7 @@ def get_label(idx): # could do something else with these in future wyckoff_sites.append({"label": label, "str": " ".join(s["str"] for s in sites)}) n += 1 - o += m * 22 + 2 + o += multiplicity * 22 + 2 return wyckoff_sites @@ -284,7 +286,7 @@ def _parse_transformation(b): db.close() @classmethod - def from_og(cls, label: Sequence[int] | str) -> MagneticSpaceGroup: + def from_og(cls, label: Sequence[int] | str) -> Self: """Initialize from Opechowski and Guccione (OG) label or number. Args: diff --git a/pymatgen/symmetry/settings.py b/pymatgen/symmetry/settings.py index a329c4f4466..ba22014efce 100644 --- a/pymatgen/symmetry/settings.py +++ b/pymatgen/symmetry/settings.py @@ -4,6 +4,7 @@ import re from fractions import Fraction +from typing import TYPE_CHECKING import numpy as np from sympy import Matrix @@ -13,6 +14,9 @@ from pymatgen.core.operations import MagSymmOp, SymmOp from pymatgen.util.string import transformation_to_string +if TYPE_CHECKING: + from typing_extensions import Self + __author__ = "Matthew Horton" __copyright__ = "Copyright 2017, The Materials Project" __version__ = "0.1" @@ -56,7 +60,7 @@ def __init__(self, P, p): self._P, self._p = P, p @classmethod - def from_transformation_str(cls, transformation_string="a,b,c;0,0,0"): + def from_transformation_str(cls, transformation_string: str = "a,b,c;0,0,0") -> Self: """Construct SpaceGroupTransformation from its transformation string. Args: @@ -69,7 +73,7 @@ def from_transformation_str(cls, transformation_string="a,b,c;0,0,0"): return cls(P, p) @classmethod - def from_origin_shift(cls, origin_shift="0,0,0"): + def from_origin_shift(cls, origin_shift: str = "0,0,0") -> Self: """Construct SpaceGroupTransformation from its origin shift string. Args: @@ -138,8 +142,8 @@ def p(self) -> list[float]: @property def inverse(self) -> JonesFaithfulTransformation: """JonesFaithfulTransformation.""" - Q = np.linalg.inv(self.P) - return JonesFaithfulTransformation(Q, -np.matmul(Q, self.p)) + P_inv = np.linalg.inv(self.P) + return JonesFaithfulTransformation(P_inv, -np.matmul(P_inv, self.p)) @property def transformation_string(self) -> str: @@ -157,9 +161,9 @@ def transform_symmop(self, symmop: SymmOp | MagSymmOp) -> SymmOp | MagSymmOp: """Takes a symmetry operation and transforms it.""" W_rot = symmop.rotation_matrix w_translation = symmop.translation_vector - Q = np.linalg.inv(self.P) - W_ = np.matmul(np.matmul(Q, W_rot), self.P) - w_ = np.matmul(Q, (w_translation + np.matmul(W_rot - np.identity(3), self.p))) + P_inv = np.linalg.inv(self.P) + W_ = np.matmul(np.matmul(P_inv, W_rot), self.P) + w_ = np.matmul(P_inv, (w_translation + np.matmul(W_rot - np.identity(3), self.p))) w_ = np.mod(w_, 1.0) if isinstance(symmop, MagSymmOp): return MagSymmOp.from_rotation_and_translation_and_time_reversal( @@ -176,8 +180,8 @@ def transform_coords(self, coords: list[list[float]] | np.ndarray) -> list[list[ """Takes a list of coordinates and transforms them.""" new_coords = [] for x in coords: - Q = np.linalg.inv(self.P) - x_ = np.matmul(Q, (np.array(x) - self.p)) + P_inv = np.linalg.inv(self.P) + x_ = np.matmul(P_inv, (np.array(x) - self.p)) new_coords.append(x_.tolist()) return new_coords diff --git a/pymatgen/symmetry/site_symmetries.py b/pymatgen/symmetry/site_symmetries.py index 9568fe103b4..1b59ec0e60b 100644 --- a/pymatgen/symmetry/site_symmetries.py +++ b/pymatgen/symmetry/site_symmetries.py @@ -21,7 +21,7 @@ def get_site_symmetries(struct: Structure, precision: float = 0.1) -> list[list[ struct: Pymatgen structure precision (float): tolerance to find symmetry operations - Return: + Returns: list of lists of point operations for each atomic site """ point_ops: list[list[SymmOp]] = [] @@ -55,7 +55,7 @@ def get_shared_symmetry_operations(struct: Structure, pointops: list[list[SymmOp pointops: list of point group operations from get_site_symmetries method tol (float): tolerance to find symmetry operations - Return: + Returns: list of lists of shared point operations for each pair of atomic sites """ n_sites = len(struct) diff --git a/pymatgen/symmetry/structure.py b/pymatgen/symmetry/structure.py index 682bd278c0f..5003fe4a27c 100644 --- a/pymatgen/symmetry/structure.py +++ b/pymatgen/symmetry/structure.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from pymatgen.symmetry.analyzer import SpacegroupOperations @@ -133,7 +135,7 @@ def as_dict(self): } @classmethod - def from_dict(cls, dct): + def from_dict(cls, dct: dict) -> Self: # type: ignore[override] """ Args: dct (dict): Dict representation. diff --git a/pymatgen/transformations/advanced_transformations.py b/pymatgen/transformations/advanced_transformations.py index e0325f87713..338387283c3 100644 --- a/pymatgen/transformations/advanced_transformations.py +++ b/pymatgen/transformations/advanced_transformations.py @@ -40,15 +40,16 @@ ) from pymatgen.transformations.transformation_abc import AbstractTransformation -if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - from typing import Any - try: import hiphive except ImportError: hiphive = None +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Any + + __author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose" logger = logging.getLogger(__name__) @@ -228,7 +229,7 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | dummy_sp: self.r_fraction, } } - trans = SubstitutionTransformation(mapping) # type: ignore + trans = SubstitutionTransformation(mapping) # type: ignore[arg-type] dummy_structure = trans.apply_transformation(structure) if self.charge_balance_species is not None: cbt = ChargeBalanceTransformation(self.charge_balance_species) @@ -385,7 +386,7 @@ def apply_transformation( raise ValueError(f"Too many disordered sites! ({n_disordered} > {self.max_disordered_sites})") max_cell_sizes: Iterable[int] = range( self.min_cell_size, - int(math.floor(self.max_disordered_sites / n_disordered)) + 1, + math.floor(self.max_disordered_sites / n_disordered) + 1, ) else: max_cell_sizes = [self.max_cell_size] @@ -415,6 +416,8 @@ def apply_transformation( original_latt = structure.lattice inv_latt = np.linalg.inv(original_latt.matrix) ewald_matrices = {} + m3gnet_model = None + if not callable(self.sort_criteria) and self.sort_criteria.startswith("m3gnet"): import matgl from matgl.ext.ase import M3GNetCalculator, Relaxer @@ -451,13 +454,15 @@ def _get_stats(struct): relax_results = m3gnet_model.relax(struct) energy = float(relax_results["trajectory"].energies[-1]) struct = relax_results["final_structure"] - else: - from pymatgen.io.ase import AseAtomsAdaptor + elif self.sort_criteria == "m3gnet": atoms = AseAtomsAdaptor().get_atoms(struct) m3gnet_model.calculate(atoms) energy = float(m3gnet_model.results["energy"]) + else: + raise ValueError("Unsupported sort criteria.") + return { "num_sites": len(struct), "energy": energy, @@ -578,14 +583,16 @@ def __init__( site_constraint_name=None, site_constraints=None, ): - """:param order_parameter (float): any number from 0.0 to 1.0, - typically 0.5 (antiferromagnetic) or 1.0 (ferromagnetic) - :param species_constraint (list): str or list of strings - of Species symbols that the constraint should apply to - :param site_constraint_name (str): name of the site property - that the constraint should apply to, e.g. "coordination_no" - :param site_constraints (list): list of values of the site - property that the constraints should apply to + """ + Args: + order_parameter (float): any number from 0.0 to 1.0, + typically 0.5 (antiferromagnetic) or 1.0 (ferromagnetic) + species_constraints (list): str or list of strings + of Species symbols that the constraint should apply to + site_constraint_name (str): name of the site property + that the constraint should apply to, e.g. "coordination_no" + site_constraints (list): list of values of the site + property that the constraints should apply to. """ # validation if site_constraints and site_constraints != [None] and not site_constraint_name: @@ -632,22 +639,24 @@ class MagOrderingTransformation(AbstractTransformation): """ def __init__(self, mag_species_spin, order_parameter=0.5, energy_model=None, **kwargs): - """:param mag_species_spin: A mapping of elements/species to their - spin magnitudes, e.g. {"Fe3+": 5, "Mn3+": 4} - :param order_parameter (float or list): if float, a specifies a - global order parameter and can take values from 0.0 to 1.0 - (e.g. 0.5 for antiferromagnetic or 1.0 for ferromagnetic), if - list has to be a list of - pymatgen.transformations.advanced_transformations.MagOrderParameterConstraint - to specify more complicated orderings, see documentation for - MagOrderParameterConstraint more details on usage - :param energy_model: Energy model to rank the returned structures, - see :mod: `pymatgen.analysis.energy_models` for more information (note - that this is not necessarily a physical energy). By default, returned - structures use SymmetryModel() which ranks structures from most - symmetric to least. - :param kwargs: Additional kwargs that are passed to - EnumerateStructureTransformation such as min_cell_size etc. + """ + Args: + mag_species_spin: A mapping of elements/species to their + spin magnitudes, e.g. {"Fe3+": 5, "Mn3+": 4} + order_parameter (float or list): if float, a specifies a + global order parameter and can take values from 0.0 to 1.0 + (e.g. 0.5 for antiferromagnetic or 1.0 for ferromagnetic), if + list has to be a list of + pymatgen.transformations.advanced_transformations.MagOrderParameterConstraint + to specify more complicated orderings, see documentation for + MagOrderParameterConstraint more details on usage + energy_model: Energy model to rank the returned structures, + see :mod: `pymatgen.analysis.energy_models` for more information (note + that this is not necessarily a physical energy). By default, returned + structures use SymmetryModel() which ranks structures from most + symmetric to least. + kwargs: Additional kwargs that are passed to + EnumerateStructureTransformation such as min_cell_size etc. """ # checking for sensible order_parameter values if isinstance(order_parameter, float): @@ -710,8 +719,10 @@ def lcm(n1, n2): @staticmethod def _add_dummy_species(structure, order_parameters): - """:param structure: ordered Structure - :param order_parameters: list of MagOrderParameterConstraints + """ + Args: + structure: ordered Structure + order_parameters: list of MagOrderParameterConstraints. Returns: A structure decorated with disordered @@ -765,7 +776,7 @@ def _remove_dummy_species(structure): merged with the original sites. Used after performing enumeration. """ if not structure.is_ordered: - raise Exception("Something went wrong with enumeration.") + raise RuntimeError("Something went wrong with enumeration.") sites_to_remove = [] logger.debug(f"Dummy species structure:\n{structure}") @@ -780,7 +791,7 @@ def _remove_dummy_species(structure): include_index=True, ) if len(neighbors) != 1: - raise Exception(f"This shouldn't happen, found neighbors: {neighbors}") + raise RuntimeError(f"This shouldn't happen, found {neighbors=}") orig_site_idx = neighbors[0][2] orig_specie = structure[orig_site_idx].specie new_specie = Species( @@ -2069,6 +2080,9 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | sqs_kwargs=self.icet_sqs_kwargs, ).run() + else: + raise RuntimeError(f"Unsupported SQS method {self.sqs_method}.") + return self._get_unique_best_sqs_structs( sqs, best_only=self.best_only, diff --git a/pymatgen/transformations/site_transformations.py b/pymatgen/transformations/site_transformations.py index bfb346508a5..296c87ec04a 100644 --- a/pymatgen/transformations/site_transformations.py +++ b/pymatgen/transformations/site_transformations.py @@ -49,12 +49,12 @@ def __init__(self, species, coords, coords_are_cartesian=False, validate_proximi def apply_transformation(self, structure: Structure): """Apply the transformation. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. - Return: - Returns a copy of structure with sites inserted. + Returns: + A copy of structure with sites inserted. """ struct = structure.copy() for idx, sp in enumerate(self.species): @@ -72,12 +72,12 @@ def __repr__(self): @property def inverse(self): - """Return: None.""" + """Returns None.""" return @property def is_one_to_many(self) -> bool: - """Return: False.""" + """Returns False.""" return False @@ -99,31 +99,31 @@ def __init__(self, indices_species_map): def apply_transformation(self, structure: Structure): """Apply the transformation. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. - Return: - Returns a copy of structure with sites replaced. + Returns: + A copy of structure with sites replaced. """ struct = structure.copy() - for i, sp in self.indices_species_map.items(): - struct[int(i)] = sp + for idx, sp in self.indices_species_map.items(): + struct[int(idx)] = sp return struct def __repr__(self): return "ReplaceSiteSpeciesTransformation :" + ", ".join( - [f"{k}->{v}" + v for k, v in self.indices_species_map.items()] + [f"{key}->{val}" + val for key, val in self.indices_species_map.items()] ) @property def inverse(self): - """Return: None.""" + """Returns None.""" return @property def is_one_to_many(self) -> bool: - """Return: False.""" + """Returns False.""" return False @@ -140,12 +140,12 @@ def __init__(self, indices_to_remove): def apply_transformation(self, structure: Structure): """Apply the transformation. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. - Return: - Returns a copy of structure with sites removed. + Returns: + A copy of structure with sites removed. """ struct = structure.copy() struct.remove_sites(self.indices_to_remove) @@ -156,12 +156,12 @@ def __repr__(self): @property def inverse(self): - """Return: None.""" + """Returns None.""" return @property def is_one_to_many(self) -> bool: - """Return: False.""" + """Returns False.""" return False @@ -187,12 +187,12 @@ def __init__(self, indices_to_move, translation_vector, vector_in_frac_coords=Tr def apply_transformation(self, structure: Structure): """Apply the transformation. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. - Return: - Returns a copy of structure with sites translated. + Returns: + A copy of structure with sites translated. """ struct = structure.copy() if self.translation_vector.shape == (len(self.indices_to_move), 3): @@ -219,7 +219,7 @@ def inverse(self): @property def is_one_to_many(self) -> bool: - """Return: False.""" + """Returns False.""" return False def as_dict(self): @@ -487,12 +487,12 @@ def __repr__(self): @property def inverse(self) -> None: - """Return: None.""" + """Returns None.""" return @property def is_one_to_many(self) -> bool: - """Return: True.""" + """Returns True.""" return True @@ -509,12 +509,12 @@ def __init__(self, site_properties): def apply_transformation(self, structure: Structure): """Apply the transformation. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. - Return: - Returns a copy of structure with sites properties added. + Returns: + A copy of structure with sites properties added. """ new_struct = structure.copy() for prop in self.site_properties: @@ -523,12 +523,12 @@ def apply_transformation(self, structure: Structure): @property def inverse(self): - """Return: None.""" + """Returns None.""" return @property def is_one_to_many(self) -> bool: - """Return: False.""" + """Returns False.""" return False diff --git a/pymatgen/transformations/standard_transformations.py b/pymatgen/transformations/standard_transformations.py index ae4bcfed96f..00232e3c99c 100644 --- a/pymatgen/transformations/standard_transformations.py +++ b/pymatgen/transformations/standard_transformations.py @@ -25,6 +25,8 @@ from pymatgen.transformations.transformation_abc import AbstractTransformation if TYPE_CHECKING: + from typing_extensions import Self + from pymatgen.core.sites import PeriodicSite from pymatgen.util.typing import SpeciesLike @@ -209,7 +211,7 @@ def __init__(self, scaling_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1))): self.scaling_matrix = scaling_matrix @classmethod - def from_scaling_factors(cls, scale_a=1, scale_b=1, scale_c=1): + def from_scaling_factors(cls, scale_a: float = 1, scale_b: float = 1, scale_c: float = 1) -> Self: """Convenience method to get a SupercellTransformation from a simple series of three numbers for scaling each lattice vector. Equivalent to calling the normal with [[scale_a, 0, 0], [0, scale_b, 0], @@ -225,10 +227,10 @@ def from_scaling_factors(cls, scale_a=1, scale_b=1, scale_c=1): """ return cls([[scale_a, 0, 0], [0, scale_b, 0], [0, 0, scale_c]]) - @staticmethod + @classmethod def from_boundary_distance( - structure: Structure, min_boundary_dist: float = 6, allow_rotation: bool = False, max_atoms: float = -1 - ) -> SupercellTransformation: + cls, structure: Structure, min_boundary_dist: float = 6, allow_rotation: bool = False, max_atoms: float = -1 + ) -> Self: """Get a SupercellTransformation according to the desired minimum distance between periodic boundaries of the resulting supercell. @@ -241,7 +243,7 @@ def from_boundary_distance( number of atoms than the SupercellTransformation with unchanged lattice angles can possibly be found. If such a SupercellTransformation cannot be found easily, the SupercellTransformation with unchanged lattice angles will be returned. - max_atoms (int): Maximum number of atoms allowed in the supercell. Defaults to infinity. + max_atoms (int): Maximum number of atoms allowed in the supercell. Defaults to -1 for infinity. Returns: SupercellTransformation. @@ -253,21 +255,21 @@ def from_boundary_distance( if allow_rotation and sum(min_expand != 0) > 1: min1, min2, min3 = map(int, min_expand) # type: ignore # map(int) just for mypy's sake scaling_matrix = [ - [min1 if min1 else 1, 1 if min1 and min2 else 0, 1 if min1 and min3 else 0], - [-1 if min2 and min1 else 0, min2 if min2 else 1, 1 if min2 and min3 else 0], - [-1 if min3 and min1 else 0, -1 if min3 and min2 else 0, min3 if min3 else 1], + [min1 or 1, 1 if min1 and min2 else 0, 1 if min1 and min3 else 0], + [-1 if min2 and min1 else 0, min2 or 1, 1 if min2 and min3 else 0], + [-1 if min3 and min1 else 0, -1 if min3 and min2 else 0, min3 or 1], ] struct_scaled = structure.make_supercell(scaling_matrix, in_place=False) min_expand_scaled = np.int8( min_boundary_dist / np.array([struct_scaled.lattice.d_hkl(plane) for plane in np.eye(3)]) ) if sum(min_expand_scaled != 0) == 0 and len(struct_scaled) <= max_atoms: - return SupercellTransformation(scaling_matrix) + return cls(scaling_matrix) scaling_matrix = np.eye(3) + np.diag(min_expand) # type: ignore[assignment] struct_scaled = structure.make_supercell(scaling_matrix, in_place=False) if len(struct_scaled) <= max_atoms: - return SupercellTransformation(scaling_matrix) + return cls(scaling_matrix) msg = f"{max_atoms=} exceeded while trying to solve for supercell. You can try lowering {min_boundary_dist=}" if not allow_rotation: @@ -511,7 +513,7 @@ def __init__(self, algo=ALGO_FAST, symmetrized_structures=False, no_oxi_states=F self.no_oxi_states = no_oxi_states self.symmetrized_structures = symmetrized_structures - def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False): + def apply_transformation(self, structure: Structure, return_ranked_list: bool | int = False) -> Structure: """For this transformation, the apply_transformation method will return only the ordered structure with the lowest Ewald energy, to be consistent with the method signature of the other transformations. @@ -521,7 +523,6 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | Args: structure: Oxidation state decorated disordered structure to order return_ranked_list (bool | int, optional): If return_ranked_list is int, that number of structures - is returned. If False, only the single lowest energy structure is returned. Defaults to False. Returns: @@ -536,11 +537,11 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | transmuted structure class. """ try: - num_to_return = int(return_ranked_list) + n_to_return = int(return_ranked_list) except ValueError: - num_to_return = 1 + n_to_return = 1 - num_to_return = max(1, num_to_return) + n_to_return = max(1, n_to_return) if self.no_oxi_states: structure = Structure.from_sites(structure) @@ -601,12 +602,12 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | manipulations.append([0, empty, list(group), None]) matrix = EwaldSummation(struct).total_energy_matrix - ewald_m = EwaldMinimizer(matrix, manipulations, num_to_return, self.algo) + ewald_m = EwaldMinimizer(matrix, manipulations, n_to_return, self.algo) self._all_structures = [] lowest_energy = ewald_m.output_lists[0][0] - num_atoms = sum(structure.composition.values()) + n_atoms = sum(structure.composition.values()) for output in ewald_m.output_lists: struct_copy = struct.copy() @@ -626,13 +627,13 @@ def apply_transformation(self, structure: Structure, return_ranked_list: bool | self._all_structures.append( { "energy": output[0], - "energy_above_minimum": (output[0] - lowest_energy) / num_atoms, + "energy_above_minimum": (output[0] - lowest_energy) / n_atoms, "structure": struct_copy.get_sorted_structure(), } ) if return_ranked_list: - return self._all_structures[:num_to_return] + return self._all_structures[:n_to_return] # type: ignore[return-value] return self._all_structures[0]["structure"] def __repr__(self): @@ -967,7 +968,7 @@ def apply_transformation(self, structure): """Returns a copy of structure with lattice parameters and sites scaled to the same degree as the relaxed_structure. - Arg: + Args: structure (Structure): A structurally similar structure in regards to crystal and site positions. """ diff --git a/pymatgen/util/coord.py b/pymatgen/util/coord.py index 7f969d873e6..a9c26f3405a 100644 --- a/pymatgen/util/coord.py +++ b/pymatgen/util/coord.py @@ -20,6 +20,8 @@ from numpy.typing import ArrayLike + from pymatgen.util.typing import PbcLike + # array size threshold for looping instead of broadcasting LOOP_THRESHOLD = 1e6 @@ -100,7 +102,7 @@ def coord_list_mapping(subset: ArrayLike, superset: ArrayLike, atol: float = 1e- return inds -def coord_list_mapping_pbc(subset, superset, atol: float = 1e-8, pbc: tuple[bool, bool, bool] = (True, True, True)): +def coord_list_mapping_pbc(subset, superset, atol: float = 1e-8, pbc: PbcLike = (True, True, True)): """Gives the index mapping from a subset to a superset. Superset cannot contain duplicate matching rows. @@ -162,7 +164,7 @@ def all_distances(coords1: ArrayLike, coords2: ArrayLike) -> np.ndarray: return np.sum(z, axis=-1) ** 0.5 -def pbc_diff(fcoords1: ArrayLike, fcoords2: ArrayLike, pbc: tuple[bool, bool, bool] = (True, True, True)): +def pbc_diff(fcoords1: ArrayLike, fcoords2: ArrayLike, pbc: PbcLike = (True, True, True)): """Returns the 'fractional distance' between two coordinates taking into account periodic boundary conditions. @@ -206,9 +208,7 @@ def pbc_shortest_vectors(lattice, fcoords1, fcoords2, mask=None, return_d2: bool return coord_cython.pbc_shortest_vectors(lattice, fcoords1, fcoords2, mask, return_d2) -def find_in_coord_list_pbc( - fcoord_list, fcoord, atol: float = 1e-8, pbc: tuple[bool, bool, bool] = (True, True, True) -) -> np.ndarray: +def find_in_coord_list_pbc(fcoord_list, fcoord, atol: float = 1e-8, pbc: PbcLike = (True, True, True)) -> np.ndarray: """Get the indices of all points in a fractional coord list that are equal to a fractional coord (with a tolerance), taking into account periodic boundary conditions. @@ -225,15 +225,13 @@ def find_in_coord_list_pbc( """ if len(fcoord_list) == 0: return [] - fcoords = np.tile(fcoord, (len(fcoord_list), 1)) - fdist = fcoord_list - fcoords - fdist[:, pbc] -= np.round(fdist)[:, pbc] - return np.where(np.all(np.abs(fdist) < atol, axis=1))[0] + frac_coords = np.tile(fcoord, (len(fcoord_list), 1)) + frac_dist = fcoord_list - frac_coords + frac_dist[:, pbc] -= np.round(frac_dist)[:, pbc] + return np.where(np.all(np.abs(frac_dist) < atol, axis=1))[0] -def in_coord_list_pbc( - fcoord_list, fcoord, atol: float = 1e-8, pbc: tuple[bool, bool, bool] = (True, True, True) -) -> bool: +def in_coord_list_pbc(fcoord_list, fcoord, atol: float = 1e-8, pbc: PbcLike = (True, True, True)) -> bool: """Tests if a particular fractional coord is within a fractional coord_list. Args: @@ -249,9 +247,7 @@ def in_coord_list_pbc( return len(find_in_coord_list_pbc(fcoord_list, fcoord, atol=atol, pbc=pbc)) > 0 -def is_coord_subset_pbc( - subset, superset, atol: float = 1e-8, mask=None, pbc: tuple[bool, bool, bool] = (True, True, True) -) -> bool: +def is_coord_subset_pbc(subset, superset, atol: float = 1e-8, mask=None, pbc: PbcLike = (True, True, True)) -> bool: """Tests if all fractional coords in subset are contained in superset. Args: @@ -269,9 +265,9 @@ def is_coord_subset_pbc( """ c1 = np.array(subset, dtype=np.float64) c2 = np.array(superset, dtype=np.float64) - m = np.array(mask, dtype=int) if mask is not None else np.zeros((len(subset), len(superset)), dtype=int) + mask_arr = np.array(mask, dtype=int) if mask is not None else np.zeros((len(subset), len(superset)), dtype=int) atol = np.zeros(3, dtype=np.float64) + atol - return coord_cython.is_coord_subset_pbc(c1, c2, atol, m, pbc) + return coord_cython.is_coord_subset_pbc(c1, c2, atol, mask_arr, pbc) def lattice_points_in_supercell(supercell_matrix): @@ -378,7 +374,7 @@ def volume(self) -> float: def bary_coords(self, point): """ Args: - point (): Point coordinates. + point (ArrayLike): Point coordinates. Returns: Barycentric coordinations. @@ -391,17 +387,17 @@ def bary_coords(self, point): def point_from_bary_coords(self, bary_coords: ArrayLike): """ Args: - bary_coords (): Barycentric coordinates. + bary_coords (ArrayLike): Barycentric coordinates (d+1, d). Returns: - Point coordinates + np.array: Point in the simplex. """ try: return np.dot(bary_coords, self._aug[:, :-1]) except AttributeError as exc: raise ValueError("Simplex is not full-dimensional") from exc - def in_simplex(self, point, tolerance=1e-8): + def in_simplex(self, point: Sequence[float], tolerance: float = 1e-8) -> bool: """Checks if a point is in the simplex using the standard barycentric coordinate system algorithm. @@ -413,7 +409,7 @@ def in_simplex(self, point, tolerance=1e-8): is in the facet. Args: - point ([float]): Point to test + point (list[float]): Point to test tolerance (float): Tolerance to test if point is in simplex. """ return (self.bary_coords(point) >= -tolerance).all() diff --git a/pymatgen/util/coord_cython.pyx b/pymatgen/util/coord_cython.pyx index a5637305db8..f3a853d2e97 100644 --- a/pymatgen/util/coord_cython.pyx +++ b/pymatgen/util/coord_cython.pyx @@ -21,10 +21,10 @@ from libc.math cimport fabs, round from libc.stdlib cimport free, malloc #create images, 2d array of all length 3 combinations of [-1,0,1] -r = np.arange(-1, 2, dtype=np.float_) -arange = r[:, None] * np.array([1, 0, 0])[None, :] -brange = r[:, None] * np.array([0, 1, 0])[None, :] -crange = r[:, None] * np.array([0, 0, 1])[None, :] +rng = np.arange(-1, 2, dtype=np.float_) +arange = rng[:, None] * np.array([1, 0, 0])[None, :] +brange = rng[:, None] * np.array([0, 1, 0])[None, :] +crange = rng[:, None] * np.array([0, 0, 1])[None, :] images_t = arange[:, None, None] + brange[None, :, None] + \ crange[None, None, :] images = images_t.reshape((27, 3)) @@ -126,9 +126,9 @@ def pbc_shortest_vectors(lattice, fcoords1, fcoords2, mask=None, return_d2=False cdef np.float_t[:, ::1] cart_im = malloc(3 * n_pbc_im * sizeof(np.float_t)) cdef bint has_mask = mask is not None - cdef np.int_t[:, :] m + cdef np.int_t[:, :] mask_arr if has_mask: - m = np.array(mask, dtype=np.int_, copy=False, order="C") + mask_arr = np.array(mask, dtype=np.int_, copy=False, order="C") cdef bint has_ftol = (lll_frac_tol is not None) cdef np.float_t[:] ftol @@ -151,7 +151,7 @@ def pbc_shortest_vectors(lattice, fcoords1, fcoords2, mask=None, return_d2=False for i in range(I): for j in range(J): within_frac = False - if (not has_mask) or (m[i, j] == 0): + if (not has_mask) or (mask_arr[i, j] == 0): within_frac = True if has_ftol: for l in range(3): diff --git a/pymatgen/util/plotting.py b/pymatgen/util/plotting.py index ff9efd3ab5f..70ffdfc6e0c 100644 --- a/pymatgen/util/plotting.py +++ b/pymatgen/util/plotting.py @@ -108,7 +108,7 @@ def pretty_plot_two_axis( linewidth, etc. Returns: - matplotlib.pyplot + plt.Axes: matplotlib axes object with properly sized fonts. """ colors = palettable.colorbrewer.diverging.RdYlBu_4.mpl_colors c1 = colors[0] @@ -120,8 +120,8 @@ def pretty_plot_two_axis( height = int(width * golden_ratio) width = 12 - labelsize = int(width * 3) - ticksize = int(width * 2.5) + label_size = int(width * 3) + tick_size = int(width * 2.5) styles = ["-", "--", "-.", "."] fig, ax1 = plt.subplots() @@ -131,34 +131,34 @@ def pretty_plot_two_axis( if isinstance(y1, dict): for idx, (key, val) in enumerate(y1.items()): ax1.plot(x, val, c=c1, marker="s", ls=styles[idx % len(styles)], label=key, **plot_kwargs) - ax1.legend(fontsize=labelsize) + ax1.legend(fontsize=label_size) else: ax1.plot(x, y1, c=c1, marker="s", ls="-", **plot_kwargs) if xlabel: - ax1.set_xlabel(xlabel, fontsize=labelsize) + ax1.set_xlabel(xlabel, fontsize=label_size) if y1label: # Make the y-axis label, ticks and tick labels match the line color. - ax1.set_ylabel(y1label, color=c1, fontsize=labelsize) + ax1.set_ylabel(y1label, color=c1, fontsize=label_size) - ax1.tick_params("x", labelsize=ticksize) - ax1.tick_params("y", colors=c1, labelsize=ticksize) + ax1.tick_params("x", labelsize=tick_size) + ax1.tick_params("y", colors=c1, labelsize=tick_size) ax2 = ax1.twinx() if isinstance(y2, dict): for idx, (key, val) in enumerate(y2.items()): ax2.plot(x, val, c=c2, marker="o", ls=styles[idx % len(styles)], label=key) - ax2.legend(fontsize=labelsize) + ax2.legend(fontsize=label_size) else: ax2.plot(x, y2, c=c2, marker="o", ls="-") if y2label: # Make the y-axis label, ticks and tick labels match the line color. - ax2.set_ylabel(y2label, color=c2, fontsize=labelsize) + ax2.set_ylabel(y2label, color=c2, fontsize=label_size) - ax2.tick_params("y", colors=c2, labelsize=ticksize) - return plt + ax2.tick_params("y", colors=c2, labelsize=tick_size) + return ax1 def pretty_polyfit_plot(x: ArrayLike, y: ArrayLike, deg: int = 1, xlabel=None, ylabel=None, **kwargs): @@ -244,14 +244,17 @@ def periodic_table_heatmap( pymatviz (bool): Whether to use pymatviz to generate the heatmap. Defaults to True. See https://github.com/janosh/pymatviz. kwargs: Passed to pymatviz.ptable_heatmap_plotly + + Returns: + plt.Axes: matplotlib Axes object """ if pymatviz: try: from pymatviz import ptable_heatmap_plotly if elemental_data: - kwargs.setdefault("elem_values", elemental_data) - print('elemental_data is deprecated, use elem_values={"Fe": 4.2, "O": 5.0} instead') + kwargs.setdefault("values", elemental_data) + print('elemental_data is deprecated, use values={"Fe": 4.2, "O": 5.0} instead') if cbar_label: kwargs.setdefault("color_bar", {}).setdefault("title", cbar_label) print('cbar_label is deprecated, use color_bar={"title": cbar_label} instead') @@ -265,8 +268,8 @@ def periodic_table_heatmap( kwargs.setdefault("cscale_range", cmap_range) print("cmap_range is deprecated, use cscale_range instead") if value_format: - kwargs.setdefault("precision", value_format) - print("value_format is deprecated, use precision instead") + kwargs.setdefault("fmt", value_format) + print("value_format is deprecated, use fmt instead") if blank_color != "grey": print("blank_color is deprecated") if edge_color != "white": @@ -340,7 +343,7 @@ def periodic_table_heatmap( # Grey out missing elements in input data cbar.cmap.set_under(blank_color) - # Set the colorbar label and tick marks + # Set the color bar label and tick marks cbar.set_label(cbar_label, rotation=270, labelpad=25, size=cbar_label_size) cbar.ax.tick_params(labelsize=cbar_label_size) @@ -353,15 +356,15 @@ def periodic_table_heatmap( scalar_cmap = cm.ScalarMappable(norm=norm, cmap=cmap) # Label each block with corresponding element and value - for i, row in enumerate(value_table): - for j, el in enumerate(row): + for ii, row in enumerate(value_table): + for jj, el in enumerate(row): if not np.isnan(el): - symbol = Element.from_row_and_group(i + 1, j + 1).symbol + symbol = Element.from_row_and_group(ii + 1, jj + 1).symbol rgba = scalar_cmap.to_rgba(el) fontcolor = _decide_fontcolor(rgba) if readable_fontcolor else "black" plt.text( - j + 0.5, - i + 0.25, + jj + 0.5, + ii + 0.25, symbol, horizontalalignment="center", verticalalignment="center", @@ -370,8 +373,8 @@ def periodic_table_heatmap( ) if el != blank_value and value_format is not None: plt.text( - j + 0.5, - i + 0.5, + jj + 0.5, + ii + 0.5, value_format % el, horizontalalignment="center", verticalalignment="center", @@ -384,7 +387,7 @@ def periodic_table_heatmap( if show_plot: plt.show() - return plt + return ax def format_formula(formula: str) -> str: @@ -396,12 +399,12 @@ def format_formula(formula: str) -> str: """ formatted_formula = "" number_format = "" - for idx, char in enumerate(formula): + for idx, char in enumerate(formula, start=1): if char.isdigit(): if not number_format: number_format = "_{" number_format += char - if idx == len(formula) - 1: + if idx == len(formula): number_format += "}" formatted_formula += number_format else: @@ -416,14 +419,14 @@ def format_formula(formula: str) -> str: def van_arkel_triangle(list_of_materials: Sequence, annotate: bool = True): """A static method that generates a binary van Arkel-Ketelaar triangle to - quantify the ionic, metallic and covalent character of a compound - by plotting the electronegativity difference (y) vs average (x). - See: - A.E. van Arkel, Molecules and Crystals in Inorganic Chemistry, - Interscience, New York (1956) - and - J.A.A Ketelaar, Chemical Constitution (2nd edition), An Introduction - to the Theory of the Chemical Bond, Elsevier, New York (1958). + quantify the ionic, metallic and covalent character of a compound + by plotting the electronegativity difference (y) vs average (x). + See: + A.E. van Arkel, Molecules and Crystals in Inorganic Chemistry, + Interscience, New York (1956) + and + J.A.A Ketelaar, Chemical Constitution (2nd edition), An Introduction + to the Theory of the Chemical Bond, Elsevier, New York (1958). Args: list_of_materials (list): A list of computed entries of binary @@ -431,6 +434,9 @@ def van_arkel_triangle(list_of_materials: Sequence, annotate: bool = True): annotate (bool): Whether or not to label the points on the triangle with reduced formula (if list of entries) or pair of elements (if list of list of str). + + Returns: + plt.Axes: matplotlib Axes object """ # F-Fr has the largest X difference. We set this # as our top corner of the triangle (most ionic) @@ -521,7 +527,7 @@ def van_arkel_triangle(list_of_materials: Sequence, annotate: bool = True): alpha=0.8, ) - # Label the triangle with datapoints + # Label the triangle with data points for entry in list_of_materials: if type(entry).__name__ not in ["ComputedEntry", "ComputedStructureEntry"]: X_pair = [Element(el).X for el in entry] @@ -539,7 +545,7 @@ def van_arkel_triangle(list_of_materials: Sequence, annotate: bool = True): ) plt.tight_layout() - return plt + return ax def get_ax_fig(ax: Axes = None, **kwargs) -> tuple[Axes, Figure]: @@ -691,7 +697,7 @@ def wrapper(*args, **kwargs): return fig # Add docstring to the decorated method. - s = """\n\n + doc_str = """\n\n Keyword arguments controlling the display of the figure: ================ ==================================================== @@ -714,9 +720,9 @@ def wrapper(*args, **kwargs): if wrapper.__doc__ is not None: # Add s at the end of the docstring. - wrapper.__doc__ += "\n" + s + wrapper.__doc__ += f"\n{doc_str}" else: # Use s - wrapper.__doc__ = s + wrapper.__doc__ = doc_str return wrapper diff --git a/pymatgen/util/provenance.py b/pymatgen/util/provenance.py index fb4e457d74f..438bc06ef19 100644 --- a/pymatgen/util/provenance.py +++ b/pymatgen/util/provenance.py @@ -8,22 +8,24 @@ import sys from collections import namedtuple from io import StringIO +from typing import TYPE_CHECKING from monty.json import MontyDecoder, MontyEncoder +from pymatgen.core.structure import Molecule, Structure + try: from pybtex import errors from pybtex.database.input import bibtex except ImportError: - pybtex = bibtex = None - -from typing import TYPE_CHECKING - -from pymatgen.core.structure import Molecule, Structure + pybtex = bibtex = errors = None if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + + __author__ = "Anubhav Jain, Shyue Ping Ong" __credits__ = "Dan Gunter" @@ -75,8 +77,8 @@ def as_dict(self) -> dict[str, str]: """Returns: Dict.""" return {"name": self.name, "url": self.url, "description": self.description} - @staticmethod - def from_dict(dct: dict[str, str]) -> HistoryNode: + @classmethod + def from_dict(cls, dct: dict[str, str]) -> Self: """ Args: dct (dict): Dict representation. @@ -84,25 +86,24 @@ def from_dict(dct: dict[str, str]) -> HistoryNode: Returns: HistoryNode """ - return HistoryNode(dct["name"], dct["url"], dct["description"]) + return cls(dct["name"], dct["url"], dct["description"]) - @staticmethod - def parse_history_node(h_node): + @classmethod + def parse_history_node(cls, h_node) -> Self: """Parses a History Node object from either a dict or a tuple. Args: - h_node: A dict with name/url/description fields or a 3-element - tuple. + h_node: A dict with name/url/description fields or a 3-element tuple. Returns: - History node. + HistoryNode """ if isinstance(h_node, dict): - return HistoryNode.from_dict(h_node) + return cls.from_dict(h_node) if len(h_node) != 3: raise ValueError(f"Invalid History node, should be dict or (name, version, description) tuple: {h_node}") - return HistoryNode(h_node[0], h_node[1], h_node[2]) + return cls(h_node[0], h_node[1], h_node[2]) class Author(namedtuple("Author", ["name", "email"])): @@ -120,19 +121,19 @@ def as_dict(self): """Returns: MSONable dict.""" return {"name": self.name, "email": self.email} - @staticmethod - def from_dict(d): + @classmethod + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: Author """ - return Author(d["name"], d["email"]) + return cls(dct["name"], dct["email"]) - @staticmethod - def parse_author(author): + @classmethod + def parse_author(cls, author) -> Self: """Parses an Author object from either a String, dict, or tuple. Args: @@ -145,15 +146,15 @@ def parse_author(author): if isinstance(author, str): # Regex looks for whitespace, (any name), whitespace, <, (email), # >, whitespace - m = re.match(r"\s*(.*?)\s*<(.*?@.*?)>\s*", author) - if not m or m.start() != 0 or m.end() != len(author): + match = re.match(r"\s*(.*?)\s*<(.*?@.*?)>\s*", author) + if not match or match.start() != 0 or match.end() != len(author): raise ValueError(f"Invalid author format! {author}") - return Author(m.groups()[0], m.groups()[1]) + return cls(match.groups()[0], match.groups()[1]) if isinstance(author, dict): - return Author.from_dict(author) + return cls.from_dict(author) if len(author) != 2: raise ValueError(f"Invalid author, should be String or (name, email) tuple: {author}") - return Author(author[0], author[1]) + return cls(author[0], author[1]) class StructureNL: @@ -272,30 +273,29 @@ def as_dict(self): return dct @classmethod - def from_dict(cls, d): + def from_dict(cls, dct: dict) -> Self: """ Args: - d (dict): Dict representation. + dct (dict): Dict representation. Returns: Class """ - a = d["about"] - dec = MontyDecoder() + about = dct["about"] - created_at = dec.process_decoded(a.get("created_at")) - data = {k: v for k, v in d["about"].items() if k.startswith("_")} - data = dec.process_decoded(data) + created_at = MontyDecoder().process_decoded(about.get("created_at")) + data = {k: v for k, v in dct["about"].items() if k.startswith("_")} + data = MontyDecoder().process_decoded(data) - structure = Structure.from_dict(d) if "lattice" in d else Molecule.from_dict(d) + structure = Structure.from_dict(dct) if "lattice" in dct else Molecule.from_dict(dct) return cls( structure, - a["authors"], - projects=a.get("projects"), - references=a.get("references", ""), - remarks=a.get("remarks"), + about["authors"], + projects=about.get("projects"), + references=about.get("references", ""), + remarks=about.get("remarks"), data=data, - history=a.get("history"), + history=about.get("history"), created_at=created_at, ) @@ -310,7 +310,7 @@ def from_structures( data=None, histories=None, created_at=None, - ): + ) -> list[Self]: """A convenience method for getting a list of StructureNL objects by specifying structures and metadata separately. Some of the metadata is applied to all of the structures for ease of use. @@ -339,7 +339,7 @@ def from_structures( snl_list = [] for idx, struct in enumerate(structures): - snl = StructureNL( + snl = cls( struct, authors, projects=projects, diff --git a/pymatgen/util/string.py b/pymatgen/util/string.py index 33e6cb9b9ff..cbd7da1c5a2 100644 --- a/pymatgen/util/string.py +++ b/pymatgen/util/string.py @@ -162,26 +162,28 @@ def latexify(formula: str, bold: bool = False): return re.sub(r"([A-Za-z\(\)])([\d\.]+)", r"\1$_{\\mathbf{\2}}$" if bold else r"\1$_{\2}$", formula) -def htmlify(formula): +def htmlify(formula: str) -> str: """Generates a HTML formatted formula, e.g. Fe2O3 is transformed to Fe2O3. Note that Composition now has a to_html_string() method that may be used instead. - :param formula: + Args: + formula: The string to format. """ return re.sub(r"([A-Za-z\(\)])([\d\.]+)", r"\1\2", formula) -def unicodeify(formula): +def unicodeify(formula: str) -> str: """Generates a formula with unicode subscripts, e.g. Fe2O3 is transformed to Feโ‚‚Oโ‚ƒ. Does not support formulae with decimal points. Note that Composition now has a to_unicode_string() method that may be used instead. - :param formula: + Args: + formula: The string to format. """ if "." in formula: raise ValueError("No unicode character exists for subscript period.") @@ -296,11 +298,11 @@ def transformation_to_string(matrix, translation_vec=(0, 0, 0), components=("x", parts = [] for idx in range(3): string = "" - m = matrix[idx] + mat = matrix[idx] offset = translation_vec[idx] for j, dim in enumerate(components): - if m[j] != 0: - f = Fraction(m[j]).limit_denominator() + if mat[j] != 0: + f = Fraction(mat[j]).limit_denominator() if string != "" and f >= 0: string += "+" if abs(f.numerator) != 1: @@ -332,7 +334,8 @@ def disordered_formula(disordered_struct, symbols=("x", "y", "z"), fmt="plain"): species more symbols will need to be added fmt (str): 'plain', 'HTML' or 'LaTeX' - Returns (str): a disordered formula string + Returns: + str: a disordered formula string """ # this is in string utils and not in Composition because we need to have access to # site occupancies to calculate this, so have to pass the full structure as an @@ -395,7 +398,10 @@ def disordered_formula(disordered_struct, symbols=("x", "y", "z"), fmt="plain"): elif fmt == "HTML": sub_start = "" sub_end = "" - elif fmt != "plain": + elif fmt == "plain": + sub_start = "" + sub_end = "" + else: raise ValueError("Unsupported output format, choose from: LaTeX, HTML, plain") disordered_formula = [] diff --git a/pymatgen/util/testing/__init__.py b/pymatgen/util/testing/__init__.py index 759a7c7c3af..c0986b2fe63 100644 --- a/pymatgen/util/testing/__init__.py +++ b/pymatgen/util/testing/__init__.py @@ -10,12 +10,12 @@ import json import pickle # use pickle, not cPickle so that we get the traceback in case of errors import string -import unittest from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar +from unittest import TestCase import pytest -from monty.json import MontyDecoder, MSONable +from monty.json import MontyDecoder, MontyEncoder, MSONable from monty.serialization import loadfn from pymatgen.core import ROOT, SETTINGS, Structure @@ -34,7 +34,7 @@ FAKE_POTCAR_DIR = f"{VASP_IN_DIR}/fake_potcars" -class PymatgenTest(unittest.TestCase): +class PymatgenTest(TestCase): """Extends unittest.TestCase with several assert methods for array and str comparison.""" # dict of lazily-loaded test structures (initialized to None) @@ -132,7 +132,7 @@ def serialize_with_pickle(self, objects: Any, protocols: Sequence[int] | None = return [o[0] for o in objects_by_protocol] return objects_by_protocol - def assert_msonable(self, obj, test_is_subclass=True): + def assert_msonable(self, obj: MSONable, test_is_subclass: bool = True) -> str: """Test if obj is MSONable and verify the contract is fulfilled. By default, the method tests whether obj is an instance of MSONable. @@ -141,4 +141,7 @@ def assert_msonable(self, obj, test_is_subclass=True): if test_is_subclass: assert isinstance(obj, MSONable) assert obj.as_dict() == type(obj).from_dict(obj.as_dict()).as_dict() - json.loads(obj.to_json(), cls=MontyDecoder) + json_str = json.dumps(obj.as_dict(), cls=MontyEncoder) + round_trip = json.loads(json_str, cls=MontyDecoder) + assert issubclass(type(round_trip), type(obj)), f"{type(round_trip)} != {type(obj)}" + return json_str diff --git a/pymatgen/util/typing.py b/pymatgen/util/typing.py index f5fa5a7c0bf..7420fbf26d1 100644 --- a/pymatgen/util/typing.py +++ b/pymatgen/util/typing.py @@ -5,7 +5,7 @@ from __future__ import annotations -from pathlib import Path +from os import PathLike as OsPathLike from typing import TYPE_CHECKING, Any, Union from pymatgen.core import Composition, DummySpecies, Element, Species @@ -18,7 +18,8 @@ from pymatgen.entries.exp_entries import ExpEntry -PathLike = Union[str, Path] +PathLike = Union[str, OsPathLike] +PbcLike = tuple[bool, bool, bool] # Things that can be cast to a Species-like object using get_el_sp SpeciesLike = Union[str, Element, Species, DummySpecies] diff --git a/pymatgen/vis/plotters.py b/pymatgen/vis/plotters.py index 43e4592693f..cae4d50e88c 100644 --- a/pymatgen/vis/plotters.py +++ b/pymatgen/vis/plotters.py @@ -13,7 +13,7 @@ class SpectrumPlotter: """ Class for plotting Spectrum objects and subclasses. Note that the interface is extremely flexible given that there are many different ways in which - people want to view spectra. The typical usage is:: + people want to view spectra. The typical usage is: # Initializes plotter with some optional args. Defaults are usually # fine, diff --git a/pymatgen/vis/structure_chemview.py b/pymatgen/vis/structure_chemview.py index 4a62a215c5c..9007bb9446f 100644 --- a/pymatgen/vis/structure_chemview.py +++ b/pymatgen/vis/structure_chemview.py @@ -3,7 +3,6 @@ from __future__ import annotations import numpy as np -from monty.dev import requires from pymatgen.analysis.molecule_structure_comparator import CovalentRadius from pymatgen.symmetry.analyzer import SpacegroupAnalyzer @@ -15,9 +14,9 @@ chemview_loaded = True except ImportError: chemview_loaded = False + MolecularViewer = get_atom_color = None -@requires(chemview_loaded, "To use quick_view, you need to have chemview installed.") def quick_view( structure, bonds=True, @@ -43,6 +42,10 @@ def quick_view( Returns: A chemview.MolecularViewer object """ + # Ensure MolecularViewer is loaded + if not chemview_loaded: + raise RuntimeError("MolecularViewer not loaded.") + struct = structure.copy() if conventional: struct = SpacegroupAnalyzer(struct).get_conventional_standard_structure() diff --git a/pymatgen/vis/structure_vtk.py b/pymatgen/vis/structure_vtk.py index 33d6f28aad8..cfad254aae4 100644 --- a/pymatgen/vis/structure_vtk.py +++ b/pymatgen/vis/structure_vtk.py @@ -16,9 +16,6 @@ from pymatgen.core import PeriodicSite, Species, Structure from pymatgen.util.coord import in_coord_list -if TYPE_CHECKING: - from collections.abc import Sequence - try: import vtk from vtk import vtkInteractorStyleTrackballCamera as TrackballCamera @@ -27,6 +24,9 @@ vtk = None TrackballCamera = object +if TYPE_CHECKING: + from collections.abc import Sequence + module_dir = os.path.dirname(os.path.abspath(__file__)) EL_COLORS = loadfn(f"{module_dir}/ElementColorSchemes.yaml") @@ -227,11 +227,9 @@ def set_structure(self, structure: Structure, reset_camera=True, to_unit_cell=Tr labels = ["a", "b", "c"] colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] - if has_lattice: - matrix = struct.lattice.matrix + matrix = struct.lattice.matrix if has_lattice else None if self.show_unit_cell and has_lattice: - # matrix = s.lattice.matrix self.add_text([0, 0, 0], "o") for vec in matrix: self.add_line((0, 0, 0), vec, colors[count]) @@ -340,10 +338,8 @@ def add_site(self, site): vis_radius = 0.2 + 0.002 * radius for specie, occu in site.species.items(): - if not specie: - color = (1, 1, 1) - elif specie.symbol in self.el_color_mapping: - color = [i / 255 for i in self.el_color_mapping[specie.symbol]] + color = [i / 255 for i in self.el_color_mapping.get(specie.symbol, (255, 255, 255))] + mapper = self.add_partial_sphere(site.coords, vis_radius, color, start_angle, start_angle + 360 * occu) self.mapper_map[mapper] = [site] start_angle += 360 * occu @@ -487,14 +483,16 @@ def add_polyhedron( # If partial occupations are involved, the color of the specie with # the highest occupation is used max_occu = 0.0 - for specie, occu in center.species.items(): + max_species = next(iter(center.species), None) + for species, occu in center.species.items(): if occu > max_occu: - max_specie = specie + max_species = species max_occu = occu - color = [i / 255 for i in self.el_color_mapping[max_specie.symbol]] + color = [i / 255 for i in self.el_color_mapping[max_species.symbol]] ac.GetProperty().SetColor(color) else: ac.GetProperty().SetColor(color) + if draw_edges: ac.GetProperty().SetEdgeColor(edges_color) ac.GetProperty().SetLineWidth(edges_linewidth) @@ -551,11 +549,12 @@ def add_triangle( # If partial occupations are involved, the color of the specie with # the highest occupation is used max_occu = 0.0 - for specie, occu in center.species.items(): + max_species = next(iter(center.species), None) + for species, occu in center.species.items(): if occu > max_occu: - max_specie = specie + max_species = species max_occu = occu - color = [i / 255 for i in self.el_color_mapping[max_specie.symbol]] + color = [i / 255 for i in self.el_color_mapping[max_species.symbol]] ac.GetProperty().SetColor(color) else: ac.GetProperty().SetColor(color) @@ -602,11 +601,11 @@ def add_faces(self, faces, color, opacity=0.35): for site in face: center += site center /= np.float64(len(face)) - for ii, f in enumerate(face): + for ii, f in enumerate(face, start=1): points = vtk.vtkPoints() triangle = vtk.vtkTriangle() points.InsertNextPoint(f[0], f[1], f[2]) - ii2 = np.mod(ii + 1, len(face)) + ii2 = np.mod(ii, len(face)) points.InsertNextPoint(face[ii2][0], face[ii2][1], face[ii2][2]) points.InsertNextPoint(center[0], center[1], center[2]) for jj in range(3): @@ -868,7 +867,7 @@ def make_movie(structures, output_filename="movie.mp4", zoom=1.0, fps=20, bitrat vis.show_help = False vis.redraw() vis.zoom(zoom) - sig_fig = int(math.floor(math.log10(len(structures))) + 1) + sig_fig = math.floor(math.log10(len(structures))) + 1 filename = f"image{{0:0{sig_fig}d}}.png" for idx, site in enumerate(structures): vis.set_structure(site) @@ -957,17 +956,21 @@ def set_structures(self, structures: Sequence[Structure], tags=None): struct_vis_radii = [] for site in struct: radius = 0 - for specie, occu in site.species.items(): + vis_radius = 0.2 + for species, occu in site.species.items(): radius += occu * ( - specie.ionic_radius - if isinstance(specie, Species) and specie.ionic_radius - else specie.average_ionic_radius + species.ionic_radius + if isinstance(species, Species) and species.ionic_radius + else species.average_ionic_radius ) vis_radius = 0.2 + 0.002 * radius + struct_radii.append(radius) struct_vis_radii.append(vis_radius) + self.all_radii.append(struct_radii) self.all_vis_radii.append(struct_vis_radii) + self.set_structure(self.current_structure, reset_camera=True, to_unit_cell=False) def set_structure(self, structure: Structure, reset_camera=True, to_unit_cell=False): @@ -988,7 +991,7 @@ def apply_tags(self): tags = {} for tag in self.tags: istruct = tag.get("istruct", "all") - if istruct != "all" and istruct != self.istruct: + if istruct not in ("all", self.istruct): continue site_index = tag["site_index"] color = tag.get("color", [0.5, 0.5, 0.5]) diff --git a/pyproject.toml b/pyproject.toml index 64a2afe59e5..5bcb0f6abfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,36 +23,38 @@ repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest target-version = "py39" line-length = 120 lint.select = [ - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "D", # pydocstyle - "E", # pycodestyle error - "EXE", # flake8-executable - "F", # pyflakes - "FA", # flake8-future-annotations - "FBT003", # boolean-positional-value-in-call - "FLY", # flynt - "I", # isort - "ICN", # flake8-import-conventions - "ISC", # flake8-implicit-str-concat - "PD", # pandas-vet - "PERF", # perflint - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PYI", # flakes8-pyi - "Q", # flake8-quotes - "RET", # flake8-return - "RSE", # flake8-raise - "RUF", # Ruff-specific rules - "SIM", # flake8-simplify - "SLOT", # flake8-slots - "TCH", # flake8-type-checking - "TID", # tidy imports - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warning - "YTT", # flake8-2020 + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # pycodestyle error + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT003", # boolean-positional-value-in-call + "FLY", # flynt + "I", # isort + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "PD", # pandas-vet + "PERF", # perflint + "PIE", # flake8-pie + "PL", # pylint + "PLR0402", + "PLR1714", + "PLR5501", + "PT", # flake8-pytest-style + "PYI", # flakes8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "TCH", # flake8-type-checking + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 ] lint.ignore = [ "B023", # Function definition does not bind loop variable @@ -70,7 +72,6 @@ lint.ignore = [ "PERF401", # manual-list-comprehension (TODO fix these or wait for autofix) "PLC1901", # can be simplified to ... as empty is falsey "PLR", # pylint refactor - "PLW1514", # open() without explicit encoding argument "PLW2901", # Outer for loop variable overwritten by inner assignment target "PT013", # pytest-incorrect-pytest-import "PTH", # prefer pathlib to os.path @@ -142,3 +143,12 @@ rute,reson,titels,ges,scalr,strat,struc,hda,nin,ons,pres,kno,loos,lamda,lew,atom """ skip = "pymatgen/analysis/aflow_prototypes.json" check-filenames = true + +[tool.pyright] +typeCheckingMode = "off" +reportPossiblyUnboundVariable = true +reportUnboundVariable = true +reportMissingImports = false +reportMissingModuleSource = false +reportInvalidTypeForm = false +exclude = ["**/tests"] diff --git a/requirements-optional.txt b/requirements-optional.txt index a7269ecd2c8..b5a29956285 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -1,14 +1,14 @@ ase>=3.22.1 +BoltzTraP2>=22.3.2 chemview>=0.6 -netCDF4>=1.5.8 +f90nml>=1.4.3 fdint>=2.0.2 -phonopy==2.20.0 +galore>=0.7.0 h5py==3.9.0 -BoltzTraP2>=22.3.2 -f90nml>=1.4.3 # hiphive>=0.6 -seekpath>=2.0.1 +icet>=2.2 jarvis-tools>=2022.9.16 -galore>=0.7.0 matgl==1.0.0 -icet>=2.2 +netCDF4>=1.5.8 +phonopy==2.20.0 +seekpath>=2.0.1 diff --git a/setup.py b/setup.py index 673db0f3193..47dd203f786 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup( name="pymatgen", packages=find_namespace_packages(include=["pymatgen.*", "pymatgen.**.*", "cmd_line"]), - version="2024.3.1", + version="2024.4.13", python_requires=">=3.9", install_requires=[ "matplotlib>=1.5", @@ -60,6 +60,7 @@ "pytest-split", "pytest", "ruff", + "typing-extensions", ], "docs": [ "sphinx", diff --git a/tasks.py b/tasks.py index 0f5d23b88ba..7cf78730856 100644 --- a/tasks.py +++ b/tasks.py @@ -14,6 +14,7 @@ import re import subprocess import webbrowser +from typing import TYPE_CHECKING import requests from invoke import task @@ -21,13 +22,17 @@ from pymatgen.core import __version__ +if TYPE_CHECKING: + from invoke import Context + @task -def make_doc(ctx): +def make_doc(ctx: Context) -> None: """ Generate API documentation + run Sphinx. - :param ctx: + Args: + ctx (Context): The context. """ with cd("docs"): ctx.run("touch apidoc/index.rst", warn=True) @@ -70,11 +75,12 @@ def make_doc(ctx): @task -def publish(ctx): +def publish(ctx: Context) -> None: """ Upload release to Pypi using twine. - :param ctx: Context + Args: + ctx (Context): The context. """ ctx.run("rm dist/*.*", warn=True) ctx.run("python setup.py sdist bdist_wheel") @@ -82,7 +88,7 @@ def publish(ctx): @task -def set_ver(ctx, version): +def set_ver(ctx: Context, version: str): with open("setup.py") as file: contents = file.read() contents = re.sub(r"version=([^,]+),", f"version={version!r},", contents) @@ -92,11 +98,12 @@ def set_ver(ctx, version): @task -def release_github(ctx, version): +def release_github(ctx: Context, version: str) -> None: """ Release to Github using Github API. - :param ctx: + Args: + version (str): The version. """ with open("docs/CHANGES.md") as file: contents = file.read() @@ -120,11 +127,12 @@ def release_github(ctx, version): print(response.text) -def post_discourse(version): +def post_discourse(version: str) -> None: """ Post release announcement to http://discuss.matsci.org/c/pymatgen. - :param ctx: + Args: + version (str): The version. """ with open("CHANGES.rst") as file: contents = file.read() @@ -149,11 +157,16 @@ def post_discourse(version): @task -def update_changelog(ctx, version=None, dry_run=False): +def update_changelog(ctx: Context, version: str | None = None, dry_run: bool = False) -> None: """ Create a preliminary change log using the git logs. - :param ctx: + Args: + ctx (invoke.Context): The context object. + version (str, optional): The version to use for the change log. If not provided, it will + use the current date in the format 'YYYY.M.D'. Defaults to None. + dry_run (bool, optional): If True, the function will only print the changes without + updating the actual change log file. Defaults to False. """ version = version or f"{datetime.datetime.now():%Y.%-m.%-d}" output = subprocess.check_output(["git", "log", "--pretty=format:%s", f"v{__version__}..HEAD"]) @@ -191,12 +204,14 @@ def update_changelog(ctx, version=None, dry_run=False): @task -def release(ctx, version=None, nodoc=False): +def release(ctx: Context, version: str | None = None, nodoc: bool = False) -> None: """ Run full sequence for releasing pymatgen. - :param ctx: - :param nodoc: Whether to skip doc generation. + Args: + ctx (invoke.Context): The context object. + version (str, optional): The version to release. + nodoc (bool, optional): Whether to skip documentation generation. """ version = version or f"{datetime.datetime.now():%Y.%-m.%-d}" ctx.run("rm -r dist build pymatgen.egg-info", warn=True) @@ -216,7 +231,7 @@ def release(ctx, version=None, nodoc=False): @task -def open_doc(ctx): +def open_doc(ctx: Context) -> None: """ Open local documentation in web browser. """ @@ -225,6 +240,12 @@ def open_doc(ctx): @task -def lint(ctx): +def lint(ctx: Context) -> None: + """ + Run linting tools. + + Args: + ctx (invoke.Context): The context object. + """ for cmd in ["ruff", "mypy", "ruff format"]: ctx.run(f"{cmd} pymatgen") diff --git a/tests/alchemy/test_filters.py b/tests/alchemy/test_filters.py index cc13c27828d..86ab3941258 100644 --- a/tests/alchemy/test_filters.py +++ b/tests/alchemy/test_filters.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from monty.json import MontyDecoder @@ -70,11 +70,11 @@ def test_as_from_dict(self): assert isinstance(SpecieProximityFilter.from_dict(dct), SpecieProximityFilter) -class TestRemoveDuplicatesFilter(unittest.TestCase): +class TestRemoveDuplicatesFilter(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/TiO2_entries.json") as file: entries = json.load(file, cls=MontyDecoder) - self._struct_list = [e.structure for e in entries] + self._struct_list = [entry.structure for entry in entries] self._sm = StructureMatcher() def test_filter(self): @@ -89,11 +89,11 @@ def test_as_from_dict(self): assert isinstance(RemoveDuplicatesFilter().from_dict(dct), RemoveDuplicatesFilter) -class TestRemoveExistingFilter(unittest.TestCase): +class TestRemoveExistingFilter(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/TiO2_entries.json") as file: entries = json.load(file, cls=MontyDecoder) - self._struct_list = [e.structure for e in entries] + self._struct_list = [entry.structure for entry in entries] self._sm = StructureMatcher() self._existing_structures = self._struct_list[:-1] diff --git a/tests/analysis/chemenv/connectivity/test_connected_components.py b/tests/analysis/chemenv/connectivity/test_connected_components.py index 0d4f05297d2..6692d84879a 100644 --- a/tests/analysis/chemenv/connectivity/test_connected_components.py +++ b/tests/analysis/chemenv/connectivity/test_connected_components.py @@ -22,11 +22,6 @@ from pymatgen.core.structure import Structure from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -try: - import bson # type: ignore -except ModuleNotFoundError: - bson = None # type: ignore[assignment] - __author__ = "waroquiers" @@ -122,24 +117,23 @@ def test_serialization(self): cc = ConnectedComponent(graph=graph) ref_sorted_edges = [[en1, en2], [en1, en3]] - sorted_edges = sorted(sorted(e) for e in cc.graph.edges()) + sorted_edges = sorted(sorted(edge) for edge in cc.graph.edges()) assert sorted_edges == ref_sorted_edges cc_from_dict = ConnectedComponent.from_dict(cc.as_dict()) cc_from_json = ConnectedComponent.from_dict(json.loads(json.dumps(cc.as_dict()))) loaded_cc_list = [cc_from_dict, cc_from_json] - if bson is not None: - bson_data = bson.BSON.encode(cc.as_dict()) - cc_from_bson = ConnectedComponent.from_dict(bson_data.decode()) - loaded_cc_list.append(cc_from_bson) + json_str = self.assert_msonable(cc) + cc_from_json = ConnectedComponent.from_dict(json.loads(json_str)) + loaded_cc_list.append(cc_from_json) for loaded_cc in loaded_cc_list: assert loaded_cc.graph.number_of_nodes() == 3 assert loaded_cc.graph.number_of_edges() == 2 assert set(cc.graph.nodes()) == set(loaded_cc.graph.nodes()) - assert sorted_edges == sorted(sorted(e) for e in loaded_cc.graph.edges()) + assert sorted_edges == sorted(sorted(edge) for edge in loaded_cc.graph.edges()) - for e in sorted_edges: - assert cc.graph[e[0]][e[1]] == loaded_cc.graph[e[0]][e[1]] + for edge in sorted_edges: + assert cc.graph[edge[0]][edge[1]] == loaded_cc.graph[edge[0]][edge[1]] for node in loaded_cc.graph.nodes(): assert isinstance(node.central_site, PeriodicSite) diff --git a/tests/analysis/chemenv/connectivity/test_environment_nodes.py b/tests/analysis/chemenv/connectivity/test_environment_nodes.py index aefdb1c4760..60af01acff7 100644 --- a/tests/analysis/chemenv/connectivity/test_environment_nodes.py +++ b/tests/analysis/chemenv/connectivity/test_environment_nodes.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + from pymatgen.analysis.chemenv.connectivity.environment_nodes import EnvironmentNode from pymatgen.util.testing import PymatgenTest @@ -34,17 +36,16 @@ def test_equal(self): def test_as_dict(self): struct = self.get_structure("SiO2") - en = EnvironmentNode(central_site=struct[2], i_central_site=2, ce_symbol="T:4") + env_node = EnvironmentNode(central_site=struct[2], i_central_site=2, ce_symbol="T:4") - en_from_dict = EnvironmentNode.from_dict(en.as_dict()) - assert en.everything_equal(en_from_dict) + env_node_from_dict = EnvironmentNode.from_dict(env_node.as_dict()) + assert env_node.everything_equal(env_node_from_dict) - if bson is not None: - bson_data = bson.BSON.encode(en.as_dict()) - en_from_bson = EnvironmentNode.from_dict(bson_data.decode()) - assert en.everything_equal(en_from_bson) + json_str = self.assert_msonable(env_node) + env_node_from_json = EnvironmentNode.from_dict(json.loads(json_str)) + assert env_node.everything_equal(env_node_from_json) def test_str(self): struct = self.get_structure("SiO2") - en = EnvironmentNode(central_site=struct[2], i_central_site=2, ce_symbol="T:4") - assert str(en) == "Node #2 Si (T:4)" + env_node = EnvironmentNode(central_site=struct[2], i_central_site=2, ce_symbol="T:4") + assert str(env_node) == "Node #2 Si (T:4)" diff --git a/tests/analysis/chemenv/connectivity/test_structure_connectivity.py b/tests/analysis/chemenv/connectivity/test_structure_connectivity.py index 66baf66c13d..2160950d087 100644 --- a/tests/analysis/chemenv/connectivity/test_structure_connectivity.py +++ b/tests/analysis/chemenv/connectivity/test_structure_connectivity.py @@ -11,11 +11,6 @@ ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -try: - import bson -except ModuleNotFoundError: - bson = None # type: ignore[assignment] - __author__ = "waroquiers" @@ -40,9 +35,8 @@ def test_serialization(self): assert set(sc._graph.nodes()) == set(sc_from_json._graph.nodes()) assert set(sc._graph.edges()) == set(sc_from_json._graph.edges()) - if bson is not None: - bson_data = bson.BSON.encode(sc.as_dict()) - sc_from_bson = StructureConnectivity.from_dict(bson_data.decode()) - assert sc.light_structure_environments == sc_from_bson.light_structure_environments - assert set(sc._graph.nodes()) == set(sc_from_bson._graph.nodes()) - assert set(sc._graph.edges()) == set(sc_from_bson._graph.edges()) + json_str = self.assert_msonable(sc) + sc_from_json = StructureConnectivity.from_dict(json.loads(json_str)) + assert sc.light_structure_environments == sc_from_json.light_structure_environments + assert set(sc._graph.nodes()) == set(sc_from_json._graph.nodes()) + assert set(sc._graph.edges()) == set(sc_from_json._graph.edges()) diff --git a/tests/analysis/chemenv/utils/test_chemenv_config.py b/tests/analysis/chemenv/utils/test_chemenv_config.py index d913c17e73d..d1c03bca235 100644 --- a/tests/analysis/chemenv/utils/test_chemenv_config.py +++ b/tests/analysis/chemenv/utils/test_chemenv_config.py @@ -23,10 +23,10 @@ def test_chemenv_config(self): config.package_options_description() == "Package options :\n" " - Maximum distance factor : 1.8000\n" ' - Default strategy is "SimplestChemenvStrategy" :\n' - " Simplest ChemenvStrategy using fixed angle and distance parameters \n" - " for the definition of neighbors in the Voronoi approach. \n" - " The coordination environment is then given as the one with the \n" - " lowest continuous symmetry measure.\n" + "Simplest ChemenvStrategy using fixed angle and distance parameters \n" + "for the definition of neighbors in the Voronoi approach. \n" + "The coordination environment is then given as the one with the \n" + "lowest continuous symmetry measure.\n" " with options :\n" " - distance_cutoff : 1.4\n" " - angle_cutoff : 0.3\n" diff --git a/tests/analysis/chemenv/utils/test_func_utils.py b/tests/analysis/chemenv/utils/test_func_utils.py index f7832f36d15..265c26d39f4 100644 --- a/tests/analysis/chemenv/utils/test_func_utils.py +++ b/tests/analysis/chemenv/utils/test_func_utils.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import numpy as np import pytest from pytest import approx @@ -15,7 +13,7 @@ __author__ = "waroquiers" -class TestFuncUtils(unittest.TestCase): +class TestFuncUtils: def test_csm_finite_ratio_function(self): max_csm = 8 alpha = 1 diff --git a/tests/analysis/diffraction/test_tem.py b/tests/analysis/diffraction/test_tem.py index 836dc0cf7bc..7b5a79cea67 100644 --- a/tests/analysis/diffraction/test_tem.py +++ b/tests/analysis/diffraction/test_tem.py @@ -78,8 +78,8 @@ def test_get_interplanar_spacings(self): # Test that the appropriate interplanar spacing is returned tem_calc = TEMCalculator() point = [(3, 9, 0)] - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) tet = self.get_structure("Li10GeP2S12") hexa = self.get_structure("Graphite") ortho = self.get_structure("K2O2") @@ -100,8 +100,8 @@ def test_bragg_angles(self): # Test that the appropriate bragg angle is returned. Testing formula with values of x-ray diffraction in # materials project. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(1, 1, 0)] spacings = tem_calc.get_interplanar_spacings(cubic, point) bragg_angles_val = np.arcsin(1.5406 / (2 * spacings[point[0]])) @@ -110,8 +110,8 @@ def test_bragg_angles(self): def test_get_s2(self): # Test that the appropriate s2 factor is returned. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(-10, 3, 0)] spacings = tem_calc.get_interplanar_spacings(cubic, point) angles = tem_calc.bragg_angles(spacings) @@ -121,8 +121,8 @@ def test_get_s2(self): def test_x_ray_factors(self): tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(-10, 3, 0)] spacings = tem_calc.get_interplanar_spacings(cubic, point) angles = tem_calc.bragg_angles(spacings) @@ -135,8 +135,8 @@ def test_electron_scattering_factors(self): # international table of crystallography volume C. Rounding error when converting hkl to sin(theta)/lambda. # Error increases as sin(theta)/lambda is smaller. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) nacl = Structure.from_spacegroup("Fm-3m", Lattice.cubic(5.692), ["Na", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(2, 1, 3)] point_nacl = [(4, 2, 0)] @@ -164,8 +164,8 @@ def test_cell_scattering_factors(self): def test_cell_intensity(self): # Test that bcc structure gives lower intensity for h + k + l != even. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(2, 1, 0)] point2 = [(2, 2, 0)] spacings = tem_calc.get_interplanar_spacings(cubic, point) @@ -179,8 +179,8 @@ def test_cell_intensity(self): def test_normalized_cell_intensity(self): # Test that the method correctly normalizes a value. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) point = [(2, 0, 0)] spacings = tem_calc.get_interplanar_spacings(cubic, point) angles = tem_calc.bragg_angles(spacings) @@ -195,9 +195,9 @@ def test_is_parallel(self): def test_get_first_point(self): tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) + lattice = Lattice.cubic(4.209) points = tem_calc.generate_points(-2, 2) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) first_pt = tem_calc.get_first_point(cubic, points) assert 4.209 in first_pt.values() @@ -205,15 +205,15 @@ def test_interplanar_angle(self): # test interplanar angles. Reference values from KW Andrews, # Interpretation of Electron Diffraction pp70-90. tem_calc = TEMCalculator() - latt = Lattice.cubic(4.209) - cubic = Structure(latt, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + lattice = Lattice.cubic(4.209) + cubic = Structure(lattice, ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) phi = tem_calc.get_interplanar_angle(cubic, (0, 0, -1), (0, -1, 0)) assert phi == approx(90) tet = self.get_structure("Li10GeP2S12") phi = tem_calc.get_interplanar_angle(tet, (0, 0, 1), (1, 0, 3)) assert phi == approx(25.7835, rel=1e-4) - latt = Lattice.hexagonal(2, 4) - hexagonal = Structure(latt, ["Ab"], [[0, 0, 0]]) + lattice = Lattice.hexagonal(2, 4) + hexagonal = Structure(lattice, ["Ab"], [[0, 0, 0]]) phi = tem_calc.get_interplanar_angle(hexagonal, (0, 0, 1), (1, 0, 6)) assert phi == approx(21.0517, rel=1e-4) diff --git a/tests/analysis/elasticity/test_elastic.py b/tests/analysis/elasticity/test_elastic.py index 0a290726523..8f415b7a3b4 100644 --- a/tests/analysis/elasticity/test_elastic.py +++ b/tests/analysis/elasticity/test_elastic.py @@ -335,11 +335,11 @@ def test_get_effective_ecs(self): # Ensure zero strain is same as SOEC test_zero = self.exp_cu.get_effective_ecs(np.zeros((3, 3))) assert_allclose(test_zero, self.exp_cu[0]) - s = np.zeros((3, 3)) - s[0, 0] = 0.02 - test_2percent = self.exp_cu.get_effective_ecs(s) + strain = np.zeros((3, 3)) + strain[0, 0] = 0.02 + test_2percent = self.exp_cu.get_effective_ecs(strain) diff = test_2percent - test_zero - assert_allclose(self.exp_cu[1].einsum_sequence([s]), diff) + assert_allclose(self.exp_cu[1].einsum_sequence([strain]), diff) def test_get_strain_from_stress(self): strain = Strain.from_voigt([0.05, 0, 0, 0, 0, 0]) @@ -411,8 +411,8 @@ def test_get_strain_state_dict(self): strain_states.append(tuple(ss)) vec = np.zeros((4, 6)) rand_values = np.random.uniform(0.1, 1, 4) - for i in strain_ind: - vec[:, i] = rand_values + for idx in strain_ind: + vec[:, idx] = rand_values vecs[strain_ind] = vec all_strains = [Strain.from_voigt(v).zeroed() for vec in vecs.values() for v in vec] random.shuffle(all_strains) diff --git a/tests/analysis/elasticity/test_stress.py b/tests/analysis/elasticity/test_stress.py index fe97ecf95e0..cc79322de83 100644 --- a/tests/analysis/elasticity/test_stress.py +++ b/tests/analysis/elasticity/test_stress.py @@ -36,13 +36,13 @@ def test_properties(self): # von_mises assert self.symm_stress.von_mises == approx(11.52253878275) # piola_kirchoff 1, 2 - f = Deformation.from_index_amount((0, 1), 0.03) + deform = Deformation.from_index_amount((0, 1), 0.03) assert_allclose( - self.symm_stress.piola_kirchoff_1(f), + self.symm_stress.piola_kirchoff_1(deform), [[0.4413, 2.29, 2.42], [2.1358, 5.14, 5.07], [2.2679, 5.07, 5.33]], ) assert_allclose( - self.symm_stress.piola_kirchoff_2(f), + self.symm_stress.piola_kirchoff_2(deform), [[0.377226, 2.1358, 2.2679], [2.1358, 5.14, 5.07], [2.2679, 5.07, 5.33]], ) # voigt diff --git a/tests/analysis/ferroelectricity/test_polarization.py b/tests/analysis/ferroelectricity/test_polarization.py index 6d53def1b66..b28a2430626 100644 --- a/tests/analysis/ferroelectricity/test_polarization.py +++ b/tests/analysis/ferroelectricity/test_polarization.py @@ -17,7 +17,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/vasp/fixtures/BTO_221_99_polarization" bto_folders = ["nonpolar_polarization"] -bto_folders += [f"interpolation_{i}_polarization" for i in range(1, 9)][::-1] +bto_folders += [f"interpolation_{idx}_polarization" for idx in range(8, 0, -1)] bto_folders += ["polar_polarization"] structures = [Structure.from_file(f"{TEST_DIR}/{folder}/POSCAR") for folder in bto_folders] diff --git a/tests/analysis/gb/test_grain.py b/tests/analysis/gb/test_grain.py deleted file mode 100644 index 4fe8dcab866..00000000000 --- a/tests/analysis/gb/test_grain.py +++ /dev/null @@ -1,331 +0,0 @@ -from __future__ import annotations - -import numpy as np -from numpy.testing import assert_allclose -from pytest import approx - -from pymatgen.analysis.gb.grain import GrainBoundary, GrainBoundaryGenerator -from pymatgen.core.structure import Structure -from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest - -__author__ = "Xiang-Guo Li" -__copyright__ = "Copyright 2018, The Materials Virtual Lab" -__email__ = "xil110@eng.ucsd.edu" -__date__ = "07/30/18" - -TEST_DIR = f"{TEST_FILES_DIR}/grain_boundary" - - -class TestGrainBoundary(PymatgenTest): - @classmethod - def setUpClass(cls): - cls.Cu_conv = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_conventional_standard.cif") - GB_Cu_conv = GrainBoundaryGenerator(cls.Cu_conv) - cls.Cu_GB1 = GB_Cu_conv.gb_from_parameters( - [1, 2, 3], - 123.74898859588858, - expand_times=4, - vacuum_thickness=1.5, - ab_shift=[0.0, 0.0], - plane=[1, 3, 1], - rm_ratio=0.0, - ) - cls.Cu_GB2 = GB_Cu_conv.gb_from_parameters( - [1, 2, 3], - 123.74898859588858, - expand_times=4, - vacuum_thickness=1.5, - ab_shift=[0.2, 0.2], - rm_ratio=0.0, - ) - - def test_init(self): - assert self.Cu_GB1.rotation_angle == approx(123.74898859588858) - assert self.Cu_GB1.vacuum_thickness == approx(1.5) - assert self.Cu_GB2.rotation_axis == [1, 2, 3] - assert_allclose(self.Cu_GB1.ab_shift, [0.0, 0.0]) - assert_allclose(self.Cu_GB2.ab_shift, [0.2, 0.2]) - assert self.Cu_GB1.gb_plane == [1, 3, 1] - assert self.Cu_GB2.gb_plane == [1, 2, 3] - assert_allclose(self.Cu_GB1.init_cell.lattice.matrix, self.Cu_conv.lattice.matrix) - - def test_copy(self): - Cu_GB1_copy = self.Cu_GB1.copy() - assert Cu_GB1_copy.sigma == approx(self.Cu_GB1.sigma) - assert Cu_GB1_copy.rotation_angle == approx(self.Cu_GB1.rotation_angle) - assert Cu_GB1_copy.rotation_axis == self.Cu_GB1.rotation_axis - assert Cu_GB1_copy.gb_plane == self.Cu_GB1.gb_plane - assert_allclose(Cu_GB1_copy.init_cell.lattice.matrix, self.Cu_GB1.init_cell.lattice.matrix) - assert_allclose( - Cu_GB1_copy.oriented_unit_cell.lattice.matrix, - self.Cu_GB1.oriented_unit_cell.lattice.matrix, - ) - assert_allclose(Cu_GB1_copy.lattice.matrix, self.Cu_GB1.lattice.matrix) - - def test_sigma(self): - assert self.Cu_GB1.sigma == approx(9) - assert self.Cu_GB2.sigma == approx(9) - - def test_top_grain(self): - assert len(self.Cu_GB1) == len(self.Cu_GB1.top_grain) * 2 - assert_allclose(self.Cu_GB1.lattice.matrix, self.Cu_GB1.top_grain.lattice.matrix) - - def test_bottom_grain(self): - assert len(self.Cu_GB1) == len(self.Cu_GB1.bottom_grain) * 2 - assert_allclose(self.Cu_GB1.lattice.matrix, self.Cu_GB1.bottom_grain.lattice.matrix) - - def test_coincidents(self): - assert len(self.Cu_GB1) / self.Cu_GB1.sigma == len(self.Cu_GB1.coincidents) - assert len(self.Cu_GB2) / self.Cu_GB2.sigma == len(self.Cu_GB2.coincidents) - - def test_as_dict_and_from_dict(self): - d1 = self.Cu_GB1.as_dict() - d2 = self.Cu_GB2.as_dict() - Cu_GB1_new = GrainBoundary.from_dict(d1) - Cu_GB2_new = GrainBoundary.from_dict(d2) - assert Cu_GB1_new.sigma == approx(self.Cu_GB1.sigma) - assert Cu_GB1_new.rotation_angle == approx(self.Cu_GB1.rotation_angle) - assert Cu_GB1_new.rotation_axis == self.Cu_GB1.rotation_axis - assert Cu_GB1_new.gb_plane == self.Cu_GB1.gb_plane - assert_allclose(Cu_GB1_new.init_cell.lattice.matrix, self.Cu_GB1.init_cell.lattice.matrix) - assert_allclose( - Cu_GB1_new.oriented_unit_cell.lattice.matrix, self.Cu_GB1.oriented_unit_cell.lattice.matrix, atol=1e-9 - ) - assert_allclose(Cu_GB1_new.lattice.matrix, self.Cu_GB1.lattice.matrix) - assert Cu_GB2_new.sigma == approx(self.Cu_GB2.sigma) - assert Cu_GB2_new.rotation_angle == approx(self.Cu_GB2.rotation_angle) - assert Cu_GB2_new.rotation_axis == self.Cu_GB2.rotation_axis - assert Cu_GB2_new.gb_plane == self.Cu_GB2.gb_plane - assert_allclose(Cu_GB2_new.init_cell.lattice.matrix, self.Cu_GB2.init_cell.lattice.matrix) - assert_allclose( - Cu_GB2_new.oriented_unit_cell.lattice.matrix, - self.Cu_GB2.oriented_unit_cell.lattice.matrix, - ) - assert_allclose(Cu_GB2_new.lattice.matrix, self.Cu_GB2.lattice.matrix) - - -class TestGrainBoundaryGenerator(PymatgenTest): - @classmethod - def setUpClass(cls): - cls.Cu_prim = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_primitive.cif") - cls.GB_Cu_prim = GrainBoundaryGenerator(cls.Cu_prim) - cls.Cu_conv = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_conventional_standard.cif") - cls.GB_Cu_conv = GrainBoundaryGenerator(cls.Cu_conv) - cls.Be = Structure.from_file(f"{TEST_DIR}/Be_mp-87_conventional_standard.cif") - cls.GB_Be = GrainBoundaryGenerator(cls.Be) - cls.Pa = Structure.from_file(f"{TEST_DIR}/Pa_mp-62_conventional_standard.cif") - cls.GB_Pa = GrainBoundaryGenerator(cls.Pa) - cls.Br = Structure.from_file(f"{TEST_DIR}/Br_mp-23154_conventional_standard.cif") - cls.GB_Br = GrainBoundaryGenerator(cls.Br) - cls.Bi = Structure.from_file(f"{TEST_DIR}/Bi_mp-23152_primitive.cif") - cls.GB_Bi = GrainBoundaryGenerator(cls.Bi) - - def test_gb_from_parameters(self): - # from fcc primitive cell,axis[1,2,3],sigma 9. - gb_cu_123_prim1 = self.GB_Cu_prim.gb_from_parameters([1, 2, 3], 123.74898859588858, expand_times=2) - lat_mat1 = gb_cu_123_prim1.lattice.matrix - c_vec1 = np.cross(lat_mat1[0], lat_mat1[1]) / np.linalg.norm(np.cross(lat_mat1[0], lat_mat1[1])) - c_len1 = np.dot(lat_mat1[2], c_vec1) - vol_ratio = gb_cu_123_prim1.volume / self.Cu_prim.volume - assert vol_ratio == approx(9 * 2 * 2, abs=1e-8) - # test expand_times and vacuum layer - gb_cu_123_prim2 = self.GB_Cu_prim.gb_from_parameters( - [1, 2, 3], 123.74898859588858, expand_times=4, vacuum_thickness=1.5 - ) - lat_mat2 = gb_cu_123_prim2.lattice.matrix - c_vec2 = np.cross(lat_mat2[0], lat_mat2[1]) / np.linalg.norm(np.cross(lat_mat2[0], lat_mat2[1])) - c_len2 = np.dot(lat_mat2[2], c_vec2) - assert (c_len2 - 1.5 * 2) / c_len1 == approx(2) - - # test normal - gb_cu_123_prim3 = self.GB_Cu_prim.gb_from_parameters([1, 2, 3], 123.74898859588858, expand_times=2, normal=True) - lat_mat3 = gb_cu_123_prim3.lattice.matrix - c_vec3 = np.cross(lat_mat3[0], lat_mat3[1]) / np.linalg.norm(np.cross(lat_mat3[0], lat_mat3[1])) - ab_len3 = np.linalg.norm(np.cross(lat_mat3[2], c_vec3)) - assert ab_len3 == approx(0) - - # test normal in tilt boundary - # The 'finfo(np.float32).eps' is the smallest representable positive number in float32, - # which has been introduced because comparing to just zero or one failed the test by rounding errors. - gb_cu_010_conv1 = self.GB_Cu_conv.gb_from_parameters( - rotation_axis=[0, 1, 0], - rotation_angle=36.8698976458, - expand_times=1, - vacuum_thickness=1.0, - ab_shift=[0.0, 0.0], - rm_ratio=0.0, - plane=[0, 0, 1], - normal=True, - ) - assert np.all(-np.finfo(np.float32).eps <= gb_cu_010_conv1.frac_coords) - assert np.all(1 + np.finfo(np.float32).eps >= gb_cu_010_conv1.frac_coords) - - # from fcc conventional cell,axis [1,2,3], siamg 9 - gb_cu_123_conv1 = self.GB_Cu_conv.gb_from_parameters( - [1, 2, 3], 123.74898859588858, expand_times=4, vacuum_thickness=1.5 - ) - lat_mat1 = gb_cu_123_conv1.lattice.matrix - assert np.dot(lat_mat1[0], [1, 2, 3]) == approx(0) - assert np.dot(lat_mat1[1], [1, 2, 3]) == approx(0) - # test plane - gb_cu_123_conv2 = self.GB_Cu_conv.gb_from_parameters( - [1, 2, 3], - 123.74898859588858, - expand_times=2, - vacuum_thickness=1.5, - normal=False, - plane=[1, 3, 1], - ) - lat_mat2 = gb_cu_123_conv2.lattice.matrix - assert np.dot(lat_mat2[0], [1, 3, 1]) == approx(0) - assert np.dot(lat_mat2[1], [1, 3, 1]) == approx(0) - - # from hex cell,axis [1,1,1], sigma 21 - gb_Be_111_1 = self.GB_Be.gb_from_parameters( - [1, 1, 1], - 147.36310249644626, - ratio=[5, 2], - expand_times=4, - vacuum_thickness=1.5, - plane=[1, 2, 1], - ) - lat_priv = self.Be.lattice.matrix - lat_mat1 = np.matmul(gb_Be_111_1.lattice.matrix, np.linalg.inv(lat_priv)) - assert np.dot(lat_mat1[0], [1, 2, 1]) == approx(0) - assert np.dot(lat_mat1[1], [1, 2, 1]) == approx(0) - # test volume associated with sigma value - gb_Be_111_2 = self.GB_Be.gb_from_parameters([1, 1, 1], 147.36310249644626, ratio=[5, 2], expand_times=4) - vol_ratio = gb_Be_111_2.volume / self.Be.volume - assert vol_ratio == approx(19 * 2 * 4) - # test ratio = None, axis [0,0,1], sigma 7 - gb_Be_111_3 = self.GB_Be.gb_from_parameters([0, 0, 1], 21.786789298261812, ratio=[5, 2], expand_times=4) - gb_Be_111_4 = self.GB_Be.gb_from_parameters([0, 0, 1], 21.786789298261812, ratio=None, expand_times=4) - assert gb_Be_111_3.lattice.abc == gb_Be_111_4.lattice.abc - assert gb_Be_111_3.lattice.angles == gb_Be_111_4.lattice.angles - gb_Be_111_5 = self.GB_Be.gb_from_parameters([3, 1, 0], 180.0, ratio=[5, 2], expand_times=4) - gb_Be_111_6 = self.GB_Be.gb_from_parameters([3, 1, 0], 180.0, ratio=None, expand_times=4) - assert gb_Be_111_5.lattice.abc == gb_Be_111_6.lattice.abc - assert gb_Be_111_5.lattice.angles == gb_Be_111_6.lattice.angles - - # gb from tetragonal cell, axis[1,1,1], sigma 15 - gb_Pa_111_1 = self.GB_Pa.gb_from_parameters( - [1, 1, 1], 151.92751306414706, ratio=[2, 3], expand_times=4, max_search=10 - ) - vol_ratio = gb_Pa_111_1.volume / self.Pa.volume - assert vol_ratio == approx(17 * 2 * 4) - - # gb from orthorhombic cell, axis[1,1,1], sigma 83 - gb_Br_111_1 = self.GB_Br.gb_from_parameters( - [1, 1, 1], - 131.5023374652235, - ratio=[21, 20, 5], - expand_times=4, - max_search=10, - ) - vol_ratio = gb_Br_111_1.volume / self.Br.volume - assert vol_ratio == approx(83 * 2 * 4) - - # gb from rhombohedra cell, axis[1,2,0], sigma 63 - gb_Bi_120_1 = self.GB_Bi.gb_from_parameters( - [1, 2, 0], 63.310675060280246, ratio=[19, 5], expand_times=4, max_search=5 - ) - vol_ratio = gb_Bi_120_1.volume / self.Bi.volume - assert vol_ratio == approx(59 * 2 * 4) - - def test_get_ratio(self): - # hexagnal - Be_ratio = self.GB_Be.get_ratio(max_denominator=2) - assert Be_ratio == [5, 2] - Be_ratio = self.GB_Be.get_ratio(max_denominator=5) - assert Be_ratio == [12, 5] - # tetragonal - Pa_ratio = self.GB_Pa.get_ratio(max_denominator=5) - assert Pa_ratio == [2, 3] - # orthorhombic - Br_ratio = self.GB_Br.get_ratio(max_denominator=5) - assert Br_ratio == [21, 20, 5] - # orthorhombic - Bi_ratio = self.GB_Bi.get_ratio(max_denominator=5) - assert Bi_ratio == [19, 5] - - def test_enum_sigma_cubic(self): - true_100 = [5, 13, 17, 25, 29, 37, 41] - true_110 = [3, 9, 11, 17, 19, 27, 33, 41, 43] - true_111 = [3, 7, 13, 19, 21, 31, 37, 39, 43, 49] - sigma_100 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 0, 0])) - sigma_110 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 1, 0])) - sigma_111 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 1, 1])) - sigma_222 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [2, 2, 2])) - sigma_888 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [8, 8, 8])) - - assert sorted(true_100) == sorted(sigma_100) - assert sorted(true_110) == sorted(sigma_110) - assert sorted(true_111) == sorted(sigma_111) - assert sorted(true_111) == sorted(sigma_222) - assert sorted(true_111) == sorted(sigma_888) - - def test_enum_sigma_hex(self): - true_100 = [17, 18, 22, 27, 38, 41] - true_001 = [7, 13, 19, 31, 37, 43, 49] - true_210 = [10, 11, 14, 25, 35, 49] - sigma_100 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [1, 0, 0], [8, 3])) - sigma_001 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [0, 0, 1], [8, 3])) - sigma_210 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [2, 1, 0], [8, 3])) - sigma_420 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [4, 2, 0], [8, 3])) - sigma_840 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [8, 4, 0], [8, 3])) - - assert sorted(true_100) == sorted(sigma_100) - assert sorted(true_001) == sorted(sigma_001) - assert sorted(true_210) == sorted(sigma_210) - assert sorted(true_210) == sorted(sigma_420) - assert sorted(true_210) == sorted(sigma_840) - - def test_enum_sigma_tet(self): - true_100 = [5, 37, 41, 13, 3, 15, 39, 25, 17, 29] - true_331 = [9, 3, 21, 39, 7, 31, 43, 13, 19, 37, 49] - sigma_100 = list(GrainBoundaryGenerator.enum_sigma_tet(50, [1, 0, 0], [9, 1])) - sigma_331 = list(GrainBoundaryGenerator.enum_sigma_tet(50, [3, 3, 1], [9, 1])) - - assert sorted(true_100) == sorted(sigma_100) - assert sorted(true_331) == sorted(sigma_331) - - def test_enum_sigma_ort(self): - true_100 = [41, 37, 39, 5, 15, 17, 13, 3, 25, 29] - sigma_100 = list(GrainBoundaryGenerator.enum_sigma_ort(50, [1, 0, 0], [270, 30, 29])) - - assert sorted(true_100) == sorted(sigma_100) - - def test_enum_sigma_rho(self): - true_100 = [7, 11, 43, 13, 41, 19, 47, 31] - sigma_100 = list(GrainBoundaryGenerator.enum_sigma_rho(50, [1, 0, 0], [15, 4])) - - assert sorted(true_100) == sorted(sigma_100) - - def test_enum_possible_plane_cubic(self): - all_plane = GrainBoundaryGenerator.enum_possible_plane_cubic(4, [1, 1, 1], 60) - assert len(all_plane["Twist"]) == 1 - assert len(all_plane["Symmetric tilt"]) == 6 - assert len(all_plane["Normal tilt"]) == 12 - - def test_get_trans_mat(self): - mat1, mat2 = GrainBoundaryGenerator.get_trans_mat( - [1, 1, 1], - 95.55344419565849, - lat_type="o", - ratio=[10, 20, 21], - surface=[21, 20, 10], - normal=True, - ) - assert np.dot(mat1[0], [21, 20, 10]) == approx(0) - assert np.dot(mat1[1], [21, 20, 10]) == approx(0) - assert np.linalg.det(mat1) == approx(np.linalg.det(mat2)) - ab_len1 = np.linalg.norm(np.cross(mat1[2], [1, 1, 1])) - assert ab_len1 == approx(0) - - def test_get_rotation_angle_from_sigma(self): - true_angle = [12.680383491819821, 167.3196165081802] - angle = GrainBoundaryGenerator.get_rotation_angle_from_sigma(41, [1, 0, 0], lat_type="o", ratio=[270, 30, 29]) - assert_allclose(true_angle, angle) - close_angle = [36.86989764584403, 143.13010235415598] - angle = GrainBoundaryGenerator.get_rotation_angle_from_sigma(6, [1, 0, 0], lat_type="o", ratio=[270, 30, 29]) - assert_allclose(close_angle, angle) diff --git a/tests/analysis/magnetism/test_analyzer.py b/tests/analysis/magnetism/test_analyzer.py index d289c6ffd0e..947c68bba88 100644 --- a/tests/analysis/magnetism/test_analyzer.py +++ b/tests/analysis/magnetism/test_analyzer.py @@ -1,7 +1,7 @@ from __future__ import annotations -import unittest from shutil import which +from unittest import TestCase import pytest from monty.serialization import loadfn @@ -22,7 +22,7 @@ enumlib_present = enum_cmd and makestr_cmd -class TestCollinearMagneticStructureAnalyzer(unittest.TestCase): +class TestCollinearMagneticStructureAnalyzer(TestCase): def setUp(self): self.Fe = Structure.from_file(f"{TEST_FILES_DIR}/Fe.cif", primitive=True) @@ -143,6 +143,10 @@ def test_modes(self): magmoms = msa.structure.site_properties["magmom"] assert magmoms == [1, 0] + # test invalid overwrite_magmom_mode + with pytest.raises(ValueError, match="'invalid_mode' is not a valid OverwriteMagmomMode"): + CollinearMagneticStructureAnalyzer(self.NiO, overwrite_magmom_mode="invalid_mode") + def test_net_positive(self): msa = CollinearMagneticStructureAnalyzer(self.NiO_unphysical) magmoms = msa.structure.site_properties["magmom"] @@ -234,16 +238,16 @@ def test_missing_spin(self): # This test catches the case where a structure has some species with # Species.spin=None. This previously raised an error upon construction # of the analyzer). - latt = Lattice([[2.085, 2.085, 0.0], [0.0, -2.085, -2.085], [-2.085, 2.085, -4.17]]) + lattice = Lattice([[2.085, 2.085, 0.0], [0.0, -2.085, -2.085], [-2.085, 2.085, -4.17]]) species = [Species("Ni", spin=-5), Species("Ni", spin=5), Species("O", spin=None), Species("O", spin=None)] coords = [[0.5, 0, 0.5], [0, 0, 0], [0.25, 0.5, 0.25], [0.75, 0.5, 0.75]] - struct = Structure(latt, species, coords) + struct = Structure(lattice, species, coords) msa = CollinearMagneticStructureAnalyzer(struct, round_magmoms=0.001, make_primitive=False) assert msa.structure.site_properties["magmom"] == [-5, 5, 0, 0] -class TestMagneticStructureEnumerator(unittest.TestCase): +class TestMagneticStructureEnumerator: @pytest.mark.skipif(not enumlib_present, reason="enumlib not present") def test_ordering_enumeration(self): # simple afm @@ -282,7 +286,7 @@ def test_ordering_enumeration(self): assert enumerator.input_origin == "afm_by_motif_2a" -class TestMagneticDeformation(unittest.TestCase): +class TestMagneticDeformation: def test_magnetic_deformation(self): test_structs = loadfn(f"{TEST_FILES_DIR}/magnetic_deformation.json") mag_def = magnetic_deformation(test_structs[0], test_structs[1]) diff --git a/tests/analysis/magnetism/test_heisenberg.py b/tests/analysis/magnetism/test_heisenberg.py index 32d17287153..2007e5fa8f6 100644 --- a/tests/analysis/magnetism/test_heisenberg.py +++ b/tests/analysis/magnetism/test_heisenberg.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pandas as pd @@ -11,7 +11,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/magnetic_orderings" -class TestHeisenbergMapper(unittest.TestCase): +class TestHeisenbergMapper(TestCase): @classmethod def setUpClass(cls): cls.df = pd.read_json(f"{TEST_DIR}/mag_orderings_test_cases.json") @@ -31,9 +31,6 @@ def setUpClass(cls): hm = HeisenbergMapper(ordered_structures, energies, cutoff=5.0, tol=0.02) cls.hms.append(hm) - def setUp(self): - pass - def test_graphs(self): for hm in self.hms: struct_graphs = hm.sgraphs diff --git a/tests/analysis/magnetism/test_jahnteller.py b/tests/analysis/magnetism/test_jahnteller.py index 6f83c107a83..083bbc87e1b 100644 --- a/tests/analysis/magnetism/test_jahnteller.py +++ b/tests/analysis/magnetism/test_jahnteller.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np from pytest import approx @@ -10,76 +10,76 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestJahnTeller(unittest.TestCase): +class TestJahnTeller(TestCase): def setUp(self): self.jt = JahnTellerAnalyzer() def test_jahn_teller_species_analysis(self): # 1 d-shell electron - m = self.jt.get_magnitude_of_effect_from_species("Ti3+", "", "oct") - assert m == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Ti3+", "", "oct") + assert magnitude == "weak" # 2 d-shell electrons - m = self.jt.get_magnitude_of_effect_from_species("Ti2+", "", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("V3+", "", "oct") - assert m == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Ti2+", "", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("V3+", "", "oct") + assert magnitude == "weak" # 3 - m = self.jt.get_magnitude_of_effect_from_species("V2+", "", "oct") - assert m == "none" - m = self.jt.get_magnitude_of_effect_from_species("Cr3+", "", "oct") - assert m == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("V2+", "", "oct") + assert magnitude == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Cr3+", "", "oct") + assert magnitude == "none" # 4 - m = self.jt.get_magnitude_of_effect_from_species("Cr2+", "high", "oct") - assert m == "strong" - m = self.jt.get_magnitude_of_effect_from_species("Cr2+", "low", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("Mn3+", "high", "oct") - assert m == "strong" - m = self.jt.get_magnitude_of_effect_from_species("Mn3+", "low", "oct") - assert m == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Cr2+", "high", "oct") + assert magnitude == "strong" + magnitude = self.jt.get_magnitude_of_effect_from_species("Cr2+", "low", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Mn3+", "high", "oct") + assert magnitude == "strong" + magnitude = self.jt.get_magnitude_of_effect_from_species("Mn3+", "low", "oct") + assert magnitude == "weak" # 5 - m = self.jt.get_magnitude_of_effect_from_species("Mn2+", "high", "oct") - assert m == "none" - m = self.jt.get_magnitude_of_effect_from_species("Mn2+", "low", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("Fe3+", "high", "oct") - assert m == "none" - m = self.jt.get_magnitude_of_effect_from_species("Fe3+", "low", "oct") - assert m == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Mn2+", "high", "oct") + assert magnitude == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Mn2+", "low", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Fe3+", "high", "oct") + assert magnitude == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Fe3+", "low", "oct") + assert magnitude == "weak" # 6 - m = self.jt.get_magnitude_of_effect_from_species("Fe2+", "high", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("Fe2+", "low", "oct") - assert m == "none" - m = self.jt.get_magnitude_of_effect_from_species("Co3+", "high", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("Co3+", "low", "oct") - assert m == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Fe2+", "high", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Fe2+", "low", "oct") + assert magnitude == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Co3+", "high", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Co3+", "low", "oct") + assert magnitude == "none" # 7 - m = self.jt.get_magnitude_of_effect_from_species("Co2+", "high", "oct") - assert m == "weak" - m = self.jt.get_magnitude_of_effect_from_species("Co2+", "low", "oct") - assert m == "strong" + magnitude = self.jt.get_magnitude_of_effect_from_species("Co2+", "high", "oct") + assert magnitude == "weak" + magnitude = self.jt.get_magnitude_of_effect_from_species("Co2+", "low", "oct") + assert magnitude == "strong" # 8 - m = self.jt.get_magnitude_of_effect_from_species("Ni2+", "", "oct") - assert m == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Ni2+", "", "oct") + assert magnitude == "none" # 9 - m = self.jt.get_magnitude_of_effect_from_species("Cu2+", "", "oct") - assert m == "strong" + magnitude = self.jt.get_magnitude_of_effect_from_species("Cu2+", "", "oct") + assert magnitude == "strong" # 10 - m = self.jt.get_magnitude_of_effect_from_species("Cu+", "", "oct") - assert m == "none" - m = self.jt.get_magnitude_of_effect_from_species("Zn2+", "", "oct") - assert m == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Cu+", "", "oct") + assert magnitude == "none" + magnitude = self.jt.get_magnitude_of_effect_from_species("Zn2+", "", "oct") + assert magnitude == "none" def test_jahn_teller_structure_analysis(self): LiFePO4 = Structure.from_file(f"{TEST_FILES_DIR}/LiFePO4.cif", primitive=True) diff --git a/tests/analysis/structure_prediction/test_dopant_predictor.py b/tests/analysis/structure_prediction/test_dopant_predictor.py index c42a1cba0ab..f19d032624d 100644 --- a/tests/analysis/structure_prediction/test_dopant_predictor.py +++ b/tests/analysis/structure_prediction/test_dopant_predictor.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pytest import approx @@ -12,7 +12,7 @@ from pymatgen.core import Species, Structure -class TestDopantPrediction(unittest.TestCase): +class TestDopantPrediction(TestCase): def setUp(self): self.tin_dioxide = Structure( [3.24, 0, 0, 0, 4.83, 0, 0, 0, 4.84], diff --git a/tests/analysis/structure_prediction/test_substitution_probability.py b/tests/analysis/structure_prediction/test_substitution_probability.py index c66a8ec4483..63fec5470b9 100644 --- a/tests/analysis/structure_prediction/test_substitution_probability.py +++ b/tests/analysis/structure_prediction/test_substitution_probability.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from pytest import approx @@ -24,7 +24,7 @@ def get_table(): return json.load(file) -class TestSubstitutionProbability(unittest.TestCase): +class TestSubstitutionProbability(TestCase): def test_full_lambda_table(self): """ This test tests specific values in the data folder. If the @@ -56,7 +56,7 @@ def test_mini_lambda_table(self): assert prob == approx(0.00102673915742, abs=1e-5), "probability isn't correct" -class TestSubstitutionPredictor(unittest.TestCase): +class TestSubstitutionPredictor(TestCase): def test_prediction(self): sp = SubstitutionPredictor(threshold=8e-3) result = sp.list_prediction(["Na+", "Cl-"], to_this_composition=True)[5] diff --git a/tests/analysis/structure_prediction/test_substitutor.py b/tests/analysis/structure_prediction/test_substitutor.py index 8315d8c0d87..a86548d87dc 100644 --- a/tests/analysis/structure_prediction/test_substitutor.py +++ b/tests/analysis/structure_prediction/test_substitutor.py @@ -20,19 +20,19 @@ def get_table(): class TestSubstitutor(PymatgenTest): def setUp(self): - self.s = Substitutor(threshold=1e-3, lambda_table=get_table(), alpha=-5.0) + self.substitutor = Substitutor(threshold=1e-3, lambda_table=get_table(), alpha=-5.0) def test_substitutor(self): s_list = [Species("O", -2), Species("Li", 1)] - subs = self.s.pred_from_list(s_list) + subs = self.substitutor.pred_from_list(s_list) assert len(subs) == 4, "incorrect number of substitutions" comp = Composition({"O2-": 1, "Li1+": 2}) - subs = self.s.pred_from_comp(comp) + subs = self.substitutor.pred_from_comp(comp) assert len(subs) == 4, "incorrect number of substitutions" structures = [{"structure": PymatgenTest.get_structure("Li2O"), "id": "pmgtest"}] - subs = self.s.pred_from_structures(["Na+", "O2-"], structures) + subs = self.substitutor.pred_from_structures(["Na+", "O2-"], structures) assert subs[0].formula == "Na2 O1" def test_as_dict(self): - Substitutor.from_dict(self.s.as_dict()) + Substitutor.from_dict(self.substitutor.as_dict()) diff --git a/tests/analysis/structure_prediction/test_volume_predictor.py b/tests/analysis/structure_prediction/test_volume_predictor.py index c9d02fd7eb0..893f6b805f2 100644 --- a/tests/analysis/structure_prediction/test_volume_predictor.py +++ b/tests/analysis/structure_prediction/test_volume_predictor.py @@ -18,32 +18,32 @@ def test_predict(self): nacl = PymatgenTest.get_structure("CsCl") nacl.replace_species({"Cs": "Na"}) nacl.scale_lattice(184.384551033) - p = RLSVolumePredictor(radii_type="ionic") - assert p.predict(struct, nacl) == approx(342.84905395082535) - p = RLSVolumePredictor(radii_type="atomic") - assert p.predict(struct, nacl) == approx(391.884366481) + predictor = RLSVolumePredictor(radii_type="ionic") + assert predictor.predict(struct, nacl) == approx(342.84905395082535) + predictor = RLSVolumePredictor(radii_type="atomic") + assert predictor.predict(struct, nacl) == approx(391.884366481) lif = PymatgenTest.get_structure("CsCl") lif.replace_species({"Cs": "Li", "Cl": "F"}) - p = RLSVolumePredictor(radii_type="ionic") - assert p.predict(lif, nacl) == approx(74.268402413690467) - p = RLSVolumePredictor(radii_type="atomic") - assert p.predict(lif, nacl) == approx(62.2808125839) + predictor = RLSVolumePredictor(radii_type="ionic") + assert predictor.predict(lif, nacl) == approx(74.268402413690467) + predictor = RLSVolumePredictor(radii_type="atomic") + assert predictor.predict(lif, nacl) == approx(62.2808125839) lfpo = PymatgenTest.get_structure("LiFePO4") lmpo = PymatgenTest.get_structure("LiFePO4") lmpo.replace_species({"Fe": "Mn"}) - p = RLSVolumePredictor(radii_type="ionic") - assert p.predict(lmpo, lfpo) == approx(310.08253254420134) - p = RLSVolumePredictor(radii_type="atomic") - assert p.predict(lmpo, lfpo) == approx(299.607967711) + predictor = RLSVolumePredictor(radii_type="ionic") + assert predictor.predict(lmpo, lfpo) == approx(310.08253254420134) + predictor = RLSVolumePredictor(radii_type="atomic") + assert predictor.predict(lmpo, lfpo) == approx(299.607967711) sto = PymatgenTest.get_structure("SrTiO3") scoo = PymatgenTest.get_structure("SrTiO3") scoo.replace_species({"Ti4+": "Co4+"}) - p = RLSVolumePredictor(radii_type="ionic") - assert p.predict(scoo, sto) == approx(56.162534974936463) - p = RLSVolumePredictor(radii_type="atomic") - assert p.predict(scoo, sto) == approx(57.4777835108) + predictor = RLSVolumePredictor(radii_type="ionic") + assert predictor.predict(scoo, sto) == approx(56.162534974936463) + predictor = RLSVolumePredictor(radii_type="atomic") + assert predictor.predict(scoo, sto) == approx(57.4777835108) # Use Ag7P3S11 as a test case: @@ -51,16 +51,16 @@ def test_predict(self): aps = Structure.from_file(f"{module_dir}/Ag7P3S11_mp-683910_primitive.cif") apo = Structure.from_file(f"{module_dir}/Ag7P3S11_mp-683910_primitive.cif") apo.replace_species({"S": "O"}) - p = RLSVolumePredictor(radii_type="atomic", check_isostructural=False) - assert p.predict(apo, aps) == approx(1196.31384276) + predictor = RLSVolumePredictor(radii_type="atomic", check_isostructural=False) + assert predictor.predict(apo, aps) == approx(1196.31384276) # (ii) Oxidation states are assigned. apo.add_oxidation_state_by_element({"Ag": 1, "P": 5, "O": -2}) aps.add_oxidation_state_by_element({"Ag": 1, "P": 5, "S": -2}) - p = RLSVolumePredictor(radii_type="ionic") - assert p.predict(apo, aps) == approx(1165.23259079) - p = RLSVolumePredictor(radii_type="atomic") - assert p.predict(apo, aps) == approx(1196.31384276) + predictor = RLSVolumePredictor(radii_type="ionic") + assert predictor.predict(apo, aps) == approx(1165.23259079) + predictor = RLSVolumePredictor(radii_type="atomic") + assert predictor.predict(apo, aps) == approx(1196.31384276) def test_modes(self): cs_cl = PymatgenTest.get_structure("CsCl") diff --git a/tests/analysis/test_adsorption.py b/tests/analysis/test_adsorption.py index 00adbb6e279..9f611a2feaa 100644 --- a/tests/analysis/test_adsorption.py +++ b/tests/analysis/test_adsorption.py @@ -117,9 +117,9 @@ def test_generate_adsorption_structures(self): def test_adsorb_both_surfaces(self): # Test out for monatomic adsorption - o = Molecule("O", [[0, 0, 0]]) - ad_slabs = self.asf_100.adsorb_both_surfaces(o) - ad_slabs_one = self.asf_100.generate_adsorption_structures(o) + oxi = Molecule("O", [[0, 0, 0]]) + ad_slabs = self.asf_100.adsorb_both_surfaces(oxi) + ad_slabs_one = self.asf_100.generate_adsorption_structures(oxi) assert len(ad_slabs) == len(ad_slabs_one) for ad_slab in ad_slabs: sg = SpacegroupAnalyzer(ad_slab) diff --git a/tests/analysis/test_bond_dissociation.py b/tests/analysis/test_bond_dissociation.py index 704ae46f3dd..aa660efc7d8 100644 --- a/tests/analysis/test_bond_dissociation.py +++ b/tests/analysis/test_bond_dissociation.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import unittest +from unittest import TestCase import pytest from monty.serialization import loadfn @@ -11,7 +11,7 @@ module_dir = os.path.dirname(os.path.abspath(__file__)) -class TestBondDissociation(unittest.TestCase): +class TestBondDissociation(TestCase): def setUp(self): pytest.importorskip("openbabel") self.PC_65_principle = loadfn(f"{module_dir}/PC_65_principle.json") diff --git a/tests/analysis/test_chempot_diagram.py b/tests/analysis/test_chempot_diagram.py index 5c7584dab4b..3f316b797c7 100644 --- a/tests/analysis/test_chempot_diagram.py +++ b/tests/analysis/test_chempot_diagram.py @@ -22,17 +22,12 @@ class TestChemicalPotentialDiagram(PymatgenTest): def setUp(self): self.entries = EntrySet.from_csv(str(module_dir / "pd_entries_test.csv")) - self.cpd_ternary = ChemicalPotentialDiagram(entries=self.entries, default_min_limit=-25, formal_chempots=False) - self.cpd_ternary_formal = ChemicalPotentialDiagram( - entries=self.entries, default_min_limit=-25, formal_chempots=True - ) - elements = [Element("Fe"), Element("O")] - binary_entries = list( - filter( - lambda e: set(e.elements).issubset(elements), - self.entries, - ) + self.cpd_ternary, self.cpd_ternary_formal = ( + ChemicalPotentialDiagram(entries=self.entries, default_min_limit=-25, formal_chempots=formal) + for formal in [False, True] ) + elements = {Element("Fe"), Element("O")} + binary_entries = [entry for entry in self.entries if set(entry.elements) <= elements] self.cpd_binary = ChemicalPotentialDiagram(entries=binary_entries, default_min_limit=-25, formal_chempots=False) def test_dim(self): @@ -139,216 +134,183 @@ def test_get_plot(self): def test_domains(self): correct_domains = { - "Fe": np.array( - [ - [-25.0, -6.596147, -25.0], - [-25.0, -6.596147, -7.115354], - [-3.931615, -6.596147, -7.115354], - [-3.625002, -6.596147, -7.268661], - [-3.351598, -6.596147, -7.610416], - [-1.913015, -6.596147, -25.0], - [-1.913015, -6.596147, -10.487582], - ] - ), - "Fe2O3": np.array( - [ - [-25.0, -10.739688, -4.258278], - [-25.0, -7.29639, -6.55381], - [-5.550202, -10.739688, -4.258278], - [-5.406275, -10.451834, -4.450181], - [-4.35446, -7.29639, -6.55381], - ] - ), - "Fe3O4": np.array( - [ - [-25.0, -7.29639, -6.55381], - [-25.0, -6.741594, -6.969907], - [-4.35446, -7.29639, -6.55381], - [-4.077062, -6.741594, -6.969907], - ] - ), - "FeO": np.array( - [ - [-25.0, -6.741594, -6.969907], - [-25.0, -6.596147, -7.115354], - [-4.077062, -6.741594, -6.969907], - [-3.931615, -6.596147, -7.115354], - ] - ), - "Li": np.array( - [ - [-1.913015, -25.0, -25.0], - [-1.913015, -25.0, -10.487582], - [-1.913015, -6.596147, -25.0], - [-1.913015, -6.596147, -10.487582], - ] - ), - "Li2FeO3": np.array( - [ - [-5.550202, -10.739688, -4.258278], - [-5.442823, -10.954446, -4.258278], - [-5.406275, -10.451834, -4.450181], - [-4.739887, -10.251509, -4.961215], - [-4.662209, -9.707768, -5.194246], - ] - ), - "Li2O": np.array( - [ - [-4.612511, -25.0, -5.088591], - [-4.612511, -10.378885, -5.088591], - [-3.351598, -6.596147, -7.610416], - [-1.913015, -25.0, -10.487582], - [-1.913015, -6.596147, -10.487582], - ] - ), - "Li2O2": np.array( - [ - [-5.442823, -25.0, -4.258278], - [-5.442823, -10.954446, -4.258278], - [-4.739887, -10.251509, -4.961215], - [-4.612511, -25.0, -5.088591], - [-4.612511, -10.378885, -5.088591], - ] - ), - "Li5FeO4": np.array( - [ - [-4.739887, -10.251509, -4.961215], - [-4.662209, -9.707768, -5.194246], - [-4.612511, -10.378885, -5.088591], - [-3.625002, -6.596147, -7.268661], - [-3.351598, -6.596147, -7.610416], - ] - ), - "LiFeO2": np.array( - [ - [-5.406275, -10.451834, -4.450181], - [-4.662209, -9.707768, -5.194246], - [-4.35446, -7.29639, -6.55381], - [-4.077062, -6.741594, -6.969907], - [-3.931615, -6.596147, -7.115354], - [-3.625002, -6.596147, -7.268661], - ] - ), - "O2": np.array( - [ - [-25.0, -25.0, -4.258278], - [-25.0, -10.739688, -4.258278], - [-5.550202, -10.739688, -4.258278], - [-5.442823, -25.0, -4.258278], - [-5.442823, -10.954446, -4.258278], - ] - ), + "Fe": [ + [-25.0, -6.596147, -25.0], + [-25.0, -6.596147, -7.115354], + [-3.931615, -6.596147, -7.115354], + [-3.625002, -6.596147, -7.268661], + [-3.351598, -6.596147, -7.610416], + [-1.913015, -6.596147, -25.0], + [-1.913015, -6.596147, -10.487582], + ], + "Fe2O3": [ + [-25.0, -10.739688, -4.258278], + [-25.0, -7.29639, -6.55381], + [-5.550202, -10.739688, -4.258278], + [-5.406275, -10.451834, -4.450181], + [-4.35446, -7.29639, -6.55381], + ], + "Fe3O4": [ + [-25.0, -7.29639, -6.55381], + [-25.0, -6.741594, -6.969907], + [-4.35446, -7.29639, -6.55381], + [-4.077062, -6.741594, -6.969907], + ], + "FeO": [ + [-25.0, -6.741594, -6.969907], + [-25.0, -6.596147, -7.115354], + [-4.077062, -6.741594, -6.969907], + [-3.931615, -6.596147, -7.115354], + ], + "Li": [ + [-1.913015, -25.0, -25.0], + [-1.913015, -25.0, -10.487582], + [-1.913015, -6.596147, -25.0], + [-1.913015, -6.596147, -10.487582], + ], + "Li2FeO3": [ + [-5.550202, -10.739688, -4.258278], + [-5.442823, -10.954446, -4.258278], + [-5.406275, -10.451834, -4.450181], + [-4.739887, -10.251509, -4.961215], + [-4.662209, -9.707768, -5.194246], + ], + "Li2O": [ + [-4.612511, -25.0, -5.088591], + [-4.612511, -10.378885, -5.088591], + [-3.351598, -6.596147, -7.610416], + [-1.913015, -25.0, -10.487582], + [-1.913015, -6.596147, -10.487582], + ], + "Li2O2": [ + [-5.442823, -25.0, -4.258278], + [-5.442823, -10.954446, -4.258278], + [-4.739887, -10.251509, -4.961215], + [-4.612511, -25.0, -5.088591], + [-4.612511, -10.378885, -5.088591], + ], + "Li5FeO4": [ + [-4.739887, -10.251509, -4.961215], + [-4.662209, -9.707768, -5.194246], + [-4.612511, -10.378885, -5.088591], + [-3.625002, -6.596147, -7.268661], + [-3.351598, -6.596147, -7.610416], + ], + "LiFeO2": [ + [-5.406275, -10.451834, -4.450181], + [-4.662209, -9.707768, -5.194246], + [-4.35446, -7.29639, -6.55381], + [-4.077062, -6.741594, -6.969907], + [-3.931615, -6.596147, -7.115354], + [-3.625002, -6.596147, -7.268661], + ], + "O2": [ + [-25.0, -25.0, -4.258278], + [-25.0, -10.739688, -4.258278], + [-5.550202, -10.739688, -4.258278], + [-5.442823, -25.0, -4.258278], + [-5.442823, -10.954446, -4.258278], + ], } for formula, domain in correct_domains.items(): - d = self.cpd_ternary.domains[formula] - d = d.round(6) # to get rid of numerical errors from qhull - actual_domain_sorted = d[np.lexsort((d[:, 2], d[:, 1], d[:, 0]))] - assert actual_domain_sorted == approx(domain) + dom = self.cpd_ternary.domains[formula] + dom = dom.round(6) # to get rid of numerical errors from qhull + actual_domain_sorted = dom[np.lexsort((dom[:, 2], dom[:, 1], dom[:, 0]))] + assert actual_domain_sorted == approx(np.array(domain)) formal_domains = { - "FeO": np.array( - [ - [-2.50000000e01, 3.55271368e-15, -2.85707600e00], - [-2.01860032e00, 3.55271368e-15, -2.85707600e00], - [-2.50000000e01, -1.45446765e-01, -2.71162923e00], - [-2.16404709e00, -1.45446765e-01, -2.71162923e00], - ] - ), - "Fe2O3": np.array( - [ - [-25.0, -4.14354109, 0.0], - [-3.637187, -4.14354108, 0.0], - [-3.49325969, -3.85568646, -0.19190308], - [-25.0, -0.70024301, -2.29553205], - [-2.44144521, -0.70024301, -2.29553205], - ] - ), - "Fe3O4": np.array( - [ - [-25.0, -0.70024301, -2.29553205], - [-25.0, -0.14544676, -2.71162923], - [-2.44144521, -0.70024301, -2.29553205], - [-2.16404709, -0.14544676, -2.71162923], - ] - ), - "LiFeO2": np.array( - [ - [-3.49325969e00, -3.85568646e00, -1.91903083e-01], - [-2.01860032e00, 3.55271368e-15, -2.85707600e00], - [-2.44144521e00, -7.00243005e-01, -2.29553205e00], - [-2.16404709e00, -1.45446765e-01, -2.71162923e00], - [-1.71198739e00, 3.55271368e-15, -3.01038246e00], - [-2.74919447e00, -3.11162124e00, -9.35968300e-01], - ] - ), - "Li2O": np.array( - [ - [0.00000000e00, -2.50000000e01, -6.22930387e00], - [-2.69949567e00, -2.50000000e01, -8.30312528e-01], - [3.55271368e-15, 3.55271368e-15, -6.22930387e00], - [-1.43858289e00, 3.55271368e-15, -3.35213809e00], - [-2.69949567e00, -3.78273835e00, -8.30312528e-01], - ] - ), - "Li2O2": np.array( - [ - [-3.52980820e00, -2.50000000e01, 0.00000000e00], - [-2.69949567e00, -2.50000000e01, -8.30312528e-01], - [-3.52980820e00, -4.35829869e00, 3.55271368e-15], - [-2.69949567e00, -3.78273835e00, -8.30312528e-01], - [-2.82687176e00, -3.65536226e00, -7.02936437e-01], - ] - ), - "Li2FeO3": np.array( - [ - [-3.52980820e00, -4.35829869e00, 3.55271368e-15], - [-3.63718700e00, -4.14354108e00, 0.00000000e00], - [-3.49325969e00, -3.85568646e00, -1.91903083e-01], - [-2.74919447e00, -3.11162124e00, -9.35968300e-01], - [-2.82687176e00, -3.65536226e00, -7.02936437e-01], - ] - ), - "Li5FeO4": np.array( - [ - [-1.43858289e00, 3.55271368e-15, -3.35213809e00], - [-1.71198739e00, 3.55271368e-15, -3.01038246e00], - [-2.74919447e00, -3.11162124e00, -9.35968300e-01], - [-2.69949567e00, -3.78273835e00, -8.30312528e-01], - [-2.82687176e00, -3.65536226e00, -7.02936437e-01], - ] - ), - "O2": np.array( - [ - [-2.50000000e01, -2.50000000e01, 3.55271368e-15], - [-3.52980820e00, -2.50000000e01, 0.00000000e00], - [-2.50000000e01, -4.14354109e00, 0.00000000e00], - [-3.52980820e00, -4.35829869e00, 3.55271368e-15], - [-3.63718700e00, -4.14354108e00, 0.00000000e00], - ] - ), - "Fe": np.array( - [ - [0.00000000e00, 0.00000000e00, -2.50000000e01], - [-2.50000000e01, 0.00000000e00, -2.50000000e01], - [3.55271368e-15, 3.55271368e-15, -6.22930387e00], - [-2.50000000e01, 3.55271368e-15, -2.85707600e00], - [-2.01860032e00, 3.55271368e-15, -2.85707600e00], - [-1.43858289e00, 3.55271368e-15, -3.35213809e00], - [-1.71198739e00, 3.55271368e-15, -3.01038246e00], - ] - ), - "Li": np.array( - [ - [3.55271368e-15, -2.50000000e01, -2.50000000e01], - [0.00000000e00, -2.50000000e01, -6.22930387e00], - [0.00000000e00, 0.00000000e00, -2.50000000e01], - [3.55271368e-15, 3.55271368e-15, -6.22930387e00], - ] - ), + "FeO": [ + [-2.50000000e01, 3.55271368e-15, -2.85707600e00], + [-2.01860032e00, 3.55271368e-15, -2.85707600e00], + [-2.50000000e01, -1.45446765e-01, -2.71162923e00], + [-2.16404709e00, -1.45446765e-01, -2.71162923e00], + ], + "Fe2O3": [ + [-25.0, -4.14354109, 0.0], + [-3.637187, -4.14354108, 0.0], + [-3.49325969, -3.85568646, -0.19190308], + [-25.0, -0.70024301, -2.29553205], + [-2.44144521, -0.70024301, -2.29553205], + ], + "Fe3O4": [ + [-25.0, -0.70024301, -2.29553205], + [-25.0, -0.14544676, -2.71162923], + [-2.44144521, -0.70024301, -2.29553205], + [-2.16404709, -0.14544676, -2.71162923], + ], + "LiFeO2": [ + [-3.49325969e00, -3.85568646e00, -1.91903083e-01], + [-2.01860032e00, 3.55271368e-15, -2.85707600e00], + [-2.44144521e00, -7.00243005e-01, -2.29553205e00], + [-2.16404709e00, -1.45446765e-01, -2.71162923e00], + [-1.71198739e00, 3.55271368e-15, -3.01038246e00], + [-2.74919447e00, -3.11162124e00, -9.35968300e-01], + ], + "Li2O": [ + [0.00000000e00, -2.50000000e01, -6.22930387e00], + [-2.69949567e00, -2.50000000e01, -8.30312528e-01], + [3.55271368e-15, 3.55271368e-15, -6.22930387e00], + [-1.43858289e00, 3.55271368e-15, -3.35213809e00], + [-2.69949567e00, -3.78273835e00, -8.30312528e-01], + ], + "Li2O2": [ + [-3.52980820e00, -2.50000000e01, 0.00000000e00], + [-2.69949567e00, -2.50000000e01, -8.30312528e-01], + [-3.52980820e00, -4.35829869e00, 3.55271368e-15], + [-2.69949567e00, -3.78273835e00, -8.30312528e-01], + [-2.82687176e00, -3.65536226e00, -7.02936437e-01], + ], + "Li2FeO3": [ + [-3.52980820e00, -4.35829869e00, 3.55271368e-15], + [-3.63718700e00, -4.14354108e00, 0.00000000e00], + [-3.49325969e00, -3.85568646e00, -1.91903083e-01], + [-2.74919447e00, -3.11162124e00, -9.35968300e-01], + [-2.82687176e00, -3.65536226e00, -7.02936437e-01], + ], + "Li5FeO4": [ + [-1.43858289e00, 3.55271368e-15, -3.35213809e00], + [-1.71198739e00, 3.55271368e-15, -3.01038246e00], + [-2.74919447e00, -3.11162124e00, -9.35968300e-01], + [-2.69949567e00, -3.78273835e00, -8.30312528e-01], + [-2.82687176e00, -3.65536226e00, -7.02936437e-01], + ], + "O2": [ + [-2.50000000e01, -2.50000000e01, 3.55271368e-15], + [-3.52980820e00, -2.50000000e01, 0.00000000e00], + [-2.50000000e01, -4.14354109e00, 0.00000000e00], + [-3.52980820e00, -4.35829869e00, 3.55271368e-15], + [-3.63718700e00, -4.14354108e00, 0.00000000e00], + ], + "Fe": [ + [0.00000000e00, 0.00000000e00, -2.50000000e01], + [-2.50000000e01, 0.00000000e00, -2.50000000e01], + [3.55271368e-15, 3.55271368e-15, -6.22930387e00], + [-2.50000000e01, 3.55271368e-15, -2.85707600e00], + [-2.01860032e00, 3.55271368e-15, -2.85707600e00], + [-1.43858289e00, 3.55271368e-15, -3.35213809e00], + [-1.71198739e00, 3.55271368e-15, -3.01038246e00], + ], + "Li": [ + [3.55271368e-15, -2.50000000e01, -2.50000000e01], + [0.00000000e00, -2.50000000e01, -6.22930387e00], + [0.00000000e00, 0.00000000e00, -2.50000000e01], + [3.55271368e-15, 3.55271368e-15, -6.22930387e00], + ], } for formula, domain in formal_domains.items(): - d = self.cpd_ternary_formal.domains[formula] - d = d.round(6) # to get rid of numerical errors from qhull - assert d == approx(domain, abs=1e-5) + dom = self.cpd_ternary_formal.domains[formula] + dom = dom.round(6) # to get rid of numerical errors from qhull + assert dom == approx(np.array(domain), abs=1e-5) + + def test_formal_chempots_get_plot(self): + elems = [Element("Fe"), Element("O")] + fig_2d = self.cpd_ternary.get_plot(elements=elems) + fig_2d_formal = self.cpd_ternary_formal.get_plot(elements=elems) + + assert max(filter(bool, fig_2d.data[0].x)) == approx(-6.5961471) + assert max(filter(bool, fig_2d_formal.data[0].x)) == approx(0) + + assert max(filter(bool, fig_2d.data[0].y)) == approx(-4.2582781) + assert max(filter(bool, fig_2d_formal.data[0].y)) == approx(0) diff --git a/tests/analysis/test_cost.py b/tests/analysis/test_cost.py index 9bbc70bfeb7..137a864d177 100644 --- a/tests/analysis/test_cost.py +++ b/tests/analysis/test_cost.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pytest import approx @@ -8,7 +8,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestCostAnalyzer(unittest.TestCase): +class TestCostAnalyzer(TestCase): def setUp(self): self.ca1 = CostAnalyzer(CostDBCSV(f"{TEST_FILES_DIR}/costdb_1.csv")) self.ca2 = CostAnalyzer(CostDBCSV(f"{TEST_FILES_DIR}/costdb_2.csv")) @@ -29,7 +29,7 @@ def test_sanity(self): assert self.ca1.get_cost_per_kg("Ag") == self.ca2.get_cost_per_kg("Ag") -class TestCostDB(unittest.TestCase): +class TestCostDB(TestCase): def test_sanity(self): ca = CostAnalyzer(CostDBElements()) assert ca.get_cost_per_kg("PtO") > ca.get_cost_per_kg("MgO") diff --git a/tests/analysis/test_dimensionality.py b/tests/analysis/test_dimensionality.py index 6cb16893c1b..4fc43e2c522 100644 --- a/tests/analysis/test_dimensionality.py +++ b/tests/analysis/test_dimensionality.py @@ -107,8 +107,8 @@ def test_zero_d_to_molecule_graph(self): zero_d_graph_to_molecule_graph(self.graphite, comp_graphs[0]) # test for a troublesome structure - s = loadfn(f"{TEST_FILES_DIR}/PH7CN3O3F.json.gz") - bs = CrystalNN().get_bonded_structure(s) + struct = loadfn(f"{TEST_FILES_DIR}/PH7CN3O3F.json.gz") + bs = CrystalNN().get_bonded_structure(struct) comp_graphs = [bs.graph.subgraph(c) for c in nx.weakly_connected_components(bs.graph)] mol_graph = zero_d_graph_to_molecule_graph(bs, comp_graphs[0]) assert len(mol_graph.molecule) == 12 diff --git a/tests/analysis/test_energy_models.py b/tests/analysis/test_energy_models.py index 693ab51b291..36f27a22e16 100644 --- a/tests/analysis/test_energy_models.py +++ b/tests/analysis/test_energy_models.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - from pytest import approx from pymatgen.analysis.energy_models import EwaldElectrostaticModel, IsingModel, SymmetryModel @@ -11,7 +9,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestEwaldElectrostaticModel(unittest.TestCase): +class TestEwaldElectrostaticModel: def test_get_energy(self): coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = Lattice([[3.0, 0.0, 0.0], [1.0, 3.0, 0], [0, -2.0, 3.0]]) @@ -26,11 +24,11 @@ def test_get_energy(self): coords, ) - m = EwaldElectrostaticModel() + model = EwaldElectrostaticModel() # large tolerance because scipy constants changed between 0.16.1 and 0.17 - assert m.get_energy(struct) == approx(-264.66364858, abs=1e-2) # Result from GULP + assert model.get_energy(struct) == approx(-264.66364858, abs=1e-2) # Result from GULP s2 = Structure.from_file(f"{TEST_FILES_DIR}/Li2O.cif") - assert m.get_energy(s2) == approx(-145.39050015844839, abs=1e-4) + assert model.get_energy(s2) == approx(-145.39050015844839, abs=1e-4) def test_as_from_dict(self): model = EwaldElectrostaticModel() @@ -39,7 +37,7 @@ def test_as_from_dict(self): assert isinstance(restored, EwaldElectrostaticModel) -class TestSymmetryModel(unittest.TestCase): +class TestSymmetryModel: def test_get_energy(self): model = SymmetryModel() struct = Structure.from_file(f"{TEST_FILES_DIR}/Li2O.cif") @@ -52,7 +50,7 @@ def test_as_from_dict(self): assert restored.symprec == approx(0.2) -class TestIsingModel(unittest.TestCase): +class TestIsingModel: def test_get_energy(self): model = IsingModel(5, 6) diff --git a/tests/analysis/test_ewald.py b/tests/analysis/test_ewald.py index c13cdb24722..285152bb8fd 100644 --- a/tests/analysis/test_ewald.py +++ b/tests/analysis/test_ewald.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -11,7 +11,7 @@ from pymatgen.util.testing import VASP_IN_DIR -class TestEwaldSummation(unittest.TestCase): +class TestEwaldSummation(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.original_struct = Structure.from_file(filepath) @@ -80,7 +80,7 @@ def test_as_dict(self): assert ham.as_dict() == EwaldSummation.from_dict(dct).as_dict() -class TestEwaldMinimizer(unittest.TestCase): +class TestEwaldMinimizer(TestCase): def test_init(self): matrix = np.array( [ @@ -109,11 +109,11 @@ def test_site(self): """Test that uses an uncharged structure.""" filepath = f"{VASP_IN_DIR}/POSCAR" struct = Structure.from_file(filepath) - s = struct.copy() - s.add_oxidation_state_by_element({"Li": 1, "Fe": 3, "P": 5, "O": -2}) + struct = struct.copy() + struct.add_oxidation_state_by_element({"Li": 1, "Fe": 3, "P": 5, "O": -2}) # Comparison to LAMMPS result - ham = EwaldSummation(s, compute_forces=True) + ham = EwaldSummation(struct, compute_forces=True) assert approx(ham.total_energy, abs=1e-3) == -1226.3335 assert approx(ham.get_site_energy(0), abs=1e-3) == -45.8338 assert approx(ham.get_site_energy(8), abs=1e-3) == -27.2978 diff --git a/tests/analysis/test_fragmenter.py b/tests/analysis/test_fragmenter.py index 13894a3b18c..31795ebb05e 100644 --- a/tests/analysis/test_fragmenter.py +++ b/tests/analysis/test_fragmenter.py @@ -72,7 +72,7 @@ def test_babel_pc_old_defaults(self): fragmenter = Fragmenter(molecule=self.pc, open_rings=True) assert fragmenter.open_rings assert fragmenter.opt_steps == 10000 - default_mol_graph = MoleculeGraph.with_local_env_strategy(self.pc, OpenBabelNN()) + default_mol_graph = MoleculeGraph.from_local_env_strategy(self.pc, OpenBabelNN()) assert fragmenter.mol_graph == default_mol_graph assert fragmenter.total_unique_fragments == 13 @@ -81,7 +81,7 @@ def test_babel_pc_defaults(self): fragmenter = Fragmenter(molecule=self.pc) assert fragmenter.open_rings is False assert fragmenter.opt_steps == 10_000 - default_mol_graph = MoleculeGraph.with_local_env_strategy(self.pc, OpenBabelNN()) + default_mol_graph = MoleculeGraph.from_local_env_strategy(self.pc, OpenBabelNN()) assert fragmenter.mol_graph == default_mol_graph assert fragmenter.total_unique_fragments == 8 @@ -95,8 +95,8 @@ def test_edges_given_pc_not_defaults(self): ) assert fragmenter.open_rings is False assert fragmenter.opt_steps == 0 - edges = {(e[0], e[1]): None for e in self.pc_edges} - default_mol_graph = MoleculeGraph.with_edges(self.pc, edges=edges) + edges = {(edge[0], edge[1]): None for edge in self.pc_edges} + default_mol_graph = MoleculeGraph.from_edges(self.pc, edges=edges) assert fragmenter.mol_graph == default_mol_graph assert fragmenter.total_unique_fragments == 20 diff --git a/tests/analysis/test_functional_groups.py b/tests/analysis/test_functional_groups.py index a9abe7ea1f5..978215ab85c 100644 --- a/tests/analysis/test_functional_groups.py +++ b/tests/analysis/test_functional_groups.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pytest @@ -24,12 +24,12 @@ __credit__ = "Peiyuan Yu" -class TestFunctionalGroupExtractor(unittest.TestCase): +class TestFunctionalGroupExtractor(TestCase): def setUp(self): self.file = f"{TEST_DIR}/func_group_test.mol" self.mol = Molecule.from_file(self.file) self.strategy = OpenBabelNN() - self.mg = MoleculeGraph.with_local_env_strategy(self.mol, self.strategy) + self.mg = MoleculeGraph.from_local_env_strategy(self.mol, self.strategy) self.extractor = FunctionalGroupExtractor(self.mg) def test_init(self): diff --git a/tests/analysis/test_graphs.py b/tests/analysis/test_graphs.py index 9e570d318d1..990036cb87c 100644 --- a/tests/analysis/test_graphs.py +++ b/tests/analysis/test_graphs.py @@ -2,9 +2,10 @@ import copy import os -import unittest +import re from glob import glob from shutil import which +from unittest import TestCase import networkx as nx import networkx.algorithms.isomorphism as iso @@ -49,11 +50,9 @@ class TestStructureGraph(PymatgenTest): def setUp(self): - self.maxDiff = None - # trivial example, simple square lattice for testing structure = Structure(Lattice.tetragonal(5, 50), ["H"], [[0, 0, 0]]) - self.square_sg = StructureGraph.with_empty_graph(structure, edge_weight_name="", edge_weight_units="") + self.square_sg = StructureGraph.from_empty_graph(structure, edge_weight_name="", edge_weight_units="") self.square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(1, 0, 0)) self.square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(-1, 0, 0)) self.square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(0, 1, 0)) @@ -63,7 +62,7 @@ def setUp(self): # body-centered square lattice for testing structure = Structure(Lattice.tetragonal(5, 50), ["H", "He"], [[0, 0, 0], [0.5, 0.5, 0.5]]) - self.bc_square_sg = StructureGraph.with_empty_graph(structure, edge_weight_name="", edge_weight_units="") + self.bc_square_sg = StructureGraph.from_empty_graph(structure, edge_weight_name="", edge_weight_units="") self.bc_square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(1, 0, 0)) self.bc_square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(-1, 0, 0)) self.bc_square_sg.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(0, 1, 0)) @@ -76,7 +75,7 @@ def setUp(self): # body-centered square lattice for testing # directions reversed, should be equivalent to bc_square structure = Structure(Lattice.tetragonal(5, 50), ["H", "He"], [[0, 0, 0], [0.5, 0.5, 0.5]]) - self.bc_square_sg_r = StructureGraph.with_empty_graph(structure, edge_weight_name="", edge_weight_units="") + self.bc_square_sg_r = StructureGraph.from_empty_graph(structure, edge_weight_name="", edge_weight_units="") self.bc_square_sg_r.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(1, 0, 0)) self.bc_square_sg_r.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(-1, 0, 0)) self.bc_square_sg_r.add_edge(0, 0, from_jimage=(0, 0, 0), to_jimage=(0, 1, 0)) @@ -106,7 +105,7 @@ def setUp(self): def test_inappropriate_construction(self): # Check inappropriate strategy with pytest.raises(ValueError, match="Chosen strategy is not designed for use with structures"): - StructureGraph.with_local_env_strategy(self.NiO, CovalentBondNN()) + StructureGraph.from_local_env_strategy(self.NiO, CovalentBondNN()) def test_properties(self): assert self.mos2_sg.name == "bonds" @@ -134,7 +133,7 @@ def test_properties(self): ] nacl = Structure(nacl_lattice, ["Na", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) - nacl_graph = StructureGraph.with_local_env_strategy(nacl, CutOffDictNN({("Cl", "Cl"): 5.0})) + nacl_graph = StructureGraph.from_local_env_strategy(nacl, CutOffDictNN({("Cl", "Cl"): 5.0})) assert len(nacl_graph.get_connected_sites(1)) == 12 assert len(nacl_graph.graph.get_edge_data(1, 1)) == 6 @@ -206,17 +205,17 @@ def test_substitute(self): structure_copy = copy.deepcopy(structure) structure_copy_graph = copy.deepcopy(structure) - sg = StructureGraph.with_local_env_strategy(structure, MinimumDistanceNN()) - sg_copy = copy.deepcopy(sg) + struct_graph = StructureGraph.from_local_env_strategy(structure, MinimumDistanceNN()) + sg_copy = copy.deepcopy(struct_graph) # Ensure that strings and molecules lead to equivalent substitutions - sg.substitute_group(1, molecule, MinimumDistanceNN) + struct_graph.substitute_group(1, molecule, MinimumDistanceNN) sg_copy.substitute_group(1, "methyl", MinimumDistanceNN) - assert sg == sg_copy + assert struct_graph == sg_copy # Ensure that the underlying structure has been modified as expected structure_copy.substitute(1, "methyl") - assert structure_copy == sg.structure + assert structure_copy == struct_graph.structure # Test inclusion of graph dictionary graph_dict = { @@ -225,16 +224,16 @@ def test_substitute(self): (0, 3): {"weight": 0.5}, } - sg_with_graph = StructureGraph.with_local_env_strategy(structure_copy_graph, MinimumDistanceNN()) + sg_with_graph = StructureGraph.from_local_env_strategy(structure_copy_graph, MinimumDistanceNN()) sg_with_graph.substitute_group(1, "methyl", MinimumDistanceNN, graph_dict=graph_dict) edge = sg_with_graph.graph.get_edge_data(11, 13)[0] assert edge["weight"] == 0.5 def test_auto_image_detection(self): - sg = StructureGraph.with_empty_graph(self.structure) - sg.add_edge(0, 0) + struct_graph = StructureGraph.from_empty_graph(self.structure) + struct_graph.add_edge(0, 0) - assert len(list(sg.graph.edges(data=True))) == 3 + assert len(list(struct_graph.graph.edges(data=True))) == 3 def test_str(self): square_sg_str_ref = """Structure Graph @@ -330,16 +329,16 @@ def test_mul(self): for idx in mos2_sg_mul.structure.indices_from_symbol("Mo"): assert mos2_sg_mul.get_coordination_of_site(idx) == 6 - mos2_sg_premul = StructureGraph.with_local_env_strategy(self.structure * (3, 3, 1), MinimumDistanceNN()) + mos2_sg_premul = StructureGraph.from_local_env_strategy(self.structure * (3, 3, 1), MinimumDistanceNN()) assert mos2_sg_mul == mos2_sg_premul # test 3D Structure - nio_sg = StructureGraph.with_local_env_strategy(self.NiO, MinimumDistanceNN()) - nio_sg = nio_sg * 3 + nio_struct_graph = StructureGraph.from_local_env_strategy(self.NiO, MinimumDistanceNN()) + nio_struct_graph = nio_struct_graph * 3 - for n in range(len(nio_sg)): - assert nio_sg.get_coordination_of_site(n) == 6 + for n in range(len(nio_struct_graph)): + assert nio_struct_graph.get_coordination_of_site(n) == 6 @pytest.mark.skipif( pygraphviz is None or not (which("neato") and which("fdp")), reason="graphviz executables not present" @@ -356,7 +355,7 @@ def test_draw(self): mos2_sg_2.draw_graph_to_file(f"{self.tmp_path}/MoS2_twice_mul.pdf", algo="neato", hide_image_edges=True) # draw MoS2 graph that's generated from a pre-multiplied Structure - mos2_sg_premul = StructureGraph.with_local_env_strategy(self.structure * (3, 3, 1), MinimumDistanceNN()) + mos2_sg_premul = StructureGraph.from_local_env_strategy(self.structure * (3, 3, 1), MinimumDistanceNN()) mos2_sg_premul.draw_graph_to_file(f"{self.tmp_path}/MoS2_premul.pdf", algo="neato", hide_image_edges=True) # draw graph for a square lattice @@ -397,19 +396,19 @@ def test_as_from_dict(self): assert dct == d2 def test_from_local_env_and_equality_and_diff(self): - nn = MinimumDistanceNN() - sg = StructureGraph.with_local_env_strategy(self.structure, nn) + min_dist_nn = MinimumDistanceNN() + struct_graph = StructureGraph.from_local_env_strategy(self.structure, min_dist_nn) - assert sg.graph.number_of_edges() == 6 + assert struct_graph.graph.number_of_edges() == 6 nn2 = MinimumOKeeffeNN() - sg2 = StructureGraph.with_local_env_strategy(self.structure, nn2) + sg2 = StructureGraph.from_local_env_strategy(self.structure, nn2) - assert sg == sg2 - assert sg == self.mos2_sg + assert struct_graph == sg2 + assert struct_graph == self.mos2_sg # TODO: find better test case where graphs are different - diff = sg.diff(sg2) + diff = struct_graph.diff(sg2) assert diff["dist"] == 0 assert self.square_sg.get_coordination_of_site(0) == 2 @@ -424,19 +423,19 @@ def test_from_edges(self): structure = Structure(Lattice.tetragonal(5.0, 50.0), ["H"], [[0, 0, 0]]) - sg = StructureGraph.with_edges(structure, edges) + struct_graph = StructureGraph.from_edges(structure, edges) - assert sg == self.square_sg + assert struct_graph == self.square_sg def test_extract_molecules(self): structure_file = f"{TEST_FILES_DIR}/H6PbCI3N_mp-977013_symmetrized.cif" struct = Structure.from_file(structure_file) - nn = MinimumDistanceNN() - sg = StructureGraph.with_local_env_strategy(struct, nn) + min_dist_nn = MinimumDistanceNN() + struct_graph = StructureGraph.from_local_env_strategy(struct, min_dist_nn) - molecules = sg.get_subgraphs_as_molecules() + molecules = struct_graph.get_subgraphs_as_molecules() assert molecules[0].formula == "H3 C1" assert len(molecules) == 1 @@ -470,11 +469,11 @@ def test_no_duplicate_hops(self): coords=[[0.005572, 0.994428, 0.151095]], ) - nn = MinimumDistanceNN(cutoff=6, get_all_sites=True) + min_dist_nn = MinimumDistanceNN(cutoff=6, get_all_sites=True) - sg = StructureGraph.with_local_env_strategy(test_structure, nn) + struct_graph = StructureGraph.from_local_env_strategy(test_structure, min_dist_nn) - assert sg.graph.number_of_edges() == 3 + assert struct_graph.graph.number_of_edges() == 3 def test_sort(self): sg = copy.deepcopy(self.bc_square_sg_r) @@ -487,11 +486,11 @@ def test_sort(self): assert list(sg.graph.edges)[-2:] == [(1, 3, 0), (1, 2, 0)] -class TestMoleculeGraph(unittest.TestCase): +class TestMoleculeGraph(TestCase): def setUp(self): cyclohexene_xyz = f"{TEST_FILES_DIR}/graphs/cyclohexene.xyz" cyclohexene = Molecule.from_file(cyclohexene_xyz) - self.cyclohexene = MoleculeGraph.with_empty_graph( + self.cyclohexene = MoleculeGraph.from_empty_graph( cyclohexene, edge_weight_name="strength", edge_weight_units="" ) self.cyclohexene.add_edge(0, 1, weight=1.0) @@ -512,7 +511,7 @@ def setUp(self): self.cyclohexene.add_edge(5, 15, weight=1.0) butadiene = Molecule.from_file(f"{TEST_FILES_DIR}/graphs/butadiene.xyz") - self.butadiene = MoleculeGraph.with_empty_graph(butadiene, edge_weight_name="strength", edge_weight_units="") + self.butadiene = MoleculeGraph.from_empty_graph(butadiene, edge_weight_name="strength", edge_weight_units="") self.butadiene.add_edge(0, 1, weight=2.0) self.butadiene.add_edge(1, 2, weight=1.0) self.butadiene.add_edge(2, 3, weight=2.0) @@ -524,7 +523,7 @@ def setUp(self): self.butadiene.add_edge(3, 9, weight=1.0) ethylene = Molecule.from_file(f"{TEST_FILES_DIR}/graphs/ethylene.xyz") - self.ethylene = MoleculeGraph.with_empty_graph(ethylene, edge_weight_name="strength", edge_weight_units="") + self.ethylene = MoleculeGraph.from_empty_graph(ethylene, edge_weight_name="strength", edge_weight_units="") self.ethylene.add_edge(0, 1, weight=2.0) self.ethylene.add_edge(0, 2, weight=1.0) self.ethylene.add_edge(0, 3, weight=1.0) @@ -569,8 +568,8 @@ def setUp(self): def test_construction(self): pytest.importorskip("openbabel") - edges_frag = {(e[0], e[1]): {"weight": 1.0} for e in self.pc_frag1_edges} - mol_graph = MoleculeGraph.with_edges(self.pc_frag1, edges_frag) + edges_frag = {(edge[0], edge[1]): {"weight": 1.0} for edge in self.pc_frag1_edges} + mol_graph = MoleculeGraph.from_edges(self.pc_frag1, edges_frag) # dumpfn(mol_graph.as_dict(), f"{module_dir}/pc_frag1_mg.json") ref_mol_graph = loadfn(f"{module_dir}/pc_frag1_mg.json") assert mol_graph == ref_mol_graph @@ -580,8 +579,8 @@ def test_construction(self): for ii in range(3): assert mol_graph.graph.nodes[node]["coords"][ii] == ref_mol_graph.graph.nodes[node]["coords"][ii] - edges_pc = {(e[0], e[1]): {"weight": 1.0} for e in self.pc_edges} - mol_graph = MoleculeGraph.with_edges(self.pc, edges_pc) + edges_pc = {(edge[0], edge[1]): {"weight": 1.0} for edge in self.pc_edges} + mol_graph = MoleculeGraph.from_edges(self.pc, edges_pc) # dumpfn(mol_graph.as_dict(), f"{module_dir}/pc_mg.json") ref_mol_graph = loadfn(f"{module_dir}/pc_mg.json") assert mol_graph == ref_mol_graph @@ -591,18 +590,17 @@ def test_construction(self): for ii in range(3): assert mol_graph.graph.nodes[node]["coords"][ii] == ref_mol_graph.graph.nodes[node]["coords"][ii] - mol_graph_edges = MoleculeGraph.with_edges(self.pc, edges=edges_pc) - mol_graph_strat = MoleculeGraph.with_local_env_strategy(self.pc, OpenBabelNN()) + mol_graph_edges = MoleculeGraph.from_edges(self.pc, edges=edges_pc) + mol_graph_strat = MoleculeGraph.from_local_env_strategy(self.pc, OpenBabelNN()) assert mol_graph_edges.isomorphic_to(mol_graph_strat) - # Check inappropriate strategy - non_mol_strategy = VoronoiNN() + # Check error message on using inappropriate strategy for molecules + strategy = VoronoiNN() with pytest.raises( - ValueError, - match=f"strategy='{non_mol_strategy}' is not designed for use with molecules! Choose another strategy", + ValueError, match=re.escape(f"{strategy=} is not designed for use with molecules! Choose another strategy") ): - MoleculeGraph.with_local_env_strategy(self.pc, non_mol_strategy) + MoleculeGraph.from_local_env_strategy(self.pc, strategy) def test_properties(self): assert self.cyclohexene.name == "bonds" @@ -632,7 +630,7 @@ def test_set_node_attributes(self): def test_coordination(self): molecule = Molecule(["C", "C"], [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) - mg = MoleculeGraph.with_empty_graph(molecule) + mg = MoleculeGraph.from_empty_graph(molecule) assert mg.get_coordination_of_site(0) == 0 assert self.cyclohexene.get_coordination_of_site(0) == 4 @@ -697,7 +695,7 @@ def test_get_disconnected(self): just_he = Molecule(["He"], [[5, 5, 5]]) - dis_mg = MoleculeGraph.with_empty_graph(disconnected) + dis_mg = MoleculeGraph.from_empty_graph(disconnected) dis_mg.add_edge(0, 1) dis_mg.add_edge(0, 2) dis_mg.add_edge(0, 3) @@ -711,7 +709,7 @@ def test_get_disconnected(self): assert fragments[1].molecule == just_he assert index_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5} - con_mg = MoleculeGraph.with_empty_graph(no_he) + con_mg = MoleculeGraph.from_empty_graph(no_he) con_mg.add_edge(0, 1) con_mg.add_edge(0, 2) con_mg.add_edge(0, 3) @@ -743,7 +741,7 @@ def test_split(self): # Test naive charge redistribution hydroxide = Molecule(["O", "H"], [[0, 0, 0], [0.5, 0.5, 0.5]], charge=-1) - oh_mg = MoleculeGraph.with_empty_graph(hydroxide) + oh_mg = MoleculeGraph.from_empty_graph(hydroxide) oh_mg.add_edge(0, 1) @@ -766,14 +764,14 @@ def test_split(self): ], ) - diff_spec_mg = MoleculeGraph.with_empty_graph(diff_species) + diff_spec_mg = MoleculeGraph.from_empty_graph(diff_species) diff_spec_mg.add_edge(0, 1) diff_spec_mg.add_edge(0, 2) diff_spec_mg.add_edge(0, 3) diff_spec_mg.add_edge(0, 4) - for i in range(1, 5): - bond = (0, i) + for idx in range(1, 5): + bond = (0, idx) split_mgs = diff_spec_mg.split_molecule_subgraphs([bond]) for split_mg in split_mgs: @@ -784,8 +782,8 @@ def test_split(self): assert species[j] == str(atom.specie) def test_build_unique_fragments(self): - edges = {(e[0], e[1]): None for e in self.pc_edges} - mol_graph = MoleculeGraph.with_edges(self.pc, edges) + edges = {(edge[0], edge[1]): None for edge in self.pc_edges} + mol_graph = MoleculeGraph.from_edges(self.pc, edges) unique_fragment_dict = mol_graph.build_unique_fragments() unique_fragments = [fragment for key in unique_fragment_dict for fragment in unique_fragment_dict[key]] assert len(unique_fragments) == 295 @@ -832,7 +830,7 @@ def test_isomorphic(self): (0, 5): {"weight": 1}, } - ethylene_graph = MoleculeGraph.with_edges(ethylene, edges) + ethylene_graph = MoleculeGraph.from_edges(ethylene, edges) # If they are equal, they must also be isomorphic assert self.ethylene.isomorphic_to(ethylene_graph) assert not self.butadiene.isomorphic_to(self.ethylene) @@ -840,11 +838,11 @@ def test_isomorphic(self): # check fix in https://github.com/materialsproject/pymatgen/pull/3221 # by comparing graph with equal nodes but different edges edges[(1, 4)] = {"weight": 2} - assert not self.ethylene.isomorphic_to(MoleculeGraph.with_edges(ethylene, edges)) + assert not self.ethylene.isomorphic_to(MoleculeGraph.from_edges(ethylene, edges)) def test_substitute(self): molecule = FunctionalGroups["methyl"] - mol_graph = MoleculeGraph.with_edges( + mol_graph = MoleculeGraph.from_edges( molecule, {(0, 1): {"weight": 1}, (0, 2): {"weight": 1}, (0, 3): {"weight": 1}}, ) diff --git a/tests/analysis/test_hhi.py b/tests/analysis/test_hhi.py index df9230192c5..105072426fb 100644 --- a/tests/analysis/test_hhi.py +++ b/tests/analysis/test_hhi.py @@ -1,13 +1,11 @@ from __future__ import annotations -import unittest - from pytest import approx from pymatgen.analysis.hhi import HHIModel -class TestHHIModel(unittest.TestCase): +class TestHHIModel: def test_hhi(self): hhi = HHIModel() assert hhi.get_hhi("He") == (3200, 3900) diff --git a/tests/analysis/test_interface_reactions.py b/tests/analysis/test_interface_reactions.py index 6288a413311..c0030688278 100644 --- a/tests/analysis/test_interface_reactions.py +++ b/tests/analysis/test_interface_reactions.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -17,7 +17,7 @@ from pymatgen.entries.computed_entries import ComputedEntry -class TestInterfaceReaction(unittest.TestCase): +class TestInterfaceReaction(TestCase): def setUp(self): self.entries = [ ComputedEntry(Composition("Li"), 0), diff --git a/tests/analysis/test_local_env.py b/tests/analysis/test_local_env.py index 29cc6e94796..63b8db814a3 100644 --- a/tests/analysis/test_local_env.py +++ b/tests/analysis/test_local_env.py @@ -11,9 +11,9 @@ from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph from pymatgen.analysis.local_env import ( - BrunnerNN_real, - BrunnerNN_reciprocal, - BrunnerNN_relative, + BrunnerNNReal, + BrunnerNNReciprocal, + BrunnerNNRelative, CovalentBondNN, Critic2NN, CrystalNN, @@ -403,60 +403,60 @@ def test_all_nn_classes(self): for image in MinimumDistanceNN(tol=0.1).get_nn_images(self.mos2, 0): assert image in [(0, 0, 0), (0, 1, 0), (-1, 0, 0), (0, 0, 0), (0, 1, 0), (-1, 0, 0)] - okeeffe = MinimumOKeeffeNN(tol=0.01) - assert okeeffe.get_cn(self.diamond, 0) == 4 - assert okeeffe.get_cn(self.nacl, 0) == 6 - assert okeeffe.get_cn(self.cscl, 0) == 8 - assert okeeffe.get_cn(self.lifepo4, 0) == 2 + okeeffe_nn = MinimumOKeeffeNN(tol=0.01) + assert okeeffe_nn.get_cn(self.diamond, 0) == 4 + assert okeeffe_nn.get_cn(self.nacl, 0) == 6 + assert okeeffe_nn.get_cn(self.cscl, 0) == 8 + assert okeeffe_nn.get_cn(self.lifepo4, 0) == 2 - virenn = MinimumVIRENN(tol=0.01) - assert virenn.get_cn(self.diamond, 0) == 4 - assert virenn.get_cn(self.nacl, 0) == 6 - assert virenn.get_cn(self.cscl, 0) == 8 - assert virenn.get_cn(self.lifepo4, 0) == 2 + min_vire_nn = MinimumVIRENN(tol=0.01) + assert min_vire_nn.get_cn(self.diamond, 0) == 4 + assert min_vire_nn.get_cn(self.nacl, 0) == 6 + assert min_vire_nn.get_cn(self.cscl, 0) == 8 + assert min_vire_nn.get_cn(self.lifepo4, 0) == 2 - brunner_recip = BrunnerNN_reciprocal(tol=0.01) + brunner_recip = BrunnerNNReciprocal(tol=0.01) assert brunner_recip.get_cn(self.diamond, 0) == 4 assert brunner_recip.get_cn(self.nacl, 0) == 6 assert brunner_recip.get_cn(self.cscl, 0) == 14 assert brunner_recip.get_cn(self.lifepo4, 0) == 6 - brunner_rel = BrunnerNN_relative(tol=0.01) + brunner_rel = BrunnerNNRelative(tol=0.01) assert brunner_rel.get_cn(self.diamond, 0) == 4 assert brunner_rel.get_cn(self.nacl, 0) == 6 assert brunner_rel.get_cn(self.cscl, 0) == 14 assert brunner_rel.get_cn(self.lifepo4, 0) == 6 - brunner_real = BrunnerNN_real(tol=0.01) + brunner_real = BrunnerNNReal(tol=0.01) assert brunner_real.get_cn(self.diamond, 0) == 4 assert brunner_real.get_cn(self.nacl, 0) == 6 assert brunner_real.get_cn(self.cscl, 0) == 14 assert brunner_real.get_cn(self.lifepo4, 0) == 30 - econn = EconNN() - assert econn.get_cn(self.diamond, 0) == 4 - assert econn.get_cn(self.nacl, 0) == 6 - assert econn.get_cn(self.cscl, 0) == 14 - assert econn.get_cn(self.lifepo4, 0) == 6 + econ_nn = EconNN() + assert econ_nn.get_cn(self.diamond, 0) == 4 + assert econ_nn.get_cn(self.nacl, 0) == 6 + assert econ_nn.get_cn(self.cscl, 0) == 14 + assert econ_nn.get_cn(self.lifepo4, 0) == 6 - voroinn = VoronoiNN(tol=0.5) - assert voroinn.get_cn(self.diamond, 0) == 4 - assert voroinn.get_cn(self.nacl, 0) == 6 - assert voroinn.get_cn(self.cscl, 0) == 8 - assert voroinn.get_cn(self.lifepo4, 0) == 6 + voronoi_nn = VoronoiNN(tol=0.5) + assert voronoi_nn.get_cn(self.diamond, 0) == 4 + assert voronoi_nn.get_cn(self.nacl, 0) == 6 + assert voronoi_nn.get_cn(self.cscl, 0) == 8 + assert voronoi_nn.get_cn(self.lifepo4, 0) == 6 - crystalnn = CrystalNN() - assert crystalnn.get_cn(self.diamond, 0) == 4 - assert crystalnn.get_cn(self.nacl, 0) == 6 - assert crystalnn.get_cn(self.cscl, 0) == 8 - assert crystalnn.get_cn(self.lifepo4, 0) == 6 + crystal_nn = CrystalNN() + assert crystal_nn.get_cn(self.diamond, 0) == 4 + assert crystal_nn.get_cn(self.nacl, 0) == 6 + assert crystal_nn.get_cn(self.cscl, 0) == 8 + assert crystal_nn.get_cn(self.lifepo4, 0) == 6 def test_get_local_order_params(self): - nn = MinimumDistanceNN() - ops = nn.get_local_order_parameters(self.diamond, 0) + min_dist_nn = MinimumDistanceNN() + ops = min_dist_nn.get_local_order_parameters(self.diamond, 0) assert ops["tetrahedral"] == approx(0.9999934389036574) - ops = nn.get_local_order_parameters(self.nacl, 0) + ops = min_dist_nn.get_local_order_parameters(self.nacl, 0) assert ops["octahedral"] == approx(0.9999995266669) @@ -534,18 +534,18 @@ def setUp(self): ) def test_site_is_of_motif_type(self): - for i in range(len(self.diamond)): - assert site_is_of_motif_type(self.diamond, i) == "tetrahedral" - for i in range(len(self.nacl)): - assert site_is_of_motif_type(self.nacl, i) == "octahedral" - for i in range(len(self.cscl)): - assert site_is_of_motif_type(self.cscl, i) == "bcc" + for idx in range(len(self.diamond)): + assert site_is_of_motif_type(self.diamond, idx) == "tetrahedral" + for idx in range(len(self.nacl)): + assert site_is_of_motif_type(self.nacl, idx) == "octahedral" + for idx in range(len(self.cscl)): + assert site_is_of_motif_type(self.cscl, idx) == "bcc" assert site_is_of_motif_type(self.square_pyramid, 0) == "square pyramidal" - for i in range(1, len(self.square_pyramid)): - assert site_is_of_motif_type(self.square_pyramid, i) == "unrecognized" + for idx in range(1, len(self.square_pyramid)): + assert site_is_of_motif_type(self.square_pyramid, idx) == "unrecognized" assert site_is_of_motif_type(self.trigonal_bipyramid, 0) == "trigonal bipyramidal" - for i in range(1, len(self.trigonal_bipyramid)): - assert site_is_of_motif_type(self.trigonal_bipyramid, i) == "unrecognized" + for idx in range(1, len(self.trigonal_bipyramid)): + assert site_is_of_motif_type(self.trigonal_bipyramid, idx) == "unrecognized" def test_get_neighbors_of_site_with_index(self): assert len(get_neighbors_of_site_with_index(self.diamond, 0)) == 4 @@ -1003,15 +1003,10 @@ def test_get_order_parameters(self): "tet_max", "sq_face_cap_trig_pris", ] - op_params = [None for i in range(len(op_types))] + op_params = [None] * len(op_types) op_params[1] = {"TA": 1, "IGW_TA": 1.0 / 0.0667} op_params[2] = {"TA": 45.0 / 180, "IGW_TA": 1.0 / 0.0667} - op_params[33] = { - "TA": 0.6081734479693927, - "IGW_TA": 18.33, - "fac_AA": 1.5, - "exp_cos_AA": 2, - } + op_params[33] = {"TA": 0.6081734479693927, "IGW_TA": 18.33, "fac_AA": 1.5, "exp_cos_AA": 2} ops_044 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.44) ops_071 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.71) ops_087 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.87) @@ -1344,7 +1339,7 @@ class TestMetalEdgeExtender(PymatgenTest): def setUp(self): self.LiEC = Molecule.from_file(f"{TEST_DIR}/LiEC.xyz") self.phsh = Molecule.from_file(f"{TEST_DIR}/phsh.xyz") - self.phsh_graph = MoleculeGraph.with_edges( + self.phsh_graph = MoleculeGraph.from_edges( molecule=self.phsh, edges={ (0, 1): None, @@ -1374,7 +1369,7 @@ def setUp(self): (21, 24): None, }, ) - self.LiEC_graph = MoleculeGraph.with_edges( + self.LiEC_graph = MoleculeGraph.from_edges( molecule=self.LiEC, edges={ (0, 2): None, @@ -1396,7 +1391,7 @@ def setUp(self): K_sites = [s.coords for s in uncharged_K_cluster] K_species = [s.species for s in uncharged_K_cluster] charged_K_cluster = Molecule(K_species, K_sites, charge=1) - self.water_cluster_K = MoleculeGraph.with_empty_graph(charged_K_cluster) + self.water_cluster_K = MoleculeGraph.from_empty_graph(charged_K_cluster) assert len(self.water_cluster_K.graph.edges) == 0 # Mg + 6 H2O at 1.94 Ang from Mg @@ -1404,7 +1399,7 @@ def setUp(self): Mg_sites = [s.coords for s in uncharged_Mg_cluster] Mg_species = [s.species for s in uncharged_Mg_cluster] charged_Mg_cluster = Molecule(Mg_species, Mg_sites, charge=2) - self.water_cluster_Mg = MoleculeGraph.with_empty_graph(charged_Mg_cluster) + self.water_cluster_Mg = MoleculeGraph.from_empty_graph(charged_Mg_cluster) def test_metal_edge_extender(self): assert len(self.LiEC_graph.graph.edges) == 11 diff --git a/tests/analysis/test_molecule_matcher.py b/tests/analysis/test_molecule_matcher.py index 783c4ddcea8..095e826c0b6 100644 --- a/tests/analysis/test_molecule_matcher.py +++ b/tests/analysis/test_molecule_matcher.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -147,7 +147,7 @@ def generate_Si2O_cluster(): @pytest.mark.skipif(ob_align_missing, reason="OBAlign is missing, Skipping") -class TestMoleculeMatcher(unittest.TestCase): +class TestMoleculeMatcher: def test_fit(self): self.fit_with_mapper(IsomorphismMolAtomMapper()) self.fit_with_mapper(InchiMolAtomMapper()) @@ -261,7 +261,7 @@ def test_cdi_23(self): assert not mol_matcher.fit(mol1, mol2) -class TestKabschMatcher(unittest.TestCase): +class TestKabschMatcher: def test_get_rmsd(self): mol1 = Molecule.from_file(f"{TEST_DIR}/t3.xyz") mol2 = Molecule.from_file(f"{TEST_DIR}/t4.xyz") @@ -335,7 +335,7 @@ def test_fit(self): assert rmsd == approx(0, abs=6) -class TestHungarianOrderMatcher(unittest.TestCase): +class TestHungarianOrderMatcher: def test_get_rmsd(self): mol1 = Molecule.from_file(f"{TEST_DIR}/t3.xyz") mol2 = Molecule.from_file(f"{TEST_DIR}/t4.xyz") @@ -440,7 +440,7 @@ def test_fit(self): assert rmsd == approx(0, abs=6) -class TestGeneticOrderMatcher(unittest.TestCase): +class TestGeneticOrderMatcher: def test_get_rmsd(self): mol1 = Molecule.from_file(f"{TEST_DIR}/t3.xyz") mol2 = Molecule.from_file(f"{TEST_DIR}/t4.xyz") @@ -545,7 +545,7 @@ def test_fit(self): assert rmsd == approx(0, abs=6) -class TestKabschMatcherSi(unittest.TestCase): +class TestKabschMatcherSi(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si_cluster.xyz") @@ -581,7 +581,7 @@ def test_permuted_atoms_order(self): assert rmsd == approx(2.7962454578966454, abs=1e-6) -class TestBruteForceOrderMatcherSi(unittest.TestCase): +class TestBruteForceOrderMatcherSi(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si_cluster.xyz") @@ -602,7 +602,7 @@ def test_random_match(self): self.mol_matcher.fit(mol2) -class TestHungarianOrderMatcherSi(unittest.TestCase): +class TestHungarianOrderMatcherSi(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si_cluster.xyz") @@ -641,7 +641,7 @@ def test_random_match(self): assert rmsd == approx(1.0177241485450828, abs=1e-6) -class TestGeneticOrderMatcherSi(unittest.TestCase): +class TestGeneticOrderMatcherSi(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si_cluster.xyz") @@ -678,7 +678,7 @@ def test_random_match(self): assert res[0][-1] == approx(0.22163169511782, abs=1e-6) -class TestKabschMatcherSi2O(unittest.TestCase): +class TestKabschMatcherSi2O(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si2O_cluster.xyz") @@ -711,7 +711,7 @@ def test_permuted_atoms_order(self): self.mol_matcher.fit(mol2) -class TestBruteForceOrderMatcherSi2O(unittest.TestCase): +class TestBruteForceOrderMatcherSi2O(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si2O_cluster.xyz") @@ -743,7 +743,7 @@ def test_random_match(self): assert rmsd == approx(0.23051587697194997, abs=1e-6) -class TestHungarianOrderMatcherSi2O(unittest.TestCase): +class TestHungarianOrderMatcherSi2O(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si2O_cluster.xyz") @@ -775,7 +775,7 @@ def test_random_match(self): assert rmsd == approx(0.23231038877573124, abs=1e-6) -class TestGeneticOrderMatcherSi2O(unittest.TestCase): +class TestGeneticOrderMatcherSi2O(TestCase): @classmethod def setUpClass(cls): cls.mol1 = Molecule.from_file(f"{TEST_DIR}/Si2O_cluster.xyz") diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index 479bb8dadae..669d8a89efc 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -5,6 +5,7 @@ import unittest import unittest.mock from numbers import Number +from unittest import TestCase import matplotlib.pyplot as plt import numpy as np @@ -36,7 +37,7 @@ module_dir = os.path.dirname(os.path.abspath(__file__)) -class TestPDEntry(unittest.TestCase): +class TestPDEntry(TestCase): def setUp(self): comp = Composition("LiFeO2") self.entry = PDEntry(comp, 53, name="mp-757614") @@ -112,7 +113,7 @@ def test_read_csv(self): assert len(entries) == 490, "Wrong number of entries!" -class TestTransformedPDEntry(unittest.TestCase): +class TestTransformedPDEntry(TestCase): def setUp(self): comp = Composition("LiFeO2") entry = PDEntry(comp, 53) @@ -350,7 +351,7 @@ def test_get_equilibrium_reaction_energy(self): def test_get_phase_separation_energy(self): for entry in self.pd.unstable_entries: if entry.composition.fractional_composition not in [ - e.composition.fractional_composition for e in self.pd.stable_entries + entry.composition.fractional_composition for entry in self.pd.stable_entries ]: assert ( self.pd.get_phase_separation_energy(entry) >= 0 @@ -398,7 +399,7 @@ def test_get_phase_separation_energy(self): duplicate_entry = PDEntry("Li2O", -14.31361175) scaled_dup_entry = PDEntry("Li4O2", -14.31361175 * 2) - stable_entry = next(e for e in self.pd.stable_entries if e.name == "Li2O") + stable_entry = next(entry for entry in self.pd.stable_entries if entry.name == "Li2O") assert self.pd.get_phase_separation_energy(duplicate_entry) == self.pd.get_phase_separation_energy( stable_entry @@ -489,8 +490,8 @@ def test_get_hull_energy_per_atom(self): def test_1d_pd(self): entry = PDEntry("H", 0) pd = PhaseDiagram([entry]) - decomp, e = pd.get_decomp_and_e_above_hull(PDEntry("H", 1)) - assert e == 1 + decomp, e_above_hull = pd.get_decomp_and_e_above_hull(PDEntry("H", 1)) + assert e_above_hull == 1 assert decomp[entry] == approx(1.0) def test_get_critical_compositions_fractional(self): @@ -639,7 +640,7 @@ def test_val_err_on_no_entries(self): PhaseDiagram(entries=entries) -class TestGrandPotentialPhaseDiagram(unittest.TestCase): +class TestGrandPotentialPhaseDiagram(TestCase): def setUp(self): self.entries = EntrySet.from_csv(f"{module_dir}/pd_entries_test.csv") self.pd = GrandPotentialPhaseDiagram(self.entries, {Element("O"): -5}) @@ -675,7 +676,7 @@ def test_str(self): ) -class TestCompoundPhaseDiagram(unittest.TestCase): +class TestCompoundPhaseDiagram(TestCase): def setUp(self): self.entries = EntrySet.from_csv(f"{module_dir}/pd_entries_test.csv") self.pd = CompoundPhaseDiagram(self.entries, [Composition("Li2O"), Composition("Fe2O3")]) @@ -701,7 +702,7 @@ def test_str(self): assert str(self.pd) == "Xf-Xg phase diagram\n4 stable phases: \nLiFeO2, Li2O, Li5FeO4, Fe2O3" -class TestPatchedPhaseDiagram(unittest.TestCase): +class TestPatchedPhaseDiagram(TestCase): def setUp(self): self.entries = EntrySet.from_csv(f"{module_dir}/reaction_entries_test.csv") # NOTE add He to test for correct behavior despite no patches involving He @@ -832,14 +833,14 @@ def test_setitem_and_delitem(self): del self.ppd[unlikely_chem_space] # test __delitem__() and restore original state -class TestReactionDiagram(unittest.TestCase): +class TestReactionDiagram(TestCase): def setUp(self): self.entries = list(EntrySet.from_csv(f"{module_dir}/reaction_entries_test.csv").entries) - for e in self.entries: - if e.reduced_formula == "VPO5": - entry1 = e - elif e.reduced_formula == "H4(CO)3": - entry2 = e + for entry in self.entries: + if entry.reduced_formula == "VPO5": + entry1 = entry + elif entry.reduced_formula == "H4(CO)3": + entry2 = entry self.rd = ReactionDiagram(entry1=entry1, entry2=entry2, all_entries=self.entries[2:]) def test_get_compound_pd(self): @@ -852,7 +853,7 @@ def test_formula(self): assert Element.C in entry.composition assert Element.P in entry.composition assert Element.H in entry.composition - # formed_formula = [e.reduced_formula for e in self.rd.rxn_entries] + # formed_formula = [entry.reduced_formula for entry in self.rd.rxn_entries] # expected_formula = [ # "V0.12707182P0.12707182H0.0441989C0.03314917O0.66850829", # "V0.125P0.125H0.05C0.0375O0.6625", @@ -872,15 +873,15 @@ def test_formula(self): # assert formula in formed_formula, f"{formed_formula=} not in {expected_formula=}" -class TestPDPlotter(unittest.TestCase): +class TestPDPlotter(TestCase): def setUp(self): entries = list(EntrySet.from_csv(f"{module_dir}/pd_entries_test.csv")) - elemental_entries = [e for e in entries if e.elements == [Element("Li")]] + elemental_entries = [entry for entry in entries if entry.elements == [Element("Li")]] self.pd_unary = PhaseDiagram(elemental_entries) self.plotter_unary_plotly = PDPlotter(self.pd_unary, backend="plotly") - entries_LiO = [e for e in entries if "Fe" not in e.composition] + entries_LiO = [entry for entry in entries if "Fe" not in entry.composition] self.pd_binary = PhaseDiagram(entries_LiO) self.plotter_binary_mpl = PDPlotter(self.pd_binary, backend="matplotlib") self.plotter_binary_plotly = PDPlotter(self.pd_binary, backend="plotly") @@ -951,7 +952,7 @@ def test_plotly_plots(self): mock_show.assert_called_once() -class TestUtilityFunction(unittest.TestCase): +class TestUtilityFunction: def test_unique_lines(self): testdata = [ [5, 53, 353], diff --git a/tests/analysis/test_piezo_sensitivity.py b/tests/analysis/test_piezo_sensitivity.py index c2a5b64648d..753750375cb 100644 --- a/tests/analysis/test_piezo_sensitivity.py +++ b/tests/analysis/test_piezo_sensitivity.py @@ -156,24 +156,24 @@ def test_get_asum_fcm(self): rand_FCM = fcm.get_asum_FCM(rand_FCM) rand_FCM = np.reshape(rand_FCM, (10, 3, 10, 3)).swapaxes(1, 2) - for i in range(len(self.FCM_operations)): - for j in range(len(self.FCM_operations[i][4])): + for ii in range(len(self.FCM_operations)): + for jj in range(len(self.FCM_operations[ii][4])): assert_allclose( - self.FCM_operations[i][4][j].transform_tensor( - rand_FCM[self.FCM_operations[i][2]][self.FCM_operations[i][3]] + self.FCM_operations[ii][4][jj].transform_tensor( + rand_FCM[self.FCM_operations[ii][2]][self.FCM_operations[ii][3]] ), - rand_FCM[self.FCM_operations[i][0]][self.FCM_operations[i][1]], + rand_FCM[self.FCM_operations[ii][0]][self.FCM_operations[ii][1]], atol=1e-4, ) - for i in range(len(rand_FCM)): - asum1 = np.zeros([3, 3]) - asum2 = np.zeros([3, 3]) - for j in range(len(rand_FCM[i])): - asum1 += rand_FCM[i][j] - asum2 += rand_FCM[j][i] - assert_allclose(asum1, np.zeros([3, 3]), atol=1e-5) - assert_allclose(asum2, np.zeros([3, 3]), atol=1e-5) + for ii in range(len(rand_FCM)): + sum1 = np.zeros([3, 3]) + sum2 = np.zeros([3, 3]) + for jj in range(len(rand_FCM[ii])): + sum1 += rand_FCM[ii][jj] + sum2 += rand_FCM[jj][ii] + assert_allclose(sum1, np.zeros([3, 3]), atol=1e-5) + assert_allclose(sum2, np.zeros([3, 3]), atol=1e-5) def test_get_stable_fcm(self): fcm = ForceConstantMatrix(self.piezo_struct, self.FCM, self.point_ops, self.shared_ops) @@ -181,10 +181,10 @@ def test_get_stable_fcm(self): rand_FCM = fcm.get_unstable_FCM() rand_FCM1 = fcm.get_stable_FCM(rand_FCM) - eigs, _vecs = np.linalg.eig(rand_FCM1) - eigsort = np.argsort(np.abs(eigs)) - for i in range(3, len(eigs)): - assert eigs[eigsort[i]] < 1e-6 + eig_vals, _vecs = np.linalg.eig(rand_FCM1) + eig_sorted = np.argsort(np.abs(eig_vals)) + for i in range(3, len(eig_vals)): + assert eig_vals[eig_sorted[i]] < 1e-6 rand_FCM1 = np.reshape(rand_FCM1, (10, 3, 10, 3)).swapaxes(1, 2) @@ -227,10 +227,10 @@ def test_rand_fcm(self): dyn_mass[m][n] = dyn[m][n] / np.sqrt(masses[m]) / np.sqrt(masses[n]) dyn_mass = np.reshape(np.swapaxes(dyn_mass, 1, 2), (10 * 3, 10 * 3)) - eigs, _vecs = np.linalg.eig(dyn_mass) - eigsort = np.argsort(np.abs(eigs)) - for i in range(3, len(eigs)): - assert eigs[eigsort[i]] < 1e-6 + eig_vals, _eig_vecs = np.linalg.eig(dyn_mass) + eig_sort = np.argsort(np.abs(eig_vals)) + for i in range(3, len(eig_vals)): + assert eig_vals[eig_sort[i]] < 1e-6 # rand_FCM1 = np.reshape(rand_FCM1, (10,3,10,3)).swapaxes(1,2) dyn_mass = np.reshape(dyn_mass, (10, 3, 10, 3)).swapaxes(1, 2) @@ -263,19 +263,19 @@ def test_rand_piezo(self): self.piezo_struct, self.point_ops, self.shared_ops, self.BEC, self.IST, self.FCM ) - for i in range(len(self.BEC_operations)): - for j in range(len(self.BEC_operations[i][2])): + for ii in range(len(self.BEC_operations)): + for jj in range(len(self.BEC_operations[ii][2])): assert_allclose( - rand_BEC[self.BEC_operations[i][0]], - self.BEC_operations[i][2][j].transform_tensor(rand_BEC[self.BEC_operations[i][1]]), + rand_BEC[self.BEC_operations[ii][0]], + self.BEC_operations[ii][2][jj].transform_tensor(rand_BEC[self.BEC_operations[ii][1]]), atol=1e-3, ) - for i in range(len(self.IST_operations)): - for j in range(len(self.IST_operations[i])): + for ii in range(len(self.IST_operations)): + for jj in range(len(self.IST_operations[ii])): assert_allclose( - rand_IST[i], - self.IST_operations[i][j][1].transform_tensor(rand_IST[self.IST_operations[i][j][0]]), + rand_IST[ii], + self.IST_operations[ii][jj][1].transform_tensor(rand_IST[self.IST_operations[ii][jj][0]]), atol=1e-3, ) @@ -294,28 +294,28 @@ def test_rand_piezo(self): dyn_mass[m][n] = dyn[m][n] / np.sqrt(masses[m]) / np.sqrt(masses[n]) dyn_mass = np.reshape(np.swapaxes(dyn_mass, 1, 2), (10 * 3, 10 * 3)) - eigs, _eig_vecs = np.linalg.eig(dyn_mass) - eig_sorted = np.argsort(np.abs(eigs)) - for i in range(3, len(eigs)): - assert eigs[eig_sorted[i]] < 1e-6 + eig_vals, _eig_vecs = np.linalg.eig(dyn_mass) + eig_sort = np.argsort(np.abs(eig_vals)) + for idx in range(3, len(eig_vals)): + assert eig_vals[eig_sort[idx]] < 1e-6 # rand_FCM1 = np.reshape(rand_FCM1, (10,3,10,3)).swapaxes(1,2) dyn_mass = np.reshape(dyn_mass, (10, 3, 10, 3)).swapaxes(1, 2) - for i in range(len(self.FCM_operations)): - for j in range(len(self.FCM_operations[i][4])): + for ii in range(len(self.FCM_operations)): + for jj in range(len(self.FCM_operations[ii][4])): assert_allclose( - self.FCM_operations[i][4][j].transform_tensor( - dyn_mass[self.FCM_operations[i][2]][self.FCM_operations[i][3]] + self.FCM_operations[ii][4][jj].transform_tensor( + dyn_mass[self.FCM_operations[ii][2]][self.FCM_operations[ii][3]] ), - dyn_mass[self.FCM_operations[i][0]][self.FCM_operations[i][1]], + dyn_mass[self.FCM_operations[ii][0]][self.FCM_operations[ii][1]], atol=1e-4, ) - for i in range(len(dyn_mass)): + for ii in range(len(dyn_mass)): asum1 = np.zeros([3, 3]) asum2 = np.zeros([3, 3]) - for j in range(len(dyn_mass[i])): - asum1 += dyn_mass[i][j] - asum2 += dyn_mass[j][i] + for jj in range(len(dyn_mass[ii])): + asum1 += dyn_mass[ii][jj] + asum2 += dyn_mass[jj][ii] assert_allclose(asum1, np.zeros([3, 3]), atol=1e-5) assert_allclose(asum2, np.zeros([3, 3]), atol=1e-5) diff --git a/tests/analysis/test_pourbaix_diagram.py b/tests/analysis/test_pourbaix_diagram.py index c6990ac9401..7273c90caf8 100644 --- a/tests/analysis/test_pourbaix_diagram.py +++ b/tests/analysis/test_pourbaix_diagram.py @@ -2,7 +2,7 @@ import logging import multiprocessing -import unittest +from unittest import TestCase import matplotlib.pyplot as plt import numpy as np @@ -98,7 +98,7 @@ def test_get_elt_fraction(self): assert pb_entry.get_element_fraction("Mn") == approx(0.4) -class TestPourbaixDiagram(unittest.TestCase): +class TestPourbaixDiagram(TestCase): @classmethod def setUpClass(cls): cls.test_data = loadfn(f"{TEST_FILES_DIR}/pourbaix_test_data.json") @@ -106,7 +106,7 @@ def setUpClass(cls): cls.pbx_no_filter = PourbaixDiagram(cls.test_data["Zn"], filter_solids=False) def test_pourbaix_diagram(self): - assert {e.name for e in self.pbx.stable_entries} == { + assert {entry.name for entry in self.pbx.stable_entries} == { "ZnO(s)", "Zn[2+]", "ZnHO2[-]", @@ -114,7 +114,7 @@ def test_pourbaix_diagram(self): "Zn(s)", }, "List of stable entries does not match" - assert {e.name for e in self.pbx_no_filter.stable_entries} == { + assert {entry.name for entry in self.pbx_no_filter.stable_entries} == { "ZnO(s)", "Zn[2+]", "ZnHO2[-]", @@ -124,8 +124,8 @@ def test_pourbaix_diagram(self): "ZnH(s)", }, "List of stable entries for unfiltered pbx does not match" - pbx_lowconc = PourbaixDiagram(self.test_data["Zn"], conc_dict={"Zn": 1e-8}, filter_solids=True) - assert {e.name for e in pbx_lowconc.stable_entries} == { + pbx_low_conc = PourbaixDiagram(self.test_data["Zn"], conc_dict={"Zn": 1e-8}, filter_solids=True) + assert {entry.name for entry in pbx_low_conc.stable_entries} == { "Zn(HO)2(aq)", "Zn[2+]", "ZnHO2[-]", @@ -138,9 +138,9 @@ def test_properties(self): def test_multicomponent(self): # Assure no ions get filtered at high concentration - ag_n = [e for e in self.test_data["Ag-Te-N"] if "Te" not in e.composition] + ag_n = [entry for entry in self.test_data["Ag-Te-N"] if "Te" not in entry.composition] highconc = PourbaixDiagram(ag_n, filter_solids=True, conc_dict={"Ag": 1e-5, "N": 1}) - entry_sets = [set(e.entry_id) for e in highconc.stable_entries] + entry_sets = [set(entry.entry_id) for entry in highconc.stable_entries] assert {"mp-124", "ion-17"} in entry_sets # Binary system @@ -256,7 +256,7 @@ def test_solid_filter(self): def test_serialization(self): dct = self.pbx.as_dict() new = PourbaixDiagram.from_dict(dct) - assert {e.name for e in new.stable_entries} == { + assert {entry.name for entry in new.stable_entries} == { "ZnO(s)", "Zn[2+]", "ZnHO2[-]", @@ -268,7 +268,7 @@ def test_serialization(self): # previously filtered entries being included dct = self.pbx_no_filter.as_dict() new = PourbaixDiagram.from_dict(dct) - assert {e.name for e in new.stable_entries} == { + assert {entry.name for entry in new.stable_entries} == { "ZnO(s)", "Zn[2+]", "ZnHO2[-]", @@ -288,7 +288,7 @@ def test_serialization(self): assert len(pd_binary.stable_entries) == len(new_binary.stable_entries) -class TestPourbaixPlotter(unittest.TestCase): +class TestPourbaixPlotter(TestCase): def setUp(self): self.test_data = loadfn(f"{TEST_FILES_DIR}/pourbaix_test_data.json") self.pd = PourbaixDiagram(self.test_data["Zn"]) diff --git a/tests/analysis/test_quasiharmonic_debye_approx.py b/tests/analysis/test_quasi_harmonic_debye_approx.py similarity index 95% rename from tests/analysis/test_quasiharmonic_debye_approx.py rename to tests/analysis/test_quasi_harmonic_debye_approx.py index ae0a7fb4cca..9a9aa86bce5 100644 --- a/tests/analysis/test_quasiharmonic_debye_approx.py +++ b/tests/analysis/test_quasi_harmonic_debye_approx.py @@ -1,18 +1,18 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np from numpy.testing import assert_allclose from pymatgen.analysis.eos import EOS -from pymatgen.analysis.quasiharmonic import QuasiharmonicDebyeApprox +from pymatgen.analysis.quasiharmonic import QuasiHarmonicDebyeApprox from pymatgen.core.structure import Structure __author__ = "Kiran Mathew" -class TestQuasiharmociDebyeApprox(unittest.TestCase): +class TestQuasiHarmonicDebyeApprox(TestCase): def setUp(self): struct = Structure.from_dict( { @@ -95,7 +95,7 @@ def setUp(self): ] self.eos = "vinet" self.T = 300 - self.qhda = QuasiharmonicDebyeApprox( + self.qhda = QuasiHarmonicDebyeApprox( self.energies, self.volumes, struct, @@ -137,7 +137,7 @@ def test_vibrational_free_energy(self): assert_allclose(A, 0.494687, atol=1e-3) -class TestAnharmonicQuasiharmociDebyeApprox(unittest.TestCase): +class TestAnharmonicQuasiHarmonicDebyeApprox(TestCase): def setUp(self): struct = Structure.from_str( """FCC Al @@ -172,7 +172,7 @@ def setUp(self): ] self.eos = "vinet" self.T = 500 - self.qhda = QuasiharmonicDebyeApprox( + self.qhda = QuasiHarmonicDebyeApprox( self.energies, self.volumes, struct, diff --git a/tests/analysis/test_quasirrho.py b/tests/analysis/test_quasirrho.py index 1e3da6dfcdb..798ca867396 100644 --- a/tests/analysis/test_quasirrho.py +++ b/tests/analysis/test_quasirrho.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pytest @@ -10,7 +10,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestQuasiRRHO(unittest.TestCase): +class TestQuasiRRHO(TestCase): """Test class for QuasiRRHO""" def setUp(self): @@ -48,13 +48,13 @@ def test_rrho_manual(self): """ Test manual input creation. Values from GaussianOutput """ - e = self.gout.final_energy + e_final = self.gout.final_energy mol = self.gout.final_structure - vib_freqs = [f["frequency"] for f in self.gout.frequencies[-1]] + vib_freqs = [freq["frequency"] for freq in self.gout.frequencies[-1]] correct_g = -884.776886 correct_stot = 141.584080 - qrrho = QuasiRRHO(mol=mol, energy=e, frequencies=vib_freqs, mult=1) + qrrho = QuasiRRHO(mol=mol, energy=e_final, frequencies=vib_freqs, mult=1) assert correct_stot == pytest.approx(qrrho.entropy_quasiRRHO, 0.1), "Incorrect total entropy" assert correct_g == pytest.approx(qrrho.free_energy_quasiRRHO), "Incorrect Quasi-RRHO free energy" diff --git a/tests/analysis/test_reaction_calculator.py b/tests/analysis/test_reaction_calculator.py index 90efd95c7ae..dcb06819574 100644 --- a/tests/analysis/test_reaction_calculator.py +++ b/tests/analysis/test_reaction_calculator.py @@ -1,8 +1,8 @@ from __future__ import annotations -import unittest from collections import defaultdict from math import isnan +from unittest import TestCase import numpy as np import pytest @@ -13,7 +13,7 @@ from pymatgen.entries.computed_entries import ComputedEntry -class TestReaction(unittest.TestCase): +class TestReaction: def test_init(self): reactants = [Composition("Fe"), Composition("O2")] products = [Composition("Fe2O3")] @@ -287,7 +287,7 @@ def test_underdetermined_reactants(self): assert str(rxn) == "LiMnCl3 + 3 LiCl + MnCl2 -> 2 Li2MnCl4" -class TestBalancedReaction(unittest.TestCase): +class TestBalancedReaction(TestCase): def setUp(self) -> None: rct = {"K2SO4": 3, "Na2S": 1, "Li": 24} prod = {"KNaS": 2, "K2S": 2, "Li2O": 12} @@ -334,7 +334,7 @@ def test_hash(self): assert hash(self.rxn) == 4774511606373046513 -class TestComputedReaction(unittest.TestCase): +class TestComputedReaction(TestCase): def setUp(self): dct = [ { @@ -372,7 +372,7 @@ def setUp(self): self.rxn = ComputedReaction(reactants, prods) - def test_calculated_reaction_energy(self): + def test_nd_reaction_energy(self): assert self.rxn.calculated_reaction_energy == approx(-5.60748821935) def test_calculated_reaction_energy_uncertainty(self): @@ -432,7 +432,7 @@ def test_calculated_reaction_energy_uncertainty(self): "correction": -1.864, }, ] - entries = [ComputedEntry.from_dict(e) for e in d] + entries = [ComputedEntry.from_dict(entry) for entry in d] reactants = list(filter(lambda e: e.reduced_formula in ["Li", "O2"], entries)) prods = list(filter(lambda e: e.reduced_formula == "Li2O2", entries)) @@ -447,7 +447,7 @@ def test_calculated_reaction_energy_uncertainty_for_no_uncertainty(self): def test_calculated_reaction_energy_uncertainty_for_nan(self): # test that reaction_energy_uncertainty property is nan when the uncertainty # for any product/reactant is nan - d = [ + dicts = [ { "correction": 0, "data": {}, @@ -503,7 +503,7 @@ def test_calculated_reaction_energy_uncertainty_for_nan(self): "correction": -1.864, }, ] - entries = [ComputedEntry.from_dict(e) for e in d] + entries = [ComputedEntry.from_dict(entry) for entry in dicts] reactants = list(filter(lambda e: e.reduced_formula in ["Li", "O2"], entries)) prods = list(filter(lambda e: e.reduced_formula == "Li2O2", entries)) diff --git a/tests/analysis/test_structure_analyzer.py b/tests/analysis/test_structure_analyzer.py index 11a7c65243d..12a0c02f991 100644 --- a/tests/analysis/test_structure_analyzer.py +++ b/tests/analysis/test_structure_analyzer.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np from numpy.testing import assert_allclose @@ -23,21 +23,21 @@ class TestVoronoiAnalyzer(PymatgenTest): def setUp(self): - self.ss = Xdatcar(f"{VASP_OUT_DIR}/XDATCAR.MD").structures - self.s = self.ss[1] + self.structs = Xdatcar(f"{VASP_OUT_DIR}/XDATCAR.MD").structures + self.struct = self.structs[1] self.va = VoronoiAnalyzer(cutoff=4) def test_analyze(self): # Check for the Voronoi index of site i in Structure - single_structure = self.va.analyze(self.s, n=5) + single_structure = self.va.analyze(self.struct, n=5) assert single_structure.view() in np.array([4, 3, 3, 4, 2, 2, 1, 0]).view(), "Cannot find the right polyhedron." # Check for the presence of a Voronoi index and its frequency in # a ensemble (list) of Structures - ensemble = self.va.analyze_structures(self.ss, step_freq=2, most_frequent_polyhedra=10) + ensemble = self.va.analyze_structures(self.structs, step_freq=2, most_frequent_polyhedra=10) assert ("[1 3 4 7 1 0 0 0]", 3) in ensemble, "Cannot find the right polyhedron in ensemble." -class TestRelaxationAnalyzer(unittest.TestCase): +class TestRelaxationAnalyzer(TestCase): def setUp(self): s1 = Structure.from_file(f"{VASP_IN_DIR}/POSCAR_Li2O") s2 = Structure.from_file(f"{VASP_OUT_DIR}/CONTCAR_Li2O") @@ -95,16 +95,16 @@ def test_solid_angle(self): assert solid_angle(center, coords) == approx(1.83570965938, abs=1e-7), "Wrong result returned by solid_angle" def test_contains_peroxide(self): - for f in ["LiFePO4", "NaFePO4", "Li3V2(PO4)3", "Li2O"]: - assert not contains_peroxide(self.get_structure(f)) + for formula in ("LiFePO4", "NaFePO4", "Li3V2(PO4)3", "Li2O"): + assert not contains_peroxide(self.get_structure(formula)) - for f in ["Li2O2", "K2O2"]: - assert contains_peroxide(self.get_structure(f)) + for formula in ("Li2O2", "K2O2"): + assert contains_peroxide(self.get_structure(formula)) def test_oxide_type(self): el_li = Element("Li") el_o = Element("O") - latt = Lattice([[3.985034, 0, 0], [0, 4.881506, 0], [0, 0, 2.959824]]) + lattice = Lattice([[3.985034, 0, 0], [0, 4.881506, 0], [0, 0, 2.959824]]) elems = [el_li, el_li, el_o, el_o, el_o, el_o] coords = [ [0.5, 0.5, 0.5], @@ -114,23 +114,23 @@ def test_oxide_type(self): [0.132568, 0.414910, 0], [0.867432, 0.585090, 0], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "superoxide" el_li = Element("Li") el_o = Element("O") elems = [el_li, el_o, el_o, el_o] - latt = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) + lattice = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) coords = [ [0.513004, 0.513004, 1.000000], [0.017616, 0.017616, 0.000000], [0.649993, 0.874790, 0.775203], [0.099587, 0.874790, 0.224797], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "ozonide" - latt = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) + lattice = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) el_li = Element("Li") el_o = Element("O") elems = [el_li, el_li, el_li, el_li, el_o, el_o, el_o, el_o] @@ -144,13 +144,13 @@ def test_oxide_type(self): [0.666666, 0.666686, 0.350813], [0.666665, 0.666684, 0.149189], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "peroxide" el_li = Element("Li") el_o = Element("O") el_h = Element("H") - latt = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) + lattice = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) elems = [el_h, el_h, el_li, el_li, el_o, el_o] coords = [ [0.000000, 0.500000, 0.413969], @@ -160,13 +160,13 @@ def test_oxide_type(self): [0.000000, 0.500000, 0.192672], [0.500000, 0.000000, 0.807328], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "hydroxide" el_li = Element("Li") el_n = Element("N") el_h = Element("H") - latt = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) + lattice = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) elems = [el_h, el_h, el_li, el_li, el_n, el_n] coords = [ [0.000000, 0.500000, 0.413969], @@ -176,11 +176,11 @@ def test_oxide_type(self): [0.000000, 0.500000, 0.192672], [0.500000, 0.000000, 0.807328], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "None" el_o = Element("O") - latt = Lattice.from_parameters(4.389828, 5.369789, 5.369789, 70.786622, 69.244828, 69.244828) + lattice = Lattice.from_parameters(4.389828, 5.369789, 5.369789, 70.786622, 69.244828, 69.244828) elems = [el_o, el_o, el_o, el_o, el_o, el_o, el_o, el_o] coords = [ [0.844609, 0.273459, 0.786089], @@ -192,12 +192,12 @@ def test_oxide_type(self): [0.132641, 0.148222, 0.148222], [0.867359, 0.851778, 0.851778], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) assert oxide_type(struct, 1.1) == "None" def test_sulfide_type(self): # NaS2 -> polysulfide - latt = Lattice.tetragonal(9.59650, 11.78850) + lattice = Lattice.tetragonal(9.59650, 11.78850) species = ["Na"] * 2 + ["S"] * 2 coords = [ [0.00000, 0.00000, 0.17000], @@ -205,18 +205,18 @@ def test_sulfide_type(self): [0.03400, 0.25000, 0.29600], [0.14700, 0.11600, 0.40000], ] - struct = Structure.from_spacegroup(122, latt, species, coords) + struct = Structure.from_spacegroup(122, lattice, species, coords) assert sulfide_type(struct) == "polysulfide" # NaCl type NaS -> sulfide - latt = Lattice.cubic(5.75) + lattice = Lattice.cubic(5.75) species = ["Na", "S"] coords = [[0.00000, 0.00000, 0.00000], [0.50000, 0.50000, 0.50000]] - struct = Structure.from_spacegroup(225, latt, species, coords) + struct = Structure.from_spacegroup(225, lattice, species, coords) assert sulfide_type(struct) == "sulfide" # Na2S2O3 -> None (sulfate) - latt = Lattice.monoclinic(6.40100, 8.10000, 8.47400, 96.8800) + lattice = Lattice.monoclinic(6.40100, 8.10000, 8.47400, 96.8800) species = ["Na"] * 2 + ["S"] * 2 + ["O"] * 3 coords = [ [0.29706, 0.62396, 0.08575], @@ -227,11 +227,11 @@ def test_sulfide_type(self): [0.38604, -0.20144, 0.33624], [0.16248, -0.08546, 0.11608], ] - struct = Structure.from_spacegroup(14, latt, species, coords) + struct = Structure.from_spacegroup(14, lattice, species, coords) assert sulfide_type(struct) is None # Na3PS3O -> sulfide - latt = Lattice.orthorhombic(9.51050, 11.54630, 5.93230) + lattice = Lattice.orthorhombic(9.51050, 11.54630, 5.93230) species = ["Na"] * 2 + ["S"] * 2 + ["P", "O"] coords = [ [0.19920, 0.11580, 0.24950], @@ -241,7 +241,7 @@ def test_sulfide_type(self): [0.50000, 0.29400, 0.35500], [0.50000, 0.30300, 0.61140], ] - struct = Structure.from_spacegroup(36, latt, species, coords) + struct = Structure.from_spacegroup(36, lattice, species, coords) assert sulfide_type(struct) == "sulfide" # test for unphysical cells diff --git a/tests/analysis/test_structure_matcher.py b/tests/analysis/test_structure_matcher.py index 1268295cf88..52a58a0c508 100644 --- a/tests/analysis/test_structure_matcher.py +++ b/tests/analysis/test_structure_matcher.py @@ -73,8 +73,8 @@ def test_get_supercell_size(self): assert sm._get_supercell_size(s1, s2) == (1, True) assert sm._get_supercell_size(s2, s1) == (1, True) - sm = StructureMatcher(supercell_size="wfieoh") - with pytest.raises(ValueError, match="Can't parse Element or Species from str: wfieoh"): + sm = StructureMatcher(supercell_size="invalid") + with pytest.raises(ValueError, match="Can't parse Element or Species from 'invalid'"): sm._get_supercell_size(s1, s2) def test_cmp_fstruct(self): @@ -97,7 +97,7 @@ def test_cmp_fstruct(self): def test_cart_dists(self): sm = StructureMatcher() - latt = Lattice.orthorhombic(1, 2, 3) + lattice = Lattice.orthorhombic(1, 2, 3) s1 = np.array([[0.13, 0.25, 0.37], [0.1, 0.2, 0.3]]) s2 = np.array([[0.11, 0.22, 0.33]]) @@ -108,41 +108,41 @@ def test_cart_dists(self): mask3 = np.array([[False, False], [False, False]]) mask4 = np.array([[False, True], [False, True]]) - n1 = (len(s1) / latt.volume) ** (1 / 3) - n2 = (len(s2) / latt.volume) ** (1 / 3) + n1 = (len(s1) / lattice.volume) ** (1 / 3) + n2 = (len(s2) / lattice.volume) ** (1 / 3) with pytest.raises(ValueError, match=r"len\(s1\)=1 must be larger than len\(s2\)=2"): - sm._cart_dists(s2, s1, latt, mask.T, n2) + sm._cart_dists(s2, s1, lattice, mask.T, n2) with pytest.raises(ValueError, match="mask has incorrect shape"): - sm._cart_dists(s1, s2, latt, mask.T, n1) + sm._cart_dists(s1, s2, lattice, mask.T, n1) - d, ft, s = sm._cart_dists(s1, s2, latt, mask, n1) - assert_allclose(d, [0]) - assert_allclose(ft, [-0.01, -0.02, -0.03]) - assert_allclose(s, [1]) + distances, trac_trans_vec, solution = sm._cart_dists(s1, s2, lattice, mask, n1) + assert_allclose(distances, [0]) + assert_allclose(trac_trans_vec, [-0.01, -0.02, -0.03]) + assert_allclose(solution, [1]) # check that masking best value works - d, ft, s = sm._cart_dists(s1, s2, latt, mask2, n1) - assert_allclose(d, [0]) - assert_allclose(ft, [0.02, 0.03, 0.04]) - assert_allclose(s, [0]) + distances, trac_trans_vec, solution = sm._cart_dists(s1, s2, lattice, mask2, n1) + assert_allclose(distances, [0]) + assert_allclose(trac_trans_vec, [0.02, 0.03, 0.04]) + assert_allclose(solution, [0]) # check that averaging of translation is done properly - d, ft, s = sm._cart_dists(s1, s3, latt, mask3, n1) - assert_allclose(d, [0.08093341] * 2) - assert_allclose(ft, [0.01, 0.025, 0.035]) - assert_allclose(s, [1, 0]) + distances, trac_trans_vec, solution = sm._cart_dists(s1, s3, lattice, mask3, n1) + assert_allclose(distances, [0.08093341] * 2) + assert_allclose(trac_trans_vec, [0.01, 0.025, 0.035]) + assert_allclose(solution, [1, 0]) # check distances are large when mask allows no 'real' mapping - d, ft, s = sm._cart_dists(s1, s4, latt, mask4, n1) - assert np.min(d) > 1e8 - assert np.min(ft) > 1e8 + distances, trac_trans_vec, solution = sm._cart_dists(s1, s4, lattice, mask4, n1) + assert np.min(distances) > 1e8 + assert np.min(trac_trans_vec) > 1e8 def test_get_mask(self): sm = StructureMatcher(comparator=ElementComparator()) - latt = Lattice.cubic(1) - s1 = Structure(latt, ["Mg", "Cu", "Ag", "Cu"], [[0] * 3] * 4) - s2 = Structure(latt, ["Cu", "Cu", "Ag"], [[0] * 3] * 3) + lattice = Lattice.cubic(1) + s1 = Structure(lattice, ["Mg", "Cu", "Ag", "Cu"], [[0] * 3] * 4) + s2 = Structure(lattice, ["Cu", "Cu", "Ag"], [[0] * 3] * 3) result = [ [True, False, True, False], @@ -194,8 +194,8 @@ def test_get_mask(self): assert list(inds) == [] # test for multiple translation indices - s1 = Structure(latt, ["Cu", "Ag", "Cu", "Ag", "Ag"], [[0] * 3] * 5) - s2 = Structure(latt, ["Ag", "Cu", "Ag"], [[0] * 3] * 3) + s1 = Structure(lattice, ["Cu", "Ag", "Cu", "Ag", "Ag"], [[0] * 3] * 5) + s2 = Structure(lattice, ["Ag", "Cu", "Ag"], [[0] * 3] * 3) result = [[1, 0, 1, 0, 0], [0, 1, 0, 1, 1], [1, 0, 1, 0, 0]] mask, inds, idx = sm._get_mask(s1, s2, 1, s1_supercell=True) @@ -205,9 +205,9 @@ def test_get_mask(self): def test_get_supercells(self): sm = StructureMatcher(comparator=ElementComparator()) - latt = Lattice.cubic(1) + lattice = Lattice.cubic(1) l2 = Lattice.cubic(0.5) - s1 = Structure(latt, ["Mg", "Cu", "Ag", "Cu"], [[0] * 3] * 4) + s1 = Structure(lattice, ["Mg", "Cu", "Ag", "Cu"], [[0] * 3] * 4) s2 = Structure(l2, ["Cu", "Cu", "Ag"], [[0] * 3] * 3) scs = list(sm._get_supercells(s1, s2, fu=8, s1_supercell=False)) for x in scs: @@ -393,9 +393,9 @@ def test_find_match1(self): scale=True, attempt_supercell=False, ) - latt = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) - s2 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [0.7, 0.5, 0.375]]) + lattice = Lattice.orthorhombic(1, 2, 3) + s1 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s2 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [0.7, 0.5, 0.375]]) s1, s2, fu, s1_supercell = sm._preprocess(s1, s2, niggli=False) assert s1_supercell is True @@ -406,7 +406,7 @@ def test_find_match1(self): fc -= np.round(fc) assert np.sum(fc) == approx(0.9) assert np.sum(fc[:, :2]) == approx(0.1) - cart_dist = np.sum(match[1] * (latt.volume / 3) ** (1 / 3)) + cart_dist = np.sum(match[1] * (lattice.volume / 3) ** (1 / 3)) assert cart_dist == approx(0.15) def test_find_match2(self): @@ -418,9 +418,9 @@ def test_find_match2(self): scale=True, attempt_supercell=False, ) - latt = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Si", "Si"], [[0, 0, 0.1], [0, 0, 0.2]]) - s2 = Structure(latt, ["Si", "Si"], [[0, 0.1, 0], [0, 0.1, -0.95]]) + lattice = Lattice.orthorhombic(1, 2, 3) + s1 = Structure(lattice, ["Si", "Si"], [[0, 0, 0.1], [0, 0, 0.2]]) + s2 = Structure(lattice, ["Si", "Si"], [[0, 0.1, 0], [0, 0.1, -0.95]]) s1, s2, fu, _s1_supercell = sm._preprocess(s1, s2, niggli=False) @@ -453,10 +453,10 @@ def test_supercell_subsets(self): allow_subset=False, supercell_size="volume", ) - latt = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) + lattice = Lattice.orthorhombic(1, 2, 3) + s1 = Structure(lattice, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) s1.make_supercell([2, 1, 1]) - s2 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0.1, -0.95], [0, 0.1, 0], [-0.7, 0.5, 0.375]]) + s2 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0.1, -0.95], [0, 0.1, 0], [-0.7, 0.5, 0.375]]) shuffle = [0, 2, 1, 3, 4, 5] s1 = Structure.from_sites([s1[i] for i in shuffle]) @@ -523,8 +523,8 @@ def test_get_s2_large_s2(self): supercell_size="volume", ) - latt = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) + lattice = Lattice.orthorhombic(1, 2, 3) + s1 = Structure(lattice, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) l2 = Lattice.orthorhombic(1.01, 2.01, 3.01) s2 = Structure(l2, ["Si", "Si", "Ag"], [[0, 0.1, -0.95], [0, 0.1, 0], [-0.7, 0.5, 0.375]]) @@ -545,28 +545,28 @@ def test_get_mapping(self): attempt_supercell=False, allow_subset=True, ) - latt = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) - s1.make_supercell([2, 1, 1]) - s2 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0.1, -0.95], [0, 0.1, 0], [-0.7, 0.5, 0.375]]) + lattice = Lattice.orthorhombic(1, 2, 3) + struct1 = Structure(lattice, ["Ag", "Si", "Si"], [[0.7, 0.4, 0.5], [0, 0, 0.1], [0, 0, 0.2]]) + struct1.make_supercell([2, 1, 1]) + struct2 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0.1, -0.95], [0, 0.1, 0], [-0.7, 0.5, 0.375]]) shuffle = [2, 0, 1, 3, 5, 4] - s1 = Structure.from_sites([s1[i] for i in shuffle]) + struct1 = Structure.from_sites([struct1[i] for i in shuffle]) # test the mapping - s2.make_supercell([2, 1, 1]) + struct2.make_supercell([2, 1, 1]) # equal sizes - for i, x in enumerate(sm.get_mapping(s1, s2)): - assert s1[x].species == s2[i].species + for ii, jj in enumerate(sm.get_mapping(struct1, struct2)): + assert struct1[jj].species == struct2[ii].species - del s1[0] + del struct1[0] # s1 is subset of s2 - for i, x in enumerate(sm.get_mapping(s2, s1)): - assert s1[i].species == s2[x].species + for ii, jj in enumerate(sm.get_mapping(struct2, struct1)): + assert struct1[ii].species == struct2[jj].species # s2 is smaller than s1 - del s2[0] - del s2[1] + del struct2[0] + del struct2[1] with pytest.raises(ValueError, match="subset is larger than superset"): - sm.get_mapping(s2, s1) + sm.get_mapping(struct2, struct1) def test_get_supercell_matrix(self): sm = StructureMatcher( @@ -578,18 +578,18 @@ def test_get_supercell_matrix(self): attempt_supercell=True, ) - latt = Lattice.orthorhombic(1, 2, 3) + lattice = Lattice.orthorhombic(1, 2, 3) - s1 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s1 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) s1.make_supercell([2, 1, 1]) - s2 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [-0.7, 0.5, 0.375]]) + s2 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [-0.7, 0.5, 0.375]]) result = sm.get_supercell_matrix(s1, s2) assert (result == [[-2, 0, 0], [0, 1, 0], [0, 0, 1]]).all() - s1 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s1 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) s1.make_supercell([[1, -1, 0], [0, 0, -1], [0, 1, 0]]) - s2 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [-0.7, 0.5, 0.375]]) + s2 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0.1, 0], [0, 0.1, -0.95], [-0.7, 0.5, 0.375]]) result = sm.get_supercell_matrix(s1, s2) assert (result == [[-1, -1, 0], [0, 0, -1], [0, 1, 0]]).all() @@ -617,17 +617,17 @@ def test_subset(self): attempt_supercell=False, allow_subset=True, ) - latt = Lattice.orthorhombic(10, 20, 30) - s1 = Structure(latt, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) - s2 = Structure(latt, ["Si", "Ag"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) + lattice = Lattice.orthorhombic(10, 20, 30) + s1 = Structure(lattice, ["Si", "Si", "Ag"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s2 = Structure(lattice, ["Si", "Ag"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) result = sm.get_s2_like_s1(s1, s2) assert len(find_in_coord_list_pbc(result.frac_coords, [0, 0, 0.1])) == 1 assert len(find_in_coord_list_pbc(result.frac_coords, [0.7, 0.4, 0.5])) == 1 # test with fewer species in s2 - s1 = Structure(latt, ["Si", "Ag", "Si"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) - s2 = Structure(latt, ["Si", "Si"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) + s1 = Structure(lattice, ["Si", "Ag", "Si"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s2 = Structure(lattice, ["Si", "Si"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) result = sm.get_s2_like_s1(s1, s2) mindists = np.min(s1.lattice.get_all_distances(s1.frac_coords, result.frac_coords), axis=0) assert np.max(mindists) < 1e-6 @@ -637,14 +637,14 @@ def test_subset(self): # test with not enough sites in s1 # test with fewer species in s2 - s1 = Structure(latt, ["Si", "Ag", "Cl"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) - s2 = Structure(latt, ["Si", "Si"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) + s1 = Structure(lattice, ["Si", "Ag", "Cl"], [[0, 0, 0.1], [0, 0, 0.2], [0.7, 0.4, 0.5]]) + s2 = Structure(lattice, ["Si", "Si"], [[0, 0.1, 0], [-0.7, 0.5, 0.4]]) assert sm.get_s2_like_s1(s1, s2) is None def test_out_of_cell_s2_like_s1(self): - latt = Lattice.cubic(5) - s1 = Structure(latt, ["Si", "Ag", "Si"], [[0, 0, -0.02], [0, 0, 0.001], [0.7, 0.4, 0.5]]) - s2 = Structure(latt, ["Si", "Ag", "Si"], [[0, 0, 0.98], [0, 0, 0.99], [0.7, 0.4, 0.5]]) + lattice = Lattice.cubic(5) + s1 = Structure(lattice, ["Si", "Ag", "Si"], [[0, 0, -0.02], [0, 0, 0.001], [0.7, 0.4, 0.5]]) + s2 = Structure(lattice, ["Si", "Ag", "Si"], [[0, 0, 0.98], [0, 0, 0.99], [0.7, 0.4, 0.5]]) new_s2 = StructureMatcher(primitive_cell=False).get_s2_like_s1(s1, s2) dists = np.sum((s1.cart_coords - new_s2.cart_coords) ** 2, axis=-1) ** 0.5 assert np.max(dists) < 0.1 @@ -785,11 +785,11 @@ def test_rms_vs_minimax(self): # greater than stol are treated properly # stol=0.3 gives exactly an ftol of 0.1 on the c axis sm = StructureMatcher(ltol=0.2, stol=0.301, angle_tol=1, primitive_cell=False) - latt = Lattice.orthorhombic(1, 2, 12) + lattice = Lattice.orthorhombic(1, 2, 12) sp = ["Si", "Si", "Al"] - s1 = Structure(latt, sp, np.diag((0.5, 0, 0.5))) - s2 = Structure(latt, sp, np.diag((0.5, 0, 0.6))) + s1 = Structure(lattice, sp, np.diag((0.5, 0, 0.5))) + s2 = Structure(lattice, sp, np.diag((0.5, 0, 0.6))) assert_allclose(sm.get_rms_dist(s1, s2), (0.32**0.5 / 2, 0.4)) assert sm.fit(s1, s2) is False diff --git a/tests/analysis/test_surface_analysis.py b/tests/analysis/test_surface_analysis.py index cf54d2cd8c8..1c4364884fe 100644 --- a/tests/analysis/test_surface_analysis.py +++ b/tests/analysis/test_surface_analysis.py @@ -18,7 +18,7 @@ __date__ = "Aug 24, 2017" -TEST_DIR = f"{TEST_FILES_DIR}/surface_tests" +TEST_DIR = f"{TEST_FILES_DIR}/surfaces" class TestSlabEntry(PymatgenTest): @@ -113,8 +113,8 @@ def test_cleaned_up_slab(self): for hkl in val: for clean in self.metals_O_entry_dict[el][hkl]: for ads in self.metals_O_entry_dict[el][hkl][clean]: - s = ads.cleaned_up_slab - assert s.composition.reduced_composition == clean.composition.reduced_composition + slab_clean = ads.cleaned_up_slab + assert slab_clean.composition.reduced_composition == clean.composition.reduced_composition class TestSurfaceEnergyPlotter(PymatgenTest): @@ -285,7 +285,7 @@ def test_entry_dict_from_list(self): # plt = analyzer.chempot_vs_gamma_facet(hkl) -class TestWorkfunctionAnalyzer(PymatgenTest): +class TestWorkFunctionAnalyzer(PymatgenTest): def setUp(self): self.kwargs = { "poscar_filename": f"{TEST_DIR}/CONTCAR.relax1.gz", @@ -320,26 +320,26 @@ def setUp(self): def test_stability_at_r(self): # Check that we have a different polymorph that is # stable below or above the equilibrium particle size - r = self.nanoscale_stability.solve_equilibrium_point(self.La_hcp_analyzer, self.La_fcc_analyzer) * 10 + radius = self.nanoscale_stability.solve_equilibrium_point(self.La_hcp_analyzer, self.La_fcc_analyzer) * 10 # hcp phase of La particle should be the stable # polymorph above the equilibrium radius hcp_wulff = self.La_hcp_analyzer.wulff_from_chempot() bulk = self.La_hcp_analyzer.ucell_entry - ghcp, _rhcp = self.nanoscale_stability.wulff_gform_and_r(hcp_wulff, bulk, r + 10, from_sphere_area=True) + ghcp, _rhcp = self.nanoscale_stability.wulff_gform_and_r(hcp_wulff, bulk, radius + 10, from_sphere_area=True) fcc_wulff = self.La_fcc_analyzer.wulff_from_chempot() bulk = self.La_fcc_analyzer.ucell_entry - gfcc, _rfcc = self.nanoscale_stability.wulff_gform_and_r(fcc_wulff, bulk, r + 10, from_sphere_area=True) + gfcc, _rfcc = self.nanoscale_stability.wulff_gform_and_r(fcc_wulff, bulk, radius + 10, from_sphere_area=True) assert gfcc > ghcp # fcc phase of La particle should be the stable # polymorph below the equilibrium radius hcp_wulff = self.La_hcp_analyzer.wulff_from_chempot() bulk = self.La_hcp_analyzer.ucell_entry - ghcp, _rhcp = self.nanoscale_stability.wulff_gform_and_r(hcp_wulff, bulk, r - 10, from_sphere_area=True) + ghcp, _rhcp = self.nanoscale_stability.wulff_gform_and_r(hcp_wulff, bulk, radius - 10, from_sphere_area=True) fcc_wulff = self.La_fcc_analyzer.wulff_from_chempot() bulk = self.La_fcc_analyzer.ucell_entry - gfcc, _rfcc = self.nanoscale_stability.wulff_gform_and_r(fcc_wulff, bulk, r - 10, from_sphere_area=True) + gfcc, _rfcc = self.nanoscale_stability.wulff_gform_and_r(fcc_wulff, bulk, radius - 10, from_sphere_area=True) assert gfcc < ghcp def test_scaled_wulff(self): @@ -360,23 +360,23 @@ def get_entry_dict(filename): entry_dict = {} with open(filename) as file: entries = json.loads(file.read()) - for k in entries: - n = k[25:] + for entry in entries: + sub_str = entry[25:] miller_index = [] - for i, s in enumerate(n): - if s == "_": + for idx, char in enumerate(sub_str): + if char == "_": break - if s == "-": + if char == "-": continue - t = int(s) - if n[i - 1] == "-": + t = int(char) + if sub_str[idx - 1] == "-": t *= -1 miller_index.append(t) hkl = tuple(miller_index) if hkl not in entry_dict: entry_dict[hkl] = {} - entry = ComputedStructureEntry.from_dict(entries[k]) - entry_dict[hkl][SlabEntry(entry.structure, entry.energy, hkl, label=k)] = [] + entry = ComputedStructureEntry.from_dict(entries[entry]) + entry_dict[hkl][SlabEntry(entry.structure, entry.energy, hkl, label=entry)] = [] return entry_dict diff --git a/tests/analysis/test_wulff.py b/tests/analysis/test_wulff.py index d009df41463..b933debcccf 100644 --- a/tests/analysis/test_wulff.py +++ b/tests/analysis/test_wulff.py @@ -75,9 +75,12 @@ def test_get_plot(self): # Basic test to check figure contains a single Axes3D object for wulff in (self.wulff_Nb, self.wulff_Ir, self.wulff_Ti): - plt = wulff.get_plot() - assert len(plt.gcf().get_axes()) == 1 - assert isinstance(plt.gcf().get_axes()[0], Axes3D) + ax_3d = wulff.get_plot() + assert isinstance(ax_3d, Axes3D) + assert len(ax_3d.collections) in (24, 74, 110) + assert ax_3d.get_title() == "" + assert ax_3d.get_xlabel() == "x" + assert ax_3d.get_ylabel() == "y" def test_get_plotly(self): # Basic test, not really a unittest. diff --git a/tests/analysis/test_xps.py b/tests/analysis/test_xps.py index 7369efd7209..45b250e3100 100644 --- a/tests/analysis/test_xps.py +++ b/tests/analysis/test_xps.py @@ -5,7 +5,7 @@ from pymatgen.util.testing import VASP_OUT_DIR, PymatgenTest -class XPSTestCase(PymatgenTest): +class TestXPS(PymatgenTest): def test_from_dos(self): vasp_run = Vasprun(f"{VASP_OUT_DIR}/vasprun.LiF.xml.gz") dos = vasp_run.complete_dos diff --git a/tests/apps/battery/test_conversion_battery.py b/tests/apps/battery/test_conversion_battery.py index 33fbca9f725..767c07b2bca 100644 --- a/tests/apps/battery/test_conversion_battery.py +++ b/tests/apps/battery/test_conversion_battery.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from monty.json import MontyDecoder from pytest import approx @@ -11,7 +11,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestConversionElectrode(unittest.TestCase): +class TestConversionElectrode(TestCase): def setUp(self): self.formulas = ["LiCoO2", "FeF3", "MnO2"] self.conversion_electrodes = {} @@ -67,16 +67,18 @@ def setUp(self): def test_init(self): # both 'LiCoO2' and "FeF3" are using Li+ as working ion; MnO2 is for the multivalent Mg2+ ion for formula in self.formulas: - c = self.conversion_electrodes[formula]["CE"] + conv_electrode = self.conversion_electrodes[formula]["CE"] - assert len(c.get_sub_electrodes(adjacent_only=True)) == c.num_steps - assert len(c.get_sub_electrodes(adjacent_only=False)) == sum(range(1, c.num_steps + 1)) - p = self.expected_properties[formula] + assert len(conv_electrode.get_sub_electrodes(adjacent_only=True)) == conv_electrode.num_steps + assert len(conv_electrode.get_sub_electrodes(adjacent_only=False)) == sum( + range(1, conv_electrode.num_steps + 1) + ) + props = self.expected_properties[formula] - for k, v in p.items(): - assert getattr(c, f"get_{k}")() == approx(v, abs=1e-2) + for key, val in props.items(): + assert getattr(conv_electrode, f"get_{key}")() == approx(val, abs=1e-2) - assert {*c.get_summary_dict(print_subelectrodes=True)} == { + assert {*conv_electrode.get_summary_dict(print_subelectrodes=True)} == { "adj_pairs", "reactions", "energy_vol", @@ -98,17 +100,17 @@ def test_init(self): } # try to export/import a voltage pair via a dict - pair = c.voltage_pairs[0] + pair = conv_electrode.voltage_pairs[0] dct = pair.as_dict() pair2 = ConversionVoltagePair.from_dict(dct) for prop in ["voltage", "mass_charge", "mass_discharge"]: assert getattr(pair, prop) == getattr(pair2, prop), 2 # try to create an electrode from a dict and test methods - dct = c.as_dict() + dct = conv_electrode.as_dict() electrode = ConversionElectrode.from_dict(dct) - for k, v in p.items(): - assert getattr(electrode, "get_" + k)() == approx(v, abs=1e-2) + for key, val in props.items(): + assert getattr(electrode, f"get_{key}")() == approx(val, abs=1e-2) def test_repr(self): conv_electrode = self.conversion_electrodes[self.formulas[0]]["CE"] @@ -121,11 +123,11 @@ def test_repr(self): def test_summary(self): key_map = {"specific_energy": "energy_grav", "energy_density": "energy_vol"} - for f in self.formulas: - c = self.conversion_electrodes[f]["CE"] - dct = c.get_summary_dict() - p = self.expected_properties[f] - for k, v in p.items(): + for formula in self.formulas: + conv_elec = self.conversion_electrodes[formula]["CE"] + dct = conv_elec.get_summary_dict() + props = self.expected_properties[formula] + for k, v in props.items(): summary_key = key_map.get(k, k) assert dct[summary_key] == approx(v, abs=1e-2) @@ -133,10 +135,10 @@ def test_composite(self): # check entries in charged/discharged state for formula in self.formulas: CE = self.conversion_electrodes[formula]["CE"] - for step, vpair in enumerate(CE.voltage_pairs): + for step, volt_pair in enumerate(CE.voltage_pairs): # entries_charge/entries_discharge attributes should return entries equal with the expected composite_dict = self.expected_composite[formula] for attr in ["entries_charge", "entries_discharge"]: # composite at each discharge step, of which entry object is simplified to reduced formula - entries_formula_list = [entry.reduced_formula for entry in getattr(vpair, attr)] + entries_formula_list = [entry.reduced_formula for entry in getattr(volt_pair, attr)] assert entries_formula_list == composite_dict[attr][step] diff --git a/tests/apps/battery/test_insertion_battery.py b/tests/apps/battery/test_insertion_battery.py index 129cb66159f..68c68cb44a6 100644 --- a/tests/apps/battery/test_insertion_battery.py +++ b/tests/apps/battery/test_insertion_battery.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from monty.json import MontyDecoder, MontyEncoder from pytest import approx @@ -11,7 +11,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestInsertionElectrode(unittest.TestCase): +class TestInsertionElectrode(TestCase): def setUp(self): self.entry_Li = ComputedEntry("Li", -1.90753119) self.entry_Ca = ComputedEntry("Ca", -1.99689568) diff --git a/tests/apps/battery/test_plotter.py b/tests/apps/battery/test_plotter.py index 91ee3fc91ae..354a8e1123c 100644 --- a/tests/apps/battery/test_plotter.py +++ b/tests/apps/battery/test_plotter.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from monty.json import MontyDecoder @@ -13,7 +13,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestVoltageProfilePlotter(unittest.TestCase): +class TestVoltageProfilePlotter(TestCase): def setUp(self): entry_Li = ComputedEntry("Li", -1.90753119) diff --git a/tests/apps/borg/test_hive.py b/tests/apps/borg/test_hive.py index a174803bf00..810a7c5f8ac 100644 --- a/tests/apps/borg/test_hive.py +++ b/tests/apps/borg/test_hive.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import unittest +from unittest import TestCase from pytest import approx @@ -14,7 +14,7 @@ from pymatgen.util.testing import TEST_FILES_DIR, VASP_OUT_DIR -class TestVaspToComputedEntryDrone(unittest.TestCase): +class TestVaspToComputedEntryDrone(TestCase): def setUp(self): self.drone = VaspToComputedEntryDrone(data=["efermi"]) self.structure_drone = VaspToComputedEntryDrone(inc_structure=True) @@ -47,7 +47,7 @@ def test_as_from_dict(self): assert isinstance(drone, VaspToComputedEntryDrone) -class TestSimpleVaspToComputedEntryDrone(unittest.TestCase): +class TestSimpleVaspToComputedEntryDrone(TestCase): def setUp(self): self.drone = SimpleVaspToComputedEntryDrone() self.structure_drone = SimpleVaspToComputedEntryDrone(inc_structure=True) @@ -63,7 +63,7 @@ def test_as_from_dict(self): assert isinstance(drone, SimpleVaspToComputedEntryDrone) -class TestGaussianToComputedEntryDrone(unittest.TestCase): +class TestGaussianToComputedEntryDrone(TestCase): def setUp(self): self.drone = GaussianToComputedEntryDrone(data=["corrections"]) self.structure_drone = GaussianToComputedEntryDrone(inc_structure=True) diff --git a/tests/apps/borg/test_queen.py b/tests/apps/borg/test_queen.py index c05de552c6c..3536d04b778 100644 --- a/tests/apps/borg/test_queen.py +++ b/tests/apps/borg/test_queen.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from pytest import approx from pymatgen.apps.borg.hive import VaspToComputedEntryDrone from pymatgen.apps.borg.queen import BorgQueen @@ -13,14 +13,14 @@ TEST_DIR = f"{TEST_FILES_DIR}/app_borg" -class TestBorgQueen(unittest.TestCase): +class TestBorgQueen: def test_get_data(self): """Test get data from vasprun.xml.xe.gz file.""" drone = VaspToComputedEntryDrone() queen = BorgQueen(drone, TEST_DIR, 1) data = queen.get_data() assert len(data) == 1 - assert data[0].energy == 0.5559329 + assert data[0].energy == approx(0.5559329, 1e-6) def test_load_data(self): drone = VaspToComputedEntryDrone() diff --git a/tests/command_line/test_chargemol_caller.py b/tests/command_line/test_chargemol_caller.py index 5f7fb8e5557..8ec8f4a2ee2 100644 --- a/tests/command_line/test_chargemol_caller.py +++ b/tests/command_line/test_chargemol_caller.py @@ -1,13 +1,11 @@ from __future__ import annotations -import unittest - from pymatgen.command_line.chargemol_caller import ChargemolAnalysis from pymatgen.core import Element from pymatgen.util.testing import TEST_FILES_DIR -class TestChargemolAnalysis(unittest.TestCase): +class TestChargemolAnalysis: def test_parse_chargemol(self): test_dir = f"{TEST_FILES_DIR}/chargemol/spin_unpolarized" ca = ChargemolAnalysis(path=test_dir, run_chargemol=False) diff --git a/tests/command_line/test_critic2_caller.py b/tests/command_line/test_critic2_caller.py index fd97e08ba61..3ed30c65231 100644 --- a/tests/command_line/test_critic2_caller.py +++ b/tests/command_line/test_critic2_caller.py @@ -1,7 +1,7 @@ from __future__ import annotations -import unittest from shutil import which +from unittest import TestCase import pytest from pytest import approx @@ -19,7 +19,7 @@ @pytest.mark.skipif(not which("critic2"), reason="critic2 executable not present") -class TestCritic2Caller(unittest.TestCase): +class TestCritic2Caller: def test_from_path(self): # uses CHGCARs c2c = Critic2Caller.from_path(f"{TEST_FILES_DIR}/bader") @@ -74,7 +74,7 @@ def test_from_structure(self): assert "ERROR : load int.CHGCAR id chg_int zpsp Mo 6 S 6" in c2c._input_script -class TestCritic2Analysis(unittest.TestCase): +class TestCritic2Analysis(TestCase): def setUp(self): stdout_file = f"{TEST_FILES_DIR}/critic2/MoS2_critic2_stdout.txt" stdout_file_new_format = f"{TEST_FILES_DIR}/critic2/MoS2_critic2_stdout_new_format.txt" @@ -88,7 +88,7 @@ def setUp(self): self.c2o = Critic2Analysis(structure, reference_stdout) self.c2o_new_format = Critic2Analysis(structure, reference_stdout_new_format) - def test_properties_to_from_dict(self): + def test_to_from_dict(self): assert len(self.c2o.critical_points) == 6 assert len(self.c2o.nodes) == 14 assert len(self.c2o.edges) == 10 diff --git a/tests/command_line/test_enumlib_caller.py b/tests/command_line/test_enumlib_caller.py index e22fb12ba85..1621b34143d 100644 --- a/tests/command_line/test_enumlib_caller.py +++ b/tests/command_line/test_enumlib_caller.py @@ -28,8 +28,8 @@ def test_init(self): adaptor.run() structures = adaptor.structures assert len(structures) == 86 - for s in structures: - assert s.composition.get_atomic_fraction(Element("Li")) == approx(0.5 / 6.5) + for struct_trafo in structures: + assert struct_trafo.composition.get_atomic_fraction(Element("Li")) == approx(0.5 / 6.5) adaptor = EnumlibAdaptor(sub_trans.apply_transformation(struct), 1, 2, refine_structure=True) adaptor.run() structures = adaptor.structures @@ -40,8 +40,8 @@ def test_init(self): adaptor.run() structures = adaptor.structures assert len(structures) == 1 - for s in structures: - assert s.composition.get_atomic_fraction(Element("Li")) == approx(0.25 / 6.25) + for struct_trafo in structures: + assert struct_trafo.composition.get_atomic_fraction(Element("Li")) == approx(0.25 / 6.25) # Make sure it works for completely disordered structures. struct = Structure(np.eye(3) * 10, [{"Fe": 0.5}], [[0, 0, 0]]) @@ -52,11 +52,11 @@ def test_init(self): # Make sure it works properly when symmetry is broken by ordered sites. struct = self.get_structure("LiFePO4") sub_trans = SubstitutionTransformation({"Li": {"Li": 0.25}}) - s = sub_trans.apply_transformation(struct) + struct_trafo = sub_trans.apply_transformation(struct) # REmove some ordered sites to break symmetry. remove_trans = RemoveSitesTransformation([4, 7]) - s = remove_trans.apply_transformation(s) - adaptor = EnumlibAdaptor(s, 1, 1, enum_precision_parameter=0.01) + struct_trafo = remove_trans.apply_transformation(struct_trafo) + adaptor = EnumlibAdaptor(struct_trafo, 1, 1, enum_precision_parameter=0.01) adaptor.run() structures = adaptor.structures assert len(structures) == 4 diff --git a/tests/command_line/test_gulp_caller.py b/tests/command_line/test_gulp_caller.py index b68cfe24c46..f0a00115616 100644 --- a/tests/command_line/test_gulp_caller.py +++ b/tests/command_line/test_gulp_caller.py @@ -10,6 +10,7 @@ import sys import unittest from shutil import which +from unittest import TestCase import numpy as np import pytest @@ -34,7 +35,7 @@ @pytest.mark.skipif(not gulp_present, reason="gulp not present.") -class TestGulpCaller(unittest.TestCase): +class TestGulpCaller: def test_run(self): mgo_lattice = np.eye(3) * 4.212 mgo_specie = ["Mg"] * 4 + ["O"] * 4 @@ -97,7 +98,7 @@ def test_decimal(self): @pytest.mark.skipif(not gulp_present, reason="gulp not present.") -class TestGulpIO(unittest.TestCase): +class TestGulpIO(TestCase): def setUp(self): self.structure = Structure.from_file(f"{VASP_IN_DIR}/POSCAR_Al12O18") self.gio = GulpIO() @@ -259,7 +260,7 @@ def test_tersoff_input(self): @pytest.mark.skipif(not gulp_present, reason="gulp not present.") -class TestGlobalFunctions(unittest.TestCase): +class TestGlobalFunctions(TestCase): def setUp(self): mgo_latt = np.eye(3) * 4.212 mgo_specie = ["Mg", "O"] * 4 @@ -305,7 +306,7 @@ def test_get_energy_relax_structure_buckingham(self): @pytest.mark.skipif(not gulp_present, reason="gulp not present.") -class TestBuckinghamPotentialLewis(unittest.TestCase): +class TestBuckinghamPotentialLewis(TestCase): def setUp(self): self.bpl = BuckinghamPotential("lewis") @@ -333,7 +334,7 @@ def test_spring(self): @pytest.mark.skipif(not gulp_present, reason="gulp not present.") -class TestBuckinghamPotentialBush(unittest.TestCase): +class TestBuckinghamPotentialBush(TestCase): def setUp(self): self.bpb = BuckinghamPotential("bush") diff --git a/tests/core/test_bonds.py b/tests/core/test_bonds.py index ebe768ad9ba..777d938619a 100644 --- a/tests/core/test_bonds.py +++ b/tests/core/test_bonds.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import pytest from pytest import approx @@ -16,7 +14,7 @@ __date__ = "Jul 26, 2012" -class TestCovalentBond(unittest.TestCase): +class TestCovalentBond: def test_length(self): site1 = Site("C", [0, 0, 0]) site2 = Site("H", [0, 0.7, 0.6]) @@ -46,7 +44,7 @@ def test_str(self): assert CovalentBond(site1, site2) is not None -class TestFunc(unittest.TestCase): +class TestFunc: def test_get_bond_length(self): assert approx(get_bond_length("C", "C", 1) - 1.54) == 0 assert approx(get_bond_length("C", "C", 2) - 1.34) == 0 diff --git a/tests/core/test_composition.py b/tests/core/test_composition.py index ed3e6f85a2a..2d26cebf992 100644 --- a/tests/core/test_composition.py +++ b/tests/core/test_composition.py @@ -7,7 +7,6 @@ from __future__ import annotations import random -import unittest import pytest from numpy.testing import assert_allclose @@ -93,7 +92,7 @@ def test_in(self): # Test float in Composition comp = Composition({Element("Fe"): 2}) - with pytest.raises(TypeError, match="expected string or bytes-like object"): + with pytest.raises(TypeError, match="Invalid key=1.5 for Composition"): assert 1.5 in comp # Test DummySpecies in Composition @@ -132,7 +131,7 @@ def test_init(self): assert Composition({"Fe": 4, "Li": 4, "O": 16, "P": 4}).formula == "Li4 Fe4 P4 O16" - with pytest.raises(TypeError, match="expected string or bytes-like object"): + with pytest.raises(ValueError, match="Can't parse Element or Species from"): Composition({None: 4, "Li": 4, "O": 16, "P": 4}) assert Composition({1: 2, 8: 1}).formula == "H2 O1" @@ -141,6 +140,11 @@ def test_init(self): comp = Composition({"S": Composition.amount_tolerance / 2}) assert len(comp.elements) == 0 + # test Composition from int/float raises + for val in (1, 2.5): + with pytest.raises(TypeError, match=f"{type(val).__name__!r} object is not iterable"): + Composition(val) + def test_str_and_repr(self): test_cases = [ ( @@ -212,10 +216,7 @@ def test_formula(self): assert Composition("(C)((C)0.9(B)0.1)") == Composition("C1.9 B0.1") assert Composition("NaN").reduced_formula == "NaN" - with pytest.raises( - ValueError, - match=r"float\('NaN'\) is not a valid Composition, did you mean str\('NaN'\)\?", - ): + with pytest.raises(ValueError, match=r"float\('NaN'\) is not a valid Composition, did you mean 'NaN'\?"): Composition(float("NaN")) # test bad formulas raise ValueError @@ -672,19 +673,27 @@ def test_oxi_state_decoration(self): def test_metallofullerene(self): # Test: Parse Metallofullerene formula (e.g. Y3N@C80) - formula = "Y3N@C80" - sym_dict = {"Y": 3, "N": 1, "C": 80} - cmp = Composition(formula) - cmp2 = Composition.from_dict(sym_dict) - assert cmp == cmp2 + comp1 = Composition("Y3N@C80") + comp2 = Composition({"Y": 3, "N": 1, "C": 80}) + assert comp1 == comp2 def test_contains_element_type(self): - formula = "EuTiO3" - cmp = Composition(formula) - assert cmp.contains_element_type("lanthanoid") - assert not cmp.contains_element_type("noble_gas") - assert cmp.contains_element_type("f-block") - assert not cmp.contains_element_type("s-block") + EuTiO3 = Composition("EuTiO3") + assert EuTiO3.contains_element_type("lanthanoid") is True + assert EuTiO3.contains_element_type("noble_gas") is False + assert EuTiO3.contains_element_type("f-block") is True + assert EuTiO3.contains_element_type("s-block") is False + assert EuTiO3.contains_element_type("alkali") is False + NaCl = Composition("NaCl") + assert NaCl.contains_element_type("halogen") is True + assert NaCl.contains_element_type("alkali") is True + assert NaCl.contains_element_type("s-block") is True + assert NaCl.contains_element_type("p-block") is True + assert NaCl.contains_element_type("d-block") is False + assert NaCl.contains_element_type("f-block") is False + + with pytest.raises(ValueError, match="Invalid category='invalid', pick from"): + EuTiO3.contains_element_type("invalid") def test_chemical_system(self): assert Composition({"Na": 1, "Cl": 1}).chemical_system == "Cl-Na" @@ -692,34 +701,34 @@ def test_chemical_system(self): def test_is_valid(self): formula = "NaCl" - cmp = Composition(formula) - assert cmp.valid + comp = Composition(formula) + assert comp.valid formula = "NaClX" - cmp = Composition(formula) - assert not cmp.valid + comp = Composition(formula) + assert not comp.valid with pytest.raises(ValueError, match="Composition is not valid, contains: Na, Cl, X0+"): Composition("NaClX", strict=True) def test_remove_charges(self): - cmp1 = Composition({"Al3+": 2.0, "O2-": 3.0}) + comp1 = Composition({"Al3+": 2.0, "O2-": 3.0}) - cmp2 = Composition({"Al": 2.0, "O": 3.0}) - assert str(cmp1) != str(cmp2) + comp2 = Composition({"Al": 2.0, "O": 3.0}) + assert str(comp1) != str(comp2) - cmp1 = cmp1.remove_charges() - assert str(cmp1) == str(cmp2) + comp1 = comp1.remove_charges() + assert str(comp1) == str(comp2) - cmp1 = cmp1.remove_charges() - assert str(cmp1) == str(cmp2) + comp1 = comp1.remove_charges() + assert str(comp1) == str(comp2) - cmp1 = Composition({"Fe3+": 2.0, "Fe2+": 3.0, "O2-": 6.0}) - cmp2 = Composition({"Fe": 5.0, "O": 6.0}) - assert str(cmp1) != str(cmp2) + comp1 = Composition({"Fe3+": 2.0, "Fe2+": 3.0, "O2-": 6.0}) + comp2 = Composition({"Fe": 5.0, "O": 6.0}) + assert str(comp1) != str(comp2) - cmp1 = cmp1.remove_charges() - assert str(cmp1) == str(cmp2) + comp1 = comp1.remove_charges() + assert str(comp1) == str(comp2) def test_replace(self): Fe2O3 = Composition("Fe2O3") @@ -794,7 +803,7 @@ def test_isotopes(self): assert "Deuterium" in [x.long_name for x in composition.elements] -class TestChemicalPotential(unittest.TestCase): +class TestChemicalPotential: def test_init(self): dct = {"Fe": 1, Element("Fe"): 1} with pytest.raises(ValueError, match="Duplicate potential specified"): diff --git a/tests/core/test_interface.py b/tests/core/test_interface.py index 2dc8e064e60..73cfe6c58ec 100644 --- a/tests/core/test_interface.py +++ b/tests/core/test_interface.py @@ -2,11 +2,330 @@ import numpy as np from numpy.testing import assert_allclose +from pytest import approx -from pymatgen.core.interface import Interface +from pymatgen.core.interface import GrainBoundary, GrainBoundaryGenerator, Interface +from pymatgen.core.structure import Structure from pymatgen.core.surface import SlabGenerator from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatgen.util.testing import PymatgenTest +from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest + +TEST_DIR = f"{TEST_FILES_DIR}/grain_boundary" + + +class TestGrainBoundary(PymatgenTest): + @classmethod + def setUpClass(cls): + cls.Cu_conv = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_conventional_standard.cif") + GB_Cu_conv = GrainBoundaryGenerator(cls.Cu_conv) + cls.Cu_GB1 = GB_Cu_conv.gb_from_parameters( + [1, 2, 3], + 123.74898859588858, + expand_times=4, + vacuum_thickness=1.5, + ab_shift=[0.0, 0.0], + plane=[1, 3, 1], + rm_ratio=0.0, + ) + cls.Cu_GB2 = GB_Cu_conv.gb_from_parameters( + [1, 2, 3], + 123.74898859588858, + expand_times=4, + vacuum_thickness=1.5, + ab_shift=[0.2, 0.2], + rm_ratio=0.0, + ) + + def test_init(self): + assert self.Cu_GB1.rotation_angle == approx(123.74898859588858) + assert self.Cu_GB1.vacuum_thickness == approx(1.5) + assert self.Cu_GB2.rotation_axis == [1, 2, 3] + assert_allclose(self.Cu_GB1.ab_shift, [0.0, 0.0]) + assert_allclose(self.Cu_GB2.ab_shift, [0.2, 0.2]) + assert self.Cu_GB1.gb_plane == [1, 3, 1] + assert self.Cu_GB2.gb_plane == [1, 2, 3] + assert_allclose(self.Cu_GB1.init_cell.lattice.matrix, self.Cu_conv.lattice.matrix) + + def test_copy(self): + Cu_GB1_copy = self.Cu_GB1.copy() + assert Cu_GB1_copy.sigma == approx(self.Cu_GB1.sigma) + assert Cu_GB1_copy.rotation_angle == approx(self.Cu_GB1.rotation_angle) + assert Cu_GB1_copy.rotation_axis == self.Cu_GB1.rotation_axis + assert Cu_GB1_copy.gb_plane == self.Cu_GB1.gb_plane + assert_allclose(Cu_GB1_copy.init_cell.lattice.matrix, self.Cu_GB1.init_cell.lattice.matrix) + assert_allclose( + Cu_GB1_copy.oriented_unit_cell.lattice.matrix, + self.Cu_GB1.oriented_unit_cell.lattice.matrix, + ) + assert_allclose(Cu_GB1_copy.lattice.matrix, self.Cu_GB1.lattice.matrix) + + def test_sigma(self): + assert self.Cu_GB1.sigma == approx(9) + assert self.Cu_GB2.sigma == approx(9) + + def test_top_grain(self): + assert len(self.Cu_GB1) == len(self.Cu_GB1.top_grain) * 2 + assert_allclose(self.Cu_GB1.lattice.matrix, self.Cu_GB1.top_grain.lattice.matrix) + + def test_bottom_grain(self): + assert len(self.Cu_GB1) == len(self.Cu_GB1.bottom_grain) * 2 + assert_allclose(self.Cu_GB1.lattice.matrix, self.Cu_GB1.bottom_grain.lattice.matrix) + + def test_coincidents(self): + assert len(self.Cu_GB1) / self.Cu_GB1.sigma == len(self.Cu_GB1.coincidents) + assert len(self.Cu_GB2) / self.Cu_GB2.sigma == len(self.Cu_GB2.coincidents) + + def test_as_dict_and_from_dict(self): + d1 = self.Cu_GB1.as_dict() + d2 = self.Cu_GB2.as_dict() + Cu_GB1_new = GrainBoundary.from_dict(d1) + Cu_GB2_new = GrainBoundary.from_dict(d2) + assert Cu_GB1_new.sigma == approx(self.Cu_GB1.sigma) + assert Cu_GB1_new.rotation_angle == approx(self.Cu_GB1.rotation_angle) + assert Cu_GB1_new.rotation_axis == self.Cu_GB1.rotation_axis + assert Cu_GB1_new.gb_plane == self.Cu_GB1.gb_plane + assert_allclose(Cu_GB1_new.init_cell.lattice.matrix, self.Cu_GB1.init_cell.lattice.matrix) + assert_allclose( + Cu_GB1_new.oriented_unit_cell.lattice.matrix, self.Cu_GB1.oriented_unit_cell.lattice.matrix, atol=1e-9 + ) + assert_allclose(Cu_GB1_new.lattice.matrix, self.Cu_GB1.lattice.matrix) + assert Cu_GB2_new.sigma == approx(self.Cu_GB2.sigma) + assert Cu_GB2_new.rotation_angle == approx(self.Cu_GB2.rotation_angle) + assert Cu_GB2_new.rotation_axis == self.Cu_GB2.rotation_axis + assert Cu_GB2_new.gb_plane == self.Cu_GB2.gb_plane + assert_allclose(Cu_GB2_new.init_cell.lattice.matrix, self.Cu_GB2.init_cell.lattice.matrix) + assert_allclose( + Cu_GB2_new.oriented_unit_cell.lattice.matrix, + self.Cu_GB2.oriented_unit_cell.lattice.matrix, + ) + assert_allclose(Cu_GB2_new.lattice.matrix, self.Cu_GB2.lattice.matrix) + + +class TestGrainBoundaryGenerator(PymatgenTest): + @classmethod + def setUpClass(cls): + cls.Cu_prim = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_primitive.cif") + cls.GB_Cu_prim = GrainBoundaryGenerator(cls.Cu_prim) + cls.Cu_conv = Structure.from_file(f"{TEST_DIR}/Cu_mp-30_conventional_standard.cif") + cls.GB_Cu_conv = GrainBoundaryGenerator(cls.Cu_conv) + cls.Be = Structure.from_file(f"{TEST_DIR}/Be_mp-87_conventional_standard.cif") + cls.GB_Be = GrainBoundaryGenerator(cls.Be) + cls.Pa = Structure.from_file(f"{TEST_DIR}/Pa_mp-62_conventional_standard.cif") + cls.GB_Pa = GrainBoundaryGenerator(cls.Pa) + cls.Br = Structure.from_file(f"{TEST_DIR}/Br_mp-23154_conventional_standard.cif") + cls.GB_Br = GrainBoundaryGenerator(cls.Br) + cls.Bi = Structure.from_file(f"{TEST_DIR}/Bi_mp-23152_primitive.cif") + cls.GB_Bi = GrainBoundaryGenerator(cls.Bi) + + def test_gb_from_parameters(self): + # from fcc primitive cell,axis[1,2,3],sigma 9. + gb_cu_123_prim1 = self.GB_Cu_prim.gb_from_parameters([1, 2, 3], 123.74898859588858, expand_times=2) + lat_mat1 = gb_cu_123_prim1.lattice.matrix + c_vec1 = np.cross(lat_mat1[0], lat_mat1[1]) / np.linalg.norm(np.cross(lat_mat1[0], lat_mat1[1])) + c_len1 = np.dot(lat_mat1[2], c_vec1) + vol_ratio = gb_cu_123_prim1.volume / self.Cu_prim.volume + assert vol_ratio == approx(9 * 2 * 2, abs=1e-8) + # test expand_times and vacuum layer + gb_cu_123_prim2 = self.GB_Cu_prim.gb_from_parameters( + [1, 2, 3], 123.74898859588858, expand_times=4, vacuum_thickness=1.5 + ) + lat_mat2 = gb_cu_123_prim2.lattice.matrix + c_vec2 = np.cross(lat_mat2[0], lat_mat2[1]) / np.linalg.norm(np.cross(lat_mat2[0], lat_mat2[1])) + c_len2 = np.dot(lat_mat2[2], c_vec2) + assert (c_len2 - 1.5 * 2) / c_len1 == approx(2) + + # test normal + gb_cu_123_prim3 = self.GB_Cu_prim.gb_from_parameters([1, 2, 3], 123.74898859588858, expand_times=2, normal=True) + lat_mat3 = gb_cu_123_prim3.lattice.matrix + c_vec3 = np.cross(lat_mat3[0], lat_mat3[1]) / np.linalg.norm(np.cross(lat_mat3[0], lat_mat3[1])) + ab_len3 = np.linalg.norm(np.cross(lat_mat3[2], c_vec3)) + assert ab_len3 == approx(0) + + # test normal in tilt boundary + # The 'finfo(np.float32).eps' is the smallest representable positive number in float32, + # which has been introduced because comparing to just zero or one failed the test by rounding errors. + gb_cu_010_conv1 = self.GB_Cu_conv.gb_from_parameters( + rotation_axis=[0, 1, 0], + rotation_angle=36.8698976458, + expand_times=1, + vacuum_thickness=1.0, + ab_shift=[0.0, 0.0], + rm_ratio=0.0, + plane=[0, 0, 1], + normal=True, + ) + assert np.all(-np.finfo(np.float32).eps <= gb_cu_010_conv1.frac_coords) + assert np.all(1 + np.finfo(np.float32).eps >= gb_cu_010_conv1.frac_coords) + + # from fcc conventional cell,axis [1,2,3], siamg 9 + gb_cu_123_conv1 = self.GB_Cu_conv.gb_from_parameters( + [1, 2, 3], 123.74898859588858, expand_times=4, vacuum_thickness=1.5 + ) + lat_mat1 = gb_cu_123_conv1.lattice.matrix + assert np.dot(lat_mat1[0], [1, 2, 3]) == approx(0) + assert np.dot(lat_mat1[1], [1, 2, 3]) == approx(0) + # test plane + gb_cu_123_conv2 = self.GB_Cu_conv.gb_from_parameters( + [1, 2, 3], + 123.74898859588858, + expand_times=2, + vacuum_thickness=1.5, + normal=False, + plane=[1, 3, 1], + ) + lat_mat2 = gb_cu_123_conv2.lattice.matrix + assert np.dot(lat_mat2[0], [1, 3, 1]) == approx(0) + assert np.dot(lat_mat2[1], [1, 3, 1]) == approx(0) + + # from hex cell,axis [1,1,1], sigma 21 + gb_Be_111_1 = self.GB_Be.gb_from_parameters( + [1, 1, 1], + 147.36310249644626, + ratio=[5, 2], + expand_times=4, + vacuum_thickness=1.5, + plane=[1, 2, 1], + ) + lat_priv = self.Be.lattice.matrix + lat_mat1 = np.matmul(gb_Be_111_1.lattice.matrix, np.linalg.inv(lat_priv)) + assert np.dot(lat_mat1[0], [1, 2, 1]) == approx(0) + assert np.dot(lat_mat1[1], [1, 2, 1]) == approx(0) + # test volume associated with sigma value + gb_Be_111_2 = self.GB_Be.gb_from_parameters([1, 1, 1], 147.36310249644626, ratio=[5, 2], expand_times=4) + vol_ratio = gb_Be_111_2.volume / self.Be.volume + assert vol_ratio == approx(19 * 2 * 4) + # test ratio = None, axis [0,0,1], sigma 7 + gb_Be_111_3 = self.GB_Be.gb_from_parameters([0, 0, 1], 21.786789298261812, ratio=[5, 2], expand_times=4) + gb_Be_111_4 = self.GB_Be.gb_from_parameters([0, 0, 1], 21.786789298261812, ratio=None, expand_times=4) + assert gb_Be_111_3.lattice.abc == gb_Be_111_4.lattice.abc + assert gb_Be_111_3.lattice.angles == gb_Be_111_4.lattice.angles + gb_Be_111_5 = self.GB_Be.gb_from_parameters([3, 1, 0], 180.0, ratio=[5, 2], expand_times=4) + gb_Be_111_6 = self.GB_Be.gb_from_parameters([3, 1, 0], 180.0, ratio=None, expand_times=4) + assert gb_Be_111_5.lattice.abc == gb_Be_111_6.lattice.abc + assert gb_Be_111_5.lattice.angles == gb_Be_111_6.lattice.angles + + # gb from tetragonal cell, axis[1,1,1], sigma 15 + gb_Pa_111_1 = self.GB_Pa.gb_from_parameters( + [1, 1, 1], 151.92751306414706, ratio=[2, 3], expand_times=4, max_search=10 + ) + vol_ratio = gb_Pa_111_1.volume / self.Pa.volume + assert vol_ratio == approx(17 * 2 * 4) + + # gb from orthorhombic cell, axis[1,1,1], sigma 83 + gb_Br_111_1 = self.GB_Br.gb_from_parameters( + [1, 1, 1], + 131.5023374652235, + ratio=[21, 20, 5], + expand_times=4, + max_search=10, + ) + vol_ratio = gb_Br_111_1.volume / self.Br.volume + assert vol_ratio == approx(83 * 2 * 4) + + # gb from rhombohedra cell, axis[1,2,0], sigma 63 + gb_Bi_120_1 = self.GB_Bi.gb_from_parameters( + [1, 2, 0], 63.310675060280246, ratio=[19, 5], expand_times=4, max_search=5 + ) + vol_ratio = gb_Bi_120_1.volume / self.Bi.volume + assert vol_ratio == approx(59 * 2 * 4) + + def test_get_ratio(self): + # hexagnal + Be_ratio = self.GB_Be.get_ratio(max_denominator=2) + assert Be_ratio == [5, 2] + Be_ratio = self.GB_Be.get_ratio(max_denominator=5) + assert Be_ratio == [12, 5] + # tetragonal + Pa_ratio = self.GB_Pa.get_ratio(max_denominator=5) + assert Pa_ratio == [2, 3] + # orthorhombic + Br_ratio = self.GB_Br.get_ratio(max_denominator=5) + assert Br_ratio == [21, 20, 5] + # orthorhombic + Bi_ratio = self.GB_Bi.get_ratio(max_denominator=5) + assert Bi_ratio == [19, 5] + + def test_enum_sigma_cubic(self): + true_100 = [5, 13, 17, 25, 29, 37, 41] + true_110 = [3, 9, 11, 17, 19, 27, 33, 41, 43] + true_111 = [3, 7, 13, 19, 21, 31, 37, 39, 43, 49] + sigma_100 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 0, 0])) + sigma_110 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 1, 0])) + sigma_111 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [1, 1, 1])) + sigma_222 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [2, 2, 2])) + sigma_888 = list(GrainBoundaryGenerator.enum_sigma_cubic(50, [8, 8, 8])) + + assert sorted(true_100) == sorted(sigma_100) + assert sorted(true_110) == sorted(sigma_110) + assert sorted(true_111) == sorted(sigma_111) + assert sorted(true_111) == sorted(sigma_222) + assert sorted(true_111) == sorted(sigma_888) + + def test_enum_sigma_hex(self): + true_100 = [17, 18, 22, 27, 38, 41] + true_001 = [7, 13, 19, 31, 37, 43, 49] + true_210 = [10, 11, 14, 25, 35, 49] + sigma_100 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [1, 0, 0], [8, 3])) + sigma_001 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [0, 0, 1], [8, 3])) + sigma_210 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [2, 1, 0], [8, 3])) + sigma_420 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [4, 2, 0], [8, 3])) + sigma_840 = list(GrainBoundaryGenerator.enum_sigma_hex(50, [8, 4, 0], [8, 3])) + + assert sorted(true_100) == sorted(sigma_100) + assert sorted(true_001) == sorted(sigma_001) + assert sorted(true_210) == sorted(sigma_210) + assert sorted(true_210) == sorted(sigma_420) + assert sorted(true_210) == sorted(sigma_840) + + def test_enum_sigma_tet(self): + true_100 = [5, 37, 41, 13, 3, 15, 39, 25, 17, 29] + true_331 = [9, 3, 21, 39, 7, 31, 43, 13, 19, 37, 49] + sigma_100 = list(GrainBoundaryGenerator.enum_sigma_tet(50, [1, 0, 0], [9, 1])) + sigma_331 = list(GrainBoundaryGenerator.enum_sigma_tet(50, [3, 3, 1], [9, 1])) + + assert sorted(true_100) == sorted(sigma_100) + assert sorted(true_331) == sorted(sigma_331) + + def test_enum_sigma_ort(self): + true_100 = [41, 37, 39, 5, 15, 17, 13, 3, 25, 29] + sigma_100 = list(GrainBoundaryGenerator.enum_sigma_ort(50, [1, 0, 0], [270, 30, 29])) + + assert sorted(true_100) == sorted(sigma_100) + + def test_enum_sigma_rho(self): + true_100 = [7, 11, 43, 13, 41, 19, 47, 31] + sigma_100 = list(GrainBoundaryGenerator.enum_sigma_rho(50, [1, 0, 0], [15, 4])) + + assert sorted(true_100) == sorted(sigma_100) + + def test_enum_possible_plane_cubic(self): + all_plane = GrainBoundaryGenerator.enum_possible_plane_cubic(4, [1, 1, 1], 60) + assert len(all_plane["Twist"]) == 1 + assert len(all_plane["Symmetric tilt"]) == 6 + assert len(all_plane["Normal tilt"]) == 12 + + def test_get_trans_mat(self): + mat1, mat2 = GrainBoundaryGenerator.get_trans_mat( + [1, 1, 1], + 95.55344419565849, + lat_type="o", + ratio=[10, 20, 21], + surface=[21, 20, 10], + normal=True, + ) + assert np.dot(mat1[0], [21, 20, 10]) == approx(0) + assert np.dot(mat1[1], [21, 20, 10]) == approx(0) + assert np.linalg.det(mat1) == approx(np.linalg.det(mat2)) + ab_len1 = np.linalg.norm(np.cross(mat1[2], [1, 1, 1])) + assert ab_len1 == approx(0) + + def test_get_rotation_angle_from_sigma(self): + true_angle = [12.680383491819821, 167.3196165081802] + angle = GrainBoundaryGenerator.get_rotation_angle_from_sigma(41, [1, 0, 0], lat_type="o", ratio=[270, 30, 29]) + assert_allclose(true_angle, angle) + close_angle = [36.86989764584403, 143.13010235415598] + angle = GrainBoundaryGenerator.get_rotation_angle_from_sigma(6, [1, 0, 0], lat_type="o", ratio=[270, 30, 29]) + assert_allclose(close_angle, angle) class TestInterface(PymatgenTest): @@ -39,8 +358,8 @@ def test_gap_setter(self): assert_allclose(interface.gap, 2.0) - max_sub_c = np.max(np.array([s.frac_coords for s in interface.substrate])[:, 2]) - min_film_c = np.min(np.array([f.frac_coords for f in interface.film])[:, 2]) + max_sub_c = np.max(np.array([site.frac_coords for site in interface.substrate])[:, 2]) + min_film_c = np.min(np.array([site.frac_coords for site in interface.film])[:, 2]) gap = (min_film_c - max_sub_c) * interface.lattice.c assert_allclose(interface.gap, gap) @@ -48,8 +367,8 @@ def test_gap_setter(self): assert_allclose(interface.gap, 3.0) - max_sub_c = np.max(np.array([s.frac_coords for s in interface.substrate])[:, 2]) - min_film_c = np.min(np.array([f.frac_coords for f in interface.film])[:, 2]) + max_sub_c = np.max(np.array([site.frac_coords for site in interface.substrate])[:, 2]) + min_film_c = np.min(np.array([site.frac_coords for site in interface.film])[:, 2]) gap = (min_film_c - max_sub_c) * interface.lattice.c assert_allclose(interface.gap, gap) @@ -61,8 +380,8 @@ def test_in_plane_offset_setter(self): assert_allclose(interface.in_plane_offset, [0.2, 0.2]) test_coords = np.array(init_coords) - for i in interface.film_indices: - test_coords[i] += [0.2, 0.2, 0] + for idx in interface.film_indices: + test_coords[idx] += [0.2, 0.2, 0] assert_allclose(np.mod(test_coords, 1.0), np.mod(interface.frac_coords, 1.0)) def test_vacuum_over_film_setter(self): diff --git a/tests/core/test_ion.py b/tests/core/test_ion.py index 192e694eb75..16037ec35be 100644 --- a/tests/core/test_ion.py +++ b/tests/core/test_ion.py @@ -1,7 +1,7 @@ from __future__ import annotations import random -import unittest +from unittest import TestCase import pytest @@ -9,7 +9,7 @@ from pymatgen.core.ion import Ion -class TestIon(unittest.TestCase): +class TestIon(TestCase): def setUp(self): self.comp = [] self.comp.append(Ion.from_formula("Li+")) @@ -108,9 +108,9 @@ def test_mixed_valence(self): assert comp.formula == "Li8 Fe6 (aq)" def test_oxi_state_guesses(self): - i = Ion.from_formula("SO4-2") - assert i.oxi_state_guesses()[0].get("S") == 6 - assert i.oxi_state_guesses()[0].get("O") == -2 + ion = Ion.from_formula("SO4-2") + assert ion.oxi_state_guesses()[0].get("S") == 6 + assert ion.oxi_state_guesses()[0].get("O") == -2 def test_alphabetical_formula(self): correct_formulas = [ @@ -129,8 +129,8 @@ def test_alphabetical_formula(self): def test_num_atoms(self): correct_num_atoms = [1, 5, 1, 4, 13, 13, 72, 1, 3] - all_natoms = [c.num_atoms for c in self.comp] - assert all_natoms == correct_num_atoms + all_n_atoms = [c.num_atoms for c in self.comp] + assert all_n_atoms == correct_num_atoms def test_anonymized_formula(self): expected_formulas = [ @@ -144,8 +144,8 @@ def test_anonymized_formula(self): "A+2", "ABC(aq)", ] - for i, _ in enumerate(self.comp): - assert self.comp[i].anonymized_formula == expected_formulas[i] + for idx, expected in enumerate(expected_formulas): + assert self.comp[idx].anonymized_formula == expected def test_from_dict(self): sym_dict = {"P": 1, "O": 4, "charge": -2} diff --git a/tests/core/test_lattice.py b/tests/core/test_lattice.py index 4c1e7adee47..3a12ef6eb8e 100644 --- a/tests/core/test_lattice.py +++ b/tests/core/test_lattice.py @@ -12,7 +12,7 @@ from pymatgen.util.testing import PymatgenTest -class LatticeTestCase(PymatgenTest): +class TestLattice(PymatgenTest): def setUp(self): self.lattice = Lattice.cubic(10.0) self.cubic = self.lattice @@ -24,19 +24,25 @@ def setUp(self): self.cubic_partial_pbc = Lattice.cubic(10.0, pbc=(True, True, False)) - family_names = [ - "cubic", - "tetragonal", - "orthorhombic", - "monoclinic", - "hexagonal", - "rhombohedral", - ] - self.families = {} - for name in family_names: + for name in ("cubic", "tetragonal", "orthorhombic", "monoclinic", "hexagonal", "rhombohedral"): self.families[name] = getattr(self, name) + def test_init(self): + len_a = 9.026 + lattice = Lattice.cubic(len_a) + assert lattice is not None, "Initialization from new_cubic failed" + assert_array_equal(lattice.pbc, (True, True, True)) + lattice2 = Lattice(np.eye(3) * len_a) + for ii in range(3): + for jj in range(3): + assert lattice.matrix[ii][jj] == lattice2.matrix[ii][jj], "Inconsistent matrix from two inits!" + assert_array_equal(self.cubic_partial_pbc.pbc, (True, True, False)) + + for bad_pbc in [(True, True), (True, True, True, True), (True, True, 2)]: + with pytest.raises(ValueError, match="pbc must be a tuple of three True/False values, got"): + Lattice(np.eye(3), pbc=bad_pbc) + def test_equal(self): assert self.cubic == self.cubic assert self.cubic == self.lattice @@ -57,17 +63,6 @@ def test_format(self): assert format(self.lattice, ".3f") == lattice_str assert format(self.lattice, ".1fp") == "{10.0, 10.0, 10.0, 90.0, 90.0, 90.0}" - def test_init(self): - len_a = 9.026 - lattice = Lattice.cubic(len_a) - assert lattice is not None, "Initialization from new_cubic failed" - assert_array_equal(lattice.pbc, (True, True, True)) - lattice2 = Lattice(np.eye(3) * len_a) - for ii in range(3): - for jj in range(3): - assert lattice.matrix[ii][jj] == lattice2.matrix[ii][jj], "Inconsistent matrix from two inits!" - assert_array_equal(self.cubic_partial_pbc.pbc, (True, True, False)) - def test_copy(self): cubic_copy = self.cubic.copy() assert cubic_copy == self.cubic @@ -110,8 +105,8 @@ def test_get_vector_along_lattice_directions(self): def test_d_hkl(self): cubic_copy = self.cubic.copy() hkl = (1, 2, 3) - dhkl = ((hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2) / (cubic_copy.a**2)) ** (-1 / 2) - assert dhkl == cubic_copy.d_hkl(hkl) + d_hkl = ((hkl[0] ** 2 + hkl[1] ** 2 + hkl[2] ** 2) / (cubic_copy.a**2)) ** (-1 / 2) + assert d_hkl == cubic_copy.d_hkl(hkl) def test_reciprocal_lattice(self): recip_latt = self.lattice.reciprocal_lattice @@ -123,8 +118,8 @@ def test_reciprocal_lattice(self): ) # Test the crystallographic version. - recip_latt_xtal = self.lattice.reciprocal_lattice_crystallographic - assert_allclose(recip_latt.matrix, recip_latt_xtal.matrix * 2 * np.pi, 5) + recip_latt_crystallographic = self.lattice.reciprocal_lattice_crystallographic + assert_allclose(recip_latt.matrix, recip_latt_crystallographic.matrix * 2 * np.pi, 5) def test_static_methods(self): expected_lengths = [3.840198, 3.84019885, 3.8401976] @@ -180,9 +175,9 @@ def test_get_lll_reduced_lattice(self): assert np.linalg.det(np.linalg.solve(expected.matrix, reduced_latt.matrix)) == approx(1) assert_allclose(sorted(reduced_latt.abc), sorted(expected.abc)) assert reduced_latt.volume == approx(lattice.volume) - latt = [7.164750, 2.481942, 0.000000, -4.298850, 2.481942, 0.000000, 0.000000, 0.000000, 14.253000] + lattice = [7.164750, 2.481942, 0.000000, -4.298850, 2.481942, 0.000000, 0.000000, 0.000000, 14.253000] expected = Lattice([-4.298850, 2.481942, 0.000000, 2.865900, 4.963884, 0.000000, 0.000000, 0.000000, 14.253000]) - reduced_latt = Lattice(latt).get_lll_reduced_lattice() + reduced_latt = Lattice(lattice).get_lll_reduced_lattice() assert np.linalg.det(np.linalg.solve(expected.matrix, reduced_latt.matrix)) == approx(1) assert_allclose(sorted(reduced_latt.abc), sorted(expected.abc)) @@ -239,23 +234,23 @@ def test_get_niggli_reduced_lattice(self): def test_find_mapping(self): matrix = [[0.1, 0.2, 0.3], [-0.1, 0.2, 0.7], [0.6, 0.9, 0.2]] - latt = Lattice(matrix) + lattice = Lattice(matrix) op = SymmOp.from_origin_axis_angle([0, 0, 0], [2, 3, 3], 35) rot = op.rotation_matrix scale = np.array([[1, 1, 0], [0, 1, 0], [0, 0, 1]]) latt2 = Lattice(np.dot(rot, np.dot(scale, matrix).T).T) - mapping = latt2.find_mapping(latt) + mapping = latt2.find_mapping(lattice) assert isinstance(mapping, tuple) aligned_out, rot_out, scale_out = mapping assert abs(np.linalg.det(rot)) == approx(1) - rotated = SymmOp.from_rotation_and_translation(rot_out).operate_multi(latt.matrix) + rotated = SymmOp.from_rotation_and_translation(rot_out).operate_multi(lattice.matrix) assert_allclose(rotated, aligned_out.matrix) assert_allclose(np.dot(scale_out, latt2.matrix), aligned_out.matrix) - assert_allclose(aligned_out.parameters, latt.parameters) + assert_allclose(aligned_out.parameters, lattice.parameters) assert not np.allclose(aligned_out.parameters, latt2.parameters) def test_find_all_mappings(self): @@ -283,15 +278,15 @@ def test_find_all_mappings(self): assert isinstance(latt, Lattice) def test_mapping_symmetry(self): - latt = Lattice.cubic(1) + lattice = Lattice.cubic(1) l2 = Lattice.orthorhombic(1.1001, 1, 1) - assert latt.find_mapping(l2, ltol=0.1) == l2.find_mapping(latt, ltol=0.1) - assert l2.find_mapping(latt, ltol=0.1) is None + assert lattice.find_mapping(l2, ltol=0.1) == l2.find_mapping(lattice, ltol=0.1) + assert l2.find_mapping(lattice, ltol=0.1) is None l2 = Lattice.orthorhombic(1.0999, 1, 1) - mapping = l2.find_mapping(latt, ltol=0.1) + mapping = l2.find_mapping(lattice, ltol=0.1) assert isinstance(mapping, tuple) assert len(mapping) == 3 - assert latt.find_mapping(l2, ltol=0.1) is not None + assert lattice.find_mapping(l2, ltol=0.1) is not None def test_as_from_dict(self): dct = self.tetragonal.as_dict() @@ -355,13 +350,13 @@ def test_dot_and_norm(self): def test_get_points_in_sphere(self): # This is a non-niggli representation of a cubic lattice - latt = Lattice([[1, 5, 0], [0, 1, 0], [5, 0, 1]]) + lattice = Lattice([[1, 5, 0], [0, 1, 0], [5, 0, 1]]) # evenly spaced points array between 0 and 1 pts = np.array(list(itertools.product(range(5), repeat=3))) / 5 - pts = latt.get_fractional_coords(pts) + pts = lattice.get_fractional_coords(pts) # Test getting neighbors within 1 neighbor distance of the origin - fcoords, dists, inds, images = latt.get_points_in_sphere(pts, [0, 0, 0], 0.20001, zip_results=False) + fcoords, dists, inds, images = lattice.get_points_in_sphere(pts, [0, 0, 0], 0.20001, zip_results=False) assert len(fcoords) == 7 # There are 7 neighbors assert np.isclose(dists, 0.2).sum() == 6 # 6 are at 0.2 assert np.isclose(dists, 0).sum() == 1 # 1 is at 0 @@ -369,7 +364,7 @@ def test_get_points_in_sphere(self): assert_array_equal(images[np.isclose(dists, 0)], [[0, 0, 0]]) # More complicated case, using the zip output - result = latt.get_points_in_sphere(pts, [0.5, 0.5, 0.5], 1.0001) + result = lattice.get_points_in_sphere(pts, [0.5, 0.5, 0.5], 1.0001) assert len(result) == 552 assert len(result[0]) == 4 # coords, dists, ind, supercell @@ -440,8 +435,8 @@ def test_get_distance_and_image(self): def test_get_distance_and_image_strict(self): for _ in range(10): - lengths = [np.random.randint(1, 100) for i in range(3)] - lattice = [np.random.rand(3) * lengths[i] for i in range(3)] + lengths = np.random.randint(1, 100, 3) + lattice = np.random.rand(3, 3) * lengths lattice = Lattice(lattice) f1 = np.random.rand(3) @@ -486,32 +481,32 @@ def test_lll_basis(self): def test_get_miller_index_from_sites(self): # test on a cubic system - m = Lattice.cubic(1) + cubic = Lattice.cubic(1) s1 = np.array([0.5, -1.5, 3]) s2 = np.array([0.5, 3.0, -1.5]) s3 = np.array([2.5, 1.5, -4.0]) - assert m.get_miller_index_from_coords([s1, s2, s3]) == (2, 1, 1) + assert cubic.get_miller_index_from_coords([s1, s2, s3]) == (2, 1, 1) # test on a hexagonal system - m = Lattice([[2.319, -4.01662582, 0.0], [2.319, 4.01662582, 0.0], [0.0, 0.0, 7.252]]) + hexagonal = Lattice([[2.319, -4.01662582, 0.0], [2.319, 4.01662582, 0.0], [0.0, 0.0, 7.252]]) s1 = np.array([2.319, 1.33887527, 6.3455]) s2 = np.array([1.1595, 0.66943764, 4.5325]) s3 = np.array([1.1595, 0.66943764, 0.9065]) - hkl = m.get_miller_index_from_coords([s1, s2, s3]) + hkl = hexagonal.get_miller_index_from_coords([s1, s2, s3]) assert hkl == (2, -1, 0) # test for previous failing structure - m = Lattice([10, 0, 0, 0, 10, 0, 0, 0, 10]) + cubic_from_flat = Lattice([10, 0, 0, 0, 10, 0, 0, 0, 10]) sites = [[0.5, 0.8, 0.8], [0.5, 0.4, 0.2], [0.5, 0.3, 0.7]] - hkl = m.get_miller_index_from_coords(sites, coords_are_cartesian=False) + hkl = cubic_from_flat.get_miller_index_from_coords(sites, coords_are_cartesian=False) assert hkl == (1, 0, 0) # test for more than 3 sites sites = [[0.5, 0.8, 0.8], [0.5, 0.4, 0.2], [0.5, 0.3, 0.7], [0.5, 0.1, 0.2]] - hkl = m.get_miller_index_from_coords(sites, coords_are_cartesian=False) + hkl = cubic_from_flat.get_miller_index_from_coords(sites, coords_are_cartesian=False) assert hkl == (1, 0, 0) def test_points_in_spheres(self): diff --git a/tests/core/test_molecular_orbitals.py b/tests/core/test_molecular_orbitals.py index 618aedac59d..4476c856d17 100644 --- a/tests/core/test_molecular_orbitals.py +++ b/tests/core/test_molecular_orbitals.py @@ -8,7 +8,7 @@ test_case = MolecularOrbitals("NaCl") -class MolecularOrbitalTestCase(PymatgenTest): +class TestMolecularOrbital(PymatgenTest): def test_max_electronegativity(self): test_elec_neg = 2.23 assert test_elec_neg == test_case.max_electronegativity() diff --git a/tests/core/test_operations.py b/tests/core/test_operations.py index 673490f3f7a..d2dcc26de42 100644 --- a/tests/core/test_operations.py +++ b/tests/core/test_operations.py @@ -8,7 +8,7 @@ from pymatgen.util.testing import PymatgenTest -class SymmOpTestCase(PymatgenTest): +class TestSymmOp(PymatgenTest): def setUp(self): self.op = SymmOp.from_axis_angle_and_translation([0, 0, 1], 30, translation_vec=[0, 0, 1]) @@ -217,13 +217,13 @@ def test_xyz(self): # update PymatgenTest for unittest2? # self.assertWarns(UserWarning, self.op.as_xyz_str) - o = SymmOp.from_xyz_str("0.5+x, 0.25+y, 0.75+z") - assert_allclose(o.translation_vector, [0.5, 0.25, 0.75]) - o = SymmOp.from_xyz_str("x + 0.5, y + 0.25, z + 0.75") - assert_allclose(o.translation_vector, [0.5, 0.25, 0.75]) + symm_op = SymmOp.from_xyz_str("0.5+x, 0.25+y, 0.75+z") + assert_allclose(symm_op.translation_vector, [0.5, 0.25, 0.75]) + symm_op = SymmOp.from_xyz_str("x + 0.5, y + 0.25, z + 0.75") + assert_allclose(symm_op.translation_vector, [0.5, 0.25, 0.75]) -class MagSymmOpTestCase(PymatgenTest): +class TestMagSymmOp(PymatgenTest): def test_xyzt_string(self): xyzt_strings = ["x, y, z, +1", "x, y, z, -1", "-y+1/2, x+1/2, x+1/2, +1"] diff --git a/tests/core/test_periodic_table.py b/tests/core/test_periodic_table.py index b4885cbe852..bc24acf4cd8 100644 --- a/tests/core/test_periodic_table.py +++ b/tests/core/test_periodic_table.py @@ -2,19 +2,20 @@ import math import pickle -import unittest from copy import deepcopy +from enum import Enum import numpy as np import pytest from pytest import approx from pymatgen.core import DummySpecies, Element, Species, get_el_sp -from pymatgen.core.periodic_table import ElementBase +from pymatgen.core.periodic_table import ElementBase, ElementType +from pymatgen.core.units import Ha_to_eV from pymatgen.util.testing import PymatgenTest -class ElementTestCase(PymatgenTest): +class TestElement(PymatgenTest): def test_init(self): assert Element("Fe").symbol == "Fe" @@ -211,8 +212,8 @@ def test_term_symbols(self): ], # f3 "Ne": [["1S0"]], } - for k, v in cases.items(): - assert Element(k).term_symbols == v + for key, val in cases.items(): + assert Element(key).term_symbols == val def test_ground_state_term_symbol(self): cases = { @@ -222,8 +223,8 @@ def test_ground_state_term_symbol(self): "Ti": "3F2.0", # d2 "Pr": "4I4.5", } # f3 - for k, v in cases.items(): - assert Element(k).ground_state_term_symbol == v + for key, val in cases.items(): + assert Element(key).ground_state_term_symbol == val def test_attributes(self): is_true = { @@ -238,71 +239,80 @@ def test_attributes(self): ("O", "Te"): "is_chalcogen", } - for key, v in is_true.items(): + for key, val in is_true.items(): for sym in key: - assert getattr(Element(sym), v), f"{sym=} is false" + assert getattr(Element(sym), val), f"{sym=} is false" - keys = [ - "mendeleev_no", + keys = ( "atomic_mass", - "electronic_structure", + "atomic_orbitals", + "atomic_orbitals_eV", "atomic_radius", - "min_oxidation_state", - "max_oxidation_state", - "electrical_resistivity", - "velocity_of_sound", - "reflectivity", - "refractive_index", - "poissons_ratio", - "molar_volume", - "thermal_conductivity", - "melting_point", + "average_anionic_radius", + "average_cationic_radius", + "average_ionic_radius", "boiling_point", - "liquid_range", - "critical_temperature", - "superconduction_temperature", - "bulk_modulus", - "youngs_modulus", "brinell_hardness", - "rigidity_modulus", - "mineral_hardness", - "vickers_hardness", - "density_of_solid", - "atomic_orbitals", + "bulk_modulus", "coefficient_of_linear_thermal_expansion", - "oxidation_states", "common_oxidation_states", - "average_ionic_radius", - "average_cationic_radius", - "average_anionic_radius", + "critical_temperature", + "density_of_solid", + "electrical_resistivity", + "electronic_structure", + "ground_level", "ionic_radii", + "ionization_energies", + "iupac_ordering", + "liquid_range", "long_name", + "max_oxidation_state", + "melting_point", + "mendeleev_no", "metallic_radius", - "iupac_ordering", - "ground_level", - "ionization_energies", - ] + "min_oxidation_state", + "mineral_hardness", + "molar_volume", + "oxidation_states", + "poissons_ratio", + "reflectivity", + "refractive_index", + "rigidity_modulus", + "superconduction_temperature", + "thermal_conductivity", + "velocity_of_sound", + "vickers_hardness", + "youngs_modulus", + ) # Test all elements up to Uranium for idx in range(1, 104): el = Element.from_Z(idx) for key in keys: - k_str = key.capitalize().replace("_", " ") - if k_str in el.data and (not str(el.data[k_str]).startswith("no data")): + key_str = key.capitalize().replace("_", " ") + if key_str in el.data and (not str(el.data[key_str]).startswith("no data")): assert getattr(el, key) is not None elif key == "long_name": assert el.long_name == el.data["Name"] elif key == "iupac_ordering": assert "IUPAC ordering" in el.data assert getattr(el, key) is not None - el = Element.from_Z(idx) + if len(el.oxidation_states) > 0: assert max(el.oxidation_states) == el.max_oxidation_state assert min(el.oxidation_states) == el.min_oxidation_state - if el.symbol not in ["He", "Ne", "Ar"]: + if el.symbol not in {"He", "Ne", "Ar"}: assert el.X > 0, f"No electroneg for {el}" + # check atomic_orbitals_eV is Ha_to_eV * atomic_orbitals + for el in Element: + if el.atomic_orbitals is None: + continue + assert el.atomic_orbitals_eV == approx( + {orb: energy * Ha_to_eV for orb, energy in el.atomic_orbitals.items()} + ) + with pytest.raises(ValueError, match="Unexpected atomic number Z=1000"): Element.from_Z(1000) @@ -366,7 +376,7 @@ def test_isotope(self): assert [el.atomic_mass for el in elems] == [1.00794, 2.013553212712, 3.0155007134] -class SpeciesTestCase(PymatgenTest): +class TestSpecies(PymatgenTest): def setUp(self): self.specie1 = Species.from_str("Fe2+") self.specie2 = Species("Fe", 3) @@ -443,8 +453,8 @@ def test_get_crystal_field_spin(self): with pytest.raises(ValueError, match="Invalid coordination or spin config"): Species("Fe", 2).get_crystal_field_spin("hex") - s = Species("Co", 3).get_crystal_field_spin("tet", spin_config="low") - assert s == 2 + spin = Species("Co", 3).get_crystal_field_spin("tet", spin_config="low") + assert spin == 2 def test_get_nmr_mom(self): assert Species("H").get_nmr_quadrupole_moment() == 2.860 @@ -525,7 +535,7 @@ def test_symbol_oxi_state_str(symbol_oxi, expected_element, expected_oxi_state): assert species._oxi_state == expected_oxi_state -class DummySpeciesTestCase(unittest.TestCase): +class TestDummySpecies: def test_init(self): self.specie1 = DummySpecies("X") with pytest.raises(ValueError, match="Xe contains Xe, which is a valid element symbol"): @@ -561,29 +571,42 @@ def test_from_str(self): def test_pickle(self): el1 = DummySpecies("X", 3) - o = pickle.dumps(el1) - assert el1 == pickle.loads(o) + pickled = pickle.dumps(el1) + assert el1 == pickle.loads(pickled) def test_sort(self): - r = sorted([Element.Fe, DummySpecies("X")]) - assert r == [DummySpecies("X"), Element.Fe] + Fe, X = Element.Fe, DummySpecies("X") + assert sorted([Fe, X]) == [X, Fe] assert DummySpecies("X", 3) < DummySpecies("X", 4) sp = Species("Fe", 2, spin=5) with pytest.raises(AttributeError) as exc: sp.spin = 6 + # for some reason different message on Windows and Mac. on Linux: 'can't set attribute' assert "can't set attribute" in str(exc.value) or "property 'spin' of 'Species' object has no setter" in str( exc.value - ) # 'can't set attribute' on Linux, for some reason different message on Windows and Mac + ) assert sp.spin == 5 -class TestFunc(unittest.TestCase): - def test_get_el_sp(self): - assert get_el_sp("Fe2+") == Species("Fe", 2) - assert get_el_sp("3") == Element.Li - assert get_el_sp("3.0") == Element.Li - assert get_el_sp("U") == Element.U - assert get_el_sp("X2+") == DummySpecies("X", 2) - assert get_el_sp("Mn3+") == Species("Mn", 3) +def test_get_el_sp(): + assert get_el_sp("Fe2+") == Species("Fe", 2) + assert get_el_sp("3") == Element.Li + assert get_el_sp(5) == Element.B + assert get_el_sp("3.0") == Element.Li + assert get_el_sp("+3.0") == Element.Li + assert get_el_sp(2.0) == Element.He + assert get_el_sp("U") == Element.U + assert get_el_sp("X2+") == DummySpecies("X", 2) + assert get_el_sp("Mn3+") == Species("Mn", 3) + assert get_el_sp("X2+spin=5") == DummySpecies("X", 2, spin=5) + + with pytest.raises(ValueError, match="Can't parse Element or Species from None"): + get_el_sp(None) + + +def test_element_type(): + assert isinstance(ElementType.actinoid, Enum) + assert isinstance(ElementType.metalloid, Enum) + assert len(ElementType) == 17 diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py index c55774da641..eb0ba992bbf 100644 --- a/tests/core/test_settings.py +++ b/tests/core/test_settings.py @@ -2,12 +2,13 @@ from typing import TYPE_CHECKING +from pymatgen.core import _load_pmg_settings + if TYPE_CHECKING: from pathlib import Path from pytest import MonkeyPatch -from pymatgen.core import _load_pmg_settings __author__ = "Janosh Riebesell" __date__ = "2022-10-21" diff --git a/tests/core/test_sites.py b/tests/core/test_sites.py index 8f8caf15650..4c6f67d6fa4 100644 --- a/tests/core/test_sites.py +++ b/tests/core/test_sites.py @@ -141,9 +141,9 @@ def test_distance_and_image(self): assert ( not (abs(dist_old - dist_new) < 1e-8) ^ (jimage_old == jimage_new).all() ), "If old dist == new dist, images must be the same!" - latt = Lattice.from_parameters(3.0, 3.1, 10.0, 2.96, 2.0, 1.0) - site = PeriodicSite("Fe", [0.1, 0.1, 0.1], latt) - site2 = PeriodicSite("Fe", [0.99, 0.99, 0.99], latt) + lattice = Lattice.from_parameters(3.0, 3.1, 10.0, 2.96, 2.0, 1.0) + site = PeriodicSite("Fe", [0.1, 0.1, 0.1], lattice) + site2 = PeriodicSite("Fe", [0.99, 0.99, 0.99], lattice) dist, img = site.distance_and_image(site2) assert dist == approx(0.15495358379511573) assert list(img) == [-11, 6, 0] @@ -250,8 +250,7 @@ def get_distance_and_image_old(site1, site2, jimage=None): nearest to the site is found. Returns: - (distance, jimage): - distance and periodic lattice translations of the other site + tuple[float, np.array]: distance and periodic lattice translations of the other site for which the distance applies. Note: diff --git a/tests/core/test_structure.py b/tests/core/test_structure.py index 94c25e04fd6..c48bf262d0c 100644 --- a/tests/core/test_structure.py +++ b/tests/core/test_structure.py @@ -526,10 +526,10 @@ def test_get_primitive_structure(self): assert len(struct.get_primitive_structure()) == 4 def test_primitive_cell_site_merging(self): - latt = Lattice.cubic(10) + lattice = Lattice.cubic(10) coords = [[0, 0, 0], [0, 0, 0.5], [0, 0, 0.26], [0, 0, 0.74]] sp = ["Ag", "Ag", "Be", "Be"] - struct = Structure(latt, sp, coords) + struct = Structure(lattice, sp, coords) dm = struct.get_primitive_structure().distance_matrix assert_allclose(dm, [[0, 2.5], [2.5, 0]]) @@ -572,7 +572,7 @@ def test_primitive_positions(self): assert prim.distance_matrix[0, 1] == approx(1.0203432356739286) def test_primitive_structure_volume_check(self): - latt = Lattice.tetragonal(10, 30) + lattice = Lattice.tetragonal(10, 30) coords = [ [0.5, 0.8, 0], [0.5, 0.2, 0], @@ -581,7 +581,7 @@ def test_primitive_structure_volume_check(self): [0.5, 0.5, 0.666], [0.5, 0.2, 0.666], ] - struct = IStructure(latt, ["Ag"] * 6, coords) + struct = IStructure(lattice, ["Ag"] * 6, coords) primitive = struct.get_primitive_structure(tolerance=0.1) assert len(primitive) == 6 @@ -670,15 +670,14 @@ def test_get_neighbor_list(self): # @skipIf(not os.getenv("CI"), reason="Only run this in CI tests") # def test_get_all_neighbors_crosscheck_old(self): - # # for i in range(100): # alpha, beta = np.random.rand(2) * 90 # a, b, c = 3 + np.random.rand(3) * 5 # species = ["H"] * 5 # frac_coords = np.random.rand(5, 3) # try: - # latt = Lattice.from_parameters(a, b, c, alpha, beta, 90) - # struct = Structure.from_spacegroup("P1", latt, species, frac_coords) + # lattice = Lattice.from_parameters(a, b, c, alpha, beta, 90) + # struct = Structure.from_spacegroup("P1", lattice, species, frac_coords) # for nn_new, nn_old in zip(struct.get_all_neighbors(4), struct.get_all_neighbors_old(4)): # sites1 = [i[0] for i in nn_new] # sites2 = [i[0] for i in nn_old] @@ -711,7 +710,7 @@ def test_get_neighbor_list(self): # }, # "sites": [ # { - # "species": [{"element": "Mn", "oxidation_state": 0, "spin": Spin.down, "occu": 1}], + # "species": [{"element": "Mn", "oxidation_state": 0, "spin": Spin.down, "occu": 1}], # "abc": [0.0, 0.5, 0.5], # "xyz": [2.8730499999999997, 3.83185, 4.1055671618015446e-16], # "label": "Mn0+,spin=-1", @@ -959,9 +958,27 @@ def test_sort(self): assert self.struct[0].species_string == "Si" assert self.struct[1].species_string == "F" + def test_replace(self): + assert self.struct.formula == "Si2" + struct = self.struct.replace(0, "O") + assert struct is self.struct + assert struct.formula == "Si1 O1" + assert_allclose(struct[0].frac_coords, [0, 0, 0]) + struct.replace(0, "O", coords=[0.25, 0.25, 0.25]) + assert struct.formula == "Si1 O1" + assert_allclose(struct[0].frac_coords, [0.25, 0.25, 0.25]) + struct.replace(0, "O", properties={"magmom": 1}) + assert struct.formula == "Si1 O1" + assert struct[0].magmom == 1 + struct.replace(0, "O", properties={"magmom": 2}, coords=[0.9, 0.9, 0.9]) + assert struct.formula == "Si1 O1" + assert struct[0].magmom == 2 + assert_allclose(struct[0].frac_coords, [0.9, 0.9, 0.9]) + def test_replace_species(self): - struct = self.struct - struct.replace_species({"Si": "Na"}) + assert self.struct.formula == "Si2" + struct = self.struct.replace_species({"Si": "Na"}) + assert struct is self.struct assert struct.formula == "Na2" # test replacement with a dictionary @@ -1023,18 +1040,25 @@ def test_append_insert_remove_replace_substitute(self): assert struct.n_elems == 4 struct.replace_species({"Ge": "Si"}) - struct.substitute(1, "hydroxyl") + substituted = struct.substitute(1, "hydroxyl") + assert substituted is struct assert struct.formula == "Si1 H1 N1 O1" assert struct.symbol_set == ("H", "N", "O", "Si") + with pytest.raises( + ValueError, match="Can't find functional group 'OH' in list. Provide explicit coordinates instead" + ): + substituted = struct.substitute(2, "OH") # Distance between O and H assert struct.get_distance(2, 3) == approx(0.96) # Distance between Si and H assert struct.get_distance(0, 3) == approx(2.09840889) - struct.remove_species(["H"]) + h_removed = struct.remove_species(["H"]) + assert h_removed is struct assert struct.formula == "Si1 N1 O1" - struct.remove_sites([1, 2]) + sites_removed = struct.remove_sites([1, 2]) + assert sites_removed is struct assert struct.formula == "Si1" def test_add_remove_site_property(self): @@ -1137,10 +1161,10 @@ def test_add_oxidation_state_by_guess(self): assert site.specie in expected def test_add_remove_spin_states(self): - latt = Lattice.cubic(4.17) + lattice = Lattice.cubic(4.17) species = ["Ni", "O"] coords = [[0, 0, 0], [0.5, 0.5, 0.5]] - nio = Structure.from_spacegroup(225, latt, species, coords) + nio = Structure.from_spacegroup(225, lattice, species, coords) # should do nothing, but not fail nio1 = nio.remove_spin() @@ -1289,10 +1313,10 @@ def test_make_supercell_labeled(self): assert set(struct.labels) == {"Si1", "Si2"} def test_disordered_supercell_primitive_cell(self): - latt = Lattice.cubic(2) + lattice = Lattice.cubic(2) coords = [[0.5, 0.5, 0.5]] sp = [{"Si": 0.54738}] - struct = Structure(latt, sp, coords) + struct = Structure(lattice, sp, coords) # this supercell often breaks things struct.make_supercell([[0, -1, 1], [-1, 1, 0], [1, 1, 1]]) assert len(struct.get_primitive_structure()) == 1 @@ -1452,19 +1476,19 @@ def test_merge_sites(self): assert_allclose(struct[1].frac_coords, [0.5, 0.5, 0.5005]) # Test for TaS2 with spacegroup 166 in 160 setting. - latt = Lattice.hexagonal(3.374351, 20.308941) + lattice = Lattice.hexagonal(3.374351, 20.308941) species = ["Ta", "S", "S"] coords = [ [0, 0, 0.944333], [0.333333, 0.666667, 0.353424], [0.666667, 0.333333, 0.535243], ] - tas2 = Structure.from_spacegroup(160, latt, species, coords) + tas2 = Structure.from_spacegroup(160, lattice, species, coords) assert len(tas2) == 13 tas2.merge_sites(mode="d") assert len(tas2) == 9 - latt = Lattice.hexagonal(3.587776, 19.622793) + lattice = Lattice.hexagonal(3.587776, 19.622793) species = ["Na", "V", "S", "S"] coords = [ [0.333333, 0.666667, 0.165000], @@ -1472,13 +1496,13 @@ def test_merge_sites(self): [0.333333, 0.666667, 0.399394], [0.666667, 0.333333, 0.597273], ] - navs2 = Structure.from_spacegroup(160, latt, species, coords) + navs2 = Structure.from_spacegroup(160, lattice, species, coords) assert len(navs2) == 18 navs2.merge_sites(mode="d") assert len(navs2) == 12 # Test that we can average the site properties that are floats - latt = Lattice.hexagonal(3.587776, 19.622793) + lattice = Lattice.hexagonal(3.587776, 19.622793) species = ["Na", "V", "S", "S"] coords = [ [0.333333, 0.666667, 0.165000], @@ -1487,7 +1511,7 @@ def test_merge_sites(self): [0.666667, 0.333333, 0.597273], ] site_props = {"prop1": [3.0, 5.0, 7.0, 11.0]} - navs2 = Structure.from_spacegroup(160, latt, species, coords, site_properties=site_props) + navs2 = Structure.from_spacegroup(160, lattice, species, coords, site_properties=site_props) navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0}) navs2.merge_sites(mode="a") assert len(navs2) == 12 @@ -1970,7 +1994,7 @@ def test_get_dist_matrix(self): def test_get_zmatrix(self): mol = IMolecule(["C", "H", "H", "H", "H"], self.coords) - zmatrix = """C + z_matrix = """C H 1 B1 H 1 B2 2 A2 H 1 B3 2 A3 3 D3 @@ -1986,10 +2010,10 @@ def test_get_zmatrix(self): A4=109.471213 D4=119.999966 """ - assert self.assert_str_content_equal(mol.get_zmatrix(), zmatrix) + assert self.assert_str_content_equal(mol.get_zmatrix(), z_matrix) def test_break_bond(self): - (mol1, mol2) = self.mol.break_bond(0, 1) + mol1, mol2 = self.mol.break_bond(0, 1) assert mol1.formula == "H3 C1" assert mol2.formula == "H1" @@ -2162,7 +2186,7 @@ def test_rotate_sites(self): assert returned is self.mol assert_allclose(self.mol.cart_coords[2], [0.889164737, 0.513359500, -0.363000000]) - def test_replace(self): + def test_replace_species(self): self.mol[0] = "Ge" assert self.mol.formula == "Ge1 H4" @@ -2226,8 +2250,7 @@ def test_substitute(self): returned = self.mol.substitute(1, sub) assert returned is self.mol assert self.mol.get_distance(0, 4) == approx(1.54) - f = Molecule(["X", "F"], [[0, 0, 0], [0, 0, 1.11]]) - self.mol.substitute(2, f) + self.mol.substitute(2, Molecule(["X", "F"], [[0, 0, 0], [0, 0, 1.11]])) assert self.mol.get_distance(0, 7) == approx(1.35) oh = Molecule( ["X", "O", "H"], diff --git a/tests/core/test_surface.py b/tests/core/test_surface.py index 4961b9dbc33..0fc58e56284 100644 --- a/tests/core/test_surface.py +++ b/tests/core/test_surface.py @@ -30,7 +30,7 @@ class TestSlab(PymatgenTest): def setUp(self): - zno1 = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/ZnO-wz.cif", primitive=False) + zno1 = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/ZnO-wz.cif", primitive=False) zno55 = SlabGenerator(zno1, [1, 0, 0], 5, 5, lll_reduce=False, center_slab=False).get_slab() Ti = Structure( @@ -49,18 +49,17 @@ def setUp(self): [[0, 0, 0], [0, 0.5, 0.5], [0.5, 0, 0.5], [0.5, 0.5, 0]], ) - m = [[3.913449, 0, 0], [0, 3.913449, 0], [0, 0, 5.842644]] - latt = Lattice(m) - fcoords = [[0.5, 0, 0.222518], [0, 0.5, 0.777482], [0, 0, 0], [0, 0, 0.5], [0.5, 0.5, 0]] - non_laue = Structure(latt, ["Nb", "Nb", "N", "N", "N"], fcoords) + lattice = Lattice([[3.913449, 0, 0], [0, 3.913449, 0], [0, 0, 5.842644]]) + frac_coords = [[0.5, 0, 0.222518], [0, 0.5, 0.777482], [0, 0, 0], [0, 0, 0.5], [0.5, 0.5, 0]] + non_laue = Structure(lattice, ["Nb", "Nb", "N", "N", "N"], frac_coords) self.ti = Ti - self.agfcc = Ag_fcc + self.ag_fcc = Ag_fcc self.zno1 = zno1 self.zno55 = zno55 self.non_laue = non_laue - self.h = Structure(Lattice.cubic(3), ["H"], [[0, 0, 0]]) - self.libcc = Structure(Lattice.cubic(3.51004), ["Li", "Li"], [[0, 0, 0], [0.5, 0.5, 0.5]]) + self.hydrogen = Structure(Lattice.cubic(3), ["H"], [[0, 0, 0]]) + self.li_bcc = Structure(Lattice.cubic(3.51004), ["Li", "Li"], [[0, 0, 0], [0.5, 0.5, 0.5]]) def test_init(self): zno_slab = Slab( @@ -72,8 +71,8 @@ def test_init(self): 0, self.zno55.scale_factor, ) - m = self.zno55.lattice.matrix - area = np.linalg.norm(np.cross(m[0], m[1])) + matrix = self.zno55.lattice.matrix + area = np.linalg.norm(np.cross(matrix[0], matrix[1])) assert zno_slab.surface_area == approx(area) assert zno_slab.lattice.parameters == self.zno55.lattice.parameters assert zno_slab.oriented_unit_cell.composition == self.zno1.composition @@ -116,8 +115,8 @@ def test_add_adsorbate_atom(self): assert str(zno_slab[8].specie) == "H" assert zno_slab.get_distance(1, 8) == approx(1.0) assert zno_slab[8].c > zno_slab[0].c - m = self.zno55.lattice.matrix - area = np.linalg.norm(np.cross(m[0], m[1])) + matrix = self.zno55.lattice.matrix + area = np.linalg.norm(np.cross(matrix[0], matrix[1])) assert zno_slab.surface_area == approx(area) assert zno_slab.lattice.parameters == self.zno55.lattice.parameters @@ -152,7 +151,7 @@ def test_surface_sites_and_symmetry(self): for boolean in [True, False]: # We will also set the slab to be centered and # off centered in order to test the center of mass - slab_gen = SlabGenerator(self.agfcc, (3, 1, 0), 10, 10, center_slab=boolean) + slab_gen = SlabGenerator(self.ag_fcc, (3, 1, 0), 10, 10, center_slab=boolean) slab = slab_gen.get_slabs()[0] surf_sites_dict = slab.get_surface_sites() assert len(surf_sites_dict["top"]) == len(surf_sites_dict["bottom"]) @@ -163,7 +162,7 @@ def test_surface_sites_and_symmetry(self): # Test if the ratio of surface sites per area is # constant, ie are the surface energies the same r1 = total_surf_sites / (2 * slab.surface_area) - slab_gen = SlabGenerator(self.agfcc, (3, 1, 0), 10, 10, primitive=False) + slab_gen = SlabGenerator(self.ag_fcc, (3, 1, 0), 10, 10, primitive=False) slab = slab_gen.get_slabs()[0] surf_sites_dict = slab.get_surface_sites() total_surf_sites = sum(len(surf_sites_dict[key]) for key in surf_sites_dict) @@ -192,7 +191,7 @@ def test_symmetrization(self): ) all_Ag_fcc_slabs = generate_all_slabs( - self.agfcc, + self.ag_fcc, 2, 10, 10, @@ -267,10 +266,10 @@ def test_oriented_unit_cell(self): # parameter for get_primitive_structure is working properly def surface_area(s): - m = s.lattice.matrix - return np.linalg.norm(np.cross(m[0], m[1])) + matrix = s.lattice.matrix + return np.linalg.norm(np.cross(matrix[0], matrix[1])) - all_slabs = generate_all_slabs(self.agfcc, 2, 10, 10, max_normal_search=3) + all_slabs = generate_all_slabs(self.ag_fcc, 2, 10, 10, max_normal_search=3) for slab in all_slabs: ouc = slab.oriented_unit_cell @@ -314,7 +313,7 @@ def test_as_dict(self): d = json.loads(dict_str) assert slab == Slab.from_dict(d) - # test initialising with a list scale_factor + # test initializing with a list scale_factor slab = Slab( self.zno55.lattice, self.zno55.species, @@ -322,7 +321,7 @@ def test_as_dict(self): self.zno55.miller_index, self.zno55.oriented_unit_cell, 0, - self.zno55.scale_factor.tolist(), + self.zno55.scale_factor, ) dict_str = json.dumps(slab.as_dict()) d = json.loads(dict_str) @@ -376,9 +375,8 @@ def test_get_slab(self): assert len(slab_non_prim) == len(slab) * 4 # Some randomized testing of cell vectors - for _ in range(1, 231): - i = random.randint(1, 230) - sg = SpaceGroup.from_int_number(i) + for spg_int in np.random.randint(1, 230, 10): + sg = SpaceGroup.from_int_number(spg_int) if sg.crystal_system == "hexagonal" or ( sg.crystal_system == "trigonal" and ( @@ -387,11 +385,11 @@ def test_get_slab(self): in [143, 144, 145, 147, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 162, 163, 164, 165] ) ): - latt = Lattice.hexagonal(5, 10) + lattice = Lattice.hexagonal(5, 10) else: # Cubic lattice is compatible with all other space groups. - latt = Lattice.cubic(5) - struct = Structure.from_spacegroup(i, latt, ["H"], [[0, 0, 0]]) + lattice = Lattice.cubic(5) + struct = Structure.from_spacegroup(spg_int, lattice, ["H"], [[0, 0, 0]]) miller = (0, 0, 0) while miller == (0, 0, 0): miller = ( @@ -400,9 +398,9 @@ def test_get_slab(self): random.randint(0, 6), ) gen = SlabGenerator(struct, miller, 10, 10) - a, b, _c = gen.oriented_unit_cell.lattice.matrix - assert np.dot(a, gen._normal) == approx(0) - assert np.dot(b, gen._normal) == approx(0) + a_vec, b_vec, _c_vec = gen.oriented_unit_cell.lattice.matrix + assert np.dot(a_vec, gen._normal) == approx(0) + assert np.dot(b_vec, gen._normal) == approx(0) def test_normal_search(self): fcc = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3), ["Fe"], [[0, 0, 0]]) @@ -432,8 +430,8 @@ def test_get_slabs(self): gen = SlabGenerator(self.get_structure("CsCl"), [0, 0, 1], 10, 10) # Test orthogonality of some internal variables. - a, b, _c = gen.oriented_unit_cell.lattice.matrix - assert np.dot(a, gen._normal) == approx(0) + a_len, b, _c = gen.oriented_unit_cell.lattice.matrix + assert np.dot(a_len, gen._normal) == approx(0) assert np.dot(b, gen._normal) == approx(0) assert len(gen.get_slabs()) == 1 @@ -456,12 +454,12 @@ def test_get_slabs(self): # slabs is of sites in LiFePO4 unit cell - 2 + 1. assert len(gen.get_slabs(tol=1e-4, ftol=1e-4)) == 15 - LiCoO2 = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_LiCoO2.cif", primitive=False) + LiCoO2 = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_LiCoO2.cif", primitive=False) gen = SlabGenerator(LiCoO2, [0, 0, 1], 10, 10) lco = gen.get_slabs(bonds={("Co", "O"): 3}) assert len(lco) == 1 - a, b, _c = gen.oriented_unit_cell.lattice.matrix - assert np.dot(a, gen._normal) == approx(0) + a_len, b, _c = gen.oriented_unit_cell.lattice.matrix + assert np.dot(a_len, gen._normal) == approx(0) assert np.dot(b, gen._normal) == approx(0) scc = Structure.from_spacegroup("Pm-3m", Lattice.cubic(3), ["Fe"], [[0, 0, 0]]) @@ -475,13 +473,13 @@ def test_get_slabs(self): # Test whether using units of hkl planes instead of Angstroms for # min_slab_size and min_vac_size will give us the same number of atoms n_atoms = [] - for a in [1, 1.4, 2.5, 3.6]: - struct = Structure.from_spacegroup("Im-3m", Lattice.cubic(a), ["Fe"], [[0, 0, 0]]) + for a_len in [1, 1.4, 2.5, 3.6]: + struct = Structure.from_spacegroup("Im-3m", Lattice.cubic(a_len), ["Fe"], [[0, 0, 0]]) slab_gen = SlabGenerator(struct, (1, 1, 1), 10, 10, in_unit_planes=True, max_normal_search=2) n_atoms.append(len(slab_gen.get_slab())) - n = n_atoms[0] - for i in n_atoms: - assert n == i + # Check if the number of atoms in all slabs is the same + for n_a in n_atoms: + assert n_atoms[0] == n_a def test_triclinic_TeI(self): # Test case for a triclinic structure of TeI. Only these three @@ -490,14 +488,14 @@ def test_triclinic_TeI(self): # in other Miller indices can cause some ambiguity when choosing a # higher tolerance. n_slabs = {(0, 0, 1): 5, (0, 1, 0): 3, (1, 0, 0): 7} - TeI = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_TeI.cif", primitive=False) + TeI = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_TeI.cif", primitive=False) for k, v in n_slabs.items(): triclinic_TeI = SlabGenerator(TeI, k, 10, 10) TeI_slabs = triclinic_TeI.get_slabs() assert v == len(TeI_slabs) def test_get_orthogonal_c_slab(self): - TeI = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_TeI.cif", primitive=False) + TeI = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_TeI.cif", primitive=False) triclinic_TeI = SlabGenerator(TeI, (0, 0, 1), 10, 10) TeI_slabs = triclinic_TeI.get_slabs() slab = TeI_slabs[0] @@ -506,7 +504,7 @@ def test_get_orthogonal_c_slab(self): assert norm_slab.lattice.angles[1] == approx(90) def test_get_orthogonal_c_slab_site_props(self): - TeI = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_TeI.cif", primitive=False) + TeI = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_TeI.cif", primitive=False) triclinic_TeI = SlabGenerator(TeI, (0, 0, 1), 10, 10) TeI_slabs = triclinic_TeI.get_slabs() slab = TeI_slabs[0] @@ -540,7 +538,7 @@ def test_get_tasker2_slabs(self): assert slab.is_symmetric() assert not slab.is_polar() - def test_nonstoichiometric_symmetrized_slab(self): + def test_non_stoichiometric_symmetrized_slab(self): # For the (111) halite slab, sometimes a non-stoichiometric # system is preferred over the stoichiometric Tasker 2. slab_gen = SlabGenerator(self.MgO, (1, 1, 1), 10, 10, max_normal_search=1) @@ -601,14 +599,14 @@ def test_bonds_broken(self): class ReconstructionGeneratorTests(PymatgenTest): def setUp(self): - latt = Lattice.cubic(3.51) + lattice = Lattice.cubic(3.51) species = ["Ni"] coords = [[0, 0, 0]] - self.Ni = Structure.from_spacegroup("Fm-3m", latt, species, coords) - latt = Lattice.cubic(2.819000) + self.Ni = Structure.from_spacegroup("Fm-3m", lattice, species, coords) + lattice = Lattice.cubic(2.819000) species = ["Fe"] coords = [[0, 0, 0]] - self.Fe = Structure.from_spacegroup("Im-3m", latt, species, coords) + self.Fe = Structure.from_spacegroup("Im-3m", lattice, species, coords) self.Si = Structure.from_spacegroup("Fd-3m", Lattice.cubic(5.430500), ["Si"], [(0, 0, 0.5)]) pmg_core_dir = os.path.dirname(pymatgen.core.__file__) @@ -637,14 +635,12 @@ def test_build_slab(self): assert len(slab) == len(recon_slab) - 2 assert recon_slab.is_symmetric() - # If a slab references another slab, - # make sure it is properly generated + # If a slab references another slab, make sure it is properly generated recon = ReconstructionGenerator(self.Ni, 10, 10, "fcc_111_adatom_ft_1x1") slab = recon.build_slabs()[0] assert slab.is_symmetric - # Test a reconstruction where it works on a specific - # termination (Fd-3m (111)) + # Test a reconstruction where it works on a specific termination (Fd-3m (111)) recon = ReconstructionGenerator(self.Si, 10, 10, "diamond_111_1x2") slab = recon.get_unreconstructed_slabs()[0] recon_slab = recon.build_slabs()[0] @@ -687,7 +683,7 @@ def test_previous_reconstructions(self): el = self.Si[0].species_string slabs = rec.build_slabs() - struct = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/reconstructions/{el}_{idx}.cif") + struct = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/reconstructions/{el}_{idx}.cif") assert any(len(match.group_structures([struct, slab])) == 1 for slab in slabs) @@ -695,11 +691,11 @@ class MillerIndexFinderTests(PymatgenTest): def setUp(self): self.cscl = Structure.from_spacegroup("Pm-3m", Lattice.cubic(4.2), ["Cs", "Cl"], [[0, 0, 0], [0.5, 0.5, 0.5]]) self.Fe = Structure.from_spacegroup("Im-3m", Lattice.cubic(2.82), ["Fe"], [[0, 0, 0]]) - mglatt = Lattice.from_parameters(3.2, 3.2, 5.13, 90, 90, 120) - self.Mg = Structure(mglatt, ["Mg", "Mg"], [[1 / 3, 2 / 3, 1 / 4], [2 / 3, 1 / 3, 3 / 4]]) + mg_lattice = Lattice.from_parameters(3.2, 3.2, 5.13, 90, 90, 120) + self.Mg = Structure(mg_lattice, ["Mg", "Mg"], [[1 / 3, 2 / 3, 1 / 4], [2 / 3, 1 / 3, 3 / 4]]) self.lifepo4 = self.get_structure("LiFePO4") - self.tei = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_TeI.cif", primitive=False) - self.LiCoO2 = Structure.from_file(f"{TEST_FILES_DIR}/surface_tests/icsd_LiCoO2.cif", primitive=False) + self.tei = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_TeI.cif", primitive=False) + self.LiCoO2 = Structure.from_file(f"{TEST_FILES_DIR}/surfaces/icsd_LiCoO2.cif", primitive=False) self.p1 = Structure( Lattice.from_parameters(3, 4, 5, 31, 43, 50), @@ -707,7 +703,7 @@ def setUp(self): [[0, 0, 0], [0.1, 0.2, 0.3]], ) self.graphite = self.get_structure("Graphite") - self.trigBi = Structure( + self.trig_bi = Structure( Lattice.from_parameters(3, 3, 10, 90, 90, 120), ["Bi", "Bi", "Bi", "Bi", "Bi", "Bi"], [ @@ -744,14 +740,14 @@ def test_get_symmetrically_distinct_miller_indices(self): assert len(indices) == 12 # Now try a trigonal system. - indices = get_symmetrically_distinct_miller_indices(self.trigBi, 2, return_hkil=True) + indices = get_symmetrically_distinct_miller_indices(self.trig_bi, 2, return_hkil=True) assert len(indices) == 17 assert all(len(hkl) == 4 for hkl in indices) # Test to see if the output with max_index i is a subset of the output with max_index i+1 - for i in range(1, 4): - assert set(get_symmetrically_distinct_miller_indices(self.trigBi, i)) <= set( - get_symmetrically_distinct_miller_indices(self.trigBi, i + 1) + for idx in range(1, 4): + assert set(get_symmetrically_distinct_miller_indices(self.trig_bi, idx)) <= set( + get_symmetrically_distinct_miller_indices(self.trig_bi, idx + 1) ) def test_get_symmetrically_equivalent_miller_indices(self): @@ -834,17 +830,17 @@ def test_generate_all_slabs(self): def test_miller_index_from_sites(self): """Test surface miller index convenience function.""" # test on a cubic system - m = Lattice.cubic(1) + cubic = Lattice.cubic(1) s1 = np.array([0.5, -1.5, 3]) s2 = np.array([0.5, 3.0, -1.5]) s3 = np.array([2.5, 1.5, -4.0]) - assert miller_index_from_sites(m, [s1, s2, s3]) == (2, 1, 1) + assert miller_index_from_sites(cubic, [s1, s2, s3]) == (2, 1, 1) # test casting from matrix to Lattice - m = [[2.319, -4.01662582, 0.0], [2.319, 4.01662582, 0.0], [0.0, 0.0, 7.252]] + matrix = [[2.319, -4.01662582, 0.0], [2.319, 4.01662582, 0.0], [0.0, 0.0, 7.252]] s1 = np.array([2.319, 1.33887527, 6.3455]) s2 = np.array([1.1595, 0.66943764, 4.5325]) s3 = np.array([1.1595, 0.66943764, 0.9065]) - hkl = miller_index_from_sites(m, [s1, s2, s3]) + hkl = miller_index_from_sites(matrix, [s1, s2, s3]) assert hkl == (2, -1, 0) diff --git a/tests/core/test_tensors.py b/tests/core/test_tensors.py index 38ba6322070..232ff0855a8 100644 --- a/tests/core/test_tensors.py +++ b/tests/core/test_tensors.py @@ -543,9 +543,9 @@ def test_get_scaled(self): assert self.non_symm.get_scaled(10) == approx(SquareTensor([[1, 2, 3], [4, 5, 6], [2, 5, 5]])) def test_polar_decomposition(self): - u, p = self.rand_sqtensor.polar_decomposition() - assert_allclose(np.dot(u, p), self.rand_sqtensor) - assert_allclose(np.eye(3), np.dot(u, np.conjugate(np.transpose(u))), atol=1e-9) + u_mat, p_mat = self.rand_sqtensor.polar_decomposition() + assert_allclose(np.dot(u_mat, p_mat), self.rand_sqtensor) + assert_allclose(np.eye(3), np.dot(u_mat, np.conjugate(np.transpose(u_mat))), atol=1e-9) def test_serialization(self): # Test base serialize-deserialize diff --git a/tests/core/test_trajectory.py b/tests/core/test_trajectory.py index 936ae5027eb..2bc67c0908e 100644 --- a/tests/core/test_trajectory.py +++ b/tests/core/test_trajectory.py @@ -29,7 +29,7 @@ def setUp(self): mol = Molecule( species, coord, charge=int(last_mol.charge), spin_multiplicity=int(last_mol.spin_multiplicity) ) - self.molecules.append(mol) + self.molecules += [mol] self.traj_mols = Trajectory( species=species, @@ -45,34 +45,30 @@ def _check_traj_equality(self, traj_1, traj_2): if traj_1.species != traj_2.species: return False - return all(i == j for i, j in zip(self.traj, traj_2)) + return all(frame1 == frame2 for frame1, frame2 in zip(self.traj, traj_2)) def _get_lattice_species_and_coords(self): lattice = ((1, 0, 0), (0, 1, 0), (0, 0, 1)) species = ["Si", "Si"] - coords = np.asarray( - [ - [[0, 0, 0], [0.5, 0.5, 0.5]], - [[0.1, 0.1, 0.1], [0.6, 0.6, 0.6]], - [[0.2, 0.2, 0.2], [0.7, 0.7, 0.7]], - ] - ) + coords = [ + [[0, 0, 0], [0.5, 0.5, 0.5]], + [[0.1, 0.1, 0.1], [0.6, 0.6, 0.6]], + [[0.2, 0.2, 0.2], [0.7, 0.7, 0.7]], + ] return lattice, species, coords def _get_species_and_coords(self): species = ["C", "O"] - coords = np.asarray( - [ - [[1.5709474478, -0.16099953, 0.0], [1.9291378639, -1.2161950538, 0.0]], - [[1.5688628148, -0.1548583957, 0.0], [1.9312224969, -1.2223361881, 0.0]], - [[1.5690858055, -0.1555153055, 0.0], [1.9309995062, -1.2216792783, 0.0]], - ] - ) + coords = [ + [[1.5709474478, -0.16099953, 0.0], [1.9291378639, -1.2161950538, 0.0]], + [[1.5688628148, -0.1548583957, 0.0], [1.9312224969, -1.2223361881, 0.0]], + [[1.5690858055, -0.1555153055, 0.0], [1.9309995062, -1.2216792783, 0.0]], + ] return species, coords, 0, 1 def test_single_index_slice(self): - assert all(self.traj[i] == self.structures[i] for i in range(0, len(self.structures), 19)) - assert all(self.traj_mols[i] == self.molecules[i] for i in range(len(self.molecules))) + assert all(self.traj[idx] == self.structures[idx] for idx in range(0, len(self.structures), 19)) + assert all(self.traj_mols[idx] == self.molecules[idx] for idx in range(len(self.molecules))) def test_slice(self): sliced_traj = self.traj[2:99:3] @@ -87,7 +83,7 @@ def test_slice(self): sliced_traj_from_structs = Trajectory.from_structures(self.structures[:-4:2]) if len(sliced_traj) == len(sliced_traj_from_structs): - assert all(sliced_traj[i] == sliced_traj_from_structs[i] for i in range(len(sliced_traj))) + assert all(sliced_traj[idx] == sliced_traj_from_structs[idx] for idx in range(len(sliced_traj))) else: raise AssertionError @@ -187,7 +183,7 @@ def test_site_properties(self): def test_frame_properties(self): lattice, species, coords = self._get_lattice_species_and_coords() - props = [{"energy_per_atom": e} for e in [-3.0001, -3.0971, -3.0465]] + props = [{"energy_per_atom": ene} for ene in [-3.0001, -3.0971, -3.0465]] traj = Trajectory(lattice=lattice, species=species, coords=coords, frame_properties=props) @@ -200,7 +196,9 @@ def test_frame_properties(self): species, coords, charge, spin = self._get_species_and_coords() - props = [{"SCF_energy_in_the_final_basis_set": e} for e in [-113.3256885788, -113.3260019471, -113.326006415]] + props = [ + {"SCF_energy_in_the_final_basis_set": ene} for ene in [-113.3256885788, -113.3260019471, -113.326006415] + ] traj = Trajectory( species=species, @@ -386,7 +384,7 @@ def test_extend_frame_props(self): pressure_2 = [2, 2.5, 2.5] # energy only properties - props_1 = [{"energy": e} for e in energy_1] + props_1 = [{"energy": ene} for ene in energy_1] traj_1 = Trajectory(lattice=lattice, species=species, coords=coords, frame_properties=props_1) # energy and pressure properties @@ -421,11 +419,11 @@ def test_displacements(self): structures = [Structure.from_file(f"{VASP_IN_DIR}/POSCAR")] displacements = np.zeros((11, *np.shape(structures[-1].frac_coords))) - for i in range(10): + for idx in range(10): displacement = np.random.random_sample(np.shape(structures[-1].frac_coords)) / 20 new_coords = displacement + structures[-1].frac_coords structures.append(Structure(structures[-1].lattice, structures[-1].species, new_coords)) - displacements[i + 1, :, :] = displacement + displacements[idx + 1, :, :] = displacement traj = Trajectory.from_structures(structures, constant_lattice=True) traj.to_displacements() diff --git a/tests/core/test_units.py b/tests/core/test_units.py index e59a4f80c09..c21d7b8e304 100644 --- a/tests/core/test_units.py +++ b/tests/core/test_units.py @@ -9,19 +9,32 @@ Energy, EnergyArray, FloatWithUnit, + Ha_to_eV, Length, LengthArray, Mass, Memory, + Ry_to_eV, Time, TimeArray, Unit, UnitError, + amu_to_kg, + bohr_to_angstrom, + eV_to_Ha, unitized, ) from pymatgen.util.testing import PymatgenTest +def test_unit_conversions(): + assert Ha_to_eV == approx(27.211386245988) + assert eV_to_Ha == 1 / Ha_to_eV + assert Ry_to_eV == approx(Ha_to_eV / 2) + assert bohr_to_angstrom == approx(0.529177210903) + assert amu_to_kg == approx(1.66053906660e-27) + + class TestUnit(PymatgenTest): def test_init(self): u1 = Unit((("m", 1), ("s", -1))) @@ -54,11 +67,11 @@ def test_energy(self): assert a + 1 == 2.1 assert str(a / d) == "1.1 eV Ha^-1" - e = Energy(1, "kJ") - f = e.to("kCal") - assert f == approx(0.2390057361376673) - assert str(e + f) == "2.0 kJ" - assert str(f + e) == "0.4780114722753346 kCal" + e_kj = Energy(1, "kJ") + e_kcal = e_kj.to("kCal") + assert e_kcal == approx(0.2390057361376673) + assert str(e_kj + e_kcal) == "2.0 kJ" + assert str(e_kcal + e_kj) == "0.4780114722753346 kCal" def test_time(self): a = Time(20, "h") @@ -93,48 +106,45 @@ def test_memory(self): def test_unitized(self): @unitized("eV") - def f(): + def func1(): return [1, 2, 3] - assert str(f()[0]) == "1.0 eV" - assert isinstance(f(), list) + assert str(func1()[0]) == "1.0 eV" + assert isinstance(func1(), list) @unitized("eV") - def g(): + def func2(): return 2, 3, 4 - assert str(g()[0]) == "2.0 eV" - assert isinstance(g(), tuple) + assert str(func2()[0]) == "2.0 eV" + assert isinstance(func2(), tuple) @unitized("pm") - def h(): - dct = {} - for i in range(3): - dct[i] = i * 20 - return dct + def func3(): + return {idx: idx * 20 for idx in range(3)} - assert str(h()[1]) == "20.0 pm" - assert isinstance(h(), dict) + assert str(func3()[1]) == "20.0 pm" + assert isinstance(func3(), dict) @unitized("kg") - def i(): + def func4(): return FloatWithUnit(5, "g") - assert i() == FloatWithUnit(0.005, "kg") + assert func4() == FloatWithUnit(0.005, "kg") @unitized("kg") - def j(): + def func5(): return ArrayWithUnit([5, 10], "g") - j_out = j() + j_out = func5() assert j_out.unit == Unit("kg") assert j_out[0] == 0.005 assert j_out[1] == 0.01 def test_compound_operations(self): - g = 10 * Length(1, "m") / (Time(1, "s") ** 2) - e = Mass(1, "kg") * g * Length(1, "m") - assert str(e) == "10.0 N m" + earth_acc = 9.81 * Length(1, "m") / (Time(1, "s") ** 2) + e_pot = Mass(1, "kg") * earth_acc * Length(1, "m") + assert str(e_pot) == "9.81 N m" form_e = FloatWithUnit(10, unit="kJ mol^-1").to("eV atom^-1") assert form_e == approx(0.103642691905) assert str(form_e.unit) == "eV atom^-1" @@ -252,16 +262,16 @@ def test_array_algebra(self): _ = ene_ha + time_s def test_factors(self): - e = EnergyArray([27.21138386, 1], "eV").to("Ha") - assert str(e).endswith("Ha") + e_arr = EnergyArray([27.21138386, 1], "eV").to("Ha") + assert str(e_arr).endswith("Ha") len_arr = LengthArray([1.0], "ang").to("bohr") assert str(len_arr).endswith(" bohr") v = ArrayWithUnit([1, 2, 3], "bohr^3").to("ang^3") assert str(v).endswith(" ang^3") def test_as_base_units(self): - x = ArrayWithUnit([5, 10], "MPa") - assert_array_equal(ArrayWithUnit([5000000, 10000000], "Pa"), x.as_base_units) + pressure_arr = ArrayWithUnit([5, 10], "MPa") + assert_array_equal(ArrayWithUnit([5000000, 10000000], "Pa"), pressure_arr.as_base_units) class TestDataPersistence(PymatgenTest): diff --git a/tests/electronic_structure/test_bandstructure.py b/tests/electronic_structure/test_bandstructure.py index e1b667fa7b2..7020d69877b 100644 --- a/tests/electronic_structure/test_bandstructure.py +++ b/tests/electronic_structure/test_bandstructure.py @@ -2,7 +2,7 @@ import copy import json -import unittest +from unittest import TestCase import numpy as np import pytest @@ -23,7 +23,7 @@ from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -class TestKpoint(unittest.TestCase): +class TestKpoint(TestCase): def setUp(self): self.lattice = Lattice.cubic(10.0) self.kpoint = Kpoint([0.1, 0.4, -0.5], self.lattice, label="X") diff --git a/tests/electronic_structure/test_boltztrap.py b/tests/electronic_structure/test_boltztrap.py index 5dc6cb2d9f1..b7cadaa49cd 100644 --- a/tests/electronic_structure/test_boltztrap.py +++ b/tests/electronic_structure/test_boltztrap.py @@ -1,8 +1,8 @@ from __future__ import annotations import json -import unittest from shutil import which +from unittest import TestCase import pytest from monty.serialization import loadfn @@ -27,7 +27,7 @@ @pytest.mark.skipif(not x_trans, reason="No x_trans.") -class TestBoltztrapAnalyzer(unittest.TestCase): +class TestBoltztrapAnalyzer(TestCase): @classmethod def setUpClass(cls): cls.bz = BoltztrapAnalyzer.from_files(f"{TEST_FILES_DIR}/boltztrap/transp/") @@ -93,9 +93,8 @@ def test_get_seebeck_eff_mass(self): sbk_mass_avg_mu = self.bz.get_seebeck_eff_mass(output="average", doping_levels=False, temp=300)[3] sbk_mass_avg_dop = self.bz.get_seebeck_eff_mass(output="average", doping_levels=True, temp=300)["n"][2] - for i in range(3): - assert sbk_mass_tens_mu[i] == approx(ref2[i], abs=1e-1) - assert sbk_mass_tens_dop[i] == approx(ref[i], abs=1e-4) + assert sbk_mass_tens_mu == approx(ref2, abs=1e-1) + assert sbk_mass_tens_dop == approx(ref, abs=1e-4) assert sbk_mass_avg_mu == approx(4361.4744008038842, abs=1e-1) assert sbk_mass_avg_dop == approx(1.661553842105382, abs=1e-4) @@ -109,52 +108,40 @@ def test_get_complexity_factor(self): sbk_mass_avg_mu = self.bz.get_complexity_factor(output="average", doping_levels=False, temp=300)[3] sbk_mass_avg_dop = self.bz.get_complexity_factor(output="average", doping_levels=True, temp=300)["n"][2] - for i in range(3): - assert sbk_mass_tens_mu[i] == approx(ref2[i], abs=1e-4) - assert sbk_mass_tens_dop[i] == approx(ref[i], abs=1e-4) + assert sbk_mass_tens_mu == approx(ref2, abs=1e-4) + assert sbk_mass_tens_dop == approx(ref, abs=1e-4) assert sbk_mass_avg_mu == approx(0.00628677029221, abs=1e-4) assert sbk_mass_avg_dop == approx(1.12322832119, abs=1e-4) def test_get_seebeck(self): ref = [-768.99078999999995, -724.43919999999991, -686.84682999999973] - for i in range(3): - assert self.bz.get_seebeck()["n"][800][3][i] == approx(ref[i]) + assert self.bz.get_seebeck()["n"][800][3] == approx(ref) assert self.bz.get_seebeck(output="average")["p"][800][3] == approx(697.608936667) assert self.bz.get_seebeck(output="average", doping_levels=False)[500][520] == approx(1266.7056) - assert self.bz.get_seebeck(output="average", doping_levels=False)[300][65] == approx( - -36.2459389333 - ) # TODO: this was originally "eigs" + assert self.bz.get_seebeck(output="average", doping_levels=False)[300][65] == approx(-36.2459389333) def test_get_conductivity(self): ref = [5.9043185000000022, 17.855599000000002, 26.462935000000002] - for i in range(3): - assert self.bz.get_conductivity()["p"][600][2][i] == approx(ref[i]) + assert self.bz.get_conductivity()["p"][600][2] == approx(ref) assert self.bz.get_conductivity(output="average")["n"][700][1] == approx(1.58736609667) assert self.bz.get_conductivity(output="average", doping_levels=False)[300][457] == approx(2.87163566667) - assert self.bz.get_conductivity( - output="average", - doping_levels=False, - # TODO: this was originally "eigs" - relaxation_time=1e-15, - )[200][63] == approx(16573.0536667) + assert self.bz.get_conductivity(output="average", doping_levels=False, relaxation_time=1e-15)[200][ + 63 + ] == approx(16573.0536667) def test_get_power_factor(self): ref = [6.2736602345523362, 17.900184232304138, 26.158282220458144] - for i in range(3): - assert self.bz.get_power_factor()["p"][200][2][i] == approx(ref[i]) + assert self.bz.get_power_factor()["p"][200][2] == approx(ref) assert self.bz.get_power_factor(output="average")["n"][600][4] == approx(411.230962976) assert self.bz.get_power_factor(output="average", doping_levels=False, relaxation_time=1e-15)[500][ 459 ] == approx(6.59277148467) - assert self.bz.get_power_factor(output="average", doping_levels=False)[800][61] == approx( - 2022.67064134 - ) # TODO: this was originally "eigs" + assert self.bz.get_power_factor(output="average", doping_levels=False)[800][61] == approx(2022.67064134) def test_get_thermal_conductivity(self): ref = [2.7719565628862623e-05, 0.00010048046886793946, 0.00015874549392499391] - for i in range(3): - assert self.bz.get_thermal_conductivity()["p"][300][2][i] == approx(ref[i]) + assert self.bz.get_thermal_conductivity()["p"][300][2] == approx(ref) assert self.bz.get_thermal_conductivity(output="average", relaxation_time=1e-15)["n"][500][0] == approx( 1.74466575612e-07 ) @@ -170,27 +157,23 @@ def test_get_thermal_conductivity(self): def test_get_zt(self): ref = [0.097408810215, 0.29335112354, 0.614673998089] - for i in range(3): - assert self.bz.get_zt()["n"][800][4][i] == approx(ref[i]) + assert self.bz.get_zt()["n"][800][4] == approx(ref) assert self.bz.get_zt(output="average", k_l=0.5)["p"][700][2] == approx(0.0170001879916) assert self.bz.get_zt(output="average", doping_levels=False, relaxation_time=1e-15)[300][240] == approx( 0.0041923533238348342 ) - eigs = self.bz.get_zt(output="eigs", doping_levels=False)[700][65] - ref_eigs = [0.082420053399668847, 0.29408035502671648, 0.40822061215079392] - for idx, val in enumerate(ref_eigs): - assert eigs[idx] == approx(val, abs=1e-5) + eig_vals = self.bz.get_zt(output="eigs", doping_levels=False)[700][65] + ref_eig_vals = [0.082420053399668847, 0.29408035502671648, 0.40822061215079392] + assert eig_vals == approx(ref_eig_vals, abs=1e-5) def test_get_average_eff_mass(self): ref = [0.76045816788363574, 0.96181142990667101, 2.9428428773308628] - for i in range(3): - assert self.bz.get_average_eff_mass()["p"][300][2][i] == approx(ref[i]) + assert self.bz.get_average_eff_mass()["p"][300][2] == approx(ref) ref = [1.1295783824744523, 1.3898454041924351, 5.2459984671977935] ref2 = [6.6648842712692078, 31.492540105738343, 37.986369302138954] - for i in range(3): - assert self.bz.get_average_eff_mass()["n"][600][1][i] == approx(ref[i]) - assert self.bz.get_average_eff_mass(doping_levels=False)[300][200][i] == approx(ref2[i]) + assert self.bz.get_average_eff_mass()["n"][600][1] == approx(ref) + assert self.bz.get_average_eff_mass(doping_levels=False)[300][200] == approx(ref2) ref = [ [9.61811430e-01, -8.25159596e-19, -4.70319444e-19], [-8.25159596e-19, 2.94284288e00, 3.00368916e-18], @@ -202,12 +185,8 @@ def test_get_average_eff_mass(self): [-1.36897140e-17, 8.74169648e-17, 2.21151980e01], ] - for i in range(3): - for j in range(3): - assert self.bz.get_average_eff_mass(output="tensor")["p"][300][2][i][j] == approx(ref[i][j], abs=1e-4) - assert self.bz.get_average_eff_mass(output="tensor", doping_levels=False)[300][500][i][j] == approx( - ref2[i][j], 4 - ) + assert self.bz.get_average_eff_mass(output="tensor")["p"][300][2] == approx(ref, abs=1e-4) + assert self.bz.get_average_eff_mass(output="tensor", doping_levels=False)[300][500] == approx(ref2, 4) assert self.bz.get_average_eff_mass(output="average")["n"][300][2] == approx(1.53769093989, abs=1e-4) def test_get_carrier_concentration(self): diff --git a/tests/electronic_structure/test_boltztrap2.py b/tests/electronic_structure/test_boltztrap2.py index 337df8d7b21..e50f3958b91 100644 --- a/tests/electronic_structure/test_boltztrap2.py +++ b/tests/electronic_structure/test_boltztrap2.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -41,7 +41,7 @@ @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestVasprunBSLoader(unittest.TestCase): +class TestVasprunBSLoader(TestCase): def setUp(self): self.loader = VasprunBSLoader(vasp_run) assert self.loader is not None @@ -81,7 +81,7 @@ def test_get_volume(self): @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestBandstructureLoader(unittest.TestCase): +class TestBandstructureLoader(TestCase): def setUp(self): self.loader = BandstructureLoader(bs, vasp_run.structures[-1]) assert self.loader is not None @@ -110,7 +110,7 @@ def test_get_volume(self): @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestVasprunLoader(unittest.TestCase): +class TestVasprunLoader(TestCase): def setUp(self): self.loader = VasprunLoader(vasp_run) assert self.loader.proj.shape == (120, 20, 2, 9) @@ -130,7 +130,7 @@ def test_from_file(self): @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestBztInterpolator(unittest.TestCase): +class TestBztInterpolator(TestCase): def setUp(self): self.loader = VasprunBSLoader(vasp_run) self.bztInterp = BztInterpolator(self.loader, lpfac=2) @@ -207,7 +207,7 @@ def test_tot_proj_dos(self): @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestBztTransportProperties(unittest.TestCase): +class TestBztTransportProperties(TestCase): def setUp(self): loader = VasprunBSLoader(vasp_run) bztInterp = BztInterpolator(loader, lpfac=2) @@ -307,7 +307,7 @@ def test_compute_properties_doping(self): @pytest.mark.skipif(not BOLTZTRAP2_PRESENT, reason="No boltztrap2, skipping tests...") -class TestBztPlotter(unittest.TestCase): +class TestBztPlotter(TestCase): def test_plot(self): loader = VasprunBSLoader(vasp_run) bztInterp = BztInterpolator(loader, lpfac=2) diff --git a/tests/electronic_structure/test_cohp.py b/tests/electronic_structure/test_cohp.py index 56a525981fd..87154eb33ca 100644 --- a/tests/electronic_structure/test_cohp.py +++ b/tests/electronic_structure/test_cohp.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase import pytest from numpy.testing import assert_allclose, assert_array_equal @@ -20,7 +20,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/cohp" -class TestCohp(unittest.TestCase): +class TestCohp(TestCase): def setUp(self): with open(f"{TEST_DIR}/cohp.json") as file: self.cohp = Cohp.from_dict(json.load(file)) @@ -84,7 +84,7 @@ def test_antibnd_states_below_efermi(self): assert self.cohp.has_antibnd_states_below_efermi(spin=Spin.up, limit=0.5) == {Spin.up: False} -class TestIcohpValue(unittest.TestCase): +class TestIcohpValue(TestCase): def setUp(self): # without spin polarization label = "1" @@ -162,7 +162,7 @@ def test_str(self): assert str(self.icohpvalue_sp) == expected -class TestCombinedIcohp(unittest.TestCase): +class TestCombinedIcohp(TestCase): def setUp(self): # without spin polarization: are_coops = False @@ -1245,7 +1245,7 @@ def test_orbital_resolved_cohp_summed_spin_channels(self): ).are_coops -class TestMethod(unittest.TestCase): +class TestMethod(TestCase): def setUp(self): filepath = f"{TEST_DIR}/COHPCAR.lobster.gz" structure = f"{TEST_DIR}/POSCAR" diff --git a/tests/electronic_structure/test_core.py b/tests/electronic_structure/test_core.py index c842431d813..da7ae2257ba 100644 --- a/tests/electronic_structure/test_core.py +++ b/tests/electronic_structure/test_core.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import numpy as np import pytest from numpy.testing import assert_allclose @@ -10,7 +8,7 @@ from pymatgen.electronic_structure.core import Magmom, Orbital, Spin -class TestSpin(unittest.TestCase): +class TestSpin: def test_init(self): assert int(Spin.up) == 1 assert int(Spin.down) == -1 @@ -25,7 +23,7 @@ def test_cached(self): assert id(Spin(1)) == id(Spin.up) -class TestOrbital(unittest.TestCase): +class TestOrbital: def test_init(self): for orb in Orbital: assert Orbital(orb.value) == orb @@ -36,7 +34,7 @@ def test_cached(self): assert id(Orbital(0)) == id(Orbital.s) -class TestMagmom(unittest.TestCase): +class TestMagmom: def test_init(self): # backwards compatibility for scalar-like magmoms magmom = Magmom(2.0) diff --git a/tests/electronic_structure/test_dos.py b/tests/electronic_structure/test_dos.py index 0e4c5a34472..474b91b6000 100644 --- a/tests/electronic_structure/test_dos.py +++ b/tests/electronic_structure/test_dos.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase import numpy as np import pytest @@ -16,7 +16,7 @@ from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -class TestDos(unittest.TestCase): +class TestDos(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/complete_dos.json") as file: self.dos = CompleteDos.from_dict(json.load(file)) @@ -51,7 +51,7 @@ def test_as_dict(self): assert not isinstance(dos_dict["densities"]["1"][0], np.float64) -class TestFermiDos(unittest.TestCase): +class TestFermiDos(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/complete_dos.json") as file: self.dos = CompleteDos.from_dict(json.load(file)) @@ -61,7 +61,7 @@ def test_doping_fermi(self): T = 300 fermi0 = self.dos.efermi fermi_range = [fermi0 - 0.5, fermi0, fermi0 + 2.0, fermi0 + 2.2] - dopings = [self.dos.get_doping(fermi_level=f, temperature=T) for f in fermi_range] + dopings = [self.dos.get_doping(fermi_level=fermi_lvl, temperature=T) for fermi_lvl in fermi_range] ref_dopings = [3.48077e21, 1.9235e18, -2.6909e16, -4.8723e19] for i, c_ref in enumerate(ref_dopings): assert abs(dopings[i] / c_ref - 1.0) <= 0.01 @@ -98,7 +98,7 @@ def test_as_dict(self): assert not isinstance(dos_dict["densities"]["1"][0], np.float64) -class TestCompleteDos(unittest.TestCase): +class TestCompleteDos(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/complete_dos.json") as file: self.dos = CompleteDos.from_dict(json.load(file)) @@ -318,14 +318,14 @@ def test_get_gap(self): assert_allclose(self.dos.get_cbm_vbm(spin=Spin.down), (4.645, 1.8140000000000001)) -class TestSpinPolarization(unittest.TestCase): +class TestSpinPolarization(TestCase): def test_spin_polarization(self): dos_path = f"{TEST_FILES_DIR}/dos_spin_polarization_mp-865805.json" dos = loadfn(dos_path) assert dos.spin_polarization == approx(0.6460514663341762) -class TestLobsterCompleteDos(unittest.TestCase): +class TestLobsterCompleteDos(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/LobsterCompleteDos_spin.json") as file: data_spin = json.load(file) diff --git a/tests/electronic_structure/test_plotter.py b/tests/electronic_structure/test_plotter.py index 92a8dc3e8aa..9ecfe62cb25 100644 --- a/tests/electronic_structure/test_plotter.py +++ b/tests/electronic_structure/test_plotter.py @@ -2,8 +2,8 @@ import json import os -import unittest from shutil import which +from unittest import TestCase import matplotlib.pyplot as plt import numpy as np @@ -138,7 +138,7 @@ def test_bs_plot_data(self): assert ( len(self.plotter.bs_plot_data()["distances"][0]) == 16 ), "wrong number of distances in the first sequence of branches" - assert sum(len(e) for e in self.plotter.bs_plot_data()["distances"]) == 160, "wrong number of distances" + assert sum(len(dist) for dist in self.plotter.bs_plot_data()["distances"]) == 160, "wrong number of distances" length = len(self.plotter.bs_plot_data(split_branches=False)["distances"][0]) assert length == 144, "wrong number of distances in the first sequence of branches" @@ -179,7 +179,7 @@ def test_get_plot(self): plt.close("all") -class TestBSPlotterProjected(unittest.TestCase): +class TestBSPlotterProjected(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/Cu2O_361_bandstructure.json") as file: dct = json.load(file) @@ -208,7 +208,7 @@ def test_methods(self): self.plotter_PbTe = BSPlotterProjected(self.bs_PbTe) -class TestBSDOSPlotter(unittest.TestCase): +class TestBSDOSPlotter: # Minimal baseline testing for get_plot. not a true test. Just checks that # it can actually execute. def test_methods(self): @@ -229,25 +229,25 @@ def test_methods(self): data_structure = [[[[0 for _ in range(12)] for _ in range(9)] for _ in range(70)] for _ in range(90)] band_struct_dict["projections"]["1"] = data_structure dct = band_struct_dict["projections"]["1"] - for i in range(len(dct)): - for j in range(len(dct[i])): - for k in range(len(dct[i][j])): - for m in range(len(dct[i][j][k])): - dct[i][j][k][m] = 0 + for ii in range(len(dct)): + for jj in range(len(dct[ii])): + for kk in range(len(dct[ii][jj])): + for ll in range(len(dct[ii][jj][kk])): + dct[ii][jj][kk][ll] = 0 # d[i][j][k][m] = np.random.rand() # generate random number for two atoms a = np.random.randint(0, 7) b = np.random.randint(0, 7) # c = np.random.randint(0,7) - dct[i][j][k][a] = np.random.rand() - dct[i][j][k][b] = np.random.rand() + dct[ii][jj][kk][a] = np.random.rand() + dct[ii][jj][kk][b] = np.random.rand() # d[i][j][k][c] = np.random.rand() band_struct = BandStructureSymmLine.from_dict(band_struct_dict) ax = plotter.get_plot(band_struct) assert isinstance(ax, plt.Axes) -class TestPlotBZ(unittest.TestCase): +class TestPlotBZ(TestCase): def setUp(self): self.rec_latt = Structure.from_file(f"{TEST_FILES_DIR}/cssr/Si.cssr").lattice.reciprocal_lattice self.kpath = [[[0.0, 0.0, 0.0], [0.5, 0.0, 0.5], [0.5, 0.25, 0.75], [0.375, 0.375, 0.75]]] @@ -292,7 +292,7 @@ def test_fold_point(self): @pytest.mark.skipif(not which("x_trans"), reason="No x_trans executable found") -class TestBoltztrapPlotter(unittest.TestCase): +class TestBoltztrapPlotter(TestCase): def setUp(self): bz = BoltztrapAnalyzer.from_files(f"{TEST_FILES_DIR}/boltztrap/transp/") self.plotter = BoltztrapPlotter(bz) diff --git a/tests/entries/test_compatibility.py b/tests/entries/test_compatibility.py index cd3f525fb00..a6a779502b0 100644 --- a/tests/entries/test_compatibility.py +++ b/tests/entries/test_compatibility.py @@ -3,10 +3,11 @@ import copy import json import os -import unittest from collections import defaultdict from math import sqrt from pathlib import Path +from typing import TYPE_CHECKING +from unittest import TestCase import pytest from monty.json import MontyDecoder @@ -18,6 +19,8 @@ from pymatgen.core.lattice import Lattice from pymatgen.core.structure import Structure from pymatgen.entries.compatibility import ( + MP2020_COMPAT_CONFIG, + MP_COMPAT_CONFIG, MU_H2O, AqueousCorrection, Compatibility, @@ -27,12 +30,16 @@ MaterialsProjectCompatibility, MITAqueousCompatibility, MITCompatibility, + needs_u_correction, ) from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry, ConstantEnergyAdjustment from pymatgen.util.testing import TEST_FILES_DIR +if TYPE_CHECKING: + from pymatgen.util.typing import CompositionLike -class TestCorrectionSpecificity(unittest.TestCase): + +class TestCorrectionSpecificity(TestCase): """Make sure corrections are only applied to GGA or GGA+U entries.""" def setUp(self): @@ -174,7 +181,7 @@ def test_overlapping_adjustments(): assert len(processed) == 0 -class TestMaterialsProjectCompatibility(unittest.TestCase): +class TestMaterialsProjectCompatibility(TestCase): def setUp(self): self.entry1 = ComputedEntry( "Fe2O3", @@ -481,7 +488,7 @@ def test_msonable(self): assert isinstance(temp_compat, MaterialsProjectCompatibility) -class TestMaterialsProjectCompatibility2020(unittest.TestCase): +class TestMaterialsProjectCompatibility2020(TestCase): def setUp(self): self.entry1 = ComputedEntry( "Fe2O3", @@ -1042,7 +1049,7 @@ def test_process_entry_with_oxidation_state(self): assert processed_entry.energy == approx(-58.97 + -6.084) -class TestMITCompatibility(unittest.TestCase): +class TestMITCompatibility(TestCase): def setUp(self): self.compat = MITCompatibility(check_potcar_hash=True) self.gga_compat = MITCompatibility("GGA", check_potcar_hash=True) @@ -1264,35 +1271,27 @@ def test_revert_to_symbols(self): with pytest.raises(ValueError, match="Cannot check hash without potcar_spec field"): self.compat.process_entry(entry) - def test_potcar_doenst_match_structure(self): + def test_potcar_not_match_structure(self): compat = MITCompatibility() - entry = ComputedEntry( - "Li2O3", - -1, - correction=0.0, - parameters={ - "is_hubbard": True, - "hubbards": {"Fe": 4.0, "O": 0}, - "run_type": "GGA+U", - "potcar_symbols": ["PAW_PBE Fe_pv 06Sep2000", "PAW_PBE O 08Apr2002"], - }, - ) + params = { + "is_hubbard": True, + "hubbards": {"Fe": 4.0, "O": 0}, + "run_type": "GGA+U", + "potcar_symbols": ["PAW_PBE Fe_pv 06Sep2000", "PAW_PBE O 08Apr2002"], + } + entry = ComputedEntry("Li2O3", -1, correction=0.0, parameters=params) assert compat.process_entry(entry) is None def test_potcar_spec_is_none(self): compat = MITCompatibility(check_potcar_hash=True) - entry = ComputedEntry( - "Li2O3", - -1, - correction=0.0, - parameters={ - "is_hubbard": True, - "hubbards": {"Fe": 4.0, "O": 0}, - "run_type": "GGA+U", - "potcar_spec": [None, None], - }, - ) + params = { + "is_hubbard": True, + "hubbards": {"Fe": 4.0, "O": 0}, + "run_type": "GGA+U", + "potcar_spec": [None, None], + } + entry = ComputedEntry("Li2O3", -1, correction=0.0, parameters=params) assert compat.process_entry(entry) is None @@ -1322,7 +1321,7 @@ def test_msonable(self): assert isinstance(temp_compat, MITCompatibility) -class TestOxideTypeCorrection(unittest.TestCase): +class TestOxideTypeCorrection(TestCase): def setUp(self): self.compat = MITCompatibility(check_potcar_hash=True) @@ -1348,7 +1347,7 @@ def test_no_struct_compat(self): def test_process_entry_superoxide(self): el_li = Element("Li") el_o = Element("O") - latt = Lattice([[3.985034, 0.0, 0.0], [0.0, 4.881506, 0.0], [0.0, 0.0, 2.959824]]) + lattice = Lattice([[3.985034, 0.0, 0.0], [0.0, 4.881506, 0.0], [0.0, 0.0, 2.959824]]) elems = [el_li, el_li, el_o, el_o, el_o, el_o] coords = [ [0.5, 0.5, 0.5], @@ -1358,7 +1357,7 @@ def test_process_entry_superoxide(self): [0.132568, 0.41491, 0.0], [0.867432, 0.58509, 0.0], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) lio2_entry = ComputedStructureEntry( struct, -3, @@ -1377,7 +1376,7 @@ def test_process_entry_superoxide(self): assert lio2_entry_corrected.energy == approx(-3 - 0.13893 * 4) def test_process_entry_peroxide(self): - latt = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) + lattice = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) el_li = Element("Li") el_o = Element("O") elems = [el_li, el_li, el_li, el_li, el_o, el_o, el_o, el_o] @@ -1391,7 +1390,7 @@ def test_process_entry_peroxide(self): [0.666666, 0.666686, 0.350813], [0.666665, 0.666684, 0.149189], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) li2o2_entry = ComputedStructureEntry( struct, -3, @@ -1413,14 +1412,14 @@ def test_process_entry_ozonide(self): el_li = Element("Li") el_o = Element("O") elems = [el_li, el_o, el_o, el_o] - latt = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) + lattice = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) coords = [ [0.513004, 0.513004, 1.000000], [0.017616, 0.017616, 0.000000], [0.649993, 0.874790, 0.775203], [0.099587, 0.874790, 0.224797], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) lio3_entry = ComputedStructureEntry( struct, -3, @@ -1442,9 +1441,9 @@ def test_process_entry_oxide(self): el_li = Element("Li") el_o = Element("O") elems = [el_li, el_li, el_o] - latt = Lattice.from_parameters(3.278, 3.278, 3.278, 60, 60, 60) + lattice = Lattice.from_parameters(3.278, 3.278, 3.278, 60, 60, 60) coords = [[0.25, 0.25, 0.25], [0.75, 0.75, 0.75], [0.0, 0.0, 0.0]] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) li2o_entry = ComputedStructureEntry( struct, -3, @@ -1463,7 +1462,7 @@ def test_process_entry_oxide(self): assert li2o_entry_corrected.energy == approx(-3.0 - 0.66975) -class TestSulfideTypeCorrection2020(unittest.TestCase): +class TestSulfideTypeCorrection2020(TestCase): def setUp(self): self.compat = MaterialsProject2020Compatibility(check_potcar_hash=False) @@ -1623,7 +1622,7 @@ def test_struct_no_struct(self): assert struct_corrected.correction == approx(nostruct_corrected.correction) -class TestOxideTypeCorrectionNoPeroxideCorr(unittest.TestCase): +class TestOxideTypeCorrectionNoPeroxideCorr(TestCase): def setUp(self): self.compat = MITCompatibility(correct_peroxide=False) @@ -1631,9 +1630,9 @@ def test_oxide_energy_corr(self): el_li = Element("Li") el_o = Element("O") elems = [el_li, el_li, el_o] - latt = Lattice.from_parameters(3.278, 3.278, 3.278, 60, 60, 60) + lattice = Lattice.from_parameters(3.278, 3.278, 3.278, 60, 60, 60) coords = [[0.25, 0.25, 0.25], [0.75, 0.75, 0.75], [0.0, 0.0, 0.0]] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) li2o_entry = ComputedStructureEntry( struct, -3, @@ -1652,7 +1651,7 @@ def test_oxide_energy_corr(self): assert li2o_entry_corrected.energy == approx(-3.0 - 0.66975) def test_peroxide_energy_corr(self): - latt = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) + lattice = Lattice.from_parameters(3.159597, 3.159572, 7.685205, 89.999884, 89.999674, 60.000510) el_li = Element("Li") el_o = Element("O") elems = [el_li, el_li, el_li, el_li, el_o, el_o, el_o, el_o] @@ -1666,7 +1665,7 @@ def test_peroxide_energy_corr(self): [0.666666, 0.666686, 0.350813], [0.666665, 0.666684, 0.149189], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) cse_params = { "is_hubbard": False, "hubbards": None, @@ -1685,14 +1684,14 @@ def test_ozonide(self): el_li = Element("Li") el_o = Element("O") elems = [el_li, el_o, el_o, el_o] - latt = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) + lattice = Lattice.from_parameters(3.999911, 3.999911, 3.999911, 133.847504, 102.228244, 95.477342) coords = [ [0.513004, 0.513004, 1.000000], [0.017616, 0.017616, 0.000000], [0.649993, 0.874790, 0.775203], [0.099587, 0.874790, 0.224797], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) lio3_entry = ComputedStructureEntry( struct, -3, @@ -1854,7 +1853,7 @@ def test_processing_entries_inplace(self): assert all(e.correction == e_copy.correction for e, e_copy in zip(entries, entries_copy)) -class TestAqueousCorrection(unittest.TestCase): +class TestAqueousCorrection(TestCase): def setUp(self): module_dir = os.path.dirname(os.path.abspath(pymatgen.entries.__file__)) fp = f"{module_dir}/MITCompatibility.yaml" @@ -1882,7 +1881,7 @@ def test_compound_energy(self): assert entry.energy == approx(-24.344373) -class TestMITAqueousCompatibility(unittest.TestCase): +class TestMITAqueousCompatibility(TestCase): def setUp(self): self.compat = MITCompatibility(check_potcar_hash=True) self.aqcompat = MITAqueousCompatibility(check_potcar_hash=True) @@ -1894,7 +1893,7 @@ def test_aqueous_compat(self): el_li = Element("Li") el_o = Element("O") el_h = Element("H") - latt = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) + lattice = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) elems = [el_h, el_h, el_li, el_li, el_o, el_o] coords = [ [0.000000, 0.500000, 0.413969], @@ -1904,7 +1903,7 @@ def test_aqueous_compat(self): [0.000000, 0.500000, 0.192672], [0.500000, 0.000000, 0.807328], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) lioh_entry = ComputedStructureEntry( struct, -3, @@ -1924,12 +1923,12 @@ def test_aqueous_compat(self): lioh_entry_aqcompat = self.aqcompat.process_entry(lioh_entry) assert lioh_entry_compat_aqcorr.energy == approx(lioh_entry_aqcompat.energy) - def test_potcar_doenst_match_structure(self): + def test_potcar_not_match_structure(self): compat = MITCompatibility() el_li = Element("Li") el_o = Element("O") el_h = Element("H") - latt = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) + lattice = Lattice.from_parameters(3.565276, 3.565276, 4.384277, 90.000000, 90.000000, 90.000000) elems = [el_h, el_h, el_li, el_li, el_o, el_o] coords = [ [0.000000, 0.500000, 0.413969], @@ -1939,7 +1938,7 @@ def test_potcar_doenst_match_structure(self): [0.000000, 0.500000, 0.192672], [0.500000, 0.000000, 0.807328], ] - struct = Structure(latt, elems, coords) + struct = Structure(lattice, elems, coords) lioh_entry = ComputedStructureEntry( struct, @@ -1975,7 +1974,7 @@ def test_dont_error_on_weird_elements(self): assert self.compat.process_entry(entry) is None -class TestCorrectionErrors2020Compatibility(unittest.TestCase): +class TestCorrectionErrors2020Compatibility(TestCase): def setUp(self): self.compat = MaterialsProject2020Compatibility() @@ -2042,3 +2041,26 @@ def test_errors(self): ): corrected_entry = self.compat.process_entry(entry) assert corrected_entry.correction_uncertainty == approx(expected) + + +@pytest.mark.parametrize( + "u_config", + [MP2020_COMPAT_CONFIG["Corrections"]["GGAUMixingCorrections"], MP_COMPAT_CONFIG["Advanced"]["UCorrections"]], +) +@pytest.mark.parametrize( + ("comp", "expected"), + [ + ("Fe2O3", {"Fe", "O"}), + ("Fe3O4", {"Fe", "O"}), + ("FeS", set()), + ("FeF3", {"Fe", "F"}), + ("LiH", set()), + ("H", set()), + (Composition("MnO"), {"Mn", "O"}), + (Composition("MnO2"), {"Mn", "O"}), + (Composition("LiFePO4"), {"Fe", "O"}), + (Composition("LiFePS4"), set()), + ], +) +def test_needs_u_correction(comp: CompositionLike, expected: set[str], u_config: dict): + assert needs_u_correction(comp, u_config=u_config) == expected diff --git a/tests/entries/test_computed_entries.py b/tests/entries/test_computed_entries.py index d05aef967ff..890daa91776 100644 --- a/tests/entries/test_computed_entries.py +++ b/tests/entries/test_computed_entries.py @@ -2,8 +2,8 @@ import copy import json -import unittest from collections import defaultdict +from unittest import TestCase import pytest from monty.json import MontyDecoder @@ -89,7 +89,7 @@ def test_temp_energy_adjustment(): assert str(ea_dct) == str(ea2.as_dict()) -class TestComputedEntry(unittest.TestCase): +class TestComputedEntry(TestCase): def setUp(self): self.entry = ComputedEntry( vasp_run.final_structure.composition, @@ -238,7 +238,7 @@ def test_copy(self): assert str(entry) == str(copy) -class TestComputedStructureEntry(unittest.TestCase): +class TestComputedStructureEntry(TestCase): def setUp(self): self.entry = ComputedStructureEntry(vasp_run.final_structure, vasp_run.final_energy, parameters=vasp_run.incar) @@ -435,7 +435,7 @@ def test_eq(self): assert copy3 != copy1 -class TestGibbsComputedStructureEntry(unittest.TestCase): +class TestGibbsComputedStructureEntry(TestCase): def setUp(self): self.temps = [300, 600, 900, 1200, 1500, 1800] self.struct = vasp_run.final_structure @@ -499,7 +499,7 @@ def test_as_from_dict(self): assert test_entry == entry assert entry.energy == approx(test_entry.energy) - def test_str(self): + def test_repr(self): assert str(self.entries_with_temps[300]).startswith( "GibbsComputedStructureEntry test - Li1 Fe4 P4 O16\nGibbs Free Energy (Formation) = -56.2127" ) diff --git a/tests/entries/test_correction_calculator.py b/tests/entries/test_correction_calculator.py index ec88c15143b..b74685a1d86 100644 --- a/tests/entries/test_correction_calculator.py +++ b/tests/entries/test_correction_calculator.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pytest @@ -10,7 +10,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/correction_calculator" -class TestCorrectionCalculator(unittest.TestCase): +class TestCorrectionCalculator(TestCase): def setUp(self): self.exclude_polyanions = ["SO4", "CO3", "NO3", "OCl3", "SiO4", "SeO3", "TiO3", "TiO4"] diff --git a/tests/entries/test_entry_tools.py b/tests/entries/test_entry_tools.py index 50429d977ed..a55e0fb6ba8 100644 --- a/tests/entries/test_entry_tools.py +++ b/tests/entries/test_entry_tools.py @@ -34,8 +34,8 @@ def test_group_entries_by_composition(self): # Make sure no entries are left behind assert sum(len(g) for g in groups) == len(entries) # test sorting by energy - for g in groups: - assert g == sorted(g, key=lambda e: e.energy_per_atom) + for group in groups: + assert group == sorted(group, key=lambda e: e.energy_per_atom) class TestEntrySet(PymatgenTest): @@ -48,8 +48,8 @@ def test_chemsys(self): def test_get_subset(self): entries = self.entry_set.get_subset_in_chemsys(["Li", "O"]) - for e in entries: - assert {Element.Li, Element.O}.issuperset(e.composition) + for ent in entries: + assert {Element.Li, Element.O}.issuperset(ent.composition) with pytest.raises(ValueError) as exc: # noqa: PT011 self.entry_set.get_subset_in_chemsys(["Fe", "F"]) assert "['F', 'Fe'] is not a subset of ['Fe', 'Li', 'O', 'P'], extra: {'F'}" in str(exc.value) @@ -63,3 +63,14 @@ def test_as_dict(self): dumpfn(self.entry_set, f"{self.tmp_path}/temp_entry_set.json") entry_set = loadfn(f"{self.tmp_path}/temp_entry_set.json") assert len(entry_set) == len(self.entry_set) + + def test_ground_states(self): + ground_states = self.entry_set.ground_states + assert len(ground_states) < len(self.entry_set) + + # Check if ground states have the lowest energy per atom for each composition + for gs in ground_states: + same_comp_entries = [ + ent for ent in self.entry_set if ent.composition.reduced_formula == gs.composition.reduced_formula + ] + assert gs.energy_per_atom <= min(entry.energy_per_atom for entry in same_comp_entries) diff --git a/tests/entries/test_exp_entries.py b/tests/entries/test_exp_entries.py index 052549c7899..65fdeb569e5 100644 --- a/tests/entries/test_exp_entries.py +++ b/tests/entries/test_exp_entries.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase from monty.json import MontyDecoder from pytest import approx @@ -10,7 +10,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestExpEntry(unittest.TestCase): +class TestExpEntry(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/Fe2O3_exp.json") as file: thermo_data = json.load(file, cls=MontyDecoder) diff --git a/tests/entries/test_mixing_scheme.py b/tests/entries/test_mixing_scheme.py index 1c237a3dbd9..c98e1ef2e46 100644 --- a/tests/entries/test_mixing_scheme.py +++ b/tests/entries/test_mixing_scheme.py @@ -420,7 +420,7 @@ def ms_gga_1_scan(ms_complete): ground state of SnBr2 (r2scan-4). """ gga_entries = ms_complete.gga_entries - scan_entries = [e for e in ms_complete.scan_entries if e.entry_id == "r2scan-4"] + scan_entries = [entry for entry in ms_complete.scan_entries if entry.entry_id == "r2scan-4"] row_list = [ ["Br", 64, 4, True, "gga-3", None, "GGA", None, 0.0, np.nan, 0.0, np.nan], @@ -472,7 +472,7 @@ def ms_gga_2_scan_same(ms_complete): ground state and one unstable polymorph of SnBr2 (r2scan-4 and r2scan-6). """ gga_entries = ms_complete.gga_entries - scan_entries = [e for e in ms_complete.scan_entries if e.entry_id in ["r2scan-4", "r2scan-6"]] + scan_entries = [entry for entry in ms_complete.scan_entries if entry.entry_id in ["r2scan-4", "r2scan-6"]] row_list = [ ["Br", 64, 4, True, "gga-3", None, "GGA", None, 0.0, np.nan, 0.0, np.nan], @@ -496,7 +496,7 @@ def ms_gga_2_scan_diff_match(ms_complete): r2scan-4 and r2scan-7. """ gga_entries = ms_complete.gga_entries - scan_entries = [e for e in ms_complete.scan_entries if e.entry_id in ["r2scan-4", "r2scan-7"]] + scan_entries = [entry for entry in ms_complete.scan_entries if entry.entry_id in ["r2scan-4", "r2scan-7"]] row_list = [ ["Br", 64, 4, True, "gga-3", None, "GGA", None, 0.0, np.nan, 0.0, np.nan], @@ -519,7 +519,7 @@ def ms_gga_2_scan_diff_no_match(ms_complete): that does not match any GGA material (r2scan-8). """ gga_entries = ms_complete.gga_entries - scan_entries = [e for e in ms_complete.scan_entries if e.entry_id == "r2scan-4"] + scan_entries = [entry for entry in ms_complete.scan_entries if entry.entry_id == "r2scan-4"] scan_entries.append( ComputedStructureEntry( Structure( @@ -560,7 +560,9 @@ def ms_all_gga_scan_gs(ms_complete): ground states, but no others. """ gga_entries = ms_complete.gga_entries - scan_entries = [e for e in ms_complete.scan_entries if e.entry_id in ["r2scan-1", "r2scan-3", "r2scan-4"]] + scan_entries = [ + entry for entry in ms_complete.scan_entries if entry.entry_id in ["r2scan-1", "r2scan-3", "r2scan-4"] + ] row_list = [ ["Br", 64, 4, True, "gga-3", "r2scan-3", "GGA", "R2SCAN", 0.0, 0.0, 0.0, 0.0], @@ -660,7 +662,7 @@ def ms_all_scan_novel(ms_complete): @pytest.fixture() def ms_incomplete_gga_all_scan(ms_complete): """Mixing state with an incomplete GGA phase diagram.""" - gga_entries = [e for e in ms_complete.gga_entries if e.reduced_formula != "Sn"] + gga_entries = [entry for entry in ms_complete.gga_entries if entry.reduced_formula != "Sn"] scan_entries = ms_complete.scan_entries row_list = [ @@ -901,12 +903,12 @@ def test_clean(self, mixing_scheme_no_compat): ] mixing_scheme_no_compat.process_entries(entries, clean=False) - for e in entries: - assert e.correction == -20 + for entry in entries: + assert entry.correction == -20 mixing_scheme_no_compat.process_entries(entries, clean=True) - for e in entries: - assert e.correction == 0 + for entry in entries: + assert entry.correction == 0 def test_no_run_type(self, mixing_scheme_no_compat): """ @@ -986,7 +988,7 @@ def test_incompatible_run_type(self, mixing_scheme_no_compat): state_data = mixing_scheme_no_compat.get_mixing_state_data(entries) with pytest.raises(CompatibilityError, match="Invalid run_type='LDA'"): mixing_scheme_no_compat.get_adjustments( - next(e for e in entries if e.parameters["run_type"] == "LDA"), + next(entry for entry in entries if entry.parameters["run_type"] == "LDA"), state_data, ) @@ -1011,19 +1013,13 @@ def test_no_foreign_entries(self, mixing_scheme_no_compat, ms_complete): """ If process_entries or get_adjustments is called with a populated mixing_state_data kwarg and one or more of the entry_ids is not present in the mixing_state_data, - raise CompatbilityError. + raise CompatibilityError. """ foreign_entry = ComputedStructureEntry( Structure( lattice3, ["Sn", "Br", "Br", "Br", "Br"], - [ - [0, 0, 0], - [0.2, 0.2, 0.2], - [0.4, 0.4, 0.4], - [0.7, 0.7, 0.7], - [1, 1, 1], - ], + [[0, 0, 0], [0.2, 0.2, 0.2], [0.4, 0.4, 0.4], [0.7, 0.7, 0.7], [1, 1, 1]], ), -25, parameters={"run_type": "R2SCAN"}, @@ -1038,9 +1034,9 @@ def test_no_foreign_entries(self, mixing_scheme_no_compat, ms_complete): [*ms_complete.all_entries, foreign_entry], mixing_state_data=ms_complete.state_data ) assert len(entries) == 7 - for e in entries: - assert e.correction == 0 - assert e.parameters["run_type"] == "R2SCAN" + for entry in entries: + assert entry.correction == 0 + assert entry.parameters["run_type"] == "R2SCAN" def test_fuzzy_matching(self, ms_complete): """ @@ -1294,13 +1290,13 @@ def test_state_gga_only(self, mixing_scheme_no_compat, ms_gga_only): state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_gga_only.all_entries) pd.testing.assert_frame_equal(state_data, ms_gga_only.state_data) - for e in ms_gga_only.all_entries: - assert mixing_scheme_no_compat.get_adjustments(e, ms_gga_only.state_data) == [] + for entry in ms_gga_only.all_entries: + assert mixing_scheme_no_compat.get_adjustments(entry, ms_gga_only.state_data) == [] entries = mixing_scheme_no_compat.process_entries(ms_gga_only.all_entries) assert len(entries) == 7 - for e in entries: - assert e.correction == 0 - assert e.parameters["run_type"] == "GGA" + for entry in entries: + assert entry.correction == 0 + assert entry.parameters["run_type"] == "GGA" def test_state_scan_only(self, mixing_scheme_no_compat, ms_scan_only): """ @@ -1311,14 +1307,14 @@ def test_state_scan_only(self, mixing_scheme_no_compat, ms_scan_only): state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_scan_only.all_entries) pd.testing.assert_frame_equal(state_data, ms_scan_only.state_data) - for e in ms_scan_only.all_entries: - assert mixing_scheme_no_compat.get_adjustments(e, ms_scan_only.state_data) == [] + for entry in ms_scan_only.all_entries: + assert mixing_scheme_no_compat.get_adjustments(entry, ms_scan_only.state_data) == [] entries = mixing_scheme_no_compat.process_entries(ms_scan_only.all_entries) assert len(entries) == 7 - for e in entries: - assert e.correction == 0 - assert e.parameters["run_type"] == "R2SCAN" + for entry in entries: + assert entry.correction == 0 + assert entry.parameters["run_type"] == "R2SCAN" def test_state_gga_1_scan(self, mixing_scheme_no_compat, ms_gga_1_scan): """ @@ -1331,26 +1327,26 @@ def test_state_gga_1_scan(self, mixing_scheme_no_compat, ms_gga_1_scan): state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_gga_1_scan.all_entries) pd.testing.assert_frame_equal(state_data, ms_gga_1_scan.state_data) - for e in ms_gga_1_scan.gga_entries: - if e.entry_id == "gga-4": + for entry in ms_gga_1_scan.gga_entries: + if entry.entry_id == "gga-4": with pytest.raises(CompatibilityError, match="because it is a GGA\\(\\+U\\) ground state"): - mixing_scheme_no_compat.get_adjustments(e, ms_gga_1_scan.state_data) + mixing_scheme_no_compat.get_adjustments(entry, ms_gga_1_scan.state_data) else: - assert mixing_scheme_no_compat.get_adjustments(e, ms_gga_1_scan.state_data) == [] + assert mixing_scheme_no_compat.get_adjustments(entry, ms_gga_1_scan.state_data) == [] - for e in ms_gga_1_scan.scan_entries: + for entry in ms_gga_1_scan.scan_entries: # gga-4 energy is -6 eV/atom, r2scan-4 energy is -7 eV/atom. There are 3 atoms. - assert mixing_scheme_no_compat.get_adjustments(e, ms_gga_1_scan.state_data)[0].value == 3 + assert mixing_scheme_no_compat.get_adjustments(entry, ms_gga_1_scan.state_data)[0].value == 3 entries = mixing_scheme_no_compat.process_entries(ms_gga_1_scan.all_entries) assert len(entries) == 7 - for e in entries: - if "4" in e.entry_id: - assert e.correction == 3 - assert e.parameters["run_type"] == "R2SCAN" + for entry in entries: + if "4" in entry.entry_id: + assert entry.correction == 3 + assert entry.parameters["run_type"] == "R2SCAN" else: - assert e.correction == 0, f"{e.entry_id}" - assert e.parameters["run_type"] == "GGA" + assert entry.correction == 0, f"{entry.entry_id}" + assert entry.parameters["run_type"] == "GGA" def test_state_gga_1_scan_plus_novel(self, mixing_scheme_no_compat, ms_gga_1_scan_novel): """ @@ -1362,13 +1358,13 @@ def test_state_gga_1_scan_plus_novel(self, mixing_scheme_no_compat, ms_gga_1_sca state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_gga_1_scan_novel.all_entries) pd.testing.assert_frame_equal(state_data, ms_gga_1_scan_novel.state_data) - for e in ms_gga_1_scan_novel.gga_entries: - if e.entry_id == "gga-4": - assert mixing_scheme_no_compat.get_adjustments(e, ms_gga_1_scan_novel.state_data) == [] + for entry in ms_gga_1_scan_novel.gga_entries: + if entry.entry_id == "gga-4": + assert mixing_scheme_no_compat.get_adjustments(entry, ms_gga_1_scan_novel.state_data) == [] - for e in ms_gga_1_scan_novel.scan_entries: + for entry in ms_gga_1_scan_novel.scan_entries: with pytest.raises(CompatibilityError, match="no R2SCAN ground states at this composition"): - mixing_scheme_no_compat.get_adjustments(e, ms_gga_1_scan_novel.state_data) + mixing_scheme_no_compat.get_adjustments(entry, ms_gga_1_scan_novel.state_data) entries = mixing_scheme_no_compat.process_entries(ms_gga_1_scan_novel.all_entries) assert len(entries) == 7 @@ -1525,9 +1521,9 @@ def test_state_incomplete_gga_all_scan(self, mixing_scheme_no_compat, ms_incompl state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_incomplete_gga_all_scan.all_entries) pd.testing.assert_frame_equal(state_data, ms_incomplete_gga_all_scan.state_data) - for e in ms_incomplete_gga_all_scan.all_entries: + for entry in ms_incomplete_gga_all_scan.all_entries: with pytest.raises(CompatibilityError, match="do not form a complete PhaseDiagram"): - mixing_scheme_no_compat.get_adjustments(e, ms_incomplete_gga_all_scan.state_data) + mixing_scheme_no_compat.get_adjustments(entry, ms_incomplete_gga_all_scan.state_data) # process_entries should discard all entries and issue a warning with pytest.warns(UserWarning, match="do not form a complete PhaseDiagram"): @@ -1630,12 +1626,12 @@ def test_state_energy_modified(self, mixing_scheme_no_compat, ms_complete): """ state_data = mixing_scheme_no_compat.get_mixing_state_data(ms_complete.all_entries) # lower the energy of the SnBr2 ground state - e = next(e for e in ms_complete.gga_entries if e.entry_id == "gga-4") + entry = next(ent for ent in ms_complete.gga_entries if ent.entry_id == "gga-4") d_compat = DummyCompatibility() - d_compat.process_entries(e) + d_compat.process_entries(entry) with pytest.raises(CompatibilityError, match="energy has been modified"): - mixing_scheme_no_compat.get_adjustments(e, state_data) + mixing_scheme_no_compat.get_adjustments(entry, state_data) def test_chemsys_mismatch(self, mixing_scheme_no_compat, ms_scan_chemsys_superset): """ diff --git a/tests/ext/test_cod.py b/tests/ext/test_cod.py index e0802fae6d2..de925483794 100644 --- a/tests/ext/test_cod.py +++ b/tests/ext/test_cod.py @@ -1,7 +1,7 @@ from __future__ import annotations -import unittest from shutil import which +from unittest import TestCase import pytest import requests @@ -15,7 +15,7 @@ @pytest.mark.skipif(website_down, reason="www.crystallography.net is down.") -class TestCOD(unittest.TestCase): +class TestCOD(TestCase): @pytest.mark.skipif(not which("mysql"), reason="No mysql.") def test_get_cod_ids(self): ids = COD().get_cod_ids("Li2O") diff --git a/tests/ext/test_matproj.py b/tests/ext/test_matproj.py index aa3eb42f65c..a76034eb737 100644 --- a/tests/ext/test_matproj.py +++ b/tests/ext/test_matproj.py @@ -108,7 +108,7 @@ def test_get_data(self): assert set(Composition(d["unit_cell_formula"]).elements).issubset(elements) with pytest.raises(MPRestError, match="REST query returned with error status code 404"): - self.rester.get_data("Fe2O3", "badmethod") + self.rester.get_data("Fe2O3", "bad-method") def test_get_materials_id_from_task_id(self): assert self.rester.get_materials_id_from_task_id("mp-540081") == "mp-19017" @@ -133,12 +133,12 @@ def test_get_entries_in_chemsys(self): entries = self.rester.get_entries_in_chemsys(syms) entries2 = self.rester.get_entries_in_chemsys(syms2) elements = {Element(sym) for sym in syms} - for e in entries: - assert isinstance(e, ComputedEntry) - assert set(e.elements).issubset(elements) + for entry in entries: + assert isinstance(entry, ComputedEntry) + assert set(entry.elements).issubset(elements) - e1 = {i.entry_id for i in entries} - e2 = {i.entry_id for i in entries2} + e1 = {ent.entry_id for ent in entries} + e2 = {ent.entry_id for ent in entries2} assert e1 == e2 stable_entries = self.rester.get_entries_in_chemsys(syms, additional_criteria={"e_above_hull": {"$lte": 0.001}}) @@ -280,10 +280,10 @@ def test_get_pourbaix_entries(self): for pbx_entry in pbx_entries: assert isinstance(pbx_entry, PourbaixEntry) - # fe_two_plus = [e for e in pbx_entries if e.entry_id == "ion-0"][0] + # fe_two_plus = next(entry for entry in pbx_entries if entry.entry_id == "ion-0") # assert fe_two_plus.energy == approx(-1.12369, abs=1e-3) - # feo2 = [e for e in pbx_entries if e.entry_id == "mp-25332"][0] + # feo2 = next(entry for entry in pbx_entries if entry.entry_id == "mp-25332") # assert feo2.energy == approx(3.56356, abs=1e-3) # # Test S, which has Na in reference solids @@ -460,11 +460,11 @@ def test_include_user_agent(self): ) headers = self.rester.session.headers assert "user-agent" in headers, "Include user-agent header by default" - m = re.match( + match = re.match( r"pymatgen/(\d+)\.(\d+)\.(\d+)\.?(\d+)? \(Python/(\d+)\.(\d)+\.(\d+) ([^\/]*)/([^\)]*)\)", headers["user-agent"], ) - assert m is not None, f"Unexpected user-agent value {headers['user-agent']}" + assert match is not None, f"Unexpected user-agent value {headers['user-agent']}" self.rester = _MPResterLegacy(include_user_agent=False) assert "user-agent" not in self.rester.session.headers, "user-agent header unwanted" @@ -627,15 +627,15 @@ def test_get_entries_and_in_chemsys(self): entries = self.rester.get_entries_in_chemsys(syms) entries2 = self.rester.get_entries(syms2) elements = {Element(sym) for sym in syms} - for e in entries: - assert isinstance(e, ComputedEntry) - assert set(e.elements).issubset(elements) + for entry in entries: + assert isinstance(entry, ComputedEntry) + assert set(entry.elements).issubset(elements) assert len(entries) > 1000 - for e in entries2: - assert isinstance(e, ComputedEntry) - assert set(e.elements).issubset(elements) + for entry in entries2: + assert isinstance(entry, ComputedEntry) + assert set(entry.elements).issubset(elements) assert len(entries2) < 1000 e1 = {i.entry_id for i in entries} diff --git a/tests/ext/test_optimade.py b/tests/ext/test_optimade.py index e06e4a99755..9e09625a2b8 100644 --- a/tests/ext/test_optimade.py +++ b/tests/ext/test_optimade.py @@ -3,21 +3,33 @@ import pytest import requests -from pymatgen.core import SETTINGS from pymatgen.ext.optimade import OptimadeRester from pymatgen.util.testing import PymatgenTest try: # 403 is returned when server detects bot-like behavior - website_down = requests.get("https://materialsproject.org").status_code not in (200, 403) + website_down = requests.get(OptimadeRester.aliases["mp"]).status_code not in (200, 403) except requests.exceptions.ConnectionError: website_down = True +try: + optimade_providers_down = requests.get("https://providers.optimade.org").status_code not in (200, 403) +except requests.exceptions.ConnectionError: + optimade_providers_down = True + +try: + mc3d_down = requests.get(OptimadeRester.aliases["mcloud.mc3d"] + "/v1/info").status_code not in (200, 403, 301) +except requests.exceptions.ConnectionError: + mc3d_down = True + +try: + mc2d_down = requests.get(OptimadeRester.aliases["mcloud.mc2d"] + "/v1/info").status_code not in (200, 403, 301) +except requests.exceptions.ConnectionError: + mc2d_down = True + class TestOptimade(PymatgenTest): - @pytest.mark.skipif( - not SETTINGS.get("PMG_MAPI_KEY") or website_down, reason="PMG_MAPI_KEY env var not set or MP is down." - ) + @pytest.mark.skipif(website_down, reason="MP OPTIMADE is down.") def test_get_structures_mp(self): with OptimadeRester("mp") as optimade: structs = optimade.get_structures(elements=["Ga", "N"], nelements=2) @@ -35,9 +47,7 @@ def test_get_structures_mp(self): raw_filter_structs["mp"] ), f"Raw filter {_filter} did not return the same number of results as the query builder." - @pytest.mark.skipif( - not SETTINGS.get("PMG_MAPI_KEY") or website_down, reason="PMG_MAPI_KEY env var not set or MP is down." - ) + @pytest.mark.skipif(website_down, reason="MP OPTIMADE is down.") def test_get_snls_mp(self): base_query = dict(elements=["Ga", "N"], nelements=2, nsites=[2, 6]) with OptimadeRester("mp") as optimade: @@ -60,48 +70,50 @@ def test_get_snls_mp(self): struct_nl_set = next(iter(extra_fields_set["mp"].values())) assert field_set <= {*struct_nl_set.data["_optimade"]} - # Tests fail in CI for unknown reason, use for development only. - # def test_get_structures_mcloud_2dstructures(self): - # with OptimadeRester("mcloud.2dstructures") as optimade: - # structs = optimade.get_structures(elements=["B", "N"], nelements=2) + @pytest.mark.skipif(mc3d_down or mc2d_down, reason="At least one MC OPTIMADE API is down.") + def test_get_structures_mcloud(self): + with OptimadeRester(["mcloud.mc2d", "mcloud.mc3d"]) as optimade: + structs = optimade.get_structures(elements=["B", "N"], nelements=2) - # test_struct = next(iter(structs["mcloud.2dstructures"].values())) + test_struct = next(iter(structs["mcloud.mc2d"].values())) - # assert [str(el) for el in test_struct.types_of_species] == ["B", "N"] + assert [str(el) for el in test_struct.types_of_species] == ["B", "N"] - # def test_update_aliases(self): - # - # with OptimadeRester() as optimade: - # optimade.refresh_aliases() - # - # self.assertIn("mp", optimade.aliases) + test_struct = next(iter(structs["mcloud.mc3d"].values())) + assert [str(el) for el in test_struct.types_of_species] == ["B", "N"] + + @pytest.mark.skipif(optimade_providers_down, reason="OPTIMADE providers list is down.") + def test_update_aliases(self): + with OptimadeRester() as optimade: + optimade.refresh_aliases() + + assert "mp" in optimade.aliases def test_build_filter(self): - with OptimadeRester("mp") as optimade: - assert optimade._build_filter( - elements=["Ga", "N"], - nelements=2, - nsites=(1, 100), - chemical_formula_anonymous="A2B", - chemical_formula_hill="GaN", - ) == ( - '(elements HAS ALL "Ga", "N")' - " AND (nsites>=1 AND nsites<=100)" - " AND (nelements=2)" - " AND (chemical_formula_anonymous='A2B')" - " AND (chemical_formula_hill='GaN')" - ) - - assert optimade._build_filter( - elements=["C", "H", "O"], - nelements=(3, 4), - nsites=(1, 100), - chemical_formula_anonymous="A4B3C", - chemical_formula_hill="C4H3O", - ) == ( - '(elements HAS ALL "C", "H", "O")' - " AND (nsites>=1 AND nsites<=100)" - " AND (nelements>=3 AND nelements<=4)" - " AND (chemical_formula_anonymous='A4B3C')" - " AND (chemical_formula_hill='C4H3O')" - ) + assert OptimadeRester._build_filter( + elements=["Ga", "N"], + nelements=2, + nsites=(1, 100), + chemical_formula_anonymous="A2B", + chemical_formula_hill="GaN", + ) == ( + '(elements HAS ALL "Ga", "N")' + " AND (nsites>=1 AND nsites<=100)" + " AND (nelements=2)" + " AND (chemical_formula_anonymous='A2B')" + " AND (chemical_formula_hill='GaN')" + ) + + assert OptimadeRester._build_filter( + elements=["C", "H", "O"], + nelements=(3, 4), + nsites=(1, 100), + chemical_formula_anonymous="A4B3C", + chemical_formula_hill="C4H3O", + ) == ( + '(elements HAS ALL "C", "H", "O")' + " AND (nsites>=1 AND nsites<=100)" + " AND (nelements>=3 AND nelements<=4)" + " AND (chemical_formula_anonymous='A4B3C')" + " AND (chemical_formula_hill='C4H3O')" + ) diff --git a/tests/files/.pytest-split-durations b/tests/files/.pytest-split-durations index 465fac2bb5f..d12ddce1e5d 100644 --- a/tests/files/.pytest-split-durations +++ b/tests/files/.pytest-split-durations @@ -563,19 +563,13 @@ "tests/analysis/test_pourbaix_diagram.py::TestPourbaixPlotter::test_plot_entry_stability": 0.12873674900038168, "tests/analysis/test_pourbaix_diagram.py::TestPourbaixPlotter::test_plot_pourbaix": 0.25079550000373274, "tests/analysis/test_prototypes.py::TestAflowPrototypeMatcher::test_prototype_matching": 3.0040170419961214, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_debye_temperature": 0.00467820797348395, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_gruneisen_parameter": 0.004110915935598314, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_optimum_volume": 0.003661041962914169, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_thermal_conductivity": 0.00443675002316013, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_vibrational_free_energy": 0.004137665964663029, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestAnharmonicQuasiharmociDebyeApprox::test_vibrational_internal_energy": 0.004504501004703343, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_bulk_modulus": 0.005177041050046682, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_debye_temperature": 0.0036324569955468178, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_gruneisen_parameter": 0.003981249057687819, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_optimum_volume": 0.0035152080236002803, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_thermal_conductivity": 0.004235040920320898, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_vibrational_free_energy": 0.0036306249676272273, - "tests/analysis/test_quasiharmonic_debye_approx.py::TestQuasiharmociDebyeApprox::test_vibrational_internal_energy": 0.003679332963656634, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_bulk_modulus": 0.005177041050046682, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_debye_temperature": 0.0036324569955468178, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_gruneisen_parameter": 0.003981249057687819, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_optimum_volume": 0.0035152080236002803, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_thermal_conductivity": 0.004235040920320898, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_vibrational_free_energy": 0.0036306249676272273, + "tests/analysis/test_quasi_harmonic_debye_approx.py::TestQuasiHarmonicDebyeApprox::test_vibrational_internal_energy": 0.003679332963656634, "tests/analysis/test_reaction_calculator.py::TestBalancedReaction::test_from_str": 0.0005021260003559291, "tests/analysis/test_reaction_calculator.py::TestBalancedReaction::test_init": 0.0004712080117315054, "tests/analysis/test_reaction_calculator.py::TestBalancedReaction::test_remove_spectator_species": 0.0003208750276826322, @@ -651,15 +645,15 @@ "tests/analysis/test_surface_analysis.py::TestSurfaceEnergyPlotter::test_get_surface_equilibrium": 0.009471333003602922, "tests/analysis/test_surface_analysis.py::TestSurfaceEnergyPlotter::test_stable_u_range_dict": 0.009161207941360772, "tests/analysis/test_surface_analysis.py::TestSurfaceEnergyPlotter::test_wulff_from_chempot": 0.14841087599052116, - "tests/analysis/test_surface_analysis.py::TestWorkfunctionAnalyzer::test_is_converged": 0.8860225000535138, - "tests/analysis/test_surface_analysis.py::TestWorkfunctionAnalyzer::test_shift": 1.7440659570274875, + "tests/analysis/test_surface_analysis.py::TestWorkFunctionAnalyzer::test_is_converged": 0.8860225000535138, + "tests/analysis/test_surface_analysis.py::TestWorkFunctionAnalyzer::test_shift": 1.7440659570274875, "tests/analysis/test_transition_state.py::TestNEBAnalysis::test_combine_neb_plots": 3.785049124970101, "tests/analysis/test_wulff.py::TestWulffShape::test_corner_and_edges": 0.048474123992491513, "tests/analysis/test_wulff.py::TestWulffShape::test_get_azimuth_elev": 0.046476999996230006, "tests/analysis/test_wulff.py::TestWulffShape::test_get_plot": 0.1072019999846816, "tests/analysis/test_wulff.py::TestWulffShape::test_get_plotly": 0.1067069589626044, "tests/analysis/test_wulff.py::TestWulffShape::test_properties": 0.04795904195634648, - "tests/analysis/test_xps.py::XPSTestCase::test_from_dos": 0.7830395839991979, + "tests/analysis/test_xps.py::TestXPS::test_from_dos": 0.7830395839991979, "tests/analysis/topological/test_spillage.py::TestSolar::test_spillage_from_vasprun": 3.4113557069795206, "tests/analysis/xas/test_spectrum.py::TestXAS::test_add_mul": 0.002699415956158191, "tests/analysis/xas/test_spectrum.py::TestXAS::test_attributes": 0.0020827510161325336, @@ -709,7 +703,7 @@ "tests/command_line/test_chargemol_caller.py::TestChargemolAnalysis::test_parse_chargemol": 0.05377787502948195, "tests/command_line/test_chargemol_caller.py::TestChargemolAnalysis::test_parse_chargemol2": 0.0007731670048087835, "tests/command_line/test_critic2_caller.py::TestCritic2Analysis::test_graph_output": 0.002192000043578446, - "tests/command_line/test_critic2_caller.py::TestCritic2Analysis::test_properties_to_from_dict": 0.0017835840117186308, + "tests/command_line/test_critic2_caller.py::TestCritic2Analysis::test_to_from_dict": 0.0017835840117186308, "tests/command_line/test_critic2_caller.py::TestCritic2Caller::test_from_path": 0.00018416601233184338, "tests/command_line/test_critic2_caller.py::TestCritic2Caller::test_from_structure": 0.00014708400703966618, "tests/command_line/test_enumlib_caller.py::TestEnumlibAdaptor::test_init": 0.0001795410644263029, @@ -830,102 +824,102 @@ "tests/core/test_ion.py::TestIon::test_oxi_state_guesses": 0.0005007070139981806, "tests/core/test_ion.py::TestIon::test_special_formulas": 0.000947916938457638, "tests/core/test_ion.py::TestIon::test_to_latex_string": 0.0005689989775419235, - "tests/core/test_lattice.py::LatticeTestCase::test_attributes": 0.0013616670621559024, - "tests/core/test_lattice.py::LatticeTestCase::test_copy": 0.00129941594786942, - "tests/core/test_lattice.py::LatticeTestCase::test_d_hkl": 0.001224166015163064, - "tests/core/test_lattice.py::LatticeTestCase::test_dot_and_norm": 0.0021065829787403345, - "tests/core/test_lattice.py::LatticeTestCase::test_equal": 0.0015783329727128148, - "tests/core/test_lattice.py::LatticeTestCase::test_find_all_mappings": 0.014700166007969528, - "tests/core/test_lattice.py::LatticeTestCase::test_find_mapping": 0.001790250011254102, - "tests/core/test_lattice.py::LatticeTestCase::test_format": 0.0012162079801782966, - "tests/core/test_lattice.py::LatticeTestCase::test_get_all_distances": 0.0014519169926643372, - "tests/core/test_lattice.py::LatticeTestCase::test_get_cartesian_or_frac_coord": 0.0013719589915126562, - "tests/core/test_lattice.py::LatticeTestCase::test_get_distance_and_image": 0.001282749988604337, - "tests/core/test_lattice.py::LatticeTestCase::test_get_distance_and_image_strict": 0.013159626047126949, - "tests/core/test_lattice.py::LatticeTestCase::test_get_lll_reduced_lattice": 0.002524374984204769, - "tests/core/test_lattice.py::LatticeTestCase::test_get_miller_index_from_sites": 0.0015758739900775254, - "tests/core/test_lattice.py::LatticeTestCase::test_get_niggli_reduced_lattice": 0.002360417041927576, - "tests/core/test_lattice.py::LatticeTestCase::test_get_points_in_sphere": 0.005142041016370058, - "tests/core/test_lattice.py::LatticeTestCase::test_get_vector_along_lattice_directions": 0.0014230000670067966, - "tests/core/test_lattice.py::LatticeTestCase::test_get_wigner_seitz_cell": 0.0015817929524928331, - "tests/core/test_lattice.py::LatticeTestCase::test_init": 0.0013397909933701158, - "tests/core/test_lattice.py::LatticeTestCase::test_is_3d_periodic": 0.0012610420235432684, - "tests/core/test_lattice.py::LatticeTestCase::test_is_hexagonal": 0.0012544590863399208, - "tests/core/test_lattice.py::LatticeTestCase::test_lattice_matrices": 0.0012341660331003368, - "tests/core/test_lattice.py::LatticeTestCase::test_lll_basis": 0.001683499023783952, - "tests/core/test_lattice.py::LatticeTestCase::test_mapping_symmetry": 0.001692500023636967, - "tests/core/test_lattice.py::LatticeTestCase::test_monoclinic": 0.0012882089940831065, - "tests/core/test_lattice.py::LatticeTestCase::test_points_in_spheres": 0.0026560419937595725, - "tests/core/test_lattice.py::LatticeTestCase::test_reciprocal_lattice": 0.001335958018898964, - "tests/core/test_lattice.py::LatticeTestCase::test_scale": 0.0016442069900222123, - "tests/core/test_lattice.py::LatticeTestCase::test_selling_dist": 0.00350079097552225, - "tests/core/test_lattice.py::LatticeTestCase::test_selling_vector": 0.0014326670207083225, - "tests/core/test_lattice.py::LatticeTestCase::test_static_methods": 0.0013170000165700912, - "tests/core/test_lattice.py::LatticeTestCase::test_as_from_dict": 0.0015288760187104344, + "tests/core/test_lattice.py::TestLattice::test_attributes": 0.0013616670621559024, + "tests/core/test_lattice.py::TestLattice::test_copy": 0.00129941594786942, + "tests/core/test_lattice.py::TestLattice::test_d_hkl": 0.001224166015163064, + "tests/core/test_lattice.py::TestLattice::test_dot_and_norm": 0.0021065829787403345, + "tests/core/test_lattice.py::TestLattice::test_equal": 0.0015783329727128148, + "tests/core/test_lattice.py::TestLattice::test_find_all_mappings": 0.014700166007969528, + "tests/core/test_lattice.py::TestLattice::test_find_mapping": 0.001790250011254102, + "tests/core/test_lattice.py::TestLattice::test_format": 0.0012162079801782966, + "tests/core/test_lattice.py::TestLattice::test_get_all_distances": 0.0014519169926643372, + "tests/core/test_lattice.py::TestLattice::test_get_cartesian_or_frac_coord": 0.0013719589915126562, + "tests/core/test_lattice.py::TestLattice::test_get_distance_and_image": 0.001282749988604337, + "tests/core/test_lattice.py::TestLattice::test_get_distance_and_image_strict": 0.013159626047126949, + "tests/core/test_lattice.py::TestLattice::test_get_lll_reduced_lattice": 0.002524374984204769, + "tests/core/test_lattice.py::TestLattice::test_get_miller_index_from_sites": 0.0015758739900775254, + "tests/core/test_lattice.py::TestLattice::test_get_niggli_reduced_lattice": 0.002360417041927576, + "tests/core/test_lattice.py::TestLattice::test_get_points_in_sphere": 0.005142041016370058, + "tests/core/test_lattice.py::TestLattice::test_get_vector_along_lattice_directions": 0.0014230000670067966, + "tests/core/test_lattice.py::TestLattice::test_get_wigner_seitz_cell": 0.0015817929524928331, + "tests/core/test_lattice.py::TestLattice::test_init": 0.0013397909933701158, + "tests/core/test_lattice.py::TestLattice::test_is_3d_periodic": 0.0012610420235432684, + "tests/core/test_lattice.py::TestLattice::test_is_hexagonal": 0.0012544590863399208, + "tests/core/test_lattice.py::TestLattice::test_lattice_matrices": 0.0012341660331003368, + "tests/core/test_lattice.py::TestLattice::test_lll_basis": 0.001683499023783952, + "tests/core/test_lattice.py::TestLattice::test_mapping_symmetry": 0.001692500023636967, + "tests/core/test_lattice.py::TestLattice::test_monoclinic": 0.0012882089940831065, + "tests/core/test_lattice.py::TestLattice::test_points_in_spheres": 0.0026560419937595725, + "tests/core/test_lattice.py::TestLattice::test_reciprocal_lattice": 0.001335958018898964, + "tests/core/test_lattice.py::TestLattice::test_scale": 0.0016442069900222123, + "tests/core/test_lattice.py::TestLattice::test_selling_dist": 0.00350079097552225, + "tests/core/test_lattice.py::TestLattice::test_selling_vector": 0.0014326670207083225, + "tests/core/test_lattice.py::TestLattice::test_static_methods": 0.0013170000165700912, + "tests/core/test_lattice.py::TestLattice::test_as_from_dict": 0.0015288760187104344, "tests/core/test_libxcfunc.py::TestLibxcFunc::test_libxcfunc_api": 0.0015083340113051236, - "tests/core/test_molecular_orbitals.py::MolecularOrbitalTestCase::test_aos_as_list": 0.0012476249830797315, - "tests/core/test_molecular_orbitals.py::MolecularOrbitalTestCase::test_fractional_compositions": 0.00218645908171311, - "tests/core/test_molecular_orbitals.py::MolecularOrbitalTestCase::test_max_electronegativity": 0.004327957925852388, - "tests/core/test_molecular_orbitals.py::MolecularOrbitalTestCase::test_obtain_band_edges": 0.0013270010240375996, - "tests/core/test_operations.py::MagSymmOpTestCase::test_operate_magmom": 0.0021593329729512334, - "tests/core/test_operations.py::MagSymmOpTestCase::test_as_from_dict": 0.0012797510134987533, - "tests/core/test_operations.py::MagSymmOpTestCase::test_xyzt_string": 0.0013733739615418017, - "tests/core/test_operations.py::SymmOpTestCase::test_apply_rotation_only": 0.0012938339496031404, - "tests/core/test_operations.py::SymmOpTestCase::test_are_symmetrically_related": 0.0012189170229248703, - "tests/core/test_operations.py::SymmOpTestCase::test_are_symmetrically_related_vectors": 0.0014421659871004522, - "tests/core/test_operations.py::SymmOpTestCase::test_inverse": 0.0012554159620776772, - "tests/core/test_operations.py::SymmOpTestCase::test_inversion": 0.0012409999617375433, - "tests/core/test_operations.py::SymmOpTestCase::test_operate": 0.0012290010345168412, - "tests/core/test_operations.py::SymmOpTestCase::test_operate_multi": 0.0013142511015757918, - "tests/core/test_operations.py::SymmOpTestCase::test_properties": 0.0012810410116799176, - "tests/core/test_operations.py::SymmOpTestCase::test_reflection": 0.0012293330510146916, - "tests/core/test_operations.py::SymmOpTestCase::test_as_from_dict": 0.0013375001144595444, - "tests/core/test_operations.py::SymmOpTestCase::test_transform_tensor": 0.0014549170155078173, - "tests/core/test_operations.py::SymmOpTestCase::test_xyz": 0.0016299580456689, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_attr": 0.0002989170025102794, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_eq": 0.0002703749923966825, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_from_string": 0.0003050409723073244, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_immutable": 0.00031504296930506825, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_init": 0.0004290830693207681, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_pickle": 0.00027341704117134213, - "tests/core/test_periodic_table.py::DummySpeciesTestCase::test_sort": 0.0002741679782047868, - "tests/core/test_periodic_table.py::ElementTestCase::test_attributes": 0.062131249986123294, - "tests/core/test_periodic_table.py::ElementTestCase::test_block": 0.0014038329245522618, - "tests/core/test_periodic_table.py::ElementTestCase::test_data": 0.0012909590150229633, - "tests/core/test_periodic_table.py::ElementTestCase::test_deepcopy": 0.0012744999257847667, - "tests/core/test_periodic_table.py::ElementTestCase::test_dict": 0.001254875969607383, - "tests/core/test_periodic_table.py::ElementTestCase::test_from_name": 0.0013069179840385914, - "tests/core/test_periodic_table.py::ElementTestCase::test_from_row_and_group": 0.0016577079659327865, - "tests/core/test_periodic_table.py::ElementTestCase::test_full_electronic_structure": 0.0013409170205704868, - "tests/core/test_periodic_table.py::ElementTestCase::test_ground_state_term_symbol": 0.002299125015269965, - "tests/core/test_periodic_table.py::ElementTestCase::test_group": 0.0012170420959591866, - "tests/core/test_periodic_table.py::ElementTestCase::test_ie_ea": 0.0011918750242330134, - "tests/core/test_periodic_table.py::ElementTestCase::test_init": 0.00143912504427135, - "tests/core/test_periodic_table.py::ElementTestCase::test_is": 0.0012667080154642463, - "tests/core/test_periodic_table.py::ElementTestCase::test_is_metal": 0.0012503340258263052, - "tests/core/test_periodic_table.py::ElementTestCase::test_nan_x": 0.001281375007238239, - "tests/core/test_periodic_table.py::ElementTestCase::test_oxidation_states": 0.0012213330483064055, - "tests/core/test_periodic_table.py::ElementTestCase::test_pickle": 0.010329875047318637, - "tests/core/test_periodic_table.py::ElementTestCase::test_print_periodic_table": 0.009697831992525607, - "tests/core/test_periodic_table.py::ElementTestCase::test_radii": 0.0014834590256214142, - "tests/core/test_periodic_table.py::ElementTestCase::test_row": 0.0012839180417358875, - "tests/core/test_periodic_table.py::ElementTestCase::test_sort": 0.001303499040659517, - "tests/core/test_periodic_table.py::ElementTestCase::test_term_symbols": 0.002003665955271572, - "tests/core/test_periodic_table.py::ElementTestCase::test_valence": 0.0014555829693563282, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_attr": 0.0012557489681057632, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_cmp": 0.001242915983311832, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_deepcopy": 0.0012872919905930758, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_eq": 0.0014182509621605277, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_get_crystal_field_spin": 0.0016555010224692523, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_get_nmr_mom": 0.0015882920124568045, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_get_shannon_radius": 0.0014892080216668546, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_init": 0.0014829590218141675, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_ionic_radius": 0.001733499055262655, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_no_oxidation_state": 0.0012180840130895376, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_pickle": 0.0022111249854788184, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_sort": 0.0013953330926597118, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_stringify": 0.0013461660128086805, - "tests/core/test_periodic_table.py::SpeciesTestCase::test_to_from_string": 0.001370625977870077, + "tests/core/test_molecular_orbitals.py::TestMolecularOrbital::test_aos_as_list": 0.0012476249830797315, + "tests/core/test_molecular_orbitals.py::TestMolecularOrbital::test_fractional_compositions": 0.00218645908171311, + "tests/core/test_molecular_orbitals.py::TestMolecularOrbital::test_max_electronegativity": 0.004327957925852388, + "tests/core/test_molecular_orbitals.py::TestMolecularOrbital::test_obtain_band_edges": 0.0013270010240375996, + "tests/core/test_operations.py::TestMagSymmOp::test_operate_magmom": 0.0021593329729512334, + "tests/core/test_operations.py::TestMagSymmOp::test_as_from_dict": 0.0012797510134987533, + "tests/core/test_operations.py::TestMagSymmOp::test_xyzt_string": 0.0013733739615418017, + "tests/core/test_operations.py::TestSymmOp::test_apply_rotation_only": 0.0012938339496031404, + "tests/core/test_operations.py::TestSymmOp::test_are_symmetrically_related": 0.0012189170229248703, + "tests/core/test_operations.py::TestSymmOp::test_are_symmetrically_related_vectors": 0.0014421659871004522, + "tests/core/test_operations.py::TestSymmOp::test_inverse": 0.0012554159620776772, + "tests/core/test_operations.py::TestSymmOp::test_inversion": 0.0012409999617375433, + "tests/core/test_operations.py::TestSymmOp::test_operate": 0.0012290010345168412, + "tests/core/test_operations.py::TestSymmOp::test_operate_multi": 0.0013142511015757918, + "tests/core/test_operations.py::TestSymmOp::test_properties": 0.0012810410116799176, + "tests/core/test_operations.py::TestSymmOp::test_reflection": 0.0012293330510146916, + "tests/core/test_operations.py::TestSymmOp::test_as_from_dict": 0.0013375001144595444, + "tests/core/test_operations.py::TestSymmOp::test_transform_tensor": 0.0014549170155078173, + "tests/core/test_operations.py::TestSymmOp::test_xyz": 0.0016299580456689, + "tests/core/test_periodic_table.py::TestDummySpecies::test_attr": 0.0002989170025102794, + "tests/core/test_periodic_table.py::TestDummySpecies::test_eq": 0.0002703749923966825, + "tests/core/test_periodic_table.py::TestDummySpecies::test_from_string": 0.0003050409723073244, + "tests/core/test_periodic_table.py::TestDummySpecies::test_immutable": 0.00031504296930506825, + "tests/core/test_periodic_table.py::TestDummySpecies::test_init": 0.0004290830693207681, + "tests/core/test_periodic_table.py::TestDummySpecies::test_pickle": 0.00027341704117134213, + "tests/core/test_periodic_table.py::TestDummySpecies::test_sort": 0.0002741679782047868, + "tests/core/test_periodic_table.py::TestElement::test_attributes": 0.062131249986123294, + "tests/core/test_periodic_table.py::TestElement::test_block": 0.0014038329245522618, + "tests/core/test_periodic_table.py::TestElement::test_data": 0.0012909590150229633, + "tests/core/test_periodic_table.py::TestElement::test_deepcopy": 0.0012744999257847667, + "tests/core/test_periodic_table.py::TestElement::test_dict": 0.001254875969607383, + "tests/core/test_periodic_table.py::TestElement::test_from_name": 0.0013069179840385914, + "tests/core/test_periodic_table.py::TestElement::test_from_row_and_group": 0.0016577079659327865, + "tests/core/test_periodic_table.py::TestElement::test_full_electronic_structure": 0.0013409170205704868, + "tests/core/test_periodic_table.py::TestElement::test_ground_state_term_symbol": 0.002299125015269965, + "tests/core/test_periodic_table.py::TestElement::test_group": 0.0012170420959591866, + "tests/core/test_periodic_table.py::TestElement::test_ie_ea": 0.0011918750242330134, + "tests/core/test_periodic_table.py::TestElement::test_init": 0.00143912504427135, + "tests/core/test_periodic_table.py::TestElement::test_is": 0.0012667080154642463, + "tests/core/test_periodic_table.py::TestElement::test_is_metal": 0.0012503340258263052, + "tests/core/test_periodic_table.py::TestElement::test_nan_x": 0.001281375007238239, + "tests/core/test_periodic_table.py::TestElement::test_oxidation_states": 0.0012213330483064055, + "tests/core/test_periodic_table.py::TestElement::test_pickle": 0.010329875047318637, + "tests/core/test_periodic_table.py::TestElement::test_print_periodic_table": 0.009697831992525607, + "tests/core/test_periodic_table.py::TestElement::test_radii": 0.0014834590256214142, + "tests/core/test_periodic_table.py::TestElement::test_row": 0.0012839180417358875, + "tests/core/test_periodic_table.py::TestElement::test_sort": 0.001303499040659517, + "tests/core/test_periodic_table.py::TestElement::test_term_symbols": 0.002003665955271572, + "tests/core/test_periodic_table.py::TestElement::test_valence": 0.0014555829693563282, + "tests/core/test_periodic_table.py::TestSpecies::test_attr": 0.0012557489681057632, + "tests/core/test_periodic_table.py::TestSpecies::test_cmp": 0.001242915983311832, + "tests/core/test_periodic_table.py::TestSpecies::test_deepcopy": 0.0012872919905930758, + "tests/core/test_periodic_table.py::TestSpecies::test_eq": 0.0014182509621605277, + "tests/core/test_periodic_table.py::TestSpecies::test_get_crystal_field_spin": 0.0016555010224692523, + "tests/core/test_periodic_table.py::TestSpecies::test_get_nmr_mom": 0.0015882920124568045, + "tests/core/test_periodic_table.py::TestSpecies::test_get_shannon_radius": 0.0014892080216668546, + "tests/core/test_periodic_table.py::TestSpecies::test_init": 0.0014829590218141675, + "tests/core/test_periodic_table.py::TestSpecies::test_ionic_radius": 0.001733499055262655, + "tests/core/test_periodic_table.py::TestSpecies::test_no_oxidation_state": 0.0012180840130895376, + "tests/core/test_periodic_table.py::TestSpecies::test_pickle": 0.0022111249854788184, + "tests/core/test_periodic_table.py::TestSpecies::test_sort": 0.0013953330926597118, + "tests/core/test_periodic_table.py::TestSpecies::test_stringify": 0.0013461660128086805, + "tests/core/test_periodic_table.py::TestSpecies::test_to_from_string": 0.001370625977870077, "tests/core/test_periodic_table.py::TestFunc::test_get_el_sp": 0.00031133199809119105, "tests/core/test_periodic_table.py::test_symbol_oxi_state_str[Ca+-Ca-1]": 0.00030433404026553035, "tests/core/test_periodic_table.py::test_symbol_oxi_state_str[Fe-Fe-None]": 0.00034308305475860834, @@ -1028,7 +1022,7 @@ "tests/core/test_structure.py::TestMolecule::test_relax_ase_mol": 0.00900712498696521, "tests/core/test_structure.py::TestMolecule::test_relax_ase_mol_return_traj": 0.006231041974388063, "tests/core/test_structure.py::TestMolecule::test_relax_gfnxtb": 0.00013875000877305865, - "tests/core/test_structure.py::TestMolecule::test_replace": 0.0018936669803224504, + "tests/core/test_structure.py::TestMolecule::test_replace_species": 0.0018936669803224504, "tests/core/test_structure.py::TestMolecule::test_rotate_sites": 0.0017342490027658641, "tests/core/test_structure.py::TestMolecule::test_substitute": 0.005597625044174492, "tests/core/test_structure.py::TestMolecule::test_as_from_dict": 0.0021212499705143273, @@ -1110,7 +1104,7 @@ "tests/core/test_surface.py::TestSlabGenerator::test_get_slabs": 0.6161378750111908, "tests/core/test_surface.py::TestSlabGenerator::test_get_tasker2_slabs": 0.0742877500015311, "tests/core/test_surface.py::TestSlabGenerator::test_move_to_other_side": 0.8720399169833399, - "tests/core/test_surface.py::TestSlabGenerator::test_nonstoichiometric_symmetrized_slab": 3.6920437510707416, + "tests/core/test_surface.py::TestSlabGenerator::test_non_stoichiometric_symmetrized_slab": 3.6920437510707416, "tests/core/test_surface.py::TestSlabGenerator::test_normal_search": 0.42782758397515863, "tests/core/test_surface.py::TestSlabGenerator::test_triclinic_TeI": 0.2443105829297565, "tests/core/test_tensors.py::TestSquareTensor::test_get_scaled": 0.0019302499713376164, @@ -1344,13 +1338,13 @@ "tests/entries/test_compatibility.py::TestMITAqueousCompatibility::test_aqueous_compat": 0.0009554990101605654, "tests/entries/test_compatibility.py::TestMITAqueousCompatibility::test_dont_error_on_weird_elements": 0.00039654201827943325, "tests/entries/test_compatibility.py::TestMITAqueousCompatibility::test_msonable": 0.000427792954724282, - "tests/entries/test_compatibility.py::TestMITAqueousCompatibility::test_potcar_doenst_match_structure": 0.0004772910033352673, + "tests/entries/test_compatibility.py::TestMITAqueousCompatibility::test_potcar_not_match_structure": 0.0004772910033352673, "tests/entries/test_compatibility.py::TestMITCompatibility::test_U_value": 0.015942417085170746, "tests/entries/test_compatibility.py::TestMITCompatibility::test_correction_value": 0.0007064170204102993, "tests/entries/test_compatibility.py::TestMITCompatibility::test_element_processing": 0.00044875004095956683, "tests/entries/test_compatibility.py::TestMITCompatibility::test_get_explanation_dict": 0.0006203330704011023, "tests/entries/test_compatibility.py::TestMITCompatibility::test_msonable": 0.0004487930564209819, - "tests/entries/test_compatibility.py::TestMITCompatibility::test_potcar_doenst_match_structure": 0.0004331250092945993, + "tests/entries/test_compatibility.py::TestMITCompatibility::test_potcar_not_match_structure": 0.0004331250092945993, "tests/entries/test_compatibility.py::TestMITCompatibility::test_potcar_spec_is_none": 0.00044504296965897083, "tests/entries/test_compatibility.py::TestMITCompatibility::test_process_entry": 0.000629501068033278, "tests/entries/test_compatibility.py::TestMITCompatibility::test_revert_to_symbols": 0.0005876249633729458, @@ -1430,7 +1424,7 @@ "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_gf_sisso": 0.012284540978726, "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_interpolation": 0.014316249988041818, "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_normalize": 0.048159833007957786, - "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_str": 0.012615332962013781, + "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_repr": 0.012615332962013781, "tests/entries/test_computed_entries.py::TestGibbsComputedStructureEntry::test_as_from_dict": 0.01511958398623392, "tests/entries/test_computed_entries.py::test_composition_energy_adjustment": 0.00030474894447252154, "tests/entries/test_computed_entries.py::test_constant_energy_adjustment": 0.00030654098372906446, @@ -1531,9 +1525,9 @@ "tests/io/abinit/test_abiobjects.py::TestRelaxation::test_base": 0.0019369579968042672, "tests/io/abinit/test_abiobjects.py::TestSmearing::test_base": 0.0023311239783652127, "tests/io/abinit/test_abiobjects.py::TestSpinMode::test_base": 0.002095124975312501, - "tests/io/abinit/test_inputs.py::AbinitInputTestCase::test_api": 0.018387624993920326, - "tests/io/abinit/test_inputs.py::AbinitInputTestCase::test_helper_functions": 0.0280475010513328, - "tests/io/abinit/test_inputs.py::AbinitInputTestCase::test_input_errors": 0.0036487909965217113, + "tests/io/abinit/test_inputs.py::TestAbinitInput::test_api": 0.018387624993920326, + "tests/io/abinit/test_inputs.py::TestAbinitInput::test_helper_functions": 0.0280475010513328, + "tests/io/abinit/test_inputs.py::TestAbinitInput::test_input_errors": 0.0036487909965217113, "tests/io/abinit/test_inputs.py::TestFactory::test_ebands_input": 0.059605583955999464, "tests/io/abinit/test_inputs.py::TestFactory::test_gs_input": 0.004122291051317006, "tests/io/abinit/test_inputs.py::TestFactory::test_ion_ioncell_relax_input": 0.003928000049199909, @@ -1541,10 +1535,10 @@ "tests/io/abinit/test_inputs.py::TestShiftMode::test_shiftmode": 0.0020915840286761522, "tests/io/abinit/test_netcdf.py::TestEtsfReader::test_read_si2": 0.001917207962833345, "tests/io/abinit/test_netcdf.py::TestAbinitHeader::test_api": 0.0019235840300098062, - "tests/io/abinit/test_pseudos.py::PseudoTestCase::test_nc_pseudos": 0.008074542041867971, - "tests/io/abinit/test_pseudos.py::PseudoTestCase::test_oncvpsp_pseudo_fr": 0.002990167122334242, - "tests/io/abinit/test_pseudos.py::PseudoTestCase::test_oncvpsp_pseudo_sr": 0.0030953330569900572, - "tests/io/abinit/test_pseudos.py::PseudoTestCase::test_pawxml_pseudos": 0.025607499002944678, + "tests/io/abinit/test_pseudos.py::TestPseudo::test_nc_pseudos": 0.008074542041867971, + "tests/io/abinit/test_pseudos.py::TestPseudo::test_oncvpsp_pseudo_fr": 0.002990167122334242, + "tests/io/abinit/test_pseudos.py::TestPseudo::test_oncvpsp_pseudo_sr": 0.0030953330569900572, + "tests/io/abinit/test_pseudos.py::TestPseudo::test_pawxml_pseudos": 0.025607499002944678, "tests/io/abinit/test_pseudos.py::TestPseudoTable::test_methods": 0.009253333962988108, "tests/io/cp2k/test_inputs.py::TestBasisAndPotential::test_basis": 0.002900875057093799, "tests/io/cp2k/test_inputs.py::TestBasisAndPotential::test_basis_info": 0.002058790996670723, @@ -1558,7 +1552,7 @@ "tests/io/cp2k/test_inputs.py::TestInput::test_mongo": 0.005608417035546154, "tests/io/cp2k/test_inputs.py::TestInput::test_odd_file": 0.011910750006791204, "tests/io/cp2k/test_inputs.py::TestInput::test_preprocessor": 0.005230542039498687, - "tests/io/cp2k/test_inputs.py::TestInput::test_sectionlist": 0.005215583078097552, + "tests/io/cp2k/test_inputs.py::TestInput::test_section_list": 0.005215583078097552, "tests/io/cp2k/test_outputs.py::TestSet::test_band": 0.031144334003329277, "tests/io/cp2k/test_outputs.py::TestSet::test_chi": 0.029846749908756465, "tests/io/cp2k/test_outputs.py::TestSet::test_dos": 0.030147832992952317, @@ -2289,7 +2283,7 @@ "tests/optimization/test_linear_assignment.py::TestLinearAssignment::test_boolean_inputs": 0.00030920800054445863, "tests/optimization/test_linear_assignment.py::TestLinearAssignment::test_rectangular": 0.00033775094198063016, "tests/optimization/test_linear_assignment.py::TestLinearAssignment::test_small_range": 0.0002947920002043247, - "tests/optimization/test_neighbors.py::NeighborsTestCase::test_points_in_spheres": 0.0025064569781534374, + "tests/optimization/test_neighbors.py::TestNeighbors::test_points_in_spheres": 0.0025064569781534374, "tests/phonon/test_bandstructure.py::TestPhononBandStructureSymmLine::test_basic": 0.016336957982275635, "tests/phonon/test_bandstructure.py::TestPhononBandStructureSymmLine::test_branches": 0.014964499976485968, "tests/phonon/test_bandstructure.py::TestPhononBandStructureSymmLine::test_dict_methods": 0.0630996249965392, @@ -2515,11 +2509,11 @@ "tests/util/test_graph_hashing.py::test_graph_hash": 0.00039154296973720193, "tests/util/test_graph_hashing.py::test_subgraph_hashes": 0.0003051679814234376, "tests/util/test_io_utils.py::TestFunc::test_micro_pyawk": 0.08383158198557794, - "tests/util/test_num_utils.py::FuncTestCase::test_abs_cap": 0.0003750419127754867, - "tests/util/test_num_utils.py::FuncTestCase::test_min_max_indexes": 0.0002843319671228528, - "tests/util/test_num_utils.py::FuncTestCase::test_round": 0.0004711660440079868, - "tests/util/test_plotting.py::FuncTestCase::test_plot_periodic_heatmap": 0.0028367919730953872, - "tests/util/test_plotting.py::FuncTestCase::test_van_arkel_triangle": 0.0405823330511339, + "tests/util/test_num_utils.py::TestFunc::test_abs_cap": 0.0003750419127754867, + "tests/util/test_num_utils.py::TestFunc::test_min_max_indexes": 0.0002843319671228528, + "tests/util/test_num_utils.py::TestFunc::test_round": 0.0004711660440079868, + "tests/util/test_plotting.py::TestFunc::test_plot_periodic_heatmap": 0.0028367919730953872, + "tests/util/test_plotting.py::TestFunc::test_van_arkel_triangle": 0.0405823330511339, "tests/util/test_provenance.py::StructureNLCase::test_authors": 0.001689917000476271, "tests/util/test_provenance.py::StructureNLCase::test_data": 0.0003868750645779073, "tests/util/test_provenance.py::StructureNLCase::test_eq": 0.0011642919853329659, diff --git a/tests/files/classical_md_mols/CCO.npy b/tests/files/classical_md_mols/CCO.npy new file mode 100644 index 00000000000..4724234e784 Binary files /dev/null and b/tests/files/classical_md_mols/CCO.npy differ diff --git a/tests/files/classical_md_mols/CCO.xyz b/tests/files/classical_md_mols/CCO.xyz new file mode 100644 index 00000000000..403720cd49b --- /dev/null +++ b/tests/files/classical_md_mols/CCO.xyz @@ -0,0 +1,11 @@ +9 + +C 1.000000 1.000000 0.000000 +C -0.515000 1.000000 0.000000 +O -0.999000 1.000000 1.335000 +H 1.390000 1.001000 -1.022000 +H 1.386000 0.119000 0.523000 +H 1.385000 1.880000 0.526000 +H -0.907000 0.118000 -0.516000 +H -0.897000 1.894000 -0.501000 +H -0.661000 0.198000 1.768000 diff --git a/tests/files/classical_md_mols/FEC-r.xyz b/tests/files/classical_md_mols/FEC-r.xyz new file mode 100644 index 00000000000..f94a8923eef --- /dev/null +++ b/tests/files/classical_md_mols/FEC-r.xyz @@ -0,0 +1,12 @@ +10 + +O 1.000000 1.000000 0.000000 +C -0.219000 1.000000 0.000000 +O -0.984000 1.000000 1.133000 +C -2.322000 0.780000 0.720000 +C -2.300000 1.205000 -0.711000 +H -3.034000 0.686000 -1.332000 +F -2.507000 2.542000 -0.809000 +O -0.983000 0.948000 -1.128000 +H -3.008000 1.375000 1.328000 +H -2.544000 -0.285000 0.838000 diff --git a/tests/files/classical_md_mols/FEC-s.xyz b/tests/files/classical_md_mols/FEC-s.xyz new file mode 100644 index 00000000000..af492afc655 --- /dev/null +++ b/tests/files/classical_md_mols/FEC-s.xyz @@ -0,0 +1,12 @@ +10 + +O 1.000000 1.000000 0.000000 +C -0.219000 1.000000 0.000000 +O -0.981000 1.000000 1.133000 +C -2.323000 0.828000 0.723000 +C -2.305000 1.254000 -0.707000 +H -2.567000 2.305000 -0.862000 +F -3.125000 0.469000 -1.445000 +O -0.983000 1.001000 -1.127000 +H -2.991000 1.447000 1.328000 +H -2.610000 -0.222000 0.848000 diff --git a/tests/files/classical_md_mols/FEC.npy b/tests/files/classical_md_mols/FEC.npy new file mode 100644 index 00000000000..016912f7d64 Binary files /dev/null and b/tests/files/classical_md_mols/FEC.npy differ diff --git a/tests/analysis/gb/__init__.py b/tests/files/classical_md_mols/FEC_bad.npy similarity index 100% rename from tests/analysis/gb/__init__.py rename to tests/files/classical_md_mols/FEC_bad.npy diff --git a/tests/files/classical_md_mols/Li.npy b/tests/files/classical_md_mols/Li.npy new file mode 100644 index 00000000000..b87c6ecebb2 Binary files /dev/null and b/tests/files/classical_md_mols/Li.npy differ diff --git a/tests/files/classical_md_mols/Li.xyz b/tests/files/classical_md_mols/Li.xyz new file mode 100644 index 00000000000..7f08d77c84a --- /dev/null +++ b/tests/files/classical_md_mols/Li.xyz @@ -0,0 +1,3 @@ +1 + +Li 0.0 0.0 0.0 diff --git a/tests/files/classical_md_mols/PF6.npy b/tests/files/classical_md_mols/PF6.npy new file mode 100644 index 00000000000..1a4c723b70c Binary files /dev/null and b/tests/files/classical_md_mols/PF6.npy differ diff --git a/tests/files/classical_md_mols/PF6.xyz b/tests/files/classical_md_mols/PF6.xyz new file mode 100644 index 00000000000..3f3df87fa83 --- /dev/null +++ b/tests/files/classical_md_mols/PF6.xyz @@ -0,0 +1,9 @@ +7 + +P 0.0 0.0 0.0 +F 1.6 0.0 0.0 +F -1.6 0.0 0.0 +F 0.0 1.6 0.0 +F 0.0 -1.6 0.0 +F 0.0 0.0 1.6 +F 0.0 0.0 -1.6 diff --git a/tests/files/cohp/lobsterin.1 b/tests/files/cohp/lobsterin.1 index 046868f0691..8b92bf4613b 100644 --- a/tests/files/cohp/lobsterin.1 +++ b/tests/files/cohp/lobsterin.1 @@ -4,7 +4,7 @@ basisSet pbeVaspFit2015 gaussianSmearingWidth 0.1 basisfunctions Fe 3d 4p 4s basisfunctions Co 3d 4p 4s -skipdos +skipDOS skipcohp skipcoop skipPopulationAnalysis diff --git a/tests/files/cohp/lobsterin.3 b/tests/files/cohp/lobsterin.3 index e0ca783e1d8..e39e8e9230e 100644 --- a/tests/files/cohp/lobsterin.3 +++ b/tests/files/cohp/lobsterin.3 @@ -3,4 +3,6 @@ COHPendEnergy 5.0 basisSet pbeVaspFit2015 gaussianSmearingWidth 0.1 basisfunctions Fe 3d 4p 4s -basisfunctions Co 3d 4p 4s \ No newline at end of file +basisfunctions Co 3d 4p 4s +skipcoBI +SKIPDOS \ No newline at end of file diff --git a/tests/files/surface_tests/Au_slab_init.cif b/tests/files/surfaces/Au_slab_init.cif similarity index 100% rename from tests/files/surface_tests/Au_slab_init.cif rename to tests/files/surfaces/Au_slab_init.cif diff --git a/tests/files/surface_tests/CONTCAR.relax1.gz b/tests/files/surfaces/CONTCAR.relax1.gz similarity index 100% rename from tests/files/surface_tests/CONTCAR.relax1.gz rename to tests/files/surfaces/CONTCAR.relax1.gz diff --git a/tests/files/surface_tests/Cu_entries.txt b/tests/files/surfaces/Cu_entries.txt similarity index 100% rename from tests/files/surface_tests/Cu_entries.txt rename to tests/files/surfaces/Cu_entries.txt diff --git a/tests/files/surface_tests/Cu_slab_fin.cif b/tests/files/surfaces/Cu_slab_fin.cif similarity index 100% rename from tests/files/surface_tests/Cu_slab_fin.cif rename to tests/files/surfaces/Cu_slab_fin.cif diff --git a/tests/files/surface_tests/Cu_slab_init.cif b/tests/files/surfaces/Cu_slab_init.cif similarity index 100% rename from tests/files/surface_tests/Cu_slab_init.cif rename to tests/files/surfaces/Cu_slab_init.cif diff --git a/tests/files/surface_tests/LOCPOT.gz b/tests/files/surfaces/LOCPOT.gz similarity index 100% rename from tests/files/surface_tests/LOCPOT.gz rename to tests/files/surfaces/LOCPOT.gz diff --git a/tests/files/surface_tests/La_fcc_entries.txt b/tests/files/surfaces/La_fcc_entries.txt similarity index 100% rename from tests/files/surface_tests/La_fcc_entries.txt rename to tests/files/surfaces/La_fcc_entries.txt diff --git a/tests/files/surface_tests/La_hcp_entries.txt b/tests/files/surfaces/La_hcp_entries.txt similarity index 100% rename from tests/files/surface_tests/La_hcp_entries.txt rename to tests/files/surfaces/La_hcp_entries.txt diff --git a/tests/files/surface_tests/MgO_slab_entries.txt b/tests/files/surfaces/MgO_slab_entries.txt similarity index 100% rename from tests/files/surface_tests/MgO_slab_entries.txt rename to tests/files/surfaces/MgO_slab_entries.txt diff --git a/tests/files/surface_tests/OUTCAR.relax1.gz b/tests/files/surfaces/OUTCAR.relax1.gz similarity index 100% rename from tests/files/surface_tests/OUTCAR.relax1.gz rename to tests/files/surfaces/OUTCAR.relax1.gz diff --git a/tests/files/surface_tests/ZnO-wz.cif b/tests/files/surfaces/ZnO-wz.cif similarity index 100% rename from tests/files/surface_tests/ZnO-wz.cif rename to tests/files/surfaces/ZnO-wz.cif diff --git a/tests/files/surface_tests/cs_entries_o_ads.json b/tests/files/surfaces/cs_entries_o_ads.json similarity index 100% rename from tests/files/surface_tests/cs_entries_o_ads.json rename to tests/files/surfaces/cs_entries_o_ads.json diff --git a/tests/files/surface_tests/cs_entries_slabs.json b/tests/files/surfaces/cs_entries_slabs.json similarity index 100% rename from tests/files/surface_tests/cs_entries_slabs.json rename to tests/files/surfaces/cs_entries_slabs.json diff --git a/tests/files/surface_tests/icsd_LiCoO2.cif b/tests/files/surfaces/icsd_LiCoO2.cif similarity index 100% rename from tests/files/surface_tests/icsd_LiCoO2.cif rename to tests/files/surfaces/icsd_LiCoO2.cif diff --git a/tests/files/surface_tests/icsd_TeI.cif b/tests/files/surfaces/icsd_TeI.cif similarity index 100% rename from tests/files/surface_tests/icsd_TeI.cif rename to tests/files/surfaces/icsd_TeI.cif diff --git a/tests/files/surface_tests/icsd_batio3.cif b/tests/files/surfaces/icsd_batio3.cif similarity index 100% rename from tests/files/surface_tests/icsd_batio3.cif rename to tests/files/surfaces/icsd_batio3.cif diff --git a/tests/files/surface_tests/isolated_O_entry.txt b/tests/files/surfaces/isolated_O_entry.txt similarity index 100% rename from tests/files/surface_tests/isolated_O_entry.txt rename to tests/files/surfaces/isolated_O_entry.txt diff --git a/tests/files/surface_tests/reconstructions/Fe_bcc_100_zigzag_rt2xrt2.cif b/tests/files/surfaces/reconstructions/Fe_bcc_100_zigzag_rt2xrt2.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Fe_bcc_100_zigzag_rt2xrt2.cif rename to tests/files/surfaces/reconstructions/Fe_bcc_100_zigzag_rt2xrt2.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_100_missing_row_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_100_missing_row_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_100_missing_row_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_100_missing_row_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_110_missing_row_1x2.cif b/tests/files/surfaces/reconstructions/Ni_fcc_110_missing_row_1x2.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_110_missing_row_1x2.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_110_missing_row_1x2.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_bridge_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_bridge_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_bridge_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_bridge_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_ft_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_ft_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_ft_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_ft_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_h_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_h_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_h_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_h_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_ht_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_ht_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_ht_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_ht_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_t_1x1.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_t_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_adatom_t_1x1.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_adatom_t_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Ni_fcc_111_missing_row_1x2.cif b/tests/files/surfaces/reconstructions/Ni_fcc_111_missing_row_1x2.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Ni_fcc_111_missing_row_1x2.cif rename to tests/files/surfaces/reconstructions/Ni_fcc_111_missing_row_1x2.cif diff --git a/tests/files/surface_tests/reconstructions/Si_diamond_100_2x1.cif b/tests/files/surfaces/reconstructions/Si_diamond_100_2x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Si_diamond_100_2x1.cif rename to tests/files/surfaces/reconstructions/Si_diamond_100_2x1.cif diff --git a/tests/files/surface_tests/reconstructions/Si_diamond_110_1x1.cif b/tests/files/surfaces/reconstructions/Si_diamond_110_1x1.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Si_diamond_110_1x1.cif rename to tests/files/surfaces/reconstructions/Si_diamond_110_1x1.cif diff --git a/tests/files/surface_tests/reconstructions/Si_diamond_111_1x2.cif b/tests/files/surfaces/reconstructions/Si_diamond_111_1x2.cif similarity index 100% rename from tests/files/surface_tests/reconstructions/Si_diamond_111_1x2.cif rename to tests/files/surfaces/reconstructions/Si_diamond_111_1x2.cif diff --git a/tests/files/surface_tests/ucell_entries.txt b/tests/files/surfaces/ucell_entries.txt similarity index 100% rename from tests/files/surface_tests/ucell_entries.txt rename to tests/files/surfaces/ucell_entries.txt diff --git a/tests/io/abinit/test_inputs.py b/tests/io/abinit/test_inputs.py index c4301c76435..2b281da3bff 100644 --- a/tests/io/abinit/test_inputs.py +++ b/tests/io/abinit/test_inputs.py @@ -33,7 +33,7 @@ def abiref_files(*filenames): return [f"{TEST_DIR}/{file}" for file in filenames] -class AbinitInputTestCase(PymatgenTest): +class TestAbinitInput(PymatgenTest): """Unit tests for BasicAbinitInput.""" def test_api(self): @@ -223,7 +223,7 @@ def test_api(self): split = multi.split_datasets() assert len(split) == 2 - assert all(split[i] == multi[i] for i in range(multi.ndtset)) + assert all(split[idx] == multi[idx] for idx in range(multi.ndtset)) assert multi.to_str(with_pseudos=False) tmpdir = tempfile.mkdtemp() diff --git a/tests/io/abinit/test_pseudos.py b/tests/io/abinit/test_pseudos.py index e2eec094248..b3a94f57a67 100644 --- a/tests/io/abinit/test_pseudos.py +++ b/tests/io/abinit/test_pseudos.py @@ -12,7 +12,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/abinit" -class PseudoTestCase(PymatgenTest): +class TestPseudo(PymatgenTest): def setUp(self): nc_pseudo_fnames = collections.defaultdict(list) nc_pseudo_fnames["Si"] = [f"{TEST_DIR}/{file}" for file in ("14si.pspnc", "14si.4.hgh", "14-Si.LDA.fhi")] diff --git a/tests/io/aims/test_aims_inputs.py b/tests/io/aims/test_aims_inputs.py index 5fecda3dcc1..d0b0d5a30c4 100644 --- a/tests/io/aims/test_aims_inputs.py +++ b/tests/io/aims/test_aims_inputs.py @@ -24,8 +24,8 @@ def test_read_write_si_in(tmp_path: Path): si = AimsGeometryIn.from_file(TEST_DIR / "geometry.in.si.gz") - in_lattice = np.array([[0.0, 2.715, 2.716], [2.717, 0.0, 2.718], [2.719, 2.720, 0.0]]) - in_coords = np.array([[0.0, 0.0, 0.0], [0.25, 0.24, 0.26]]) + in_lattice = np.array([[0, 2.715, 2.716], [2.717, 0, 2.718], [2.719, 2.720, 0]]) + in_coords = np.array([[0, 0, 0], [0.25, 0.24, 0.26]]) assert all(sp.symbol == "Si" for sp in si.structure.species) assert_allclose(si.structure.lattice.matrix, in_lattice) @@ -50,9 +50,9 @@ def test_read_h2o_in(tmp_path: Path): h2o = AimsGeometryIn.from_file(TEST_DIR / "geometry.in.h2o.gz") in_coords = [ - [0.0, 0.0, 0.119262], - [0.0, 0.763239, -0.477047], - [0.0, -0.763239, -0.477047], + [0, 0, 0.119262], + [0, 0.763239, -0.477047], + [0, -0.763239, -0.477047], ] assert all(sp.symbol == symb for sp, symb in zip(h2o.structure.species, ["O", "H", "H"])) @@ -107,12 +107,12 @@ def test_aims_cube(): AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], origin=[0]) with pytest.raises(ValueError, match="Only three cube edges can be passed"): - AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], edges=[[0.0, 0.0, 0.1]]) + AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], edges=[[0, 0, 0.1]]) with pytest.raises(ValueError, match="Each cube edge must have 3 components"): AimsCube( type=ALLOWED_AIMS_CUBE_TYPES[0], - edges=[[0.0, 0.0, 0.1], [0.1, 0.0, 0.0], [0.1, 0.0]], + edges=[[0, 0, 0.1], [0.1, 0, 0], [0.1, 0]], ) with pytest.raises(ValueError, match="elf_type is only used when the cube type is elf. Otherwise it must be None"): @@ -124,7 +124,7 @@ def test_aims_cube(): test_cube = AimsCube( type="elf", origin=[0, 0, 0], - edges=[[0.01, 0, 0], [0.0, 0.01, 0], [0.0, 0, 0.01]], + edges=[[0.01, 0, 0], [0, 0.01, 0], [0, 0, 0.01]], points=[100, 100, 100], spin_state=1, kpoint=1, diff --git a/tests/io/cp2k/test_inputs.py b/tests/io/cp2k/test_inputs.py index af1e998fb55..5303c7d0285 100644 --- a/tests/io/cp2k/test_inputs.py +++ b/tests/io/cp2k/test_inputs.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import pytest from numpy.testing import assert_allclose, assert_array_equal from pytest import approx @@ -10,6 +11,7 @@ BasisInfo, Coord, Cp2kInput, + DataFile, GaussianTypeOrbitalBasisSet, GthPotential, Keyword, @@ -22,7 +24,7 @@ ) from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -Si_structure = Structure( +si_struct = Structure( lattice=[[0, 2.734364, 2.734364], [2.734364, 0, 2.734364], [2.734364, 2.734364, 0]], species=["Si", "Si"], coords=[[0, 0, 0], [0.25, 0.25, 0.25]], @@ -34,9 +36,9 @@ coords=[[-1, -1, -1]], ) -molecule = Molecule(species=["C", "H"], coords=[[0, 0, 0], [1, 1, 1]]) +ch_mol = Molecule(species=["C", "H"], coords=[[0, 0, 0], [1, 1, 1]]) -basis = """ +BASIS_FILE_STR = """ H SZV-MOLOPT-GTH SZV-MOLOPT-GTH-q1 1 2 0 0 7 1 @@ -48,59 +50,65 @@ 0.066918004004 0.037148121400 0.021708243634 -0.001125195500 """ -all_hydrogen = """ +ALL_HYDROGEN_STR = """ H ALLELECTRON ALL 1 0 0 0.20000000 0 """ -pot_hydrogen = """ +POT_HYDROGEN_STR = """ H GTH-PBE-q1 GTH-PBE 1 0.20000000 2 -4.17890044 0.72446331 0 """ +CP2K_INPUT_STR = """ +&GLOBAL + RUN_TYPE ENERGY + PROJECT_NAME CP2K ! default name +&END +""" class TestBasisAndPotential(PymatgenTest): def test_basis_info(self): # Ensure basis metadata can be read from string - b = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1-SCAN") - assert b.valence == 2 - assert b.molopt - assert b.electrons == 1 - assert b.polarization == 1 - assert b.cc - assert b.pc - assert b.xc == "SCAN" - - # Ensure one-way softmatching works - b2 = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1") - assert b2.softmatch(b) - assert not b.softmatch(b2) - - b3 = BasisInfo.from_str("cpFIT3") - assert b3.valence == 3 - assert b3.polarization == 1 - assert b3.contracted, True + basis_info = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1-SCAN") + assert basis_info.valence == 2 + assert basis_info.molopt + assert basis_info.electrons == 1 + assert basis_info.polarization == 1 + assert basis_info.cc + assert basis_info.pc + assert basis_info.xc == "SCAN" + + # Ensure one-way soft-matching works + basis_info2 = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1") + assert basis_info2.softmatch(basis_info) + assert not basis_info.softmatch(basis_info2) + + basis_info3 = BasisInfo.from_str("cpFIT3") + assert basis_info3.valence == 3 + assert basis_info3.polarization == 1 + assert basis_info3.contracted, True def test_potential_info(self): # Ensure potential metadata can be read from string - p = PotentialInfo.from_str("GTH-PBE-q1-NLCC") - assert p.potential_type == "GTH" - assert p.xc == "PBE" - assert p.nlcc + pot_info = PotentialInfo.from_str("GTH-PBE-q1-NLCC") + assert pot_info.potential_type == "GTH" + assert pot_info.xc == "PBE" + assert pot_info.nlcc - # Ensure one-way softmatching works - p2 = PotentialInfo.from_str("GTH-q1-NLCC") - assert p2.softmatch(p) - assert not p.softmatch(p2) + # Ensure one-way soft-matching works + pot_info2 = PotentialInfo.from_str("GTH-q1-NLCC") + assert pot_info2.softmatch(pot_info) + assert not pot_info.softmatch(pot_info2) def test_basis(self): # Ensure cp2k formatted string can be read for data correctly - mol_opt = GaussianTypeOrbitalBasisSet.from_str(basis) + mol_opt = GaussianTypeOrbitalBasisSet.from_str(BASIS_FILE_STR) assert mol_opt.nexp == [7] # Basis file can read from strings - bf = BasisFile.from_str(basis) + bf = BasisFile.from_str(BASIS_FILE_STR) for obj in [mol_opt, bf.objects[0]]: assert_allclose( obj.exponents[0], @@ -125,17 +133,22 @@ def test_basis(self): def test_potentials(self): # Ensure cp2k formatted string can be read for data correctly - h_all_elec = GthPotential.from_str(all_hydrogen) + h_all_elec = GthPotential.from_str(ALL_HYDROGEN_STR) assert h_all_elec.potential == "All Electron" - pot = GthPotential.from_str(pot_hydrogen) + pot = GthPotential.from_str(POT_HYDROGEN_STR) assert pot.potential == "Pseudopotential" assert pot.r_loc == approx(0.2) assert pot.nexp_ppl == approx(2) assert_allclose(pot.c_exp_ppl, [-4.17890044, 0.72446331]) # Basis file can read from strings - pf = PotentialFile.from_str(pot_hydrogen) - assert pf.objects[0] == pot + pot_file = PotentialFile.from_str(POT_HYDROGEN_STR) + assert pot_file.objects[0] == pot + + pot_file_path = self.tmp_path / "potential-file" + pot_file_path.write_text(POT_HYDROGEN_STR) + pot_from_file = PotentialFile.from_file(pot_file_path) + assert pot_file != pot_from_file # unequal because pot_from_file has filename != None # Ensure keyword can be properly generated kw = pot.get_keyword() @@ -149,27 +162,21 @@ def setUp(self): self.ci = Cp2kInput.from_file(f"{TEST_FILES_DIR}/cp2k/cp2k.inp") def test_basic_sections(self): - s = """ - &GLOBAL - RUN_TYPE ENERGY - PROJECT_NAME CP2K ! default name - &END - """ - ci = Cp2kInput.from_str(s) - assert ci["GLOBAL"]["RUN_TYPE"] == Keyword("RUN_TYPE", "energy") - assert ci["GLOBAL"]["PROJECT_NAME"].description == "default name" - self.assert_msonable(ci) - - def test_sectionlist(self): - s1 = Section("TEST") - sl = SectionList(sections=[s1, s1]) - for s in sl: + cp2k_input = Cp2kInput.from_str(CP2K_INPUT_STR) + assert cp2k_input["GLOBAL"]["RUN_TYPE"] == Keyword("RUN_TYPE", "energy") + assert cp2k_input["GLOBAL"]["PROJECT_NAME"].description == "default name" + self.assert_msonable(cp2k_input) + + def test_section_list(self): + sec1 = Section("TEST") + sec_list = SectionList(sections=[sec1, sec1]) + for s in sec_list: assert isinstance(s, Section) - assert sl[0].name == "TEST" - assert sl[1].name == "TEST" - assert len(sl) == 2 - sl += s1 - assert len(sl) == 3 + assert sec_list[0].name == "TEST" + assert sec_list[1].name == "TEST" + assert len(sec_list) == 2 + sec_list += sec1 + assert len(sec_list) == 3 def test_basic_keywords(self): kwd = Keyword("TEST1", 1, 2) @@ -181,14 +188,14 @@ def test_basic_keywords(self): assert "[Ha]" in kwd.get_str() def test_coords(self): - for struct in [nonsense_struct, Si_structure, molecule]: + for struct in [nonsense_struct, si_struct, ch_mol]: coords = Coord(struct) - for c in coords.keywords.values(): - assert isinstance(c, (Keyword, KeywordList)) + for val in coords.keywords.values(): + assert isinstance(val, (Keyword, KeywordList)) def test_kind(self): - for s in [nonsense_struct, Si_structure, molecule]: - for spec in s.species: + for struct in [nonsense_struct, si_struct, ch_mol]: + for spec in struct.species: assert spec == Kind(spec).specie def test_ci_file(self): @@ -205,20 +212,20 @@ def test_ci_file(self): def test_odd_file(self): scramble = "" - for s in self.ci.get_str(): + for string in self.ci.get_str(): if np.random.rand(1) > 0.5: - if s == "\t": + if string == "\t": scramble += " " - elif s == " ": + elif string == " ": scramble += " " - elif s in ("&", "\n"): - scramble += s - elif s.isalpha(): - scramble += s.lower() + elif string in ("&", "\n"): + scramble += string + elif string.isalpha(): + scramble += string.lower() else: - scramble += s + scramble += string else: - scramble += s + scramble += string # Can you initialize from jumbled input # should be case insensitive and ignore # excessive white space or tabs @@ -236,19 +243,22 @@ def test_preprocessor(self): assert self.ci["FORCE_EVAL"]["DFT"]["SCF"]["MAX_SCF"] == Keyword("MAX_SCF", 1) def test_mongo(self): - s = """ - &GLOBAL - RUN_TYPE ENERGY - PROJECT_NAME CP2K ! default name - &END - """ - s = Cp2kInput.from_str(s) - s.inc({"GLOBAL": {"TEST": 1}}) - assert s["global"]["test"] == Keyword("TEST", 1) - - s.unset({"GLOBAL": "RUN_TYPE"}) - assert "RUN_TYPE" not in s["global"].keywords - - s.set({"GLOBAL": {"SUBSEC": {"TEST2": 2}, "SUBSEC2": {"Test2": 1}}}) - assert s.check("global/SUBSEC") - assert s.check("global/subsec2") + cp2k_input = Cp2kInput.from_str(CP2K_INPUT_STR) + cp2k_input.inc({"GLOBAL": {"TEST": 1}}) + assert cp2k_input["global"]["test"] == Keyword("TEST", 1) + + cp2k_input.unset({"GLOBAL": "RUN_TYPE"}) + assert "RUN_TYPE" not in cp2k_input["global"].keywords + + cp2k_input.set({"GLOBAL": {"SUBSEC": {"TEST2": 2}, "SUBSEC2": {"Test2": 1}}}) + assert cp2k_input.check("global/SUBSEC") + assert cp2k_input.check("global/subsec2") + + +class TestDataFile(PymatgenTest): + def test_data_file(self): + # make temp file with BASIS_FILE_STR + data_file = self.tmp_path / "data-file" + data_file.write_text(BASIS_FILE_STR) + with pytest.raises(NotImplementedError): + DataFile.from_file(data_file) diff --git a/tests/io/cp2k/test_outputs.py b/tests/io/cp2k/test_outputs.py index 4e56fb5ebfe..4efcb22597a 100644 --- a/tests/io/cp2k/test_outputs.py +++ b/tests/io/cp2k/test_outputs.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np from numpy.testing import assert_allclose @@ -10,7 +10,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestSet(unittest.TestCase): +class TestSet(TestCase): def setUp(self): self.out = Cp2kOutput(f"{TEST_FILES_DIR}/cp2k/cp2k.out", auto_load=True) diff --git a/tests/io/feff/test_inputs.py b/tests/io/feff/test_inputs.py index c53ba276cdf..677d62cb3a5 100644 --- a/tests/io/feff/test_inputs.py +++ b/tests/io/feff/test_inputs.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import unittest +from unittest import TestCase from numpy.testing import assert_allclose from pytest import approx @@ -27,7 +27,7 @@ * 4 O 0.333333 0.666667 0.378675""" -class TestHeader(unittest.TestCase): +class TestHeader(TestCase): def test_init(self): filepath = f"{FEFF_TEST_DIR}/HEADER" header = Header.header_string_from_file(filepath) @@ -43,10 +43,9 @@ def test_from_str(self): def test_get_str(self): cif_file = f"{TEST_FILES_DIR}/CoO19128.cif" - h = Header.from_cif_file(cif_file) - head = str(h) + header = Header.from_cif_file(cif_file) assert ( - head.splitlines()[3].split()[-1] == header_string.splitlines()[3].split()[-1] + str(header).splitlines()[3].split()[-1] == header_string.splitlines()[3].split()[-1] ), "Failed to generate HEADER from structure" def test_as_dict_and_from_dict(self): @@ -57,7 +56,7 @@ def test_as_dict_and_from_dict(self): assert str(header) == str(header2), "Header failed to and from dict test" -class TestFeffAtoms(unittest.TestCase): +class TestFeffAtoms(TestCase): @classmethod def setUpClass(cls): cls.structure = Structure.from_file(f"{TEST_FILES_DIR}/CoO19128.cif") @@ -144,7 +143,7 @@ def test_atom_num(self): assert atoms.formula == "Pt37" -class TestFeffTags(unittest.TestCase): +class TestFeffTags(TestCase): def test_init(self): filepath = f"{FEFF_TEST_DIR}/PARAMETERS" parameters = Tags.from_file(filepath) @@ -209,7 +208,7 @@ def test_eels_tags(self): assert dict(tags_2) == ans_1 -class TestFeffPot(unittest.TestCase): +class TestFeffPot(TestCase): def test_init(self): filepath = f"{FEFF_TEST_DIR}/POTENTIALS" feff_pot = Potential.pot_string_from_file(filepath) @@ -242,7 +241,7 @@ def test_as_dict_and_from_dict(self): assert str(pot) == str(pot2), "Potential to and from dict does not match" -class TestPaths(unittest.TestCase): +class TestPaths(TestCase): def setUp(self): feo = Structure.from_dict( { diff --git a/tests/io/feff/test_outputs.py b/tests/io/feff/test_outputs.py index d090297d26f..9c90e913f9c 100644 --- a/tests/io/feff/test_outputs.py +++ b/tests/io/feff/test_outputs.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pymatgen.io.feff.outputs import LDos, Xmu from pymatgen.util.testing import TEST_FILES_DIR @@ -8,7 +8,7 @@ FEFF_TEST_DIR = f"{TEST_FILES_DIR}/feff" -class TestFeffLdos(unittest.TestCase): +class TestFeffLdos(TestCase): filepath1 = f"{FEFF_TEST_DIR}/feff.inp" filepath2 = f"{FEFF_TEST_DIR}/ldos" ldos = LDos.from_file(filepath1, filepath2) @@ -49,7 +49,7 @@ def test_reci_charge(self): assert charge_trans["1"]["O"]["tot"] == -0.594 -class TestXmu(unittest.TestCase): +class TestXmu(TestCase): def test_init(self): filepath1 = f"{FEFF_TEST_DIR}/xmu.dat" filepath2 = f"{FEFF_TEST_DIR}/feff.inp" diff --git a/tests/io/feff/test_sets.py b/tests/io/feff/test_sets.py index 4863d74d0df..bc130d866c1 100644 --- a/tests/io/feff/test_sets.py +++ b/tests/io/feff/test_sets.py @@ -45,8 +45,7 @@ def test_get_header(self): if idx < 9: assert line == ref[idx] else: - s = " ".join(line.split()[2:]) - assert s in last4 + assert " ".join(line.split()[2:]) in last4 def test_get_feff_tags(self): tags = self.mp_xanes.tags.as_dict() diff --git a/tests/io/lammps/test_data.py b/tests/io/lammps/test_data.py index 88c711f0814..e20284a9243 100644 --- a/tests/io/lammps/test_data.py +++ b/tests/io/lammps/test_data.py @@ -3,7 +3,7 @@ import gzip import json import random -import unittest +from unittest import TestCase import numpy as np import pandas as pd @@ -23,13 +23,7 @@ class TestLammpsBox(PymatgenTest): @classmethod def setUpClass(cls): - cls.peptide = LammpsBox( - bounds=[ - [36.840194, 64.211560], - [41.013691, 68.385058], - [29.768095, 57.139462], - ] - ) + cls.peptide = LammpsBox(bounds=[[36.840194, 64.211560], [41.013691, 68.385058], [29.768095, 57.139462]]) cls.quartz = LammpsBox( bounds=[[0, 4.913400], [0, 4.255129], [0, 5.405200]], tilt=[-2.456700, 0.0, 0.0], @@ -307,9 +301,9 @@ def test_disassemble(self): base_kws = ["Bond", "Angle", "Dihedral", "Improper"] for kw in base_kws: ff_kw = f"{kw} Coeffs" - i = random.randint(0, len(c_ff.topo_coeffs[ff_kw]) - 1) - sample_coeff = c_ff.topo_coeffs[ff_kw][i] - np.testing.assert_array_equal(sample_coeff["coeffs"], c.force_field[ff_kw].iloc[i].values, ff_kw) + idx = random.randint(0, len(c_ff.topo_coeffs[ff_kw]) - 1) + sample_coeff = c_ff.topo_coeffs[ff_kw][idx] + np.testing.assert_array_equal(sample_coeff["coeffs"], c.force_field[ff_kw].iloc[idx].values, ff_kw) topo = topos[-1] atoms = c.atoms[c.atoms["molecule-ID"] == 46] assert_allclose(topo.sites.cart_coords, atoms[["x", "y", "z"]]) @@ -331,10 +325,10 @@ def test_disassemble(self): assert topo_type in ff_coeffs[topo_type_idx]["types"], ff_kw # test no guessing element and pair_ij as non-bond coeffs - v = self.virus - _, v_ff, _ = v.disassemble(guess_element=False) + virus = self.virus + _, v_ff, _ = virus.disassemble(guess_element=False) assert v_ff.maps["Atoms"] == {"Qa1": 1, "Qb1": 2, "Qc1": 3, "Qa2": 4} - pair_ij_coeffs = v.force_field["PairIJ Coeffs"].drop(["id1", "id2"], axis=1) + pair_ij_coeffs = virus.force_field["PairIJ Coeffs"].drop(["id1", "id2"], axis=1) np.testing.assert_array_equal(v_ff.nonbond_coeffs, pair_ij_coeffs.values) # test class2 ff _, e_ff, _ = self.ethane.disassemble() @@ -455,10 +449,10 @@ def test_from_ff_and_topologies(self): np.testing.assert_array_equal(bonds.index.values, np.arange(1, len(bonds) + 1)) np.testing.assert_array_equal(angles.index.values, np.arange(1, len(angles) + 1)) - i = random.randint(0, len(topologies) - 1) - sample = topologies[i] - in_atoms = ice.atoms[ice.atoms["molecule-ID"] == i + 1] - np.testing.assert_array_equal(in_atoms.index.values, np.arange(3 * i + 1, 3 * i + 4)) + idx = random.randint(0, len(topologies) - 1) + sample = topologies[idx] + in_atoms = ice.atoms[ice.atoms["molecule-ID"] == idx + 1] + np.testing.assert_array_equal(in_atoms.index.values, np.arange(3 * idx + 1, 3 * idx + 4)) np.testing.assert_array_equal(in_atoms["type"].values, [2, 1, 1]) np.testing.assert_array_equal(in_atoms["q"].values, sample.charges) np.testing.assert_array_equal(in_atoms[["x", "y", "z"]].values, sample.sites.cart_coords) @@ -467,8 +461,8 @@ def test_from_ff_and_topologies(self): "Angle Coeffs": [{"coeffs": [42.1845, 109.4712], "types": [("H", "H", "H")]}], } broken_ff = ForceField(mass.items(), non_bond_coeffs, broken_topo_coeffs) - ld_woangles = LammpsData.from_ff_and_topologies(box=box, ff=broken_ff, topologies=[sample]) - assert "Angles" not in ld_woangles.topology + ld_wo_angles = LammpsData.from_ff_and_topologies(box=box, ff=broken_ff, topologies=[sample]) + assert "Angles" not in ld_wo_angles.topology def test_from_structure(self): lattice = Lattice.monoclinic(9.78746, 4.75058, 8.95892, 115.9693) @@ -480,14 +474,14 @@ def test_from_structure(self): ) velocities = np.random.randn(20, 3) * 0.1 structure.add_site_property("velocities", velocities) - ld = LammpsData.from_structure(structure=structure, ff_elements=["O", "Os", "Na"]) - i = random.randint(0, 19) + lammps_data = LammpsData.from_structure(structure=structure, ff_elements=["O", "Os", "Na"]) + idx = random.randint(0, 19) a = lattice.matrix[0] - va = velocities[i].dot(a) / np.linalg.norm(a) - assert va == approx(ld.velocities.loc[i + 1, "vx"]) - assert velocities[i, 1] == approx(ld.velocities.loc[i + 1, "vy"]) - assert_allclose(ld.masses["mass"], [22.989769, 190.23, 15.9994]) - np.testing.assert_array_equal(ld.atoms["type"], [2] * 4 + [3] * 16) + v_a = velocities[idx].dot(a) / np.linalg.norm(a) + assert v_a == approx(lammps_data.velocities.loc[idx + 1, "vx"]) + assert velocities[idx, 1] == approx(lammps_data.velocities.loc[idx + 1, "vy"]) + assert_allclose(lammps_data.masses["mass"], [22.989769, 190.23, 15.9994]) + np.testing.assert_array_equal(lammps_data.atoms["type"], [2] * 4 + [3] * 16) def test_set_charge_atom(self): peptide = self.peptide @@ -524,7 +518,7 @@ def test_json_dict(self): assert pd.testing.assert_frame_equal(c2h6.topology[key], target_df) is None, key -class TestTopology(unittest.TestCase): +class TestTopology(TestCase): def test_init(self): inner_charge = np.random.rand(10) - 0.5 outer_charge = np.random.rand(10) - 0.5 @@ -694,9 +688,8 @@ def test_init(self): ("B", "C"): 2, ("C", "B"): 2, } - e = self.ethane - assert e.masses.loc[1, "mass"] == 12.01115 - e_ff = e.force_field + assert self.ethane.masses.loc[1, "mass"] == 12.01115 + e_ff = self.ethane.force_field assert isinstance(e_ff, dict) assert "PairIJ Coeffs" not in e_ff assert e_ff["Pair Coeffs"].loc[1, "coeff2"] == 3.854 @@ -712,7 +705,7 @@ def test_init(self): assert e_ff["AngleAngleTorsion Coeffs"].loc[1, "coeff1"] == -12.564 assert e_ff["BondBond13 Coeffs"].loc[1, "coeff1"] == 0.0 assert e_ff["AngleAngle Coeffs"].loc[1, "coeff2"] == -0.4825 - e_maps = e.maps + e_maps = self.ethane.maps assert e_maps["Atoms"] == {"c4": 1, "h1": 2} assert e_maps["Bonds"] == {("c4", "c4"): 1, ("c4", "h1"): 2, ("h1", "c4"): 2} assert e_maps["Angles"] == {("c4", "c4", "h1"): 1, ("h1", "c4", "c4"): 1, ("h1", "c4", "h1"): 2} @@ -735,10 +728,9 @@ def test_to_file(self): assert dct["nonbond_coeffs"] == self.virus.nonbond_coeffs def test_from_file(self): - e = self.ethane - assert e.mass_info == [("c4", 12.01115), ("h1", 1.00797)] - np.testing.assert_array_equal(e.nonbond_coeffs, [[0.062, 3.854], [0.023, 2.878]]) - e_tc = e.topo_coeffs + assert self.ethane.mass_info == [("c4", 12.01115), ("h1", 1.00797)] + np.testing.assert_array_equal(self.ethane.nonbond_coeffs, [[0.062, 3.854], [0.023, 2.878]]) + e_tc = self.ethane.topo_coeffs assert "Bond Coeffs" in e_tc assert "BondAngle Coeffs" in e_tc["Angle Coeffs"][0] assert "BondBond Coeffs" in e_tc["Angle Coeffs"][0] @@ -758,7 +750,7 @@ def test_from_dict(self): assert decoded.topo_coeffs == self.ethane.topo_coeffs -class TestFunc(unittest.TestCase): +class TestFunc(TestCase): def test_lattice_2_lmpbox(self): matrix = np.diag(np.random.randint(5, 14, size=(3,))) + np.random.rand(3, 3) * 0.2 - 0.1 init_latt = Lattice(matrix) @@ -790,7 +782,7 @@ def test_lattice_2_lmpbox(self): ) -class TestCombinedData(unittest.TestCase): +class TestCombinedData(TestCase): @classmethod def setUpClass(cls): cls.ec = LammpsData.from_file(filename=f"{TEST_DIR}/ec.data.gz") diff --git a/tests/io/lammps/test_inputs.py b/tests/io/lammps/test_inputs.py index 68596287e4f..0b840cee7c0 100644 --- a/tests/io/lammps/test_inputs.py +++ b/tests/io/lammps/test_inputs.py @@ -564,8 +564,6 @@ def test_add_comment(self): class TestLammpsRun(PymatgenTest): - maxDiff = None - def test_md(self): struct = Structure.from_spacegroup(225, Lattice.cubic(3.62126), ["Cu"], [[0, 0, 0]]) ld = LammpsData.from_structure(struct, atom_style="atomic") diff --git a/tests/io/lammps/test_outputs.py b/tests/io/lammps/test_outputs.py index 6bff18c5ffa..d627c45500a 100644 --- a/tests/io/lammps/test_outputs.py +++ b/tests/io/lammps/test_outputs.py @@ -2,7 +2,7 @@ import json import os -import unittest +from unittest import TestCase import numpy as np import pandas as pd @@ -14,7 +14,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/lammps" -class TestLammpsDump(unittest.TestCase): +class TestLammpsDump(TestCase): @classmethod def setUpClass(cls): with open(f"{TEST_DIR}/dump.rdx_wc.100") as file: @@ -54,7 +54,7 @@ def test_json_dict(self): pd.testing.assert_frame_equal(rdx.data, self.rdx.data) -class TestFunc(unittest.TestCase): +class TestFunc(TestCase): def test_parse_lammps_dumps(self): # gzipped rdx_10_pattern = f"{TEST_DIR}/dump.rdx.gz" diff --git a/tests/io/lammps/test_utils.py b/tests/io/lammps/test_utils.py index 25f39603928..a7893692a9a 100644 --- a/tests/io/lammps/test_utils.py +++ b/tests/io/lammps/test_utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pymatgen.core.structure import Molecule from pymatgen.io.lammps.data import Topology @@ -8,7 +8,7 @@ from pymatgen.util.testing import TEST_FILES_DIR -class TestPolymer(unittest.TestCase): +class TestPolymer(TestCase): @classmethod def setUpClass(cls): # head molecule @@ -80,7 +80,7 @@ def test_polymer_chain_topologies(self): assert topology_linear.topologies["Dihedrals"] != topology_random.topologies["Dihedrals"] -class TestPackmolOutput(unittest.TestCase): +class TestPackmolOutput(TestCase): @classmethod def setUpClass(cls): ethanol_coords = [ diff --git a/tests/io/lobster/test_inputs.py b/tests/io/lobster/test_inputs.py index 1ebea491204..1a0be9ab9dc 100644 --- a/tests/io/lobster/test_inputs.py +++ b/tests/io/lobster/test_inputs.py @@ -3,7 +3,7 @@ import json import os import tempfile -import unittest +from unittest import TestCase import numpy as np import pytest @@ -268,7 +268,7 @@ def test_cohp_data(self): assert cohp1 == approx(cohp2, abs=1e-3) def test_orbital_resolved_cohp(self): - orbitals = [(Orbital(i), Orbital(j)) for j in range(4) for i in range(4)] + orbitals = [(Orbital(jj), Orbital(ii)) for ii in range(4) for jj in range(4)] assert self.cohp_bise.orb_res_cohp is None assert self.coop_bise.orb_res_cohp is None assert self.cohp_fe.orb_res_cohp is None @@ -369,7 +369,7 @@ def test_orbital_resolved_cohp(self): assert len(self.cobi6.orb_res_cohp["21"]["2py-1s-2s"]["COHP"][Spin.down]) == 12 -class TestIcohplist(unittest.TestCase): +class TestIcohplist(TestCase): def setUp(self): self.icohp_bise = Icohplist(filename=f"{TEST_FILES_DIR}/cohp/ICOHPLIST.lobster.BiSe") self.icoop_bise = Icohplist( @@ -639,7 +639,7 @@ def test_msonable(self): assert getattr(icohplist_from_dict, attr_name) == attr_value -class TestNciCobiList(unittest.TestCase): +class TestNciCobiList(TestCase): def setUp(self): self.ncicobi = NciCobiList(filename=f"{TEST_FILES_DIR}/cohp/NcICOBILIST.lobster") self.ncicobi_gz = NciCobiList(filename=f"{TEST_FILES_DIR}/cohp/NcICOBILIST.lobster.gz") @@ -677,7 +677,7 @@ def test_ncicobilist(self): ) -class TestDoscar(unittest.TestCase): +class TestDoscar(TestCase): def setUp(self): # first for spin polarized version doscar = f"{VASP_OUT_DIR}/DOSCAR.lobster.spin" @@ -693,7 +693,7 @@ def setUp(self): self.DOSCAR_spin_pol = Doscar(doscar=doscar, structure_file=poscar) self.DOSCAR_nonspin_pol = Doscar(doscar=doscar2, structure_file=poscar2) - with open(f"{TEST_FILES_DIR}/structure_KF.json") as file: + with open(f"{TEST_FILES_DIR}/structure_KF.json", encoding="utf-8") as file: data = json.load(file) self.structure = Structure.from_dict(data) @@ -1556,7 +1556,7 @@ def test_get_bandstructure(self): assert bs_p_x.get_projection_on_elements()[Spin.up][0][0]["Si"] == approx(3 * (0.001 + 0.064), abs=1e-2) -class TestLobsterin(unittest.TestCase): +class TestLobsterin(TestCase): def setUp(self): self.Lobsterinfromfile = Lobsterin.from_file(f"{TEST_FILES_DIR}/cohp/lobsterin.1") self.Lobsterinfromfile2 = Lobsterin.from_file(f"{TEST_FILES_DIR}/cohp/lobsterin.2") @@ -1612,11 +1612,11 @@ def test_initialize_from_dict(self): assert lobsterin["basisfunctions"][0] == "Fe 3d 4p 4s" assert lobsterin["basisfunctions"][1] == "Co 3d 4p 4s" assert {*lobsterin} >= {"skipdos", "skipcohp", "skipcoop", "skippopulationanalysis", "skipgrosspopulation"} - with pytest.raises(IOError, match="There are duplicates for the keywords! The program will stop here."): + with pytest.raises(KeyError, match="There are duplicates for the keywords!"): lobsterin2 = Lobsterin({"cohpstartenergy": -15.0, "cohpstartEnergy": -20.0}) lobsterin2 = Lobsterin({"cohpstartenergy": -15.0}) # can only calculate nbands if basis functions are provided - with pytest.raises(IOError, match="No basis functions are provided. The program cannot calculate nbands"): + with pytest.raises(ValueError, match="No basis functions are provided. The program cannot calculate nbands"): lobsterin2._get_nbands(structure=Structure.from_file(f"{VASP_IN_DIR}/POSCAR_Fe3O4")) def test_standard_settings(self): @@ -1752,7 +1752,7 @@ def test_standard_with_energy_range_from_vasprun(self): def test_diff(self): # test diff assert self.Lobsterinfromfile.diff(self.Lobsterinfromfile2)["Different"] == {} - assert self.Lobsterinfromfile.diff(self.Lobsterinfromfile2)["Same"]["COHPSTARTENERGY"] == approx(-15.0) + assert self.Lobsterinfromfile.diff(self.Lobsterinfromfile2)["Same"]["cohpstartenergy"] == approx(-15.0) # test diff in both directions for entry in self.Lobsterinfromfile.diff(self.Lobsterinfromfile3)["Same"]: @@ -1765,8 +1765,8 @@ def test_diff(self): assert entry in self.Lobsterinfromfile.diff(self.Lobsterinfromfile3)["Different"] assert ( - self.Lobsterinfromfile.diff(self.Lobsterinfromfile3)["Different"]["SKIPCOHP"]["lobsterin1"] - == self.Lobsterinfromfile3.diff(self.Lobsterinfromfile)["Different"]["SKIPCOHP"]["lobsterin2"] + self.Lobsterinfromfile.diff(self.Lobsterinfromfile3)["Different"]["skipcohp"]["lobsterin1"] + == self.Lobsterinfromfile3.diff(self.Lobsterinfromfile)["Different"]["skipcohp"]["lobsterin2"] ) def test_dict_functionality(self): @@ -1793,8 +1793,6 @@ def test_read_write_lobsterin(self): lobsterin2 = Lobsterin.from_file(outfile_path) assert lobsterin1.diff(lobsterin2)["Different"] == {} - # TODO: will integer vs float break cohpsteps? - def test_get_basis(self): # get basis functions lobsterin1 = Lobsterin({}) @@ -2030,7 +2028,7 @@ def test_msonable_implementation(self): new_lobsterin.to_json() -class TestBandoverlaps(unittest.TestCase): +class TestBandoverlaps(TestCase): def setUp(self): # test spin-polarized calc and non spinpolarized calc @@ -2189,7 +2187,7 @@ def test_keys(self): assert len(bo_dict_new[Spin.down]["matrices"]) == 73 -class TestGrosspop(unittest.TestCase): +class TestGrosspop(TestCase): def setUp(self): self.grosspop1 = Grosspop(f"{TEST_FILES_DIR}/cohp/GROSSPOP.lobster") @@ -2590,7 +2588,7 @@ def test_raises(self): self.hamilton_matrices = LobsterMatrices(filename=f"{TEST_FILES_DIR}/cohp/Na_hamiltonMatrices.lobster.gz") with pytest.raises( - OSError, - match=r"Please check provided input file, it seems to be empty", + RuntimeError, + match="Please check provided input file, it seems to be empty", ): self.hamilton_matrices = LobsterMatrices(filename=f"{TEST_FILES_DIR}/cohp/hamiltonMatrices.lobster") diff --git a/tests/io/lobster/test_lobsterenv.py b/tests/io/lobster/test_lobsterenv.py index 8ad8c27ac45..3b985fe0336 100644 --- a/tests/io/lobster/test_lobsterenv.py +++ b/tests/io/lobster/test_lobsterenv.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import unittest +from unittest import TestCase import numpy as np import pytest @@ -26,7 +26,7 @@ module_dir = os.path.dirname(os.path.abspath(__file__)) -class TestLobsterNeighbors(unittest.TestCase): +class TestLobsterNeighbors(TestCase): def setUp(self): # test additional conditions first # only consider cation anion bonds diff --git a/tests/io/qchem/test_inputs.py b/tests/io/qchem/test_inputs.py index c2c29d10de8..9d70f70780b 100644 --- a/tests/io/qchem/test_inputs.py +++ b/tests/io/qchem/test_inputs.py @@ -43,7 +43,6 @@ def test_molecule_template(self): assert molecule_actual == molecule_test def test_multi_molecule_template(self): - self.maxDiff = None species = ["C", "C", "H", "H", "H", "H"] coords_1 = [ [0.000000, 0.000000, 0.000000], diff --git a/tests/io/qchem/test_outputs.py b/tests/io/qchem/test_outputs.py index ebb7c8d0fa1..179a101341d 100644 --- a/tests/io/qchem/test_outputs.py +++ b/tests/io/qchem/test_outputs.py @@ -7,7 +7,7 @@ import numpy as np import pytest from monty.serialization import dumpfn, loadfn -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose from pytest import approx from pymatgen.core.structure import Molecule @@ -282,30 +282,34 @@ def _test_property(self, key, single_outs, multi_outs): assert out_data.get(key) == single_job_dict[filename].get(key) except ValueError: try: - assert_array_equal(out_data.get(key), single_job_dict[filename].get(key)) + if isinstance(out_data.get(key), dict): + assert out_data.get(key) == approx(single_job_dict[filename].get(key)) + else: + assert_allclose(out_data.get(key), single_job_dict[filename].get(key), atol=1e-6) except AssertionError: raise RuntimeError(f"Issue with {filename=} Exiting...") except AssertionError: raise RuntimeError(f"Issue with {filename=} Exiting...") for filename, outputs in multi_outs.items(): - for ii, sub_output in enumerate(outputs): + for idx, sub_output in enumerate(outputs): try: - assert sub_output.data.get(key) == multi_job_dict[filename][ii].get(key) + assert sub_output.data.get(key) == multi_job_dict[filename][idx].get(key) except ValueError: - assert_array_equal(sub_output.data.get(key), multi_job_dict[filename][ii].get(key)) + if isinstance(sub_output.data.get(key), dict): + assert sub_output.data.get(key) == approx(multi_job_dict[filename][idx].get(key)) + else: + assert_allclose(sub_output.data.get(key), multi_job_dict[filename][idx].get(key), atol=1e-6) + @pytest.mark.skip() # self._test_property(key, single_outs, multi_outs) fails with + # ValueError: The truth value of an array with more than one element is ambiguous @pytest.mark.skipif(openbabel is None, reason="OpenBabel not installed.") def test_all(self): - self.maxDiff = None - single_outs = {} - for file in single_job_out_names: - single_outs[file] = QCOutput(f"{TEST_FILES_DIR}/molecules/{file}").data + single_outs = {file: QCOutput(f"{TEST_FILES_DIR}/molecules/{file}").data for file in single_job_out_names} - multi_outs = {} - for file in multi_job_out_names: - multi_outs[file] = QCOutput.multiple_outputs_from_file( - f"{TEST_FILES_DIR}/molecules/{file}", keep_sub_files=False - ) + multi_outs = { + file: QCOutput.multiple_outputs_from_file(f"{TEST_FILES_DIR}/molecules/{file}", keep_sub_files=False) + for file in multi_job_out_names + } for key in property_list: self._test_property(key, single_outs, multi_outs) diff --git a/tests/io/test_adf.py b/tests/io/test_adf.py index ac3fb699d4d..188d9735066 100644 --- a/tests/io/test_adf.py +++ b/tests/io/test_adf.py @@ -71,15 +71,11 @@ def readfile(file_object): """` Return the content of the file as a string. - Parameters - ---------- - file_object : file or str - The file to read. This can be either a File object or a file path. + Args: + file_object (file or str): The file to read. This can be either a File object or a file path. Returns: - ------- - content : str - The content of the file. + content (str): The content of the file. """ if hasattr(file_object, "read"): return file_object.read() @@ -167,8 +163,8 @@ def test_option_operations(self): def test_atom_block_key(self): block = AdfKey("atoms") - o = Molecule.from_str(h2o_xyz, "xyz") - for site in o: + mol = Molecule.from_str(h2o_xyz, "xyz") + for site in mol: block.add_subkey(AdfKey(str(site.specie), list(site.coords))) assert str(block) == atoms_string @@ -206,14 +202,14 @@ def test_energy(self): def test_serialization(self): task = AdfTask() - o = AdfTask.from_dict(task.as_dict()) - assert task.title == o.title - assert task.basis_set == o.basis_set - assert task.scf == o.scf - assert task.geo == o.geo - assert task.operation == o.operation - assert task.units == o.units - assert str(task) == str(o) + adf_task = AdfTask.from_dict(task.as_dict()) + assert task.title == adf_task.title + assert task.basis_set == adf_task.basis_set + assert task.scf == adf_task.scf + assert task.geo == adf_task.geo + assert task.operation == adf_task.operation + assert task.units == adf_task.units + assert str(task) == str(adf_task) rhb18 = { @@ -261,35 +257,35 @@ def test_main(self): class TestAdfOutput: def test_analytical_freq(self): filename = f"{TEST_DIR}/adf/analytical_freq/adf.out" - o = AdfOutput(filename) - assert o.final_energy == approx(-0.54340325) - assert len(o.energies) == 4 - assert len(o.structures) == 4 - assert o.frequencies[0] == approx(1553.931) - assert o.frequencies[2] == approx(3793.086) - assert o.normal_modes[0][2] == approx(0.071) - assert o.normal_modes[0][6] == approx(0.000) - assert o.normal_modes[0][7] == approx(-0.426) - assert o.normal_modes[0][8] == approx(-0.562) + adf_out = AdfOutput(filename) + assert adf_out.final_energy == approx(-0.54340325) + assert len(adf_out.energies) == 4 + assert len(adf_out.structures) == 4 + assert adf_out.frequencies[0] == approx(1553.931) + assert adf_out.frequencies[2] == approx(3793.086) + assert adf_out.normal_modes[0][2] == approx(0.071) + assert adf_out.normal_modes[0][6] == approx(0.000) + assert adf_out.normal_modes[0][7] == approx(-0.426) + assert adf_out.normal_modes[0][8] == approx(-0.562) def test_numerical_freq(self): filename = f"{TEST_DIR}/adf/numerical_freq/adf.out" - o = AdfOutput(filename) - assert o.freq_type == "Numerical" - assert len(o.final_structure) == 4 - assert len(o.frequencies) == 6 - assert len(o.normal_modes) == 6 - assert o.frequencies[0] == approx(938.21) - assert o.frequencies[3] == approx(3426.64) - assert o.frequencies[4] == approx(3559.35) - assert o.frequencies[5] == approx(3559.35) - assert o.normal_modes[1][0] == approx(0.067) - assert o.normal_modes[1][3] == approx(-0.536) - assert o.normal_modes[1][7] == approx(0.000) - assert o.normal_modes[1][9] == approx(-0.536) + adf_out = AdfOutput(filename) + assert adf_out.freq_type == "Numerical" + assert len(adf_out.final_structure) == 4 + assert len(adf_out.frequencies) == 6 + assert len(adf_out.normal_modes) == 6 + assert adf_out.frequencies[0] == approx(938.21) + assert adf_out.frequencies[3] == approx(3426.64) + assert adf_out.frequencies[4] == approx(3559.35) + assert adf_out.frequencies[5] == approx(3559.35) + assert adf_out.normal_modes[1][0] == approx(0.067) + assert adf_out.normal_modes[1][3] == approx(-0.536) + assert adf_out.normal_modes[1][7] == approx(0.000) + assert adf_out.normal_modes[1][9] == approx(-0.536) def test_single_point(self): filename = f"{TEST_DIR}/adf/sp/adf.out" - o = AdfOutput(filename) - assert o.final_energy == approx(-0.74399276) - assert len(o.final_structure) == 4 + adf_out = AdfOutput(filename) + assert adf_out.final_energy == approx(-0.74399276) + assert len(adf_out.final_structure) == 4 diff --git a/tests/io/test_babel.py b/tests/io/test_babel.py index 908389e7d4f..d4d302cae14 100644 --- a/tests/io/test_babel.py +++ b/tests/io/test_babel.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -import unittest +from unittest import TestCase import pytest from pytest import approx @@ -16,7 +16,7 @@ pybel = pytest.importorskip("openbabel.pybel") -class TestBabelMolAdaptor(unittest.TestCase): +class TestBabelMolAdaptor(TestCase): def setUp(self): coords = [ [0.000000, 0.000000, 0.000000], @@ -53,7 +53,7 @@ def test_from_file_return_all_molecules(self): assert len(adaptors) == 302 def test_from_molecule_graph(self): - graph = MoleculeGraph.with_empty_graph(self.mol) + graph = MoleculeGraph.from_empty_graph(self.mol) adaptor = BabelMolAdaptor.from_molecule_graph(graph) ob_mol = adaptor.openbabel_mol assert ob_mol.NumAtoms() == 5 diff --git a/tests/io/test_cif.py b/tests/io/test_cif.py index d4f58201648..467d9ed4ecc 100644 --- a/tests/io/test_cif.py +++ b/tests/io/test_cif.py @@ -499,7 +499,7 @@ def test_symmetrized(self): def test_disordered(self): si = Element("Si") - n = Element("N") + nitrogen = Element("N") coords = [] coords.extend((np.array([0, 0, 0]), np.array([0.75, 0.5, 0.75]))) lattice = [ @@ -507,7 +507,7 @@ def test_disordered(self): [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - struct = Structure(lattice, [si, {si: 0.5, n: 0.5}], coords) + struct = Structure(lattice, [si, {si: 0.5, nitrogen: 0.5}], coords) writer = CifWriter(struct) answer = """# generated using pymatgen data_Si1.5N0.5 @@ -554,7 +554,7 @@ def test_cif_writer_without_refinement(self): def test_specie_cif_writer(self): si4 = Species("Si", 4) si3 = Species("Si", 3) - n = DummySpecies("X", -3) + dummy_spec = DummySpecies("X", -3) coords = [] coords.extend((np.array([0.5, 0.5, 0.5]), np.array([0.75, 0.5, 0.75]), np.array([0, 0, 0]))) lattice = [ @@ -562,7 +562,7 @@ def test_specie_cif_writer(self): [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - struct = Structure(lattice, [n, {si3: 0.5, n: 0.5}, si4], coords) + struct = Structure(lattice, [dummy_spec, {si3: 0.5, dummy_spec: 0.5}, si4], coords) writer = CifWriter(struct) answer = """# generated using pymatgen data_X1.5Si1.5 @@ -740,20 +740,20 @@ def test_bad_cif(self): assert struct[0].species["Al3+"] == approx(0.778) def test_one_line_symm(self): - f = f"{TEST_FILES_DIR}/OneLineSymmP1.cif" - parser = CifParser(f) + cif_file = f"{TEST_FILES_DIR}/OneLineSymmP1.cif" + parser = CifParser(cif_file) struct = parser.parse_structures()[0] assert struct.formula == "Ga4 Pb2 O8" def test_no_symmops(self): - f = f"{TEST_FILES_DIR}/nosymm.cif" - parser = CifParser(f) + cif_file = f"{TEST_FILES_DIR}/nosymm.cif" + parser = CifParser(cif_file) struct = parser.parse_structures()[0] assert struct.formula == "H96 C60 O8" def test_dot_positions(self): - f = f"{TEST_FILES_DIR}/ICSD59959.cif" - parser = CifParser(f) + cif_file = f"{TEST_FILES_DIR}/ICSD59959.cif" + parser = CifParser(cif_file) struct = parser.parse_structures()[0] assert struct.formula == "K1 Mn1 F3" diff --git a/tests/io/test_core.py b/tests/io/test_core.py index bb12ebedb23..892c44d46e3 100644 --- a/tests/io/test_core.py +++ b/tests/io/test_core.py @@ -2,6 +2,7 @@ import copy import os +from typing import TYPE_CHECKING import pytest from monty.serialization import MontyDecoder @@ -11,6 +12,9 @@ from pymatgen.io.core import InputFile, InputSet from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest +if TYPE_CHECKING: + from typing_extensions import Self + class StructInputFile(InputFile): """Test implementation of an InputFile object for CIF.""" @@ -23,7 +27,7 @@ def get_str(self) -> str: return str(cw) @classmethod - def from_str(cls, contents: str): + def from_str(cls, contents: str) -> Self: # type: ignore[override] struct = Structure.from_str(contents, fmt="cif") return cls(structure=struct) diff --git a/tests/io/test_cssr.py b/tests/io/test_cssr.py index ea3aa45f27f..6d624fa7988 100644 --- a/tests/io/test_cssr.py +++ b/tests/io/test_cssr.py @@ -2,7 +2,7 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pymatgen.core.structure import Structure from pymatgen.io.cssr import Cssr @@ -16,7 +16,7 @@ __date__ = "Jan 24, 2012" -class TestCssr(unittest.TestCase): +class TestCssr(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.cssr = Cssr(Structure.from_file(filepath)) diff --git a/tests/io/test_fiesta.py b/tests/io/test_fiesta.py index 1c4385501a2..acd3703f66a 100644 --- a/tests/io/test_fiesta.py +++ b/tests/io/test_fiesta.py @@ -1,13 +1,13 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pymatgen.core.structure import Molecule from pymatgen.io.fiesta import FiestaInput, FiestaOutput from pymatgen.util.testing import TEST_FILES_DIR -class TestFiestaInput(unittest.TestCase): +class TestFiestaInput(TestCase): def setUp(self): coords = [ [0.000000, 0.000000, 0.000000], @@ -69,7 +69,7 @@ def test_str_and_from_str(self): assert cell_in.cohsex_options["eigMethod"] == "C" -class TestFiestaOutput(unittest.TestCase): +class TestFiestaOutput(TestCase): def setUp(self): self.log_fiesta = FiestaOutput(f"{TEST_FILES_DIR}/fiesta/log_fiesta") diff --git a/tests/io/test_gaussian.py b/tests/io/test_gaussian.py index 2a6858bafdb..6ba00740eb1 100644 --- a/tests/io/test_gaussian.py +++ b/tests/io/test_gaussian.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pytest from pytest import approx @@ -13,7 +13,7 @@ TEST_DIR = f"{TEST_FILES_DIR}/molecules" -class TestGaussianInput(unittest.TestCase): +class TestGaussianInput(TestCase): def setUp(self): coords = [ [0, 0, 0], @@ -251,7 +251,7 @@ def test_no_molecule_func_bset_charge_mult(self): assert input_str == gau_str -class TestGaussianOutput(unittest.TestCase): +class TestGaussianOutput(TestCase): # TODO: Add unittest for PCM type output. def setUp(self): diff --git a/tests/io/test_jarvis.py b/tests/io/test_jarvis.py index 5448edc880b..75c3c78eb82 100644 --- a/tests/io/test_jarvis.py +++ b/tests/io/test_jarvis.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import pytest from pymatgen.core import Structure @@ -10,7 +8,7 @@ @pytest.mark.skipif(Atoms is None, reason="JARVIS-tools not loaded.") -class TestJarvisAtomsAdaptor(unittest.TestCase): +class TestJarvisAtomsAdaptor: def test_get_atoms_from_structure(self): struct = Structure.from_file(f"{VASP_IN_DIR}/POSCAR") atoms = JarvisAtomsAdaptor.get_atoms(struct) diff --git a/tests/io/test_nwchem.py b/tests/io/test_nwchem.py index 5e317c0ab43..102ccf4bd8d 100644 --- a/tests/io/test_nwchem.py +++ b/tests/io/test_nwchem.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase import pytest from pytest import approx @@ -22,7 +22,7 @@ mol = Molecule(["C", "H", "H", "H", "H"], coords) -class TestNwTask(unittest.TestCase): +class TestNwTask(TestCase): def setUp(self): self.task = NwTask( 0, @@ -136,7 +136,7 @@ def test_esp_task(self): assert str(task) == answer -class TestNwInput(unittest.TestCase): +class TestNwInput(TestCase): def setUp(self): tasks = [ NwTask.dft_task(mol, operation="optimize", xc="b3lyp", basis_set="6-31++G*"), @@ -396,7 +396,7 @@ def test_from_str_and_file(self): assert nwi_symm.tasks[-1].basis_set["C"] == "6-311++G**" -class TestNwOutput(unittest.TestCase): +class TestNwOutput: def test_read(self): nwo = NwOutput(f"{TEST_DIR}/CH4.nwout") nwo_cosmo = NwOutput(f"{TEST_DIR}/N2O4.nwout") diff --git a/tests/io/test_openff.py b/tests/io/test_openff.py new file mode 100644 index 00000000000..0c2857f584e --- /dev/null +++ b/tests/io/test_openff.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from pathlib import Path + +import networkx as nx +import networkx.algorithms.isomorphism as iso +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from pymatgen.analysis.graphs import MoleculeGraph +from pymatgen.analysis.local_env import OpenBabelNN +from pymatgen.core import Molecule +from pymatgen.io.openff import ( + add_conformer, + assign_partial_charges, + create_openff_mol, + get_atom_map, + infer_openff_mol, + mol_graph_from_openff_mol, + mol_graph_to_openff_mol, +) + +tk = pytest.importorskip("openff.toolkit") + + +@pytest.fixture() +def mol_files(): + geo_dir = Path(__file__).absolute().parent.parent / "files/classical_md_mols" + return { + "CCO_xyz": str(geo_dir / "CCO.xyz"), + "CCO_charges": str(geo_dir / "CCO.npy"), + "FEC_r_xyz": str(geo_dir / "FEC-r.xyz"), + "FEC_s_xyz": str(geo_dir / "FEC-s.xyz"), + "FEC_charges": str(geo_dir / "FEC.npy"), + "PF6_xyz": str(geo_dir / "PF6.xyz"), + "PF6_charges": str(geo_dir / "PF6.npy"), + "Li_charges": str(geo_dir / "Li.npy"), + "Li_xyz": str(geo_dir / "Li.xyz"), + } + + +def test_mol_graph_from_atom_bonds(mol_files): + pytest.importorskip("openff") + + pf6_openff = tk.Molecule.from_smiles("F[P-](F)(F)(F)(F)F") + + pf6_graph = mol_graph_from_openff_mol(pf6_openff) + + assert len(pf6_graph.molecule) == 7 + assert pf6_graph.molecule.charge == -1 + + em = iso.categorical_edge_match("weight", 1) + + pf6_openff2 = mol_graph_to_openff_mol(pf6_graph) + pf6_graph2 = mol_graph_from_openff_mol(pf6_openff2) + assert nx.is_isomorphic(pf6_graph.graph, pf6_graph2.graph, edge_match=em) + + +def test_mol_graph_from_openff_mol_cco(): + atom_coords = np.array( + [ + [1.000000, 1.000000, 0.000000], + [-0.515000, 1.000000, 0.000000], + [-0.999000, 1.000000, 1.335000], + [1.390000, 1.001000, -1.022000], + [1.386000, 0.119000, 0.523000], + [1.385000, 1.880000, 0.526000], + [-0.907000, 0.118000, -0.516000], + [-0.897000, 1.894000, -0.501000], + [-0.661000, 0.198000, 1.768000], + ] + ) + + atoms = ["C", "C", "O", "H", "H", "H", "H", "H", "H"] + + cco_openff = tk.Molecule.from_smiles("CCO") + cco_openff.assign_partial_charges("mmff94") + + cco_mol_graph_1 = mol_graph_from_openff_mol(cco_openff) + + assert len(cco_mol_graph_1.molecule) == 9 + assert cco_mol_graph_1.molecule.charge == 0 + assert len(cco_mol_graph_1.graph.edges) == 8 + + cco_pmg = Molecule(atoms, atom_coords) + cco_mol_graph_2 = MoleculeGraph.with_local_env_strategy(cco_pmg, OpenBabelNN()) + + em = iso.categorical_edge_match("weight", 1) + + assert nx.is_isomorphic(cco_mol_graph_1.graph, cco_mol_graph_2.graph, edge_match=em) + + +def test_mol_graph_to_openff_pf6(mol_files): + """transform a water MoleculeGraph to a OpenFF water molecule""" + pf6_mol = Molecule.from_file(mol_files["PF6_xyz"]) + pf6_mol.set_charge_and_spin(charge=-1) + pf6_mol_graph = MoleculeGraph.with_edges( + pf6_mol, + { + (0, 1): {"weight": 1}, + (0, 2): {"weight": 1}, + (0, 3): {"weight": 1}, + (0, 4): {"weight": 1}, + (0, 5): {"weight": 1}, + (0, 6): {"weight": 1}, + }, + ) + + pf6_openff_1 = tk.Molecule.from_smiles("F[P-](F)(F)(F)(F)F") + + pf6_openff_2 = mol_graph_to_openff_mol(pf6_mol_graph) + assert pf6_openff_1 == pf6_openff_2 + + +def test_mol_graph_to_openff_cco(mol_files): + cco_pmg = Molecule.from_file(mol_files["CCO_xyz"]) + cco_mol_graph = MoleculeGraph.with_local_env_strategy(cco_pmg, OpenBabelNN()) + + cco_openff_1 = mol_graph_to_openff_mol(cco_mol_graph) + + cco_openff_2 = tk.Molecule.from_smiles("CCO") + cco_openff_2.assign_partial_charges("mmff94") + + assert cco_openff_1 == cco_openff_2 + + +def test_openff_back_and_forth(): + cco_openff = tk.Molecule.from_smiles("CC(=O)O") + cco_openff.assign_partial_charges("mmff94") + + cco_mol_graph_1 = mol_graph_from_openff_mol(cco_openff) + + assert len(cco_mol_graph_1.molecule) == 8 + assert cco_mol_graph_1.molecule.charge == 0 + assert len(cco_mol_graph_1.graph.edges) == 7 + + cco_openff_2 = mol_graph_to_openff_mol(cco_mol_graph_1) + + assert tk.Molecule.is_isomorphic_with(cco_openff, cco_openff_2, bond_order_matching=True) + assert max(bond.bond_order for bond in cco_openff_2.bonds) == 2 + + +@pytest.mark.parametrize( + ("xyz_path", "smile", "map_values"), + [ + ("CCO_xyz", "CCO", [0, 1, 2, 3, 4, 5, 6, 7, 8]), + ("FEC_r_xyz", "O=C1OC[C@@H](F)O1", [0, 1, 2, 3, 4, 6, 7, 9, 8, 5]), + ("FEC_s_xyz", "O=C1OC[C@H](F)O1", [0, 1, 2, 3, 4, 6, 7, 9, 8, 5]), + ("PF6_xyz", "F[P-](F)(F)(F)(F)F", [1, 0, 2, 3, 4, 5, 6]), + ], +) +def test_get_atom_map(xyz_path, smile, map_values, mol_files): + mol = Molecule.from_file(mol_files[xyz_path]) + inferred_mol = infer_openff_mol(mol) + openff_mol = tk.Molecule.from_smiles(smile) + isomorphic, atom_map = get_atom_map(inferred_mol, openff_mol) + assert isomorphic + assert map_values == list(atom_map.values()) + + +@pytest.mark.parametrize( + ("xyz_path", "n_atoms", "n_bonds"), + [ + ("CCO_xyz", 9, 8), + ("FEC_r_xyz", 10, 10), + ("FEC_s_xyz", 10, 10), + ("PF6_xyz", 7, 6), + ], +) +def test_infer_openff_mol(xyz_path, n_atoms, n_bonds, mol_files): + mol = Molecule.from_file(mol_files[xyz_path]) + openff_mol = infer_openff_mol(mol) + assert isinstance(openff_mol, tk.Molecule) + assert openff_mol.n_atoms == n_atoms + assert openff_mol.n_bonds == n_bonds + + +def test_add_conformer(mol_files): + openff_mol = tk.Molecule.from_smiles("CCO") + geometry = Molecule.from_file(mol_files["CCO_xyz"]) + openff_mol, atom_map = add_conformer(openff_mol, geometry) + assert openff_mol.n_conformers == 1 + assert list(atom_map.values()) == list(range(openff_mol.n_atoms)) + + +def test_assign_partial_charges(mol_files): + openff_mol = tk.Molecule.from_smiles("CCO") + geometry = Molecule.from_file(mol_files["CCO_xyz"]) + openff_mol, atom_map = add_conformer(openff_mol, geometry) + partial_charges = np.load(mol_files["CCO_charges"]) + openff_mol = assign_partial_charges(openff_mol, atom_map, "am1bcc", partial_charges) + assert_allclose(openff_mol.partial_charges.magnitude, partial_charges) + + +def test_create_openff_mol(mol_files): + smile = "CCO" + geometry = mol_files["CCO_xyz"] + partial_charges = np.load(mol_files["CCO_charges"]) + openff_mol = create_openff_mol(smile, geometry, 1.0, partial_charges, "am1bcc") + assert isinstance(openff_mol, tk.Molecule) + assert openff_mol.n_atoms == 9 + assert openff_mol.n_bonds == 8 + assert_allclose(openff_mol.partial_charges.magnitude, partial_charges) + + +def test_add_conformer_no_geometry(): + openff_mol = tk.Molecule.from_smiles("CCO") + openff_mol, atom_map = add_conformer(openff_mol, None) + assert openff_mol.n_conformers == 1 + assert list(atom_map.values()) == list(range(openff_mol.n_atoms)) + + +def test_assign_partial_charges_single_atom(mol_files): + openff_mol = tk.Molecule.from_smiles("[Li+]") + geometry = Molecule.from_file(mol_files["Li_xyz"]) + openff_mol, atom_map = add_conformer(openff_mol, geometry) + openff_mol = assign_partial_charges(openff_mol, atom_map, "am1bcc", None) + assert_allclose(openff_mol.partial_charges.magnitude, [1.0]) + + +def test_create_openff_mol_no_geometry(): + smile = "CCO" + openff_mol = create_openff_mol(smile) + assert isinstance(openff_mol, tk.Molecule) + assert openff_mol.n_atoms == 9 + assert openff_mol.n_bonds == 8 + assert openff_mol.n_conformers == 1 + + +def test_create_openff_mol_geometry_path(mol_files): + smile = "CCO" + geometry = mol_files["CCO_xyz"] + openff_mol = create_openff_mol(smile, geometry) + assert isinstance(openff_mol, tk.Molecule) + assert openff_mol.n_atoms == 9 + assert openff_mol.n_bonds == 8 + assert openff_mol.n_conformers == 1 + + +def test_create_openff_mol_partial_charges_no_geometry(): + smile = "CCO" + partial_charges = [-0.4, 0.2, 0.2] + with pytest.raises(ValueError, match="geometries must be set if partial_charges is set"): + create_openff_mol(smile, partial_charges=partial_charges) + + +def test_create_openff_mol_partial_charges_length_mismatch(mol_files): + smile = "CCO" + geometry = mol_files["CCO_xyz"] + partial_charges = [-0.4, 0.2] + with pytest.raises(ValueError, match="partial charges must have same length & order as geometry"): + create_openff_mol(smile, geometry, partial_charges=partial_charges) diff --git a/tests/io/test_packmol.py b/tests/io/test_packmol.py index 751705552e9..c715335fbdf 100644 --- a/tests/io/test_packmol.py +++ b/tests/io/test_packmol.py @@ -12,11 +12,12 @@ from pymatgen.io.packmol import PackmolBoxGen from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest -# ruff: noqa: PT011 - - TEST_DIR = f"{TEST_FILES_DIR}/packmol" - +# error message is different in CI for unknown reasons (as of 2024-04-12) +# macOS: "Packmol failed with error code 173 and stderr: b'STOP 173\\n'" +# CI: "Packmol failed with return code 0 and stdout: Packmol was unable to +# put the molecules in the desired regions even without" +ERR_MSG_173 = "Packmol failed with " if which("packmol") is None: pytest.skip("packmol executable not present", allow_module_level=True) @@ -110,7 +111,7 @@ def test_control_params(self): input_string = file.read() assert "maxit 0" in input_string assert "nloop 0" in input_string - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=ERR_MSG_173): input_set.run(self.tmp_path) def test_timeout(self): @@ -141,7 +142,7 @@ def test_no_return_and_box(self): with open(f"{self.tmp_path}/packmol.inp") as file: input_string = file.read() assert "inside box 0 0 0 2 2 2" in input_string - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=ERR_MSG_173): pw.run(self.tmp_path) def test_chdir_behavior(self): @@ -159,7 +160,8 @@ def test_chdir_behavior(self): box=[0, 0, 0, 2, 2, 2], ) pw.write_input(self.tmp_path) - with pytest.raises(ValueError): + + with pytest.raises(ValueError, match=ERR_MSG_173): pw.run(self.tmp_path) assert str(Path.cwd()) == start_dir diff --git a/tests/io/test_phonopy.py b/tests/io/test_phonopy.py index eab1b3dd251..b8337d56b78 100644 --- a/tests/io/test_phonopy.py +++ b/tests/io/test_phonopy.py @@ -1,8 +1,8 @@ from __future__ import annotations import os -import unittest from pathlib import Path +from unittest import TestCase import numpy as np import pytest @@ -106,8 +106,8 @@ def test_structure_conversion(self): assert struct_pmg_round_trip.matches(struct_pmg) coords_ph = struct_ph.get_scaled_positions() - symbols_pmg = {e.symbol for e in struct_pmg.composition} - symbols_pmg2 = {e.symbol for e in struct_pmg_round_trip.composition} + symbols_pmg = {*map(str, struct_pmg.composition)} + symbols_pmg2 = {*map(str, struct_pmg_round_trip.composition)} assert struct_ph.get_cell()[1, 1] == approx(struct_pmg.lattice._matrix[1, 1], abs=1e-7) assert struct_pmg.lattice._matrix[1, 1] == approx(struct_pmg_round_trip.lattice._matrix[1, 1], abs=1e-7) @@ -156,7 +156,7 @@ def test_get_displaced_structures(self): @pytest.mark.skipif(Phonopy is None, reason="Phonopy not present") -class TestPhonopyFromForceConstants(unittest.TestCase): +class TestPhonopyFromForceConstants(TestCase): def setUp(self) -> None: test_path = Path(TEST_DIR) structure_file = test_path / "POSCAR-NaCl" @@ -206,7 +206,7 @@ def test_get_phonon_band_structure_symm_line_from_fc(self): assert bs.bands[2][10] == approx(2.869229797603161) -class TestGruneisen(unittest.TestCase): +class TestGruneisen: def test_ph_bs_symm_line(self): self.bs_symm_line_1 = get_gruneisen_ph_bs_symm_line( gruneisen_path=f"{TEST_FILES_DIR}/gruneisen/gruneisen_band_Si.yaml", diff --git a/tests/io/test_prismatic.py b/tests/io/test_prismatic.py index 81330adfcd2..4231888cc71 100644 --- a/tests/io/test_prismatic.py +++ b/tests/io/test_prismatic.py @@ -1,13 +1,11 @@ from __future__ import annotations -import unittest - from pymatgen.core.structure import Structure from pymatgen.io.prismatic import Prismatic from pymatgen.util.testing import TEST_FILES_DIR -class TestPrismatic(unittest.TestCase): +class TestPrismatic: def test_to_str(self): structure = Structure.from_file(f"{TEST_FILES_DIR}/CuCl.cif") prismatic = Prismatic(structure) diff --git a/tests/io/test_pwscf.py b/tests/io/test_pwscf.py index 50fb70c6755..373518967c4 100644 --- a/tests/io/test_pwscf.py +++ b/tests/io/test_pwscf.py @@ -364,25 +364,25 @@ def test_read_str(self): ] ) - pwin = PWInput.from_str(string) + pw_in = PWInput.from_str(string) # generate list of coords - pw_sites = np.array([list(site.coords) for site in pwin.structure]) + pw_sites = np.array([list(site.coords) for site in pw_in.structure]) assert_allclose(sites, pw_sites) - assert_allclose(lattice, pwin.structure.lattice.matrix) - assert pwin.sections["system"]["smearing"] == "cold" + assert_allclose(lattice, pw_in.structure.lattice.matrix) + assert pw_in.sections["system"]["smearing"] == "cold" class TestPWOuput(PymatgenTest): def setUp(self): - self.pwout = PWOutput(f"{TEST_FILES_DIR}/Si.pwscf.out") + self.pw_out = PWOutput(f"{TEST_FILES_DIR}/Si.pwscf.out") def test_properties(self): - assert self.pwout.final_energy == approx(-93.45259708) + assert self.pw_out.final_energy == approx(-93.45259708) def test_get_celldm(self): - assert self.pwout.get_celldm(1) == approx(10.323) + assert self.pw_out.get_celldm(1) == approx(10.323) for i in range(2, 7): - assert self.pwout.get_celldm(i) == approx(0) + assert self.pw_out.get_celldm(i) == approx(0) diff --git a/tests/io/test_res.py b/tests/io/test_res.py index 445cd10d5ba..3cb19218d81 100644 --- a/tests/io/test_res.py +++ b/tests/io/test_res.py @@ -52,6 +52,8 @@ def test_misc(self, provider: AirssProvider): date, path = rs_info assert path == "/path/to/airss/run" assert date.day == 16 + assert date.month == 7 + assert date.year == 2021 castep_v = provider.get_castep_version() assert castep_v == "19.11" diff --git a/tests/io/test_xcrysden.py b/tests/io/test_xcrysden.py index 88973b194bd..045a41e466b 100644 --- a/tests/io/test_xcrysden.py +++ b/tests/io/test_xcrysden.py @@ -1,28 +1,42 @@ from __future__ import annotations +import numpy as np + from pymatgen.core.structure import Structure from pymatgen.io.xcrysden import XSF from pymatgen.util.testing import PymatgenTest class TestXSF(PymatgenTest): - def test_xsf(self): - coords = [[0, 0, 0], [0.75, 0.5, 0.75]] - lattice = [ + def setUp(self): + self.coords = [[0, 0, 0], [0.75, 0.5, 0.75]] + self.lattice = [ [3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - structure = Structure(lattice, ["Si", "Si"], coords) - xsf = XSF(structure) - assert structure, XSF.from_str(xsf.to_str()) + self.struct = Structure(self.lattice, ["Si", "Si"], self.coords) + + def test_xsf(self): + xsf = XSF(self.struct) + assert self.struct, XSF.from_str(xsf.to_str()) + xsf = XSF(self.struct) + assert self.struct, XSF.from_str(xsf.to_str()) + + def test_append_vect(self): + self.struct.add_site_property("vect", np.eye(2, 3)) + xsf_str = XSF(self.struct).to_str() + last_line_split = xsf_str.split("\n")[-1].split() + assert len(last_line_split) == 7 + assert last_line_split[-1] == "0.00000000000000" + assert last_line_split[-2] == "1.00000000000000" + assert last_line_split[-3] == "0.00000000000000" def test_to_str(self): structure = self.get_structure("Li2O") xsf = XSF(structure) - xsf_str = xsf.to_str() assert ( - xsf_str + xsf.to_str() == """CRYSTAL # Primitive lattice vectors in Angstrom PRIMVEC @@ -36,9 +50,9 @@ def test_to_str(self): Li 3.01213761017484 2.21364440998406 4.74632330032018 Li 1.00309136982516 0.73718000001594 1.58060372967982""" ) - xsf_str = xsf.to_str(atom_symbol=False) + assert ( - xsf_str + xsf.to_str(atom_symbol=False) == """CRYSTAL # Primitive lattice vectors in Angstrom PRIMVEC diff --git a/tests/io/test_xr.py b/tests/io/test_xr.py index 95f24422a84..2497e072b23 100644 --- a/tests/io/test_xr.py +++ b/tests/io/test_xr.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase from pymatgen.core.structure import Structure from pymatgen.io.xr import Xr @@ -14,7 +14,7 @@ __date__ = "June 23, 2016" -class TestXr(unittest.TestCase): +class TestXr(TestCase): def setUp(self): struct = Structure.from_file(f"{VASP_IN_DIR}/POSCAR") self.xr = Xr(struct) diff --git a/tests/io/test_xyz.py b/tests/io/test_xyz.py index 6414e208245..79bc5d37896 100644 --- a/tests/io/test_xyz.py +++ b/tests/io/test_xyz.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import pandas as pd import pytest @@ -12,7 +12,7 @@ from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR -class TestXYZ(unittest.TestCase): +class TestXYZ(TestCase): def setUp(self): coords = [ [0, 0, 0], diff --git a/tests/io/test_zeopp.py b/tests/io/test_zeopp.py index 196f1cbecba..dde217b7561 100644 --- a/tests/io/test_zeopp.py +++ b/tests/io/test_zeopp.py @@ -1,6 +1,7 @@ from __future__ import annotations import unittest +from unittest import TestCase import pytest from pytest import approx @@ -30,7 +31,7 @@ @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestZeoCssr(unittest.TestCase): +class TestZeoCssr(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.zeo_cssr = ZeoCssr(Structure.from_file(filepath)) @@ -73,7 +74,7 @@ def test_from_file(self): @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestZeoCssrOxi(unittest.TestCase): +class TestZeoCssrOxi(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" structure = BVAnalyzer().get_oxi_state_decorated_structure(Structure.from_file(filepath)) @@ -117,7 +118,7 @@ def test_from_file(self): @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestZeoVoronoiXYZ(unittest.TestCase): +class TestZeoVoronoiXYZ(TestCase): def setUp(self): coords = [ [0.000000, 0.000000, 0.000000], @@ -148,7 +149,7 @@ def test_from_file(self): @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestGetVoronoiNodes(unittest.TestCase): +class TestGetVoronoiNodes(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.structure = Structure.from_file(filepath) @@ -157,8 +158,8 @@ def setUp(self): el = [site.species_string for site in self.structure] valence_dict = dict(zip(el, valences)) self.rad_dict = {} - for k, v in valence_dict.items(): - self.rad_dict[k] = float(Species(k, v).ionic_radius) + for key, val in valence_dict.items(): + self.rad_dict[key] = float(Species(key, val).ionic_radius) assert len(self.rad_dict) == len(self.structure.composition) @@ -172,18 +173,11 @@ def test_get_voronoi_nodes(self): @unittest.skip("file free_sph.cif not present") -class TestGetFreeSphereParams(unittest.TestCase): +class TestGetFreeSphereParams(TestCase): def setUp(self): filepath = f"{TEST_FILES_DIR}/free_sph.cif" self.structure = Structure.from_file(filepath) - self.rad_dict = { - "Ge": 0.67, - "P": 0.52, - "S": 1.7, - "La": 1.17, - "Zr": 0.86, - "O": 1.26, - } + self.rad_dict = {"Ge": 0.67, "P": 0.52, "S": 1.7, "La": 1.17, "Zr": 0.86, "O": 1.26} def test_get_free_sphere_params(self): free_sph_params = get_free_sphere_params(self.structure, rad_dict=self.rad_dict) @@ -194,7 +188,7 @@ def test_get_free_sphere_params(self): @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestGetHighAccuracyVoronoiNodes(unittest.TestCase): +class TestGetHighAccuracyVoronoiNodes(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.structure = Structure.from_file(filepath) @@ -214,7 +208,7 @@ def test_get_voronoi_nodes(self): @pytest.mark.skipif(zeo is None, reason="zeo not present.") -class TestGetVoronoiNodesMultiOxi(unittest.TestCase): +class TestGetVoronoiNodesMultiOxi(TestCase): def setUp(self): filepath = f"{VASP_IN_DIR}/POSCAR" self.structure = Structure.from_file(filepath) diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index 579059782cc..a1bc5c2c5ef 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -5,8 +5,8 @@ import os import pickle import re -import unittest from shutil import copyfile +from unittest import TestCase import numpy as np import pytest @@ -212,12 +212,12 @@ def test_significant_figures(self): coords = [[0, 0, 0], [0.75, 0.5, 0.75]] # Silicon structure for testing. - latt = [ + lattice = [ [3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - struct = Structure(latt, [si, si], coords) + struct = Structure(lattice, [si, si], coords) poscar = Poscar(struct) expected_str = """Si2 1.0 @@ -239,12 +239,12 @@ def test_str(self): coords = [[0, 0, 0], [0.75, 0.5, 0.75]] # Silicon structure for testing. - latt = [ + lattice = [ [3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - struct = Structure(latt, [si, si], coords) + struct = Structure(lattice, [si, si], coords) poscar = Poscar(struct) expected_str = """Si2 1.0 @@ -384,12 +384,12 @@ def test_velocities(self): coords = [[0, 0, 0], [0.75, 0.5, 0.75]] # Silicon structure for testing. - latt = [ + lattice = [ [3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] - struct = Structure(latt, [si, si], coords) + struct = Structure(lattice, [si, si], coords) poscar = Poscar(struct) poscar.set_temperature(900) @@ -644,7 +644,7 @@ def test_write(self): assert incar == self.incar def test_get_str(self): - s = self.incar.get_str(pretty=True, sort_keys=True) + incar_str = self.incar.get_str(pretty=True, sort_keys=True) expected = """ALGO = Damped EDIFF = 0.0001 ENCUT = 500 @@ -673,7 +673,7 @@ def test_get_str(self): SIGMA = 0.05 SYSTEM = Id=[0] dblock_code=[97763-icsd] formula=[li mn (p o4)] sg_name=[p n m a] TIME = 0.4""" - assert s == expected + assert incar_str == expected def test_lsorbit_magmom(self): magmom1 = [[0.0, 0.0, 3.0], [0, 1, 0], [2, 1, 2]] @@ -747,16 +747,16 @@ def test_types(self): NBMOD = -3 PREC = Accurate SIGMA = 0.1""" - i = Incar.from_str(incar_str) - assert isinstance(i["EINT"], list) - assert i["EINT"][0] == -0.85 + incar = Incar.from_str(incar_str) + assert isinstance(incar["EINT"], list) + assert incar["EINT"][0] == -0.85 incar_str += "\nLHFCALC = .TRUE. ; HFSCREEN = 0.2" incar_str += "\nALGO = All;" - i = Incar.from_str(incar_str) - assert i["LHFCALC"] - assert i["HFSCREEN"] == 0.2 - assert i["ALGO"] == "All" + incar = Incar.from_str(incar_str) + assert incar["LHFCALC"] + assert incar["HFSCREEN"] == 0.2 + assert incar["ALGO"] == "All" def test_proc_types(self): assert Incar.proc_val("HELLO", "-0.85 0.85") == "-0.85 0.85" @@ -814,12 +814,12 @@ def test_init(self): filepath = f"{VASP_IN_DIR}/KPOINTS_cartesian" kpoints = Kpoints.from_file(filepath) assert kpoints.kpts == [[0.25, 0, 0], [0, 0.25, 0], [0, 0, 0.25]], "Wrong kpoint lattice read" - assert kpoints.kpts_shift == [0.5, 0.5, 0.5], "Wrong kpoint shift read" + assert kpoints.kpts_shift == (0.5, 0.5, 0.5) filepath = f"{VASP_IN_DIR}/KPOINTS" kpoints = Kpoints.from_file(filepath) self.kpoints = kpoints - assert kpoints.kpts == [[2, 4, 6]] + assert kpoints.kpts == [(2, 4, 6)] filepath = f"{VASP_IN_DIR}/KPOINTS_band" kpoints = Kpoints.from_file(filepath) @@ -997,7 +997,7 @@ def test_automatic_monkhorst_vs_gamma_style_selection(self): assert kpoints.style == Kpoints.supported_modes.Gamma -class TestPotcarSingle(unittest.TestCase): +class TestPotcarSingle(TestCase): def setUp(self): self.psingle_Mn_pv = PotcarSingle.from_file(f"{FAKE_POTCAR_DIR}/POT_GGA_PAW_PBE/POTCAR.Mn_pv.gz") self.psingle_Fe = PotcarSingle.from_file(f"{FAKE_POTCAR_DIR}/POT_GGA_PAW_PBE/POTCAR.Fe.gz") @@ -1151,11 +1151,11 @@ def test_multi_potcar_with_and_without_sha256(self): assert psingle.is_valid # def test_default_functional(self): - # p = PotcarSingle.from_symbol_and_functional("Fe") - # assert p.functional_class == "GGA" + # potcar = PotcarSingle.from_symbol_and_functional("Fe") + # assert potcar.functional_class == "GGA" # SETTINGS["PMG_DEFAULT_FUNCTIONAL"] = "LDA" - # p = PotcarSingle.from_symbol_and_functional("Fe") - # assert p.functional_class == "LDA" + # potcar = PotcarSingle.from_symbol_and_functional("Fe") + # assert potcar.functional_class == "LDA" # SETTINGS["PMG_DEFAULT_FUNCTIONAL"] = "PBE" def test_repr(self): @@ -1225,8 +1225,8 @@ def test_as_from_dict(self): def test_write(self): tmp_file = f"{self.tmp_path}/POTCAR.testing" self.potcar.write_file(tmp_file) - p = Potcar.from_file(tmp_file) - assert p.symbols == self.potcar.symbols + potcar = Potcar.from_file(tmp_file) + assert potcar.symbols == self.potcar.symbols with zopen(self.filepath, mode="rt", encoding="utf-8") as f_ref, open(tmp_file, encoding="utf-8") as f_new: ref_potcar = f_ref.readlines() @@ -1246,13 +1246,13 @@ def test_set_symbol(self): assert self.potcar[0].nelectrons == 14 # def test_default_functional(self): - # p = Potcar(["Fe", "P"]) - # assert p[0].functional_class == "GGA" - # assert p[1].functional_class == "GGA" + # potcar = Potcar(["Fe", "P"]) + # assert potcar[0].functional_class == "GGA" + # assert potcar[1].functional_class == "GGA" # SETTINGS["PMG_DEFAULT_FUNCTIONAL"] = "LDA" - # p = Potcar(["Fe", "P"]) - # assert p[0].functional_class == "LDA" - # assert p[1].functional_class == "LDA" + # potcar = Potcar(["Fe", "P"]) + # assert potcar[0].functional_class == "LDA" + # assert potcar[1].functional_class == "LDA" def test_pickle(self): pickle.dumps(self.potcar) @@ -1267,8 +1267,7 @@ def setUp(self): incar = Incar.from_file(filepath) filepath = f"{VASP_IN_DIR}/POSCAR" poscar = Poscar.from_file(filepath, check_for_potcar=False) - if "PMG_VASP_PSP_DIR" not in os.environ: - os.environ["PMG_VASP_PSP_DIR"] = str(TEST_FILES_DIR) + os.environ.setdefault("PMG_VASP_PSP_DIR", str(TEST_FILES_DIR)) filepath = f"{FAKE_POTCAR_DIR}/POTCAR.gz" potcar = Potcar.from_file(filepath) filepath = f"{VASP_IN_DIR}/KPOINTS_auto" diff --git a/tests/io/vasp/test_outputs.py b/tests/io/vasp/test_outputs.py index 438315e9c98..7fa37307090 100644 --- a/tests/io/vasp/test_outputs.py +++ b/tests/io/vasp/test_outputs.py @@ -4,10 +4,10 @@ import json import os import sys -import xml.etree.ElementTree as ElementTree from io import StringIO from pathlib import Path from shutil import copyfile, copyfileobj +from xml.etree import ElementTree import numpy as np import pytest @@ -171,12 +171,12 @@ def test_energies(self): assert vasp_run.final_energy == approx(-11.18986774) # VASP 5.4.1 - o = Vasprun(f"{VASP_OUT_DIR}/vasprun.etest3.xml.gz") - assert o.final_energy == approx(-15.89355325) + vasp_run = Vasprun(f"{VASP_OUT_DIR}/vasprun.etest3.xml.gz") + assert vasp_run.final_energy == approx(-15.89355325) # VASP 6.2.1 - o = Vasprun(f"{VASP_OUT_DIR}/vasprun.etest4.xml.gz") - assert o.final_energy == approx(-15.89364691) + vasp_run = Vasprun(f"{VASP_OUT_DIR}/vasprun.etest4.xml.gz") + assert vasp_run.final_energy == approx(-15.89364691) def test_nonlmn(self): filepath = f"{VASP_OUT_DIR}/vasprun.nonlm.xml.gz" @@ -240,7 +240,8 @@ def test_standard(self): assert vasp_run.structures[i] == step["structure"] assert all( - vasp_run.structures[i] == vasp_run.ionic_steps[i]["structure"] for i in range(len(vasp_run.ionic_steps)) + vasp_run.structures[idx] == vasp_run.ionic_steps[idx]["structure"] + for idx in range(len(vasp_run.ionic_steps)) ) assert total_sc_steps == 308, "Incorrect number of energies read from vasprun.xml" @@ -1023,11 +1024,10 @@ def test_read_lcalcpol(self): p_sp1 = [2.01124, 2.01124, -2.04426] p_sp2 = [2.01139, 2.01139, -2.04426] - for i in range(3): - assert outcar.p_ion[i] == approx(p_ion[i]) - assert outcar.p_elec[i] == approx(p_elec[i]) - assert outcar.p_sp1[i] == approx(p_sp1[i]) - assert outcar.p_sp2[i] == approx(p_sp2[i]) + assert outcar.p_ion == approx(p_ion) + assert outcar.p_elec == approx(p_elec) + assert outcar.p_sp1 == approx(p_sp1) + assert outcar.p_sp2 == approx(p_sp2) # outcar with |e| Angst units filepath = f"{VASP_OUT_DIR}/OUTCAR_vasp_6.3.gz" @@ -1040,11 +1040,10 @@ def test_read_lcalcpol(self): p_sp1 = [4.50564, 0.0, 1.62154] p_sp2 = [4.50563e00, -1.00000e-05, 1.62154e00] - for i in range(3): - assert outcar.p_ion[i] == approx(p_ion[i]) - assert outcar.p_elec[i] == approx(p_elec[i]) - assert outcar.p_sp1[i] == approx(p_sp1[i]) - assert outcar.p_sp2[i] == approx(p_sp2[i]) + assert outcar.p_ion == approx(p_ion) + assert outcar.p_elec == approx(p_elec) + assert outcar.p_sp1 == approx(p_sp1) + assert outcar.p_sp2 == approx(p_sp2) def test_read_piezo_tensor(self): filepath = f"{VASP_OUT_DIR}/OUTCAR.lepsilon.gz" @@ -1306,28 +1305,28 @@ def test_vasp620_format(self): def test_energies(self): # VASP 5.2.1 - o = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest1.gz") - assert o.final_energy == approx(-11.18981538) - assert o.final_energy_wo_entrp == approx(-11.13480014) - assert o.final_fr_energy == approx(-11.21732300) + outcar = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest1.gz") + assert outcar.final_energy == approx(-11.18981538) + assert outcar.final_energy_wo_entrp == approx(-11.13480014) + assert outcar.final_fr_energy == approx(-11.21732300) # VASP 6.2.1 - o = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest2.gz") - assert o.final_energy == approx(-11.18986774) - assert o.final_energy_wo_entrp == approx(-11.13485250) - assert o.final_fr_energy == approx(-11.21737536) + outcar = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest2.gz") + assert outcar.final_energy == approx(-11.18986774) + assert outcar.final_energy_wo_entrp == approx(-11.13485250) + assert outcar.final_fr_energy == approx(-11.21737536) # VASP 5.2.1 - o = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest3.gz") - assert o.final_energy == approx(-15.89355325) - assert o.final_energy_wo_entrp == approx(-15.83853800) - assert o.final_fr_energy == approx(-15.92106087) + outcar = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest3.gz") + assert outcar.final_energy == approx(-15.89355325) + assert outcar.final_energy_wo_entrp == approx(-15.83853800) + assert outcar.final_fr_energy == approx(-15.92106087) # VASP 6.2.1 - o = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest4.gz") - assert o.final_energy == approx(-15.89364691) - assert o.final_energy_wo_entrp == approx(-15.83863167) - assert o.final_fr_energy == approx(-15.92115453) + outcar = Outcar(f"{VASP_OUT_DIR}/OUTCAR.etest4.gz") + assert outcar.final_energy == approx(-15.89364691) + assert outcar.final_energy_wo_entrp == approx(-15.83863167) + assert outcar.final_fr_energy == approx(-15.92115453) def test_read_table_pattern(self): outcar = Outcar(f"{VASP_OUT_DIR}/OUTCAR.gz") @@ -1997,16 +1996,15 @@ def _check(wder): first_line = [int(a) for a in file.readline().split()] assert wder.nkpoints == first_line[1] assert wder.nbands == first_line[2] - for i in range(10): - assert wder.get_orbital_derivative_between_states(0, i, 0, 0, 0).real == approx( - wder_ref[i, 6], abs=1e-10 - ) - assert wder.cder[0, i, 0, 0, 0].real == approx(wder_ref[i, 6], abs=1e-10) - assert wder.cder[0, i, 0, 0, 0].imag == approx(wder_ref[i, 7], abs=1e-10) - assert wder.cder[0, i, 0, 0, 1].real == approx(wder_ref[i, 8], abs=1e-10) - assert wder.cder[0, i, 0, 0, 1].imag == approx(wder_ref[i, 9], abs=1e-10) - assert wder.cder[0, i, 0, 0, 2].real == approx(wder_ref[i, 10], abs=1e-10) - assert wder.cder[0, i, 0, 0, 2].imag == approx(wder_ref[i, 11], abs=1e-10) + assert [wder.get_orbital_derivative_between_states(0, idx, 0, 0, 0).real for idx in range(10)] == approx( + wder_ref[:10, 6], abs=1e-10 + ) + assert wder.cder[0, :10, 0, 0, 0].real == approx(wder_ref[:10, 6], abs=1e-10) + assert wder.cder[0, :10, 0, 0, 0].imag == approx(wder_ref[:10, 7], abs=1e-10) + assert wder.cder[0, :10, 0, 0, 1].real == approx(wder_ref[:10, 8], abs=1e-10) + assert wder.cder[0, :10, 0, 0, 1].imag == approx(wder_ref[:10, 9], abs=1e-10) + assert wder.cder[0, :10, 0, 0, 2].real == approx(wder_ref[:10, 10], abs=1e-10) + assert wder.cder[0, :10, 0, 0, 2].imag == approx(wder_ref[:10, 11], abs=1e-10) wder = Waveder.from_binary(f"{VASP_OUT_DIR}/WAVEDER.Si") _check(wder) diff --git a/tests/io/vasp/test_sets.py b/tests/io/vasp/test_sets.py index ba8bf31dec6..9d7003591a8 100644 --- a/tests/io/vasp/test_sets.py +++ b/tests/io/vasp/test_sets.py @@ -56,8 +56,6 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from pymatgen.util.testing import FAKE_POTCAR_DIR, TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest -dec = MontyDecoder() - MonkeyPatch().setitem(SETTINGS, "PMG_VASP_PSP_DIR", str(FAKE_POTCAR_DIR)) NO_PSP_DIR = SETTINGS.get("PMG_VASP_PSP_DIR") is None @@ -276,8 +274,8 @@ def test_potcar_special_defaults(self): @skip_if_no_psp_dir def test_lda_potcar(self): structure = Structure(self.lattice, ["P", "Fe"], self.coords) - p = self.set(structure, user_potcar_functional="LDA").potcar - assert p.functional == "LDA" + potcar = self.set(structure, user_potcar_functional="LDA").potcar + assert potcar.functional == "LDA" @skip_if_no_psp_dir def test_nelect(self): @@ -528,15 +526,15 @@ def test_as_from_dict(self): ) dct = mit_set.as_dict() - val = dec.process_decoded(dct) + val = MontyDecoder().process_decoded(dct) assert val._config_dict["INCAR"]["LDAUU"]["O"]["Fe"] == 4 dct = mp_set.as_dict() - val = dec.process_decoded(dct) + val = MontyDecoder().process_decoded(dct) assert val._config_dict["INCAR"]["LDAUU"]["O"]["Fe"] == 5.3 dct = mp_user_set.as_dict() - val = dec.process_decoded(dct) + val = MontyDecoder().process_decoded(dct) assert isinstance(val, VaspInputSet) assert val.user_incar_settings["MAGMOM"] == {"Fe": 10, "S": -5, "Mn3+": 100} @@ -1080,15 +1078,15 @@ def test_structure_from_prev_run(self): vrun = Vasprun(f"{VASP_OUT_DIR}/vasprun.magmom_ldau.xml.gz") structure = vrun.final_structure poscar = Poscar(structure) - structure_decorated = get_structure_from_prev_run(vrun) + struct_magmom_decorated = get_structure_from_prev_run(vrun) ldau_ans = {"LDAUU": [5.3, 0.0], "LDAUL": [2, 0], "LDAUJ": [0.0, 0.0]} magmom_ans = [5.0, 5.0, 5.0, 5.0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6] ldau_dict = {} for key in ("LDAUU", "LDAUJ", "LDAUL"): - if hasattr(structure_decorated[0], key.lower()): - m = {site.specie.symbol: getattr(site, key.lower()) for site in structure_decorated} - ldau_dict[key] = [m[sym] for sym in poscar.site_symbols] - magmom = [site.magmom for site in structure_decorated] + if hasattr(struct_magmom_decorated[0], key.lower()): + magmoms = {site.specie.symbol: getattr(site, key.lower()) for site in struct_magmom_decorated} + ldau_dict[key] = [magmoms[sym] for sym in poscar.site_symbols] + magmom = [site.magmom for site in struct_magmom_decorated] assert ldau_dict == ldau_ans assert magmom == magmom_ans @@ -1134,7 +1132,7 @@ def test_params(self): def test_as_from_dict(self): dct = self.mit_md_param.as_dict() - input_set = dec.process_decoded(dct) + input_set = MontyDecoder().process_decoded(dct) assert isinstance(input_set, self.set) assert input_set.incar["TEBEG"] == 300 assert input_set.incar["TEEND"] == 1200 @@ -1183,7 +1181,7 @@ def test_incar(self): def test_as_from_dict(self): dct = self.mvl_npt_set.as_dict() - input_set = dec.process_decoded(dct) + input_set = MontyDecoder().process_decoded(dct) assert isinstance(input_set, MVLNPTMDSet) assert input_set.incar["NSW"] == 1000 @@ -1225,7 +1223,7 @@ def test_incar_ts1(self): def test_as_from_dict(self): dct = self.mp_md_set_noTS.as_dict() - v = dec.process_decoded(dct) + v = MontyDecoder().process_decoded(dct) assert isinstance(v, MPMDSet) assert v.incar["NSW"] == 1000 @@ -1262,7 +1260,7 @@ def test_kpoints(self): def test_as_from_dict(self): dct = self.vis.as_dict() - v = dec.process_decoded(dct) + v = MontyDecoder().process_decoded(dct) assert v.incar["IMAGES"] == 2 @skip_if_no_psp_dir @@ -1270,6 +1268,8 @@ def test_write_input(self): self.vis.write_input(".", write_cif=True, write_endpoint_inputs=True, write_path_cif=True) for file in "INCAR KPOINTS POTCAR 00/POSCAR 01/POSCAR 02/POSCAR 03/POSCAR 00/INCAR path.cif".split(): assert os.path.isfile(file), f"{file=} not written" + # check structures match + assert len(self.vis.structures[0]) + 3 == len(Structure.from_file("path.cif")) assert not os.path.isfile("04/POSCAR") @@ -1608,7 +1608,7 @@ def test_potcar(self): def test_as_from_dict(self): dct = self.mvl_scan_set.as_dict() - v = dec.process_decoded(dct) + v = MontyDecoder().process_decoded(dct) assert isinstance(v, self.set) assert v.incar["METAGGA"] == "Scan" assert v.user_incar_settings["NSW"] == 500 @@ -1712,7 +1712,7 @@ def test_potcar(self): def test_as_from_dict(self): dct = self.mp_scan_set.as_dict() - input_set = dec.process_decoded(dct) + input_set = MontyDecoder().process_decoded(dct) assert isinstance(input_set, MPScanRelaxSet) assert input_set._config_dict["INCAR"]["METAGGA"] == "R2SCAN" assert input_set.user_incar_settings["NSW"] == 500 @@ -1882,7 +1882,7 @@ def test_potcar(self): def test_as_from_dict(self): dct = self.mvl_rlx_set.as_dict() - vasp_input = dec.process_decoded(dct) + vasp_input = MontyDecoder().process_decoded(dct) assert isinstance(vasp_input, self.set) assert vasp_input.incar["NSW"] == 500 @@ -2061,7 +2061,7 @@ def test_as_from_dict(self): prev_run = f"{TEST_FILES_DIR}/vasp/fixtures/absorption/static" absorption_ipa = MPAbsorptionSet.from_prev_calc(prev_calc_dir=prev_run, mode="IPA") dct = absorption_ipa.as_dict() - vasp_input = dec.process_decoded(dct) + vasp_input = MontyDecoder().process_decoded(dct) assert vasp_input.incar["ALGO"] == "Exact" assert vasp_input.incar["LOPTICS"] assert vasp_input.incar["GGA"] == "Ps" @@ -2069,7 +2069,7 @@ def test_as_from_dict(self): prev_run = f"{TEST_FILES_DIR}/vasp/fixtures/absorption/ipa" absorption_rpa = MPAbsorptionSet.from_prev_calc(prev_run, mode="RPA") dct = absorption_rpa.as_dict() - vasp_input = dec.process_decoded(dct) + vasp_input = MontyDecoder().process_decoded(dct) assert vasp_input.incar["ALGO"] == "Chi" assert vasp_input.incar["NBANDS"] == 48 assert vasp_input.incar["GGA"] == "Ps" diff --git a/tests/optimization/test_linear_assignment.py b/tests/optimization/test_linear_assignment.py index 508388ec4d5..961a5032b6d 100644 --- a/tests/optimization/test_linear_assignment.py +++ b/tests/optimization/test_linear_assignment.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -9,7 +9,7 @@ from pymatgen.optimization.linear_assignment import LinearAssignment -class TestLinearAssignment(unittest.TestCase): +class TestLinearAssignment(TestCase): def test(self): w0 = np.array( [ diff --git a/tests/optimization/test_neighbors.py b/tests/optimization/test_neighbors.py index 49f8a5cae33..a6374c1f929 100644 --- a/tests/optimization/test_neighbors.py +++ b/tests/optimization/test_neighbors.py @@ -7,7 +7,7 @@ from pymatgen.util.testing import PymatgenTest -class NeighborsTestCase(PymatgenTest): +class TestNeighbors(PymatgenTest): def setUp(self): self.lattice = Lattice.cubic(10.0) self.cubic = self.lattice diff --git a/tests/phonon/test_dos.py b/tests/phonon/test_dos.py index 893616421ef..5cac5e8f0e5 100644 --- a/tests/phonon/test_dos.py +++ b/tests/phonon/test_dos.py @@ -45,7 +45,7 @@ def test_get_smeared_densities(self): def test_dict_methods(self): json_str = json.dumps(self.dos.as_dict()) - assert json_str is not None + assert json_str.startswith('{"@module": "pymatgen.phonon.dos", "@class": "PhononDos", "frequencies":') self.assert_msonable(self.dos) def test_thermodynamic_functions(self): diff --git a/tests/phonon/test_init.py b/tests/phonon/test_init.py index f454409568d..f2339b57c21 100644 --- a/tests/phonon/test_init.py +++ b/tests/phonon/test_init.py @@ -2,9 +2,8 @@ import pymatgen.phonon as ph import pymatgen.phonon.bandstructure as bs -import pymatgen.phonon.dos as dos import pymatgen.phonon.gruneisen as gru -import pymatgen.phonon.plotter as plotter +from pymatgen.phonon import dos, plotter def test_convenience_imports(): diff --git a/tests/phonon/test_plotter.py b/tests/phonon/test_plotter.py index 756312ff91a..48d2cac7447 100644 --- a/tests/phonon/test_plotter.py +++ b/tests/phonon/test_plotter.py @@ -1,24 +1,26 @@ from __future__ import annotations import json -import unittest +from unittest import TestCase -from matplotlib import axes, rc +import matplotlib.pyplot as plt +import pytest +from numpy.testing import assert_allclose -from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine -from pymatgen.phonon.dos import CompletePhononDos +from pymatgen.phonon import CompletePhononDos, PhononBandStructureSymmLine from pymatgen.phonon.plotter import PhononBSPlotter, PhononDosPlotter, ThermoPlotter from pymatgen.util.testing import TEST_FILES_DIR -rc("text", usetex=False) # Disabling latex for testing +plt.rc("text", usetex=False) # Disabling latex for testing -class TestPhononDosPlotter(unittest.TestCase): +class TestPhononDosPlotter(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/NaCl_complete_ph_dos.json") as file: self.dos = CompletePhononDos.from_dict(json.load(file)) - self.plotter = PhononDosPlotter(sigma=0.2, stack=True) - self.plotter_no_stack = PhononDosPlotter(sigma=0.2, stack=False) + self.plotter = PhononDosPlotter(sigma=0.2, stack=True) + self.plotter_no_stack = PhononDosPlotter(sigma=0.2, stack=False) + self.plotter_no_sigma = PhononDosPlotter(sigma=None, stack=False) def test_add_dos_dict(self): dct = self.plotter.get_dos_dict() @@ -30,39 +32,49 @@ def test_add_dos_dict(self): def test_get_dos_dict(self): self.plotter.add_dos_dict(self.dos.get_element_dos(), key_sort_func=lambda x: x.X) dct = self.plotter.get_dos_dict() - for el in ["Na", "Cl"]: - assert el in dct + assert {*dct} >= {"Na", "Cl"} def test_plot(self): self.plotter.add_dos("Total", self.dos) self.plotter.get_plot(units="mev") self.plotter_no_stack.add_dos("Total", self.dos) ax = self.plotter_no_stack.get_plot(units="mev") - assert isinstance(ax, axes.Axes) + assert isinstance(ax, plt.Axes) assert ax.get_ylabel() == "$\\mathrm{Density\\ of\\ states}$" assert ax.get_xlabel() == "$\\mathrm{Frequencies\\ (meV)}$" + self.plotter_no_sigma.add_dos("Total", self.dos) + ax2 = self.plotter_no_sigma.get_plot(units="mev") + assert_allclose(ax2.get_ylim(), (min(self.dos.densities), max(self.dos.densities))) + ax3 = self.plotter_no_sigma.get_plot(units="mev", invert_axes=True) + assert ax3.get_ylabel() == "$\\mathrm{Frequencies\\ (meV)}$" + assert ax3.get_xlabel() == "$\\mathrm{Density\\ of\\ states}$" + assert_allclose(ax3.get_xlim(), (min(self.dos.densities), max(self.dos.densities))) + assert ax3.get_ylim() == ax.get_xlim() -class TestPhononBSPlotter(unittest.TestCase): +class TestPhononBSPlotter(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/NaCl_phonon_bandstructure.json") as file: dct = json.loads(file.read()) - self.bs = PhononBandStructureSymmLine.from_dict(dct) - self.plotter = PhononBSPlotter(self.bs) + self.bs = PhononBandStructureSymmLine.from_dict(dct) + self.plotter = PhononBSPlotter(self.bs, label="NaCl") with open(f"{TEST_FILES_DIR}/SrTiO3_phonon_bandstructure.json") as file: dct = json.loads(file.read()) - self.bs_sto = PhononBandStructureSymmLine.from_dict(dct) - self.plotter_sto = PhononBSPlotter(self.bs_sto) + self.bs_sto = PhononBandStructureSymmLine.from_dict(dct) + self.plotter_sto = PhononBSPlotter(self.bs_sto) def test_bs_plot_data(self): assert len(self.plotter.bs_plot_data()["distances"][0]) == 51, "wrong number of distances in the first branch" assert len(self.plotter.bs_plot_data()["distances"]) == 4, "wrong number of branches" - assert sum(len(e) for e in self.plotter.bs_plot_data()["distances"]) == 204, "wrong number of distances" + assert sum(len(dist) for dist in self.plotter.bs_plot_data()["distances"]) == 204, "wrong number of distances" assert self.plotter.bs_plot_data()["ticks"]["label"][4] == "Y", "wrong tick label" assert len(self.plotter.bs_plot_data()["ticks"]["label"]) == 8, "wrong number of tick labels" def test_plot(self): - self.plotter.get_plot(units="mev") + ax = self.plotter.get_plot(units="mev") + assert isinstance(ax, plt.Axes) + assert ax.get_ylabel() == "$\\mathrm{Frequencies\\ (meV)}$" + assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$" def test_proj_plot(self): self.plotter.get_proj_plot(units="mev") @@ -77,26 +89,42 @@ def test_proj_plot(self): def test_plot_compare(self): labels = ("NaCl", "NaCl 2") - ax = self.plotter.plot_compare(self.plotter, units="mev", labels=labels) - assert isinstance(ax, axes.Axes) + ax = self.plotter.plot_compare({labels[1]: self.plotter}, units="mev") + assert isinstance(ax, plt.Axes) assert ax.get_ylabel() == "$\\mathrm{Frequencies\\ (meV)}$" assert ax.get_xlabel() == "$\\mathrm{Wave\\ Vector}$" assert ax.get_title() == "" assert [itm.get_text() for itm in ax.get_legend().get_texts()] == list(labels) + ax = self.plotter.plot_compare(self.plotter, units="mev") + assert [itm.get_text() for itm in ax.get_legend().get_texts()] == ["NaCl", "NaCl"] + labels = ("NaCl", "NaCl 2", "NaCl 3") + ax = self.plotter.plot_compare({labels[1]: self.plotter, labels[2]: self.plotter}, units="mev") + assert [itm.get_text() for itm in ax.get_legend().get_texts()] == list(labels) + colors = tuple([itm.get_color() for itm in ax.get_legend().get_lines()]) + assert colors == ("blue", "red", "green") + with pytest.raises(ValueError, match="The two band structures are not compatible."): + self.plotter.plot_compare(self.plotter_sto) + ax = self.plotter.plot_compare(self.plotter_sto, on_incompatible="ignore") + assert ax is None -class TestThermoPlotter(unittest.TestCase): +class TestThermoPlotter(TestCase): def setUp(self): with open(f"{TEST_FILES_DIR}/NaCl_complete_ph_dos.json") as file: self.dos = CompletePhononDos.from_dict(json.load(file)) - self.plotter = ThermoPlotter(self.dos, self.dos.structure) + self.plotter = ThermoPlotter(self.dos, self.dos.structure) def test_plot_functions(self): - self.plotter.plot_cv(5, 100, 5, show=False) - self.plotter.plot_entropy(5, 100, 5, show=False) - self.plotter.plot_internal_energy(5, 100, 5, show=False) - self.plotter.plot_helmholtz_free_energy(5, 100, 5, show=False) - self.plotter.plot_thermodynamic_properties(5, 100, 5, show=False, fig_close=True) + fig = self.plotter.plot_cv(5, 100, 5, show=False) + assert isinstance(fig, plt.Figure) + fig = self.plotter.plot_entropy(5, 100, 5, show=False) + assert isinstance(fig, plt.Figure) + fig = self.plotter.plot_internal_energy(5, 100, 5, show=False) + assert isinstance(fig, plt.Figure) + fig = self.plotter.plot_helmholtz_free_energy(5, 100, 5, show=False) + assert isinstance(fig, plt.Figure) + fig = self.plotter.plot_thermodynamic_properties(5, 100, 5, show=False, fig_close=True) + assert isinstance(fig, plt.Figure) # Gruneisen plotter is already tested in test_gruneisen diff --git a/tests/symmetry/test_analyzer.py b/tests/symmetry/test_analyzer.py index ffe76bb8ed8..bb101146520 100644 --- a/tests/symmetry/test_analyzer.py +++ b/tests/symmetry/test_analyzer.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np import pytest @@ -359,26 +359,26 @@ def test_tricky_structure(self): # for some reason this structure kills spglib1.9 # 1.7 can't find symmetry either, but at least doesn't kill python struct = Structure.from_file(f"{VASP_IN_DIR}/POSCAR_tricky_symmetry") - sa = SpacegroupAnalyzer(struct, 0.1) - assert sa.get_space_group_symbol() == "I4/mmm" - assert sa.get_space_group_number() == 139 - assert sa.get_point_group_symbol() == "4/mmm" - assert sa.get_crystal_system() == "tetragonal" - assert sa.get_hall() == "-I 4 2" + spg_analyzer = SpacegroupAnalyzer(struct, 0.1) + assert spg_analyzer.get_space_group_symbol() == "I4/mmm" + assert spg_analyzer.get_space_group_number() == 139 + assert spg_analyzer.get_point_group_symbol() == "4/mmm" + assert spg_analyzer.get_crystal_system() == "tetragonal" + assert spg_analyzer.get_hall() == "-I 4 2" -class TestSpacegroup(unittest.TestCase): +class TestSpacegroup(TestCase): def setUp(self): self.structure = Structure.from_file(f"{VASP_IN_DIR}/POSCAR") self.sg1 = SpacegroupAnalyzer(self.structure, 0.001).get_space_group_operations() def test_are_symmetrically_equivalent(self): - sites1 = [self.structure[i] for i in [0, 1]] - sites2 = [self.structure[i] for i in [2, 3]] + sites1 = [self.structure[idx] for idx in [0, 1]] + sites2 = [self.structure[idx] for idx in [2, 3]] assert self.sg1.are_symmetrically_equivalent(sites1, sites2, 1e-3) - sites1 = [self.structure[i] for i in [0, 1]] - sites2 = [self.structure[i] for i in [0, 2]] + sites1 = [self.structure[idx] for idx in [0, 1]] + sites2 = [self.structure[idx] for idx in [0, 2]] assert not self.sg1.are_symmetrically_equivalent(sites1, sites2, 1e-3) @@ -623,7 +623,7 @@ def test_get_kpoint_weights(self): spga.get_kpoint_weights(kpts) -class TestFunc(unittest.TestCase): +class TestFunc(TestCase): def test_cluster_sites(self): site, cluster = cluster_sites(CH4, 0.1) assert isinstance(site, Site) diff --git a/tests/symmetry/test_groups.py b/tests/symmetry/test_groups.py index 60dc3e4032f..273a6a6d6c3 100644 --- a/tests/symmetry/test_groups.py +++ b/tests/symmetry/test_groups.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import numpy as np import pytest from pytest import approx @@ -34,7 +32,7 @@ ).split() -class TestPointGroup(unittest.TestCase): +class TestPointGroup: def test_order(self): orders = {"mmm": 8, "432": 24, "-6m2": 12} for key, val in orders.items(): @@ -64,7 +62,7 @@ def test_is_sub_super_group(self): assert not pg_m3m.is_supergroup(pg_6mmm) -class TestSpaceGroup(unittest.TestCase): +class TestSpaceGroup: def test_renamed_e_symbols(self): assert SpaceGroup.from_int_number(64).symbol == "Cmce" @@ -123,18 +121,18 @@ def test_crystal_system(self): def test_get_orbit(self): sg = SpaceGroup("Fm-3m") - p = np.random.randint(0, 100 + 1, size=(3,)) / 100 - assert len(sg.get_orbit(p)) <= sg.order + rand_percent = np.random.randint(0, 100 + 1, size=(3,)) / 100 + assert len(sg.get_orbit(rand_percent)) <= sg.order def test_get_orbit_and_generators(self): sg = SpaceGroup("Fm-3m") - p = np.random.randint(0, 100 + 1, size=(3,)) / 100 - orbit, generators = sg.get_orbit_and_generators(p) + rand_percent = np.random.randint(0, 100 + 1, size=(3,)) / 100 + orbit, generators = sg.get_orbit_and_generators(rand_percent) assert len(orbit) <= sg.order pp = generators[0].operate(orbit[0]) - assert p[0] == approx(pp[0]) - assert p[1] == approx(pp[1]) - assert p[2] == approx(pp[2]) + assert rand_percent[0] == approx(pp[0]) + assert rand_percent[1] == approx(pp[1]) + assert rand_percent[2] == approx(pp[2]) def test_is_compatible(self): cubic = Lattice.cubic(1) diff --git a/tests/symmetry/test_settings.py b/tests/symmetry/test_settings.py index 447b176e780..2b6fe82ea26 100644 --- a/tests/symmetry/test_settings.py +++ b/tests/symmetry/test_settings.py @@ -1,6 +1,6 @@ from __future__ import annotations -import unittest +from unittest import TestCase import numpy as np from numpy.testing import assert_allclose @@ -16,7 +16,7 @@ __date__ = "Apr 2017" -class TestJonesFaithfulTransformation(unittest.TestCase): +class TestJonesFaithfulTransformation(TestCase): def setUp(self): self.test_strings = [ "a,b,c;0,0,0", # identity diff --git a/tests/transformations/test_advanced_transformations.py b/tests/transformations/test_advanced_transformations.py index 2f0f2b91c4a..c2f70bc382e 100644 --- a/tests/transformations/test_advanced_transformations.py +++ b/tests/transformations/test_advanced_transformations.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import unittest from shutil import which import numpy as np @@ -71,7 +70,7 @@ def get_table(): enumlib_present = enum_cmd and makestr_cmd -class TestSuperTransformation(unittest.TestCase): +class TestSuperTransformation: def test_apply_transformation(self): trafo = SuperTransformation( [SubstitutionTransformation({"Li+": "Na+"}), SubstitutionTransformation({"Li+": "K+"})] @@ -93,32 +92,32 @@ def test_apply_transformation(self): [0.00, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "Li+", "Li+", "O2-", "O2-"], coords) - s = trafo.apply_transformation(struct, return_ranked_list=True) + struct_trafo = trafo.apply_transformation(struct, return_ranked_list=True) - for s_and_t in s: + for s_and_t in struct_trafo: assert s_and_t["transformation"].apply_transformation(struct) == s_and_t["structure"] @pytest.mark.skipif(not enumlib_present, reason="enum_lib not present.") def test_apply_transformation_mult(self): # Test returning multiple structures from each transformation. - disord = Structure( + disordered = Structure( np.eye(3) * 4.209, [{"Cs+": 0.5, "K+": 0.5}, "Cl-"], [[0, 0, 0], [0.5, 0.5, 0.5]], ) - disord.make_supercell([2, 2, 1]) + disordered.make_supercell([2, 2, 1]) tl = [ EnumerateStructureTransformation(), OrderDisorderedStructureTransformation(), ] trafo = SuperTransformation(tl, nstructures_per_trans=10) - assert len(trafo.apply_transformation(disord, return_ranked_list=20)) == 8 + assert len(trafo.apply_transformation(disordered, return_ranked_list=20)) == 8 trafo = SuperTransformation(tl) - assert len(trafo.apply_transformation(disord, return_ranked_list=20)) == 2 + assert len(trafo.apply_transformation(disordered, return_ranked_list=20)) == 2 -class TestMultipleSubstitutionTransformation(unittest.TestCase): +class TestMultipleSubstitutionTransformation: def test_apply_transformation(self): sub_dict = {1: ["Na", "K"]} trafo = MultipleSubstitutionTransformation("Li+", 0.5, sub_dict, None) @@ -132,7 +131,7 @@ def test_apply_transformation(self): assert len(trafo.apply_transformation(struct, return_ranked_list=True)) == 2 -class TestChargeBalanceTransformation(unittest.TestCase): +class TestChargeBalanceTransformation: def test_apply_transformation(self): trafo = ChargeBalanceTransformation("Li+") coords = [ @@ -152,13 +151,13 @@ def test_apply_transformation(self): [0.00, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "Li+", "Li+", "O2-", "O2-"], coords) - s = trafo.apply_transformation(struct) + struct_trafo = trafo.apply_transformation(struct) - assert s.charge == approx(0, abs=1e-5) + assert struct_trafo.charge == approx(0, abs=1e-5) @pytest.mark.skipif(not enumlib_present, reason="enum_lib not present.") -class TestEnumerateStructureTransformation(unittest.TestCase): +class TestEnumerateStructureTransformation: def test_apply_transformation(self): enum_trans = EnumerateStructureTransformation(refine_structure=True) enum_trans2 = EnumerateStructureTransformation(refine_structure=True, sort_criteria="nsites") @@ -166,28 +165,28 @@ def test_apply_transformation(self): expected = [1, 3, 1] for idx, frac in enumerate([0.25, 0.5, 0.75]): trans = SubstitutionTransformation({"Fe": {"Fe": frac}}) - s = trans.apply_transformation(struct) + struct_trafo = trans.apply_transformation(struct) oxi_trans = OxidationStateDecorationTransformation({"Li": 1, "Fe": 2, "P": 5, "O": -2}) - s = oxi_trans.apply_transformation(s) - alls = enum_trans.apply_transformation(s, 100) + struct_trafo = oxi_trans.apply_transformation(struct_trafo) + alls = enum_trans.apply_transformation(struct_trafo, 100) assert len(alls) == expected[idx] - assert isinstance(trans.apply_transformation(s), Structure) + assert isinstance(trans.apply_transformation(struct_trafo), Structure) for ss in alls: assert "energy" in ss - alls = enum_trans2.apply_transformation(s, 100) + alls = enum_trans2.apply_transformation(struct_trafo, 100) assert len(alls) == expected[idx] - assert isinstance(trans.apply_transformation(s), Structure) + assert isinstance(trans.apply_transformation(struct_trafo), Structure) for ss in alls: assert "num_sites" in ss # make sure it works for non-oxidation state decorated structure trans = SubstitutionTransformation({"Fe": {"Fe": 0.5}}) - s = trans.apply_transformation(struct) - alls = enum_trans.apply_transformation(s, 100) + struct_trafo = trans.apply_transformation(struct) + alls = enum_trans.apply_transformation(struct_trafo, 100) assert len(alls) == 3 - assert isinstance(trans.apply_transformation(s), Structure) - for s in alls: - assert "energy" not in s + assert isinstance(trans.apply_transformation(struct_trafo), Structure) + for struct_trafo in alls: + assert "energy" not in struct_trafo @pytest.mark.skip("TODO remove skip once https://github.com/materialsvirtuallab/matgl/issues/238 is resolved") def test_m3gnet(self): @@ -195,10 +194,10 @@ def test_m3gnet(self): enum_trans = EnumerateStructureTransformation(refine_structure=True, sort_criteria="m3gnet_relax") struct = Structure.from_file(f"{VASP_IN_DIR}/POSCAR_LiFePO4") trans = SubstitutionTransformation({"Fe": {"Fe": 0.5, "Mn": 0.5}}) - s = trans.apply_transformation(struct) - alls = enum_trans.apply_transformation(s, 100) + struct_trafo = trans.apply_transformation(struct) + alls = enum_trans.apply_transformation(struct_trafo, 100) assert len(alls) == 3 - assert isinstance(trans.apply_transformation(s), Structure) + assert isinstance(trans.apply_transformation(struct_trafo), Structure) for ss in alls: assert "energy" in ss @@ -222,10 +221,10 @@ def sort_criteria(struct: Structure) -> tuple[Structure, float]: enum_trans = EnumerateStructureTransformation(refine_structure=True, sort_criteria=sort_criteria) struct = Structure.from_file(f"{VASP_IN_DIR}/POSCAR_LiFePO4") trans = SubstitutionTransformation({"Fe": {"Fe": 0.5, "Mn": 0.5}}) - s = trans.apply_transformation(struct) - alls = enum_trans.apply_transformation(s, 100) + struct_trafo = trans.apply_transformation(struct) + alls = enum_trans.apply_transformation(struct_trafo, 100) assert len(alls) == 3 - assert isinstance(trans.apply_transformation(s), Structure) + assert isinstance(trans.apply_transformation(struct_trafo), Structure) for ss in alls: assert "energy" in ss @@ -251,7 +250,7 @@ def test_as_from_dict(self): assert trans.symm_prec == 0.1 -class TestSubstitutionPredictorTransformation(unittest.TestCase): +class TestSubstitutionPredictorTransformation: def test_apply_transformation(self): trafo = SubstitutionPredictorTransformation(threshold=1e-3, alpha=-5, lambda_table=get_table()) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5]] @@ -276,21 +275,21 @@ def test_as_dict(self): @pytest.mark.skipif(not enumlib_present, reason="enum_lib not present.") class TestMagOrderingTransformation(PymatgenTest): def setUp(self): - latt = Lattice.cubic(4.17) + lattice = Lattice.cubic(4.17) species = ["Ni", "O"] coords = [[0, 0, 0], [0.5, 0.5, 0.5]] - self.NiO = Structure.from_spacegroup(225, latt, species, coords) + self.NiO = Structure.from_spacegroup(225, lattice, species, coords) - latt = Lattice([[2.085, 2.085, 0.0], [0.0, -2.085, -2.085], [-2.085, 2.085, -4.17]]) + lattice = Lattice([[2.085, 2.085, 0.0], [0.0, -2.085, -2.085], [-2.085, 2.085, -4.17]]) species = ["Ni", "Ni", "O", "O"] coords = [[0.5, 0, 0.5], [0, 0, 0], [0.25, 0.5, 0.25], [0.75, 0.5, 0.75]] - self.NiO_AFM_111 = Structure(latt, species, coords) + self.NiO_AFM_111 = Structure(lattice, species, coords) self.NiO_AFM_111.add_spin_by_site([-5, 5, 0, 0]) - latt = Lattice([[2.085, 2.085, 0], [0, 0, -4.17], [-2.085, 2.085, 0]]) + lattice = Lattice([[2.085, 2.085, 0], [0, 0, -4.17], [-2.085, 2.085, 0]]) species = ["Ni", "Ni", "O", "O"] coords = [[0.5, 0.5, 0.5], [0, 0, 0], [0, 0.5, 0], [0.5, 0, 0.5]] - self.NiO_AFM_001 = Structure(latt, species, coords) + self.NiO_AFM_001 = Structure(lattice, species, coords) self.NiO_AFM_001.add_spin_by_site([-5, 5, 0, 0]) self.Fe3O4 = Structure.from_file(f"{TEST_FILES_DIR}/Fe3O4.cif") diff --git a/tests/transformations/test_site_transformations.py b/tests/transformations/test_site_transformations.py index d22a09d61d8..3b30670fad3 100644 --- a/tests/transformations/test_site_transformations.py +++ b/tests/transformations/test_site_transformations.py @@ -1,7 +1,7 @@ from __future__ import annotations -import unittest from shutil import which +from unittest import TestCase import numpy as np import pytest @@ -78,7 +78,7 @@ def test_as_from_dict(self): str(t2) -class TestReplaceSiteSpeciesTransformation(unittest.TestCase): +class TestReplaceSiteSpeciesTransformation(TestCase): def setUp(self): coords = [ [0, 0, 0], @@ -110,7 +110,7 @@ def test_as_from_dict(self): assert struct.formula == "Na1 Li3 O4" -class TestRemoveSitesTransformation(unittest.TestCase): +class TestRemoveSitesTransformation(TestCase): def setUp(self): coords = [ [0, 0, 0], @@ -142,7 +142,7 @@ def test_as_from_dict(self): assert struct.formula == "Li2 O4" -class TestInsertSitesTransformation(unittest.TestCase): +class TestInsertSitesTransformation(TestCase): def setUp(self): coords = [ [0, 0, 0], @@ -179,7 +179,7 @@ def test_as_from_dict(self): assert struct.formula == "Li4 Mn1 Fe1 O4" -class TestPartialRemoveSitesTransformation(unittest.TestCase): +class TestPartialRemoveSitesTransformation(TestCase): def setUp(self): coords = [ [0, 0, 0], diff --git a/tests/transformations/test_standard_transformations.py b/tests/transformations/test_standard_transformations.py index b1dc36b0acd..bdbe2826dbb 100644 --- a/tests/transformations/test_standard_transformations.py +++ b/tests/transformations/test_standard_transformations.py @@ -1,13 +1,10 @@ -# ruff: noqa: N806 - from __future__ import annotations import functools import json import operator -import random -import unittest from shutil import which +from unittest import TestCase import numpy as np import pytest @@ -43,13 +40,13 @@ enumlib_present = which("enum.x") and which("makestr.x") -class TestRotationTransformations(unittest.TestCase): +class TestRotationTransformations(TestCase): def setUp(self): coords = [[0, 0, 0], [0.75, 0.5, 0.75]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] self.struct = Structure(lattice, ["Si"] * 2, coords) @@ -65,14 +62,14 @@ def test_rotation_transformation(self): assert (abs(s1.lattice.matrix - self.struct.lattice.matrix) < 1e-8).all() -class TestRemoveSpeciesTransformation(unittest.TestCase): +class TestRemoveSpeciesTransformation: def test_apply_transformation(self): trafo = RemoveSpeciesTransformation(["Li+"]) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) struct = trafo.apply_transformation(struct) @@ -82,18 +79,18 @@ def test_apply_transformation(self): assert isinstance(RemoveSpeciesTransformation.from_dict(dct), RemoveSpeciesTransformation) -class TestSubstitutionTransformation(unittest.TestCase): +class TestSubstitutionTransformation: def test_apply_transformation(self): trafo = SubstitutionTransformation({"Li+": "Na+", "O2-": "S2-"}) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) - s = trafo.apply_transformation(struct) - assert s.formula == "Na2 S2" + struct_trafo = trafo.apply_transformation(struct) + assert struct_trafo.formula == "Na2 S2" def test_fractional_substitution(self): trafo = SubstitutionTransformation({"Li+": "Na+", "O2-": {"S2-": 0.5, "Se2-": 0.5}}) @@ -101,22 +98,22 @@ def test_fractional_substitution(self): trafo = SubstitutionTransformation.from_dict(trafo.as_dict()) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) - s = trafo.apply_transformation(struct) - assert s.formula == "Na2 Se1 S1" + struct_trafo = trafo.apply_transformation(struct) + assert struct_trafo.formula == "Na2 Se1 S1" -class TestSupercellTransformation(unittest.TestCase): +class TestSupercellTransformation(TestCase): def setUp(self): coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] self.struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) @@ -126,7 +123,7 @@ def test_apply_transformation(self): assert struct.formula == "Li16 O16" def test_from_scaling_factors(self): - scale_factors = [random.randint(1, 5) for i in range(3)] + scale_factors = np.random.randint(1, 5, 3) trafo = SupercellTransformation.from_scaling_factors(*scale_factors) struct = trafo.apply_transformation(self.struct) assert len(struct) == 4 * functools.reduce(operator.mul, scale_factors) @@ -169,24 +166,24 @@ def test_from_boundary_distance(self): SupercellTransformation.from_boundary_distance(structure=self.struct, max_atoms=max_atoms) -class TestOxidationStateDecorationTransformation(unittest.TestCase): +class TestOxidationStateDecorationTransformation: def test_apply_transformation(self): trafo = OxidationStateDecorationTransformation({"Li": 1, "O": -2}) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li", "Li", "O", "O"], coords) - s = trafo.apply_transformation(struct) - assert s[0].species_string == "Li+" - assert s[2].species_string == "O2-" + struct_trafo = trafo.apply_transformation(struct) + assert struct_trafo[0].species_string == "Li+" + assert struct_trafo[2].species_string == "O2-" dct = trafo.as_dict() assert isinstance(OxidationStateDecorationTransformation.from_dict(dct), OxidationStateDecorationTransformation) -class TestAutoOxiStateDecorationTransformation(unittest.TestCase): +class TestAutoOxiStateDecorationTransformation: def test_apply_transformation(self): trafo = AutoOxiStateDecorationTransformation() struct = trafo.apply_transformation(Structure.from_file(f"{VASP_IN_DIR}/POSCAR_LiFePO4")) @@ -201,33 +198,33 @@ def test_as_from_dict(self): assert trafo.analyzer.dist_scale_factor == 1.015 -class TestOxidationStateRemovalTransformation(unittest.TestCase): +class TestOxidationStateRemovalTransformation: def test_apply_transformation(self): trafo = OxidationStateRemovalTransformation() coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) - s = trafo.apply_transformation(struct) - assert s[0].species_string == "Li" - assert s[2].species_string == "O" + struct_trafo = trafo.apply_transformation(struct) + assert struct_trafo[0].species_string == "Li" + assert struct_trafo[2].species_string == "O" dct = trafo.as_dict() assert isinstance(OxidationStateRemovalTransformation.from_dict(dct), OxidationStateRemovalTransformation) @pytest.mark.skipif(not enumlib_present, reason="enum_lib not present.") -class TestPartialRemoveSpecieTransformation(unittest.TestCase): +class TestPartialRemoveSpecieTransformation: def test_apply_transformation(self): trafo = PartialRemoveSpecieTransformation("Li+", 1.0 / 3, 3) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "O2-"], coords) assert len(trafo.apply_transformation(struct, 100)) == 2 @@ -238,7 +235,7 @@ def test_apply_transformation(self): def test_apply_transformation_fast(self): trafo = PartialRemoveSpecieTransformation("Li+", 0.5) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25], [0.1, 0.1, 0.1], [0.3, 0.75, 0.3]] - lattice = Lattice([[10, 0.00, 0.00], [0, 10, 0.00], [0.00, 0, 10]]) + lattice = Lattice([[10, 0, 0], [0, 10, 0], [0, 0, 10]]) struct = Structure(lattice, ["Li+"] * 6, coords) fast_opt_s = trafo.apply_transformation(struct) trafo = PartialRemoveSpecieTransformation("Li+", 0.5, PartialRemoveSpecieTransformation.ALGO_COMPLETE) @@ -259,14 +256,14 @@ def test_apply_transformations_best_first(self): assert len(trafo.apply_transformation(struct)) == 26 -class TestOrderDisorderedStructureTransformation(unittest.TestCase): +class TestOrderDisorderedStructureTransformation: def test_apply_transformation(self): trafo = OrderDisorderedStructureTransformation() coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure( @@ -311,21 +308,21 @@ def test_apply_transformation(self): def test_no_oxidation(self): specie = {"Cu1+": 0.5, "Au2+": 0.5} - cuau = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3.677), [specie], [[0, 0, 0]]) + cu_au = Structure.from_spacegroup("Fm-3m", Lattice.cubic(3.677), [specie], [[0, 0, 0]]) trans = OrderDisorderedStructureTransformation() - ss = trans.apply_transformation(cuau, return_ranked_list=100) + ss = trans.apply_transformation(cu_au, return_ranked_list=100) assert ss[0]["structure"].composition["Cu+"] == 2 trans = OrderDisorderedStructureTransformation(no_oxi_states=True) - ss = trans.apply_transformation(cuau, return_ranked_list=100) + ss = trans.apply_transformation(cu_au, return_ranked_list=100) assert ss[0]["structure"].composition["Cu+"] == 0 assert ss[0]["structure"].composition["Cu"] == 2 def test_symmetrized_structure(self): trafo = OrderDisorderedStructureTransformation(symmetrized_structures=True) - latt = Lattice.cubic(5) + lattice = Lattice.cubic(5) coords = [[0.5, 0.5, 0.5], [0.45, 0.45, 0.45], [0.56, 0.56, 0.56], [0.25, 0.75, 0.75], [0.75, 0.25, 0.25]] - struct = Structure(latt, [{"Si4+": 1}, *[{"Si4+": 0.5}] * 4], coords) - test_site = PeriodicSite("Si4+", coords[2], latt) + struct = Structure(lattice, [{"Si4+": 1}, *[{"Si4+": 0.5}] * 4], coords) + test_site = PeriodicSite("Si4+", coords[2], lattice) struct = SymmetrizedStructure(struct, "not_real", [0, 1, 1, 2, 2], ["a", "b", "b", "c", "c"]) output = trafo.apply_transformation(struct) assert test_site in output @@ -342,9 +339,9 @@ def test_best_first(self): trafo = OrderDisorderedStructureTransformation(algo=2) coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure( @@ -361,7 +358,7 @@ def test_best_first(self): assert output[0]["energy"] == approx(-234.57813667648315, abs=1e-4) -class TestPrimitiveCellTransformation(unittest.TestCase): +class TestPrimitiveCellTransformation: def test_apply_transformation(self): trafo = PrimitiveCellTransformation() coords = [ @@ -376,9 +373,9 @@ def test_apply_transformation(self): ] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "O2-", "O2-", "O2-", "O2-"], coords) struct = trafo.apply_transformation(struct) @@ -393,21 +390,21 @@ def test_apply_transformation(self): assert isinstance(PrimitiveCellTransformation.from_dict(dct), PrimitiveCellTransformation) -class TestConventionalCellTransformation(unittest.TestCase): +class TestConventionalCellTransformation: def test_apply_transformation(self): trafo = ConventionalCellTransformation() coords = [[0, 0, 0], [0.75, 0.75, 0.75], [0.5, 0.5, 0.5], [0.25, 0.25, 0.25]] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "O2-", "O2-"], coords) conventional_struct = trafo.apply_transformation(struct) assert conventional_struct.lattice.alpha == 90 -class TestPerturbStructureTransformation(unittest.TestCase): +class TestPerturbStructureTransformation: def test_apply_transformation(self): trafo = PerturbStructureTransformation(0.05) coords = [ @@ -422,9 +419,9 @@ def test_apply_transformation(self): ] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "O2-", "O2-", "O2-", "O2-"], coords) transformed_struct = trafo.apply_transformation(struct) @@ -444,7 +441,7 @@ def test_apply_transformation(self): assert isinstance(PerturbStructureTransformation.from_dict(dct), PerturbStructureTransformation) -class TestDeformStructureTransformation(unittest.TestCase): +class TestDeformStructureTransformation: def test_apply_transformation(self): trafo = DeformStructureTransformation([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.05, 1.0]]) coords = [ @@ -459,9 +456,9 @@ def test_apply_transformation(self): ] lattice = [ - [3.8401979337, 0.00, 0.00], - [1.9200989668, 3.3257101909, 0.00], - [0.00, -2.2171384943, 3.1355090603], + [3.8401979337, 0, 0], + [1.9200989668, 3.3257101909, 0], + [0, -2.2171384943, 3.1355090603], ] struct = Structure(lattice, ["Li+", "Li+", "Li+", "Li+", "O2-", "O2-", "O2-", "O2-"], coords) transformed_s = trafo.apply_transformation(struct) @@ -473,32 +470,32 @@ def test_apply_transformation(self): assert isinstance(DeformStructureTransformation.from_dict(dct), DeformStructureTransformation) -class TestDiscretizeOccupanciesTransformation(unittest.TestCase): +class TestDiscretizeOccupanciesTransformation: def test_apply_transformation(self): - latt = Lattice.cubic(4) - s_orig = Structure( - latt, + lattice = Lattice.cubic(4) + struct_orig = Structure( + lattice, [{"Li": 0.19, "Na": 0.19, "K": 0.62}, {"O": 1}], [[0, 0, 0], [0.5, 0.5, 0.5]], ) dot = DiscretizeOccupanciesTransformation(max_denominator=5, tol=0.5) - s = dot.apply_transformation(s_orig) - assert dict(s[0].species) == {Element("Li"): 0.2, Element("Na"): 0.2, Element("K"): 0.6} + struct = dot.apply_transformation(struct_orig) + assert dict(struct[0].species) == {Element("Li"): 0.2, Element("Na"): 0.2, Element("K"): 0.6} dot = DiscretizeOccupanciesTransformation(max_denominator=5, tol=0.01) with pytest.raises(RuntimeError, match="Cannot discretize structure within tolerance!"): - dot.apply_transformation(s_orig) + dot.apply_transformation(struct_orig) - s_orig_2 = Structure( - latt, + struct_orig_2 = Structure( + lattice, [{"Li": 0.5, "Na": 0.25, "K": 0.25}, {"O": 1}], [[0, 0, 0], [0.5, 0.5, 0.5]], ) dot = DiscretizeOccupanciesTransformation(max_denominator=9, tol=0.25, fix_denominator=False) - s = dot.apply_transformation(s_orig_2) - assert dict(s[0].species) == { + struct = dot.apply_transformation(struct_orig_2) + assert dict(struct[0].species) == { Element("Li"): Fraction(1 / 2), Element("Na"): Fraction(1 / 4), Element("K"): Fraction(1 / 4), @@ -506,47 +503,46 @@ def test_apply_transformation(self): dot = DiscretizeOccupanciesTransformation(max_denominator=9, tol=0.05, fix_denominator=True) with pytest.raises(RuntimeError, match="Cannot discretize structure within tolerance"): - dot.apply_transformation(s_orig_2) + dot.apply_transformation(struct_orig_2) -class TestChargedCellTransformation(unittest.TestCase): +class TestChargedCellTransformation: def test_apply_transformation(self): - lattice = Lattice.cubic(4) - s_orig = Structure( - lattice, + struct_orig = Structure( + np.eye(3) * 4, [{"Li": 0.19, "Na": 0.19, "K": 0.62}, {"O": 1}], [[0, 0, 0], [0.5, 0.5, 0.5]], ) cct = ChargedCellTransformation(charge=3) - s = cct.apply_transformation(s_orig) - assert s.charge == 3 + struct_trafo = cct.apply_transformation(struct_orig) + assert struct_trafo.charge == 3 -class TestScaleToRelaxedTransformation(unittest.TestCase): +class TestScaleToRelaxedTransformation: def test_apply_transformation(self): # Test on slab relaxation where volume is fixed - f = f"{TEST_FILES_DIR}/surface_tests" - Cu_fin = Structure.from_file(f"{f}/Cu_slab_fin.cif") - Cu_init = Structure.from_file(f"{f}/Cu_slab_init.cif") + surf_dir = f"{TEST_FILES_DIR}/surfaces" + Cu_fin = Structure.from_file(f"{surf_dir}/Cu_slab_fin.cif") + Cu_init = Structure.from_file(f"{surf_dir}/Cu_slab_init.cif") slab_scaling = ScaleToRelaxedTransformation(Cu_init, Cu_fin) - Au_init = Structure.from_file(f"{f}/Au_slab_init.cif") + Au_init = Structure.from_file(f"{surf_dir}/Au_slab_init.cif") Au_fin = slab_scaling.apply_transformation(Au_init) assert Au_fin.volume == approx(Au_init.volume) # Test on gb relaxation - f = f"{TEST_FILES_DIR}/grain_boundary" - Be_fin = Structure.from_file(f"{f}/Be_gb_fin.cif") - Be_init = Structure.from_file(f"{f}/Be_gb_init.cif") - Zn_init = Structure.from_file(f"{f}/Zn_gb_init.cif") + gb_dir = f"{TEST_FILES_DIR}/grain_boundary" + Be_fin = Structure.from_file(f"{gb_dir}/Be_gb_fin.cif") + Be_init = Structure.from_file(f"{gb_dir}/Be_gb_init.cif") + Zn_init = Structure.from_file(f"{gb_dir}/Zn_gb_init.cif") gb_scaling = ScaleToRelaxedTransformation(Be_init, Be_fin) Zn_fin = gb_scaling.apply_transformation(Zn_init) assert all(site.species_string == "Zn" for site in Zn_fin) assert (Be_init.lattice.a < Be_fin.lattice.a) == (Zn_init.lattice.a < Zn_fin.lattice.a) assert (Be_init.lattice.b < Be_fin.lattice.b) == (Zn_init.lattice.b < Zn_fin.lattice.b) assert (Be_init.lattice.c < Be_fin.lattice.c) == (Zn_init.lattice.c < Zn_fin.lattice.c) - Fe_fin = Structure.from_file(f"{f}/Fe_gb_fin.cif") - Fe_init = Structure.from_file(f"{f}/Fe_gb_init.cif") - Mo_init = Structure.from_file(f"{f}/Mo_gb_init.cif") + Fe_fin = Structure.from_file(f"{gb_dir}/Fe_gb_fin.cif") + Fe_init = Structure.from_file(f"{gb_dir}/Fe_gb_init.cif") + Mo_init = Structure.from_file(f"{gb_dir}/Mo_gb_init.cif") gb_scaling = ScaleToRelaxedTransformation(Fe_init, Fe_fin) Mo_fin = gb_scaling.apply_transformation(Mo_init) assert all(site.species_string == "Mo" for site in Mo_fin) diff --git a/tests/util/test_coord.py b/tests/util/test_coord.py index 9538a59526c..15cfab9d7dc 100644 --- a/tests/util/test_coord.py +++ b/tests/util/test_coord.py @@ -1,7 +1,7 @@ from __future__ import annotations import random -import unittest +from unittest import TestCase import numpy as np import pytest @@ -236,7 +236,7 @@ def test_get_angle(self): assert coord.get_angle(v1, v2, units="radians") == approx(0.9553166181245092) -class TestSimplex(unittest.TestCase): +class TestSimplex(TestCase): def setUp(self): coords = [[0, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]] self.simplex = coord.Simplex(coords) @@ -273,19 +273,19 @@ def test_str(self): assert repr(self.simplex).startswith("3-simplex in 4D space") def test_bary_coords(self): - s = coord.Simplex([[0, 2], [3, 1], [1, 0]]) + simplex = coord.Simplex([[0, 2], [3, 1], [1, 0]]) point = [0.7, 0.5] - bc = s.bary_coords(point) + bc = simplex.bary_coords(point) assert_allclose(bc, [0.26, -0.02, 0.76]) - new_point = s.point_from_bary_coords(bc) + new_point = simplex.point_from_bary_coords(bc) assert_allclose(point, new_point) def test_intersection(self): # simple test, with 2 intersections at faces - s = coord.Simplex([[0, 2], [3, 1], [1, 0]]) + simplex = coord.Simplex([[0, 2], [3, 1], [1, 0]]) point1 = [0.7, 0.5] point2 = [0.5, 0.7] - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) expected = np.array([[1.13333333, 0.06666667], [0.8, 0.4]]) assert_allclose(intersections, expected) @@ -293,14 +293,14 @@ def test_intersection(self): point1 = [0, 2] # simplex point point2 = [1, 1] # inside simplex expected = np.array([[1.66666667, 0.33333333], [0, 2]]) - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) assert_allclose(intersections, expected) # intersection through point only point1 = [0, 2] # simplex point point2 = [0.5, 0.7] expected = np.array([[0, 2]]) - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) assert_allclose(intersections, expected) # 3d intersection through edge and face @@ -321,21 +321,21 @@ def test_intersection(self): point1 = [-1, 2] point2 = [0, 0] expected = np.array([]) - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) assert_allclose(intersections, expected) # coplanar to face (with intersection line) point1 = [0, 2] # simplex point point2 = [1, 0] expected = np.array([[1, 0], [0, 2]]) - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) assert_allclose(intersections, expected) # coplanar to face (with intersection points) point1 = [0.1, 2] point2 = [1.1, 0] expected = np.array([[1.08, 0.04], [0.12, 1.96]]) - intersections = s.line_intersection(point1, point2) + intersections = simplex.line_intersection(point1, point2) assert_allclose(intersections, expected) def test_to_json(self): diff --git a/tests/util/test_plotting.py b/tests/util/test_plotting.py index 0fda04627a4..ae0b6eca76b 100644 --- a/tests/util/test_plotting.py +++ b/tests/util/test_plotting.py @@ -12,14 +12,14 @@ pymatviz = None -class FuncTestCase(PymatgenTest): +class TestFunc(PymatgenTest): def test_plot_periodic_heatmap(self): random_data = {"Te": 0.11083, "Au": 0.75756, "Th": 1.24758, "Ni": -2.0354} - ret_val = periodic_table_heatmap(random_data) + fig = periodic_table_heatmap(random_data) if pymatviz: - assert isinstance(ret_val, Figure) + assert isinstance(fig, Figure) else: - assert ret_val is plt + assert isinstance(fig, plt.Axes) # Test all keywords periodic_table_heatmap( @@ -39,6 +39,10 @@ def test_plot_periodic_heatmap(self): def test_van_arkel_triangle(self): random_list = [("Fe", "C"), ("Ni", "F")] - ret_val = van_arkel_triangle(random_list) - assert ret_val is plt - van_arkel_triangle(random_list, annotate=True) + ax = van_arkel_triangle(random_list) + assert isinstance(ax, plt.Axes) + assert ax.get_title() == "" + assert ax.get_xlabel() == r"$\frac{\chi_{A}+\chi_{B}}{2}$" + assert ax.get_ylabel() == r"$|\chi_{A}-\chi_{B}|$" + ax = van_arkel_triangle(random_list, annotate=True) + assert isinstance(ax, plt.Axes) diff --git a/tests/util/test_provenance.py b/tests/util/test_provenance.py index cec2a69d754..08bc1c0cedc 100644 --- a/tests/util/test_provenance.py +++ b/tests/util/test_provenance.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime -import unittest +from unittest import TestCase import numpy as np import pytest @@ -22,7 +22,7 @@ __date__ = "2/14/13" -class StructureNLCase(unittest.TestCase): +class StructureNLCase(TestCase): def setUp(self): # set up a Structure self.struct = Structure(np.eye(3, 3) * 3, ["Fe"], [[0, 0, 0]]) diff --git a/tests/util/test_string.py b/tests/util/test_string.py index 73f39a3c6e4..5c04a1d51ea 100644 --- a/tests/util/test_string.py +++ b/tests/util/test_string.py @@ -1,7 +1,5 @@ from __future__ import annotations -import unittest - import numpy as np import pytest @@ -36,7 +34,7 @@ def __str__(self): return "Fe**2+" -class TestStringify(unittest.TestCase): +class TestStringify: def test_to_latex_string(self): assert SubStr().to_latex_string() == "Fe$_{8}$O$_{12}$" assert SupStr().to_latex_string() == "Fe$^{2+}$" @@ -50,7 +48,7 @@ def test_to_unicode_string(self): assert SupStr().to_unicode_string() == "Feยฒโบ" -class TestFunc(unittest.TestCase): +class TestFunc: def test_latexify(self): assert latexify("Li3Fe2(PO4)3") == "Li$_{3}$Fe$_{2}$(PO$_{4}$)$_{3}$" assert latexify("Li0.2Na0.8Cl") == "Li$_{0.2}$Na$_{0.8}$Cl" diff --git a/tests/util/test_typing.py b/tests/util/test_typing.py index eb1ba9dbf7b..6112e05a933 100644 --- a/tests/util/test_typing.py +++ b/tests/util/test_typing.py @@ -1,24 +1,35 @@ +"""This module tests types are as expected and can be imported without circular ImportError.""" + +# mypy: disable-error-code="misc" + from __future__ import annotations -from typing import Any +import sys +from pathlib import Path +from types import GenericAlias +from typing import Any, get_args -# pymatgen.entries needs to be imported before pymatgen.util.typing -# to avoid circular import. -from pymatgen.entries import Entry -from pymatgen.util.typing import CompositionLike, EntryLike, PathLike, SpeciesLike +import pytest -# This module tests types are as expected and can be imported without circular ImportError. +from pymatgen.core import Composition, DummySpecies, Element, Species +from pymatgen.entries import Entry +from pymatgen.util.typing import CompositionLike, EntryLike, PathLike, PbcLike, SpeciesLike __author__ = "Janosh Riebesell" __date__ = "2022-10-20" __email__ = "janosh@lbl.gov" +skip_below_py310 = pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher") + def _type_str(some_type: Any) -> str: return str(some_type).replace("typing.", "").replace("pymatgen.core.periodic_table.", "") def test_entry_like(): + # needs to be tested as string to avoid + # TypeError: issubclass() arg 2 must be a class, a tuple of classes, or a union + # since EntryLike is defined as Union[] of strings to avoid circular imports entries = ( "Entry", "ComputedEntry", @@ -36,16 +47,32 @@ def test_entry_like(): assert Entry.__name__ in str(EntryLike) +@skip_below_py310 def test_species_like(): - assert _type_str(SpeciesLike) == "Union[str, Element, Species, DummySpecies]" + assert isinstance("H", SpeciesLike) + assert isinstance(Element("H"), SpeciesLike) + assert isinstance(Species("H+"), SpeciesLike) + assert isinstance(DummySpecies("X"), SpeciesLike) +@skip_below_py310 def test_composition_like(): - assert ( - _type_str(CompositionLike) - == "Union[str, Element, Species, DummySpecies, dict, pymatgen.core.composition.Composition]" - ) + assert isinstance("H", CompositionLike) + assert isinstance(Element("H"), CompositionLike) + assert isinstance(Species("H+"), CompositionLike) + assert isinstance(Composition("H"), CompositionLike) + assert isinstance({"H": 1}, CompositionLike) + assert isinstance(DummySpecies("X"), CompositionLike) + + +def test_pbc_like(): + assert type(PbcLike) == GenericAlias + assert get_args(PbcLike) == (bool, bool, bool) -def test_path_like(): - assert _type_str(PathLike) == "Union[str, pathlib.Path]" +@skip_below_py310 +def test_pathlike(): + assert isinstance("path/to/file", PathLike) + assert isinstance(Path("path/to/file"), PathLike) + assert not isinstance(1, PathLike) + assert not isinstance(1.0, PathLike)