2020-10-01 16:57:10 +02:00
|
|
|
from gym.spaces import Box
|
2020-08-21 12:35:16 +02:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
|
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
|
|
|
from ray.rllib.utils.annotations import override
|
|
|
|
|
|
|
|
|
2020-11-28 01:25:47 +01:00
|
|
|
class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
2020-08-21 12:35:16 +02:00
|
|
|
"""A Policy that always knows the current EpisodeID and EnvID and
|
|
|
|
returns these in its actions."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
2022-01-29 18:41:57 -08:00
|
|
|
self.state_space = Box(-1.0, 1.0, (1,))
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
class _fake_model:
|
|
|
|
pass
|
|
|
|
|
|
|
|
self.model = _fake_model()
|
|
|
|
self.model.time_major = True
|
2020-12-30 20:32:21 -05:00
|
|
|
self.model.view_requirements = {
|
2020-10-01 16:57:10 +02:00
|
|
|
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
2020-08-21 12:35:16 +02:00
|
|
|
SampleBatch.EPS_ID: ViewRequirement(),
|
|
|
|
"env_id": ViewRequirement(),
|
2020-10-01 16:57:10 +02:00
|
|
|
"t": ViewRequirement(),
|
2020-08-21 12:35:16 +02:00
|
|
|
SampleBatch.OBS: ViewRequirement(),
|
|
|
|
SampleBatch.PREV_ACTIONS: ViewRequirement(
|
2022-01-29 18:41:57 -08:00
|
|
|
SampleBatch.ACTIONS, space=self.action_space, shift=-1
|
|
|
|
),
|
|
|
|
SampleBatch.PREV_REWARDS: ViewRequirement(SampleBatch.REWARDS, shift=-1),
|
2020-08-21 12:35:16 +02:00
|
|
|
}
|
2020-10-01 16:57:10 +02:00
|
|
|
for i in range(2):
|
2022-01-29 18:41:57 -08:00
|
|
|
self.model.view_requirements["state_in_{}".format(i)] = ViewRequirement(
|
|
|
|
"state_out_{}".format(i), shift=-1, space=self.state_space
|
|
|
|
)
|
|
|
|
self.model.view_requirements["state_out_{}".format(i)] = ViewRequirement(
|
|
|
|
space=self.state_space
|
|
|
|
)
|
2020-10-01 16:57:10 +02:00
|
|
|
|
|
|
|
self.view_requirements = dict(
|
2020-08-21 12:35:16 +02:00
|
|
|
**{
|
2022-01-29 18:41:57 -08:00
|
|
|
SampleBatch.NEXT_OBS: ViewRequirement(SampleBatch.OBS, shift=1),
|
2020-08-21 12:35:16 +02:00
|
|
|
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
|
|
|
|
SampleBatch.REWARDS: ViewRequirement(),
|
|
|
|
SampleBatch.DONES: ViewRequirement(),
|
2021-09-23 12:56:45 +02:00
|
|
|
SampleBatch.UNROLL_ID: ViewRequirement(),
|
2020-08-21 12:35:16 +02:00
|
|
|
},
|
2022-01-29 18:41:57 -08:00
|
|
|
**self.model.view_requirements
|
|
|
|
)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def is_recurrent(self):
|
|
|
|
return True
|
|
|
|
|
|
|
|
@override(Policy)
|
2022-01-29 18:41:57 -08:00
|
|
|
def compute_actions_from_input_dict(
|
|
|
|
self, input_dict, explore=None, timestep=None, **kwargs
|
|
|
|
):
|
2020-10-01 16:57:10 +02:00
|
|
|
ts = input_dict["t"]
|
|
|
|
print(ts)
|
|
|
|
# Always return [episodeID, envID] as actions.
|
2022-01-29 18:41:57 -08:00
|
|
|
actions = np.array(
|
|
|
|
[
|
|
|
|
[
|
|
|
|
input_dict[SampleBatch.AGENT_INDEX][i],
|
|
|
|
input_dict[SampleBatch.EPS_ID][i],
|
|
|
|
input_dict["env_id"][i],
|
|
|
|
]
|
|
|
|
for i, _ in enumerate(input_dict["obs"])
|
|
|
|
]
|
|
|
|
)
|
2020-10-01 16:57:10 +02:00
|
|
|
states = [
|
2022-01-29 18:41:57 -08:00
|
|
|
np.array([[ts[i]] for i in range(len(input_dict["obs"]))]) for _ in range(2)
|
2020-10-01 16:57:10 +02:00
|
|
|
]
|
|
|
|
return actions, states, {}
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
@override(Policy)
|
2022-01-29 18:41:57 -08:00
|
|
|
def postprocess_trajectory(
|
|
|
|
self, sample_batch, other_agent_batches=None, episode=None
|
|
|
|
):
|
2020-11-28 01:25:47 +01:00
|
|
|
sample_batch["2xobs"] = sample_batch["obs"] * 2.0
|
|
|
|
return sample_batch
|
|
|
|
|
|
|
|
|
|
|
|
class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
|
|
|
|
"""A Policy that always knows the current EpisodeID and EnvID and
|
|
|
|
returns these in its actions."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
2022-01-29 18:41:57 -08:00
|
|
|
self.state_space = Box(-1.0, 1.0, (1,))
|
2020-11-28 01:25:47 +01:00
|
|
|
self.config["model"] = {"max_seq_len": 50}
|
|
|
|
|
|
|
|
class _fake_model:
|
|
|
|
pass
|
|
|
|
|
|
|
|
self.model = _fake_model()
|
2020-12-30 20:32:21 -05:00
|
|
|
self.model.view_requirements = {
|
2020-11-28 01:25:47 +01:00
|
|
|
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
|
|
|
SampleBatch.EPS_ID: ViewRequirement(),
|
|
|
|
"env_id": ViewRequirement(),
|
|
|
|
"t": ViewRequirement(),
|
|
|
|
SampleBatch.OBS: ViewRequirement(),
|
|
|
|
"state_in_0": ViewRequirement(
|
|
|
|
"state_out_0",
|
|
|
|
# Provide state outs -50 to -1 as "state-in".
|
2020-12-07 13:08:17 +01:00
|
|
|
shift="-50:-1",
|
2020-11-28 01:25:47 +01:00
|
|
|
# Repeat the incoming state every n time steps (usually max seq
|
|
|
|
# len).
|
|
|
|
batch_repeat_value=self.config["model"]["max_seq_len"],
|
2022-01-29 18:41:57 -08:00
|
|
|
space=self.state_space,
|
|
|
|
),
|
2021-09-23 08:31:51 +02:00
|
|
|
"state_out_0": ViewRequirement(
|
2022-01-29 18:41:57 -08:00
|
|
|
space=self.state_space, used_for_compute_actions=False
|
|
|
|
),
|
2020-11-28 01:25:47 +01:00
|
|
|
}
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
self.view_requirements = dict(
|
|
|
|
super()._get_default_view_requirements(), **self.model.view_requirements
|
|
|
|
)
|
2020-11-28 01:25:47 +01:00
|
|
|
|
|
|
|
@override(Policy)
|
|
|
|
def is_recurrent(self):
|
|
|
|
return True
|
|
|
|
|
|
|
|
@override(Policy)
|
2022-01-29 18:41:57 -08:00
|
|
|
def compute_actions_from_input_dict(
|
|
|
|
self, input_dict, explore=None, timestep=None, **kwargs
|
|
|
|
):
|
2020-11-28 01:25:47 +01:00
|
|
|
ts = input_dict["t"]
|
|
|
|
print(ts)
|
|
|
|
# Always return [episodeID, envID] as actions.
|
2022-01-29 18:41:57 -08:00
|
|
|
actions = np.array(
|
|
|
|
[
|
|
|
|
[
|
|
|
|
input_dict[SampleBatch.AGENT_INDEX][i],
|
|
|
|
input_dict[SampleBatch.EPS_ID][i],
|
|
|
|
input_dict["env_id"][i],
|
|
|
|
]
|
|
|
|
for i, _ in enumerate(input_dict["obs"])
|
|
|
|
]
|
|
|
|
)
|
2020-11-28 01:25:47 +01:00
|
|
|
states = [np.array([[ts[i]] for i in range(len(input_dict["obs"]))])]
|
|
|
|
self.global_timestep += 1
|
|
|
|
return actions, states, {}
|
|
|
|
|
|
|
|
@override(Policy)
|
2022-01-29 18:41:57 -08:00
|
|
|
def postprocess_trajectory(
|
|
|
|
self, sample_batch, other_agent_batches=None, episode=None
|
|
|
|
):
|
2020-11-28 01:25:47 +01:00
|
|
|
sample_batch["3xobs"] = sample_batch["obs"] * 3.0
|
2020-08-21 12:35:16 +02:00
|
|
|
return sample_batch
|