python - 如何使用 MNIST 在 tensorflow 中减少权重位或更改为较低位类型?

标签 python tensorflow training-data mnist 8-bit

我正在做 CNN 模型压缩,并试图减少权重的位以获得位的长度和精度之间的关系。但是当我使用 Tensorflow 网站的方法更改 CNN 的权重类型时,出现错误:

“TypeError:传递给参数 'a' 的值具有不在允许值列表中的数据类型 int8:float16、float32、float64、int32、complex64、complex128”。

看来weight的不能是其他Dtype。但是我读了一些类似的论文https://arxiv.org/pdf/1502.02551.pdf .可以将权重的位数减少到 6bits , 4bits ,甚至更低的位。

我的代码在这里(忽略 import somethings):

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
x = tf.placeholder(tf.int8,[None,784])
W = tf.Variable(tf.zeros([784,10]),tf.int8)
b = tf.Variable(tf.zeros([10]),tf.int8)

y = tf.nn.softmax(tf.matmul(x,W)+b)
#the error come out with "y = tf.nn.softmax(tf.matmul(x,W)+b)"

这只是一个标准的tensorflow官方代码,只是改变了变量的Dtype。我也试过 tf.cast ,但它仍然出现错误。

tf.cast(W,tf.int8)
tf.cast(b,tf.int8)

谁能告诉我如何克服这种情况?非常感谢!!

最佳答案

Tensorflow 不允许小于 16 位的数字。无论如何,少于 16 位是不切实际的,因为假设您使用 8 位(整数 4 位,小数 4 位),可能的最低十进制数将是 0.0625(1/16 -> 4 位只有 16 个不同的数字)。

您引用的论文使用 16 位数字,16 位分为 8-8、10-6 和 14-2 位(第一个是十进制位)。它还对变量进行四舍五入,然后将它们转换为上述位分布,而不是直接截断变量而不进行四舍五入。


更新:我做了一点挖掘,如果您使用 float16 或 14-2 固定位,它实际上没有太大区别-分布:

  • 最低的 float16 十进制数:0.0000610352
  • a 或 14-2 固定位分布的最低十进制数:0.00006103515625

So I would suggest, you just use float16 instad of fixed bit-distributions, and just use the stochastic rounding algorithm, described in the paper.


更新 2:我使用 float32float16 训练了 MNIST 数据集。 float16 网络的性能几乎与 float32 网络相同。该网络有两个隐藏层,每层有 1000 个神经元,tf.nn.relu 作为激活函数。我使用了学习率为 0.1 的标准 tensorflow tf.train.GradientDescentOptimizer 优化器。作为成本函数,我使用了 tf.nn.softmax_cross_entropy_with_logits。该网络训练了 120 多个时期,每 600 步,批量大小为 100。float16 网络的测试精度为 98.189997673,而 float32测试精度为 98.1599986553

一些有趣的链接:

更新 3: 我认为在 tensorflow 中实现混合精度会很困难,因为您必须为反向传播编写自定义处理。 TensorFlow 团队已经在致力于实现原生半精度。同时,我认为实现这一点的最佳方法是使用 caffe,其中已经实现了 native 混合精度(至少在 nvidia branch 中)。参见 this ticket .

关于python - 如何使用 MNIST 在 tensorflow 中减少权重位或更改为较低位类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47994597/

相关文章:

python - 读取文件后不关闭文件有什么缺点吗?

python - 余弦波的贝叶斯拟合花费的时间比预期的要长

queue - 使用队列时如何在 tensorflow 训练期间测试网络

machine-learning - 构建模型,在两个类别之间的差异过大时做出决策

python - 比较 2 个整数列表

python - Pandas - 将字符串转换为没有日期的时间

java - 如何从 Java 在 TensorFlow 模型中提供稀疏占位符

TensorFlow:具有非图像输入的卷积神经网络

python - 如何确定Rank-3输入张量的权重维度?

machine-learning - 训练 SyntaxNet 需要多少数据?