python - python中的联合平均实现

标签 python machine-learning artificial-intelligence conv-neural-network tensorflow-federated

我正在研究联邦学习。我正在使用一个全局服务器,我在其中定义了一个基于 cnn 的分类器。全局服务器使用超参数编译模型并将其发送到边缘(客户端),目前我使用两个客户端。每个客户端都使用其本地数据(现在我在每个客户端上使用相同的数据和模型)。训练模型后,每个客户的本地模型的准确度、精确度和召回率均超过 95%。客户端将经过训练的本地模型发送到服务器。服务器获取模型,并从每个接收到的模型中获取权重,并根据 this formula 计算平均值。 。下面是我为在 python 中实现这个公式而编写的代码。当我为模型设置平均权重并尝试进行预测时,准确率、召回率和精确率均降至 20% 以下。

我在实现过程中做错了什么吗?

# initial weights of global model, set to zer0.  
  ave_weights=model.get_weights()
  ave_weights=[i * 0 for i in ave_weights]
  count=0
# Multithreaded Python server : TCP Server Socket Thread Pool
def ClientThread_send(conn,address,weights):
    # send model to client
    conn.send(model)

    print("Model Sent to :",address)
    print("waiting for weights")
    model_recv=conn.recv(1024)
    print("weights received from:",address)
    global count
    global ave_weights

    
    #receive weights from clients
    rec_weight=model.get_weights()
    #multiply the client weights by number of local data samples in client local data
    rec_weight=  [i * 100000 for i in rec_weight]
    # divide the weights by total number of samples of all participants
    rec_weight=  [i / 200000 for i in rec_weight]

    #sum the weights of all clients
    ave_weights=[x + y for x, y in zip(ave_weights,rec_weight)]
  
    count=count+1
    conn.close()
if count==2:
    # set the global model weights if the count(number of clients is two)
    model.set_weights(ave_weights)


 while True:
     conn, address = s.accept()
     start_new_thread(ClientThread_send,(conn,address,ave_weights))   
     

最佳答案

我认为问题可能出在训练步骤而不是“平均”算法上。

根据提出 FedAvg 算法的论文( https://arxiv.org/pdf/1602.05629.pdf ),局部模型将随机梯度下降应用于全局模型,而不是从头开始训练新的局部模型。

这里有来自 TensorFlow 的应用联合平均的教程:https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2#gradient_descent_on_a_single_batch

关于python - python中的联合平均实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66472157/

相关文章:

machine-learning - 将同一物体的多个图像输入神经网络进行物体检测的方法

machine-learning - 逻辑回归中的特征范围

python-3.x - 用python进行图像膨胀

python - 从模块访问全局变量

python - 如何解决文件上传时python webapp2中的 "AttributeError: ' unicode'对象没有属性 'file'“一劳永逸?

artificial-intelligence - 反向传播问题

matlab 在 3D 散点图上绘制线性回归

artificial-intelligence - 进化算法 'approaches' 之间的主要区别是什么?

python - django-extra-views 获取当前用户

python - 在 Python 中,是否可以访问包含方法的类,只给定一个方法对象?