Skip to content

Commit

Permalink
abiml: add support for orb and sevenn
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Sep 4, 2024
1 parent 72b2a5d commit fd09860
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions abipy/ml/aseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,8 @@ class CalcBuilder:
"nequip",
"metatensor",
"deepmd",
"orb",
"sevenn",
]

def __init__(self, name: str, dftd3_args=None, **kwargs):
Expand Down Expand Up @@ -1740,6 +1742,45 @@ class MyDpCalculator(_MyCalculator, DP):
cls = MyDpCalculator if with_delta else DP
calc = cls(self.model_path, **self.calc_kwargs)

elif self.nn_type == "orb":
try:
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator
except ImportError as exc:
raise ImportError("orb not installed. See https://github.com/orbital-materials/orb-models") from exc


class MyOrbCalculator(_MyCalculator, ORBCalculator):
"""Add abi_forces and abi_stress"""

model_name = "orb-v1" if self.model_name is None else self.model_name
# Mapping model_name --> function returning the model e.g. {"orb-v1": orb_v1}
f = pretrained.ORB_PRETRAINED_MODELS[model_name]
model = f()

cls = MyOrbCalculator if with_delta else OrbCalculator
calc = cls(model, **self.calc_kwargs)

elif self.nn_type == "sevenn":
try:
from sevenn.sevennet_calculator import SevenNetCalculator
except ImportError as exc:
raise ImportError("sevenn not installed. See https://github.com/MDIL-SNU/SevenNet") from exc

class MySevenNetCalculator(_MyCalculator, SevenNetCalculator):
"""Add abi_forces and abi_stress"""

# 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...
# model_name = "7net-0" if self.model_name is None else self.model_name
# SevenNet-0 (11July2024)
# This model was trained on MPtrj. We suggest starting with this model as we found that it performs better
# than the previous SevenNet-0 (22May2024).
# Check Matbench Discovery leaderborad for this model's performance on materials discovery. For more information, click here.

model_name = "SevenNet-0" if self.model_name is None else self.model_name
cls = MySevenNetCalculator if with_delta else SevenNetCalculator
calc = MySevenNetCalculator(model=model_name, **self.calc_kwargs)

else:
raise ValueError(f"Invalid {self.nn_type=}")

Expand Down

0 comments on commit fd09860

Please sign in to comment.