mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
60d4d5e1aa
* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
132 lines
3.4 KiB
Python
132 lines
3.4 KiB
Python
import argparse
|
|
import tensorflow as tf
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import Dense
|
|
import numpy as np
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.experimental.sgd.tf.tf_trainer import TFTrainer, TFTrainable
|
|
|
|
NUM_TRAIN_SAMPLES = 1000
|
|
NUM_TEST_SAMPLES = 400
|
|
|
|
|
|
def create_config(batch_size):
|
|
return {
|
|
"batch_size": batch_size,
|
|
"fit_config": {
|
|
"steps_per_epoch": NUM_TRAIN_SAMPLES // batch_size
|
|
},
|
|
"evaluate_config": {
|
|
"steps": NUM_TEST_SAMPLES // batch_size,
|
|
}
|
|
}
|
|
|
|
|
|
def linear_dataset(a=2, size=1000):
|
|
x = np.random.rand(size)
|
|
y = x / 2
|
|
|
|
x = x.reshape((-1, 1))
|
|
y = y.reshape((-1, 1))
|
|
|
|
return x, y
|
|
|
|
|
|
def simple_dataset(config):
|
|
batch_size = config["batch_size"]
|
|
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
|
|
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
|
|
|
|
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
|
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
|
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
|
|
batch_size)
|
|
test_dataset = test_dataset.repeat().batch(batch_size)
|
|
|
|
return train_dataset, test_dataset
|
|
|
|
|
|
def simple_model(config):
|
|
model = Sequential([Dense(10, input_shape=(1, )), Dense(1)])
|
|
|
|
model.compile(
|
|
optimizer="sgd",
|
|
loss="mean_squared_error",
|
|
metrics=["mean_squared_error"])
|
|
|
|
return model
|
|
|
|
|
|
def train_example(num_replicas=1, batch_size=128, use_gpu=False):
|
|
trainer = TFTrainer(
|
|
model_creator=simple_model,
|
|
data_creator=simple_dataset,
|
|
num_replicas=num_replicas,
|
|
use_gpu=use_gpu,
|
|
verbose=True,
|
|
config=create_config(batch_size))
|
|
|
|
train_stats1 = trainer.train()
|
|
train_stats1.update(trainer.validate())
|
|
print(train_stats1)
|
|
|
|
train_stats2 = trainer.train()
|
|
train_stats2.update(trainer.validate())
|
|
print(train_stats2)
|
|
|
|
val_stats = trainer.validate()
|
|
print(val_stats)
|
|
print("success!")
|
|
|
|
|
|
def tune_example(num_replicas=1, use_gpu=False):
|
|
config = {
|
|
"model_creator": tune.function(simple_model),
|
|
"data_creator": tune.function(simple_dataset),
|
|
"num_replicas": num_replicas,
|
|
"use_gpu": use_gpu,
|
|
"trainer_config": create_config(batch_size=128)
|
|
}
|
|
|
|
analysis = tune.run(
|
|
TFTrainable,
|
|
num_samples=2,
|
|
config=config,
|
|
stop={"training_iteration": 2},
|
|
verbose=1)
|
|
|
|
return analysis.get_best_config(metric="validation_loss", mode="min")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--address",
|
|
required=False,
|
|
type=str,
|
|
help="the address to use for Ray")
|
|
parser.add_argument(
|
|
"--num-replicas",
|
|
"-n",
|
|
type=int,
|
|
default=1,
|
|
help="Sets number of replicas for training.")
|
|
parser.add_argument(
|
|
"--use-gpu",
|
|
action="store_true",
|
|
default=False,
|
|
help="Enables GPU training")
|
|
parser.add_argument(
|
|
"--tune", action="store_true", default=False, help="Tune training")
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
ray.init(address=args.address)
|
|
|
|
if args.tune:
|
|
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
|
else:
|
|
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|