Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Add dump pickle method for environments #291

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions rl_coach/base_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __init__(self,
dump_signals_to_csv_every_x_episodes=5,
dump_gifs=False,
dump_mp4=False,
dump_pickle=False,
video_dump_methods=None,
dump_in_episode_signals=False,
dump_parameters_documentation=True,
Expand All @@ -454,6 +455,9 @@ def __init__(self,
:param dump_mp4:
If set to True, MP4 videos of the environment will be stored into the experiment directory according to
the filters defined in video_dump_methods.
:param dump_pickle:
If set to True, pickles of the environment will be stored into the experiment directory according to
the filters defined in video_dump_methods.
:param dump_in_episode_signals:
If set to True, csv files will be dumped for each episode for inspecting different metrics within the
episode. This means that for each step in each episode, different metrics such as the reward, the
Expand Down Expand Up @@ -496,6 +500,7 @@ def __init__(self,
self.dump_csv = dump_csv
self.dump_gifs = dump_gifs
self.dump_mp4 = dump_mp4
self.dump_pickle = dump_pickle
self.dump_signals_to_csv_every_x_episodes = dump_signals_to_csv_every_x_episodes
self.dump_in_episode_signals = dump_in_episode_signals
self.dump_parameters_documentation = dump_parameters_documentation
Expand Down
6 changes: 5 additions & 1 deletion rl_coach/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def step(self, action: ActionType) -> EnvResponse:

# store observations for video / gif dumping
if self.should_dump_video_of_the_current_episode(episode_terminated=False) and \
(self.visualization_parameters.dump_mp4 or self.visualization_parameters.dump_gifs):
(self.visualization_parameters.dump_mp4 or self.visualization_parameters.dump_gifs
or self.visualization_parameters.dump_pickle):
self.last_episode_images.append(self.get_rendered_image())

return self.last_env_response
Expand Down Expand Up @@ -442,6 +443,9 @@ def dump_video_of_last_episode(self):
logger.create_gif(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
if self.visualization_parameters.dump_mp4:
logger.create_mp4(self.last_episode_images[::frame_skipping], name=file_name, fps=fps)
if self.visualization_parameters.dump_pickle:
logger.create_pickle(self.last_episode_images[::frame_skipping], name=file_name)


# The following functions define the interaction with the environment.
# Any new environment that inherits the Environment class should use these signatures.
Expand Down
13 changes: 13 additions & 0 deletions rl_coach/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#
import atexit
import datetime
import gzip
import os
import pickle
import re
import shutil
import signal
Expand Down Expand Up @@ -359,6 +361,17 @@ def create_mp4(images, fps=10, name="mp4"):
p.wait()


def create_pickle(images, name='pickle'):
global experiment_path
output_file = '{}_{}.pkl.gz'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S'), name)
output_dir = os.path.join(experiment_path or os.getcwd(), 'pickles')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_path = os.path.join(output_dir, output_file)
with gzip.open(output_path, 'wb') as fh:
pickle.dump(images, fh)


def remove_experiment_dir():
shutil.rmtree(experiment_path)

Expand Down