mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Implement policy_maps (multi-agent case) in RolloutWorkers as LRU caches. (#17031)
This commit is contained in:
parent
e0640ad0dc
commit
18d173b172
22 changed files with 503 additions and 208 deletions
|
@ -66,7 +66,7 @@ class TestIMPALA(unittest.TestCase):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if fw == "tf":
|
if fw == "tf":
|
||||||
check(policy._sess.run(policy.cur_lr), 0.0005)
|
check(policy.get_session().run(policy.cur_lr), 0.0005)
|
||||||
else:
|
else:
|
||||||
check(policy.cur_lr, 0.0005)
|
check(policy.cur_lr, 0.0005)
|
||||||
r1 = trainer.train()
|
r1 = trainer.train()
|
||||||
|
|
|
@ -36,10 +36,13 @@ class TestTrainer(unittest.TestCase):
|
||||||
"p0": (None, env.observation_space, env.action_space, {}),
|
"p0": (None, env.observation_space, env.action_space, {}),
|
||||||
},
|
},
|
||||||
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
|
"policy_mapping_fn": lambda aid, episode, **kwargs: "p0",
|
||||||
|
"policy_map_capacity": 2,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
# TODO: (sven) this will work for tf, once we have the DynamicTFPolicy
|
||||||
|
# refactor PR merged.
|
||||||
|
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
|
||||||
trainer = pg.PGTrainer(config=config)
|
trainer = pg.PGTrainer(config=config)
|
||||||
r = trainer.train()
|
r = trainer.train()
|
||||||
self.assertTrue("p0" in r["policy_reward_min"])
|
self.assertTrue("p0" in r["policy_reward_min"])
|
||||||
|
@ -62,8 +65,8 @@ class TestTrainer(unittest.TestCase):
|
||||||
)
|
)
|
||||||
pol_map = trainer.workers.local_worker().policy_map
|
pol_map = trainer.workers.local_worker().policy_map
|
||||||
self.assertTrue(new_pol is not trainer.get_policy("p0"))
|
self.assertTrue(new_pol is not trainer.get_policy("p0"))
|
||||||
self.assertTrue("p0" in pol_map)
|
for j in range(i):
|
||||||
self.assertTrue("p1" in pol_map)
|
self.assertTrue(f"p{j}" in pol_map)
|
||||||
self.assertTrue(len(pol_map) == i + 1)
|
self.assertTrue(len(pol_map) == i + 1)
|
||||||
r = trainer.train()
|
r = trainer.train()
|
||||||
self.assertTrue("p1" in r["policy_reward_min"])
|
self.assertTrue("p1" in r["policy_reward_min"])
|
||||||
|
|
|
@ -432,6 +432,13 @@ COMMON_CONFIG: TrainerConfigDict = {
|
||||||
# of (policy_cls, obs_space, act_space, config). This defines the
|
# of (policy_cls, obs_space, act_space, config). This defines the
|
||||||
# observation and action spaces of the policies and any extra config.
|
# observation and action spaces of the policies and any extra config.
|
||||||
"policies": {},
|
"policies": {},
|
||||||
|
# Keep this many policies in the "policy_map" (before writing
|
||||||
|
# least-recently used ones to disk/S3).
|
||||||
|
"policy_map_capacity": 100,
|
||||||
|
# Where to store overflowing (least-recently used) policies?
|
||||||
|
# Could be a directory (str) or an S3 location. None for using
|
||||||
|
# the default output dir.
|
||||||
|
"policy_map_cache": None,
|
||||||
# Function mapping agent ids to policy ids.
|
# Function mapping agent ids to policy ids.
|
||||||
"policy_mapping_fn": None,
|
"policy_mapping_fn": None,
|
||||||
# Optional list of policies to train, or None for all policies.
|
# Optional list of policies to train, or None for all policies.
|
||||||
|
@ -1181,7 +1188,7 @@ class Trainer(Trainable):
|
||||||
local worker).
|
local worker).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fn(worker):
|
def fn(worker: RolloutWorker):
|
||||||
# `foreach_worker` function: Adds the policy the the worker (and
|
# `foreach_worker` function: Adds the policy the the worker (and
|
||||||
# maybe changes its policy_mapping_fn - if provided here).
|
# maybe changes its policy_mapping_fn - if provided here).
|
||||||
worker.add_policy(
|
worker.add_policy(
|
||||||
|
|
|
@ -135,13 +135,10 @@ def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
||||||
if "new_obs" in k:
|
if "new_obs" in k:
|
||||||
new_obs_n.append(v)
|
new_obs_n.append(v)
|
||||||
|
|
||||||
target_act_sampler_n = [p.target_act_sampler for p in policies.values()]
|
for i, p in enumerate(policies.values()):
|
||||||
feed_dict = dict(zip(new_obs_ph_n, new_obs_n))
|
feed_dict = {new_obs_ph_n[i]: new_obs_n[i]}
|
||||||
|
new_act = p.get_session().run(p.target_act_sampler, feed_dict)
|
||||||
new_act_n = p.sess.run(target_act_sampler_n, feed_dict)
|
samples.update({"new_actions_%d" % i: new_act})
|
||||||
samples.update(
|
|
||||||
{"new_actions_%d" % i: new_act
|
|
||||||
for i, new_act in enumerate(new_act_n)})
|
|
||||||
|
|
||||||
# Share samples among agents.
|
# Share samples among agents.
|
||||||
policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
|
policy_batches = {pid: SampleBatch(samples) for pid in policies.keys()}
|
||||||
|
|
|
@ -230,7 +230,8 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
|
|
||||||
# _____ TensorFlow Initialization
|
# _____ TensorFlow Initialization
|
||||||
|
|
||||||
self.sess = tf1.get_default_session()
|
sess = tf1.get_default_session()
|
||||||
|
assert sess
|
||||||
|
|
||||||
def _make_loss_inputs(placeholders):
|
def _make_loss_inputs(placeholders):
|
||||||
return [(ph.name.split("/")[-1].split(":")[0], ph)
|
return [(ph.name.split("/")[-1].split(":")[0], ph)
|
||||||
|
@ -244,7 +245,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
obs_space,
|
obs_space,
|
||||||
act_space,
|
act_space,
|
||||||
config=config,
|
config=config,
|
||||||
sess=self.sess,
|
sess=sess,
|
||||||
obs_input=obs_ph_n[agent_id],
|
obs_input=obs_ph_n[agent_id],
|
||||||
sampled_action=act_sampler,
|
sampled_action=act_sampler,
|
||||||
loss=actor_loss + critic_loss,
|
loss=actor_loss + critic_loss,
|
||||||
|
@ -254,7 +255,7 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
del self.view_requirements["prev_actions"]
|
del self.view_requirements["prev_actions"]
|
||||||
del self.view_requirements["prev_rewards"]
|
del self.view_requirements["prev_rewards"]
|
||||||
|
|
||||||
self.sess.run(tf1.global_variables_initializer())
|
self.get_session().run(tf1.global_variables_initializer())
|
||||||
|
|
||||||
# Hard initial update
|
# Hard initial update
|
||||||
self.update_target(1.0)
|
self.update_target(1.0)
|
||||||
|
@ -297,11 +298,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
var_list = []
|
var_list = []
|
||||||
for var in self.vars.values():
|
for var in self.vars.values():
|
||||||
var_list += var
|
var_list += var
|
||||||
return {"_state": self.sess.run(var_list)}
|
return {"_state": self.get_session().run(var_list)}
|
||||||
|
|
||||||
@override(TFPolicy)
|
@override(TFPolicy)
|
||||||
def set_weights(self, weights):
|
def set_weights(self, weights):
|
||||||
self.sess.run(
|
self.get_session().run(
|
||||||
self.update_vars,
|
self.update_vars,
|
||||||
feed_dict=dict(zip(self.vars_ph, weights["_state"])))
|
feed_dict=dict(zip(self.vars_ph, weights["_state"])))
|
||||||
|
|
||||||
|
@ -377,6 +378,6 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
|
||||||
|
|
||||||
def update_target(self, tau=None):
|
def update_target(self, tau=None):
|
||||||
if tau is not None:
|
if tau is not None:
|
||||||
self.sess.run(self.update_target_vars, {self.tau: tau})
|
self.get_session().run(self.update_target_vars, {self.tau: tau})
|
||||||
else:
|
else:
|
||||||
self.sess.run(self.update_target_vars)
|
self.get_session().run(self.update_target_vars)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||||
TensorType
|
TensorType
|
||||||
|
@ -30,7 +30,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
policy_map: Dict[PolicyID, Policy],
|
policy_map: PolicyMap,
|
||||||
clip_rewards: Union[bool, float],
|
clip_rewards: Union[bool, float],
|
||||||
callbacks: "DefaultCallbacks",
|
callbacks: "DefaultCallbacks",
|
||||||
multiple_episodes_in_batch: bool = True,
|
multiple_episodes_in_batch: bool = True,
|
||||||
|
@ -39,8 +39,7 @@ class SampleCollector(metaclass=ABCMeta):
|
||||||
"""Initializes a SampleCollector instance.
|
"""Initializes a SampleCollector instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_map (Dict[str, Policy]): Maps policy ids to policy
|
policy_map (PolicyMap): Maps policy ids to policy instances.
|
||||||
instances.
|
|
||||||
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
clip_rewards (Union[bool, float]): Whether to clip rewards before
|
||||||
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
postprocessing (at +/-1.0) or the actual value to +/- clip.
|
||||||
callbacks (DefaultCallbacks): RLlib callbacks.
|
callbacks (DefaultCallbacks): RLlib callbacks.
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||||
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
|
||||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.debug import summarize
|
from ray.rllib.utils.debug import summarize
|
||||||
|
@ -292,7 +293,7 @@ class _PolicyCollector:
|
||||||
appended to this policy's buffers.
|
appended to this policy's buffers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy):
|
def __init__(self, policy: Policy):
|
||||||
"""Initializes a _PolicyCollector instance.
|
"""Initializes a _PolicyCollector instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -382,7 +383,7 @@ class SimpleListCollector(SampleCollector):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
policy_map: Dict[PolicyID, Policy],
|
policy_map: PolicyMap,
|
||||||
clip_rewards: Union[bool, float],
|
clip_rewards: Union[bool, float],
|
||||||
callbacks: "DefaultCallbacks",
|
callbacks: "DefaultCallbacks",
|
||||||
multiple_episodes_in_batch: bool = True,
|
multiple_episodes_in_batch: bool = True,
|
||||||
|
@ -650,8 +651,7 @@ class SimpleListCollector(SampleCollector):
|
||||||
post_batches[agent_id] = pre_batch
|
post_batches[agent_id] = pre_batch
|
||||||
if getattr(policy, "exploration", None) is not None:
|
if getattr(policy, "exploration", None) is not None:
|
||||||
policy.exploration.postprocess_trajectory(
|
policy.exploration.postprocess_trajectory(
|
||||||
policy, post_batches[agent_id],
|
policy, post_batches[agent_id], policy.get_session())
|
||||||
getattr(policy, "_sess", None))
|
|
||||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||||
post_batches[agent_id], other_batches, episode)
|
post_batches[agent_id], other_batches, episode)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import random
|
||||||
from typing import List, Dict, Callable, Any, TYPE_CHECKING
|
from typing import List, Dict, Callable, Any, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.deprecation import deprecation_warning
|
from ray.rllib.utils.deprecation import deprecation_warning
|
||||||
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
|
||||||
|
@ -49,9 +49,8 @@ class MultiAgentEpisode:
|
||||||
>>> episode.extra_batches.add(batch.build_and_reset())
|
>>> episode.extra_batches.add(batch.build_and_reset())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policies: Dict[PolicyID, Policy],
|
def __init__(self, policies: PolicyMap, policy_mapping_fn: Callable[
|
||||||
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
[AgentID, "MultiAgentEpisode"], PolicyID],
|
||||||
PolicyID],
|
|
||||||
batch_builder_factory: Callable[
|
batch_builder_factory: Callable[
|
||||||
[], "MultiAgentSampleBatchBuilder"],
|
[], "MultiAgentSampleBatchBuilder"],
|
||||||
extra_batch_callback: Callable[[SampleBatchType], None],
|
extra_batch_callback: Callable[[SampleBatchType], None],
|
||||||
|
@ -71,7 +70,7 @@ class MultiAgentEpisode:
|
||||||
self.user_data: Dict[str, Any] = {}
|
self.user_data: Dict[str, Any] = {}
|
||||||
self.hist_data: Dict[str, List[float]] = {}
|
self.hist_data: Dict[str, List[float]] = {}
|
||||||
self.media: Dict[str, Any] = {}
|
self.media: Dict[str, Any] = {}
|
||||||
self.policy_map: Dict[PolicyID, Policy] = policies
|
self.policy_map: PolicyMap = policies
|
||||||
self._policies = self.policy_map # backward compatibility
|
self._policies = self.policy_map # backward compatibility
|
||||||
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
self._policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"],
|
||||||
PolicyID] = policy_mapping_fn
|
PolicyID] = policy_mapping_fn
|
||||||
|
|
|
@ -28,6 +28,7 @@ from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||||
|
from ray.rllib.policy.policy_map import PolicyMap
|
||||||
from ray.rllib.policy.tf_policy import TFPolicy
|
from ray.rllib.policy.tf_policy import TFPolicy
|
||||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||||
from ray.rllib.utils import merge_dicts
|
from ray.rllib.utils import merge_dicts
|
||||||
|
@ -489,7 +490,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
|
|
||||||
self.make_env_fn = make_env
|
self.make_env_fn = make_env
|
||||||
|
|
||||||
self.tf_sess = None
|
|
||||||
policy_dict = _determine_spaces_for_multi_agent_dict(
|
policy_dict = _determine_spaces_for_multi_agent_dict(
|
||||||
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
|
policy_spec, self.env, spaces=spaces, policy_config=policy_config)
|
||||||
# List of IDs of those policies, which should be trained.
|
# List of IDs of those policies, which should be trained.
|
||||||
|
@ -498,7 +498,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
policy_dict.keys())
|
policy_dict.keys())
|
||||||
self.set_policies_to_train(self.policies_to_train)
|
self.set_policies_to_train(self.policies_to_train)
|
||||||
|
|
||||||
self.policy_map: Dict[PolicyID, Policy] = None
|
self.policy_map: PolicyMap = None
|
||||||
self.preprocessors: Dict[PolicyID, Preprocessor] = None
|
self.preprocessors: Dict[PolicyID, Preprocessor] = None
|
||||||
|
|
||||||
# Set Python random, numpy, env, and torch/tf seeds.
|
# Set Python random, numpy, env, and torch/tf seeds.
|
||||||
|
@ -541,26 +541,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
elif tf1 and policy_config.get("framework") == "tfe":
|
elif tf1 and policy_config.get("framework") == "tfe":
|
||||||
tf1.set_random_seed(seed)
|
tf1.set_random_seed(seed)
|
||||||
|
|
||||||
if _has_tensorflow_graph(policy_dict) and not (
|
self._build_policy_map(
|
||||||
tf1 and tf1.executing_eagerly()):
|
policy_dict,
|
||||||
if not tf1:
|
policy_config,
|
||||||
raise ImportError("Could not import tensorflow")
|
session_creator=tf_session_creator,
|
||||||
with tf1.Graph().as_default():
|
seed=seed)
|
||||||
if tf_session_creator:
|
|
||||||
self.tf_sess = tf_session_creator()
|
|
||||||
else:
|
|
||||||
self.tf_sess = tf1.Session(
|
|
||||||
config=tf1.ConfigProto(
|
|
||||||
gpu_options=tf1.GPUOptions(allow_growth=True)))
|
|
||||||
with self.tf_sess.as_default():
|
|
||||||
# set graph-level seed
|
|
||||||
if seed is not None:
|
|
||||||
tf1.set_random_seed(seed)
|
|
||||||
self.policy_map, self.preprocessors = \
|
|
||||||
self._build_policy_map(policy_dict, policy_config)
|
|
||||||
else:
|
|
||||||
self.policy_map, self.preprocessors = self._build_policy_map(
|
|
||||||
policy_dict, policy_config)
|
|
||||||
|
|
||||||
# Update Policy's view requirements from Model, only if Policy directly
|
# Update Policy's view requirements from Model, only if Policy directly
|
||||||
# inherited from base `Policy` class. At this point here, the Policy
|
# inherited from base `Policy` class. At this point here, the Policy
|
||||||
|
@ -591,14 +576,13 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
self.multiagent: bool = set(
|
self.multiagent: bool = set(
|
||||||
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||||
if self.multiagent and self.env is not None:
|
if self.multiagent and self.env is not None:
|
||||||
if not ((isinstance(self.env, MultiAgentEnv)
|
if not isinstance(self.env,
|
||||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv,
|
||||||
or isinstance(self.env, BaseEnv)):
|
ray.actor.ActorHandle)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Have multiple policies {}, but the env ".format(
|
f"Have multiple policies {self.policy_map}, but the "
|
||||||
self.policy_map) +
|
f"env {self.env} is not a subclass of BaseEnv, "
|
||||||
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
|
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!")
|
||||||
"ExternalMultiAgentEnv?".format(self.env))
|
|
||||||
|
|
||||||
self.filters: Dict[PolicyID, Filter] = {
|
self.filters: Dict[PolicyID, Filter] = {
|
||||||
policy_id: get_filter(self.observation_filter,
|
policy_id: get_filter(self.observation_filter,
|
||||||
|
@ -678,7 +662,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
horizon=episode_horizon,
|
horizon=episode_horizon,
|
||||||
multiple_episodes_in_batch=pack,
|
multiple_episodes_in_batch=pack,
|
||||||
tf_sess=self.tf_sess,
|
|
||||||
normalize_actions=normalize_actions,
|
normalize_actions=normalize_actions,
|
||||||
clip_actions=clip_actions,
|
clip_actions=clip_actions,
|
||||||
blackhole_outputs="simulation" in input_evaluation,
|
blackhole_outputs="simulation" in input_evaluation,
|
||||||
|
@ -701,7 +684,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
horizon=episode_horizon,
|
horizon=episode_horizon,
|
||||||
multiple_episodes_in_batch=pack,
|
multiple_episodes_in_batch=pack,
|
||||||
tf_sess=self.tf_sess,
|
|
||||||
normalize_actions=normalize_actions,
|
normalize_actions=normalize_actions,
|
||||||
clip_actions=clip_actions,
|
clip_actions=clip_actions,
|
||||||
soft_horizon=soft_horizon,
|
soft_horizon=soft_horizon,
|
||||||
|
@ -923,23 +905,24 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
summarize(samples)))
|
summarize(samples)))
|
||||||
if isinstance(samples, MultiAgentBatch):
|
if isinstance(samples, MultiAgentBatch):
|
||||||
info_out = {}
|
info_out = {}
|
||||||
|
builders = {}
|
||||||
to_fetch = {}
|
to_fetch = {}
|
||||||
if self.tf_sess is not None:
|
|
||||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
|
||||||
else:
|
|
||||||
builder = None
|
|
||||||
for pid, batch in samples.policy_batches.items():
|
for pid, batch in samples.policy_batches.items():
|
||||||
if pid not in self.policies_to_train:
|
if pid not in self.policies_to_train:
|
||||||
continue
|
continue
|
||||||
# Decompress SampleBatch, in case some columns are compressed.
|
# Decompress SampleBatch, in case some columns are compressed.
|
||||||
batch.decompress_if_needed()
|
batch.decompress_if_needed()
|
||||||
policy = self.policy_map[pid]
|
policy = self.policy_map[pid]
|
||||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
tf_session = policy.get_session()
|
||||||
|
if tf_session and hasattr(policy, "_build_learn_on_batch"):
|
||||||
|
builders[pid] = TFRunBuilder(tf_session, "learn_on_batch")
|
||||||
to_fetch[pid] = policy._build_learn_on_batch(
|
to_fetch[pid] = policy._build_learn_on_batch(
|
||||||
builder, batch)
|
builders[pid], batch)
|
||||||
else:
|
else:
|
||||||
info_out[pid] = policy.learn_on_batch(batch)
|
info_out[pid] = policy.learn_on_batch(batch)
|
||||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
info_out.update(
|
||||||
|
{pid: builders[pid].get(v)
|
||||||
|
for pid, v in to_fetch.items()})
|
||||||
else:
|
else:
|
||||||
info_out = {
|
info_out = {
|
||||||
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
|
DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID]
|
||||||
|
@ -1024,12 +1007,15 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_policy(
|
def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
|
||||||
self, policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID) -> Policy:
|
|
||||||
"""Return policy for the specified id, or None.
|
"""Return policy for the specified id, or None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_id (str): id of policy to return.
|
policy_id (PolicyID): ID of the policy to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Policy]: The policy under the given ID (or None if not
|
||||||
|
found).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.policy_map.get(policy_id)
|
return self.policy_map.get(policy_id)
|
||||||
|
@ -1078,18 +1064,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
policy_dict = {
|
policy_dict = {
|
||||||
policy_id: (policy_cls, observation_space, action_space, config)
|
policy_id: (policy_cls, observation_space, action_space, config)
|
||||||
}
|
}
|
||||||
if self.tf_sess is not None:
|
self._build_policy_map(
|
||||||
with self.tf_sess.graph.as_default():
|
policy_dict,
|
||||||
with self.tf_sess.as_default():
|
self.policy_config,
|
||||||
add_map, add_prep = self._build_policy_map(
|
seed=self.policy_config.get("seed"))
|
||||||
policy_dict, self.policy_config)
|
new_policy = self.policy_map[policy_id]
|
||||||
else:
|
|
||||||
add_map, add_prep = self._build_policy_map(policy_dict,
|
|
||||||
self.policy_config)
|
|
||||||
new_policy = add_map[policy_id]
|
|
||||||
|
|
||||||
self.policy_map.update(add_map)
|
|
||||||
self.preprocessors.update(add_prep)
|
|
||||||
self.filters[policy_id] = get_filter(
|
self.filters[policy_id] = get_filter(
|
||||||
self.observation_filter, new_policy.observation_space.shape)
|
self.observation_filter, new_policy.observation_space.shape)
|
||||||
|
|
||||||
|
@ -1301,12 +1281,27 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
return func(self, *args)
|
return func(self, *args)
|
||||||
|
|
||||||
def _build_policy_map(
|
def _build_policy_map(
|
||||||
self, policy_dict: MultiAgentPolicyConfigDict,
|
self,
|
||||||
policy_config: TrainerConfigDict
|
policy_dict: MultiAgentPolicyConfigDict,
|
||||||
|
policy_config: TrainerConfigDict,
|
||||||
|
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
|
) -> Tuple[Dict[PolicyID, Policy], Dict[PolicyID, Preprocessor]]:
|
||||||
policy_map = {}
|
|
||||||
preprocessors = {}
|
ma_config = policy_config.get("multiagent", {})
|
||||||
for name, (cls, obs_space, act_space,
|
|
||||||
|
self.policy_map = self.policy_map or PolicyMap(
|
||||||
|
worker_index=self.worker_index,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
capacity=ma_config.get("policy_map_capacity"),
|
||||||
|
path=ma_config.get("policy_map_cache"),
|
||||||
|
policy_config=policy_config,
|
||||||
|
session_creator=session_creator,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
self.preprocessors = self.preprocessors or {}
|
||||||
|
|
||||||
|
for name, (orig_cls, obs_space, act_space,
|
||||||
conf) in sorted(policy_dict.items()):
|
conf) in sorted(policy_dict.items()):
|
||||||
logger.debug("Creating policy for {}".format(name))
|
logger.debug("Creating policy for {}".format(name))
|
||||||
merged_conf = merge_dicts(policy_config, conf or {})
|
merged_conf = merge_dicts(policy_config, conf or {})
|
||||||
|
@ -1315,43 +1310,23 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
if self.preprocessing_enabled:
|
if self.preprocessing_enabled:
|
||||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||||
obs_space, merged_conf.get("model"))
|
obs_space, merged_conf.get("model"))
|
||||||
preprocessors[name] = preprocessor
|
self.preprocessors[name] = preprocessor
|
||||||
obs_space = preprocessor.observation_space
|
obs_space = preprocessor.observation_space
|
||||||
else:
|
else:
|
||||||
preprocessors[name] = NoPreprocessor(obs_space)
|
self.preprocessors[name] = NoPreprocessor(obs_space)
|
||||||
|
|
||||||
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
if isinstance(obs_space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Found raw Tuple|Dict space as input to policy. "
|
"Found raw Tuple|Dict space as input to policy. "
|
||||||
"Please preprocess these observations with a "
|
"Please preprocess these observations with a "
|
||||||
"Tuple|DictFlatteningPreprocessor.")
|
"Tuple|DictFlatteningPreprocessor.")
|
||||||
# Tf.
|
|
||||||
framework = policy_config.get("framework", "tf")
|
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
|
||||||
if framework in ["tf2", "tf", "tfe"]:
|
conf, merged_conf)
|
||||||
assert tf1
|
|
||||||
if framework in ["tf2", "tfe"]:
|
|
||||||
assert tf1.executing_eagerly()
|
|
||||||
if hasattr(cls, "as_eager"):
|
|
||||||
cls = cls.as_eager()
|
|
||||||
if policy_config.get("eager_tracing"):
|
|
||||||
cls = cls.with_tracing()
|
|
||||||
elif not issubclass(cls, TFPolicy):
|
|
||||||
pass # could be some other type of policy
|
|
||||||
else:
|
|
||||||
raise ValueError("This policy does not support eager "
|
|
||||||
"execution: {}".format(cls))
|
|
||||||
scope = name + (("_wk" + str(self.worker_index))
|
|
||||||
if self.worker_index else "")
|
|
||||||
with tf1.variable_scope(scope):
|
|
||||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
|
||||||
# non-tf.
|
|
||||||
else:
|
|
||||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
|
||||||
|
|
||||||
if self.worker_index == 0:
|
if self.worker_index == 0:
|
||||||
logger.info("Built policy map: {}".format(policy_map))
|
logger.info(f"Built policy map: {self.policy_map}")
|
||||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
logger.info(f"Built preprocessor map: {self.preprocessors}")
|
||||||
return policy_map, preprocessors
|
|
||||||
|
|
||||||
def setup_torch_data_parallel(self, url: str, world_rank: int,
|
def setup_torch_data_parallel(self, url: str, world_rank: int,
|
||||||
world_size: int, backend: str) -> None:
|
world_size: int, backend: str) -> None:
|
||||||
|
|
|
@ -205,8 +205,7 @@ class MultiAgentSampleBatchBuilder:
|
||||||
post_batches[agent_id] = pre_batch
|
post_batches[agent_id] = pre_batch
|
||||||
if getattr(policy, "exploration", None) is not None:
|
if getattr(policy, "exploration", None) is not None:
|
||||||
policy.exploration.postprocess_trajectory(
|
policy.exploration.postprocess_trajectory(
|
||||||
policy, post_batches[agent_id],
|
policy, post_batches[agent_id], policy.get_session())
|
||||||
getattr(policy, "_sess", None))
|
|
||||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||||
post_batches[agent_id], other_batches, episode)
|
post_batches[agent_id], other_batches, episode)
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,6 @@ from ray.rllib.utils.filter import Filter
|
||||||
from ray.rllib.utils.numpy import convert_to_numpy
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||||||
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
||||||
unsquash_action, unbatch
|
unsquash_action, unbatch
|
||||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
|
||||||
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
||||||
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
|
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
|
||||||
TensorStructType
|
TensorStructType
|
||||||
|
@ -137,7 +136,6 @@ class SyncSampler(SamplerInput):
|
||||||
callbacks: "DefaultCallbacks",
|
callbacks: "DefaultCallbacks",
|
||||||
horizon: int = None,
|
horizon: int = None,
|
||||||
multiple_episodes_in_batch: bool = False,
|
multiple_episodes_in_batch: bool = False,
|
||||||
tf_sess=None,
|
|
||||||
normalize_actions: bool = True,
|
normalize_actions: bool = True,
|
||||||
clip_actions: bool = False,
|
clip_actions: bool = False,
|
||||||
soft_horizon: bool = False,
|
soft_horizon: bool = False,
|
||||||
|
@ -150,6 +148,7 @@ class SyncSampler(SamplerInput):
|
||||||
policy_mapping_fn=None,
|
policy_mapping_fn=None,
|
||||||
preprocessors=None,
|
preprocessors=None,
|
||||||
obs_filters=None,
|
obs_filters=None,
|
||||||
|
tf_sess=None,
|
||||||
):
|
):
|
||||||
"""Initializes a SyncSampler object.
|
"""Initializes a SyncSampler object.
|
||||||
|
|
||||||
|
@ -168,8 +167,6 @@ class SyncSampler(SamplerInput):
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be
|
episodes into each batch. This guarantees batches will be
|
||||||
exactly `rollout_fragment_length` in size.
|
exactly `rollout_fragment_length` in size.
|
||||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
|
||||||
framework=tf).
|
|
||||||
normalize_actions (bool): Whether to normalize actions to the
|
normalize_actions (bool): Whether to normalize actions to the
|
||||||
action space's bounds.
|
action space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions according to the
|
clip_actions (bool): Whether to clip actions according to the
|
||||||
|
@ -199,6 +196,8 @@ class SyncSampler(SamplerInput):
|
||||||
deprecation_warning(old="preprocessors")
|
deprecation_warning(old="preprocessors")
|
||||||
if obs_filters is not None:
|
if obs_filters is not None:
|
||||||
deprecation_warning(old="obs_filters")
|
deprecation_warning(old="obs_filters")
|
||||||
|
if tf_sess is not None:
|
||||||
|
deprecation_warning(old="tf_sess")
|
||||||
|
|
||||||
self.base_env = BaseEnv.to_base_env(env)
|
self.base_env = BaseEnv.to_base_env(env)
|
||||||
self.rollout_fragment_length = rollout_fragment_length
|
self.rollout_fragment_length = rollout_fragment_length
|
||||||
|
@ -221,7 +220,7 @@ class SyncSampler(SamplerInput):
|
||||||
worker, self.base_env, self.extra_batches.put,
|
worker, self.base_env, self.extra_batches.put,
|
||||||
self.rollout_fragment_length, self.horizon, clip_rewards,
|
self.rollout_fragment_length, self.horizon, clip_rewards,
|
||||||
normalize_actions, clip_actions, multiple_episodes_in_batch,
|
normalize_actions, clip_actions, multiple_episodes_in_batch,
|
||||||
callbacks, tf_sess, self.perf_stats, soft_horizon, no_done_at_end,
|
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
|
||||||
observation_fn, self.sample_collector, self.render)
|
observation_fn, self.sample_collector, self.render)
|
||||||
self.metrics_queue = queue.Queue()
|
self.metrics_queue = queue.Queue()
|
||||||
|
|
||||||
|
@ -275,7 +274,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
callbacks: "DefaultCallbacks",
|
callbacks: "DefaultCallbacks",
|
||||||
horizon: int = None,
|
horizon: int = None,
|
||||||
multiple_episodes_in_batch: bool = False,
|
multiple_episodes_in_batch: bool = False,
|
||||||
tf_sess=None,
|
|
||||||
normalize_actions: bool = True,
|
normalize_actions: bool = True,
|
||||||
clip_actions: bool = False,
|
clip_actions: bool = False,
|
||||||
blackhole_outputs: bool = False,
|
blackhole_outputs: bool = False,
|
||||||
|
@ -289,6 +287,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
policy_mapping_fn=None,
|
policy_mapping_fn=None,
|
||||||
preprocessors=None,
|
preprocessors=None,
|
||||||
obs_filters=None,
|
obs_filters=None,
|
||||||
|
tf_sess=None,
|
||||||
):
|
):
|
||||||
"""Initializes a AsyncSampler object.
|
"""Initializes a AsyncSampler object.
|
||||||
|
|
||||||
|
@ -309,8 +308,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
multiple_episodes_in_batch (bool): Whether to pack multiple
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
||||||
episodes into each batch. This guarantees batches will be
|
episodes into each batch. This guarantees batches will be
|
||||||
exactly `rollout_fragment_length` in size.
|
exactly `rollout_fragment_length` in size.
|
||||||
tf_sess (Optional[tf.Session]): A tf.Session object to use (only if
|
|
||||||
framework=tf).
|
|
||||||
normalize_actions (bool): Whether to normalize actions to the
|
normalize_actions (bool): Whether to normalize actions to the
|
||||||
action space's bounds.
|
action space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions according to the
|
clip_actions (bool): Whether to clip actions according to the
|
||||||
|
@ -342,6 +339,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
deprecation_warning(old="preprocessors")
|
deprecation_warning(old="preprocessors")
|
||||||
if obs_filters is not None:
|
if obs_filters is not None:
|
||||||
deprecation_warning(old="obs_filters")
|
deprecation_warning(old="obs_filters")
|
||||||
|
if tf_sess is not None:
|
||||||
|
deprecation_warning(old="tf_sess")
|
||||||
|
|
||||||
self.worker = worker
|
self.worker = worker
|
||||||
|
|
||||||
|
@ -359,7 +358,6 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
self.clip_rewards = clip_rewards
|
self.clip_rewards = clip_rewards
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
||||||
self.tf_sess = tf_sess
|
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.normalize_actions = normalize_actions
|
self.normalize_actions = normalize_actions
|
||||||
self.clip_actions = clip_actions
|
self.clip_actions = clip_actions
|
||||||
|
@ -400,9 +398,9 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
||||||
self.worker, self.base_env, extra_batches_putter,
|
self.worker, self.base_env, extra_batches_putter,
|
||||||
self.rollout_fragment_length, self.horizon, self.clip_rewards,
|
self.rollout_fragment_length, self.horizon, self.clip_rewards,
|
||||||
self.normalize_actions, self.clip_actions,
|
self.normalize_actions, self.clip_actions,
|
||||||
self.multiple_episodes_in_batch, self.callbacks, self.tf_sess,
|
self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
|
||||||
self.perf_stats, self.soft_horizon, self.no_done_at_end,
|
self.soft_horizon, self.no_done_at_end, self.observation_fn,
|
||||||
self.observation_fn, self.sample_collector, self.render)
|
self.sample_collector, self.render)
|
||||||
while not self.shutdown:
|
while not self.shutdown:
|
||||||
# The timeout variable exists because apparently, if one worker
|
# The timeout variable exists because apparently, if one worker
|
||||||
# dies, the other workers won't die with it, unless the timeout is
|
# dies, the other workers won't die with it, unless the timeout is
|
||||||
|
@ -458,7 +456,6 @@ def _env_runner(
|
||||||
clip_actions: bool,
|
clip_actions: bool,
|
||||||
multiple_episodes_in_batch: bool,
|
multiple_episodes_in_batch: bool,
|
||||||
callbacks: "DefaultCallbacks",
|
callbacks: "DefaultCallbacks",
|
||||||
tf_sess: Optional["tf.Session"],
|
|
||||||
perf_stats: _PerfStats,
|
perf_stats: _PerfStats,
|
||||||
soft_horizon: bool,
|
soft_horizon: bool,
|
||||||
no_done_at_end: bool,
|
no_done_at_end: bool,
|
||||||
|
@ -484,8 +481,6 @@ def _env_runner(
|
||||||
space's bounds.
|
space's bounds.
|
||||||
clip_actions (bool): Whether to clip actions to the space range.
|
clip_actions (bool): Whether to clip actions to the space range.
|
||||||
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
||||||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
|
||||||
TF policy evaluations.
|
|
||||||
perf_stats (_PerfStats): Record perf stats into this object.
|
perf_stats (_PerfStats): Record perf stats into this object.
|
||||||
soft_horizon (bool): Calculate rewards but don't reset the
|
soft_horizon (bool): Calculate rewards but don't reset the
|
||||||
environment when the horizon is hit.
|
environment when the horizon is hit.
|
||||||
|
@ -566,7 +561,7 @@ def _env_runner(
|
||||||
policy=p,
|
policy=p,
|
||||||
environment=base_env,
|
environment=base_env,
|
||||||
episode=episode,
|
episode=episode,
|
||||||
tf_sess=getattr(p, "_sess", None))
|
tf_sess=p.get_session())
|
||||||
callbacks.on_episode_start(
|
callbacks.on_episode_start(
|
||||||
worker=worker,
|
worker=worker,
|
||||||
base_env=base_env,
|
base_env=base_env,
|
||||||
|
@ -627,7 +622,6 @@ def _env_runner(
|
||||||
policy_mapping_fn=worker.policy_mapping_fn,
|
policy_mapping_fn=worker.policy_mapping_fn,
|
||||||
sample_collector=sample_collector,
|
sample_collector=sample_collector,
|
||||||
active_episodes=active_episodes,
|
active_episodes=active_episodes,
|
||||||
tf_sess=tf_sess,
|
|
||||||
)
|
)
|
||||||
perf_stats.inference_time += time.time() - t2
|
perf_stats.inference_time += time.time() - t2
|
||||||
|
|
||||||
|
@ -915,7 +909,7 @@ def _process_observations(
|
||||||
policy=p,
|
policy=p,
|
||||||
environment=base_env,
|
environment=base_env,
|
||||||
episode=episode,
|
episode=episode,
|
||||||
tf_sess=getattr(p, "_sess", None))
|
tf_sess=p.get_session())
|
||||||
# Call custom on_episode_end callback.
|
# Call custom on_episode_end callback.
|
||||||
callbacks.on_episode_end(
|
callbacks.on_episode_end(
|
||||||
worker=worker,
|
worker=worker,
|
||||||
|
@ -986,7 +980,6 @@ def _do_policy_eval(
|
||||||
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
|
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
|
||||||
sample_collector,
|
sample_collector,
|
||||||
active_episodes: Dict[str, MultiAgentEpisode],
|
active_episodes: Dict[str, MultiAgentEpisode],
|
||||||
tf_sess: Optional["tf.Session"] = None,
|
|
||||||
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
||||||
"""Call compute_actions on collected episode/model data to get next action.
|
"""Call compute_actions on collected episode/model data to get next action.
|
||||||
|
|
||||||
|
@ -997,8 +990,6 @@ def _do_policy_eval(
|
||||||
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
||||||
obj.
|
obj.
|
||||||
sample_collector (SampleCollector): The SampleCollector object to use.
|
sample_collector (SampleCollector): The SampleCollector object to use.
|
||||||
tf_sess (Optional[tf.Session]): Optional tensorflow session to use for
|
|
||||||
batching TF policy evaluations.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
eval_results: dict of policy to compute_action() outputs.
|
eval_results: dict of policy to compute_action() outputs.
|
||||||
|
@ -1006,12 +997,6 @@ def _do_policy_eval(
|
||||||
|
|
||||||
eval_results: Dict[PolicyID, TensorStructType] = {}
|
eval_results: Dict[PolicyID, TensorStructType] = {}
|
||||||
|
|
||||||
if tf_sess:
|
|
||||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
|
||||||
pending_fetches: Dict[PolicyID, Any] = {}
|
|
||||||
else:
|
|
||||||
builder = None
|
|
||||||
|
|
||||||
if log_once("compute_actions_input"):
|
if log_once("compute_actions_input"):
|
||||||
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
||||||
summarize(to_eval)))
|
summarize(to_eval)))
|
||||||
|
@ -1033,11 +1018,6 @@ def _do_policy_eval(
|
||||||
timestep=policy.global_timestep,
|
timestep=policy.global_timestep,
|
||||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||||
|
|
||||||
if builder:
|
|
||||||
# types: PolicyID, Tuple[TensorStructType, StateBatch, dict]
|
|
||||||
for pid, v in pending_fetches.items():
|
|
||||||
eval_results[pid] = builder.get(v)
|
|
||||||
|
|
||||||
if log_once("compute_actions_result"):
|
if log_once("compute_actions_result"):
|
||||||
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
||||||
summarize(eval_results)))
|
summarize(eval_results)))
|
||||||
|
|
|
@ -182,7 +182,7 @@ class TestRolloutWorker(unittest.TestCase):
|
||||||
0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
|
0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
|
||||||
lr = policy.cur_lr
|
lr = policy.cur_lr
|
||||||
if fw == "tf":
|
if fw == "tf":
|
||||||
lr = policy._sess.run(lr)
|
lr = policy.get_session().run(lr)
|
||||||
check(lr, expected_lr, rtol=0.05)
|
check(lr, expected_lr, rtol=0.05)
|
||||||
agent.stop()
|
agent.stop()
|
||||||
|
|
||||||
|
|
|
@ -78,13 +78,14 @@ class TFMultiGPULearner(LearnerThread):
|
||||||
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||||
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
|
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
|
||||||
|
tf_session = self.policy.get_session()
|
||||||
|
|
||||||
# per-GPU graph copies created below must share vars with the policy
|
# per-GPU graph copies created below must share vars with the policy
|
||||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||||
# all of the device copies are created.
|
# all of the device copies are created.
|
||||||
self.par_opt = []
|
self.par_opt = []
|
||||||
with self.local_worker.tf_sess.graph.as_default():
|
with tf_session.graph.as_default():
|
||||||
with self.local_worker.tf_sess.as_default():
|
with tf_session.as_default():
|
||||||
with tf1.variable_scope(
|
with tf1.variable_scope(
|
||||||
DEFAULT_POLICY_ID, reuse=tf1.AUTO_REUSE):
|
DEFAULT_POLICY_ID, reuse=tf1.AUTO_REUSE):
|
||||||
if self.policy._state_inputs:
|
if self.policy._state_inputs:
|
||||||
|
@ -106,7 +107,7 @@ class TFMultiGPULearner(LearnerThread):
|
||||||
999999, # it will get rounded down
|
999999, # it will get rounded down
|
||||||
self.policy.copy))
|
self.policy.copy))
|
||||||
|
|
||||||
self.sess = self.local_worker.tf_sess
|
self.sess = tf_session
|
||||||
self.sess.run(tf1.global_variables_initializer())
|
self.sess.run(tf1.global_variables_initializer())
|
||||||
|
|
||||||
self.idle_optimizers = queue.Queue()
|
self.idle_optimizers = queue.Queue()
|
||||||
|
|
|
@ -148,14 +148,9 @@ class TrainTFMultiGPU:
|
||||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||||
# all of the device copies are created.
|
# all of the device copies are created.
|
||||||
self.optimizers = {}
|
self.optimizers = {}
|
||||||
with self.workers.local_worker().tf_sess.graph.as_default():
|
for policy_id in (self.policies
|
||||||
with self.workers.local_worker().tf_sess.as_default():
|
or self.local_worker.policies_to_train):
|
||||||
for policy_id in (self.policies
|
self.add_optimizer(policy_id)
|
||||||
or self.local_worker.policies_to_train):
|
|
||||||
self.add_optimizer(policy_id)
|
|
||||||
|
|
||||||
self.sess = self.workers.local_worker().tf_sess
|
|
||||||
self.sess.run(tf1.global_variables_initializer())
|
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(self,
|
||||||
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
|
samples: SampleBatchType) -> (SampleBatchType, List[dict]):
|
||||||
|
@ -181,10 +176,7 @@ class TrainTFMultiGPU:
|
||||||
# Policy seems to be new and doesn't have an optimizer yet.
|
# Policy seems to be new and doesn't have an optimizer yet.
|
||||||
# Add it here and continue.
|
# Add it here and continue.
|
||||||
elif policy_id not in self.optimizers:
|
elif policy_id not in self.optimizers:
|
||||||
with self.workers.local_worker().tf_sess.graph.as_default(
|
self.add_optimizer(policy_id)
|
||||||
):
|
|
||||||
with self.workers.local_worker().tf_sess.as_default():
|
|
||||||
self.add_optimizer(policy_id)
|
|
||||||
|
|
||||||
# Decompress SampleBatch, in case some columns are compressed.
|
# Decompress SampleBatch, in case some columns are compressed.
|
||||||
batch.decompress_if_needed()
|
batch.decompress_if_needed()
|
||||||
|
@ -200,13 +192,14 @@ class TrainTFMultiGPU:
|
||||||
state_keys = []
|
state_keys = []
|
||||||
num_loaded_tuples[policy_id] = (
|
num_loaded_tuples[policy_id] = (
|
||||||
self.optimizers[policy_id].load_data(
|
self.optimizers[policy_id].load_data(
|
||||||
self.sess, [tuples[k] for k in data_keys],
|
policy.get_session(), [tuples[k] for k in data_keys],
|
||||||
[tuples[k] for k in state_keys]))
|
[tuples[k] for k in state_keys]))
|
||||||
|
|
||||||
# Execute minibatch SGD on loaded data.
|
# Execute minibatch SGD on loaded data.
|
||||||
with learn_timer:
|
with learn_timer:
|
||||||
fetches = {}
|
fetches = {}
|
||||||
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
||||||
|
policy = self.workers.local_worker().get_policy(policy_id)
|
||||||
optimizer = self.optimizers[policy_id]
|
optimizer = self.optimizers[policy_id]
|
||||||
num_batches = max(
|
num_batches = max(
|
||||||
1,
|
1,
|
||||||
|
@ -217,7 +210,7 @@ class TrainTFMultiGPU:
|
||||||
batch_fetches_all_towers = []
|
batch_fetches_all_towers = []
|
||||||
for batch_index in range(num_batches):
|
for batch_index in range(num_batches):
|
||||||
batch_fetches = optimizer.optimize(
|
batch_fetches = optimizer.optimize(
|
||||||
self.sess, permutation[batch_index] *
|
policy.get_session(), permutation[batch_index] *
|
||||||
self.per_device_batch_size)
|
self.per_device_batch_size)
|
||||||
|
|
||||||
batch_fetches_all_towers.append(
|
batch_fetches_all_towers.append(
|
||||||
|
@ -250,15 +243,20 @@ class TrainTFMultiGPU:
|
||||||
|
|
||||||
def add_optimizer(self, policy_id):
|
def add_optimizer(self, policy_id):
|
||||||
policy = self.workers.local_worker().get_policy(policy_id)
|
policy = self.workers.local_worker().get_policy(policy_id)
|
||||||
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
|
tf_session = policy.get_session()
|
||||||
if policy._state_inputs:
|
with tf_session.graph.as_default():
|
||||||
rnn_inputs = policy._state_inputs + [policy._seq_lens]
|
with tf_session.as_default():
|
||||||
else:
|
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
|
||||||
rnn_inputs = []
|
if policy._state_inputs:
|
||||||
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
|
rnn_inputs = policy._state_inputs + [policy._seq_lens]
|
||||||
policy._optimizer, self.devices,
|
else:
|
||||||
list(policy._loss_input_dict_no_rnn.values()), rnn_inputs,
|
rnn_inputs = []
|
||||||
self.per_device_batch_size, policy.copy))
|
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
|
||||||
|
policy._optimizer, self.devices,
|
||||||
|
list(policy._loss_input_dict_no_rnn.values()),
|
||||||
|
rnn_inputs, self.per_device_batch_size, policy.copy))
|
||||||
|
|
||||||
|
tf_session.run(tf1.global_variables_initializer())
|
||||||
|
|
||||||
|
|
||||||
def all_tower_reduce(path, *tower_data):
|
def all_tower_reduce(path, *tower_data):
|
||||||
|
|
|
@ -166,8 +166,7 @@ class JsonReader(InputReader):
|
||||||
self.ioctx.worker.policy_map[pid].action_space_struct)
|
self.ioctx.worker.policy_map[pid].action_space_struct)
|
||||||
# Re-normalize actions (from env's bounds to 0.0 centered), if
|
# Re-normalize actions (from env's bounds to 0.0 centered), if
|
||||||
# necessary.
|
# necessary.
|
||||||
if "actions_in_input_normalized" in cfg and \
|
if cfg.get("actions_in_input_normalized") is False:
|
||||||
cfg["actions_in_input_normalized"] is False:
|
|
||||||
if isinstance(batch, SampleBatch):
|
if isinstance(batch, SampleBatch):
|
||||||
batch[SampleBatch.ACTIONS] = normalize_action(
|
batch[SampleBatch.ACTIONS] = normalize_action(
|
||||||
batch[SampleBatch.ACTIONS], self.ioctx.worker.policy_map[
|
batch[SampleBatch.ACTIONS], self.ioctx.worker.policy_map[
|
||||||
|
|
|
@ -467,7 +467,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
self._optimizer = self.optimizer()
|
self._optimizer = self.optimizer()
|
||||||
|
|
||||||
# Test calls depend on variable init, so initialize model first.
|
# Test calls depend on variable init, so initialize model first.
|
||||||
self._sess.run(tf1.global_variables_initializer())
|
self.get_session().run(tf1.global_variables_initializer())
|
||||||
|
|
||||||
logger.info("Testing `compute_actions` w/ dummy batch.")
|
logger.info("Testing `compute_actions` w/ dummy batch.")
|
||||||
actions, state_outs, extra_fetches = \
|
actions, state_outs, extra_fetches = \
|
||||||
|
@ -486,7 +486,8 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
dummy_batch = self._dummy_batch
|
dummy_batch = self._dummy_batch
|
||||||
|
|
||||||
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
logger.info("Testing `postprocess_trajectory` w/ dummy batch.")
|
||||||
self.exploration.postprocess_trajectory(self, dummy_batch, self._sess)
|
self.exploration.postprocess_trajectory(self, dummy_batch,
|
||||||
|
self.get_session())
|
||||||
_ = self.postprocess_trajectory(dummy_batch)
|
_ = self.postprocess_trajectory(dummy_batch)
|
||||||
# Add new columns automatically to (loss) input_dict.
|
# Add new columns automatically to (loss) input_dict.
|
||||||
for key in dummy_batch.added_keys:
|
for key in dummy_batch.added_keys:
|
||||||
|
@ -592,7 +593,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize again after loss init.
|
# Initialize again after loss init.
|
||||||
self._sess.run(tf1.global_variables_initializer())
|
self.get_session().run(tf1.global_variables_initializer())
|
||||||
|
|
||||||
def _do_loss_init(self, train_batch: SampleBatch):
|
def _do_loss_init(self, train_batch: SampleBatch):
|
||||||
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
loss = self._loss_fn(self, self.model, self.dist_class, train_batch)
|
||||||
|
|
|
@ -255,7 +255,6 @@ def build_eager_tf_policy(
|
||||||
|
|
||||||
self._is_training = False
|
self._is_training = False
|
||||||
self._loss_initialized = False
|
self._loss_initialized = False
|
||||||
self._sess = None
|
|
||||||
|
|
||||||
self._loss = loss_fn
|
self._loss = loss_fn
|
||||||
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
self.batch_divisibility_req = get_batch_divisibility_req(self) if \
|
||||||
|
|
|
@ -461,6 +461,9 @@ class Policy(metaclass=ABCMeta):
|
||||||
def get_weights(self) -> ModelWeights:
|
def get_weights(self) -> ModelWeights:
|
||||||
"""Returns model weights.
|
"""Returns model weights.
|
||||||
|
|
||||||
|
Note: The return value of this method will reside under the "weights"
|
||||||
|
key in the return value of Policy.get_state().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelWeights: Serializable copy or view of model weights.
|
ModelWeights: Serializable copy or view of model weights.
|
||||||
"""
|
"""
|
||||||
|
@ -468,7 +471,7 @@ class Policy(metaclass=ABCMeta):
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def set_weights(self, weights: ModelWeights) -> None:
|
def set_weights(self, weights: ModelWeights) -> None:
|
||||||
"""Sets model weights.
|
"""Sets this Policy's model's weights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weights (ModelWeights): Serializable copy or view of model weights.
|
weights (ModelWeights): Serializable copy or view of model weights.
|
||||||
|
@ -523,12 +526,16 @@ class Policy(metaclass=ABCMeta):
|
||||||
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
|
||||||
"""Returns all local state.
|
"""Returns all local state.
|
||||||
|
|
||||||
|
Note: Not to be confused with an RNN model's internal state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
|
Union[Dict[str, TensorType], List[TensorType]]: Serialized local
|
||||||
state.
|
state.
|
||||||
"""
|
"""
|
||||||
state = {
|
state = {
|
||||||
|
# All the policy's weights.
|
||||||
"weights": self.get_weights(),
|
"weights": self.get_weights(),
|
||||||
|
# The current global timestep.
|
||||||
"global_timestep": self.global_timestep,
|
"global_timestep": self.global_timestep,
|
||||||
}
|
}
|
||||||
return state
|
return state
|
||||||
|
@ -581,6 +588,16 @@ class Policy(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def get_session(self) -> Optional["tf1.Session"]:
|
||||||
|
"""Returns tf.Session object to use for computing actions or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[tf1.Session]: The tf Session to use for computing actions
|
||||||
|
and losses with this policy.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
def _create_exploration(self) -> Exploration:
|
def _create_exploration(self) -> Exploration:
|
||||||
"""Creates the Policy's Exploration object.
|
"""Creates the Policy's Exploration object.
|
||||||
|
|
||||||
|
|
292
rllib/policy/policy_map.py
Normal file
292
rllib/policy/policy_map.py
Normal file
|
@ -0,0 +1,292 @@
|
||||||
|
from collections import deque
|
||||||
|
import gym
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from typing import Callable, Dict, Optional, Type, TYPE_CHECKING
|
||||||
|
|
||||||
|
from ray.rllib.policy.policy import PolicySpec
|
||||||
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
from ray.rllib.utils.tf_ops import get_tf_eager_cls_if_necessary
|
||||||
|
from ray.rllib.utils.typing import PartialTrainerConfigDict, \
|
||||||
|
PolicyID, TrainerConfigDict
|
||||||
|
from ray.tune.utils.util import merge_dicts
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.rllib.policy.policy import Policy
|
||||||
|
|
||||||
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyMap(dict):
|
||||||
|
"""Maps policy IDs to Policy objects.
|
||||||
|
|
||||||
|
Thereby, keeps n policies in memory and - when capacity is reached -
|
||||||
|
writes the least recently used to disk. This allows adding 100s of
|
||||||
|
policies to a Trainer for league-based setups w/o running out of memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
worker_index: int,
|
||||||
|
num_workers: int,
|
||||||
|
capacity: Optional[int] = None,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
policy_config: Optional[TrainerConfigDict] = None,
|
||||||
|
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Initializes a PolicyMap instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maxlen (int): The maximum number of policies to hold in memory.
|
||||||
|
The least used ones are written to disk/S3 and retrieved
|
||||||
|
when needed.
|
||||||
|
path (str):
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.worker_index = worker_index
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.session_creator = session_creator
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
# The file extension for stashed policies (that are no longer available
|
||||||
|
# in-memory but can be reinstated any time from storage).
|
||||||
|
self.extension = ".policy.pkl"
|
||||||
|
|
||||||
|
# Dictionary of keys that may be looked up (cached or not).
|
||||||
|
self.valid_keys = set()
|
||||||
|
# The actual cache with the in-memory policy objects.
|
||||||
|
self.cache = {}
|
||||||
|
# The doubly-linked list holding the currently in-memory objects.
|
||||||
|
self.deque = deque(maxlen=capacity or 10)
|
||||||
|
# The file path where to store overflowing policies.
|
||||||
|
self.path = path or "."
|
||||||
|
# The core config to use. Each single policy's config override is
|
||||||
|
# added on top of this.
|
||||||
|
self.policy_config = policy_config or {}
|
||||||
|
# The orig classes/obs+act spaces, and config overrides of the
|
||||||
|
# Policies.
|
||||||
|
self.policy_specs = {} # type: Dict[PolicyID, PolicySpec]
|
||||||
|
|
||||||
|
def create_policy(self, policy_id: PolicyID, policy_cls: Type["Policy"],
|
||||||
|
observation_space: gym.Space, action_space: gym.Space,
|
||||||
|
config_override: PartialTrainerConfigDict,
|
||||||
|
merged_config: TrainerConfigDict) -> None:
|
||||||
|
"""Creates a new policy and stores it to the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_id (PolicyID): The policy ID. This is the key under which
|
||||||
|
the created policy will be stored in this map.
|
||||||
|
policy_cls (Type[Policy]): The (original) policy class to use.
|
||||||
|
This may still be altered in case tf-eager (and tracing)
|
||||||
|
is used.
|
||||||
|
observation_space (gym.Space): The observation space of the
|
||||||
|
policy.
|
||||||
|
action_space (gym.Space): The action space of the policy.
|
||||||
|
config_override (PartialTrainerConfigDict): The config override
|
||||||
|
dict for this policy. This is the partial dict provided by
|
||||||
|
the user.
|
||||||
|
merged_config (TrainerConfigDict): The entire config (merged
|
||||||
|
default config + `config_override`).
|
||||||
|
"""
|
||||||
|
framework = merged_config.get("framework", "tf")
|
||||||
|
class_ = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
|
||||||
|
|
||||||
|
# Tf.
|
||||||
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
|
var_scope = policy_id + (("_wk" + str(self.worker_index))
|
||||||
|
if self.worker_index else "")
|
||||||
|
|
||||||
|
# For tf static graph, build every policy in its own graph
|
||||||
|
# and create a new session for it.
|
||||||
|
if framework == "tf":
|
||||||
|
with tf1.Graph().as_default():
|
||||||
|
if self.session_creator:
|
||||||
|
sess = self.session_creator()
|
||||||
|
else:
|
||||||
|
sess = tf1.Session(
|
||||||
|
config=tf1.ConfigProto(
|
||||||
|
gpu_options=tf1.GPUOptions(allow_growth=True)))
|
||||||
|
with sess.as_default():
|
||||||
|
# Set graph-level seed.
|
||||||
|
if self.seed is not None:
|
||||||
|
tf1.set_random_seed(self.seed)
|
||||||
|
with tf1.variable_scope(var_scope):
|
||||||
|
self[policy_id] = class_(
|
||||||
|
observation_space, action_space, merged_config)
|
||||||
|
# For tf-eager: no graph, no session.
|
||||||
|
else:
|
||||||
|
with tf1.variable_scope(var_scope):
|
||||||
|
self[policy_id] = \
|
||||||
|
class_(observation_space, action_space, merged_config)
|
||||||
|
# Non-tf: No graph, no session.
|
||||||
|
else:
|
||||||
|
class_ = policy_cls
|
||||||
|
self[policy_id] = class_(observation_space, action_space,
|
||||||
|
merged_config)
|
||||||
|
|
||||||
|
# Store spec (class, obs-space, act-space, and config overrides) such
|
||||||
|
# that the map will be able to reproduce on-the-fly added policies
|
||||||
|
# from disk.
|
||||||
|
self.policy_specs[policy_id] = PolicySpec(
|
||||||
|
policy_class=policy_cls,
|
||||||
|
observation_space=observation_space,
|
||||||
|
action_space=action_space,
|
||||||
|
config=config_override)
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
# Never seen this key -> Error.
|
||||||
|
if item not in self.valid_keys:
|
||||||
|
raise KeyError(f"'{item}' not a valid key!")
|
||||||
|
|
||||||
|
# Item already in cache -> Rearrange deque (least recently used) and
|
||||||
|
# return.
|
||||||
|
if item in self.cache:
|
||||||
|
self.deque.remove(item)
|
||||||
|
self.deque.append(item)
|
||||||
|
# Item not currently in cache -> Get from disk and - if at capacity -
|
||||||
|
# remove leftmost one.
|
||||||
|
else:
|
||||||
|
self._read_from_disk(policy_id=item)
|
||||||
|
|
||||||
|
return self.cache[item]
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Item already in cache -> Rearrange deque (least recently used).
|
||||||
|
if key in self.cache:
|
||||||
|
self.deque.remove(key)
|
||||||
|
self.deque.append(key)
|
||||||
|
self.cache[key] = value
|
||||||
|
# Item not currently in cache -> store new value and - if at capacity -
|
||||||
|
# remove leftmost one.
|
||||||
|
else:
|
||||||
|
# Cache at capacity -> Drop leftmost item.
|
||||||
|
if len(self.deque) == self.deque.maxlen:
|
||||||
|
self._stash_to_disk()
|
||||||
|
self.deque.append(key)
|
||||||
|
self.cache[key] = value
|
||||||
|
self.valid_keys.add(key)
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __delitem__(self, key):
|
||||||
|
# Make key invalid.
|
||||||
|
self.valid_keys.remove(key)
|
||||||
|
# Remove policy from memory if currently cached.
|
||||||
|
if key in self.cache:
|
||||||
|
policy = self.cache[key]
|
||||||
|
self._close_session(policy)
|
||||||
|
del self.cache[key]
|
||||||
|
# Remove file associated with the policy, if it exists.
|
||||||
|
filename = self.path + "/" + key + self.extension
|
||||||
|
if os.path.isfile(filename):
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __iter__(self):
|
||||||
|
return self.keys()
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def items(self):
|
||||||
|
"""Iterates over all policies, even the stashed-to-disk ones."""
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
for key in self.valid_keys:
|
||||||
|
yield (key, self[key])
|
||||||
|
|
||||||
|
return gen()
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def keys(self):
|
||||||
|
def gen():
|
||||||
|
for key in self.valid_keys:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
return gen()
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def values(self):
|
||||||
|
def gen():
|
||||||
|
for key in self.valid_keys:
|
||||||
|
yield self[key]
|
||||||
|
|
||||||
|
return gen()
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def update(self, __m, **kwargs):
|
||||||
|
for k, v in __m.items():
|
||||||
|
self[k] = v
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def get(self, key):
|
||||||
|
if key not in self.valid_keys:
|
||||||
|
return None
|
||||||
|
return self[key]
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __len__(self):
|
||||||
|
"""Returns number of all policies, including the stashed-to-disk ones.
|
||||||
|
"""
|
||||||
|
return len(self.valid_keys)
|
||||||
|
|
||||||
|
@override(dict)
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self.valid_keys
|
||||||
|
|
||||||
|
def _stash_to_disk(self):
|
||||||
|
"""Writes the least-recently used policy to disk and rearranges cache.
|
||||||
|
|
||||||
|
Also closes the session - if applicable - of the stashed policy.
|
||||||
|
"""
|
||||||
|
# Get least recently used policy (all the way on the left in deque).
|
||||||
|
delkey = self.deque.popleft()
|
||||||
|
policy = self.cache[delkey]
|
||||||
|
# Get its state for writing to disk.
|
||||||
|
policy_state = policy.get_state()
|
||||||
|
# Closes policy's tf session, if any.
|
||||||
|
self._close_session(policy)
|
||||||
|
# Remove from memory.
|
||||||
|
# TODO: (sven) This should clear the tf Graph as well, if the Trainer
|
||||||
|
# would not hold parts of the graph (e.g. in tf multi-GPU setups).
|
||||||
|
del self.cache[delkey]
|
||||||
|
# Write state to disk.
|
||||||
|
with open(self.path + "/" + delkey + self.extension, "wb") as f:
|
||||||
|
pickle.dump(policy_state, file=f)
|
||||||
|
|
||||||
|
def _read_from_disk(self, policy_id):
|
||||||
|
"""Reads a policy ID from disk and re-adds it to the cache.
|
||||||
|
"""
|
||||||
|
# Make sure this policy ID is not in the cache right now.
|
||||||
|
assert policy_id not in self.cache
|
||||||
|
# Read policy state from disk.
|
||||||
|
with open(self.path + "/" + policy_id + self.extension, "rb") as f:
|
||||||
|
policy_state = pickle.load(f)
|
||||||
|
|
||||||
|
# Get class and config override.
|
||||||
|
merged_conf = merge_dicts(self.policy_config,
|
||||||
|
self.policy_specs[policy_id].config)
|
||||||
|
|
||||||
|
# Create policy object (from its spec: cls, obs-space, act-space,
|
||||||
|
# config).
|
||||||
|
self.create_policy(
|
||||||
|
policy_id,
|
||||||
|
self.policy_specs[policy_id].policy_class,
|
||||||
|
self.policy_specs[policy_id].observation_space,
|
||||||
|
self.policy_specs[policy_id].action_space,
|
||||||
|
self.policy_specs[policy_id].config,
|
||||||
|
merged_conf,
|
||||||
|
)
|
||||||
|
# Restore policy's state.
|
||||||
|
policy = self[policy_id]
|
||||||
|
policy.set_state(policy_state)
|
||||||
|
|
||||||
|
def _close_session(self, policy):
|
||||||
|
sess = policy.get_session()
|
||||||
|
# Closes the tf session, if any.
|
||||||
|
if sess is not None:
|
||||||
|
sess.close()
|
|
@ -263,7 +263,8 @@ class TFPolicy(Policy):
|
||||||
"`get_placeholder()` can be called"
|
"`get_placeholder()` can be called"
|
||||||
return self._loss_input_dict[name]
|
return self._loss_input_dict[name]
|
||||||
|
|
||||||
def get_session(self) -> "tf1.Session":
|
@override(Policy)
|
||||||
|
def get_session(self) -> Optional["tf1.Session"]:
|
||||||
"""Returns a reference to the TF session for this policy."""
|
"""Returns a reference to the TF session for this policy."""
|
||||||
return self._sess
|
return self._sess
|
||||||
|
|
||||||
|
@ -305,7 +306,7 @@ class TFPolicy(Policy):
|
||||||
|
|
||||||
if self.model:
|
if self.model:
|
||||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||||
[], self._sess, self.variables())
|
[], self.get_session(), self.variables())
|
||||||
|
|
||||||
# gather update ops for any batch norm layers
|
# gather update ops for any batch norm layers
|
||||||
if not self._update_ops:
|
if not self._update_ops:
|
||||||
|
@ -323,12 +324,12 @@ class TFPolicy(Policy):
|
||||||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||||
summarize(self._loss_input_dict)))
|
summarize(self._loss_input_dict)))
|
||||||
|
|
||||||
self._sess.run(tf1.global_variables_initializer())
|
self.get_session().run(tf1.global_variables_initializer())
|
||||||
self._optimizer_variables = None
|
self._optimizer_variables = None
|
||||||
if self._optimizer:
|
if self._optimizer:
|
||||||
self._optimizer_variables = \
|
self._optimizer_variables = \
|
||||||
ray.experimental.tf_utils.TensorFlowVariables(
|
ray.experimental.tf_utils.TensorFlowVariables(
|
||||||
self._optimizer.variables(), self._sess)
|
self._optimizer.variables(), self.get_session())
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def compute_actions(
|
def compute_actions(
|
||||||
|
@ -346,7 +347,7 @@ class TFPolicy(Policy):
|
||||||
explore = explore if explore is not None else self.config["explore"]
|
explore = explore if explore is not None else self.config["explore"]
|
||||||
timestep = timestep if timestep is not None else self.global_timestep
|
timestep = timestep if timestep is not None else self.global_timestep
|
||||||
|
|
||||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
builder = TFRunBuilder(self.get_session(), "compute_actions")
|
||||||
to_fetch = self._build_compute_actions(
|
to_fetch = self._build_compute_actions(
|
||||||
builder,
|
builder,
|
||||||
obs_batch=obs_batch,
|
obs_batch=obs_batch,
|
||||||
|
@ -378,7 +379,8 @@ class TFPolicy(Policy):
|
||||||
explore = explore if explore is not None else self.config["explore"]
|
explore = explore if explore is not None else self.config["explore"]
|
||||||
timestep = timestep if timestep is not None else self.global_timestep
|
timestep = timestep if timestep is not None else self.global_timestep
|
||||||
|
|
||||||
builder = TFRunBuilder(self._sess, "compute_actions_from_input_dict")
|
builder = TFRunBuilder(self.get_session(),
|
||||||
|
"compute_actions_from_input_dict")
|
||||||
obs_batch = input_dict[SampleBatch.OBS]
|
obs_batch = input_dict[SampleBatch.OBS]
|
||||||
to_fetch = self._build_compute_actions(
|
to_fetch = self._build_compute_actions(
|
||||||
builder, input_dict=input_dict, explore=explore, timestep=timestep)
|
builder, input_dict=input_dict, explore=explore, timestep=timestep)
|
||||||
|
@ -413,7 +415,7 @@ class TFPolicy(Policy):
|
||||||
self.exploration.before_compute_actions(
|
self.exploration.before_compute_actions(
|
||||||
explore=False, tf_sess=self.get_session())
|
explore=False, tf_sess=self.get_session())
|
||||||
|
|
||||||
builder = TFRunBuilder(self._sess, "compute_log_likelihoods")
|
builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods")
|
||||||
|
|
||||||
# Normalize actions if necessary.
|
# Normalize actions if necessary.
|
||||||
if actions_normalized is False and self.config["normalize_actions"]:
|
if actions_normalized is False and self.config["normalize_actions"]:
|
||||||
|
@ -451,7 +453,7 @@ class TFPolicy(Policy):
|
||||||
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||||
assert self.loss_initialized()
|
assert self.loss_initialized()
|
||||||
|
|
||||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
builder = TFRunBuilder(self.get_session(), "learn_on_batch")
|
||||||
|
|
||||||
# Callback handling.
|
# Callback handling.
|
||||||
learn_stats = {}
|
learn_stats = {}
|
||||||
|
@ -470,7 +472,7 @@ class TFPolicy(Policy):
|
||||||
postprocessed_batch: SampleBatch) -> \
|
postprocessed_batch: SampleBatch) -> \
|
||||||
Tuple[ModelGradients, Dict[str, TensorType]]:
|
Tuple[ModelGradients, Dict[str, TensorType]]:
|
||||||
assert self.loss_initialized()
|
assert self.loss_initialized()
|
||||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
builder = TFRunBuilder(self.get_session(), "compute_gradients")
|
||||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||||
return builder.get(fetches)
|
return builder.get(fetches)
|
||||||
|
|
||||||
|
@ -478,7 +480,7 @@ class TFPolicy(Policy):
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def apply_gradients(self, gradients: ModelGradients) -> None:
|
def apply_gradients(self, gradients: ModelGradients) -> None:
|
||||||
assert self.loss_initialized()
|
assert self.loss_initialized()
|
||||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
builder = TFRunBuilder(self.get_session(), "apply_gradients")
|
||||||
fetches = self._build_apply_gradients(builder, gradients)
|
fetches = self._build_apply_gradients(builder, gradients)
|
||||||
builder.get(fetches)
|
builder.get(fetches)
|
||||||
|
|
||||||
|
@ -510,7 +512,7 @@ class TFPolicy(Policy):
|
||||||
if self._optimizer_variables and \
|
if self._optimizer_variables and \
|
||||||
len(self._optimizer_variables.variables) > 0:
|
len(self._optimizer_variables.variables) > 0:
|
||||||
state["_optimizer_variables"] = \
|
state["_optimizer_variables"] = \
|
||||||
self._sess.run(self._optimizer_variables.variables)
|
self.get_session().run(self._optimizer_variables.variables)
|
||||||
# Add exploration state.
|
# Add exploration state.
|
||||||
state["_exploration_state"] = \
|
state["_exploration_state"] = \
|
||||||
self.exploration.get_state(self.get_session())
|
self.exploration.get_state(self.get_session())
|
||||||
|
@ -546,7 +548,7 @@ class TFPolicy(Policy):
|
||||||
"`tf2onnx` to be installed. Install with "
|
"`tf2onnx` to be installed. Install with "
|
||||||
"`pip install tf2onnx`.") from e
|
"`pip install tf2onnx`.") from e
|
||||||
|
|
||||||
with self._sess.graph.as_default():
|
with self.get_session().graph.as_default():
|
||||||
signature_def_map = self._build_signature_def()
|
signature_def_map = self._build_signature_def()
|
||||||
|
|
||||||
sd = signature_def_map[tf1.saved_model.signature_constants.
|
sd = signature_def_map[tf1.saved_model.signature_constants.
|
||||||
|
@ -558,7 +560,7 @@ class TFPolicy(Policy):
|
||||||
frozen_graph_def = tf_loader.freeze_session(
|
frozen_graph_def = tf_loader.freeze_session(
|
||||||
self._sess, input_names=inputs, output_names=outputs)
|
self._sess, input_names=inputs, output_names=outputs)
|
||||||
|
|
||||||
with tf.compat.v1.Session(graph=tf.Graph()) as session:
|
with tf1.Session(graph=tf.Graph()) as session:
|
||||||
tf.import_graph_def(frozen_graph_def, name="")
|
tf.import_graph_def(frozen_graph_def, name="")
|
||||||
|
|
||||||
g = tf2onnx.tfonnx.process_tf_graph(
|
g = tf2onnx.tfonnx.process_tf_graph(
|
||||||
|
@ -574,14 +576,15 @@ class TFPolicy(Policy):
|
||||||
feed_dict={},
|
feed_dict={},
|
||||||
model_proto=model_proto)
|
model_proto=model_proto)
|
||||||
else:
|
else:
|
||||||
with self._sess.graph.as_default():
|
with self.get_session().graph.as_default():
|
||||||
signature_def_map = self._build_signature_def()
|
signature_def_map = self._build_signature_def()
|
||||||
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
|
builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
self._sess, [tf1.saved_model.tag_constants.SERVING],
|
self.get_session(),
|
||||||
|
[tf1.saved_model.tag_constants.SERVING],
|
||||||
signature_def_map=signature_def_map,
|
signature_def_map=signature_def_map,
|
||||||
saver=tf1.summary.FileWriter(export_dir).add_graph(
|
saver=tf1.summary.FileWriter(export_dir).add_graph(
|
||||||
graph=self._sess.graph))
|
graph=self.get_session().graph))
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
# TODO: (sven) Deprecate this in favor of `save()`.
|
# TODO: (sven) Deprecate this in favor of `save()`.
|
||||||
|
@ -599,17 +602,17 @@ class TFPolicy(Policy):
|
||||||
if e.errno != errno.EEXIST:
|
if e.errno != errno.EEXIST:
|
||||||
raise
|
raise
|
||||||
save_path = os.path.join(export_dir, filename_prefix)
|
save_path = os.path.join(export_dir, filename_prefix)
|
||||||
with self._sess.graph.as_default():
|
with self.get_session().graph.as_default():
|
||||||
saver = tf1.train.Saver()
|
saver = tf1.train.Saver()
|
||||||
saver.save(self._sess, save_path)
|
saver.save(self.get_session(), save_path)
|
||||||
|
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def import_model_from_h5(self, import_file: str) -> None:
|
def import_model_from_h5(self, import_file: str) -> None:
|
||||||
"""Imports weights into tf model."""
|
"""Imports weights into tf model."""
|
||||||
# Make sure the session is the right one (see issue #7046).
|
# Make sure the session is the right one (see issue #7046).
|
||||||
with self._sess.graph.as_default():
|
with self.get_session().graph.as_default():
|
||||||
with self._sess.as_default():
|
with self.get_session().as_default():
|
||||||
return self.model.import_from_h5(import_file)
|
return self.model.import_from_h5(import_file)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
@ -1026,7 +1029,7 @@ class LearningRateSchedule:
|
||||||
if self._lr_schedule is not None:
|
if self._lr_schedule is not None:
|
||||||
new_val = self._lr_schedule.value(global_vars["timestep"])
|
new_val = self._lr_schedule.value(global_vars["timestep"])
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
self._sess.run(
|
self.get_session().run(
|
||||||
self._lr_update, feed_dict={self._lr_placeholder: new_val})
|
self._lr_update, feed_dict={self._lr_placeholder: new_val})
|
||||||
else:
|
else:
|
||||||
self.cur_lr.assign(new_val, read_value=False)
|
self.cur_lr.assign(new_val, read_value=False)
|
||||||
|
@ -1082,7 +1085,7 @@ class EntropyCoeffSchedule:
|
||||||
new_val = self._entropy_coeff_schedule.value(
|
new_val = self._entropy_coeff_schedule.value(
|
||||||
global_vars["timestep"])
|
global_vars["timestep"])
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
self._sess.run(
|
self.get_session().run(
|
||||||
self._entropy_coeff_update,
|
self._entropy_coeff_update,
|
||||||
feed_dict={self._entropy_coeff_placeholder: new_val})
|
feed_dict={self._entropy_coeff_placeholder: new_val})
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -45,7 +45,7 @@ class TestParameterNoise(unittest.TestCase):
|
||||||
|
|
||||||
trainer = trainer_cls(config=config, env=env)
|
trainer = trainer_cls(config=config, env=env)
|
||||||
policy = trainer.get_policy()
|
policy = trainer.get_policy()
|
||||||
pol_sess = getattr(policy, "_sess", None)
|
pol_sess = policy.get_session()
|
||||||
# Remove noise that has been added during policy initialization
|
# Remove noise that has been added during policy initialization
|
||||||
# (exploration.postprocess_trajectory does add noise to measure
|
# (exploration.postprocess_trajectory does add noise to measure
|
||||||
# the delta).
|
# the delta).
|
||||||
|
@ -110,7 +110,7 @@ class TestParameterNoise(unittest.TestCase):
|
||||||
config["explore"] = False
|
config["explore"] = False
|
||||||
trainer = trainer_cls(config=config, env=env)
|
trainer = trainer_cls(config=config, env=env)
|
||||||
policy = trainer.get_policy()
|
policy = trainer.get_policy()
|
||||||
pol_sess = getattr(policy, "_sess", None)
|
pol_sess = policy.get_session()
|
||||||
# Remove noise that has been added during policy initialization
|
# Remove noise that has been added during policy initialization
|
||||||
# (exploration.postprocess_trajectory does add noise to measure
|
# (exploration.postprocess_trajectory does add noise to measure
|
||||||
# the delta).
|
# the delta).
|
||||||
|
|
|
@ -60,6 +60,31 @@ def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tf_eager_cls_if_necessary(orig_cls, config):
|
||||||
|
cls = orig_cls
|
||||||
|
framework = config.get("framework", "tf")
|
||||||
|
if framework in ["tf2", "tf", "tfe"]:
|
||||||
|
if not tf1:
|
||||||
|
raise ImportError("Could not import tensorflow!")
|
||||||
|
if framework in ["tf2", "tfe"]:
|
||||||
|
assert tf1.executing_eagerly()
|
||||||
|
|
||||||
|
from ray.rllib.policy.tf_policy import TFPolicy
|
||||||
|
|
||||||
|
# Create eager-class.
|
||||||
|
if hasattr(orig_cls, "as_eager"):
|
||||||
|
cls = orig_cls.as_eager()
|
||||||
|
if config.get("eager_tracing"):
|
||||||
|
cls = cls.with_tracing()
|
||||||
|
# Could be some other type of policy.
|
||||||
|
elif not issubclass(orig_cls, TFPolicy):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError("This policy does not support eager "
|
||||||
|
"execution: {}".format(orig_cls))
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def huber_loss(x, delta=1.0):
|
def huber_loss(x, delta=1.0):
|
||||||
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
||||||
return tf.where(
|
return tf.where(
|
||||||
|
|
Loading…
Add table
Reference in a new issue