python - 如何在 Tensorflow Fedarated 中加载 Fashion MNIST 数据集?

标签 python tensorflow keras tensorflow-federated federated-learning

我正在使用 Tensorflow federated 开展一个项目。我已经设法使用 TensorFlow 联邦学习模拟提供的库来加载、训练和测试一些数据集。

比如我加载emnist数据集

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

它获得了 load_data() 返回的数据集作为 tff.simulation.ClientData 的实例。这是一个允许我迭代客户端 ID 并允许我选择数据子集进行模拟的接口(interface)。

len(emnist_train.client_ids)

3383


emnist_train.element_type_structure


OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])


example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

我正在尝试使用 Keras 加载 fashion_mnist 数据集以执行一些联合操作:

fashion_train,fashion_test=tf.keras.datasets.fashion_mnist.load_data()

但是我得到了这个错误

AttributeError: 'tuple' object has no attribute 'element_spec'

因为 Keras 返回一个 Numpy 数组元组而不是像以前那样的 tff.simulation.ClientData:

def tff_model_fn() -> tff.learning.Model:
    return tff.learning.from_keras_model(
        keras_model=factory.retrieve_model(True),
        input_spec=fashion_test.element_spec,
        loss=loss_builder(),
        metrics=metrics_builder())

iterative_process = tff.learning.build_federated_averaging_process(
    tff_model_fn, Parameters.server_adam_optimizer_fn, Parameters.client_adam_optimizer_fn)
server_state = iterative_process.initialize()

总结一下,

  1. 有什么方法可以从 Keras Tuple Numpy 数组创建 tff.simulation.ClientData 的元组元素?

  2. 我想到的另一个解决方案是使用 tff.simulation.HDF5ClientData 并加载 以 HDF5(train.h5, test.h5) 格式手动选择适当的文件以获得 tff.simulation.ClientData,但是我的问题是我找不到 fashion_mnist HDF5 文件格式的 url 我的意思是对于训练和测试都是这样的:

          fileprefix = 'fed_emnist_digitsonly'
          sha256 = '55333deb8546765427c385710ca5e7301e16f4ed8b60c1dc5ae224b42bd5b14b'
          filename = fileprefix + '.tar.bz2'
          path = tf.keras.utils.get_file(
              filename,
              origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
              file_hash=sha256,
              hash_algorithm='sha256',
              extract=True,
              archive_format='tar',
              cache_dir=cache_dir)
    
          dir_path = os.path.dirname(path)
          train_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path, fileprefix + '_train.h5'))
          test_client_data = hdf5_client_data.HDF5ClientData(
              os.path.join(dir_path, fileprefix + '_test.h5'))
    
          return train_client_data, test_client_data
    

我的最终目标是让 fashion_mnist 数据集与 TensorFlow 联邦学习一起工作。

最佳答案

您走在正确的轨道上。回顾一下:tff.simulation.dataset 返回的数据集API 是 tff.simulation.ClientData对象。 tf.keras.datasets.fashion_mnist.load_data 返回的对象是 numpy 数组的 元组

因此需要实现一个tff.simulation.ClientData来包装tf.keras.datasets.fashion_mnist.load_data返回的数据集。之前关于实现ClientData对象的一些问题:

这确实需要回答一个重要问题:Fashion MNIST 数据应如何拆分为单个用户?数据集不包含可用于分区的特征。研究人员提出了几种综合划分数据的方法,例如为每个参与者随机抽取一些标签,但这对模型训练有很大的影响,在这里投入一些思考是有用的。

关于python - 如何在 Tensorflow Fedarated 中加载 Fashion MNIST 数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64760396/

相关文章:

python - tf.keras.losses.categorical_crossentropy 是返回数组还是单个值?

python - Remote_api_stub 的路径问题

python - 用于训练/验证/测试集拆分的 SHA 哈希

tensorflow - 在 tensorflow 检查点中修改张量的形状

python - 使用 tflearn、tensorflow、numpy 的 Python 聊天机器人出现错误

python - 模型错误 : Layer model_1 expects 1 input(s), 但它收到了 2 个输入张量

tensorflow - 如何在tensorflow keras中访问自定义层的递归层

python - 检查 stdout 是否支持 unicode?

python - 在 python boto 中使用 cognito 获取 AWS 凭证

python - 如何将元组作为参数传递给 `divmod()` 函数