[rllib] rename compute_apply to learn_on_batch

This commit is contained in:
Eric Liang 2019-02-11 15:22:15 -08:00 committed by GitHub
parent c4182463f6
commit 8df772867c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 81 additions and 37 deletions

View file

@ -121,6 +121,25 @@ class PPOAgent(Agent):
res.update( res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {}))) info=dict(fetches, **res.get("info", {})))
# Warn about bad clipping configs
if self.config["vf_clip_param"] <= 0:
rew_scale = float("inf")
elif res["policy_reward_mean"]:
rew_scale = 0 # punt on handling multiagent case
else:
rew_scale = round(
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
0)
if rew_scale > 100:
logger.warning(
"The magnitude of your environment rewards are more than "
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
"This means that it will take more than "
"{} iterations for your value ".format(rew_scale) +
"function to converge. If this is not intended, consider "
"increasing `vf_clip_param`.")
return res return res
def _validate_config(self): def _validate_config(self):

View file

@ -234,7 +234,7 @@ class QMixPolicyGraph(PolicyGraph):
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
@override(PolicyGraph) @override(PolicyGraph)
def compute_apply(self, samples): def learn_on_batch(self, samples):
obs_batch, action_mask = self._unpack_observation(samples["obs"]) obs_batch, action_mask = self._unpack_observation(samples["obs"])
group_rewards = self._get_group_rewards(samples["infos"]) group_rewards = self._get_group_rewards(samples["infos"])

View file

