keras - 如何将 decode_batch_predictions() 方法添加到 Keras Captcha OCR 模型中?

标签 keras ocr decoding ctc

当前Keras Captcha OCR model返回一个 CTC 编码的输出,需要在推理后解码。
要对此进行解码,需要在推理后作为单独的步骤运行解码效用函数。

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
解码的效用函数使用 keras.backend.ctc_decode ,它反过来使用贪婪或波束搜索解码器。
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_length
    ]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text
我想使用 Keras 训练 Captcha OCR 模型,该模型返回解码为输出的 CTC,推理后不需要额外的解码步骤。
我将如何实现这一目标?

最佳答案

你的问题可以有两种解释。一个是:我想要一个神经网络来解决一个问题,即 CTC 解码步骤已经在网络学到的东西中。另一个是您希望有一个 Model 类在其中执行此 CTC 解码,而不使用外部功能函数。
我不知道第一个问题的答案。我什至无法判断它是否可行。在任何情况下,这听起来都是一个困难的理论问题,如果您在这里没有运气,您可能想尝试将其发布在 datascience.stackexchange.com 中。 ,这是一个更加以理论为导向的社区。
现在,如果您要解决的是问题的第二个工程版本,那么我可以为您提供帮助。该问题的解决方案如下:
您需要子类化 keras.models.Model使用带有您想要的方法的类。我浏览了您发布的链接中的教程,并提供了以下类(class):

class ModifiedModel(keras.models.Model):
    
    # A utility function to decode the output of the network
    def decode_batch_predictions(self, pred):
        input_len = np.ones(pred.shape[0]) * pred.shape[1]
        # Use greedy search. For complex tasks, you can use beam search
        results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
            :, :max_length
        ]
        # Iterate over the results and get back the text
        output_text = []
        for res in results:
            res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
            output_text.append(res)
        return output_text

    
    def predict_texts(self, batch_images):
        preds = self.predict(batch_images)
        return self.decode_batch_predictions(preds)
你可以给它取你想要的名字,这只是为了说明目的。
定义此类后,您将替换该行
# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)

prediction_model = ModifiedModel(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
然后你可以替换线条
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

pred_texts = prediction_model.predict_texts(batch_images)

关于keras - 如何将 decode_batch_predictions() 方法添加到 Keras Captcha OCR 模型中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67068303/

相关文章:

python - 如何在 Python 中使用 ASN1 库解码 .BER

email - 如何解码以Base64文本形式接收的电子邮件附件

java - 如何在 Java 中逐步解码大型多字节字符串文件?

python-3.x - 将 TensorFlow 设置为 Keras 中的 session 时出现问题

keras - 使用 model.fit_generator 时如何获取混淆矩阵

python - 具有适用于 Word2Vec 模型的 Keras 功能 API 的产品合并层

python - 如何修复 'ValueError: input tensor must have rank 4' ?

使用 Tesseract 的 Android OCR 应用程序

iphone - 检测数字并处理它们?

ocr - Kofax Capture 识别 - I vs 1