mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Fix connector examples (#27583)
This commit is contained in:
parent
89b2f616fd
commit
5f07987ab1
7 changed files with 51 additions and 24 deletions
26
rllib/BUILD
26
rllib/BUILD
|
@ -3343,36 +3343,36 @@ py_test(
|
|||
py_test(
|
||||
name = "examples/connectors/run_connector_policy",
|
||||
main = "examples/connectors/run_connector_policy.py",
|
||||
tags = ["team:rllib", "exclusive", "examples", ],
|
||||
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||
size = "small",
|
||||
srcs = ["examples/connectors/run_connector_policy.py"],
|
||||
data = [
|
||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
|
||||
],
|
||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
|
||||
data = glob([
|
||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.new",
|
||||
]),
|
||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-08062022"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/connectors/adapt_connector_policy",
|
||||
main = "examples/connectors/adapt_connector_policy.py",
|
||||
tags = ["team:rllib", "exclusive", "examples", ],
|
||||
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||
size = "small",
|
||||
srcs = ["examples/connectors/adapt_connector_policy.py"],
|
||||
data = [
|
||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
|
||||
],
|
||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
|
||||
data = glob([
|
||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.old",
|
||||
]),
|
||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/connectors/self_play_with_policy_checkpoint",
|
||||
main = "examples/connectors/self_play_with_policy_checkpoint.py",
|
||||
tags = ["team:rllib", "exclusive", "examples", ],
|
||||
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||
size = "small",
|
||||
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
|
||||
data = [
|
||||
data = glob([
|
||||
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
||||
],
|
||||
]),
|
||||
args = [
|
||||
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
||||
"--train_iteration=1" # Smoke test.
|
||||
|
|
|
@ -5,6 +5,7 @@ and adapt/use it with a different version of the environment.
|
|||
import argparse
|
||||
import gym
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from ray.rllib.utils.policy import (
|
||||
|
@ -23,9 +24,11 @@ from ray.rllib.utils.typing import (
|
|||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# A policy checkpoint that works with this example script can be found at:
|
||||
# rllib/tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
help="Path to an RLlib checkpoint file.",
|
||||
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy_id",
|
||||
|
@ -87,13 +90,21 @@ V1ToV2ActionConnector = register_lambda_action_connector(
|
|||
)
|
||||
|
||||
|
||||
def run():
|
||||
def run(checkpoint_path):
|
||||
# Restore policy.
|
||||
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
|
||||
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
|
||||
policy = policies[args.policy_id]
|
||||
|
||||
# Adapt policy trained for standard CartPole to the new env.
|
||||
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
||||
|
||||
# When this policy was trained, it relied on FlattenDataAgentConnector
|
||||
# to add a batch dimension to single observations.
|
||||
# This is not necessary anymore, so we first remove the previously used
|
||||
# FlattenDataAgentConnector.
|
||||
policy.agent_connectors.remove("FlattenDataAgentConnector")
|
||||
|
||||
# We then add the two adapter connectors.
|
||||
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
|
||||
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
|
||||
|
||||
|
@ -115,4 +126,7 @@ def run():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
checkpoint_path = str(
|
||||
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
|
||||
)
|
||||
run(checkpoint_path)
|
||||
|
|
|
@ -4,6 +4,7 @@ and use it in a serving/inference setting.
|
|||
|
||||
import argparse
|
||||
import gym
|
||||
from pathlib import Path
|
||||
|
||||
from ray.rllib.utils.policy import (
|
||||
load_policies_from_checkpoint,
|
||||
|
@ -15,7 +16,7 @@ parser = argparse.ArgumentParser()
|
|||
# This should a checkpoint created with connectors enabled.
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
help="Path to an RLlib checkpoint file.",
|
||||
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy_id",
|
||||
|
@ -27,9 +28,9 @@ args = parser.parse_args()
|
|||
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
|
||||
|
||||
|
||||
def run():
|
||||
def run(checkpoint_path):
|
||||
# Restore policy.
|
||||
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
|
||||
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
|
||||
policy = policies[args.policy_id]
|
||||
|
||||
# Run CartPole.
|
||||
|
@ -52,4 +53,7 @@ def run():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
checkpoint_path = str(
|
||||
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
|
||||
)
|
||||
run(checkpoint_path)
|
||||
|
|
|
@ -4,6 +4,7 @@ The checkpointed policy may be trained with a different algorithm too.
|
|||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pyspiel
|
||||
|
||||
import ray
|
||||
|
@ -23,7 +24,10 @@ parser.add_argument(
|
|||
"--checkpoint_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to a connector enabled checkpoint file for restoring.",
|
||||
help=(
|
||||
"Path to a connector enabled checkpoint file for restoring,"
|
||||
"relative to //ray/rllib/ folder."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy_id",
|
||||
|
@ -46,8 +50,13 @@ class AddPolicyCallback(DefaultCallbacks):
|
|||
super().__init__()
|
||||
|
||||
def on_algorithm_init(self, *, algorithm, **kwargs):
|
||||
checkpoint_path = str(
|
||||
Path(__file__)
|
||||
.parent.parent.parent.absolute()
|
||||
.joinpath(args.checkpoint_file)
|
||||
)
|
||||
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
|
||||
args.checkpoint_file
|
||||
checkpoint_path
|
||||
)
|
||||
|
||||
assert args.policy_id in policy_specs, (
|
||||
|
|
Binary file not shown.
|
@ -1,5 +1,5 @@
|
|||
import gym
|
||||
import pickle
|
||||
import ray.cloudpickle as pickle
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
|
|
Loading…
Add table
Reference in a new issue