目前的状态:人有点晕。好多细节的东西都不知道怎么来的。大方向有所把握:
1. 准备数据集:看起来很简单,其实不然。如何把文件读取进来,变成pytorch所需要的数据类型。
图片:你就需要ToTensor,Normalize转换为需要的数据类型
文字:对init,getitem,len进行重写
准备dataset,构建data_loader并返回
2. 构建模型:重写init和forward方法。在forward里对每一层进行处理。包括矩阵变换,激活函数等去得到输出
3. 训练:基本就是循环里面梯度归零,调用,loss,反向传播,更新
data_loader = get_dataloader()
for idx,(input,traget) in enumerate(data_loader):
optimizer.zero_grad() # 梯度归零
output = model(input) # 调用模型得到预测值
loss = F.nll_loss(output,traget) # 得到损失
loss.backward() # 反向传播
optimizer.step() # 梯度更新
4. 测试:pass