ray/rllib/evaluation/interface.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

124 lines
3.4 KiB
Python

import os
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class EvaluatorInterface:
"""This is the interface between policy optimizers and policy evaluation.
See also: RolloutWorker
"""
@DeveloperAPI
def sample(self):
"""Returns a batch of experience sampled from this evaluator.
This method must be implemented by subclasses.
Returns:
SampleBatch|MultiAgentBatch: A columnar batch of experiences
(e.g., tensors), or a multi-agent batch.
Examples:
>>> print(ev.sample())
SampleBatch({"obs": [1, 2, 3], "action": [0, 1, 0], ...})
"""
raise NotImplementedError
@DeveloperAPI
def learn_on_batch(self, samples):
"""Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)),
but can be optimized to avoid pulling gradients into CPU memory.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
grads, info = self.compute_gradients(samples)
self.apply_gradients(grads)
return info
@DeveloperAPI
def compute_gradients(self, samples):
"""Returns a gradient computed w.r.t the specified samples.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
(grads, info): A list of gradients that can be applied on a
compatible evaluator. In the multi-agent case, returns a dict
of gradients keyed by policy ids. An info dictionary of
extra metadata is also returned.
Examples:
>>> batch = ev.sample()
>>> grads, info = ev2.compute_gradients(samples)
"""
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, grads):
"""Applies the given gradients to this evaluator's weights.
Either this or learn_on_batch() must be implemented by subclasses.
Examples:
>>> samples = ev1.sample()
>>> grads, info = ev2.compute_gradients(samples)
>>> ev1.apply_gradients(grads)
"""
raise NotImplementedError
@DeveloperAPI
def get_weights(self):
"""Returns the model weights of this Evaluator.
This method must be implemented by subclasses.
Returns:
object: weights that can be set on a compatible evaluator.
info: dictionary of extra metadata.
Examples:
>>> weights = ev1.get_weights()
"""
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
"""Sets the model weights of this Evaluator.
This method must be implemented by subclasses.
Examples:
>>> weights = ev1.get_weights()
>>> ev2.set_weights(weights)
"""
raise NotImplementedError
@DeveloperAPI
def get_host(self):
"""Returns the hostname of the process running this evaluator."""
return os.uname()[1]
@DeveloperAPI
def apply(self, func, *args):
"""Apply the given function to this evaluator instance."""
return func(self, *args)