python - Tensorflow:如何修改张量中的值

标签 python numpy tensorflow

由于在使用Tensorflow训练模型之前需要对数据进行一些预处理,因此需要对张量进行一些修改。但是,我不知道如何像使用 numpy 那样修改 tensor 中的值。

这样做的最佳方式是它能够直接修改tensor。然而,在当前版本的 Tensorflow 中似乎不可能。另一种方法是将 tensor 更改为 ndarray 进行处理,然后使用 tf.convert_to_tensor 改回来。

关键是如何把tensor变成ndarray
1)tf.contrib.util.make_ndarray(张量): https://www.tensorflow.org/versions/r0.8/api_docs/python/contrib.util.html#make_ndarray
根据文档,这似乎是最简单的方法,但我在当前版本的 Tensorflow 中找不到此功能。其次,它的输入是TensorProto而不是tensor
2) 使用a.eval()a复制到另一个ndarray
然而,它仅适用于在笔记本中使用 tf.InteractiveSession()

一个带有代码的简单案例如下所示。这段代码的目的是使 tfc 在处理后具有与 npc 相同的输出。

提示
您应该将tfcnpc 视为相互独立的。这满足了最初检索到的训练数据为 tensor 格式且 tf.placeholder() 的情况。 .


源代码

import numpy as np
import tensorflow as tf
tf.InteractiveSession()

tfc = tf.constant([[1.,2.],[3.,4.]])
npc = np.array([[1.,2.],[3.,4.]])
row = np.array([[.1,.2]])
print('tfc:\n', tfc.eval())
print('npc:\n', npc)
for i in range(2):
    for j in range(2):
        npc[i,j] += row[0,j]

print('modified tfc:\n', tfc.eval())
print('modified npc:\n', npc)

输出:

转会:
[[ 1. 2.]
[ 3. 4.]]
全国人民代表大会:
[[ 1. 2.]
[ 3. 4.]]
修改后的 tfc:
[[ 1. 2.]
[ 3. 4.]]
修改过的npc:
[[ 1.1 2.2]
[ 3.1 4.2]]

最佳答案

使用分配和评估(或 sess.run)分配:

import numpy as np
import tensorflow as tf

npc = np.array([[1.,2.],[3.,4.]])
tfc = tf.Variable(npc) # Use variable 

row = np.array([[.1,.2]])

with tf.Session() as sess:   
    tf.initialize_all_variables().run() # need to initialize all variables

    print('tfc:\n', tfc.eval())
    print('npc:\n', npc)
    for i in range(2):
        for j in range(2):
            npc[i,j] += row[0,j]
    tfc.assign(npc).eval() # assign_sub/assign_add is also available.
    print('modified tfc:\n', tfc.eval())
    print('modified npc:\n', npc)

输出:

tfc:
 [[ 1.  2.]
 [ 3.  4.]]
npc:
 [[ 1.  2.]
 [ 3.  4.]]
modified tfc:
 [[ 1.1  2.2]
 [ 3.1  4.2]]
modified npc:
 [[ 1.1  2.2]
 [ 3.1  4.2]]

关于python - Tensorflow:如何修改张量中的值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37071788/

相关文章:

python - 根据字符位置从字符串中提取子字符串 - Python

python - NoReverseMatch 位于/blog/Django1.9 URL

Python LDAP 模块 --- 我必须始终运行 "whoami_s()"才能成功进行身份验证吗?

python - 我应该如何修改SVM方法的测试数据才能正确使用 `precomputed`核函数?

python循环不断扩展的列表

python - Tensorflow ctc_loss_calculator : No valid path found

python - 如何使用 TF1.3 中的新数据集 api 映射具有附加参数的函数?

python - 如何 "draw"二维网格上的几何形状?

python - 在 sklearn NearestNeighbor 中使用每个邻居一次

tensorflow - sample 无需更换