Skip to content

Commit

Permalink
Merge pull request #12 from mattwthompson/new-qcarhive
Browse files Browse the repository at this point in the history
Use new QCArchive APIs
  • Loading branch information
mattwthompson authored Dec 5, 2023
2 parents 1b661cf + 58277d5 commit bf3dcb4
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 22 deletions.
8 changes: 4 additions & 4 deletions devtools/conda-envs/dev.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
name: ib-dev
channels:
- conda-forge
- openeye
- conda-forge
dependencies:
- python

- openff-toolkit =0.14.3
- openff-qcsubmit
- openff-qcsubmit =0.50.1
- openmmforcefields
- smirnoff-plugins =2023.08.0
- espaloma
# espaloma =0.3

- ipython
- ipdb
- pre-commit

- openeye-toolkits
- openeye::openeye-toolkits
- rich

- pytest
Expand Down
4 changes: 2 additions & 2 deletions ibstore/_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DBQMConformerRecord(DBBase):
id = Column(Integer, primary_key=True, index=True)
parent_id = Column(Integer, ForeignKey("molecules.id"), nullable=False, index=True)

qcarchive_id = Column(String(20), nullable=False)
qcarchive_id = Column(Integer, nullable=False)

mapped_smiles = Column(String, nullable=False)
coordinates = Column(PickleType, nullable=False)
Expand All @@ -33,7 +33,7 @@ class DBMMConformerRecord(DBBase):
id = Column(Integer, primary_key=True, index=True)
parent_id = Column(Integer, ForeignKey("molecules.id"), nullable=False, index=True)

qcarchive_id = Column(String(20), nullable=False)
qcarchive_id = Column(Integer, nullable=False)
force_field = Column(String, nullable=False)

mapped_smiles = Column(String, nullable=False)
Expand Down
8 changes: 4 additions & 4 deletions ibstore/_forcefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _espaloma(molecule: Molecule, force_field_name: str) -> openmm.System:
espaloma.graphs.deploy.openmm_system_from_graph, where it will be appended
with .offxml. Raises a ValueError if there is no dash in force_field_name.
"""
import espaloma as esp
import espaloma

if not force_field_name.startswith("espaloma"):
raise NotImplementedError(f"Force field {force_field_name} not implemented.")
Expand All @@ -85,8 +85,8 @@ def _espaloma(molecule: Molecule, force_field_name: str) -> openmm.System:
else:
ff = ff[0]

mol_graph = esp.Graph(molecule)
model = esp.get_model("latest")
mol_graph = espaloma.Graph(molecule)
model = espaloma.get_model("latest")
model(mol_graph.heterograph)

return esp.graphs.deploy.openmm_system_from_graph(mol_graph, forcefield=ff)
return espaloma.graphs.deploy.openmm_system_from_graph(mol_graph, forcefield=ff)
4 changes: 2 additions & 2 deletions ibstore/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def store_qm_conformer_record(

def _qm_conformer_already_exists(
self,
qcarchive_id: str,
qcarchive_id: int,
) -> bool:
records = self.db.query(
DBQMConformerRecord.qcarchive_id,
Expand All @@ -214,7 +214,7 @@ def store_mm_conformer_record(

def _mm_conformer_already_exists(
self,
qcarchive_id: str,
qcarchive_id: int,
force_field: str,
) -> bool:
records = (
Expand Down
4 changes: 4 additions & 0 deletions ibstore/_tests/unit_tests/test_forcefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ def test_gaff_unsupported(molecule):


def test_espaloma_basic(molecule):
pytest.importorskip("espaloma")

system = _espaloma(molecule, "espaloma-openff_unconstrained-2.1.0")

assert isinstance(system, openmm.System)
assert system.getNumParticles() == molecule.n_atoms


def test_espaloma_unsupported(molecule):
pytest.importorskip("espaloma")

with pytest.raises(NotImplementedError):
_espaloma(molecule, "foo")
2 changes: 1 addition & 1 deletion ibstore/_tests/unit_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_load_from_qcsubmit(small_collection):
assert molecule_record.inchi_key == ichi_key

assert isinstance(qm_conformer, QMConformerRecord)
assert qm_conformer.energy == qc_record.get_final_energy() * hartree2kcalmol
assert qm_conformer.energy == qc_record.energies[-1] * hartree2kcalmol
assert qm_conformer.qcarchive_id == qc_record.id
assert numpy.allclose(
qm_conformer.coordinates,
Expand Down
6 changes: 3 additions & 3 deletions ibstore/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class DDE(ImmutableModel):
qcarchive_id: str
qcarchive_id: int
force_field: str
difference: float

Expand All @@ -25,7 +25,7 @@ def to_csv(self, path: str):


class RMSD(ImmutableModel):
qcarchive_id: str
qcarchive_id: int
force_field: str
rmsd: float

Expand All @@ -43,7 +43,7 @@ def to_csv(self, path: str):


class TFD(ImmutableModel):
qcarchive_id: str
qcarchive_id: int
force_field: str
tfd: float

Expand Down
11 changes: 5 additions & 6 deletions ibstore/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import TypeVar
from typing import Any, TypeVar

import qcelemental
from openff.toolkit import Molecule
from pydantic import Field
from qcportal.models.records import OptimizationRecord

from ibstore._base.array import Array
from ibstore._base.base import ImmutableModel
Expand All @@ -26,7 +25,7 @@ class QMConformerRecord(Record):
...,
description="The ID of the molecule in the database",
)
qcarchive_id: str = Field(
qcarchive_id: int = Field(
...,
description="The ID of the molecule in the QCArchive database",
)
Expand All @@ -50,15 +49,15 @@ def from_qcarchive_record(
cls,
molecule_id: int,
mapped_smiles: str,
qc_record: OptimizationRecord,
qc_record: Any, # qcportal.optimization.OptimizationRecord ?
coordinates,
):
return cls(
molecule_id=molecule_id,
qcarchive_id=qc_record.id,
mapped_smiles=mapped_smiles,
coordinates=coordinates,
energy=qc_record.get_final_energy() * hartree2kcalmol,
energy=qc_record.energies[-1] * hartree2kcalmol,
)


Expand All @@ -67,7 +66,7 @@ class MMConformerRecord(Record):
...,
description="The ID of the molecule in the database",
)
qcarchive_id: str = Field(
qcarchive_id: int = Field(
...,
description="The ID of the molecule in the QCArchive database that this conformer corresponds to",
)
Expand Down

0 comments on commit bf3dcb4

Please sign in to comment.