python - 更新已使用 theano.tensor.cast() 转换的变量

标签 python theano

我正在尝试更新函数中的 theano 变量,简化如下:

copy_func = theano.function(
    inputs=[idx],
    updates=[
        (a_variable, T.set_subtensor(a_variable[some_ptr], another_variable[idx]))
    ]
)

我的问题是我收到错误

TypeError: ('update target must be a SharedVariable', Elemwise{Cast{int32}}.0)

我获取此变量的方式是通过使用以下内容(主要是从 deeplearning.net 教程复制)(another_variable 的初始化类似):

a_variable = theano.shared(np.asarray(data,
                               dtype=theano.config.floatX),
                 borrow=True)
print type(a_variable)
a_variable = T.cast(a_variable, 'int32')
print type(a_variable)

打印内容

<class 'theano.tensor.sharedvar.TensorSharedVariable'>
<class 'theano.tensor.var.TensorVariable'>

即变量不再“共享”,解释了错误。 这是有道理的,因为我猜想该变量现在只是原始共享 float 的转换 View 。但是如何更新有效转换的变量呢?

最佳答案

我自己解决了这个问题,答案当然是显而易见的。

我没有用强制转换的版本覆盖 a_variable 变量,而是保留了未强制转换的版本:

a_variable_casted = T.cast(a_variable, 'int32')

现在对 a_variable 进行更新,而 a_variable_casted 用于执行之前使用 a_variable 进行的计算。

显然可能有一种更优雅的方式来做到这一点,在这种情况下我很想听听!

关于python - 更新已使用 theano.tensor.cast() 转换的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29255940/

相关文章:

python - AngularJS + Flask 客户端路由

python - 使用 Python 在给定日期间隔列表的情况下查找日期子间隔的值

python - arcpy:程序在相交时崩溃

python: theano导入错误 "MKL_THREADING_LAYER=GNU"

python - Lasagne/Theano 维数错误

neural-network - 不使用 GPU 的 Keras 神经网络模型预测

python - 在单个循环中追加列表的 Pithonic 方法

python - 如何在没有内置方法的情况下在Python中查找关键字之前和之后的单词

python - numpy 属性错误 : with theano module 'numpy.core.multiarray' has no attribute _get_ndarray_c_version

python - Theano/numpy 高级索引