[RLlib] Better exceptions with traceback in TorchPolicy (#17690)

This commit is contained in:
Julius Frost 2021-08-11 09:01:07 -04:00 committed by GitHub
parent 811d71b368
commit 6891dee6ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -994,9 +994,10 @@ class TorchPolicy(Policy):
results[shard_idx] = (all_grads, grad_info)
except Exception as e:
with lock:
results[shard_idx] = ValueError(
results[shard_idx] = (ValueError(
e.args[0] + "\n" +
"In tower {} on device {}".format(shard_idx, device))
"In tower {} on device {}".format(shard_idx, device)),
e)
# Single device (GPU) or fake-GPU case (serialize for better
# debugging).
@ -1006,8 +1007,8 @@ class TorchPolicy(Policy):
_worker(shard_idx, model, sample_batch, device)
# Raise errors right away for better debugging.
last_result = results[len(results) - 1]
if isinstance(last_result, ValueError):
raise last_result
if isinstance(last_result[0], ValueError):
raise last_result[0] from last_result[1]
# Multi device (GPU) case: Parallelize via threads.
else:
threads = [
@ -1027,8 +1028,8 @@ class TorchPolicy(Policy):
outputs = []
for shard_idx in range(len(sample_batches)):
output = results[shard_idx]
if isinstance(output, Exception):
raise output
if isinstance(output[0], Exception):
raise output[0] from output[1]
outputs.append(results[shard_idx])
return outputs