python - 如何从 model.predict() 结果中获取单个值

标签 python tensorflow machine-learning neural-network keras

我正在尝试使用神经网络来预测在另一个文件中运行的汽车模拟器游戏的 Action 。我需要获得传递到游戏中的 Action 的预测值,但我正在努力做到这一点。调用 model.predict 后,我​​尝试像从数组中一样访问该值,但这会返回越界错误。

总的来说,我对 python 相当陌生,更不用说使用 keras,但我的想法是,神经网络将根据玩游戏的其他人收集的一些数据(保存在 CSV 文件中)进行训练。然后,当神经网络开始玩游戏时,我将能够在每一帧传递游戏值以生成预测 Action 。我认为我已经做到了,但无法获得预测的操作。这是神经网络;

  def build_nn(self):
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu')) 
    model.add(Dense(8, activation='relu'))
    model.add(Dense(1, activation='linear'))
    model.compile(loss='mean_squared_error', optimizer=Adam(lr=self.learning_rate))

    return model

以及我用于预测 Action 的代码(stateVector 未被按原样接受,因此我必须从中获取值)

 def action(self, stateVector):
    a_action = stateVector['lastAction']
    currentLane = stateVector['currentLane']
    offRoad = stateVector['offRoad']
    collision = stateVector['collision']
    lane1 = stateVector['lane1Distance']
    lane2 = stateVector['lane2Distance']
    lane3 = stateVector['lane3Distance']
    a = [currentLane, offRoad, collision, lane1, lane2, lane3, reward, a_action]   

    act_value = self.model.predict(a)
    act = act_values[a_action]

    return act 

最佳答案

当网络需要可迭代的 8 维样本时,您正在向网络提供 8 个元素的列表。实际上:

>>> a = [currentLane, offRoad, collision, lane1, lane2, lane3, reward, a_action]
>>> a = np.array(a) # convert to a numpy array
>>> a = np.expand_dims(a, 0) # change shape from (8,) to (1,8)
>>> model.predict(a) # voila!

关于python - 如何从 model.predict() 结果中获取单个值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49239478/

相关文章:

python - 如何从 Python 在 Odoo-8 中执行查询?

python - Tensorflow 宽深模型,具有不同数据集的 AttributeError

Tensorflow - ValueError : Parent directory of trained_variables. ckpt 不存在,无法保存

python - 在 tensorflow 中初始化矩阵

python - 如何从控制面板的程序文件中找不到的 windows 中删除 python

Python:在列表中嵌入的字典中搜索

machine-learning - 如何针对噪声(分散)数据选择回归算法?

python - 如何在sklearn中使用一种热编码处理 "unseen"分类变量

python 3.4 和 mysql 信号

c++ - fatal error : google/protobuf/port_def. inc : No such file or directory #include <google/protobuf/port_def. inc>