python - 神经网络之后的[np.arange(0, self.batch_size), action]的目的是什么?

标签 python pytorch reinforcement-learning

我按照 PyTorch 教程学习强化学习( TRAIN A MARIO-PLAYING RL AGENT ),但我对以下代码感到困惑:

current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action] # Q_online(s,a)

神经网络之后的[np.arange(0, self.batch_size), action]的目的是什么?(我知道TD_estimate接受状态和 Action ,只是在编程方面对此感到困惑)这个用法是什么(在 self.net 之后放一个列表)?

教程中引用的更多相关代码:

class MarioNet(nn.Module):

def __init__(self, input_dim, output_dim):
    super().__init__()
    c, h, w = input_dim

    if h != 84:
        raise ValueError(f"Expecting input height: 84, got: {h}")
    if w != 84:
        raise ValueError(f"Expecting input width: 84, got: {w}")

    self.online = nn.Sequential(
        nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136, 512),
        nn.ReLU(),
        nn.Linear(512, output_dim),
    )

    self.target = copy.deepcopy(self.online)

    # Q_target parameters are frozen.
    for p in self.target.parameters():
        p.requires_grad = False

def forward(self, input, model):
    if model == "online":
        return self.online(input)
    elif model == "target":
        return self.target(input)

self 网:

self.net = MarioNet(self.state_dim, self.action_dim).float()

感谢您的帮助!

最佳答案

本质上,这里发生的是网络的输出被切片以获得 Q 表的所需部分。

[np.arange(0, self.batch_size), action] 的索引(有点令人困惑)对每个轴进行索引。因此,对于索引为 1 的轴,我们选择由 action 指示的项目。对于索引 0,我们选择 0 到 self.batch_size 之间的所有项目。

如果self.batch_size与该数组第0维的长度相同,则该切片可以简化为[:, action],这可能会更多大多数用户都熟悉。

关于python - 神经网络之后的[np.arange(0, self.batch_size), action]的目的是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70458347/

相关文章:

python - pytorch如何计算矩阵成对距离?为什么 'self' 距离不为零?

python - Pytorch .to ('cuda' ) 或 .cuda() 不起作用并且卡住了

artificial-intelligence - 神经网络和时间差分学习

c++ - 使用 softmax 进行 Action 选择?

python - 在 python 中启动和暂停进程

python - 无法正确导入聊天机器人及其语料库

python - PyTorch:预期输入batch_size (12) 匹配目标batch_size (64)

python - 在 OSX 10.11 (El Capitan) (系统完整性保护) 中安装 Scrapy 时出现 "OSError: [Errno 1] Operation not permitted"

python - Emacs:使用 pdbtrack (python.el)

python - 如何让这段RL代码获得GPU支持?