mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:41:19 +08:00
[sgd] Example for Training (#5292)
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: sgd-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 1
|
||||
initial_workers: 1
|
||||
max_workers: 1
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
# If a node is idle for this many minutes, it will be removed.
|
||||
idle_timeout_minutes: 20
|
||||
# docker:
|
||||
# image: tensorflow/tensorflow:1.5.0-py3
|
||||
# container_name: ray_docker
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-east-1
|
||||
availability_zone: us-east-1f
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0757fc5a639fe7666
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: ami-0757fc5a639fe7666
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
# # Run workers on spot by default. Comment this out to use on-demand.
|
||||
# InstanceMarketOptions:
|
||||
# MarketType: spot
|
||||
|
||||
setup_commands:
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev2-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- conda install -y pytorch torchvision cudatoolkit=9.0 -c pytorch
|
||||
- pip install -U ipdb ray[rllib]
|
||||
|
||||
|
||||
file_mounts: {
|
||||
}
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
||||
|
||||
# Custom commands that will be run on worker nodes after common setup.
|
||||
worker_setup_commands: []
|
||||
|
||||
# # Command to start ray on the head node. You don't need to change this.
|
||||
head_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --head --redis-port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
|
||||
|
||||
# Command to start ray on worker nodes. You don't need to change this.
|
||||
worker_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --redis-address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
|
||||
|
||||
from ray.experimental.sgd.tests.pytorch_utils import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
resources_per_replica=Resources(
|
||||
num_cpus=1, num_gpus=int(use_gpu), resources={}))
|
||||
trainer1.train()
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
ray.init(redis_address=args.redis_address)
|
||||
train_example(num_replicas=2, use_gpus=args.use_gpu)
|
||||
Reference in New Issue
Block a user