[rllib] Fix num_gpus cast and raise error on large batch (#4652)

This commit is contained in:
Eric Liang 2019-04-18 15:23:29 -07:00 committed by GitHub
parent be2cbdf130
commit 5a562bbf12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 4 deletions

View file

@ -61,8 +61,8 @@ class Preprocessor(object):
self._obs_space, observation)
except AttributeError:
raise ValueError(
"Observation for a Box space should be an np.array, "
"not a Python list.", observation)
"Observation for a Box/MultiBinary/MultiDiscrete space "
"should be an np.array, not a Python list.", observation)
self._i += 1
@property

View file

@ -47,9 +47,13 @@ class TFMultiGPULearner(LearnerThread):
if not num_gpus:
self.devices = ["/cpu:0"]
elif _fake_gpus:
self.devices = ["/cpu:{}".format(i) for i in range(num_gpus)]
self.devices = [
"/cpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
]
else:
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
self.devices = [
"/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
]
logger.info("TFMultiGPULearner devices {}".format(self.devices))
assert self.train_batch_size % len(self.devices) == 0
assert self.train_batch_size >= len(self.devices), "batch too small"

View file

@ -68,6 +68,11 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
batches = tmp
for batch in batches:
if batch.count > self.max_buffer_size:
raise ValueError(
"The size of a single sample batch exceeds the replay "
"buffer size ({} > {})".format(batch.count,
self.max_buffer_size))
self.replay_buffer.append(batch)
self.num_steps_sampled += batch.count
self.buffer_size += batch.count