[rllib] rename compute_apply to learn_on_batch

This commit is contained in:
Eric Liang 2019-02-11 15:22:15 -08:00 committed by GitHub
parent c4182463f6
commit 8df772867c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 81 additions and 37 deletions

View file

@ -121,6 +121,25 @@ class PPOAgent(Agent):
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {})))
# Warn about bad clipping configs
if self.config["vf_clip_param"] <= 0:
rew_scale = float("inf")
elif res["policy_reward_mean"]:
rew_scale = 0 # punt on handling multiagent case
else:
rew_scale = round(
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
0)
if rew_scale > 100:
logger.warning(
"The magnitude of your environment rewards are more than "
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
"This means that it will take more than "
"{} iterations for your value ".format(rew_scale) +
"function to converge. If this is not intended, consider "
"increasing `vf_clip_param`.")
return res
def _validate_config(self):

View file

@ -234,7 +234,7 @@ class QMixPolicyGraph(PolicyGraph):
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
@override(PolicyGraph)
def compute_apply(self, samples):
def learn_on_batch(self, samples):
obs_batch, action_mask = self._unpack_observation(samples["obs"])
group_rewards = self._get_group_rewards(samples["infos"])

View file

@ -31,11 +31,31 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def learn_on_batch(self, samples):
"""Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)),
but can be optimized to avoid pulling gradients into CPU memory.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
return self.compute_apply(samples)
@DeveloperAPI
def compute_gradients(self, samples):
"""Returns a gradient computed w.r.t the specified samples.
This method must be implemented by subclasses.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
(grads, info): A list of gradients that can be applied on a
@ -54,7 +74,7 @@ class EvaluatorInterface(object):
def apply_gradients(self, grads):
"""Applies the given gradients to this evaluator's weights.
This method must be implemented by subclasses.
Either this or learn_on_batch() must be implemented by subclasses.
Examples:
>>> samples = ev1.sample()
@ -95,15 +115,7 @@ class EvaluatorInterface(object):
@DeveloperAPI
def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call.
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.compute_apply(samples)
"""
"""Deprecated: override learn_on_batch instead."""
grads, info = self.compute_gradients(samples)
self.apply_gradients(grads)

View file

@ -43,7 +43,7 @@ class KerasPolicyGraph(PolicyGraph):
value = self.critic.predict(state)
return _sample(policy), [], {"vf_preds": value.flatten()}
def compute_apply(self, batch, *args):
def learn_on_batch(self, batch, *args):
self.actor.fit(
batch["obs"],
batch["adv_targets"],

View file

@ -470,16 +470,16 @@ class PolicyEvaluator(EvaluatorInterface):
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
@override(EvaluatorInterface)
def compute_apply(self, samples):
def learn_on_batch(self, samples):
if isinstance(samples, MultiAgentBatch):
info_out = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_apply")
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
info_out[pid], _ = (
self.policy_map[pid]._build_compute_apply(
self.policy_map[pid]._build_learn_on_batch(
builder, batch))
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
@ -487,11 +487,11 @@ class PolicyEvaluator(EvaluatorInterface):
if pid not in self.policies_to_train:
continue
info_out[pid], _ = (
self.policy_map[pid].compute_apply(batch))
self.policy_map[pid].learn_on_batch(batch))
return info_out
else:
grad_fetch, apply_fetch = (
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
return grad_fetch
@DeveloperAPI

View file

@ -147,10 +147,30 @@ class PolicyGraph(object):
"""
return sample_batch
@DeveloperAPI
def learn_on_batch(self, samples):
"""Fused compute gradients and apply gradients call.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
return self.compute_apply(samples)
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
"""Computes gradients against a batch of experiences.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
grads (list): List of gradient output values
info (dict): Extra policy-specific values
@ -161,6 +181,8 @@ class PolicyGraph(object):
def apply_gradients(self, gradients):
"""Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
info (dict): Extra policy-specific values
"""
@ -168,16 +190,7 @@ class PolicyGraph(object):
@DeveloperAPI
def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
apply_info: dictionary of extra metadata from apply_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.compute_apply(samples)
"""
"""Deprecated: override learn_on_batch instead."""
grads, grad_info = self.compute_gradients(samples)
apply_info = self.apply_gradients(grads)

View file

@ -179,9 +179,9 @@ class TFPolicyGraph(PolicyGraph):
return builder.get(fetches)
@override(PolicyGraph)
def compute_apply(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_apply")
fetches = self._build_compute_apply(builder, postprocessed_batch)
def learn_on_batch(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
@override(PolicyGraph)
@ -380,7 +380,7 @@ class TFPolicyGraph(PolicyGraph):
[self._apply_op, self.extra_apply_grad_fetches()])
return fetches[1]
def _build_compute_apply(self, builder, postprocessed_batch):
def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))

View file

@ -391,7 +391,7 @@ class LearnerThread(threading.Thread):
if replay is not None:
prio_dict = {}
with self.grad_timer:
grad_out = self.local_evaluator.compute_apply(replay)
grad_out = self.local_evaluator.learn_on_batch(replay)
for pid, info in grad_out.items():
prio_dict[pid] = (
replay.policy_batches[pid].data.get("batch_indexes"),

View file

@ -278,7 +278,7 @@ class LearnerThread(threading.Thread):
batch, _ = self.minibatch_buffer.get()
with self.grad_timer:
fetches = self.local_evaluator.compute_apply(batch)
fetches = self.local_evaluator.learn_on_batch(batch)
self.weights_updated = True
self.stats = fetches.get("stats", {})

View file

@ -95,7 +95,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
samples.append(random.choice(self.replay_buffer))
samples = SampleBatch.concat_samples(samples)
with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples)
info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]

View file

@ -126,7 +126,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
samples = self._replay()
with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples)
info_dict = self.local_evaluator.learn_on_batch(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]

View file

@ -54,7 +54,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
with self.grad_timer:
for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.compute_apply(samples)
fetches = self.local_evaluator.learn_on_batch(samples)
if "stats" in fetches:
self.learner_stats = fetches["stats"]
if self.num_sgd_iter > 1: