mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] rename compute_apply to learn_on_batch
This commit is contained in:
parent
c4182463f6
commit
8df772867c
12 changed files with 81 additions and 37 deletions
|
@ -121,6 +121,25 @@ class PPOAgent(Agent):
|
|||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(fetches, **res.get("info", {})))
|
||||
|
||||
# Warn about bad clipping configs
|
||||
if self.config["vf_clip_param"] <= 0:
|
||||
rew_scale = float("inf")
|
||||
elif res["policy_reward_mean"]:
|
||||
rew_scale = 0 # punt on handling multiagent case
|
||||
else:
|
||||
rew_scale = round(
|
||||
abs(res["episode_reward_mean"]) / self.config["vf_clip_param"],
|
||||
0)
|
||||
if rew_scale > 100:
|
||||
logger.warning(
|
||||
"The magnitude of your environment rewards are more than "
|
||||
"{}x the scale of `vf_clip_param`. ".format(rew_scale) +
|
||||
"This means that it will take more than "
|
||||
"{} iterations for your value ".format(rew_scale) +
|
||||
"function to converge. If this is not intended, consider "
|
||||
"increasing `vf_clip_param`.")
|
||||
|
||||
return res
|
||||
|
||||
def _validate_config(self):
|
||||
|
|
|
@ -234,7 +234,7 @@ class QMixPolicyGraph(PolicyGraph):
|
|||
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_apply(self, samples):
|
||||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(samples["obs"])
|
||||
group_rewards = self._get_group_rewards(samples["infos"])
|
||||
|
||||
|
|
|
@ -31,11 +31,31 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Update policies based on the given batch.
|
||||
|
||||
This is the equivalent to apply_gradients(compute_gradients(samples)),
|
||||
but can be optimized to avoid pulling gradients into CPU memory.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
return self.compute_apply(samples)
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, samples):
|
||||
"""Returns a gradient computed w.r.t the specified samples.
|
||||
|
||||
This method must be implemented by subclasses.
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
(grads, info): A list of gradients that can be applied on a
|
||||
|
@ -54,7 +74,7 @@ class EvaluatorInterface(object):
|
|||
def apply_gradients(self, grads):
|
||||
"""Applies the given gradients to this evaluator's weights.
|
||||
|
||||
This method must be implemented by subclasses.
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Examples:
|
||||
>>> samples = ev1.sample()
|
||||
|
@ -95,15 +115,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Returns:
|
||||
info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.compute_apply(samples)
|
||||
"""
|
||||
"""Deprecated: override learn_on_batch instead."""
|
||||
|
||||
grads, info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
|
|
|
@ -43,7 +43,7 @@ class KerasPolicyGraph(PolicyGraph):
|
|||
value = self.critic.predict(state)
|
||||
return _sample(policy), [], {"vf_preds": value.flatten()}
|
||||
|
||||
def compute_apply(self, batch, *args):
|
||||
def learn_on_batch(self, batch, *args):
|
||||
self.actor.fit(
|
||||
batch["obs"],
|
||||
batch["adv_targets"],
|
||||
|
|
|
@ -470,16 +470,16 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def compute_apply(self, samples):
|
||||
def learn_on_batch(self, samples):
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "compute_apply")
|
||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid]._build_compute_apply(
|
||||
self.policy_map[pid]._build_learn_on_batch(
|
||||
builder, batch))
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
|
@ -487,11 +487,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
info_out[pid], _ = (
|
||||
self.policy_map[pid].compute_apply(batch))
|
||||
self.policy_map[pid].learn_on_batch(batch))
|
||||
return info_out
|
||||
else:
|
||||
grad_fetch, apply_fetch = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
|
||||
self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(samples))
|
||||
return grad_fetch
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -147,10 +147,30 @@ class PolicyGraph(object):
|
|||
"""
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
apply_info: dictionary of extra metadata from apply_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
return self.compute_apply(samples)
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
"""Computes gradients against a batch of experiences.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grads (list): List of gradient output values
|
||||
info (dict): Extra policy-specific values
|
||||
|
@ -161,6 +181,8 @@ class PolicyGraph(object):
|
|||
def apply_gradients(self, gradients):
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
info (dict): Extra policy-specific values
|
||||
"""
|
||||
|
@ -168,16 +190,7 @@ class PolicyGraph(object):
|
|||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
apply_info: dictionary of extra metadata from apply_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.compute_apply(samples)
|
||||
"""
|
||||
"""Deprecated: override learn_on_batch instead."""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
apply_info = self.apply_gradients(grads)
|
||||
|
|
|
@ -179,9 +179,9 @@ class TFPolicyGraph(PolicyGraph):
|
|||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_apply(self, postprocessed_batch):
|
||||
builder = TFRunBuilder(self._sess, "compute_apply")
|
||||
fetches = self._build_compute_apply(builder, postprocessed_batch)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
|
@ -380,7 +380,7 @@ class TFPolicyGraph(PolicyGraph):
|
|||
[self._apply_op, self.extra_apply_grad_fetches()])
|
||||
return fetches[1]
|
||||
|
||||
def _build_compute_apply(self, builder, postprocessed_batch):
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
|
|
|
@ -391,7 +391,7 @@ class LearnerThread(threading.Thread):
|
|||
if replay is not None:
|
||||
prio_dict = {}
|
||||
with self.grad_timer:
|
||||
grad_out = self.local_evaluator.compute_apply(replay)
|
||||
grad_out = self.local_evaluator.learn_on_batch(replay)
|
||||
for pid, info in grad_out.items():
|
||||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid].data.get("batch_indexes"),
|
||||
|
|
|
@ -278,7 +278,7 @@ class LearnerThread(threading.Thread):
|
|||
batch, _ = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = self.local_evaluator.compute_apply(batch)
|
||||
fetches = self.local_evaluator.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = fetches.get("stats", {})
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
|||
samples.append(random.choice(self.replay_buffer))
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.compute_apply(samples)
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
|
|
|
@ -126,7 +126,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
|||
samples = self._replay()
|
||||
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.compute_apply(samples)
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
|
|
|
@ -54,7 +54,7 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
|||
|
||||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.compute_apply(samples)
|
||||
fetches = self.local_evaluator.learn_on_batch(samples)
|
||||
if "stats" in fetches:
|
||||
self.learner_stats = fetches["stats"]
|
||||
if self.num_sgd_iter > 1:
|
||||
|
|
Loading…
Add table
Reference in a new issue