ray/rllib/algorithms/mock.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

159 lines
4.6 KiB
Python
Raw Normal View History

import os
import pickle
import numpy as np
from ray.tune import result as tune_result
from ray.rllib.algorithms.algorithm import Algorithm, with_common_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AlgorithmConfigDict
class _MockTrainer(Algorithm):
"""Mock trainer for use in tests"""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"mock_error": False,
"persistent_error": False,
"test_variable": 1,
"num_workers": 0,
"user_checkpoint_freq": 0,
"framework": "tf",
}
)
@classmethod
def default_resource_request(cls, config):
return None
@override(Algorithm)
def setup(self, config):
# Setup our config: Merge the user-supplied config (which could
# be a partial config dict with the class' default).
self.config = self.merge_trainer_configs(
self.get_default_config(), config, self._allow_unknown_configs
)
self.config["env"] = self._env_id
self.validate_config(self.config)
self.callbacks = self.config["callbacks"]()
# Add needed properties.
self.info = None
self.restored = False
@override(Algorithm)
def step(self):
if (
self.config["mock_error"]
and self.iteration == 1
and (self.config["persistent_error"] or not self.restored)
):
raise Exception("mock error")
result = dict(
episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}
)
if self.config["user_checkpoint_freq"] > 0 and self.iteration > 0:
if self.iteration % self.config["user_checkpoint_freq"] == 0:
result.update({tune_result.SHOULD_CHECKPOINT: True})
return result
@override(Algorithm)
def save_checkpoint(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, "wb") as f:
pickle.dump(self.info, f)
return path
@override(Algorithm)
def load_checkpoint(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
info = pickle.load(f)
self.info = info
self.restored = True
@staticmethod
@override(Algorithm)
def _get_env_id_and_creator(env_specifier, config):
# No env to register.
return None, None
def set_info(self, info):
self.info = info
return info
def get_info(self, sess=None):
return self.info
class _SigmoidFakeData(_MockTrainer):
"""Trainer that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return dict(
episode_reward_mean=v,
episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
class _ParameterTuningTrainer(_MockTrainer):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return with_common_config(
{
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1,
"num_workers": 0,
}
)
def step(self):
return dict(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={},
)
def _algorithm_import_failed(trace):
"""Returns dummy Algorithm class for if PyTorch etc. is not installed."""
class _TrainerImportFailed(Algorithm):
_name = "TrainerImportFailed"
def setup(self, config):
raise ImportError(trace)
return _TrainerImportFailed