mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
196 lines
6.4 KiB
Python
196 lines
6.4 KiB
Python
![]() |
import gym
|
||
|
import pickle
|
||
|
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||
|
|
||
|
from ray.rllib.policy.policy import PolicySpec
|
||
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||
|
from ray.rllib.utils import merge_dicts
|
||
|
from ray.rllib.utils.framework import try_import_tf
|
||
|
from ray.rllib.utils.typing import (
|
||
|
ActionConnectorDataType,
|
||
|
AgentConnectorDataType,
|
||
|
AgentConnectorsOutput,
|
||
|
PartialAlgorithmConfigDict,
|
||
|
PolicyID,
|
||
|
PolicyOutputType,
|
||
|
PolicyState,
|
||
|
TensorStructType,
|
||
|
)
|
||
|
from ray.util.annotations import PublicAPI
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from ray.rllib.policy.policy import Policy
|
||
|
|
||
|
tf1, tf, tfv = try_import_tf()
|
||
|
|
||
|
|
||
|
@PublicAPI
|
||
|
def create_policy_for_framework(
|
||
|
policy_id: str,
|
||
|
policy_class: "Policy",
|
||
|
merged_config: PartialAlgorithmConfigDict,
|
||
|
observation_space: gym.Space,
|
||
|
action_space: gym.Space,
|
||
|
worker_index: int = 0,
|
||
|
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||
|
seed: Optional[int] = None,
|
||
|
):
|
||
|
"""Frame specific policy creation logics.
|
||
|
|
||
|
Args:
|
||
|
policy_id: Policy ID.
|
||
|
policy_class: Policy class type.
|
||
|
merged_config: Complete policy config.
|
||
|
observation_space: Observation space of env.
|
||
|
action_space: Action space of env.
|
||
|
worker_index: Index of worker holding this policy. Default is 0.
|
||
|
session_creator: An optional tf1.Session creation callable.
|
||
|
seed: Optional random seed.
|
||
|
"""
|
||
|
framework = merged_config.get("framework", "tf")
|
||
|
# Tf.
|
||
|
if framework in ["tf2", "tf", "tfe"]:
|
||
|
var_scope = policy_id + (f"_wk{worker_index}" if worker_index else "")
|
||
|
# For tf static graph, build every policy in its own graph
|
||
|
# and create a new session for it.
|
||
|
if framework == "tf":
|
||
|
with tf1.Graph().as_default():
|
||
|
if session_creator:
|
||
|
sess = session_creator()
|
||
|
else:
|
||
|
sess = tf1.Session(
|
||
|
config=tf1.ConfigProto(
|
||
|
gpu_options=tf1.GPUOptions(allow_growth=True)
|
||
|
)
|
||
|
)
|
||
|
with sess.as_default():
|
||
|
# Set graph-level seed.
|
||
|
if seed is not None:
|
||
|
tf1.set_random_seed(seed)
|
||
|
with tf1.variable_scope(var_scope):
|
||
|
return policy_class(
|
||
|
observation_space, action_space, merged_config
|
||
|
)
|
||
|
# For tf-eager: no graph, no session.
|
||
|
else:
|
||
|
with tf1.variable_scope(var_scope):
|
||
|
return policy_class(observation_space, action_space, merged_config)
|
||
|
# Non-tf: No graph, no session.
|
||
|
else:
|
||
|
return policy_class(observation_space, action_space, merged_config)
|
||
|
|
||
|
|
||
|
@PublicAPI(stability="alpha")
|
||
|
def parse_policy_specs_from_checkpoint(
|
||
|
path: str,
|
||
|
) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]:
|
||
|
"""Read and parse policy specifications from a checkpoint file.
|
||
|
|
||
|
Args:
|
||
|
path: Path to a policy checkpoint.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of: base policy config, dictionary of policy specs, and
|
||
|
dictionary of policy states.
|
||
|
"""
|
||
|
with open(path, "rb") as f:
|
||
|
checkpoint_dict = pickle.load(f)
|
||
|
# Policy data is contained as a serialized binary blob under their
|
||
|
# ID keys.
|
||
|
w = pickle.loads(checkpoint_dict["worker"])
|
||
|
|
||
|
policy_config = w["policy_config"]
|
||
|
assert policy_config.get("enable_connectors", False), (
|
||
|
"load_policies_from_checkpoint only works for checkpoints generated by stacks "
|
||
|
"with connectors enabled."
|
||
|
)
|
||
|
policy_states = w["state"]
|
||
|
serialized_policy_specs = w["policy_specs"]
|
||
|
policy_specs = {
|
||
|
id: PolicySpec.deserialize(spec) for id, spec in serialized_policy_specs.items()
|
||
|
}
|
||
|
|
||
|
return policy_config, policy_specs, policy_states
|
||
|
|
||
|
|
||
|
@PublicAPI(stability="alpha")
|
||
|
def load_policies_from_checkpoint(
|
||
|
path: str, policy_ids: Optional[List[PolicyID]] = None
|
||
|
) -> Dict[str, "Policy"]:
|
||
|
"""Load the list of policies from a connector enabled policy checkpoint.
|
||
|
|
||
|
Args:
|
||
|
path: File path to the checkpoint file.
|
||
|
policy_ids: a list of policy IDs to be restored. If missing, we will
|
||
|
load all policies contained in this checkpoint.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
|
||
|
path
|
||
|
)
|
||
|
|
||
|
policies = {}
|
||
|
for id, policy_spec in policy_specs.items():
|
||
|
if policy_ids and id not in policy_ids:
|
||
|
# User want specific policies, and this is not one of them.
|
||
|
continue
|
||
|
|
||
|
merged_config = merge_dicts(policy_config, policy_spec.config or {})
|
||
|
policy = create_policy_for_framework(
|
||
|
id,
|
||
|
policy_spec.policy_class,
|
||
|
merged_config,
|
||
|
policy_spec.observation_space,
|
||
|
policy_spec.action_space,
|
||
|
)
|
||
|
if id in policy_states:
|
||
|
policy.set_state(policy_states[id])
|
||
|
policies[id] = policy
|
||
|
|
||
|
return policies
|
||
|
|
||
|
|
||
|
@PublicAPI(stability="alpha")
|
||
|
def policy_inference(
|
||
|
policy: "Policy",
|
||
|
env_id: str,
|
||
|
agent_id: str,
|
||
|
obs: TensorStructType,
|
||
|
) -> List[PolicyOutputType]:
|
||
|
"""Run a connector enabled policy using environment observation.
|
||
|
|
||
|
policy_inference manages policy and agent/action connectors,
|
||
|
so the user does not have to care about RNN state buffering or
|
||
|
extra fetch dictionaries.
|
||
|
|
||
|
Args:
|
||
|
policy: Policy.
|
||
|
env_id: Environment ID.
|
||
|
agent_id: Agent ID.
|
||
|
obs: Env obseration.
|
||
|
|
||
|
Returns:
|
||
|
List of outputs from policy forward pass.
|
||
|
"""
|
||
|
assert (
|
||
|
policy.agent_connectors
|
||
|
), "policy_inference only works with connected enabled policies."
|
||
|
|
||
|
policy.agent_connectors.is_training(False)
|
||
|
# TODO(jungong) : support multiple env, multiple agent inference.
|
||
|
input_dict = {SampleBatch.NEXT_OBS: obs}
|
||
|
acd_list: List[AgentConnectorDataType] = [
|
||
|
AgentConnectorDataType(env_id, agent_id, input_dict)
|
||
|
]
|
||
|
ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
|
||
|
outputs = []
|
||
|
for ac in ac_outputs:
|
||
|
policy_output = policy.compute_actions_from_input_dict(ac.data.for_action)
|
||
|
if policy.action_connectors:
|
||
|
acd = ActionConnectorDataType(env_id, agent_id, policy_output)
|
||
|
acd = policy.action_connectors(acd)
|
||
|
outputs.append(acd.output)
|
||
|
return outputs
|