Example: Searching for Network Architectures Using the MBNAS Algorithm
Model-based Neural Architecture Search (MBNAS) a Huawei-developed NAS algorithm.
The work principle is as follows:
- In the search space, 200 architectures are sampled in a random complementary set manner, and the 200 architectures are trained in parallel, to obtain rewards corresponding to the 200 architectures.
- The 200 pieces of data are used to train an Evaluator, and the Evaluator can predict the reward corresponding to a sampled architecture.
- Based on the preceding Evaluator, a Controller based on reinforcement learning or evolutionary algorithms is trained.
- The trained Controller outputs several excellent architectures.
For convenience, this section uses simulated data similar to MNIST as an example. You can also modify the following code to use the MNIST dataset in Example: Replacing the Original ResNet-50 with a Better Network Architecture.
Sample Code
In the sample code, comments are added at the end of each line to indicate the changes to the original code.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import argparse
import logging
import time
import tensorflow as tf
from tensorflow import keras
import autosearch # Change 1: Import the AutoSearch package.
parser = argparse.ArgumentParser()
parser.add_argument(
"--max_steps", type=int, default=10 ** 2, help="Number of steps to run trainer."
)
FLAGS, unparsed = parser.parse_known_args()
batch_size = 1
learning_rate = 0.001
def train(config, reporter):
nas_code_extra = config["nas_code_extra"] # Change 2: Obtain the parameters delivered by the framework.
with tf.Graph().as_default():
sess = tf.InteractiveSession()
x = tf.ones([batch_size, 784], name="x-input")
y_true = tf.ones([batch_size, 10], name="y-input", dtype=tf.int64)
for i, stage in enumerate(nas_code_extra):
x = keras.layers.Dense(
units=stage["hidden_units"], activation=stage["activation"],
)(x) # Change 3: Construct the model based on the input model parameters.
y = tf.layers.dense(x, 10)
cross_entropy = tf.losses.softmax_cross_entropy(y_true, y)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_true, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.global_variables_initializer().run()
max_acc = 0
latencies = []
for i in range(FLAGS.max_steps):
logging.info("current step: {}".format(i))
if i % 10 == 0: # Record summaries and test-set accuracy
loss, acc = sess.run([cross_entropy, accuracy])
# print('Accuracy at step %s: %s' % (i, acc))
if acc > max_acc:
max_acc = acc
if i == (FLAGS.max_steps - 1):
autosearch.reporter(
loss=loss, acc=max_acc, mean_loss=loss, done=True
) # Change 4: Send precision metrics.
else:
autosearch.reporter(loss=loss, acc=acc, mean_loss=loss) # Same as change 4.
else:
start = time.time()
loss, _ = sess.run([cross_entropy, train_step])
end = time.time()
if i % 10 != 1:
latencies.append(end - start)
latency = sum(latencies) / len(latencies)
print(max_acc, latency)
|
Compiling the Configuration File
repeat_discrete is used together with MBNAS. According to the YAML configuration, the content in the for loop in the preceding code will be executed for four times.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
general:
gpu_per_instance: 0
search_space:
- type: repeat_discrete
name: mbnas
repeat: 4
params:
- name: activation
values: ['sigmoid', 'relu']
- name: hidden_units
values: [1, 2, 4, 8]
search_algorithm:
type: mbnas
reward_attr: acc
num_of_arcs: 20
|
Starting a Search Job
After uploading the script and YAML file in the sample code to OBS, you can start the job on the page. Select an existing dataset or an empty OBS directory because no actual data is required. For details about how to select other configurations, see Starting a Search Job in Example: Searching for Hyperparameters Using Classic Hyperparameter Algorithms.
Did this article solve your problem?
Thank you for your score!Your feedback would help us improve the website.