sockets - 权重在此代码中更新的位置?

标签 sockets tensorflow deep-learning distributed-computing gradient-descent

我想在分布式系统中训练模型。我在github中找到了用于分布式训练的代码,其中,工作程序节点将梯度发送到参数服务器,而参数服务器将平均梯度发送给 worker 。但是在客户端/员工端代码中,我无法理解接收到的渐变在哪里更新权重和偏差。

这是客户端/工作人员端的代码,它从参数服务器接收初始梯度,然后计算损耗,梯度并将梯度值再次发送到服务器。

from __future__ import division
from __future__ import print_function

import numpy as np
import sys
import pickle as pickle
import socket

from datetime import datetime
import time

import tensorflow as tf

import cifar10

TCP_IP = 'some IP'
TCP_PORT = 5014

port = 0
port_main = 0
s = 0

FLAGS = tf.app.flags.FLAGS


tf.app.flags.DEFINE_string('train_dir', '/home/ubuntu/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 5000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)


def safe_recv(size, server_socket):
    data = ""
    temp = ""
    data = bytearray()
    recv_size = 0
    while 1:
        try:
            temp = server_socket.recv(size-len(data))
            data.extend(temp)
            recv_size = len(data)
            if recv_size >= size:
                break
        except:
            print("Error")
    data = bytes(data)
    return data


def train():
    """Train CIFAR-10 for a number of steps."""

    g1 = tf.Graph()
    with g1.as_default():
        global_step = tf.Variable(-1, name='global_step',
                                  trainable=False, dtype=tf.int32)
        increment_global_step_op = tf.assign(global_step, global_step+1)

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)
        grads = cifar10.train_part1(loss, global_step)

        only_gradients = [g for g, _ in grads]

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""

            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                # log_device_placement=FLAGS.log_device_placement, gpu_options=gpu_options)) as mon_sess:
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
            global port
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port_main))
            recv_size = safe_recv(17, s)
            recv_size = pickle.loads(recv_size)
            recv_data = safe_recv(recv_size, s)
            var_vals = pickle.loads(recv_data)
            s.close()
            feed_dict = {}
            i = 0
            for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                feed_dict[v] = var_vals[i]
                i = i+1
            print("Received variable values from ps")
            # Opening the socket and connecting to server
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect((TCP_IP, port))
            while not mon_sess.should_stop():
                gradients, step_val = mon_sess.run(
                    [only_gradients, increment_global_step_op], feed_dict=feed_dict)
                # sending the gradients
                send_data = pickle.dumps(gradients, pickle.HIGHEST_PROTOCOL)
                to_send_size = len(send_data)
                send_size = pickle.dumps(to_send_size, pickle.HIGHEST_PROTOCOL)
                s.sendall(send_size)
                s.sendall(send_data)
                # receiving the variable values
                recv_size = safe_recv(17, s)
                recv_size = pickle.loads(recv_size)
                recv_data = safe_recv(recv_size, s)
                var_vals = pickle.loads(recv_data)

                feed_dict = {}
                i = 0
                for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                    feed_dict[v] = var_vals[i]
                    i = i+1
            s.close()


def main(argv=None):  # pylint: disable=unused-argument
    global port
    global port_main
    global s
    if(len(sys.argv) != 3):
        print("<port> <worker-id> required")
        sys.exit()
    port = int(sys.argv[1]) + int(sys.argv[2])
    port_main = int(sys.argv[1])
    print("Connecting to port ", port)
    cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    total_start_time = time.time()
    train()
    print("--- %s seconds ---" % (time.time() - total_start_time))


if __name__ == '__main__':
    tf.app.run()

编辑:

这是train_part1()代码:
def train_part1(total_loss, global_step):
  """Train CIFAR-10 model.

  Create an optimizer and apply to all trainable variables. Add moving
  average for all trainable variables.

  Args:
    total_loss: Total loss from loss().
    global_step: Integer Variable counting the number of training steps
      processed.
  Returns:
    train_op: op for training.
  """
  # Variables that affect learning rate.
  num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
  decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

  # Decay the learning rate exponentially based on the number of steps.
  lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                  global_step,
                                  decay_steps,
                                  LEARNING_RATE_DECAY_FACTOR,
                                  staircase=True)
  tf.summary.scalar('learning_rate', lr)

  # Generate moving averages of all losses and associated summaries.
  loss_averages_op = _add_loss_summaries(total_loss)

  # Compute gradients.
  with tf.control_dependencies([loss_averages_op]):
    opt = tf.train.GradientDescentOptimizer(lr)
    grads = opt.compute_gradients(total_loss)

  return grads

最佳答案

在我看来,这条线

gradients, step_val = mon_sess.run(
                    [only_gradients, increment_global_step_op], feed_dict=feed_dict)

接收feed_dict中变量的新值,将这些值分配给变量,并执行训练步骤,在此步骤中,该步骤仅计算并返回梯度,然后将其发送到参数服务器。我希望cifar10.train_part1(返回only_gradients的代码)依赖于变量值并定义更新。

更新:我调查了一下代码,改变了主意。曾去过Google,发现next answer可以说明正在发生的事情。

渐变实际上没有隐式应用在此代码中的任何地方。相反,将梯度发送到参数服务器,参数服务器对梯度求平均值,并将其应用于权重,它将权重返回给本地工作人员,*在通过feed_dict *进行的 session 运行期间,使用接收到的权重代替本地权重,即从未真正使用本地权重更新,实际上一点也不重要。关键是feed_dict允许重写 session 运行的任何张量输出,并且此代码重写变量。

关于sockets - 权重在此代码中更新的位置?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58757337/

相关文章:

tensorflow - 如何在spyder上使用tensorflow?

python - tensorflow 2.0,调用函数时给定两个变量,但是定义函数时,没有变量

machine-learning - 如何针对特定用例选择 CNN 模型?

image-processing - 具有反卷积或其他功能的高档层

android - 使用 Wireshark 从手机抓包

sockets - 是否可以将套接字内存映射到虚拟内存?

C# 线程在套接字上接收数据之前退出

c++ - 为什么我无法通过实际的公网 IP 连接到服务器?

python-3.x - "import keras"和 "import tensorflow.keras"有什么区别

python - 如何在 Tensorflow 中更新二维张量的子集?