mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 17:58:35 +08:00
[sgd] Add file lock to protect compilation of sgd op (#3486)
* add file lock to protect compilation of sgd op * lint * update * fix * fix * lint * update * rebase on arrow * Update sgd_worker.py
This commit is contained in:
@@ -9,9 +9,10 @@ import pyarrow.plasma as plasma
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.experimental.sgd.util import fetch, run_timeline, warmup
|
||||
from ray.experimental.sgd.modified_allreduce import sum_gradients_all_reduce, \
|
||||
unpack_small_tensors
|
||||
from ray.experimental.sgd.util import (ensure_plasma_tensorflow_op, fetch,
|
||||
run_timeline, warmup)
|
||||
from ray.experimental.sgd.modified_allreduce import (sum_gradients_all_reduce,
|
||||
unpack_small_tensors)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -112,8 +113,7 @@ class SGDWorker(object):
|
||||
ray.worker.global_worker.plasma_client.store_socket_name)
|
||||
manager_socket = (
|
||||
ray.worker.global_worker.plasma_client.manager_socket_name)
|
||||
if not plasma.tf_plasma_op:
|
||||
plasma.build_plasma_tensorflow_op()
|
||||
ensure_plasma_tensorflow_op()
|
||||
|
||||
# For fetching grads -> plasma
|
||||
self.plasma_in_grads = []
|
||||
|
||||
@@ -2,10 +2,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import filelock
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -120,6 +123,16 @@ class Timeline(object):
|
||||
logger.info("Wrote chrome timeline to", filename)
|
||||
|
||||
|
||||
def ensure_plasma_tensorflow_op():
|
||||
base_path = os.path.join(pyarrow.__path__[0], "tensorflow")
|
||||
lock_path = os.path.join(base_path, "compile_op.lock")
|
||||
with filelock.FileLock(lock_path):
|
||||
if not os.path.exists(os.path.join(base_path, "plasma_op.so")):
|
||||
plasma.build_plasma_tensorflow_op()
|
||||
else:
|
||||
plasma.load_plasma_tensorflow_op()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Timeline(1)
|
||||
b = Timeline(2)
|
||||
|
||||
@@ -136,6 +136,7 @@ def find_version(*filepath):
|
||||
|
||||
requires = [
|
||||
"numpy",
|
||||
"filelock",
|
||||
"funcsigs",
|
||||
"click",
|
||||
"colorama",
|
||||
|
||||
Reference in New Issue
Block a user