python - 使用 TorchScript 类作为 pytorch 模块中的成员

标签 python pytorch torchscript

我试图让一些现有的 pytorch 模型支持 TorchScript jit 编译器,但我遇到了非原始类型成员的问题。

这个小例子说明了这个问题:

import torch

@torch.jit.script
class Factory(object):
    def __init__(self):
        pass

    def create(self, x: float) -> torch.Tensor:
        return torch.tensor([x])

class Foo(torch.nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.factory: Factory = Factory()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)

mod = torch.jit.script(Foo())

运行时,jit编译报错

RuntimeError:
module has no attribute 'factory':
at example.py:17:15
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)
               ~~~~~~~~~~~~ <--- HERE

我已经测试过 Factory 类可用于 forward 方法内的 jit,但是当我将它存储为成员时它不承认它。为什么是这样?有什么办法可以让 jit 编译器将这种成员保存到编译后的模块中?

最佳答案

这是 PyTorch 中的一个错误,在您发布问题后很快就解决了:https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645 , https://github.com/pytorch/pytorch/issues/27495 .

更新 PyTorch 应该可以解决这个问题。

关于python - 使用 TorchScript 类作为 pytorch 模块中的成员,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58998441/

相关文章:

Python 相同的字符不等于

python - 42000 您的 SQL 语法有误;查看与您的 MySQL 服务器对应的手册

Python split 无法识别连字符

android - 如何利用 GPU 在 Android 上运行神经网络模型?

javascript - 如何将 javascript 源文件插入到我的 Pyramid python 应用程序中并在我的模板中使用它们?

python-3.x - 类型错误 : 'module' object is not callable error?

parameters - 使用pytorch的rnn中隐藏维度和n_layers之间的区别

python - 按照 numpy.ndarray 的顺序就地打乱 torch.Tensor

pytorch - 将 PyTorch 模型转换为 TorchScript 时出错