ray/rllib/agents/dqn/simple_q_policy.py
Sven Mika 0db2046b0a
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124)
* Exploration API (+EpsilonGreedy sub-class).

* Exploration API (+EpsilonGreedy sub-class).

* Cleanup/LINT.

* Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents).

* Add `error` option to deprecation_warning().

* WIP.

* Bug fix: Get exploration-info for tf framework.
Bug fix: Properly deprecate some DQN config keys.

* WIP.

* LINT.

* WIP.

* Split PerWorkerEpsilonGreedy out of EpsilonGreedy.
Docstrings.

* Fix bug in sampler.py in case Policy has self.exploration = None

* Update rllib/agents/dqn/dqn.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Update rllib/agents/trainer.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* WIP.

* Change requests.

* LINT

* In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set

* Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps).

* Update rllib/evaluation/worker_set.py

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* Review fixes.

* Fix default value for DQN's exploration spec.

* LINT

* Fix recursion bug (wrong parent c'tor).

* Do not pass timestep to get_exploration_info.

* Update tf_policy.py

* Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs.

* Bug fix tf-action-dist

* DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG).

* Switch off exploration when getting action probs from off-policy-estimator's policy.

* LINT

* Fix test_checkpoint_restore.py.

* Deprecate all SAC exploration (unused) configs.

* Properly use `model.last_output()` everywhere. Instead of `model._last_output`.

* WIP.

* Take out set_epsilon from multi-agent-env test (not needed, decays anyway).

* WIP.

* Trigger re-test (flaky checkpoint-restore test).

* WIP.

* WIP.

* Add test case for deterministic action sampling in PPO.

* bug fix.

* Added deterministic test cases for different Agents.

* Fix problem with TupleActions in dynamic-tf-policy.

* Separate supported_spaces tests so they can be run separately for easier debugging.

* LINT.

* Fix autoregressive_action_dist.py test case.

* Re-test.

* Fix.

* Remove duplicate py_test rule from bazel.

* LINT.

* WIP.

* WIP.

* SAC fix.

* SAC fix.

* WIP.

* WIP.

* WIP.

* FIX 2 examples tests.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* Fix.

* LINT.

* Renamed test file.

* WIP.

* Add unittest.main.

* Make action_dist_class mandatory.

* fix

* FIX.

* WIP.

* WIP.

* Fix.

* Fix.

* Fix explorations test case (contextlib cannot find its own nullcontext??).

* Force torch to be installed for QMIX.

* LINT.

* Fix determine_tests_to_run.py.

* Fix determine_tests_to_run.py.

* WIP

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function).

* Rename some stuff.

* Rename some stuff.

* WIP.

* WIP.

* Fix SAC.

* Fix SAC.

* Fix strange tf-error in ray core tests.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix test_io.py.

* LINT.

* Update SAC yaml files' config.

Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 14:19:49 -08:00

187 lines
6.7 KiB
Python

"""Basic example of a DQN policy without any optimizations."""
from gym.spaces import Discrete
import logging
import ray
from ray.rllib.agents.dqn.simple_q_model import SimpleQModel
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.tf_ops import huber_loss, make_tf_callable
tf = try_import_tf()
logger = logging.getLogger(__name__)
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
class ParameterNoiseMixin:
def __init__(self, obs_space, action_space, config):
pass
def add_parameter_noise(self):
if self.config["parameter_noise"]:
self.sess.run(self.add_noise_op)
class TargetNetworkMixin:
def __init__(self, obs_space, action_space, config):
@make_tf_callable(self.get_session())
def do_update():
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
(self.q_func_vars, self.target_q_func_vars)
for var, var_target in zip(self.q_func_vars,
self.target_q_func_vars):
update_target_expr.append(var_target.assign(var))
logger.debug("Update target op {}".format(var_target))
return tf.group(*update_target_expr)
self.update_target = do_update
@override(TFPolicy)
def variables(self):
return self.q_func_vars + self.target_q_func_vars
def build_q_models(policy, obs_space, action_space, config):
if not isinstance(action_space, Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
if config["hiddens"]:
num_outputs = 256
config["model"]["no_final_linear"] = True
else:
num_outputs = action_space.n
policy.q_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs,
config["model"],
framework="tf",
name=Q_SCOPE,
model_interface=SimpleQModel,
q_hiddens=config["hiddens"])
policy.target_q_model = ModelCatalog.get_model_v2(
obs_space,
action_space,
num_outputs,
config["model"],
framework="tf",
name=Q_TARGET_SCOPE,
model_interface=SimpleQModel,
q_hiddens=config["hiddens"])
return policy.q_model
def get_log_likelihood(policy, q_model, actions, input_dict, obs_space,
action_space, config):
# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
action_dist = Categorical(q_vals, q_model)
return action_dist.logp(actions)
def simple_sample_action_from_q_network(policy, q_model, input_dict, obs_space,
action_space, explore, config,
timestep):
# Action Q network.
q_vals = _compute_q_values(policy, q_model,
input_dict[SampleBatch.CUR_OBS], obs_space,
action_space)
policy.q_values = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_func_vars = q_model.variables()
policy.output_actions, policy.sampled_action_logp = \
policy.exploration.get_exploration_action(
policy.q_values, Categorical, q_model, explore, timestep)
return policy.output_actions, policy.sampled_action_logp
def build_q_losses(policy, model, dist_class, train_batch):
# q network evaluation
q_t = _compute_q_values(policy, policy.q_model,
train_batch[SampleBatch.CUR_OBS],
policy.observation_space, policy.action_space)
# target q network evalution
q_tp1 = _compute_q_values(policy, policy.target_q_model,
train_batch[SampleBatch.NEXT_OBS],
policy.observation_space, policy.action_space)
policy.target_q_func_vars = policy.target_q_model.variables()
# q scores for actions which we know were selected in the given state.
one_hot_selection = tf.one_hot(
tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32),
policy.action_space.n)
q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1)
# compute estimate of best possible value starting from state at t + 1
dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
q_tp1_best_one_hot_selection = tf.one_hot(
tf.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
q_tp1_best_masked = (1.0 - dones) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
policy.config["gamma"] * q_tp1_best_masked)
# compute the error (potentially clipped)
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
loss = tf.reduce_mean(huber_loss(td_error))
# save TD error as an attribute for outside access
policy.td_error = td_error
return loss
def _compute_q_values(policy, model, obs, obs_space, action_space):
input_dict = {
"obs": obs,
"is_training": policy._get_is_training_placeholder(),
}
model_out, _ = model(input_dict, [], None)
return model.get_q_values(model_out)
def setup_early_mixins(policy, obs_space, action_space, config):
ParameterNoiseMixin.__init__(policy, obs_space, action_space, config)
def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
SimpleQPolicy = build_tf_policy(
name="SimpleQPolicy",
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model=build_q_models,
action_sampler_fn=simple_sample_action_from_q_network,
log_likelihood_fn=get_log_likelihood,
loss_fn=build_q_losses,
extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values},
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
before_init=setup_early_mixins,
after_init=setup_late_mixins,
obs_include_prev_action_reward=False,
mixins=[ParameterNoiseMixin, TargetNetworkMixin])