mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
173 lines
5.6 KiB
Python
173 lines
5.6 KiB
Python
import argparse
|
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
import ray
|
|
from ray import tune
|
|
|
|
from ray.tune import registry
|
|
|
|
from ray.rllib import models
|
|
from ray.rllib.utils import try_import_tf
|
|
from ray.rllib.models.tf import attention
|
|
from ray.rllib.models.tf import recurrent_tf_modelv2
|
|
from ray.rllib.examples.custom_keras_rnn_model import RepeatAfterMeEnv
|
|
from ray.rllib.examples.custom_keras_rnn_model import RepeatInitialEnv
|
|
|
|
tf = try_import_tf()
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--run", type=str, default="PPO")
|
|
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
|
|
parser.add_argument("--stop", type=int, default=90)
|
|
parser.add_argument("--num-cpus", type=int, default=0)
|
|
|
|
|
|
class OneHot(gym.Wrapper):
|
|
|
|
def __init__(self, env):
|
|
super(OneHot, self).__init__(env)
|
|
self.observation_space = gym.spaces.Box(0., 1.,
|
|
(env.observation_space.n,))
|
|
|
|
def reset(self, **kwargs):
|
|
obs = self.env.reset(**kwargs)
|
|
return self._encode_obs(obs)
|
|
|
|
def step(self, action):
|
|
obs, reward, done, info = self.env.step(action)
|
|
return self._encode_obs(obs), reward, done, info
|
|
|
|
def _encode_obs(self, obs):
|
|
new_obs = np.ones(self.env.observation_space.n)
|
|
new_obs[obs] = 1.0
|
|
return new_obs
|
|
|
|
|
|
class LookAndPush(gym.Env):
|
|
def __init__(self):
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
self.observation_space = gym.spaces.Discrete(5)
|
|
self._state = None
|
|
self._case = None
|
|
|
|
def reset(self):
|
|
self._state = 2
|
|
self._case = np.random.choice(2)
|
|
return self._state
|
|
|
|
def step(self, action):
|
|
assert self.action_space.contains(action)
|
|
|
|
if self._state == 4:
|
|
if action and self._case:
|
|
return self._state, 10., True, {}
|
|
else:
|
|
return self._state, -10, True, {}
|
|
else:
|
|
if action:
|
|
if self._state == 0:
|
|
self._state = 2
|
|
else:
|
|
self._state += 1
|
|
elif self._state == 2:
|
|
self._state = self._case
|
|
|
|
return self._state, -1, False, {}
|
|
|
|
|
|
class GRUTrXL(recurrent_tf_modelv2.RecurrentTFModelV2):
|
|
|
|
def __init__(self, obs_space, action_space, num_outputs, model_config,
|
|
name):
|
|
super(GRUTrXL, self).__init__(obs_space, action_space, num_outputs,
|
|
model_config, name)
|
|
self.max_seq_len = model_config["max_seq_len"]
|
|
self.obs_dim = obs_space.shape[0]
|
|
input_layer = tf.keras.layers.Input(
|
|
shape=(self.max_seq_len, obs_space.shape[0]),
|
|
name="inputs",
|
|
)
|
|
|
|
trxl_out = attention.make_GRU_TrXL(
|
|
seq_length=model_config["max_seq_len"],
|
|
num_layers=model_config["custom_options"]["num_layers"],
|
|
attn_dim=model_config["custom_options"]["attn_dim"],
|
|
num_heads=model_config["custom_options"]["num_heads"],
|
|
head_dim=model_config["custom_options"]["head_dim"],
|
|
ff_hidden_dim=model_config["custom_options"]["ff_hidden_dim"],
|
|
)(input_layer)
|
|
|
|
# Postprocess TrXL output with another hidden layer and compute values
|
|
logits = tf.keras.layers.Dense(
|
|
self.num_outputs,
|
|
activation=tf.keras.activations.linear,
|
|
name="logits")(trxl_out)
|
|
values_out = tf.keras.layers.Dense(
|
|
1, activation=None, name="values")(trxl_out)
|
|
|
|
self.trxl_model = tf.keras.Model(
|
|
inputs=[input_layer],
|
|
outputs=[logits, values_out],
|
|
)
|
|
self.register_variables(self.trxl_model.variables)
|
|
self.trxl_model.summary()
|
|
|
|
def forward_rnn(self, inputs, state, seq_lens):
|
|
state = state[0]
|
|
|
|
# We assume state is the history of recent observations and append
|
|
# the current inputs to the end and only keep the most recent (up to
|
|
# max_seq_len). This allows us to deal with timestep-wise inference
|
|
# and full sequence training with the same logic.
|
|
state = tf.concat((state, inputs), axis=1)[:, -self.max_seq_len:]
|
|
logits, self._value_out = self.trxl_model(state)
|
|
|
|
in_T = tf.shape(inputs)[1]
|
|
logits = logits[:, -in_T:]
|
|
self._value_out = self._value_out[:, -in_T:]
|
|
|
|
return logits, [state]
|
|
|
|
def get_initial_state(self):
|
|
return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
|
|
|
|
def value_function(self):
|
|
return tf.reshape(self._value_out, [-1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
ray.init(num_cpus=args.num_cpus or None)
|
|
models.ModelCatalog.register_custom_model("trxl", GRUTrXL)
|
|
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
|
registry.register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
|
|
registry.register_env("LookAndPush", lambda _: OneHot(LookAndPush()))
|
|
tune.run(
|
|
args.run,
|
|
stop={"episode_reward_mean": args.stop},
|
|
config={
|
|
"env": args.env,
|
|
"env_config": {
|
|
"repeat_delay": 2,
|
|
},
|
|
"gamma": 0.99,
|
|
"num_workers": 0,
|
|
"num_envs_per_worker": 20,
|
|
"entropy_coeff": 0.001,
|
|
"num_sgd_iter": 5,
|
|
"vf_loss_coeff": 1e-5,
|
|
"model": {
|
|
"custom_model": "trxl",
|
|
"max_seq_len": 10,
|
|
"custom_options": {
|
|
"num_layers": 1,
|
|
"attn_dim": 10,
|
|
"num_heads": 1,
|
|
"head_dim": 10,
|
|
"ff_hidden_dim": 20,
|
|
},
|
|
},
|
|
})
|