mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 08:12:53 +08:00
[rllib] PPO doesn't work with fractional num gpus (#3396)
* frac ppo * gpu test
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
@@ -44,7 +45,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
if not num_gpus:
|
||||
self.devices = ["/cpu:0"]
|
||||
else:
|
||||
self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)]
|
||||
self.devices = [
|
||||
"/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
self.batch_size = int(sgd_batch_size / len(self.devices)) * len(
|
||||
self.devices)
|
||||
assert self.batch_size % len(self.devices) == 0
|
||||
|
||||
Reference in New Issue
Block a user