ray/rllib/utils/policy.py
2022-08-07 17:48:09 -07:00

245 lines
8.2 KiB
Python

import gym
import ray.cloudpickle as pickle
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import (
ActionConnectorDataType,
AgentConnectorDataType,
AgentConnectorsOutput,
PartialAlgorithmConfigDict,
PolicyID,
PolicyState,
TensorStructType,
TensorType,
)
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
tf1, tf, tfv = try_import_tf()
@PublicAPI
def create_policy_for_framework(
policy_id: str,
policy_class: "Policy",
merged_config: PartialAlgorithmConfigDict,
observation_space: gym.Space,
action_space: gym.Space,
worker_index: int = 0,
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None,
):
"""Frame specific policy creation logics.
Args:
policy_id: Policy ID.
policy_class: Policy class type.
merged_config: Complete policy config.
observation_space: Observation space of env.
action_space: Action space of env.
worker_index: Index of worker holding this policy. Default is 0.
session_creator: An optional tf1.Session creation callable.
seed: Optional random seed.
"""
framework = merged_config.get("framework", "tf")
# Tf.
if framework in ["tf2", "tf", "tfe"]:
var_scope = policy_id + (f"_wk{worker_index}" if worker_index else "")
# For tf static graph, build every policy in its own graph
# and create a new session for it.
if framework == "tf":
with tf1.Graph().as_default():
if session_creator:
sess = session_creator()
else:
sess = tf1.Session(
config=tf1.ConfigProto(
gpu_options=tf1.GPUOptions(allow_growth=True)
)
)
with sess.as_default():
# Set graph-level seed.
if seed is not None:
tf1.set_random_seed(seed)
with tf1.variable_scope(var_scope):
return policy_class(
observation_space, action_space, merged_config
)
# For tf-eager: no graph, no session.
else:
with tf1.variable_scope(var_scope):
return policy_class(observation_space, action_space, merged_config)
# Non-tf: No graph, no session.
else:
return policy_class(observation_space, action_space, merged_config)
@PublicAPI(stability="alpha")
def parse_policy_specs_from_checkpoint(
path: str,
) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]:
"""Read and parse policy specifications from a checkpoint file.
Args:
path: Path to a policy checkpoint.
Returns:
A tuple of: base policy config, dictionary of policy specs, and
dictionary of policy states.
"""
with open(path, "rb") as f:
checkpoint_dict = pickle.load(f)
# Policy data is contained as a serialized binary blob under their
# ID keys.
w = pickle.loads(checkpoint_dict["worker"])
policy_config = w["policy_config"]
assert policy_config.get("enable_connectors", False), (
"load_policies_from_checkpoint only works for checkpoints generated by stacks "
"with connectors enabled."
)
policy_states = w["state"]
serialized_policy_specs = w["policy_specs"]
policy_specs = {
id: PolicySpec.deserialize(spec) for id, spec in serialized_policy_specs.items()
}
return policy_config, policy_specs, policy_states
@PublicAPI(stability="alpha")
def load_policies_from_checkpoint(
path: str, policy_ids: Optional[List[PolicyID]] = None
) -> Dict[str, "Policy"]:
"""Load the list of policies from a connector enabled policy checkpoint.
Args:
path: File path to the checkpoint file.
policy_ids: a list of policy IDs to be restored. If missing, we will
load all policies contained in this checkpoint.
Returns:
"""
policy_config, policy_specs, policy_states = parse_policy_specs_from_checkpoint(
path
)
policies = {}
for id, policy_spec in policy_specs.items():
if policy_ids and id not in policy_ids:
# User want specific policies, and this is not one of them.
continue
merged_config = merge_dicts(policy_config, policy_spec.config or {})
policy = create_policy_for_framework(
id,
policy_spec.policy_class,
merged_config,
policy_spec.observation_space,
policy_spec.action_space,
)
if id in policy_states:
policy.set_state(policy_states[id])
if policy.agent_connectors:
policy.agent_connectors.is_training(False)
policies[id] = policy
return policies
@PublicAPI(stability="alpha")
def local_policy_inference(
policy: "Policy",
env_id: str,
agent_id: str,
obs: TensorStructType,
) -> TensorStructType:
"""Run a connector enabled policy using environment observation.
policy_inference manages policy and agent/action connectors,
so the user does not have to care about RNN state buffering or
extra fetch dictionaries.
Note that connectors are intentionally run separately from
compute_actions_from_input_dict(), so we can have the option
of running per-user connectors on the client side in a
server-client deployment.
Args:
policy: Policy.
env_id: Environment ID.
agent_id: Agent ID.
obs: Env obseration.
Returns:
List of outputs from policy forward pass.
"""
assert (
policy.agent_connectors
), "policy_inference only works with connector enabled policies."
# TODO(jungong) : support multiple env, multiple agent inference.
input_dict = {SampleBatch.NEXT_OBS: obs}
acd_list: List[AgentConnectorDataType] = [
AgentConnectorDataType(env_id, agent_id, input_dict)
]
ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list)
outputs = []
for ac in ac_outputs:
policy_output = policy.compute_actions_from_input_dict(ac.data.for_action)
if policy.action_connectors:
acd = ActionConnectorDataType(env_id, agent_id, policy_output)
acd = policy.action_connectors(acd)
actions = acd.output
else:
actions = policy_output[0]
outputs.append(actions)
# Notify agent connectors with this new policy output.
# Necessary for state buffering agent connectors, for example.
policy.agent_connectors.on_policy_output(
ActionConnectorDataType(env_id, agent_id, policy_output)
)
return outputs
@PublicAPI
def compute_log_likelihoods_from_input_dict(
policy: "Policy", batch: Union[SampleBatch, Dict[str, TensorStructType]]
):
"""Returns log likelihood for actions in given batch for policy.
Computes likelihoods by passing the observations through the current
policy's `compute_log_likelihoods()` method
Args:
batch: The SampleBatch or MultiAgentBatch to calculate action
log likelihoods from. This batch/batches must contain OBS
and ACTIONS keys.
Returns:
The probabilities of the actions in the batch, given the
observations and the policy.
"""
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
num_state_inputs += 1
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
log_likelihoods: TensorType = policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
actions_normalized=policy.config.get("actions_in_input_normalized", False),
)
return log_likelihoods