python - Tensorflow - 显示和手动修改学习模型的权重并导出以供进一步重新学习

标签 python machine-learning tensorflow

我尝试使用 Tensorflow 执行的操作如下:

  1. 假设我已经学习了神经网络文件:checkpoint、*.meta、*.data 和 *.index。
  2. 我想提取学习值(权重、偏差等)以显示或处理到文件/其他工具以进行进一步分析。
  3. 我想修改一些学习值(例如用 0 替换一些已经很小的权重,以简化计算)。
  4. 修改后的值应加载回模型。
  5. 因此,我希望获得相同的检查点集、*.meta、*.data 和 *.index 文件,但具有一些修改后的值(来自第 4 步)。

注意:用于生成初始模型的脚本未知。我在第 1 步中只有 4 个列出的文件。

到目前为止我所做的是提取图形定义并显示学习值(使用inspect_checkpoint.py)。我发现不可能更改模型上的值并将其导出回 *.data、*.meta、*.index 和检查点集。在查看 API 之后,我没有看到用于此类操作的明显工具。有可能吗? 谨致问候并感谢您的支持!

最佳答案

在 C++ 中,您可以使用 CheckpointReaderBundleWriter 从检查点文件读取/写入张量:

BundleWriter writer(tensorflow::Env::Default(), "out.ckpt");                                                                                                                                        

TF_Status status;                                                                                                                                                                                        
tensorflow::checkpoint::CheckpointReader reader("in.ckpt", &status);

const auto& var_to_shape_map = reader.GetVariableToShapeMap();                                                                                                                                                                                                                                                                                 
for (const auto& elem : var_to_shape_map) {                                                                                                                                                              
  std::unique_ptr<Tensor> weights;                                                                                                                                                                       
  const string& key = elem.first;                                                                                                                                                                        
  reader.GetTensor(key, &weights, &status);  
  auto weights_flat = weights->flat<float>();
  for (int i = 0; i < weights->NumElements(); ++i) {
    // replace with 0 some weights that are already of small value
    if (weights_flat(i) < SMALL_VALUE_THRESHOLD) {
      weights_flat(i) = 0.f;
    }
  }
  writer.Add(key, *weights.get());                                                                                                                                                        
}
writer.Finish();

运行上述代码后,您将得到out.ckpt.dataout.ckpt.index。 您可以使用原始的 *meta 文件,因为我们只修改了学习权重的值,元信息保持不变。

关于python - Tensorflow - 显示和手动修改学习模型的权重并导出以供进一步重新学习,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46687438/

相关文章:

python - 为什么 numpy ma.average 比 arr.mean 慢 24 倍?

python - Flask:在另一个文件夹中提供模板( 'templates/' 目录除外)

python - MNIST Python numpy 特征向量可视化错误

Tensorflow 的 while 循环在 GPU 上运行缓慢?

python - 检查文件是否是python中的命名管道(fifo)?

python - 为函数设置python递归限制

python - ValueError : Expected 2D array, 得到标量数组而不是 : array=5. 5

python - OpenCV 抛出错误。尝试使用随机森林模型

tensorflow - 无法在 TensorFlow 2 中加载模型权重

python-3.x - 从 Tensorflow Keras 检查点重新加载最佳权重