python - 模型包含多个头的 TorchScript

标签 python torch

我的目标是序列化经过 pytorch 训练的模型,并将其加载到定义神经网络的原始类不可用的环境中。为实现这一目标,我决定使用 TorchScript,因为它似乎是唯一可行的方法。

我有一个多任务模型(类型 nn.Module),使用每个任务通用的主体(也是 nn.Module,只是几个线性层)和一组线性头部模型,每个任务一个。 我将头部模型存储在名为 _task_head_models 的字典 Dict[int, nn.Module] 中,并在我的模块类中创建了一个临时转发方法来选择正确的头部在预测时间:

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        return self._task_head_models[task_id](self._model(x))

在我不尝试使用 torchscript 序列化它之前,它工作正常。 当我尝试 torch.jit.script(mymodule) 时,我得到:

Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)

似乎不对劲的是,我的模块包含一个 Dict,而不是错误消息中提到的 dict。暂时忘记这一点,目前还不清楚为什么会这样。语言引用中似乎支持字典:https://docs.w3cub.com/pytorch/jit_language_reference.html

我还尝试使用 ModuleDict 而不是 Dict(将键类型更改为 str),但这似乎也不起作用:无法提取字符串文字索引。 ModuleDict 索引仅支持字符串文字。支持 ModuleDict 的枚举,例如'for k, v in self.items(): ...':

最佳答案

如果Dict_task_head_models中的项目不多,我想使用if-else分支可以帮到你。示例代码如下:

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._task_head0 = torch.nn.Linear(3, 24)
        self._task_head1 = torch.nn.Linear(3, 24)

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
      if task_id == 0:
          return self._task_head0(x)
      elif task_id == 1:
          return self._task_head1(x)
      else:
          raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are 0,1."
            )

关于python - 模型包含多个头的 TorchScript,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73290107/

相关文章:

python - subprocess.run() 不返回标准输出或标准错误

Lua 命令行字符串

python - torch 聚中维

python - 如何消除 python 上的 ' the program is still running are you sure you want to kill it' 警告?

python - 尝试在列表 python 的每个值之间添加逗号

lua - 在Torch/Lua中,加载保存的模型和使用Xavier权重初始化方法有什么区别?

lua - 清除变量以释放 Lua/Torch 中的内存(GPU 或 CPU)

python - 是否可以在 TensorFlow 上加载学习模型 (.t7)?

返回属性而不是字符串(unicode)的 Python 字符串格式?

python - 将用 python 训练的 XGBoost 模型移植到另一个用 C/C++ 编写的系统