Skip to main content

具有 N-best 解码的 PyTorch CRF

项目描述

具有 N 最佳解码的 PyTorch CRF

PyTorch 1.0 中条件随机场 (CRF) 的实现。它支持前 N 个最可能路径解码。

该软件包基于pytorch-crf仅有以下区别

  • 解码最可能路径的方法_viterbi_decode得到优化。批量大小为 15+ 且序列长度为 20+ 时,运行时间减少到 50% 或更少
  • 该类现在支持通过方法的实现解码前 N 个最可能的路径_viterbi_decode_nbest

要求

  • Python 3 (>= 3.6)
  • PyTorch (>= 1.0)

安装

pip install pytorchcrf

例子

>>> import torch
>>> from pytorchcrf import CRF
>>> num_tags = 5                        # number of tags is 5
>>> model = CRF(num_tags)
>>> seq_length = 3                      # maximum sequence length in a batch
>>> batch_size = 2                      # number of samples in the batch
>>> emissions = torch.randn(seq_length, batch_size, num_tags)

# Computing log likelihood
>>> tags = torch.tensor([[2, 3], [1, 0], [3, 4]], dtype=torch.long)  # (seq_length, batch_size)
>>> model(emissions, tags)

# Decoding
>>> model.decode(emissions)             # decoding the best path
>>> model.decode(emissions, nbest=3)    # decoding the top 3 paths

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

pytorchcrf-1.2.0.tar.gz (6.4 kB 查看哈希

已上传 source

内置分布

pytorchcrf-1.2.0-py3-none-any.whl (7.1 kB 查看哈希

已上传 py3