mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix connector examples (#27583)
This commit is contained in:
parent
89b2f616fd
commit
5f07987ab1
7 changed files with 51 additions and 24 deletions
26
rllib/BUILD
26
rllib/BUILD
|
@ -3343,36 +3343,36 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/connectors/run_connector_policy",
|
name = "examples/connectors/run_connector_policy",
|
||||||
main = "examples/connectors/run_connector_policy.py",
|
main = "examples/connectors/run_connector_policy.py",
|
||||||
tags = ["team:rllib", "exclusive", "examples", ],
|
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["examples/connectors/run_connector_policy.py"],
|
srcs = ["examples/connectors/run_connector_policy.py"],
|
||||||
data = [
|
data = glob([
|
||||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
|
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.new",
|
||||||
],
|
]),
|
||||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
|
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-08062022"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/connectors/adapt_connector_policy",
|
name = "examples/connectors/adapt_connector_policy",
|
||||||
main = "examples/connectors/adapt_connector_policy.py",
|
main = "examples/connectors/adapt_connector_policy.py",
|
||||||
tags = ["team:rllib", "exclusive", "examples", ],
|
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["examples/connectors/adapt_connector_policy.py"],
|
srcs = ["examples/connectors/adapt_connector_policy.py"],
|
||||||
data = [
|
data = glob([
|
||||||
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6",
|
"tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6.old",
|
||||||
],
|
]),
|
||||||
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6"]
|
args = ["--checkpoint_file=tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/connectors/self_play_with_policy_checkpoint",
|
name = "examples/connectors/self_play_with_policy_checkpoint",
|
||||||
main = "examples/connectors/self_play_with_policy_checkpoint.py",
|
main = "examples/connectors/self_play_with_policy_checkpoint.py",
|
||||||
tags = ["team:rllib", "exclusive", "examples", ],
|
tags = ["team:rllib", "exclusive", "examples", "examples_C", "examples_C_AtoT"],
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
|
srcs = ["examples/connectors/self_play_with_policy_checkpoint.py"],
|
||||||
data = [
|
data = glob([
|
||||||
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
"tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
||||||
],
|
]),
|
||||||
args = [
|
args = [
|
||||||
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
"--checkpoint_file=tests/data/checkpoints/PPO_open_spiel_checkpoint-6",
|
||||||
"--train_iteration=1" # Smoke test.
|
"--train_iteration=1" # Smoke test.
|
||||||
|
|
|
@ -5,6 +5,7 @@ and adapt/use it with a different version of the environment.
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from ray.rllib.utils.policy import (
|
from ray.rllib.utils.policy import (
|
||||||
|
@ -23,9 +24,11 @@ from ray.rllib.utils.typing import (
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
# A policy checkpoint that works with this example script can be found at:
|
||||||
|
# rllib/tests/data/checkpoints/APPO_CartPole-v0_checkpoint-6-07092022
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_file",
|
"--checkpoint_file",
|
||||||
help="Path to an RLlib checkpoint file.",
|
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--policy_id",
|
"--policy_id",
|
||||||
|
@ -87,13 +90,21 @@ V1ToV2ActionConnector = register_lambda_action_connector(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run(checkpoint_path):
|
||||||
# Restore policy.
|
# Restore policy.
|
||||||
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
|
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
|
||||||
policy = policies[args.policy_id]
|
policy = policies[args.policy_id]
|
||||||
|
|
||||||
# Adapt policy trained for standard CartPole to the new env.
|
# Adapt policy trained for standard CartPole to the new env.
|
||||||
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
ctx: ConnectorContext = ConnectorContext.from_policy(policy)
|
||||||
|
|
||||||
|
# When this policy was trained, it relied on FlattenDataAgentConnector
|
||||||
|
# to add a batch dimension to single observations.
|
||||||
|
# This is not necessary anymore, so we first remove the previously used
|
||||||
|
# FlattenDataAgentConnector.
|
||||||
|
policy.agent_connectors.remove("FlattenDataAgentConnector")
|
||||||
|
|
||||||
|
# We then add the two adapter connectors.
|
||||||
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
|
policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
|
||||||
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
|
policy.action_connectors.append(V1ToV2ActionConnector(ctx))
|
||||||
|
|
||||||
|
@ -115,4 +126,7 @@ def run():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run()
|
checkpoint_path = str(
|
||||||
|
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
|
||||||
|
)
|
||||||
|
run(checkpoint_path)
|
||||||
|
|
|
@ -4,6 +4,7 @@ and use it in a serving/inference setting.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gym
|
import gym
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from ray.rllib.utils.policy import (
|
from ray.rllib.utils.policy import (
|
||||||
load_policies_from_checkpoint,
|
load_policies_from_checkpoint,
|
||||||
|
@ -15,7 +16,7 @@ parser = argparse.ArgumentParser()
|
||||||
# This should a checkpoint created with connectors enabled.
|
# This should a checkpoint created with connectors enabled.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_file",
|
"--checkpoint_file",
|
||||||
help="Path to an RLlib checkpoint file.",
|
help="Path to an RLlib checkpoint file, relative to //ray/rllib/ folder.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--policy_id",
|
"--policy_id",
|
||||||
|
@ -27,9 +28,9 @@ args = parser.parse_args()
|
||||||
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
|
assert args.checkpoint_file, "Must specify flag --checkpoint_file."
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run(checkpoint_path):
|
||||||
# Restore policy.
|
# Restore policy.
|
||||||
policies = load_policies_from_checkpoint(args.checkpoint_file, [args.policy_id])
|
policies = load_policies_from_checkpoint(checkpoint_path, [args.policy_id])
|
||||||
policy = policies[args.policy_id]
|
policy = policies[args.policy_id]
|
||||||
|
|
||||||
# Run CartPole.
|
# Run CartPole.
|
||||||
|
@ -52,4 +53,7 @@ def run():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run()
|
checkpoint_path = str(
|
||||||
|
Path(__file__).parent.parent.parent.absolute().joinpath(args.checkpoint_file)
|
||||||
|
)
|
||||||
|
run(checkpoint_path)
|
||||||
|
|
|
@ -4,6 +4,7 @@ The checkpointed policy may be trained with a different algorithm too.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
import pyspiel
|
import pyspiel
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -23,7 +24,10 @@ parser.add_argument(
|
||||||
"--checkpoint_file",
|
"--checkpoint_file",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="Path to a connector enabled checkpoint file for restoring.",
|
help=(
|
||||||
|
"Path to a connector enabled checkpoint file for restoring,"
|
||||||
|
"relative to //ray/rllib/ folder."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--policy_id",
|
"--policy_id",
|
||||||
|
@ -46,8 +50,13 @@ class AddPolicyCallback(DefaultCallbacks):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def on_algorithm_init(self, *, algorithm, **kwargs):
|
def on_algorithm_init(self, *, algorithm, **kwargs):
|
||||||
|
checkpoint_path = str(
|
||||||
|
Path(__file__)
|
||||||
|
.parent.parent.parent.absolute()
|
||||||
|
.joinpath(args.checkpoint_file)
|
||||||
|
)
|
||||||
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
|
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
|
||||||
args.checkpoint_file
|
checkpoint_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert args.policy_id in policy_specs, (
|
assert args.policy_id in policy_specs, (
|
||||||
|
|
Binary file not shown.
|
@ -1,5 +1,5 @@
|
||||||
import gym
|
import gym
|
||||||
import pickle
|
import ray.cloudpickle as pickle
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.policy.policy import PolicySpec
|
from ray.rllib.policy.policy import PolicySpec
|
||||||
|
|
Loading…
Add table
Reference in a new issue