[RLlib] Fix connector examples (#27583)

This commit is contained in:
Jun Gong 2022-08-07 17:48:09 -07:00 committed by GitHub
parent 89b2f616fd
commit 5f07987ab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 24 deletions

View file

@ -3343,36 +3343,36 @@ py_test(
py_test(
name = "examples/connectors/run_connector_policy",
main = "examples/connectors/run_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ],
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small",
srcs = ["examples/connectors/run_connector_policy.py"],
data = [
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
],
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
data = glob([
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.new",
]),
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-08062022"]
)
py_test(
name = "examples/connectors/adapt_connector_policy",
main = "examples/connectors/adapt_connector_policy.py",
tags = ["team:rllib", "exclusive", "examples", ],
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small",
srcs = ["examples/connectors/adapt_connector_policy.py"],
data = [
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
],
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
data = glob([
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.old",
]),
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022"]
)
py_test(
name = "examples/connectors/self_play_with_policy_checkpoint",
main = "examples/connectors/self_play_with_policy_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", ],
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
size = "small",
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
data = [
data = glob([
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
],
]),
args = [
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
"--train_iteration=1" # Smoke test.

View file

@ -5,6 +5,7 @@ and adapt/use it with a different version of the environment.
import argparse
import gym
import numpy as np
from pathlib import Path
from typing import Dict
from ray.rllib.utils.policy import (
@ -23,9 +24,11 @@ from ray.rllib.utils.typing import (
parser = argparse.ArgumentParser()
# A policy checkpoint that works with this example script can be found at:
# rllib/tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022
parser.add_argument(
"--checkpoint_file",
help="Path to an RLlib checkpoint file.",
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
)
parser.add_argument(
"--policy_id",
@ -87,13 +90,21 @@ V1ToV2ActionConnector = register_lambda_action_connector(
)
def run():
def run(checkpoint_path):
# Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
policy = policies[args.policy_id]
# Adapt policy trained for standard CartPole to the new env.
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
# When this policy was trained, it relied on FlattenDataAgentConnector
# to add a batch dimension to single observations.
# This is not necessary anymore, so we first remove the previously used
# FlattenDataAgentConnector.
policy.agent_connectors.remove("FlattenDataAgentConnector")
# We then add the two adapter connectors.
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
@ -115,4 +126,7 @@ def run():
if __name__ == "__main__":
run()
checkpoint_path = str(
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
)
run(checkpoint_path)

View file

@ -4,6 +4,7 @@ and use it in a serving/inference setting.
import argparse
import gym
from pathlib import Path
from ray.rllib.utils.policy import (
load_policies_from_checkpoint,
@ -15,7 +16,7 @@ parser = argparse.ArgumentParser()
# This should a checkpoint created with connectors enabled.
parser.add_argument(
"--checkpoint_file",
help="Path to an RLlib checkpoint file.",
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
)
parser.add_argument(
"--policy_id",
@ -27,9 +28,9 @@ args = parser.parse_args()
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
def run():
def run(checkpoint_path):
# Restore policy.
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
policy = policies[args.policy_id]
# Run CartPole.
@ -52,4 +53,7 @@ def run():
if __name__ == "__main__":
run()
checkpoint_path = str(
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
)
run(checkpoint_path)

View file

@ -4,6 +4,7 @@ The checkpointed policy may be trained with a different algorithm too.
"""
import argparse
from pathlib import Path
import pyspiel
import ray
@ -23,7 +24,10 @@ parser.add_argument(
"--checkpoint_file",
type=str,
default="",
help="Path to a connector enabled checkpoint file for restoring.",
help=(
"Path to a connector enabled checkpoint file for restoring,"
"relative to //ray/rllib/ folder."
),
)
parser.add_argument(
"--policy_id",
@ -46,8 +50,13 @@ class AddPolicyCallback(DefaultCallbacks):
super().__init__()
def on_algorithm_init(self, *, algorithm, **kwargs):
checkpoint_path = str(
Path(__file__)
.parent.parent.parent.absolute()
.joinpath(args.checkpoint_file)
)
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
args.checkpoint_file
checkpoint_path
)
assert args.policy_id in policy_specs, (

View file

@ -1,5 +1,5 @@
import gym
import pickle
import ray.cloudpickle as pickle
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from ray.rllib.policy.policy import PolicySpec