go - 如何在 golang 中的文本上执行 DL - RNN 模型?

标签 go tensorflow nlp rnn

我已经基于 reddit/twitter 对话在 tensor-flow 中构建了 RNN 模型。我将它保存在 pb 中。有谁知道如何通过 golang 中的模型传递原始文本字符串并生成输出?

modeldir := "/my_model.pb"

// Buffer input text
var buffer bytes.Buffer

args := os.Args[1:]

for _, arg := range args {
    buffer.WriteString(arg + " ")
}

inputText := buffer.String()

// Load the serialized GraphDef from a file.

model, err := ioutil.ReadFile(modeldir)
if err != nil {
    log.Fatal(err)
}
// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
    log.Fatal(err)
}
// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
    log.Fatal(err)
}
defer session.Close()

最佳答案

您可以使用 tfgo轻松加载到 Go 并使用经过训练的 tensorflow 模型:只需使用 tf.saved_model.builder.SavedModelBuilder 导出经过训练的模型,如 tfgo 自述文件中所示。

但是,您只需从图中提取输入占位符,然后使用它来馈送网络。

假设您导出模型并将其命名为 my_model 并使用标签 tag 对其进行标记。我们还假设您的输入占位符被命名为“Placeholder”。此外,您必须知道输出节点的名称。我们称它为 output/node/path/op。那么你的代码应该是这样的:

import (
        "fmt"
        tg "github.com/galeone/tfgo"
        tf "github.com/tensorflow/tensorflow/tensorflow/go"
        "flags"
)

func main() {
        model := tg.LoadModel("my_model", []string{"tag"}, nil)

        // Buffer input text
        var buffer bytes.Buffer
        args := os.Args[1:]

        for _, arg := range args {
            buffer.WriteString(arg + " ")
        }
        // handle the retunred error below, if any
        inputText, _ := tf.NewTensor(buffer.String())

        results := model.Exec([]tf.Output{
                model.Op("output/node/path/op", 0),
        }, map[tf.Output]*tf.Tensor{
                model.Op("Placeholder", 0): inputText,
        })
        // do something with results[0]
}

关于go - 如何在 golang 中的文本上执行 DL - RNN 模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47142976/

相关文章:

c++ - Tensorflow C++ 内存泄漏 - Valgrind

python - NLTK/pyNLTK 可以工作 "per language"(即非英语),如何工作?

用于标记英文文本的正则表达式

go - 什么是 reflect.Value 在 golang 中的零值

go - 为什么范围内没有帖子?戈朗

go - Kafka 生产者不通过分区分发消息

python - 在 pandas DataFrame 列中存储列表

go - "value semantics’ "和 "pointer semantics"在 Go 中是什么意思?

python - 使用 WALS 方法在 tensorflow 2.0 中进行矩阵分解

python - 使用 TensorFlow eager execution 嵌入可视化