[RLlib] Fix connector examples (#27583)

This commit is contained in:
Jun Gong 2022-08-07 17:48:09 -07:00 committed by GitHub
parent 89b2f616fd
commit 5f07987ab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 24 deletions

View file

@ -3343,36 +3343,36 @@ py_test(
py_test( py_test(
name = "examples/connectors/run_connector_policy", name = "examples/connectors/run_connector_policy",
main = "examples/connectors/run_connector_policy.py", main = "examples/connectors/run_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ], tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small", size = "small",
srcs = ["examples/connectors/run_connector_policy.py"], srcs = ["examples/connectors/run_connector_policy.py"],
data = [ data = glob([
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6", "tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.new",
], ]),
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"] args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-08062022"]
) )
py_test( py_test(
name = "examples/connectors/adapt_connector_policy", name = "examples/connectors/adapt_connector_policy",
main = "examples/connectors/adapt_connector_policy.py", main = "examples/connectors/adapt_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ], tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small", size = "small",
srcs = ["examples/connectors/adapt_connector_policy.py"], srcs = ["examples/connectors/adapt_connector_policy.py"],
data = [ data = glob([
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6", "tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.old",
], ]),
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"] args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022"]
) )
py_test( py_test(
name = "examples/connectors/self_play_with_policy_checkpoint", name = "examples/connectors/self_play_with_policy_checkpoint",
main = "examples/connectors/self_play_with_policy_checkpoint.py", main = "examples/connectors/self_play_with_policy_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", ], tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small", size = "small",
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"], srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
data = [ data = glob([
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6", "tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
], ]),
args = [ args = [
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6", "--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
"--train_iteration=1" # Smoke test. "--train_iteration=1" # Smoke test.

View file

@ -5,6 +5,7 @@ and adapt/use it with a different version of the environment.
import argparse import argparse
import gym import gym
import numpy as np import numpy as np
from pathlib import Path
from typing import Dict from typing import Dict
from ray.rllib.utils.policy import ( from ray.rllib.utils.policy import (
@ -23,9 +24,11 @@ from ray.rllib.utils.typing import (
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# A policy checkpoint that works with this example script can be found at:
# rllib/tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022
parser.add_argument( parser.add_argument(
"--checkpoint_file", "--checkpoint_file",
help="Path to an RLlib checkpoint file.", help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
) )
parser.add_argument( parser.add_argument(
"--policy_id", "--policy_id",
@ -87,13 +90,21 @@ V1ToV2ActionConnector = register_lambda_action_connector(
) )
def run(): def run(checkpoint_path):
# Restore policy. # Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id]) policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
policy = policies[args.policy_id] policy = policies[args.policy_id]
# Adapt policy trained for standard CartPole to the new env. # Adapt policy trained for standard CartPole to the new env.
ctx: ConnectorContext = ConnectorContext.from_policy(policy) ctx: ConnectorContext = ConnectorContext.from_policy(policy)
# When this policy was trained, it relied on FlattenDataAgentConnector
# to add a batch dimension to single observations.
# This is not necessary anymore, so we first remove the previously used
# FlattenDataAgentConnector.
policy.agent_connectors.remove("FlattenDataAgentConnector")
# We then add the two adapter connectors.
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx)) policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
policy.action_connectors.append(V1ToV2ActionConnector(ctx)) policy.action_connectors.append(V1ToV2ActionConnector(ctx))
@ -115,4 +126,7 @@ def run():
if __name__ == "__main__": if __name__ == "__main__":
run() checkpoint_path = str(
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
)
run(checkpoint_path)

View file

@ -4,6 +4,7 @@ and use it in a serving/inference setting.
import argparse import argparse
import gym import gym
from pathlib import Path
from ray.rllib.utils.policy import ( from ray.rllib.utils.policy import (
load_policies_from_checkpoint, load_policies_from_checkpoint,
@ -15,7 +16,7 @@ parser = argparse.ArgumentParser()
# This should a checkpoint created with connectors enabled. # This should a checkpoint created with connectors enabled.
parser.add_argument( parser.add_argument(
"--checkpoint_file", "--checkpoint_file",
help="Path to an RLlib checkpoint file.", help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
) )
parser.add_argument( parser.add_argument(
"--policy_id", "--policy_id",
@ -27,9 +28,9 @@ args = parser.parse_args()
assert args.checkpoint_file, "Must specify flag --checkpoint_file." assert args.checkpoint_file, "Must specify flag --checkpoint_file."
def run(): def run(checkpoint_path):
# Restore policy. # Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id]) policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
policy = policies[args.policy_id] policy = policies[args.policy_id]
# Run CartPole. # Run CartPole.
@ -52,4 +53,7 @@ def run():
if __name__ == "__main__": if __name__ == "__main__":
run() checkpoint_path = str(
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
)
run(checkpoint_path)

View file

@ -4,6 +4,7 @@ The checkpointed policy may be trained with a different algorithm too.
""" """
import argparse import argparse
from pathlib import Path
import pyspiel import pyspiel
import ray import ray
@ -23,7 +24,10 @@ parser.add_argument(
"--checkpoint_file", "--checkpoint_file",
type=str, type=str,
default="", default="",
help="Path to a connector enabled checkpoint file for restoring.", help=(
"Path to a connector enabled checkpoint file for restoring,"
"relative to //ray/rllib/ folder."
),
) )
parser.add_argument( parser.add_argument(
"--policy_id", "--policy_id",
@ -46,8 +50,13 @@ class AddPolicyCallback(DefaultCallbacks):
super().__init__() super().__init__()
def on_algorithm_init(self, *, algorithm, **kwargs): def on_algorithm_init(self, *, algorithm, **kwargs):
checkpoint_path = str(
Path(__file__)
.parent.parent.parent.absolute()
.joinpath(args.checkpoint_file)
)
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint( policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
args.checkpoint_file checkpoint_path
) )
assert args.policy_id in policy_specs, ( assert args.policy_id in policy_specs, (

View file

@ -1,5 +1,5 @@
import gym import gym
import pickle import ray.cloudpickle as pickle
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.policy import PolicySpec