diff --git a/docs/conf.py b/docs/conf.py index 088f8a067..712908eae 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -46,7 +45,7 @@ def __getattr__(cls, name): # Read version from file version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index da3376688..409b672e4 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a4 (WIP) +Release 1.5.1a5 (WIP) --------------------------- Breaking Changes: @@ -31,6 +31,7 @@ Deprecations: Others: ^^^^^^^ +- Upgraded to Python 3.7+ syntax using ``pyupgrade`` Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index bb53f06b5..cd8f2095a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import find_packages, setup -with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler: +with open(os.path.join("stable_baselines3", "version.txt")) as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 4e31c5b3b..d73f5f095 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -11,7 +11,7 @@ # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") -with open(version_file, "r") as file_handler: +with open(version_file) as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index eeeb670c3..13adf6800 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -82,7 +82,7 @@ def __init__( _init_setup_model: bool = True, ): - super(A2C, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -194,7 +194,7 @@ def learn( reset_num_timesteps: bool = True, ) -> "A2C": - return super(A2C, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 832ad9f23..a9b2eca1f 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -245,4 +245,4 @@ def __init__( if clip_reward: env = ClipRewardEnv(env) - super(AtariWrapper, self).__init__(env) + super().__init__(env) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index bba2272a5..d7728cbeb 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -42,7 +42,7 @@ def __init__( device: Union[th.device, str] = "cpu", n_envs: int = 1, ): - super(BaseBuffer, self).__init__() + super().__init__() self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space @@ -179,7 +179,7 @@ def __init__( optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ): - super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) # Adjust buffer size self.buffer_size = max(buffer_size // n_envs, 1) @@ -339,7 +339,7 @@ def __init__( n_envs: int = 1, ): - super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None @@ -358,7 +358,7 @@ def reset(self) -> None: self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False - super(RolloutBuffer, self).reset() + super().reset() def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: """ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 27ce5e639..c5f297c7c 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -19,7 +19,7 @@ class BaseCallback(ABC): """ def __init__(self, verbose: int = 0): - super(BaseCallback, self).__init__() + super().__init__() # The RL model self.model = None # type: Optional[base_class.BaseAlgorithm] # An alias for self.model.get_env(), the environment used for training @@ -127,14 +127,14 @@ class EventCallback(BaseCallback): """ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): - super(EventCallback, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.callback = callback # Give access to the parent if callback is not None: self.callback.parent = self def init_callback(self, model: "base_class.BaseAlgorithm") -> None: - super(EventCallback, self).init_callback(model) + super().init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) @@ -169,7 +169,7 @@ class CallbackList(BaseCallback): """ def __init__(self, callbacks: List[BaseCallback]): - super(CallbackList, self).__init__() + super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks @@ -228,7 +228,7 @@ class CheckpointCallback(BaseCallback): """ def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0): - super(CheckpointCallback, self).__init__(verbose) + super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix @@ -256,7 +256,7 @@ class ConvertCallback(BaseCallback): """ def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): - super(ConvertCallback, self).__init__(verbose) + super().__init__(verbose) self.callback = callback def _on_step(self) -> bool: @@ -307,7 +307,7 @@ def __init__( verbose: int = 1, warn: bool = True, ): - super(EvalCallback, self).__init__(callback_after_eval, verbose=verbose) + super().__init__(callback_after_eval, verbose=verbose) self.callback_on_new_best = callback_on_new_best if self.callback_on_new_best is not None: @@ -480,7 +480,7 @@ class StopTrainingOnRewardThreshold(BaseCallback): """ def __init__(self, reward_threshold: float, verbose: int = 0): - super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: @@ -505,7 +505,7 @@ class EveryNTimesteps(EventCallback): """ def __init__(self, n_steps: int, callback: BaseCallback): - super(EveryNTimesteps, self).__init__(callback) + super().__init__(callback) self.n_steps = n_steps self.last_time_trigger = 0 @@ -528,7 +528,7 @@ class StopTrainingOnMaxEpisodes(BaseCallback): """ def __init__(self, max_episodes: int, verbose: int = 0): - super(StopTrainingOnMaxEpisodes, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_episodes = max_episodes self._total_max_episodes = max_episodes self.n_episodes = 0 @@ -573,7 +573,7 @@ class StopTrainingOnNoModelImprovement(BaseCallback): """ def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): - super(StopTrainingOnNoModelImprovement, self).__init__(verbose=verbose) + super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals self.min_evals = min_evals self.last_best_mean_reward = -np.inf diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 1c0e54a88..3d1ff5aa0 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -16,7 +16,7 @@ class Distribution(ABC): """Abstract base class for distributions.""" def __init__(self): - super(Distribution, self).__init__() + super().__init__() self.distribution = None @abstractmethod @@ -120,7 +120,7 @@ class DiagGaussianDistribution(Distribution): """ def __init__(self, action_dim: int): - super(DiagGaussianDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.mean_actions = None self.log_std = None @@ -201,13 +201,13 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ def __init__(self, action_dim: int, epsilon: float = 1e-6): - super(SquashedDiagGaussianDistribution, self).__init__(action_dim) + super().__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon self.gaussian_actions = None def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": - super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) + super().proba_distribution(mean_actions, log_std) return self def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: @@ -219,7 +219,7 @@ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = N gaussian_actions = TanhBijector.inverse(actions) # Log likelihood for a Gaussian distribution - log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_actions) + log_prob = super().log_prob(gaussian_actions) # Squash correction (from original SAC implementation) # this comes from the fact that tanh is bijective and differentiable log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1) @@ -254,7 +254,7 @@ class CategoricalDistribution(Distribution): """ def __init__(self, action_dim: int): - super(CategoricalDistribution, self).__init__() + super().__init__() self.action_dim = action_dim def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -305,7 +305,7 @@ class MultiCategoricalDistribution(Distribution): """ def __init__(self, action_dims: List[int]): - super(MultiCategoricalDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -360,7 +360,7 @@ class BernoulliDistribution(Distribution): """ def __init__(self, action_dims: int): - super(BernoulliDistribution, self).__init__() + super().__init__() self.action_dims = action_dims def proba_distribution_net(self, latent_dim: int) -> nn.Module: @@ -433,7 +433,7 @@ def __init__( learn_features: bool = False, epsilon: float = 1e-6, ): - super(StateDependentNoiseDistribution, self).__init__() + super().__init__() self.action_dim = action_dim self.latent_sde_dim = None self.mean_actions = None @@ -597,7 +597,7 @@ def log_prob_from_params( return actions, log_prob -class TanhBijector(object): +class TanhBijector: """ Bijective transformation of a probability distribution using a squashing function (tanh) @@ -607,7 +607,7 @@ class TanhBijector(object): """ def __init__(self, epsilon: float = 1e-6): - super(TanhBijector, self).__init__() + super().__init__() self.epsilon = epsilon @staticmethod diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index c5d713aa2..a881b32c9 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -36,7 +36,7 @@ def __init__( image_obs_space: bool = False, channel_first: bool = True, ): - super(BitFlippingEnv, self).__init__() + super().__init__() # Shape of the observation when using image space self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state @@ -115,7 +115,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: if self.discrete_obs_space: # The internal state is the binary representation of the # observed one - return int(sum([state[i] * 2**i for i in range(len(state))])) + return int(sum(state[i] * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) @@ -135,7 +135,7 @@ def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int) if isinstance(state, int): state = np.array(state).reshape(batch_size, -1) # Convert to binary representation - state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int) + state = ((state[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int) elif self.image_obs_space: state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 else: diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 177a64166..2e5f13f61 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -42,7 +42,7 @@ def __init__( discrete_actions: bool = True, channel_last: bool = True, ): - super(SimpleMultiObsEnv, self).__init__() + super().__init__() self.vector_size = 5 if channel_last: diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 6493a3e0d..7cc3d0a30 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -24,7 +24,7 @@ DISABLED = 50 -class Video(object): +class Video: """ Video data class storing the video frames and the frame per seconds @@ -37,7 +37,7 @@ def __init__(self, frames: th.Tensor, fps: Union[float, int]): self.fps = fps -class Figure(object): +class Figure: """ Figure data class storing a matplotlib figure and whether to close the figure after logging it @@ -50,7 +50,7 @@ def __init__(self, figure: plt.figure, close: bool): self.close = close -class Image(object): +class Image: """ Image data class storing an image and data format @@ -80,13 +80,13 @@ def __init__(self, unsupported_formats: Sequence[str], value_description: str): format_str = f"formats {', '.join(unsupported_formats)} are" else: format_str = f"format {unsupported_formats[0]} is" - super(FormatUnsupportedError, self).__init__( + super().__init__( f"The {format_str} not supported for the {value_description} value logged.\n" f"You can exclude formats via the `exclude` parameter of the logger's `record` function." ) -class KVWriter(object): +class KVWriter: """ Key Value writer """ @@ -108,7 +108,7 @@ def close(self) -> None: raise NotImplementedError -class SeqWriter(object): +class SeqWriter: """ sequence writer """ @@ -427,7 +427,7 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWr # ================================================================ -class Logger(object): +class Logger: """ The logger class. @@ -623,7 +623,7 @@ def read_json(filename: str) -> pandas.DataFrame: :return: the data in the json """ data = [] - with open(filename, "rt") as file_handler: + with open(filename) as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 04cda2242..a482b72be 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -36,7 +36,7 @@ def __init__( reset_keywords: Tuple[str, ...] = (), info_keywords: Tuple[str, ...] = (), ): - super(Monitor, self).__init__(env=env) + super().__init__(env=env) self.t_start = time.time() if filename is not None: self.results_writer = ResultsWriter( @@ -110,7 +110,7 @@ def close(self) -> None: """ Closes the environment """ - super(Monitor, self).close() + super().close() if self.results_writer is not None: self.results_writer.close() @@ -224,7 +224,7 @@ def load_results(path: str) -> pandas.DataFrame: raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") data_frames, headers = [], [] for file_name in monitor_files: - with open(file_name, "rt") as file_handler: + with open(file_name) as file_handler: first_line = file_handler.readline() assert first_line[0] == "#" header = json.loads(first_line[1:]) diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index b1db6f4f2..119ed362e 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -11,7 +11,7 @@ class ActionNoise(ABC): """ def __init__(self): - super(ActionNoise, self).__init__() + super().__init__() def reset(self) -> None: """ @@ -35,7 +35,7 @@ class NormalActionNoise(ActionNoise): def __init__(self, mean: np.ndarray, sigma: np.ndarray): self._mu = mean self._sigma = sigma - super(NormalActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: return np.random.normal(self._mu, self._sigma) @@ -72,7 +72,7 @@ def __init__( self.initial_noise = initial_noise self.noise_prev = np.zeros_like(self._mu) self.reset() - super(OrnsteinUhlenbeckActionNoise, self).__init__() + super().__init__() def __call__(self) -> np.ndarray: noise = ( diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 5905deec3..ca57166f8 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -102,7 +102,7 @@ def __init__( supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OffPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 281758c0b..763c108e9 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -72,7 +72,7 @@ def __init__( supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, ): - super(OnPolicyAlgorithm, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index c322dc6f1..51a3d378a 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -67,7 +67,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(BaseModel, self).__init__() + super().__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -267,7 +267,7 @@ class BasePolicy(BaseModel): """ def __init__(self, *args, squash_output: bool = False, **kwargs): - super(BasePolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._squash_output = squash_output @staticmethod @@ -437,7 +437,7 @@ def __init__( if optimizer_class == th.optim.Adam: optimizer_kwargs["eps"] = 1e-5 - super(ActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -724,7 +724,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(ActorCriticCnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -799,7 +799,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputActorCriticPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index fb3ae8bd5..b48f9223c 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -3,7 +3,7 @@ import numpy as np -class RunningMeanStd(object): +class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ Calulates the running mean and std of a data stream diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index ba70a5f63..377b7f604 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -54,21 +54,21 @@ def __init__( centered: bool = False, ): if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) + raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= momentum: - raise ValueError("Invalid momentum value: {}".format(momentum)) + raise ValueError(f"Invalid momentum value: {momentum}") if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= alpha: - raise ValueError("Invalid alpha value: {}".format(alpha)) + raise ValueError(f"Invalid alpha value: {alpha}") defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) - super(RMSpropTFLike, self).__init__(params, defaults) + super().__init__(params, defaults) def __setstate__(self, state: Dict[str, Any]) -> None: - super(RMSpropTFLike, self).__setstate__(state) + super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) group.setdefault("centered", False) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 589d12eb1..8fd22372f 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -19,7 +19,7 @@ class BaseFeaturesExtractor(nn.Module): """ def __init__(self, observation_space: gym.Space, features_dim: int = 0): - super(BaseFeaturesExtractor, self).__init__() + super().__init__() assert features_dim > 0 self._observation_space = observation_space self._features_dim = features_dim @@ -41,7 +41,7 @@ class FlattenExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) + super().__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: @@ -61,7 +61,7 @@ class NatureCNN(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): - super(NatureCNN, self).__init__(observation_space, features_dim) + super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper assert is_image_space(observation_space, check_channels=False), ( @@ -169,7 +169,7 @@ def __init__( activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", ): - super(MlpExtractor, self).__init__() + super().__init__() device = get_device(device) shared_net, policy_net, value_net = [], [], [] policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network @@ -250,7 +250,7 @@ class CombinedExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! - super(CombinedExtractor, self).__init__(observation_space, features_dim=1) + super().__init__(observation_space, features_dim=1) extractors = {} diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index affd7756e..733b72833 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -7,7 +7,7 @@ from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first -class StackedObservations(object): +class StackedObservations: """ Frame stacking wrapper for data. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 04f5d0c58..f723c71f7 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -217,6 +217,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.space elif isinstance(space, gym.spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) else: return np.stack(obs) diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 859f1ec95..ca590cb1c 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -37,7 +37,7 @@ def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> return obs_dict elif isinstance(obs_space, gym.spaces.Tuple): assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" - return tuple((obs_dict[i] for i in range(len(obs_space.spaces)))) + return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index e6f728bec..b6b0ad832 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -26,7 +26,7 @@ def __init__(self, venv: VecEnv, skip: bool = False): self.skip = skip # Do nothing if skip: - super(VecTransposeImage, self).__init__(venv) + super().__init__(venv) return if isinstance(venv.observation_space, spaces.dict.Dict): @@ -39,7 +39,7 @@ def __init__(self, venv: VecEnv, skip: bool = False): observation_space.spaces[key] = self.transpose_space(space, key) else: observation_space = self.transpose_space(venv.observation_space) - super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) + super().__init__(venv, observation_space=observation_space) @staticmethod def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 14293ca3b..53d3fb619 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -78,7 +78,7 @@ def __init__( _init_setup_model: bool = True, ): - super(DDPG, self).__init__( + super().__init__( policy=policy, env=env, learning_rate=learning_rate, @@ -127,7 +127,7 @@ def learn( reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DDPG, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index ed6073b25..fe8f39822 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -95,7 +95,7 @@ def __init__( _init_setup_model: bool = True, ): - super(DQN, self).__init__( + super().__init__( policy, env, learning_rate, @@ -138,7 +138,7 @@ def __init__( self._setup_model() def _setup_model(self) -> None: - super(DQN, self)._setup_model() + super()._setup_model() self._create_aliases() self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, @@ -261,7 +261,7 @@ def learn( reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(DQN, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -274,7 +274,7 @@ def learn( ) def _excluded_save_params(self) -> List[str]: - return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"] + return super()._excluded_save_params() + ["q_net", "q_net_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "policy.optimizer"] diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index ea00b5cb5..ed3497c68 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -37,7 +37,7 @@ def __init__( activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(QNetwork, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -118,7 +118,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(DQNPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -239,7 +239,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -284,7 +284,7 @@ def __init__( optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index f61a78641..c461d19f2 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -82,7 +82,7 @@ def __init__( handle_timeout_termination: bool = True, ): - super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) + super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) # convert goal_selection_strategy into GoalSelectionStrategy if string if isinstance(goal_selection_strategy, str): diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 346cc022c..5b8d9e2fb 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -99,7 +99,7 @@ def __init__( _init_setup_model: bool = True, ): - super(PPO, self).__init__( + super().__init__( policy, env, learning_rate=learning_rate, @@ -162,7 +162,7 @@ def __init__( self._setup_model() def _setup_model(self) -> None: - super(PPO, self)._setup_model() + super()._setup_model() # Initialize schedules for policy/value clipping self.clip_range = get_schedule_fn(self.clip_range) @@ -307,7 +307,7 @@ def learn( reset_num_timesteps: bool = True, ) -> "PPO": - return super(PPO, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index cb6a61c11..6fcbea168 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -65,7 +65,7 @@ def __init__( clip_mean: float = 2.0, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -237,7 +237,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(SACPolicy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -424,7 +424,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -495,7 +495,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 3703b730b..07f88d9ab 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -110,7 +110,7 @@ def __init__( _init_setup_model: bool = True, ): - super(SAC, self).__init__( + super().__init__( policy, env, learning_rate, @@ -150,7 +150,7 @@ def __init__( self._setup_model() def _setup_model(self) -> None: - super(SAC, self)._setup_model() + super()._setup_model() self._create_aliases() # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": @@ -248,7 +248,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critic @@ -295,7 +295,7 @@ def learn( reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(SAC, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -308,7 +308,7 @@ def learn( ) def _excluded_save_params(self) -> List[str]: - return super(SAC, self)._excluded_save_params() + ["actor", "critic", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index ce91a0f91..f3ed53055 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -42,7 +42,7 @@ def __init__( activation_fn: Type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): - super(Actor, self).__init__( + super().__init__( observation_space, action_space, features_extractor=features_extractor, @@ -121,7 +121,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(TD3Policy, self).__init__( + super().__init__( observation_space, action_space, features_extractor_class, @@ -283,7 +283,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(CnnPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, @@ -337,7 +337,7 @@ def __init__( n_critics: int = 2, share_features_extractor: bool = True, ): - super(MultiInputPolicy, self).__init__( + super().__init__( observation_space, action_space, lr_schedule, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index d31720b67..34a783d29 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -95,7 +95,7 @@ def __init__( _init_setup_model: bool = True, ): - super(TD3, self).__init__( + super().__init__( policy, env, learning_rate, @@ -129,7 +129,7 @@ def __init__( self._setup_model() def _setup_model(self) -> None: - super(TD3, self)._setup_model() + super()._setup_model() self._create_aliases() def _create_aliases(self) -> None: @@ -168,7 +168,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) + critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) critic_losses.append(critic_loss.item()) # Optimize the critics @@ -208,7 +208,7 @@ def learn( reset_num_timesteps: bool = True, ) -> OffPolicyAlgorithm: - return super(TD3, self).learn( + return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, @@ -221,7 +221,7 @@ def learn( ) def _excluded_save_params(self) -> List[str]: - return super(TD3, self)._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] + return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index d6a9f8c61..bccb8c675 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a4 +1.5.1a5 diff --git a/tests/test_gae.py b/tests/test_gae.py index 54e03b8b1..8e461ed7a 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -10,7 +10,7 @@ class CustomEnv(gym.Env): def __init__(self, max_steps=8): - super(CustomEnv, self).__init__() + super().__init__() self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.max_steps = max_steps @@ -54,7 +54,7 @@ def step(self, action): class CheckGAECallback(BaseCallback): def __init__(self): - super(CheckGAECallback, self).__init__(verbose=0) + super().__init__(verbose=0) def _on_rollout_end(self): buffer = self.model.rollout_buffer @@ -99,7 +99,7 @@ class CustomPolicy(ActorCriticPolicy): """Custom Policy with a constant value function""" def __init__(self, *args, **kwargs): - super(CustomPolicy, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.constant_value = 0.0 def forward(self, obs, deterministic=False): diff --git a/tests/test_her.py b/tests/test_her.py index 0f6d75f6f..888d36a6e 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -156,7 +156,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values model.policy.load_state_dict(random_params) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index d3d041b4d..4c1d3cf59 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -14,7 +14,7 @@ def test_monitor(tmp_path): """ env = gym.make("CartPole-v1") env.seed(0) - monitor_file = os.path.join(str(tmp_path), "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() total_steps = 1000 @@ -37,7 +37,7 @@ def test_monitor(tmp_path): assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) _ = monitor_env.get_episode_times() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -56,7 +56,7 @@ def test_monitor_load_results(tmp_path): tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") env1.seed(0) - monitor_file1 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) monitor_files = get_monitor_files(tmp_path) @@ -76,7 +76,7 @@ def test_monitor_load_results(tmp_path): env2 = gym.make("CartPole-v1") env2.seed(0) - monitor_file2 = os.path.join(tmp_path, "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4())) + monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) assert len(monitor_files) == 2 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 452e6fbdc..2fdebbeaa 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -64,7 +64,7 @@ def test_save_load(tmp_path, model_class): model.set_parameters(invalid_object_params, exact_match=False) # Test that exact_match catches when something was missed. - missing_object_params = dict((k, v) for k, v in list(original_params.items())[:-1]) + missing_object_params = {k: v for k, v in list(original_params.items())[:-1]} with pytest.raises(ValueError): model.set_parameters(missing_object_params, exact_match=True) @@ -446,7 +446,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str, use_sde): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -537,7 +537,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): params = deepcopy(q_net.state_dict()) # Modify all parameters to be random values - random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + random_params = {param_name: th.rand_like(param) for param_name, param in params.items()} # Update model parameters with the new random values q_net.load_state_dict(random_params) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 54994b2b5..b75404288 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -9,7 +9,7 @@ class DummyMultiDiscreteSpace(gym.Env): def __init__(self, nvec): - super(DummyMultiDiscreteSpace, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) @@ -22,7 +22,7 @@ def step(self, action): class DummyMultiBinary(gym.Env): def __init__(self, n): - super(DummyMultiBinary, self).__init__() + super().__init__() self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 1ea2efe67..4f023e96a 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -28,7 +28,7 @@ class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor): """ def __init__(self, observation_space: gym.Space): - super(FlattenBatchNormDropoutExtractor, self).__init__( + super().__init__( observation_space, get_flattened_obs_dim(observation_space), ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b07bbe931..67f2ad1a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -180,7 +180,7 @@ class AlwaysDoneWrapper(gym.Wrapper): # Pretends that environment only has single step for each # episode. def __init__(self, env): - super(AlwaysDoneWrapper, self).__init__(env) + super().__init__(env) self.last_obs = None self.needs_reset = True diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 265da2ed9..962355782 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -12,7 +12,7 @@ class NanAndInfEnv(gym.Env): metadata = {"render.modes": ["human"]} def __init__(self): - super(NanAndInfEnv, self).__init__() + super().__init__() self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 974202b31..5ccc33e12 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -36,7 +36,7 @@ def test_vec_monitor(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as file_handler: + with open(monitor_file) as file_handler: first_line = file_handler.readline() assert first_line.startswith("#") metadata = json.loads(first_line[1:]) @@ -66,7 +66,7 @@ def test_vec_monitor_info_keywords(tmp_path): monitor_env.close() - with open(monitor_file, "rt") as f: + with open(monitor_file) as f: reader = csv.reader(f) for i, line in enumerate(reader): if i == 0 or i == 1: diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 07ad77f22..86a0d841d 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -47,7 +47,7 @@ class DummyDictEnv(gym.GoalEnv): """ def __init__(self): - super(DummyDictEnv, self).__init__() + super().__init__() self.observation_space = spaces.Dict( { "observation": spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32),