# Code in this file is adapted from: # https://github.com/openai/evolution-strategies-starter. import gym import numpy as np import tree # pip install dm_tree import ray from ray.rllib.models import ModelCatalog from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.filter import get_filter from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, unbatch from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, _ = try_import_torch() def before_init(policy, observation_space, action_space, config): policy.action_noise_std = config["action_noise_std"] policy.action_space_struct = get_base_struct_from_space(action_space) policy.preprocessor = ModelCatalog.get_preprocessor_for_space(observation_space) policy.observation_filter = get_filter( config["observation_filter"], policy.preprocessor.shape ) policy.single_threaded = config.get("single_threaded", False) def _set_flat_weights(policy, theta): pos = 0 theta_dict = policy.model.state_dict() new_theta_dict = {} for k in sorted(theta_dict.keys()): shape = policy.param_shapes[k] num_params = int(np.prod(shape)) new_theta_dict[k] = torch.from_numpy( np.reshape(theta[pos : pos + num_params], shape) ) pos += num_params policy.model.load_state_dict(new_theta_dict) def _get_flat_weights(policy): # Get the parameter tensors. theta_dict = policy.model.state_dict() # Flatten it into a single np.ndarray. theta_list = [] for k in sorted(theta_dict.keys()): theta_list.append(torch.reshape(theta_dict[k], (-1,))) cat = torch.cat(theta_list, dim=0) return cat.cpu().numpy() type(policy).set_flat_weights = _set_flat_weights type(policy).get_flat_weights = _get_flat_weights def _compute_actions(policy, obs_batch, add_noise=False, update=True, **kwargs): # Batch is given as list -> Try converting to numpy first. if isinstance(obs_batch, list) and len(obs_batch) == 1: obs_batch = obs_batch[0] observation = policy.preprocessor.transform(obs_batch) observation = policy.observation_filter(observation[None], update=update) observation = convert_to_torch_tensor(observation, policy.device) dist_inputs, _ = policy.model({SampleBatch.CUR_OBS: observation}, [], None) dist = policy.dist_class(dist_inputs, policy.model) action = dist.sample() def _add_noise(single_action, single_action_space): single_action = single_action.detach().cpu().numpy() if ( add_noise and isinstance(single_action_space, gym.spaces.Box) and single_action_space.dtype.name.startswith("float") ): single_action += ( np.random.randn(*single_action.shape) * policy.action_noise_std ) return single_action action = tree.map_structure(_add_noise, action, policy.action_space_struct) action = unbatch(action) return action, [], {} def _compute_single_action( policy, observation, add_noise=False, update=True, **kwargs ): action, state_outs, extra_fetches = policy.compute_actions( [observation], add_noise=add_noise, update=update, **kwargs ) return action[0], state_outs, extra_fetches type(policy).compute_actions = _compute_actions type(policy).compute_single_action = _compute_single_action def after_init(policy, observation_space, action_space, config): state_dict = policy.model.state_dict() policy.param_shapes = { k: tuple(state_dict[k].size()) for k in sorted(state_dict.keys()) } policy.num_params = sum(np.prod(s) for s in policy.param_shapes.values()) def make_model_and_action_dist(policy, observation_space, action_space, config): # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( action_space, config["model"], # model_options dist_type="deterministic", framework="torch", ) model = ModelCatalog.get_model_v2( policy.preprocessor.observation_space, action_space, num_outputs=dist_dim, model_config=config["model"], framework="torch", ) # Make all model params not require any gradients. for p in model.parameters(): p.requires_grad = False return model, dist_class ESTorchPolicy = build_policy_class( name="ESTorchPolicy", framework="torch", loss_fn=None, get_default_config=lambda: ray.rllib.algorithms.es.es.DEFAULT_CONFIG, before_init=before_init, after_init=after_init, make_model_and_action_dist=make_model_and_action_dist, )