ray/rllib/examples/connectors/adapt_connector_policy.py

118 lines
3.4 KiB
Python

"""This example script shows how to load a connector enabled policy,
and adapt/use it with a different version of the environment.
"""
import argparse
import gym
import numpy as np
from typing import Dict
from ray.rllib.utils.policy import (
load_policies_from_checkpoint,
policy_inference,
)
from ray.rllib.connectors.connector import ConnectorContext
from ray.rllib.connectors.action.lambdas import register_lambda_action_connector
from ray.rllib.connectors.agent.lambdas import register_lambda_agent_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import (
PolicyOutputType,
StateBatches,
TensorStructType,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_file",
help="Path to an RLlib checkpoint file.",
)
parser.add_argument(
"--policy_id",
default="default_policy",
help="ID of policy to load.",
)
args = parser.parse_args()
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
class MyCartPole(gym.Env):
"""A mock CartPole environment.
Gives 2 additional observation states and takes 2 discrete actions.
"""
def __init__(self):
self._env = gym.make("CartPole-v0")
self.observation_space = gym.spaces.Box(low=-10, high=10, shape=(6,))
self.action_space = gym.spaces.MultiDiscrete(nvec=[2, 2])
def step(self, actions):
# Take the first action.
action = actions[0]
obs, reward, done, info = self._env.step(action)
# Fake additional data points to the obs.
obs = np.hstack((obs, [8.0, 6.0]))
return obs, reward, done, info
def reset(self):
return np.hstack((self._env.reset(), [8.0, 6.0]))
# Custom agent connector.
def v2_to_v1_obs(data: Dict[str, TensorStructType]) -> Dict[str, TensorStructType]:
data[SampleBatch.NEXT_OBS] = data[SampleBatch.NEXT_OBS][:-2]
return data
# Agent connector that adapts observations from the new CartPole env
# into old format.
V2ToV1ObsAgentConnector = register_lambda_agent_connector(
"V2ToV1ObsAgentConnector", v2_to_v1_obs
)
# Custom action connector.
def v1_to_v2_action(
actions: TensorStructType, states: StateBatches, fetches: Dict
) -> PolicyOutputType:
return np.hstack((actions, [0])), states, fetches
# Action connector that adapts action outputs from the old policy
# into new actions for the mock environment.
V1ToV2ActionConnector = register_lambda_action_connector(
"V1ToV2ActionConnector", v1_to_v2_action
)
def run():
# Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
policy = policies[args.policy_id]
# Adapt policy trained for standard CartPole to the new env.
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
# Run CartPole.
env = MyCartPole()
obs = env.reset()
done = False
step = 0
while not done:
step += 1
# Use policy_inference() to easily run poicy with observations.
policy_outputs = policy_inference(policy, "env_1", "agent_1", obs)
assert len(policy_outputs) == 1
actions, _, _ = policy_outputs[0]
print(f"step {step}", obs, actions)
obs, _, done, _ = env.step(actions)
if __name__ == "__main__":
run()