【HW2】slides0.1李宏毅2021/2022春机器学习课程笔记EP7(P28-P30)
从今天开始我将学习李宏毅教授的机器学习视频,下面是课程的连接(强推)李宏毅2021/2022春机器学习课程_哔哩哔哩_bilibili。一共有155个视频,争取都学习完成吧。
那么首先这门课程需要有一定的代码基础,简单学习一下Python的基本用法,还有里面的NumPy库等等的基本知识。再就是数学方面的基础啦,微积分、线性代数和概率论的基础都是听懂这门课必须的。
这里是这门课程的作业2,那因为本人比较废,基础还没有很好,这里先放一个基本没有什么改动的助教的程序作为HW2的0.1版本,这里因为没有什么改动,这个model也很废啦,这里的loss比较大。等之后有了数学和程序上的基础之后,一定会回来补一个效果好一点的1.0版本!
这里我在助教的代码基础上改了超参数的数值,让模型的loss下降了一点点,但是一点点的进步的代价是训练耗费的时间指数型的暴增。还是等之后全部学完之后,再来补一个改进版本吧。
下载训练/测试数据集
!pip install --upgrade gdown
# Main link
!gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip
# Backup link 1
# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip
# Bqckup link 2
# !wget -O libriphone.zip "https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1"
!unzip -q libriphone.zip
!ls libriphone
Requirement already satisfied: gdown in /usr/local/lib/python3.10/dist-packages (4.7.3) Collecting gdown Downloading gdown-5.1.0-py3-none-any.whl (17 kB) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown) (4.12.3) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown) (3.13.1) Requirement already satisfied: requests[socks] in /usr/local/lib/python3.10/dist-packages (from gdown) (2.31.0) Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from gdown) (4.66.2) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown) (2.5) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (3.6) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (2024.2.2) Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests[socks]->gdown) (1.7.1) Installing collected packages: gdown Attempting uninstall: gdown Found existing installation: gdown 4.7.3 Uninstalling gdown-4.7.3: Successfully uninstalled gdown-4.7.3 Successfully installed gdown-5.1.0 /usr/local/lib/python3.10/dist-packages/gdown/__main__.py:132: FutureWarning: Option `–id` was deprecated in version 4.3.1 and will be removed in 5.0. You don’t need to pass it anymore to use a file ID. warnings.warn( Downloading… From (original): https://drive.google.com/uc?id=1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc From (redirected): https://drive.usercontent.google.com/download?id=1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc&confirm=t&uuid=5f6f60a2-b8f1-417e-8539-c50d69f106d6 To: /content/libriphone.zip 100% 479M/479M [00:09<00:00, 51.8MB/s] feat test_split.txt train_labels.txt train_split.txt
数据准备
import os
import random
import pandas as pd
import torch
from tqdm import tqdm
def load_feat(path):
feat = torch.load(path)
return feat
def shift(x, n):
if n < 0:
left = x[0].repeat(-n, 1)
right = x[:n]
elif n > 0:
right = x[-1].repeat(n, 1)
left = x[n:]
else:
return x
return torch.cat((left, right), dim=0)
def concat_feat(x, concat_n):
assert concat_n % 2 == 1 # n must be odd
if concat_n < 2:
return x
seq_len, feature_dim = x.size(0), x.size(1)
x = x.repeat(1, concat_n)
x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim
mid = (concat_n // 2)
for r_idx in range(1, mid+1):
x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)
x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)
return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)
def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):
class_num = 41 # NOTE: pre-computed, should not need change
mode = 'train' if (split == 'train' or split == 'val') else 'test'
label_dict = {}
if mode != 'test':
phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()
for line in phone_file:
line = line.strip('\n').split(' ')
label_dict[line[0]] = [int(p) for p in line[1:]]
if split == 'train' or split == 'val':
# split training and validation data
usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()
random.seed(train_val_seed)
random.shuffle(usage_list)
percent = int(len(usage_list) * train_ratio)
usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]
elif split == 'test':
usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()
else:
raise ValueError('Invalid \'split\' argument for dataset: PhoneDataset!')
usage_list = [line.strip('\n') for line in usage_list]
print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))
max_len = 3000000
X = torch.empty(max_len, 39 * concat_nframes)
if mode != 'test':
y = torch.empty(max_len, dtype=torch.long)
idx = 0
for i, fname in tqdm(enumerate(usage_list)):
feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))
cur_len = len(feat)
feat = concat_feat(feat, concat_nframes)
if mode != 'test':
label = torch.LongTensor(label_dict[fname])
X[idx: idx + cur_len, :] = feat
if mode != 'test':
y[idx: idx + cur_len] = label
idx += cur_len
X = X[:idx, :]
if mode != 'test':
y = y[:idx]
print(f'[INFO] {split} set')
print(X.shape)
if mode != 'test':
print(y.shape)
return X, y
else:
return X
定义数据集
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class LibriDataset(Dataset):
def __init__(self, X, y=None):
self.data = X
if y is not None:
self.label = torch.LongTensor(y)
else:
self.label = None
def __getitem__(self, idx):
if self.label is not None:
return self.data[idx], self.label[idx]
else:
return self.data[idx]
def __len__(self):
return len(self.data)
定义模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super(BasicBlock, self).__init__()
self.block = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.ReLU(),
)
def forward(self, x):
x = self.block(x)
return x
class Classifier(nn.Module):
def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):
super(Classifier, self).__init__()
self.fc = nn.Sequential(
BasicBlock(input_dim, hidden_dim),
*[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
x = self.fc(x)
return x
超参数(这里改动了一点)
# 数据参数
concat_nframes = 1 # 要连接的帧数,n必须是奇数(总共2k+1n帧)
train_ratio = 0.8 # 数据的比例用于训练,其余的将用于验证
# 训练参数
seed = 0 # 随机种子
batch_size = 1024 # batch size
num_epoch = 100 # 训练历元的数目
learning_rate = 1e-5 # 学习率
model_path = './model.ckpt' # 保存检查点的路径
# 模型参数
input_dim = 39 * concat_nframes # 模型的输入值,不应更改其值
hidden_layers = 1 # 隐藏层的数量
hidden_dim = 256 # 隐藏的朦胧
数据集和模型的准备
import gc
# preprocess data
train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)
val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)
# get dataset
train_set = LibriDataset(train_X, train_y)
val_set = LibriDataset(val_X, val_y)
# remove raw feature to save memory
del train_X, train_y, val_X, val_y
gc.collect()
# get dataloader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
[Dataset] – # phone classes: 41, number of utterances for train: 3428 3428it [00:05, 639.17it/s] [INFO] train set torch.Size([2116368, 39]) torch.Size([2116368]) [Dataset] – # phone classes: 41, number of utterances for val: 858 858it [00:00, 1011.41it/s] [INFO] val set torch.Size([527790, 39]) torch.Size([527790])
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE: {device}')
DEVICE: cpu
import numpy as np
#fix seed
def same_seeds(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# fix random seed
same_seeds(seed)
# create model, define a loss function, and optimizer
model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
训练
best_acc = 0.0
for epoch in range(num_epoch):
train_acc = 0.0
train_loss = 0.0
val_acc = 0.0
val_loss = 0.0
# training
model.train() # set the model to training mode
for i, batch in enumerate(tqdm(train_loader)):
features, labels = batch
features = features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
train_acc += (train_pred.detach() == labels.detach()).sum().item()
train_loss += loss.item()
# validation
if len(val_set) > 0:
model.eval() # set the model to evaluation mode
with torch.no_grad():
for i, batch in enumerate(tqdm(val_loader)):
features, labels = batch
features = features.to(device)
labels = labels.to(device)
outputs = model(features)
loss = criterion(outputs, labels)
_, val_pred = torch.max(outputs, 1)
val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability
val_loss += loss.item()
print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(
epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)
))
# if the model improves, save a checkpoint at this epoch
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), model_path)
print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))
else:
print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(
epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)
))
# if not validating, save the last epoch
if len(val_set) == 0:
torch.save(model.state_dict(), model_path)
print('saving model at last epoch')
100%|██████████| 2067/2067 [00:55<00:00, 37.28it/s] 100%|██████████| 516/516 [00:08<00:00, 61.31it/s] [001/100] Train Acc: 0.283196 Loss: 2.956003 | Val Acc: 0.365916 loss: 2.434805 saving model with acc 0.366 100%|██████████| 2067/2067 [00:59<00:00, 35.03it/s] 100%|██████████| 516/516 [00:08<00:00, 61.24it/s] [002/100] Train Acc: 0.390979 Loss: 2.268735 | Val Acc: 0.403733 loss: 2.182312 saving model with acc 0.404 100%|██████████| 2067/2067 [00:57<00:00, 36.05it/s] 100%|██████████| 516/516 [00:07<00:00, 68.30it/s] [003/100] Train Acc: 0.414018 Loss: 2.131366 | Val Acc: 0.416668 loss: 2.111082 saving model with acc 0.417 100%|██████████| 2067/2067 [00:57<00:00, 36.26it/s] 100%|██████████| 516/516 [00:08<00:00, 61.16it/s] [004/100] Train Acc: 0.422077 Loss: 2.084401 | Val Acc: 0.421995 loss: 2.078780 saving model with acc 0.422 100%|██████████| 2067/2067 [00:55<00:00, 37.35it/s] 100%|██████████| 516/516 [00:08<00:00, 63.79it/s] [005/100] Train Acc: 0.426466 Loss: 2.058605 | Val Acc: 0.425544 loss: 2.057722 saving model with acc 0.426 100%|██████████| 2067/2067 [00:56<00:00, 36.75it/s] 100%|██████████| 516/516 [00:08<00:00, 63.39it/s] [006/100] Train Acc: 0.429670 Loss: 2.040123 | Val Acc: 0.428562 loss: 2.041594 saving model with acc 0.429 100%|██████████| 2067/2067 [00:55<00:00, 37.52it/s] 100%|██████████| 516/516 [00:08<00:00, 61.07it/s] [007/100] Train Acc: 0.432267 Loss: 2.025459 | Val Acc: 0.430707 loss: 2.028616 saving model with acc 0.431 100%|██████████| 2067/2067 [00:54<00:00, 37.74it/s] 100%|██████████| 516/516 [00:08<00:00, 61.19it/s] [008/100] Train Acc: 0.434388 Loss: 2.013230 | Val Acc: 0.432638 loss: 2.017498 saving model with acc 0.433 100%|██████████| 2067/2067 [00:55<00:00, 37.58it/s] 100%|██████████| 516/516 [00:08<00:00, 61.99it/s] [009/100] Train Acc: 0.436352 Loss: 2.002708 | Val Acc: 0.434404 loss: 2.008038 saving model with acc 0.434 100%|██████████| 2067/2067 [00:55<00:00, 36.95it/s] 100%|██████████| 516/516 [00:07<00:00, 65.30it/s] [010/100] Train Acc: 0.437990 Loss: 1.993568 | Val Acc: 0.435757 loss: 1.999589 saving model with acc 0.436 100%|██████████| 2067/2067 [00:56<00:00, 36.84it/s] 100%|██████████| 516/516 [00:07<00:00, 66.84it/s] [011/100] Train Acc: 0.439461 Loss: 1.985492 | Val Acc: 0.437386 loss: 1.992147 saving model with acc 0.437 100%|██████████| 2067/2067 [00:56<00:00, 36.74it/s] 100%|██████████| 516/516 [00:07<00:00, 67.37it/s] [012/100] Train Acc: 0.440845 Loss: 1.978278 | Val Acc: 0.438602 loss: 1.985367 saving model with acc 0.439 100%|██████████| 2067/2067 [00:56<00:00, 36.79it/s] 100%|██████████| 516/516 [00:07<00:00, 66.08it/s] [013/100] Train Acc: 0.442134 Loss: 1.971737 | Val Acc: 0.439753 loss: 1.979366 saving model with acc 0.440 100%|██████████| 2067/2067 [00:57<00:00, 36.13it/s] 100%|██████████| 516/516 [00:08<00:00, 62.10it/s] [014/100] Train Acc: 0.443168 Loss: 1.965769 | Val Acc: 0.440897 loss: 1.973897 saving model with acc 0.441 100%|██████████| 2067/2067 [00:55<00:00, 37.24it/s] 100%|██████████| 516/516 [00:08<00:00, 60.37it/s] [015/100] Train Acc: 0.444228 Loss: 1.960339 | Val Acc: 0.442062 loss: 1.968790 saving model with acc 0.442 100%|██████████| 2067/2067 [00:55<00:00, 37.19it/s] 100%|██████████| 516/516 [00:08<00:00, 60.49it/s] [016/100] Train Acc: 0.445204 Loss: 1.955314 | Val Acc: 0.442691 loss: 1.964092 saving model with acc 0.443 100%|██████████| 2067/2067 [00:57<00:00, 35.96it/s] 100%|██████████| 516/516 [00:08<00:00, 64.31it/s] [017/100] Train Acc: 0.446156 Loss: 1.950657 | Val Acc: 0.443824 loss: 1.959666 saving model with acc 0.444 100%|██████████| 2067/2067 [00:56<00:00, 36.73it/s] 100%|██████████| 516/516 [00:07<00:00, 69.93it/s] [018/100] Train Acc: 0.447040 Loss: 1.946370 | Val Acc: 0.444512 loss: 1.955814 saving model with acc 0.445 100%|██████████| 2067/2067 [00:55<00:00, 37.06it/s] 100%|██████████| 516/516 [00:07<00:00, 68.65it/s] [019/100] Train Acc: 0.447772 Loss: 1.942352 | Val Acc: 0.445287 loss: 1.952020 saving model with acc 0.445 100%|██████████| 2067/2067 [00:56<00:00, 36.79it/s] 100%|██████████| 516/516 [00:07<00:00, 72.17it/s] [020/100] Train Acc: 0.448488 Loss: 1.938631 | Val Acc: 0.446359 loss: 1.948506 saving model with acc 0.446 100%|██████████| 2067/2067 [00:58<00:00, 35.60it/s] 100%|██████████| 516/516 [00:07<00:00, 70.33it/s] [021/100] Train Acc: 0.449211 Loss: 1.935159 | Val Acc: 0.446913 loss: 1.945285 saving model with acc 0.447 100%|██████████| 2067/2067 [00:55<00:00, 37.17it/s] 100%|██████████| 516/516 [00:07<00:00, 70.42it/s] [022/100] Train Acc: 0.449759 Loss: 1.931951 | Val Acc: 0.447633 loss: 1.942207 saving model with acc 0.448 100%|██████████| 2067/2067 [00:55<00:00, 37.31it/s] 100%|██████████| 516/516 [00:07<00:00, 71.23it/s] [023/100] Train Acc: 0.450402 Loss: 1.928946 | Val Acc: 0.448078 loss: 1.939327 saving model with acc 0.448 100%|██████████| 2067/2067 [00:55<00:00, 37.33it/s] 100%|██████████| 516/516 [00:06<00:00, 74.72it/s] [024/100] Train Acc: 0.451121 Loss: 1.926114 | Val Acc: 0.448444 loss: 1.936836 saving model with acc 0.448 100%|██████████| 2067/2067 [00:57<00:00, 36.16it/s] 100%|██████████| 516/516 [00:08<00:00, 64.49it/s] [025/100] Train Acc: 0.451586 Loss: 1.923470 | Val Acc: 0.449050 loss: 1.934270 saving model with acc 0.449 100%|██████████| 2067/2067 [00:56<00:00, 36.55it/s] 100%|██████████| 516/516 [00:08<00:00, 60.99it/s] [026/100] Train Acc: 0.452014 Loss: 1.920975 | Val Acc: 0.449929 loss: 1.932017 saving model with acc 0.450 100%|██████████| 2067/2067 [00:55<00:00, 37.01it/s] 100%|██████████| 516/516 [00:08<00:00, 60.82it/s] [027/100] Train Acc: 0.452486 Loss: 1.918612 | Val Acc: 0.449955 loss: 1.929963 saving model with acc 0.450 100%|██████████| 2067/2067 [00:56<00:00, 36.88it/s] 100%|██████████| 516/516 [00:08<00:00, 61.68it/s] [028/100] Train Acc: 0.453079 Loss: 1.916388 | Val Acc: 0.450687 loss: 1.927791 saving model with acc 0.451 100%|██████████| 2067/2067 [00:56<00:00, 36.28it/s] 100%|██████████| 516/516 [00:08<00:00, 62.70it/s] [029/100] Train Acc: 0.453480 Loss: 1.914284 | Val Acc: 0.451075 loss: 1.925931 saving model with acc 0.451 100%|██████████| 2067/2067 [00:56<00:00, 36.54it/s] 100%|██████████| 516/516 [00:07<00:00, 70.86it/s] [030/100] Train Acc: 0.453852 Loss: 1.912291 | Val Acc: 0.451348 loss: 1.924096 saving model with acc 0.451 100%|██████████| 2067/2067 [00:56<00:00, 36.39it/s] 100%|██████████| 516/516 [00:07<00:00, 70.75it/s] [031/100] Train Acc: 0.454214 Loss: 1.910424 | Val Acc: 0.451727 loss: 1.922332 saving model with acc 0.452 100%|██████████| 2067/2067 [00:57<00:00, 36.22it/s] 100%|██████████| 516/516 [00:07<00:00, 70.82it/s] [032/100] Train Acc: 0.454616 Loss: 1.908637 | Val Acc: 0.452117 loss: 1.920980 saving model with acc 0.452 100%|██████████| 2067/2067 [00:58<00:00, 35.20it/s] 100%|██████████| 516/516 [00:08<00:00, 60.16it/s] [033/100] Train Acc: 0.455022 Loss: 1.906924 | Val Acc: 0.452258 loss: 1.919316 saving model with acc 0.452 100%|██████████| 2067/2067 [00:56<00:00, 36.86it/s] 100%|██████████| 516/516 [00:08<00:00, 61.98it/s] [034/100] Train Acc: 0.455261 Loss: 1.905280 | Val Acc: 0.452386 loss: 1.917809 saving model with acc 0.452 100%|██████████| 2067/2067 [00:54<00:00, 37.74it/s] 100%|██████████| 516/516 [00:08<00:00, 63.56it/s] [035/100] Train Acc: 0.455582 Loss: 1.903720 | Val Acc: 0.452962 loss: 1.916526 saving model with acc 0.453 100%|██████████| 2067/2067 [00:55<00:00, 37.49it/s] 100%|██████████| 516/516 [00:07<00:00, 64.87it/s] [036/100] Train Acc: 0.455911 Loss: 1.902194 | Val Acc: 0.453360 loss: 1.914861 saving model with acc 0.453 100%|██████████| 2067/2067 [00:57<00:00, 35.82it/s] 100%|██████████| 516/516 [00:08<00:00, 60.37it/s] [037/100] Train Acc: 0.456176 Loss: 1.900760 | Val Acc: 0.453663 loss: 1.913680 saving model with acc 0.454 100%|██████████| 2067/2067 [00:56<00:00, 36.82it/s] 100%|██████████| 516/516 [00:07<00:00, 67.54it/s] [038/100] Train Acc: 0.456470 Loss: 1.899388 | Val Acc: 0.453826 loss: 1.912437 saving model with acc 0.454 100%|██████████| 2067/2067 [00:55<00:00, 37.05it/s] 100%|██████████| 516/516 [00:06<00:00, 74.22it/s] [039/100] Train Acc: 0.456719 Loss: 1.898045 | Val Acc: 0.454139 loss: 1.910982 saving model with acc 0.454 100%|██████████| 2067/2067 [00:55<00:00, 37.15it/s] 100%|██████████| 516/516 [00:08<00:00, 63.24it/s] [040/100] Train Acc: 0.456987 Loss: 1.896759 | Val Acc: 0.454131 loss: 1.909964 100%|██████████| 2067/2067 [00:55<00:00, 37.40it/s] 100%|██████████| 516/516 [00:07<00:00, 72.81it/s] [041/100] Train Acc: 0.457293 Loss: 1.895483 | Val Acc: 0.454453 loss: 1.909007 saving model with acc 0.454 100%|██████████| 2067/2067 [00:55<00:00, 37.36it/s] 100%|██████████| 516/516 [00:07<00:00, 73.40it/s] [042/100] Train Acc: 0.457459 Loss: 1.894299 | Val Acc: 0.454556 loss: 1.908012 saving model with acc 0.455 100%|██████████| 2067/2067 [00:55<00:00, 37.28it/s] 100%|██████████| 516/516 [00:07<00:00, 71.28it/s] [043/100] Train Acc: 0.457703 Loss: 1.893125 | Val Acc: 0.454916 loss: 1.906775 saving model with acc 0.455 100%|██████████| 2067/2067 [00:56<00:00, 36.33it/s] 100%|██████████| 516/516 [00:08<00:00, 61.18it/s] [044/100] Train Acc: 0.457984 Loss: 1.891975 | Val Acc: 0.455060 loss: 1.905893 saving model with acc 0.455 100%|██████████| 2067/2067 [00:56<00:00, 36.34it/s] 100%|██████████| 516/516 [00:08<00:00, 60.00it/s] [045/100] Train Acc: 0.458029 Loss: 1.890876 | Val Acc: 0.455124 loss: 1.904940 saving model with acc 0.455 100%|██████████| 2067/2067 [00:55<00:00, 37.01it/s] 100%|██████████| 516/516 [00:08<00:00, 60.60it/s] [046/100] Train Acc: 0.458300 Loss: 1.889813 | Val Acc: 0.455389 loss: 1.903937 saving model with acc 0.455 100%|██████████| 2067/2067 [00:55<00:00, 37.10it/s] 100%|██████████| 516/516 [00:08<00:00, 61.34it/s] [047/100] Train Acc: 0.458575 Loss: 1.888764 | Val Acc: 0.455431 loss: 1.902961 saving model with acc 0.455 100%|██████████| 2067/2067 [01:01<00:00, 33.53it/s] 100%|██████████| 516/516 [00:08<00:00, 63.54it/s] [048/100] Train Acc: 0.458721 Loss: 1.887728 | Val Acc: 0.455816 loss: 1.902145 saving model with acc 0.456 100%|██████████| 2067/2067 [00:59<00:00, 34.81it/s] 100%|██████████| 516/516 [00:08<00:00, 57.50it/s] [049/100] Train Acc: 0.458985 Loss: 1.886740 | Val Acc: 0.455757 loss: 1.901472 100%|██████████| 2067/2067 [00:58<00:00, 35.38it/s] 100%|██████████| 516/516 [00:09<00:00, 56.04it/s] [050/100] Train Acc: 0.459081 Loss: 1.885782 | Val Acc: 0.455651 loss: 1.900525 100%|██████████| 2067/2067 [01:09<00:00, 29.57it/s] 100%|██████████| 516/516 [00:09<00:00, 52.86it/s] [051/100] Train Acc: 0.459381 Loss: 1.884825 | Val Acc: 0.456032 loss: 1.899771 saving model with acc 0.456 100%|██████████| 2067/2067 [01:05<00:00, 31.64it/s] 100%|██████████| 516/516 [00:09<00:00, 54.95it/s] [052/100] Train Acc: 0.459510 Loss: 1.883918 | Val Acc: 0.456098 loss: 1.898842 saving model with acc 0.456 100%|██████████| 2067/2067 [01:07<00:00, 30.80it/s] 100%|██████████| 516/516 [00:09<00:00, 54.03it/s] [053/100] Train Acc: 0.459717 Loss: 1.883032 | Val Acc: 0.456259 loss: 1.898125 saving model with acc 0.456 100%|██████████| 2067/2067 [00:58<00:00, 35.11it/s] 100%|██████████| 516/516 [00:10<00:00, 51.11it/s] [054/100] Train Acc: 0.459898 Loss: 1.882123 | Val Acc: 0.456358 loss: 1.897445 saving model with acc 0.456 100%|██████████| 2067/2067 [01:01<00:00, 33.62it/s] 100%|██████████| 516/516 [00:08<00:00, 58.09it/s] [055/100] Train Acc: 0.460116 Loss: 1.881261 | Val Acc: 0.456536 loss: 1.896433 saving model with acc 0.457 100%|██████████| 2067/2067 [00:59<00:00, 34.70it/s] 100%|██████████| 516/516 [00:08<00:00, 59.87it/s] [056/100] Train Acc: 0.460199 Loss: 1.880439 | Val Acc: 0.456492 loss: 1.895816 100%|██████████| 2067/2067 [00:54<00:00, 37.72it/s] 100%|██████████| 516/516 [00:08<00:00, 61.60it/s] [057/100] Train Acc: 0.460341 Loss: 1.879603 | Val Acc: 0.456759 loss: 1.895241 saving model with acc 0.457 100%|██████████| 2067/2067 [00:54<00:00, 37.96it/s] 100%|██████████| 516/516 [00:08<00:00, 61.79it/s] [058/100] Train Acc: 0.460510 Loss: 1.878778 | Val Acc: 0.457026 loss: 1.894500 saving model with acc 0.457 100%|██████████| 2067/2067 [00:55<00:00, 36.97it/s] 100%|██████████| 516/516 [00:08<00:00, 62.95it/s] [059/100] Train Acc: 0.460636 Loss: 1.877990 | Val Acc: 0.456915 loss: 1.893863 100%|██████████| 2067/2067 [00:55<00:00, 37.23it/s] 100%|██████████| 516/516 [00:07<00:00, 70.85it/s] [060/100] Train Acc: 0.460799 Loss: 1.877219 | Val Acc: 0.457263 loss: 1.893090 saving model with acc 0.457 100%|██████████| 2067/2067 [00:55<00:00, 37.47it/s] 100%|██████████| 516/516 [00:07<00:00, 70.51it/s] [061/100] Train Acc: 0.460986 Loss: 1.876462 | Val Acc: 0.457087 loss: 1.892663 100%|██████████| 2067/2067 [00:54<00:00, 37.86it/s] 100%|██████████| 516/516 [00:07<00:00, 65.45it/s] [062/100] Train Acc: 0.461094 Loss: 1.875714 | Val Acc: 0.457296 loss: 1.891906 saving model with acc 0.457 100%|██████████| 2067/2067 [00:55<00:00, 37.28it/s] 100%|██████████| 516/516 [00:08<00:00, 64.02it/s] [063/100] Train Acc: 0.461261 Loss: 1.874982 | Val Acc: 0.457419 loss: 1.891563 saving model with acc 0.457 100%|██████████| 2067/2067 [00:54<00:00, 38.00it/s] 100%|██████████| 516/516 [00:07<00:00, 65.85it/s] [064/100] Train Acc: 0.461345 Loss: 1.874242 | Val Acc: 0.457623 loss: 1.890625 saving model with acc 0.458 100%|██████████| 2067/2067 [00:54<00:00, 38.06it/s] 100%|██████████| 516/516 [00:08<00:00, 64.10it/s] [065/100] Train Acc: 0.461536 Loss: 1.873522 | Val Acc: 0.457767 loss: 1.890202 saving model with acc 0.458 100%|██████████| 2067/2067 [00:54<00:00, 38.18it/s] 100%|██████████| 516/516 [00:07<00:00, 64.59it/s] [066/100] Train Acc: 0.461673 Loss: 1.872819 | Val Acc: 0.457881 loss: 1.889742 saving model with acc 0.458 100%|██████████| 2067/2067 [00:56<00:00, 36.42it/s] 100%|██████████| 516/516 [00:07<00:00, 66.19it/s] [067/100] Train Acc: 0.461787 Loss: 1.872144 | Val Acc: 0.457760 loss: 1.889154 100%|██████████| 2067/2067 [00:56<00:00, 36.73it/s] 100%|██████████| 516/516 [00:07<00:00, 68.31it/s] [068/100] Train Acc: 0.461906 Loss: 1.871492 | Val Acc: 0.458173 loss: 1.888489 saving model with acc 0.458 100%|██████████| 2067/2067 [00:59<00:00, 34.99it/s] 100%|██████████| 516/516 [00:08<00:00, 60.90it/s] [069/100] Train Acc: 0.462087 Loss: 1.870809 | Val Acc: 0.458332 loss: 1.887800 saving model with acc 0.458 100%|██████████| 2067/2067 [01:12<00:00, 28.46it/s] 100%|██████████| 516/516 [00:09<00:00, 53.86it/s] [070/100] Train Acc: 0.462223 Loss: 1.870163 | Val Acc: 0.458453 loss: 1.887355 saving model with acc 0.458 100%|██████████| 2067/2067 [01:04<00:00, 32.00it/s] 100%|██████████| 516/516 [00:09<00:00, 53.15it/s] [071/100] Train Acc: 0.462292 Loss: 1.869491 | Val Acc: 0.458453 loss: 1.887015 100%|██████████| 2067/2067 [01:00<00:00, 34.15it/s] 100%|██████████| 516/516 [00:08<00:00, 58.77it/s] [072/100] Train Acc: 0.462522 Loss: 1.868878 | Val Acc: 0.458529 loss: 1.886238 saving model with acc 0.459 100%|██████████| 2067/2067 [00:59<00:00, 34.94it/s] 100%|██████████| 516/516 [00:08<00:00, 61.02it/s] [073/100] Train Acc: 0.462546 Loss: 1.868245 | Val Acc: 0.458944 loss: 1.885764 saving model with acc 0.459 100%|██████████| 2067/2067 [00:59<00:00, 34.98it/s] 100%|██████████| 516/516 [00:08<00:00, 57.37it/s] [074/100] Train Acc: 0.462739 Loss: 1.867639 | Val Acc: 0.458730 loss: 1.885303 100%|██████████| 2067/2067 [00:57<00:00, 36.23it/s] 100%|██████████| 516/516 [00:08<00:00, 58.49it/s] [075/100] Train Acc: 0.462773 Loss: 1.867054 | Val Acc: 0.458781 loss: 1.884844 100%|██████████| 2067/2067 [00:58<00:00, 35.44it/s] 100%|██████████| 516/516 [00:07<00:00, 65.41it/s] [076/100] Train Acc: 0.462957 Loss: 1.866452 | Val Acc: 0.458967 loss: 1.884511 saving model with acc 0.459 100%|██████████| 2067/2067 [00:58<00:00, 35.26it/s] 100%|██████████| 516/516 [00:08<00:00, 57.90it/s] [077/100] Train Acc: 0.463071 Loss: 1.865853 | Val Acc: 0.459239 loss: 1.883732 saving model with acc 0.459 100%|██████████| 2067/2067 [00:58<00:00, 35.58it/s] 100%|██████████| 516/516 [00:08<00:00, 58.51it/s] [078/100] Train Acc: 0.463167 Loss: 1.865279 | Val Acc: 0.458987 loss: 1.883295 100%|██████████| 2067/2067 [00:56<00:00, 36.34it/s] 100%|██████████| 516/516 [00:08<00:00, 61.42it/s] [079/100] Train Acc: 0.463261 Loss: 1.864719 | Val Acc: 0.459186 loss: 1.882890 100%|██████████| 2067/2067 [01:00<00:00, 34.38it/s] 100%|██████████| 516/516 [00:08<00:00, 61.18it/s] [080/100] Train Acc: 0.463340 Loss: 1.864145 | Val Acc: 0.459205 loss: 1.882407 100%|██████████| 2067/2067 [00:57<00:00, 35.88it/s] 100%|██████████| 516/516 [00:08<00:00, 61.33it/s] [081/100] Train Acc: 0.463554 Loss: 1.863579 | Val Acc: 0.459459 loss: 1.881937 saving model with acc 0.459 100%|██████████| 2067/2067 [00:58<00:00, 35.08it/s] 100%|██████████| 516/516 [00:07<00:00, 64.66it/s] [082/100] Train Acc: 0.463750 Loss: 1.863056 | Val Acc: 0.459543 loss: 1.881626 saving model with acc 0.460 100%|██████████| 2067/2067 [00:58<00:00, 35.38it/s] 100%|██████████| 516/516 [00:07<00:00, 64.78it/s] [083/100] Train Acc: 0.463698 Loss: 1.862503 | Val Acc: 0.459539 loss: 1.881191 100%|██████████| 2067/2067 [00:58<00:00, 35.32it/s] 100%|██████████| 516/516 [00:08<00:00, 58.59it/s] [084/100] Train Acc: 0.463868 Loss: 1.861979 | Val Acc: 0.459512 loss: 1.880576 100%|██████████| 2067/2067 [00:57<00:00, 35.83it/s] 100%|██████████| 516/516 [00:08<00:00, 60.91it/s] [085/100] Train Acc: 0.463994 Loss: 1.861431 | Val Acc: 0.459776 loss: 1.880033 saving model with acc 0.460 100%|██████████| 2067/2067 [00:57<00:00, 35.91it/s] 100%|██████████| 516/516 [00:07<00:00, 65.50it/s] [086/100] Train Acc: 0.464049 Loss: 1.860918 | Val Acc: 0.459662 loss: 1.879941 100%|██████████| 2067/2067 [00:57<00:00, 35.91it/s] 100%|██████████| 516/516 [00:07<00:00, 69.31it/s] [087/100] Train Acc: 0.464190 Loss: 1.860406 | Val Acc: 0.459978 loss: 1.879342 saving model with acc 0.460 100%|██████████| 2067/2067 [00:57<00:00, 36.22it/s] 100%|██████████| 516/516 [00:08<00:00, 62.61it/s] [088/100] Train Acc: 0.464191 Loss: 1.859903 | Val Acc: 0.460050 loss: 1.879095 saving model with acc 0.460 100%|██████████| 2067/2067 [00:57<00:00, 35.95it/s] 100%|██████████| 516/516 [00:08<00:00, 61.68it/s] [089/100] Train Acc: 0.464335 Loss: 1.859389 | Val Acc: 0.459982 loss: 1.878472 100%|██████████| 2067/2067 [00:56<00:00, 36.75it/s] 100%|██████████| 516/516 [00:07<00:00, 65.84it/s] [090/100] Train Acc: 0.464477 Loss: 1.858915 | Val Acc: 0.459884 loss: 1.878217 100%|██████████| 2067/2067 [00:55<00:00, 37.40it/s] 100%|██████████| 516/516 [00:08<00:00, 60.87it/s] [091/100] Train Acc: 0.464511 Loss: 1.858426 | Val Acc: 0.460247 loss: 1.877929 saving model with acc 0.460 100%|██████████| 2067/2067 [00:55<00:00, 37.13it/s] 100%|██████████| 516/516 [00:08<00:00, 63.50it/s] [092/100] Train Acc: 0.464689 Loss: 1.857926 | Val Acc: 0.460113 loss: 1.877566 100%|██████████| 2067/2067 [00:57<00:00, 35.71it/s] 100%|██████████| 516/516 [00:07<00:00, 72.81it/s] [093/100] Train Acc: 0.464692 Loss: 1.857448 | Val Acc: 0.460380 loss: 1.876994 saving model with acc 0.460 100%|██████████| 2067/2067 [00:56<00:00, 36.44it/s] 100%|██████████| 516/516 [00:07<00:00, 72.37it/s] [094/100] Train Acc: 0.464821 Loss: 1.856997 | Val Acc: 0.460418 loss: 1.876573 saving model with acc 0.460 100%|██████████| 2067/2067 [00:56<00:00, 36.76it/s] 100%|██████████| 516/516 [00:07<00:00, 67.15it/s] [095/100] Train Acc: 0.464949 Loss: 1.856516 | Val Acc: 0.460291 loss: 1.876206 100%|██████████| 2067/2067 [00:56<00:00, 36.48it/s] 100%|██████████| 516/516 [00:08<00:00, 63.35it/s] [096/100] Train Acc: 0.464960 Loss: 1.856054 | Val Acc: 0.460395 loss: 1.875968 100%|██████████| 2067/2067 [00:56<00:00, 36.46it/s] 100%|██████████| 516/516 [00:08<00:00, 59.91it/s] [097/100] Train Acc: 0.465175 Loss: 1.855608 | Val Acc: 0.460422 loss: 1.875492 saving model with acc 0.460 100%|██████████| 2067/2067 [00:55<00:00, 37.19it/s] 100%|██████████| 516/516 [00:08<00:00, 60.70it/s] [098/100] Train Acc: 0.465203 Loss: 1.855148 | Val Acc: 0.460865 loss: 1.875298 saving model with acc 0.461 100%|██████████| 2067/2067 [00:56<00:00, 36.84it/s] 100%|██████████| 516/516 [00:08<00:00, 61.92it/s] [099/100] Train Acc: 0.465197 Loss: 1.854699 | Val Acc: 0.460740 loss: 1.874919 100%|██████████| 2067/2067 [00:55<00:00, 36.99it/s] 100%|██████████| 516/516 [00:08<00:00, 61.21it/s][100/100] Train Acc: 0.465276 Loss: 1.854267 | Val Acc: 0.460742 loss: 1.874491
del train_loader, val_loader
gc.collect()
0
测试
# load data
test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)
test_set = LibriDataset(test_X, None)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
[Dataset] – # phone classes: 41, number of utterances for test: 1078 1078it [00:02, 398.82it/s][INFO] test set torch.Size([646268, 39])
# load model
model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)
model.load_state_dict(torch.load(model_path))
<All keys matched successfully>
预测
test_acc = 0.0
test_lengths = 0
pred = np.array([], dtype=np.int32)
model.eval()
with torch.no_grad():
for i, batch in enumerate(tqdm(test_loader)):
features = batch
features = features.to(device)
outputs = model(features)
_, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)
100%|██████████| 632/632 [00:05<00:00, 113.55it/s]
写入文件
with open('prediction.csv', 'w') as f:
f.write('Id,Class\n')
for i, y in enumerate(pred):
f.write('{},{}\n'.format(i, y))
以上就是HW2的基础代码,等我变强之后一定回来!