Skip to content

Commit

Permalink
Csse layout 536 input_data (#358)
Browse files Browse the repository at this point in the history
* AtRes.input_data

* external_input_data

* fix print

* Apply suggestions from code review

* Update Lint.yml

* Update procedures.py
  • Loading branch information
loriab authored Nov 27, 2024
1 parent c38fa3c commit 8b56ec8
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 48 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/Lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
python-version: "3.8"
- name: Install black
run: pip install "black>=22.1.0,<23.0a0"
- name: Print code formatting with black
run: black --diff .
- name: Check code formatting with black
run: black --check .

Expand Down
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ New Features

Enhancements
++++++++++++
- (536b) ``v1.AtomicResult.convert_v`` learned a ``external_input_data`` option to inject that field (if known) rather than using incomplete reconstruction from the v1 Result. may not be the final sol'n.
- (536b) ``v2.FailedOperation`` gained schema_name and schema_version=2.
- (536b) ``v2.AtomicResult`` no longer inherits from ``v2.AtomicInput``. It gained a ``input_data`` field for the corresponding ``AtomicInput`` and independent ``id`` and ``molecule`` fields (the latter being equivalvent to ``v1.AtomicResult.molecule`` with the frame of the results; ``v2.AtomicResult.input_data.molecule`` is new, preserving the input frame). Gained independent ``extras``
- (536b) Both v1/v2 ``AtomicResult.convert_v()`` learned to handle the new ``input_data`` layout.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``success`` field enforced to ``True``. Previously it could be set T/F. Now validation errors if not T. Likewise ``v2.FailedOperation.success`` is enforced to ``False``.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``error`` field removed. This isn't used now that ``success=True`` and failure should be routed to ``FailedOperation``.
- (:pr:`357`) ``v1.Molecule`` had its schema_version changed to a Literal[2] (remember Mol is one-ahead of general numbering scheme) so new instances will be 2 even if another value is passed in. Ditto ``v2.BasisSet.schema_version=2``. Ditto ``v1.BasisSet.schema_version=1`` Ditto ``v1.QCInputSpecification.schema_version=1`` and ``v1.OptimizationSpecification.schema_version=1``.
Expand Down
10 changes: 5 additions & 5 deletions qcelemental/models/v1/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]]
dself["input_specification"].pop("schema_version", None)
Expand Down Expand Up @@ -356,11 +355,12 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

dself["input_specification"].pop("schema_version", None)
dself["optimization_spec"].pop("schema_version", None)
dself["optimization_history"] = {
(k, [opthist_class(**res).convert_v(version) for res in lst])
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

Expand Down
47 changes: 43 additions & 4 deletions qcelemental/models/v1/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,29 @@ def _native_file_protocol(cls, value, values):
return ret

def convert_v(
self, version: int
self,
version: int,
*,
external_input_data: Optional[Any] = None,
) -> Union["qcelemental.models.v1.AtomicResult", "qcelemental.models.v2.AtomicResult"]:
"""Convert to instance of particular QCSchema version."""
"""Convert to instance of particular QCSchema version.
Parameters
----------
version
The version to convert to.
external_input_data
Since self contains data merged from input, this allows passing in the original input, particularly for `molecule` and `extras` fields.
Can be model or dictionary and should be *already* converted to the desired version.
Replaces ``input_data`` field entirely (not merges with extracts from self) and w/o consistency checking.
Returns
-------
AtomicResult
Returns self (not a copy) if ``version`` already satisfied.
Returns a new AtomicResult of ``version`` otherwise.
"""
import qcelemental as qcel

if check_convertible_version(version, error="AtomicResult") == "self":
Expand All @@ -808,8 +828,27 @@ def convert_v(
dself = self.dict()
if version == 2:
# remove harmless empty error field that v2 won't accept. if populated, pydantic will catch it.
if dself.pop("error", None):
pass
dself.pop("error", None)

input_data = {
k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"]
}
input_data["molecule"] = dself["molecule"] # duplicate since input mol has been overwritten
# any input provenance has been overwritten
input_data["extras"] = {
k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in []
} # sep any merged extras
if external_input_data:
# Note: overwriting with external, not updating. reconsider?
dself["input_data"] = external_input_data
in_extras = (
external_input_data.get("extras", {})
if isinstance(external_input_data, dict)
else external_input_data.extras
)
dself["extras"] = {k: v for k, v in dself["extras"].items() if (k, v) not in in_extras.items()}
else:
dself["input_data"] = input_data

self_vN = qcel.models.v2.AtomicResult(**dself)

Expand Down
17 changes: 17 additions & 0 deletions qcelemental/models/v2/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ class FailedOperation(ProtoModel):
and containing the reason and input data which generated the failure.
"""

schema_name: Literal["qcschema_failed_operation"] = Field(
"qcschema_failed_operation",
description=(
f"The QCSchema specification this model conforms to. Explicitly fixed as qcschema_failed_operation."
),
)
schema_version: Literal[2] = Field(
2,
description="The version number of :attr:`~qcelemental.models.FailedOperation.schema_name` to which this model conforms.",
)
id: Optional[str] = Field( # type: ignore
None,
description="A unique identifier which links this FailedOperation, often of the same Id of the operation "
Expand Down Expand Up @@ -132,6 +142,10 @@ class FailedOperation(ProtoModel):
def __repr_args__(self) -> "ReprArgs":
return [("error", self.error)]

@field_validator("schema_version", mode="before")
def _version_stamp(cls, v):
return 2

def convert_v(
self, version: int
) -> Union["qcelemental.models.v1.FailedOperation", "qcelemental.models.v2.FailedOperation"]:
Expand All @@ -143,6 +157,9 @@ def convert_v(

dself = self.model_dump()
if version == 1:
dself.pop("schema_name")
dself.pop("schema_version")

self_vN = qcel.models.v1.FailedOperation(**dself)

return self_vN
Expand Down
14 changes: 14 additions & 0 deletions qcelemental/models/v2/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ def convert_v(

dself = self.model_dump()
if version == 1:
trajectory_class = self.trajectory[0].__class__

dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]]
dself["input_specification"].pop("schema_version", None)

self_vN = qcel.models.v1.OptimizationResult(**dself)

return self_vN
Expand Down Expand Up @@ -297,6 +301,9 @@ def convert_v(

dself = self.model_dump()
if version == 1:
if dself["optimization_spec"].pop("extras", None):
pass

self_vN = qcel.models.v1.TorsionDriveInput(**dself)

return self_vN
Expand Down Expand Up @@ -350,9 +357,16 @@ def convert_v(

dself = self.model_dump()
if version == 1:
opthist_class = next(iter(self.optimization_history.values()))[0].__class__

if dself["optimization_spec"].pop("extras", None):
pass

dself["optimization_history"] = {
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

self_vN = qcel.models.v1.TorsionDriveResult(**dself)

return self_vN
43 changes: 32 additions & 11 deletions qcelemental/models/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def convert_v(
return self_vN


class AtomicResult(AtomicInput):
class AtomicResult(ProtoModel):
r"""Results from a CMS program execution."""

schema_name: constr(strip_whitespace=True, pattern=r"^(qc\_?schema_output)$") = Field( # type: ignore
Expand All @@ -736,6 +736,9 @@ class AtomicResult(AtomicInput):
2,
description="The version number of :attr:`~qcelemental.models.AtomicResult.schema_name` to which this model conforms.",
)
id: Optional[str] = Field(None, description="The optional ID for the computation.")
input_data: AtomicInput = Field(..., description=str(AtomicInput.__doc__))
molecule: Molecule = Field(..., description="The molecule with frame and orientation of the results.")
properties: AtomicResultProperties = Field(..., description=str(AtomicResultProperties.__doc__))
wavefunction: Optional[WavefunctionProperties] = Field(None, description=str(WavefunctionProperties.__doc__))

Expand All @@ -755,6 +758,10 @@ class AtomicResult(AtomicInput):
True, description="The success of program execution. If False, other fields may be blank."
)
provenance: Provenance = Field(..., description=str(Provenance.__doc__))
extras: Dict[str, Any] = Field(
{},
description="Additional information to bundle with the computation. Use for schema development and scratch space.",
)

@field_validator("schema_name", mode="before")
@classmethod
Expand All @@ -774,12 +781,16 @@ def _version_stamp(cls, v):
@field_validator("return_result")
@classmethod
def _validate_return_result(cls, v, info):
if info.data["driver"] == "energy":
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")
driver = info.data["input_data"].driver
if driver == "energy":
if isinstance(v, np.ndarray) and v.size == 1:
v = v.item(0)
elif info.data["driver"] == "gradient":
elif driver == "gradient":
v = np.asarray(v).reshape(-1, 3)
elif info.data["driver"] == "hessian":
elif driver == "hessian":
v = np.asarray(v)
nsq = int(v.size**0.5)
v.shape = (nsq, nsq)
Expand All @@ -800,8 +811,8 @@ def _wavefunction_protocol(cls, value, info):
raise ValueError("wavefunction must be None, a dict, or a WavefunctionProperties object.")

# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

# Handle restricted
restricted = wfn.get("restricted", None)
Expand All @@ -814,7 +825,7 @@ def _wavefunction_protocol(cls, value, info):
wfn.pop(k)

# Handle protocols
wfnp = info.data["protocols"].wavefunction
wfnp = info.data["input_data"].protocols.wavefunction
return_keep = None
if wfnp == "all":
pass
Expand Down Expand Up @@ -861,10 +872,10 @@ def _wavefunction_protocol(cls, value, info):
@classmethod
def _stdout_protocol(cls, value, info):
# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

outp = info.data["protocols"].stdout
outp = info.data["input_data"].protocols.stdout
if outp is True:
return value
elif outp is False:
Expand All @@ -875,7 +886,11 @@ def _stdout_protocol(cls, value, info):
@field_validator("native_files")
@classmethod
def _native_file_protocol(cls, value, info):
ancp = info.data["protocols"].native_files
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

ancp = info.data["input_data"].protocols.native_files
if ancp == "all":
return value
elif ancp == "none":
Expand Down Expand Up @@ -905,6 +920,12 @@ def convert_v(

dself = self.model_dump()
if version == 1:
# input_data = self.input_data.convert_v(1) # TODO probably later
input_data = dself.pop("input_data")
input_data.pop("molecule", None) # discard
input_data.pop("provenance", None) # discard
dself["extras"] = {**input_data.pop("extras", {}), **dself.pop("extras", {})} # merge
dself = {**input_data, **dself}
self_vN = qcel.models.v1.AtomicResult(**dself)

return self_vN
Loading

0 comments on commit 8b56ec8

Please sign in to comment.