ray/rllib/policy/policy_map.py

309 lines
10 KiB
Python

import os
import pickle
import threading
from collections import deque
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Type
import gym
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.annotations import PublicAPI, override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.policy import create_policy_for_framework
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
PartialAlgorithmConfigDict,
PolicyID,
)
from ray.tune.utils.util import merge_dicts
if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy
tf1, tf, tfv = try_import_tf()
@PublicAPI
class PolicyMap(dict):
"""Maps policy IDs to Policy objects.
Thereby, keeps n policies in memory and - when capacity is reached -
writes the least recently used to disk. This allows adding 100s of
policies to a Algorithm for league-based setups w/o running out of memory.
"""
def __init__(
self,
worker_index: int,
num_workers: int,
capacity: Optional[int] = None,
path: Optional[str] = None,
policy_config: Optional[AlgorithmConfigDict] = None,
session_creator: Optional[Callable[[], "tf1.Session"]] = None,
seed: Optional[int] = None,
):
"""Initializes a PolicyMap instance.
Args:
worker_index: The worker index of the RolloutWorker this map
resides in.
num_workers: The total number of remote workers in the
WorkerSet to which this map's RolloutWorker belongs to.
capacity: The maximum number of policies to hold in memory.
The least used ones are written to disk/S3 and retrieved
when needed.
path: The path to store the policy pickle files to. Files
will have the name: [policy_id].[worker idx].policy.pkl.
policy_config: The Algorithm's base config dict.
session_creator: An optional
tf1.Session creation callable.
seed: An optional seed (used to seed tf policies).
"""
super().__init__()
self.worker_index = worker_index
self.num_workers = num_workers
self.session_creator = session_creator
self.seed = seed
# The file extension for stashed policies (that are no longer available
# in-memory but can be reinstated any time from storage).
self.extension = f".{self.worker_index}.policy.pkl"
# Dictionary of keys that may be looked up (cached or not).
self.valid_keys: Set[str] = set()
# The actual cache with the in-memory policy objects.
self.cache: Dict[str, Policy] = {}
# The doubly-linked list holding the currently in-memory objects.
self.deque = deque(maxlen=capacity or 10)
# The file path where to store overflowing policies.
self.path = path or "."
# The core config to use. Each single policy's config override is
# added on top of this.
self.policy_config: AlgorithmConfigDict = policy_config or {}
# The orig classes/obs+act spaces, and config overrides of the
# Policies.
self.policy_specs: Dict[PolicyID, PolicySpec] = {}
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when accessing the map
# and the underlying structures, like self.deque and others.
self._lock = threading.RLock()
def create_policy(
self,
policy_id: PolicyID,
policy_cls: Type["Policy"],
observation_space: gym.Space,
action_space: gym.Space,
config_override: PartialAlgorithmConfigDict,
merged_config: AlgorithmConfigDict,
) -> None:
"""Creates a new policy and stores it to the cache.
Args:
policy_id: The policy ID. This is the key under which
the created policy will be stored in this map.
policy_cls: The (original) policy class to use.
This may still be altered in case tf-eager (and tracing)
is used.
observation_space: The observation space of the
policy.
action_space: The action space of the policy.
config_override: The config override
dict for this policy. This is the partial dict provided by
the user.
merged_config: The entire config (merged
default config + `config_override`).
"""
_class = get_tf_eager_cls_if_necessary(policy_cls, merged_config)
self[policy_id] = create_policy_for_framework(
policy_id,
_class,
merged_config,
observation_space,
action_space,
self.worker_index,
self.session_creator,
self.seed,
)
# Store spec (class, obs-space, act-space, and config overrides) such
# that the map will be able to reproduce on-the-fly added policies
# from disk.
self.policy_specs[policy_id] = PolicySpec(
policy_class=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config_override,
)
@with_lock
@override(dict)
def __getitem__(self, item):
# Never seen this key -> Error.
if item not in self.valid_keys:
raise KeyError(f"PolicyID '{item}' not found in this PolicyMap!")
# Item already in cache -> Rearrange deque (least recently used) and
# return.
if item in self.cache:
self.deque.remove(item)
self.deque.append(item)
# Item not currently in cache -> Get from disk and - if at capacity -
# remove leftmost one.
else:
self._read_from_disk(policy_id=item)
return self.cache[item]
@with_lock
@override(dict)
def __setitem__(self, key, value):
# Item already in cache -> Rearrange deque (least recently used).
if key in self.cache:
self.deque.remove(key)
self.deque.append(key)
self.cache[key] = value
# Item not currently in cache -> store new value and - if at capacity -
# remove leftmost one.
else:
# Cache at capacity -> Drop leftmost item.
if len(self.deque) == self.deque.maxlen:
self._stash_to_disk()
self.deque.append(key)
self.cache[key] = value
self.valid_keys.add(key)
@with_lock
@override(dict)
def __delitem__(self, key):
# Make key invalid.
self.valid_keys.remove(key)
# Remove policy from memory if currently cached.
if key in self.cache:
policy = self.cache[key]
self._close_session(policy)
del self.cache[key]
# Remove file associated with the policy, if it exists.
filename = self.path + "/" + key + self.extension
if os.path.isfile(filename):
os.remove(filename)
@override(dict)
def __iter__(self):
return iter(self.keys())
@override(dict)
def items(self):
"""Iterates over all policies, even the stashed-to-disk ones."""
def gen():
for key in self.valid_keys:
yield (key, self[key])
return gen()
@override(dict)
def keys(self):
self._lock.acquire()
ks = list(self.valid_keys)
self._lock.release()
def gen():
for key in ks:
yield key
return gen()
@override(dict)
def values(self):
self._lock.acquire()
vs = [self[k] for k in self.valid_keys]
self._lock.release()
def gen():
for value in vs:
yield value
return gen()
@with_lock
@override(dict)
def update(self, __m, **kwargs):
for k, v in __m.items():
self[k] = v
for k, v in kwargs.items():
self[k] = v
@with_lock
@override(dict)
def get(self, key):
if key not in self.valid_keys:
return None
return self[key]
@with_lock
@override(dict)
def __len__(self):
"""Returns number of all policies, including the stashed-to-disk ones."""
return len(self.valid_keys)
@with_lock
@override(dict)
def __contains__(self, item):
return item in self.valid_keys
def _stash_to_disk(self):
"""Writes the least-recently used policy to disk and rearranges cache.
Also closes the session - if applicable - of the stashed policy.
"""
# Get least recently used policy (all the way on the left in deque).
delkey = self.deque.popleft()
policy = self.cache[delkey]
# Get its state for writing to disk.
policy_state = policy.get_state()
# Closes policy's tf session, if any.
self._close_session(policy)
# Remove from memory. This will clear the tf Graph as well.
del self.cache[delkey]
# Write state to disk.
with open(self.path + "/" + delkey + self.extension, "wb") as f:
pickle.dump(policy_state, file=f)
def _read_from_disk(self, policy_id):
"""Reads a policy ID from disk and re-adds it to the cache."""
# Make sure this policy ID is not in the cache right now.
assert policy_id not in self.cache
# Read policy state from disk.
with open(self.path + "/" + policy_id + self.extension, "rb") as f:
policy_state = pickle.load(f)
# Get class and config override.
merged_conf = merge_dicts(
self.policy_config, self.policy_specs[policy_id].config
)
# Create policy object (from its spec: cls, obs-space, act-space,
# config).
self.create_policy(
policy_id,
self.policy_specs[policy_id].policy_class,
self.policy_specs[policy_id].observation_space,
self.policy_specs[policy_id].action_space,
self.policy_specs[policy_id].config,
merged_conf,
)
# Restore policy's state.
policy = self[policy_id]
policy.set_state(policy_state)
def _close_session(self, policy):
sess = policy.get_session()
# Closes the tf session, if any.
if sess is not None:
sess.close()