mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
245 lines
8.2 KiB
Python
245 lines
8.2 KiB
Python
import gym
|
|
import ray.cloudpickle as pickle
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils import merge_dicts
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.typing import (
|
|
ActionConnectorDataType,
|
|
AgentConnectorDataType,
|
|
AgentConnectorsOutput,
|
|
PartialAlgorithmConfigDict,
|
|
PolicyID,
|
|
PolicyState,
|
|
TensorStructType,
|
|
TensorType,
|
|
)
|
|
from ray.util.annotations import PublicAPI
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.policy.policy import Policy
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
|
|
@PublicAPI
|
|
def create_policy_for_framework(
|
|
policy_id: str,
|
|
policy_class: "Policy",
|
|
merged_config: PartialAlgorithmConfigDict,
|
|
observation_space: gym.Space,
|
|
action_space: gym.Space,
|
|
worker_index: int = 0,
|
|
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
|
seed: Optional[int] = None,
|
|
):
|
|
"""Frame specific policy creation logics.
|
|
|
|
Args:
|
|
policy_id: Policy ID.
|
|
policy_class: Policy class type.
|
|
merged_config: Complete policy config.
|
|
observation_space: Observation space of env.
|
|
action_space: Action space of env.
|
|
worker_index: Index of worker holding this policy. Default is 0.
|
|
session_creator: An optional tf1.Session creation callable.
|
|
seed: Optional random seed.
|
|
"""
|
|
framework = merged_config.get("framework", "tf")
|
|
# Tf.
|
|
if framework in ["tf2", "tf", "tfe"]:
|
|
var_scope = policy_id + (f"_wk{worker_index}" if worker_index else "")
|
|
# For tf static graph, build every policy in its own graph
|
|
# and create a new session for it.
|
|
if framework == "tf":
|
|
with tf1.Graph().as_default():
|
|
if session_creator:
|
|
sess = session_creator()
|
|
else:
|
|
sess = tf1.Session(
|
|
config=tf1.ConfigProto(
|
|
gpu_options=tf1.GPUOptions(allow_growth=True)
|
|
)
|
|
)
|
|
with sess.as_default():
|
|
# Set graph-level seed.
|
|
if seed is not None:
|
|
tf1.set_random_seed(seed)
|
|
with tf1.variable_scope(var_scope):
|
|
return policy_class(
|
|
observation_space, action_space, merged_config
|
|
)
|
|
# For tf-eager: no graph, no session.
|
|
else:
|
|
with tf1.variable_scope(var_scope):
|
|
return policy_class(observation_space, action_space, merged_config)
|
|
# Non-tf: No graph, no session.
|
|
else:
|
|
return policy_class(observation_space, action_space, merged_config)
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def parse_policy_specs_from_checkpoint(
|
|
path: str,
|
|
) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]:
|
|
"""Read and parse policy specifications from a checkpoint file.
|
|
|
|
Args:
|
|
path: Path to a policy checkpoint.
|
|
|
|
Returns:
|
|
A tuple of: base policy config, dictionary of policy specs, and
|
|
dictionary of policy states.
|
|
"""
|
|
with open(path, "rb") as f:
|
|
checkpoint_dict = pickle.load(f)
|
|
# Policy data is contained as a serialized binary blob under their
|
|
# ID keys.
|
|
w = pickle.loads(checkpoint_dict["worker"])
|
|
|
|
policy_config = w["policy_config"]
|
|
assert policy_config.get("enable_connectors", False), (
|
|
"load_policies_from_checkpoint only works for checkpoints generated by stacks "
|
|
"with connectors enabled."
|
|
)
|
|
policy_states = w["state"]
|
|
serialized_policy_specs = w["policy_specs"]
|
|
policy_specs = {
|
|
id: PolicySpec.deserialize(spec) for id, spec in serialized_policy_specs.items()
|
|
}
|
|
|
|
return policy_config, policy_specs, policy_states
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def load_policies_from_checkpoint(
|
|
path: str, policy_ids: Optional[List[PolicyID]] = None
|
|
) -> Dict[str, "Policy"]:
|
|
"""Load the list of policies from a connector enabled policy checkpoint.
|
|
|
|
Args:
|
|
path: File path to the checkpoint file.
|
|
policy_ids: a list of policy IDs to be restored. If missing, we will
|
|
load all policies contained in this checkpoint.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
|
|
path
|
|
)
|
|
|
|
policies = {}
|
|
for id, policy_spec in policy_specs.items():
|
|
if policy_ids and id not in policy_ids:
|
|
# User want specific policies, and this is not one of them.
|
|
continue
|
|
|
|
merged_config = merge_dicts(policy_config, policy_spec.config or {})
|
|
policy = create_policy_for_framework(
|
|
id,
|
|
policy_spec.policy_class,
|
|
merged_config,
|
|
policy_spec.observation_space,
|
|
policy_spec.action_space,
|
|
)
|
|
if id in policy_states:
|
|
policy.set_state(policy_states[id])
|
|
if policy.agent_connectors:
|
|
policy.agent_connectors.is_training(False)
|
|
policies[id] = policy
|
|
|
|
return policies
|
|
|
|
|
|
@PublicAPI(stability="alpha")
|
|
def local_policy_inference(
|
|
policy: "Policy",
|
|
env_id: str,
|
|
agent_id: str,
|
|
obs: TensorStructType,
|
|
) -> TensorStructType:
|
|
"""Run a connector enabled policy using environment observation.
|
|
|
|
policy_inference manages policy and agent/action connectors,
|
|
so the user does not have to care about RNN state buffering or
|
|
extra fetch dictionaries.
|
|
Note that connectors are intentionally run separately from
|
|
compute_actions_from_input_dict(), so we can have the option
|
|
of running per-user connectors on the client side in a
|
|
server-client deployment.
|
|
|
|
Args:
|
|
policy: Policy.
|
|
env_id: Environment ID.
|
|
agent_id: Agent ID.
|
|
obs: Env obseration.
|
|
|
|
Returns:
|
|
List of outputs from policy forward pass.
|
|
"""
|
|
assert (
|
|
policy.agent_connectors
|
|
), "policy_inference only works with connector enabled policies."
|
|
|
|
# TODO(jungong) : support multiple env, multiple agent inference.
|
|
input_dict = {SampleBatch.NEXT_OBS: obs}
|
|
acd_list: List[AgentConnectorDataType] = [
|
|
AgentConnectorDataType(env_id, agent_id, input_dict)
|
|
]
|
|
ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
|
|
outputs = []
|
|
for ac in ac_outputs:
|
|
policy_output = policy.compute_actions_from_input_dict(ac.data.for_action)
|
|
|
|
if policy.action_connectors:
|
|
acd = ActionConnectorDataType(env_id, agent_id, policy_output)
|
|
acd = policy.action_connectors(acd)
|
|
actions = acd.output
|
|
else:
|
|
actions = policy_output[0]
|
|
|
|
outputs.append(actions)
|
|
|
|
# Notify agent connectors with this new policy output.
|
|
# Necessary for state buffering agent connectors, for example.
|
|
policy.agent_connectors.on_policy_output(
|
|
ActionConnectorDataType(env_id, agent_id, policy_output)
|
|
)
|
|
return outputs
|
|
|
|
|
|
@PublicAPI
|
|
def compute_log_likelihoods_from_input_dict(
|
|
policy: "Policy", batch: Union[SampleBatch, Dict[str, TensorStructType]]
|
|
):
|
|
"""Returns log likelihood for actions in given batch for policy.
|
|
|
|
Computes likelihoods by passing the observations through the current
|
|
policy's `compute_log_likelihoods()` method
|
|
|
|
Args:
|
|
batch: The SampleBatch or MultiAgentBatch to calculate action
|
|
log likelihoods from. This batch/batches must contain OBS
|
|
and ACTIONS keys.
|
|
|
|
Returns:
|
|
The probabilities of the actions in the batch, given the
|
|
observations and the policy.
|
|
"""
|
|
num_state_inputs = 0
|
|
for k in batch.keys():
|
|
if k.startswith("state_in_"):
|
|
num_state_inputs += 1
|
|
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
|
log_likelihoods: TensorType = policy.compute_log_likelihoods(
|
|
actions=batch[SampleBatch.ACTIONS],
|
|
obs_batch=batch[SampleBatch.OBS],
|
|
state_batches=[batch[k] for k in state_keys],
|
|
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
|
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
|
actions_normalized=policy.config.get("actions_in_input_normalized", False),
|
|
)
|
|
return log_likelihoods
|