python - PyTorch 不能 pickle lambda

标签 python lambda pytorch pickle

我有一个使用自定义 LambdaLayer 的模型,如下所示:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)


class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

该模型在 CPU 或 1 个 GPU 上运行完美。但是,当使用 PyTorch Lightning 在 2+ GPU 上运行它时,会发生此错误:

AttributeError: Can't pickle local object 'TorchCatEmbedding.__init__.<locals>.<lambda>'

这里使用 lambda 函数的目的是给定一个 inputs 张量,我只想将 inputs[:, start:end] 传递给 嵌入层。

我的问题:

  • 在这种情况下是否有替代方法来使用 lambda?
  • 如果不是,应该怎么做才能让 lambda 函数在这种情况下工作?

最佳答案

所以问题不在于 lambda 函数本身,而是 pickle 不适用于不仅仅是模块级函数的函数(pickle 处理函数的方式就像对某些模块级名称的引用) .所以,不幸的是,如果你需要捕获 startend 参数,你将无法使用闭包,你通常只需要像这样的东西:

def function_maker(start, end):
    def function(x):
        return x[:, start:end]
    return function

但是就 pickle 问题而言,这会让您回到起点。

所以,尝试这样的事情:

class Slicer:
    def __init__(self, start, end):
        self.start = start
        self.end = end
    def __call__(self, x):
        return x[:, self.start:self.end])

然后你可以使用:

LambdaLayer(Slicer(start, end))

我不熟悉 PyTorch,我很惊讶它不提供使用不同序列化后端的能力。悲伤/dill例如,project 可以 pickle 任意函数,而且通常更容易使用它。但我相信以上应该可以解决问题。

关于python - PyTorch 不能 pickle lambda,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70608810/

相关文章:

python - PyTorch - normal() 初始化对梯度的影响

Python 从单个字符串列表创建浮点列表列表

python - 如何使用按钮将sqlite Flask后端的数据显示到前端?

c++ - 使用 lambda 作为 GLFWkeyfun

mvvm - RelayCommand lambda 语法问题

python - PyTorch 的张量是如何实现的?

根据子文本节点从大 XML 中提取子 XML 的 Java 或 Python 方法

python - 如何使用 openpyxl 将单元格中的文本对齐到顶部?

ruby - 为什么 STDOUT 在 Ruby 中只显示一次返回的消息?

python - 我怎样才能解决向后()得到一个意外的关键字参数 'retain_variables' ?