PyTorch学习笔记 7.TextCNN文本分类
2014年,Yoon Kim针对CNN的输入层做了一些变形,提出了文本分类模型textCNN。与传统图像的CNN网络相比, textCNN 在网络结构上没有任何变化,包含只有一层卷积,一层最大池化层, 最后将输出外接softmax 来进行n分类。
模型结构:
本文使用的数据集是 THUCNews 。
这里使用bert的预训练模型 bert-base-chinese 实现tokenizer过程。更多与bert分词编码相关知识可以移步到这里查看。
数据加载器使用pytorch 的 dataset,关于DataSet更多知识可以移步到这里查看。
模型定义3个卷积层,卷积大小分别是2,3,4。
卷积激活函数使用relu。
卷积后进行最大池化,池化是在2维上进行,池化后进行降维处理。
根据池化层的输出和分类类别数量,构建全连接层,再经过softmax,得到最终的分类结果。
这里使用torch.nn.Linear(input_num, num_class)定义全连接层,其中input_num是池化层输出的维数,即m,num_class是分类任务的类别数量。
按批次取训练数据,调用模型进行训练,主要是以下几个步骤:
获取loss:输入数据和标签,计算得到预测值,计算损失函数;optimizer.zero_grad() 清空梯度;loss.backward() 反向传播,计算当前梯度;optimizer.step() 根据梯度更新网络参数
测试过程对每次正确率累加,最后打印整体的测试结果:
把输入文本进行分词编码输入模型,通过argmax计算预测值通过id转标签函数计算标签值
举报/反馈