python - 如何访问pytorch分类模型的预测? (伯特)

标签 python deep-learning pytorch pre-trained-model nlp

我正在运行这个文件: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py

这是一个输入批处理的预测代码:

  input_ids = input_ids.to(device)
  input_mask = input_mask.to(device)
  segment_ids = segment_ids.to(device)
  label_ids = label_ids.to(device)

  with torch.no_grad():
       logits = model(input_ids, segment_ids, input_mask, labels=None)

       loss_fct = CrossEntropyLoss()
       tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))

       eval_loss += tmp_eval_loss.mean().item()
       nb_eval_steps += 1
       if len(preds) == 0:
           preds.append(logits.detach().cpu().numpy())
       else:
           preds[0] = np.append(preds[0], logits.detach().cpu().numpy(), axis=0)

该任务是二元分类。 我想访问二进制输出。

我已经尝试过这个:

  curr_pred = logits.detach().cpu()

  if len(preds) == 0:
      preds.append(curr_pred.numpy())
  else:
      preds[0] = np.append(preds[0], curr_pred.numpy(), axis=0)

  probablities = curr_pred.softmax(1).numpy()[:, 1]

但是结果看起来很奇怪。所以我不确定这是否是正确的方法。

我的假设 - 我收到最后一层的输出,因此在 softmax 之后,它是真实概率(dim 2 的向量 - 第一类的概率和第二类的概率。)

最佳答案

查看 run_classifier.py 代码的这一部分后:

    # copied from the run_classifier.py code 
    eval_loss = eval_loss / nb_eval_steps
    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, all_label_ids.numpy())

你只是缺少:

    preds = preds[0]
    preds = np.argmax(preds, axis=1)

然后他们只使用 preds 来计算准确性:

    def simple_accuracy(preds, labels):
         return (preds == labels).mean()

关于python - 如何访问pytorch分类模型的预测? (伯特),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56201147/

相关文章:

deep-learning - NVLink 是否可以使用 DistributedDataParallel 加速训练?

python - Pytorch 验证模型错误 : Expected input batch_size (3) to match target batch_size (4)

python - 无法在 Google Colab 中安装 dgl-cu<任何版本>

python - openpyxl : data-validation read/write without treatment

python - 如果另一列中的相应值也为 NaN,则将一列中的所有值设置为 NaN

python - 如何在 Keras 中组合两个具有不同输入大小的 LSTM 层?

python - Pytorch 已安装但在 ubuntu 18.04 上无法运行

python - 如果比较

python - "rpy2"在 Enthought Canopy 中失败(与内部 GFORTRAN 库冲突)

machine-learning - 使用 CNN 进行二值图像分类 - 选择 "negative"数据集的最佳实践?