@ -31,11 +31,31 @@ class EvaluatorInterface(object):
raise NotImplementedError raise NotImplementedError
@DeveloperAPI
def learn_on_batch(self, samples):
"""Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)),
but can be optimized to avoid pulling gradients into CPU memory.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
return self.compute_apply(samples)
@DeveloperAPI @DeveloperAPI
def compute_gradients(self, samples): def compute_gradients(self, samples):
"""Returns a gradient computed w.r.t the specified samples. """Returns a gradient computed w.r.t the specified samples.
This method must be implemented by subclasses. Either this or learn_on_batch() must be implemented by subclasses.
Returns: Returns:
(grads, info): A list of gradients that can be applied on a (grads, info): A list of gradients that can be applied on a
@ -54,7 +74,7 @@ class EvaluatorInterface(object):
def apply_gradients(self, grads): def apply_gradients(self, grads):
"""Applies the given gradients to this evaluator's weights. """Applies the given gradients to this evaluator's weights.
This method must be implemented by subclasses. Either this or learn_on_batch() must be implemented by subclasses.
Examples: Examples:
>>> samples = ev1.sample() >>> samples = ev1.sample()
@ -95,15 +115,7 @@ class EvaluatorInterface(object):
@DeveloperAPI @DeveloperAPI
def compute_apply(self, samples): def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call. """Deprecated: override learn_on_batch instead."""
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.compute_apply(samples)
"""
grads, info = self.compute_gradients(samples) grads, info = self.compute_gradients(samples)
self.apply_gradients(grads) self.apply_gradients(grads)

View file

@ -43,7 +43,7 @@ class KerasPolicyGraph(PolicyGraph):
value = self.critic.predict(state) value = self.critic.predict(state)
return _sample(policy), [], {"vf_preds": value.flatten()} return _sample(policy), [], {"vf_preds": value.flatten()}
def compute_apply(self, batch, *args): def learn_on_batch(self, batch, *args):
self.actor.fit( self.actor.fit(
batch["obs"], batch["obs"],
batch["adv_targets"], batch["adv_targets"],

View file

@ -470,16 +470,16 @@ class PolicyEvaluator(EvaluatorInterface):
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
@override(EvaluatorInterface) @override(EvaluatorInterface)
def compute_apply(self, samples): def learn_on_batch(self, samples):
if isinstance(samples, MultiAgentBatch): if isinstance(samples, MultiAgentBatch):
info_out = {} info_out = {}
if self.tf_sess is not None: if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_apply") builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
for pid, batch in samples.policy_batches.items(): for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train: if pid not in self.policies_to_train:
continue continue
info_out[pid], _ = ( info_out[pid], _ = (
self.policy_map[pid]._build_compute_apply( self.policy_map[pid]._build_learn_on_batch(
builder, batch)) builder, batch))
info_out = {k: builder.get(v) for k, v in info_out.items()} info_out = {k: builder.get(v) for k, v in info_out.items()}
else: else:
@ -487,11 +487,11 @@ class PolicyEvaluator(EvaluatorInterface):
if pid not in self.policies_to_train: if pid not in self.policies_to_train:
continue continue
info_out[pid], _ = ( info_out[pid], _ = (
self.policy_map[pid].compute_apply(batch)) self.policy_map[pid].learn_on_batch(batch))
return info_out return info_out
else: else:
grad_fetch, apply_fetch = ( grad_fetch, apply_fetch = (
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
return grad_fetch return grad_fetch
@DeveloperAPI @DeveloperAPI

View file

@ -147,10 +147,30 @@ class PolicyGraph(object):
""" """
return sample_batch return sample_batch
@DeveloperAPI
def learn_on_batch(self, samples):
"""Fused compute gradients and apply gradients call.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
return self.compute_apply(samples)
@DeveloperAPI @DeveloperAPI
def compute_gradients(self, postprocessed_batch): def compute_gradients(self, postprocessed_batch):
"""Computes gradients against a batch of experiences. """Computes gradients against a batch of experiences.
Either this or learn_on_batch() must be implemented by subclasses.
Returns: Returns:
grads (list): List of gradient output values grads (list): List of gradient output values
info (dict): Extra policy-specific values info (dict): Extra policy-specific values
@ -161,6 +181,8 @@ class PolicyGraph(object):
def apply_gradients(self, gradients): def apply_gradients(self, gradients):
"""Applies previously computed gradients. """Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
Returns: Returns:
info (dict): Extra policy-specific values info (dict): Extra policy-specific values
""" """
@ -168,16 +190,7 @@ class PolicyGraph(object):
@DeveloperAPI @DeveloperAPI
def compute_apply(self, samples): def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call. """Deprecated: override learn_on_batch instead."""
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.compute_apply(samples)
"""
grads, grad_info = self.compute_gradients(samples) grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads) apply_info = self.apply_gradients(grads)

View file

@ -179,9 +179,9 @@ class TFPolicyGraph(PolicyGraph):
return builder.get(fetches) return builder.get(fetches)
@override(PolicyGraph) @override(PolicyGraph)
def compute_apply(self, postprocessed_batch): def learn_on_batch(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_apply") builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_compute_apply(builder, postprocessed_batch) fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches) return builder.get(fetches)
@override(PolicyGraph) @override(PolicyGraph)
@ -380,7 +380,7 @@ class TFPolicyGraph(PolicyGraph):
[self._apply_op, self.extra_apply_grad_fetches()]) [self._apply_op, self.extra_apply_grad_fetches()])
return fetches[1] return fetches[1]
def _build_compute_apply(self, builder, postprocessed_batch): def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict()) builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self.extra_apply_grad_feed_dict()) builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))

View file

@ -391,7 +391,7 @@ class LearnerThread(threading.Thread):
if replay is not None: if replay is not None:
prio_dict = {} prio_dict = {}
with self.grad_timer: with self.grad_timer:
grad_out = self.local_evaluator.compute_apply(replay) grad_out = self.local_evaluator.learn_on_batch(replay)
for pid, info in grad_out.items(): for pid, info in grad_out.items():
prio_dict[pid] = ( prio_dict[pid] = (
replay.policy_batches[pid].data.get("batch_indexes"), replay.policy_batches[pid].data.get("batch_indexes"),

View file

@ -278,7 +278,7 @@ class LearnerThread(threading.Thread):
batch, _ = self.minibatch_buffer.get() batch, _ = self.minibatch_buffer.get()
with self.grad_timer: with self.grad_timer:
fetches = self.local_evaluator.compute_apply(batch) fetches = self.local_evaluator.learn_on_batch(batch)
self.weights_updated = True self.weights_updated = True
self.stats = fetches.get("stats", {}) self.stats = fetches.get("stats", {})

View file

@ -95,7 +95,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
samples.append(random.choice(self.replay_buffer)) samples.append(random.choice(self.replay_buffer))
samples = SampleBatch.concat_samples(samples) samples = SampleBatch.concat_samples(samples)
with self.grad_timer: with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples) info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items(): for policy_id, info in info_dict.items():
if "stats" in info: if "stats" in info:
self.learner_stats[policy_id] = info["stats"] self.learner_stats[policy_id] = info["stats"]

View file

@ -126,7 +126,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
samples = self._replay() samples = self._replay()
with self.grad_timer: with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples) info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items(): for policy_id, info in info_dict.items():
if "stats" in info: if "stats" in info:
self.learner_stats[policy_id] = info["stats"] self.learner_stats[policy_id] = info["stats"]

View file

@ -54,7 +54,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
with self.grad_timer: with self.grad_timer:
for i in range(self.num_sgd_iter): for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.compute_apply(samples) fetches = self.local_evaluator.learn_on_batch(samples)
if "stats" in fetches: if "stats" in fetches:
self.learner_stats = fetches["stats"] self.learner_stats = fetches["stats"]
if self.num_sgd_iter > 1: if self.num_sgd_iter > 1: