Skip to content

Commit

Permalink
Fixed bugs with .gz file handling.
Browse files Browse the repository at this point in the history
Added more control over configurations to write and how many per file.
  • Loading branch information
paulsaxe committed Sep 27, 2022
1 parent dd1fbd8 commit 3f44ffa
Showing 1 changed file with 72 additions and 31 deletions.
103 changes: 72 additions & 31 deletions read_structure_step/write_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from seamm_util import ureg, Q_ # noqa: F401
import seamm_util.printing as printing
from seamm_util.printing import FormattedText as __
from .utils import guess_extension

logger = logging.getLogger(__name__)
job = printing.getPrinter()
Expand Down Expand Up @@ -119,19 +118,20 @@ def run(self):

# What type of file?
filename = P["file"].strip()
path = PurePath(filename)
file_type = P["file type"]

if file_type != "from extension":
extension = file_type.split()[0]
else:
path = PurePath(filename)
extension = path.suffix
if extension == ".gz":
extension = path.stem.suffix
extension = path.with_suffix("").suffix

if extension == "":
extension = guess_extension(filename, use_file_name=False)
P["file type"] = extension
raise RuntimeError(
"Can't write the file without knowing the type (extension)"
)

# Print what we are doing
printer.important(__(self.description_text(P), indent=4 * " "))
Expand All @@ -141,37 +141,78 @@ def run(self):
system, configuration = self.get_system_configuration(P)

structures = P["structures"]
configs = P["configurations"]
errors = not P["ignore missing"]
configurations = []
if structures == "current configuration":
configurations = [configuration]
elif structures == "all configurations of current system":
for configuration in system.configurations():
configurations.append(configuration)
elif structures == "all systems":
configurations = []
for system in system_db.systems():
for configuration in system.configurations():
configurations.append(configuration)
elif structures == "current system":
if configs == "all":
for configuration in system.configurations:
configurations.append(configuration)

write(
filename,
configurations,
extension=extension,
remove_hydrogens=P["remove hydrogens"],
printer=printer.important,
references=self.references,
bibliography=self._bibliography,
)
else:
cid = system.get_configuration_id(configs, errors=errors)
if cid is not None:
configurations.append(system.get_configuration(cid))
elif structures == "all systems":
if configs == "all":
for system in system_db.systems:
for configuration in system.configurations:
configurations.append(configuration)
else:
for system in system_db.systems:
cid = system.get_configuration_id(configs, errors=errors)
if cid is not None:
configurations.append(system.get_configuration(cid))

n_per_file = P["number per file"]
n_configurations = len(configurations)
if n_per_file == "all" or n_configurations <= n_per_file:
write(
filename,
configurations,
extension=extension,
remove_hydrogens=P["remove hydrogens"],
printer=printer.important,
references=self.references,
bibliography=self._bibliography,
)
else:
n_per_file = int(n_per_file)
if path.suffix == ".gz":
base = path.with_suffix("")
suffix = base.suffix + ".gz"
stem = str(base.with_suffix(""))
else:
suffix = path.suffix
stem = str(path.with_suffix(""))
last = 1 # Note that counting from 1 for users.

while last <= n_configurations:
first = last
last += n_per_file
tmp_name = stem + f"_{first}" + suffix
write(
tmp_name,
configurations[first - 1 : last - 1],
extension=extension,
remove_hydrogens=P["remove hydrogens"],
printer=printer.important,
references=self.references,
bibliography=self._bibliography,
)

# Finish the output
printer.important(
__(
f"\n Wrote the structure with {configuration.n_atoms} "
"atoms."
f"\n System name = {system.name}"
f"\n Configuration name = {configuration.name}",
indent=4 * " ",
if n_configurations == 1:
printer.important(
__(
f"\n Wrote the structure with {configuration.n_atoms} "
"atoms."
f"\n System name = {system.name}"
f"\n Configuration name = {configuration.name}",
indent=4 * " ",
)
)
)
printer.important("")

return next_node

0 comments on commit 3f44ffa

Please sign in to comment.