李宏毅(2020)作业9:无监督学习降维、聚类、自编码_clustering李宏毅-程序员宅基地

技术标签: 聚类  机器学习  无监督学习  自编码器  pca降维  

在这里插入图片描述

数据集

作业

任务1

请至少使用两种方法 (autoencoder 架构、optimizer、data preprocessing、后续降维方法、clustering 算法等等) 来改进 baseline code 的 accuracy。

  • 记录改进前、后的 accuracy 分别为多少。
  • 使用改进前、后的方法,分别将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。
    在这里插入图片描述

任务2

使用你 accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片。

  • 画出他们的原图以及 reconstruct 之后的图片。
    在这里插入图片描述

任务3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints。

  • 请用 model 的 reconstruction error (用所有的 trainX 计算 MSE) 和 val accuracy 对那些 checkpoints 作图。
    在这里插入图片描述

数据

请同学以 np.load() 读入资料,valX.npy 和 valY.npy 只用来检验我们的训练效果,不能用来训练。

trainX.npy

  • 里面总共有 8500 张 RGB 图片,大小都是 32 * 32 * 3
  • shape 为 (8500, 32, 32, 3)

valX.npy

  • 请不要用来训练
    • 里面总共有 500 张 RGB 图片,大小都是 32 * 32 * 3
    • shape 为 (500, 32, 32, 3)

valY.npy

  • 请不要用来训练
  • 对应 valX.npy 的 label
  • shape为 (500,)

下载数据集

创建 checkpoints文件夹

#!gdown --id '1BZb2AqOHHaad7Mo82St1qTBaXo_xtcUc' --output trainX.npy 
# !gdown --id '152NKCpj8S_zuIx3bQy0NN5oqpvBjdPIq' --output valX.npy 
# !gdown --id '1_hRGsFtm5KEazUg2ZvPZcuNScGF-ANh4' --output valY.npy 
!mkdir checkpoints
!ls
mkdir: 无法创建目录"checkpoints": 文件已存在
checkpoints	       trainX.npy
p1_baseline.png        valX.npy
prediction.csv	       valY.npy
prediction_invert.csv  李宏毅机器学习2020-作业9:无监督学习.ipynb

准备训练数据

定义我们的 preprocess:将图片的数值介于 0~255 的 int 线性转为 -1~1 的 float。

import numpy as np

def preprocess(image_list):
    """ Normalize Image and Permute (N,H,W,C) to (N,C,H,W)
    Args:
      image_list: List of images (9000, 32, 32, 3)
    Returns:
      image_list: List of images (9000, 3, 32, 32)
    """
    image_list = np.array(image_list)
    image_list = np.transpose(image_list, (0, 3, 1, 2))
    image_list = (image_list / 255.0) * 2 - 1
    image_list = image_list.astype(np.float32)
    return image_list

自定义Dataset

from torch.utils.data import Dataset

class Image_Dataset(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):
        images = self.image_list[idx]
        return images

将训练资料读入,并且 preprocess。之后我们将 preprocess 完的训练资料变成我们需要的 dataset。请同学不要使用 valX 和 valY 来训练。

from torch.utils.data import DataLoader

trainX = np.load('trainX.npy')
trainX_preprocessed = preprocess(trainX)
img_dataset = Image_Dataset(trainX_preprocessed)

一些工具函数

这边提供一些有用的 functions。一个是计算 model 参数量的(report 会用到),另一个是固定训练的随机种子(以便 reproduce)。

import random
import torch

def count_parameters(model, only_trainable=False):
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False #不做网络加速
    torch.backends.cudnn.deterministic = True #每次返回的卷积算法固定

模型

定义我们的 baseline autoencoder
ConvTranspose2d-逆卷积
在这里插入图片描述

关于模型的改进,我只是加深了一层encoder和decoder,效果会变好,参数的调整,只有epoch改为了1000

import torch.nn as nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2)
        )
 
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 5, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 9, stride=1),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 17, stride=1),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x  = self.decoder(x1)
        return x1, x
!nvidia-smi
Thu Nov  4 17:03:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce RTX 3090    Off  | 00000000:1A:00.0 Off |                  N/A |
| 57%   70C    P2   325W / 350W |   8107MiB / 24268MiB |     91%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    Off  | 00000000:68:00.0 Off |                  N/A |
|  0%   29C    P8    25W / 350W |    299MiB / 24265MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     11145      C   python                           8103MiB |
|    1   N/A  N/A      2432      G   /usr/lib/xorg/Xorg                 14MiB |
|    1   N/A  N/A      4006      G   /usr/bin/gnome-shell               17MiB |
|    1   N/A  N/A      4984      G   /usr/lib/xorg/Xorg                 70MiB |
|    1   N/A  N/A      5058      G   /usr/lib/xorg/Xorg                 18MiB |
|    1   N/A  N/A      5233      G   /usr/bin/gnome-shell              100MiB |
|    1   N/A  N/A      5384      G   /usr/bin/gnome-shell               36MiB |
|    1   N/A  N/A      6548      G   ...2179,14311511775341437302       36MiB |
+-----------------------------------------------------------------------------+

-----------------------------+

训练

这个部分就是主要的训练阶段。我们先将准备好的 dataset 当作参数喂给 dataloader。将 dataloader、model、loss criterion、optimizer 都准备好之后,就可以开始训练。训练完成后,我们会将 model 存下来。

import torch
from torch import optim

same_seeds(0)

model = AE().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

model.train()
n_epoch = 1000

# 准备 dataloader, model, loss criterion 和 optimizer
img_dataloader = DataLoader(img_dataset, batch_size=64, shuffle=True)

epoch_loss = 0

# 主要的训练过程
for epoch in range(n_epoch):
    epoch_loss = 0
    for data in img_dataloader:
        img = data
        img = img.cuda()

        output1, output = model(img)
        loss = criterion(output, img)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), './checkpoints/checkpoint_{}.pth'.format(epoch+1))
        
        epoch_loss += loss.item()
            
    print('epoch [{}/{}], loss:{:.5f}'.format(epoch+1, n_epoch, epoch_loss))

# 训练完成后存储 model
torch.save(model.state_dict(), './checkpoints/last_checkpoint.pth')
epoch [1/1000], loss:30.54165
epoch [2/1000], loss:26.34405
epoch [3/1000], loss:21.83250
epoch [4/1000], loss:19.13653
epoch [5/1000], loss:16.89123
epoch [6/1000], loss:15.81137
epoch [7/1000], loss:15.24495
epoch [8/1000], loss:14.82142
epoch [9/1000], loss:14.43517
epoch [10/1000], loss:14.08439
epoch [11/1000], loss:13.73920
epoch [12/1000], loss:13.40639
epoch [13/1000], loss:13.08327
epoch [14/1000], loss:12.66554
epoch [15/1000], loss:12.26715
epoch [16/1000], loss:11.93717
epoch [17/1000], loss:11.67487
epoch [18/1000], loss:11.45737
epoch [19/1000], loss:11.28208
epoch [20/1000], loss:11.08628
epoch [21/1000], loss:10.94622
epoch [22/1000], loss:10.80847
epoch [23/1000], loss:10.70417
epoch [24/1000], loss:10.58255
epoch [25/1000], loss:10.48495
epoch [26/1000], loss:10.39527
epoch [27/1000], loss:10.30006
epoch [28/1000], loss:10.20910
epoch [29/1000], loss:10.13124
epoch [30/1000], loss:10.04456
epoch [31/1000], loss:9.96836
epoch [32/1000], loss:9.88246
epoch [33/1000], loss:9.81235
epoch [34/1000], loss:9.72425
epoch [35/1000], loss:9.65545
epoch [36/1000], loss:9.57657
epoch [37/1000], loss:9.51310
epoch [38/1000], loss:9.45421
epoch [39/1000], loss:9.38250
epoch [40/1000], loss:9.31712
epoch [41/1000], loss:9.25833
epoch [42/1000], loss:9.20196
epoch [43/1000], loss:9.14868
epoch [44/1000], loss:9.08939
epoch [45/1000], loss:9.02597
epoch [46/1000], loss:8.95911
epoch [47/1000], loss:8.91480
epoch [48/1000], loss:8.86116
epoch [49/1000], loss:8.79443
epoch [50/1000], loss:8.73779
epoch [51/1000], loss:8.68570
epoch [52/1000], loss:8.62910
epoch [53/1000], loss:8.57338
epoch [54/1000], loss:8.53807
epoch [55/1000], loss:8.48156
epoch [56/1000], loss:8.43463
epoch [57/1000], loss:8.39641
epoch [58/1000], loss:8.34074
epoch [59/1000], loss:8.30465
epoch [60/1000], loss:8.27341
epoch [61/1000], loss:8.23230
epoch [62/1000], loss:8.18089
epoch [63/1000], loss:8.15129
epoch [64/1000], loss:8.11520
epoch [65/1000], loss:8.07959
epoch [66/1000], loss:8.04687
epoch [67/1000], loss:8.02380
epoch [68/1000], loss:7.98933
epoch [69/1000], loss:7.95649
epoch [70/1000], loss:7.92910
epoch [71/1000], loss:7.88972
epoch [72/1000], loss:7.85813
epoch [73/1000], loss:7.82851
epoch [74/1000], loss:7.81065
epoch [75/1000], loss:7.78497
epoch [76/1000], loss:7.73110
epoch [77/1000], loss:7.71461
epoch [78/1000], loss:7.68887
epoch [79/1000], loss:7.65523
epoch [80/1000], loss:7.63705
epoch [81/1000], loss:7.61096
epoch [82/1000], loss:7.57877
epoch [83/1000], loss:7.54703
epoch [84/1000], loss:7.52961
epoch [85/1000], loss:7.48876
epoch [86/1000], loss:7.46642
epoch [87/1000], loss:7.43804
epoch [88/1000], loss:7.41458
epoch [89/1000], loss:7.38298
epoch [90/1000], loss:7.38157
epoch [91/1000], loss:7.34053
epoch [92/1000], loss:7.32307
epoch [93/1000], loss:7.28897
epoch [94/1000], loss:7.27476
epoch [95/1000], loss:7.25432
epoch [96/1000], loss:7.23210
epoch [97/1000], loss:7.20764
epoch [98/1000], loss:7.17726
epoch [99/1000], loss:7.16785
epoch [100/1000], loss:7.14477
epoch [101/1000], loss:7.12776
epoch [102/1000], loss:7.10490
epoch [103/1000], loss:7.08108
epoch [104/1000], loss:7.06430
epoch [105/1000], loss:7.04382
epoch [106/1000], loss:7.01336
epoch [107/1000], loss:7.00099
epoch [108/1000], loss:6.97758
epoch [109/1000], loss:6.95376
epoch [110/1000], loss:6.94354
epoch [111/1000], loss:6.91744
epoch [112/1000], loss:6.91015
epoch [113/1000], loss:6.88055
epoch [114/1000], loss:6.86521
epoch [115/1000], loss:6.84671
epoch [116/1000], loss:6.82973
epoch [117/1000], loss:6.80817
epoch [118/1000], loss:6.78769
epoch [119/1000], loss:6.77140
epoch [120/1000], loss:6.76178
epoch [121/1000], loss:6.74296
epoch [122/1000], loss:6.71641
epoch [123/1000], loss:6.69564
epoch [124/1000], loss:6.67923
epoch [125/1000], loss:6.66339
epoch [126/1000], loss:6.64667
epoch [127/1000], loss:6.62993
epoch [128/1000], loss:6.60127
epoch [129/1000], loss:6.58229
epoch [130/1000], loss:6.57563
epoch [131/1000], loss:6.55139
epoch [132/1000], loss:6.53123
epoch [133/1000], loss:6.51448
epoch [134/1000], loss:6.49753
epoch [135/1000], loss:6.46827
epoch [136/1000], loss:6.45886
epoch [137/1000], loss:6.43451
epoch [138/1000], loss:6.41819
epoch [139/1000], loss:6.39429
epoch [140/1000], loss:6.38479
epoch [141/1000], loss:6.36964
epoch [142/1000], loss:6.34008
epoch [143/1000], loss:6.32599
epoch [144/1000], loss:6.30631
epoch [145/1000], loss:6.29071
epoch [146/1000], loss:6.27065
epoch [147/1000], loss:6.25629
epoch [148/1000], loss:6.23477
epoch [149/1000], loss:6.22027
epoch [150/1000], loss:6.20892
epoch [151/1000], loss:6.18379
epoch [152/1000], loss:6.16717
epoch [153/1000], loss:6.15294
epoch [154/1000], loss:6.13922
epoch [155/1000], loss:6.12273
epoch [156/1000], loss:6.09983
epoch [157/1000], loss:6.09613
epoch [158/1000], loss:6.08098
epoch [159/1000], loss:6.06648
epoch [160/1000], loss:6.05687
epoch [161/1000], loss:6.03163
epoch [162/1000], loss:6.00917
epoch [163/1000], loss:6.00572
epoch [164/1000], loss:5.99157
epoch [165/1000], loss:5.97707
epoch [166/1000], loss:5.96627
epoch [167/1000], loss:5.96171
epoch [168/1000], loss:5.93227
epoch [169/1000], loss:5.92656
epoch [170/1000], loss:5.92673
epoch [171/1000], loss:5.90135
epoch [172/1000], loss:5.89017
epoch [173/1000], loss:5.87263
epoch [174/1000], loss:5.86483
epoch [175/1000], loss:5.85099
epoch [176/1000], loss:5.83615
epoch [177/1000], loss:5.83101
epoch [178/1000], loss:5.82030
epoch [179/1000], loss:5.82544
epoch [180/1000], loss:5.78977
epoch [181/1000], loss:5.78293
epoch [182/1000], loss:5.77460
epoch [183/1000], loss:5.76192
epoch [184/1000], loss:5.75049
epoch [185/1000], loss:5.74188
epoch [186/1000], loss:5.73882
epoch [187/1000], loss:5.72205
epoch [188/1000], loss:5.70864
epoch [189/1000], loss:5.70273
epoch [190/1000], loss:5.69353
epoch [191/1000], loss:5.68343
epoch [192/1000], loss:5.67216
epoch [193/1000], loss:5.66239
epoch [194/1000], loss:5.65125
epoch [195/1000], loss:5.63932
epoch [196/1000], loss:5.63388
epoch [197/1000], loss:5.62116
epoch [198/1000], loss:5.61385
epoch [199/1000], loss:5.61483
epoch [200/1000], loss:5.59609
epoch [201/1000], loss:5.57955
epoch [202/1000], loss:5.57469
epoch [203/1000], loss:5.56383
epoch [204/1000], loss:5.55489
epoch [205/1000], loss:5.54320
epoch [206/1000], loss:5.52971
epoch [207/1000], loss:5.53083
epoch [208/1000], loss:5.51958
epoch [209/1000], loss:5.50594
epoch [210/1000], loss:5.50194
epoch [211/1000], loss:5.49632
epoch [212/1000], loss:5.47642
epoch [213/1000], loss:5.47105
epoch [214/1000], loss:5.46103
epoch [215/1000], loss:5.45136
epoch [216/1000], loss:5.44927
epoch [217/1000], loss:5.43372
epoch [218/1000], loss:5.43229
epoch [219/1000], loss:5.41293
epoch [220/1000], loss:5.40677
epoch [221/1000], loss:5.39713
epoch [222/1000], loss:5.39402
epoch [223/1000], loss:5.38856
epoch [224/1000], loss:5.37551
epoch [225/1000], loss:5.36045
epoch [226/1000], loss:5.35389
epoch [227/1000], loss:5.34672
epoch [228/1000], loss:5.33802
epoch [229/1000], loss:5.33105
epoch [230/1000], loss:5.32277
epoch [231/1000], loss:5.30828
epoch [232/1000], loss:5.29910
epoch [233/1000], loss:5.29399
epoch [234/1000], loss:5.28984
epoch [235/1000], loss:5.27597
epoch [236/1000], loss:5.26934
epoch [237/1000], loss:5.26663
epoch [238/1000], loss:5.25943
epoch [239/1000], loss:5.24395
epoch [240/1000], loss:5.24214
epoch [241/1000], loss:5.23017
epoch [242/1000], loss:5.21525
epoch [243/1000], loss:5.21001
epoch [244/1000], loss:5.20533
epoch [245/1000], loss:5.19778
epoch [246/1000], loss:5.19444
epoch [247/1000], loss:5.17834
epoch [248/1000], loss:5.17032
epoch [249/1000], loss:5.16573
epoch [250/1000], loss:5.16030
epoch [251/1000], loss:5.15691
epoch [252/1000], loss:5.14337
epoch [253/1000], loss:5.13357
epoch [254/1000], loss:5.12614
epoch [255/1000], loss:5.12397
epoch [256/1000], loss:5.11111
epoch [257/1000], loss:5.09905
epoch [258/1000], loss:5.09718
epoch [259/1000], loss:5.09271
epoch [260/1000], loss:5.08443
epoch [261/1000], loss:5.07630
epoch [262/1000], loss:5.06473
epoch [263/1000], loss:5.06329
epoch [264/1000], loss:5.05452
epoch [265/1000], loss:5.04306
epoch [266/1000], loss:5.04899
epoch [267/1000], loss:5.03139
epoch [268/1000], loss:5.02383
epoch [269/1000], loss:5.01982
epoch [270/1000], loss:5.01273
epoch [271/1000], loss:5.00642
epoch [272/1000], loss:4.99454
epoch [273/1000], loss:4.99690
epoch [274/1000], loss:4.98375
epoch [275/1000], loss:4.98370
epoch [276/1000], loss:4.96812
epoch [277/1000], loss:4.96210
epoch [278/1000], loss:4.96167
epoch [279/1000], loss:4.94264
epoch [280/1000], loss:4.94708
epoch [281/1000], loss:4.93381
epoch [282/1000], loss:4.92656
epoch [283/1000], loss:4.92751
epoch [284/1000], loss:4.91519
epoch [285/1000], loss:4.90649
epoch [286/1000], loss:4.90130
epoch [287/1000], loss:4.89965
epoch [288/1000], loss:4.88647
epoch [289/1000], loss:4.88522
epoch [290/1000], loss:4.87119
epoch [291/1000], loss:4.86967
epoch [292/1000], loss:4.86545
epoch [293/1000], loss:4.85670
epoch [294/1000], loss:4.84635
epoch [295/1000], loss:4.84253
epoch [296/1000], loss:4.84705
epoch [297/1000], loss:4.82709
epoch [298/1000], loss:4.82251
epoch [299/1000], loss:4.81915
epoch [300/1000], loss:4.81493
epoch [301/1000], loss:4.80140
epoch [302/1000], loss:4.79302
epoch [303/1000], loss:4.79099
epoch [304/1000], loss:4.78271
epoch [305/1000], loss:4.77509
epoch [306/1000], loss:4.76755
epoch [307/1000], loss:4.76485
epoch [308/1000], loss:4.76169
epoch [309/1000], loss:4.75328
epoch [310/1000], loss:4.74254
epoch [311/1000], loss:4.74224
epoch [312/1000], loss:4.74067
epoch [313/1000], loss:4.72933
epoch [314/1000], loss:4.71486
epoch [315/1000], loss:4.71784
epoch [316/1000], loss:4.70222
epoch [317/1000], loss:4.70290
epoch [318/1000], loss:4.69542
epoch [319/1000], loss:4.69025
epoch [320/1000], loss:4.68246
epoch [321/1000], loss:4.67295
epoch [322/1000], loss:4.67523
epoch [323/1000], loss:4.67207
epoch [324/1000], loss:4.66636
epoch [325/1000], loss:4.64616
epoch [326/1000], loss:4.64512
epoch [327/1000], loss:4.64286
epoch [328/1000], loss:4.63428
epoch [329/1000], loss:4.62759
epoch [330/1000], loss:4.62275
epoch [331/1000], loss:4.61570
epoch [332/1000], loss:4.61228
epoch [333/1000], loss:4.60109
epoch [334/1000], loss:4.60413
epoch [335/1000], loss:4.58950
epoch [336/1000], loss:4.59071
epoch [337/1000], loss:4.58295
epoch [338/1000], loss:4.57782
epoch [339/1000], loss:4.57129
epoch [340/1000], loss:4.56505
epoch [341/1000], loss:4.56037
epoch [342/1000], loss:4.55598
epoch [343/1000], loss:4.54537
epoch [344/1000], loss:4.54019
epoch [345/1000], loss:4.53571
epoch [346/1000], loss:4.53185
epoch [347/1000], loss:4.53183
epoch [348/1000], loss:4.52009
epoch [349/1000], loss:4.51411
epoch [350/1000], loss:4.50916
epoch [351/1000], loss:4.50595
epoch [352/1000], loss:4.50171
epoch [353/1000], loss:4.49431
epoch [354/1000], loss:4.48945
epoch [355/1000], loss:4.48904
epoch [356/1000], loss:4.47484
epoch [357/1000], loss:4.47601
epoch [358/1000], loss:4.46283
epoch [359/1000], loss:4.46043
epoch [360/1000], loss:4.45623
epoch [361/1000], loss:4.45144
epoch [387/1000], loss:4.32588
epoch [388/1000], loss:4.31738
epoch [389/1000], loss:4.31798
epoch [390/1000], loss:4.31714
epoch [391/1000], loss:4.30985
epoch [392/1000], loss:4.29957
epoch [393/1000], loss:4.29696
epoch [394/1000], loss:4.29420
epoch [395/1000], loss:4.28667
epoch [396/1000], loss:4.28612
epoch [397/1000], loss:4.27635
epoch [398/1000], loss:4.27332
epoch [399/1000], loss:4.27225
epoch [400/1000], loss:4.26569
epoch [401/1000], loss:4.26683
epoch [402/1000], loss:4.25562
epoch [403/1000], loss:4.24940
epoch [404/1000], loss:4.24415
epoch [405/1000], loss:4.24422
epoch [406/1000], loss:4.24053
epoch [407/1000], loss:4.23612
epoch [408/1000], loss:4.23212
epoch [409/1000], loss:4.23014
epoch [410/1000], loss:4.22054
epoch [411/1000], loss:4.21572
epoch [412/1000], loss:4.21339
epoch [413/1000], loss:4.20922
epoch [414/1000], loss:4.20910
epoch [415/1000], loss:4.20353
epoch [416/1000], loss:4.19610
epoch [417/1000], loss:4.19232
epoch [418/1000], loss:4.18926
epoch [419/1000], loss:4.18134
epoch [420/1000], loss:4.17638
epoch [421/1000], loss:4.17397
epoch [422/1000], loss:4.17142
epoch [423/1000], loss:4.16676
epoch [424/1000], loss:4.17102
epoch [425/1000], loss:4.15542
epoch [426/1000], loss:4.15438
epoch [427/1000], loss:4.15161
epoch [428/1000], loss:4.14431
epoch [429/1000], loss:4.14308
epoch [430/1000], loss:4.14248
epoch [431/1000], loss:4.13705
epoch [432/1000], loss:4.13069
epoch [433/1000], loss:4.12359
epoch [434/1000], loss:4.12440
epoch [435/1000], loss:4.12047
epoch [436/1000], loss:4.11715
epoch [437/1000], loss:4.11095
epoch [438/1000], loss:4.10556
epoch [439/1000], loss:4.10342
epoch [440/1000], loss:4.10314
epoch [441/1000], loss:4.09450
epoch [442/1000], loss:4.08683
epoch [443/1000], loss:4.08545
epoch [444/1000], loss:4.08673
epoch [445/1000], loss:4.07830
epoch [446/1000], loss:4.07518
epoch [447/1000], loss:4.06704
epoch [448/1000], loss:4.06815
epoch [449/1000], loss:4.06158
epoch [450/1000], loss:4.06410
epoch [451/1000], loss:4.05870
epoch [452/1000], loss:4.05462
epoch [453/1000], loss:4.04799
epoch [454/1000], loss:4.04455
epoch [455/1000], loss:4.03678
epoch [456/1000], loss:4.04038
epoch [457/1000], loss:4.03390
epoch [458/1000], loss:4.02727
epoch [459/1000], loss:4.02408
epoch [460/1000], loss:4.02337
epoch [461/1000], loss:4.01824
epoch [462/1000], loss:4.01433
epoch [463/1000], loss:4.00995
epoch [464/1000], loss:4.00826
epoch [465/1000], loss:4.00209
epoch [466/1000], loss:4.00384
epoch [467/1000], loss:3.99173
epoch [468/1000], loss:3.99856
epoch [469/1000], loss:3.99148
epoch [470/1000], loss:3.98304
epoch [471/1000], loss:3.98313
epoch [472/1000], loss:3.97725
epoch [473/1000], loss:3.97736
epoch [474/1000], loss:3.97326
epoch [475/1000], loss:3.96900
epoch [476/1000], loss:3.96096
epoch [477/1000], loss:3.96076
epoch [478/1000], loss:3.96005
epoch [479/1000], loss:3.95441
epoch [480/1000], loss:3.95287
epoch [481/1000], loss:3.94587
epoch [482/1000], loss:3.94024
epoch [483/1000], loss:3.93922
epoch [484/1000], loss:3.93559
epoch [485/1000], loss:3.93831
epoch [486/1000], loss:3.92520
epoch [487/1000], loss:3.92634
epoch [488/1000], loss:3.92151
epoch [489/1000], loss:3.91649
epoch [490/1000], loss:3.91573
epoch [491/1000], loss:3.91516
epoch [492/1000], loss:3.90679
epoch [493/1000], loss:3.90961
epoch [494/1000], loss:3.89975
epoch [495/1000], loss:3.89675
epoch [496/1000], loss:3.89311
epoch [497/1000], loss:3.89344
epoch [498/1000], loss:3.89109
epoch [499/1000], loss:3.88556
epoch [500/1000], loss:3.87982
epoch [501/1000], loss:3.87826
epoch [502/1000], loss:3.87651
epoch [503/1000], loss:3.87134
epoch [504/1000], loss:3.86625
epoch [505/1000], loss:3.86563
epoch [506/1000], loss:3.86109
epoch [507/1000], loss:3.86168
epoch [508/1000], loss:3.85732
epoch [509/1000], loss:3.84998
epoch [510/1000], loss:3.85233
epoch [511/1000], loss:3.84760
epoch [512/1000], loss:3.84713
epoch [513/1000], loss:3.83537
epoch [514/1000], loss:3.83900
epoch [515/1000], loss:3.82796
epoch [516/1000], loss:3.82622
epoch [517/1000], loss:3.83100
epoch [518/1000], loss:3.82413
epoch [519/1000], loss:3.81903
epoch [520/1000], loss:3.81732
epoch [521/1000], loss:3.81084
epoch [522/1000], loss:3.81144
epoch [523/1000], loss:3.80305
epoch [524/1000], loss:3.80411
epoch [525/1000], loss:3.80302
epoch [526/1000], loss:3.79430
epoch [527/1000], loss:3.79282
epoch [528/1000], loss:3.79408
epoch [529/1000], loss:3.79307
epoch [530/1000], loss:3.78673
epoch [531/1000], loss:3.78254
epoch [532/1000], loss:3.77649
epoch [533/1000], loss:3.77460
epoch [534/1000], loss:3.77207
epoch [535/1000], loss:3.76966
epoch [536/1000], loss:3.76757
epoch [537/1000], loss:3.76382
epoch [538/1000], loss:3.75726
epoch [539/1000], loss:3.76330
epoch [540/1000], loss:3.75130
epoch [541/1000], loss:3.74979
epoch [542/1000], loss:3.74968
epoch [543/1000], loss:3.73983
epoch [544/1000], loss:3.73901
epoch [545/1000], loss:3.73932
epoch [546/1000], loss:3.73718
epoch [547/1000], loss:3.73794
epoch [548/1000], loss:3.72818
epoch [549/1000], loss:3.72528
epoch [550/1000], loss:3.72475
epoch [551/1000], loss:3.71988
epoch [552/1000], loss:3.71729
epoch [553/1000], loss:3.71119
epoch [554/1000], loss:3.71207
epoch [555/1000], loss:3.71167
epoch [556/1000], loss:3.70275
epoch [557/1000], loss:3.70654
epoch [558/1000], loss:3.69792
epoch [559/1000], loss:3.69927
epoch [560/1000], loss:3.69409
epoch [561/1000], loss:3.69188
epoch [562/1000], loss:3.68632
epoch [563/1000], loss:3.68308
epoch [564/1000], loss:3.68161
epoch [565/1000], loss:3.68463
epoch [566/1000], loss:3.67181
epoch [567/1000], loss:3.67101
epoch [568/1000], loss:3.66956
epoch [569/1000], loss:3.66723
epoch [570/1000], loss:3.66829
epoch [571/1000], loss:3.66422
epoch [572/1000], loss:3.66120
epoch [573/1000], loss:3.65323
epoch [574/1000], loss:3.65280
epoch [575/1000], loss:3.65279
epoch [576/1000], loss:3.64698
epoch [577/1000], loss:3.64525
epoch [578/1000], loss:3.64385
epoch [579/1000], loss:3.63892
epoch [580/1000], loss:3.63570
epoch [581/1000], loss:3.63038
epoch [582/1000], loss:3.63306
epoch [583/1000], loss:3.62456
epoch [584/1000], loss:3.62961
epoch [585/1000], loss:3.61710
epoch [586/1000], loss:3.62218
epoch [587/1000], loss:3.61367
epoch [588/1000], loss:3.61351
epoch [589/1000], loss:3.61048
epoch [590/1000], loss:3.60863
epoch [591/1000], loss:3.60503
epoch [592/1000], loss:3.60068
epoch [593/1000], loss:3.59856
epoch [594/1000], loss:3.59472
epoch [595/1000], loss:3.59365
epoch [596/1000], loss:3.59324
epoch [597/1000], loss:3.58769
epoch [598/1000], loss:3.58214
epoch [599/1000], loss:3.58244
epoch [600/1000], loss:3.57799
epoch [601/1000], loss:3.57877
epoch [602/1000], loss:3.57055
epoch [603/1000], loss:3.57307
epoch [604/1000], loss:3.57202
epoch [605/1000], loss:3.56517
epoch [606/1000], loss:3.56280
epoch [607/1000], loss:3.56200
epoch [608/1000], loss:3.56267
epoch [609/1000], loss:3.55470
epoch [610/1000], loss:3.55250
epoch [611/1000], loss:3.54826
epoch [612/1000], loss:3.55154
epoch [613/1000], loss:3.54208
epoch [614/1000], loss:3.54206
epoch [615/1000], loss:3.54105
epoch [616/1000], loss:3.53665
epoch [617/1000], loss:3.53198
epoch [618/1000], loss:3.52956
epoch [619/1000], loss:3.52716
epoch [620/1000], loss:3.52535
epoch [621/1000], loss:3.52693
epoch [622/1000], loss:3.51926
epoch [623/1000], loss:3.51655
epoch [624/1000], loss:3.51352
epoch [625/1000], loss:3.51410
epoch [626/1000], loss:3.50871
epoch [627/1000], loss:3.50490
epoch [628/1000], loss:3.50470
epoch [629/1000], loss:3.50429
epoch [630/1000], loss:3.50063
epoch [631/1000], loss:3.49522
epoch [632/1000], loss:3.49489
epoch [633/1000], loss:3.49385
epoch [634/1000], loss:3.48804
epoch [635/1000], loss:3.48522
epoch [636/1000], loss:3.48331
epoch [637/1000], loss:3.47941
epoch [638/1000], loss:3.47592
epoch [639/1000], loss:3.47459
epoch [640/1000], loss:3.47359
epoch [641/1000], loss:3.47270
epoch [642/1000], loss:3.46967
epoch [643/1000], loss:3.46600
epoch [644/1000], loss:3.46549
epoch [645/1000], loss:3.46019
epoch [646/1000], loss:3.45748
epoch [647/1000], loss:3.45389
epoch [648/1000], loss:3.44896
epoch [649/1000], loss:3.44991
epoch [650/1000], loss:3.44311
epoch [651/1000], loss:3.44865
epoch [652/1000], loss:3.44133
epoch [653/1000], loss:3.43858
epoch [654/1000], loss:3.44189
epoch [655/1000], loss:3.43480
epoch [656/1000], loss:3.43255
epoch [657/1000], loss:3.42989
epoch [658/1000], loss:3.42864
epoch [659/1000], loss:3.42396
epoch [660/1000], loss:3.42112
epoch [661/1000], loss:3.42302
epoch [662/1000], loss:3.41736
epoch [663/1000], loss:3.41416
epoch [664/1000], loss:3.41132
epoch [665/1000], loss:3.41046
epoch [666/1000], loss:3.40492
epoch [667/1000], loss:3.40502
epoch [668/1000], loss:3.40614
epoch [669/1000], loss:3.40063
epoch [670/1000], loss:3.40028
epoch [671/1000], loss:3.39271
epoch [672/1000], loss:3.39536
epoch [673/1000], loss:3.39127
epoch [674/1000], loss:3.38746
epoch [675/1000], loss:3.38874
epoch [676/1000], loss:3.38427
epoch [677/1000], loss:3.38143
epoch [678/1000], loss:3.37742
epoch [679/1000], loss:3.37587
epoch [680/1000], loss:3.37513
epoch [681/1000], loss:3.37196
epoch [682/1000], loss:3.36916
epoch [683/1000], loss:3.36594
epoch [684/1000], loss:3.36606
epoch [685/1000], loss:3.36292
epoch [686/1000], loss:3.35892
epoch [687/1000], loss:3.35532
epoch [688/1000], loss:3.35597
epoch [689/1000], loss:3.35689
epoch [690/1000], loss:3.34953
epoch [691/1000], loss:3.34964
epoch [692/1000], loss:3.34474
epoch [693/1000], loss:3.34500
epoch [694/1000], loss:3.34074
epoch [695/1000], loss:3.34088
epoch [696/1000], loss:3.33748
epoch [697/1000], loss:3.33662
epoch [698/1000], loss:3.33202
epoch [699/1000], loss:3.33229
epoch [700/1000], loss:3.32739
epoch [701/1000], loss:3.32630
epoch [702/1000], loss:3.32807
epoch [703/1000], loss:3.32146
epoch [704/1000], loss:3.31806
epoch [705/1000], loss:3.31831
epoch [706/1000], loss:3.31332
epoch [707/1000], loss:3.31269
epoch [708/1000], loss:3.30964
epoch [709/1000], loss:3.30984
epoch [710/1000], loss:3.30538
epoch [711/1000], loss:3.30281
epoch [712/1000], loss:3.30262
epoch [713/1000], loss:3.29772
epoch [714/1000], loss:3.29625
epoch [715/1000], loss:3.29219
epoch [716/1000], loss:3.29506
epoch [717/1000], loss:3.28936
epoch [718/1000], loss:3.28897
epoch [719/1000], loss:3.29049
epoch [720/1000], loss:3.28375
epoch [721/1000], loss:3.28123
epoch [722/1000], loss:3.27900
epoch [723/1000], loss:3.27359
epoch [724/1000], loss:3.27611
epoch [725/1000], loss:3.27433
epoch [726/1000], loss:3.27112
epoch [727/1000], loss:3.26646
epoch [728/1000], loss:3.26737
epoch [729/1000], loss:3.26536
epoch [730/1000], loss:3.26612
epoch [731/1000], loss:3.26075
epoch [732/1000], loss:3.26027
epoch [733/1000], loss:3.25291
epoch [734/1000], loss:3.25916
epoch [735/1000], loss:3.24919
epoch [736/1000], loss:3.25470
epoch [737/1000], loss:3.24516
epoch [738/1000], loss:3.24314
epoch [739/1000], loss:3.24429
epoch [740/1000], loss:3.24261
epoch [741/1000], loss:3.23813
epoch [742/1000], loss:3.23578
epoch [743/1000], loss:3.23666
epoch [744/1000], loss:3.23200
epoch [745/1000], loss:3.23238
epoch [746/1000], loss:3.22988
epoch [747/1000], loss:3.22826
epoch [748/1000], loss:3.23023
epoch [749/1000], loss:3.22209
epoch [750/1000], loss:3.21966
epoch [751/1000], loss:3.21754
epoch [752/1000], loss:3.21620
epoch [753/1000], loss:3.21760
epoch [754/1000], loss:3.21165
epoch [755/1000], loss:3.21131
epoch [756/1000], loss:3.21038
epoch [757/1000], loss:3.20712
epoch [758/1000], loss:3.20317
epoch [759/1000], loss:3.20223
epoch [760/1000], loss:3.20180
epoch [761/1000], loss:3.20010
epoch [762/1000], loss:3.19946
epoch [763/1000], loss:3.19183
epoch [764/1000], loss:3.19291
epoch [765/1000], loss:3.18863
epoch [766/1000], loss:3.18918
epoch [767/1000], loss:3.18898
epoch [768/1000], loss:3.18414
epoch [769/1000], loss:3.18572
epoch [770/1000], loss:3.18738
epoch [771/1000], loss:3.17861
epoch [772/1000], loss:3.17652
epoch [773/1000], loss:3.17587
epoch [774/1000], loss:3.17144
epoch [775/1000], loss:3.17319
epoch [776/1000], loss:3.17009
epoch [777/1000], loss:3.16943
epoch [778/1000], loss:3.16559
epoch [779/1000], loss:3.16415
epoch [780/1000], loss:3.16417
epoch [781/1000], loss:3.16414
epoch [782/1000], loss:3.15878
epoch [783/1000], loss:3.15620
epoch [784/1000], loss:3.15162
epoch [785/1000], loss:3.15188
epoch [786/1000], loss:3.15056
epoch [787/1000], loss:3.14792
epoch [788/1000], loss:3.14884
epoch [789/1000], loss:3.14594
epoch [790/1000], loss:3.14544
epoch [791/1000], loss:3.14156
epoch [792/1000], loss:3.13851
epoch [793/1000], loss:3.13792
epoch [794/1000], loss:3.13770
epoch [795/1000], loss:3.13333
epoch [796/1000], loss:3.13036
epoch [797/1000], loss:3.12862
epoch [798/1000], loss:3.13088
epoch [799/1000], loss:3.12679
epoch [800/1000], loss:3.12329
epoch [801/1000], loss:3.12549
epoch [802/1000], loss:3.12244
epoch [803/1000], loss:3.11828
epoch [804/1000], loss:3.11357
epoch [805/1000], loss:3.11698
epoch [806/1000], loss:3.11326
epoch [807/1000], loss:3.11584
epoch [808/1000], loss:3.10921
epoch [809/1000], loss:3.10769
epoch [810/1000], loss:3.10721
epoch [811/1000], loss:3.10426
epoch [812/1000], loss:3.10207
epoch [813/1000], loss:3.09837
epoch [814/1000], loss:3.09836
epoch [815/1000], loss:3.09801
epoch [816/1000], loss:3.09438
epoch [817/1000], loss:3.09267
epoch [818/1000], loss:3.09224
epoch [819/1000], loss:3.08851
epoch [820/1000], loss:3.08578
epoch [821/1000], loss:3.08942
epoch [822/1000], loss:3.08425
epoch [823/1000], loss:3.08528
epoch [824/1000], loss:3.08140
epoch [825/1000], loss:3.07830
epoch [826/1000], loss:3.07588
epoch [827/1000], loss:3.07775
epoch [828/1000], loss:3.07456
epoch [829/1000], loss:3.07019
epoch [830/1000], loss:3.07405
epoch [831/1000], loss:3.06494
epoch [832/1000], loss:3.06572
epoch [833/1000], loss:3.06405
epoch [834/1000], loss:3.06366
epoch [835/1000], loss:3.05963
epoch [836/1000], loss:3.05978
epoch [837/1000], loss:3.05587
epoch [838/1000], loss:3.05641
epoch [839/1000], loss:3.05452
epoch [840/1000], loss:3.05307
epoch [841/1000], loss:3.04878
epoch [842/1000], loss:3.05134
epoch [843/1000], loss:3.04592
epoch [844/1000], loss:3.04432
epoch [845/1000], loss:3.04292
epoch [846/1000], loss:3.04020
epoch [847/1000], loss:3.04101
epoch [848/1000], loss:3.04131
epoch [849/1000], loss:3.03655
epoch [850/1000], loss:3.03434
epoch [851/1000], loss:3.03037
epoch [852/1000], loss:3.03011
epoch [853/1000], loss:3.03031
epoch [854/1000], loss:3.02658
epoch [855/1000], loss:3.02762
epoch [856/1000], loss:3.02805
epoch [857/1000], loss:3.02052
epoch [858/1000], loss:3.02101
epoch [859/1000], loss:3.01820
epoch [860/1000], loss:3.01740
epoch [861/1000], loss:3.01673
epoch [862/1000], loss:3.01265
epoch [863/1000], loss:3.00953
epoch [864/1000], loss:3.01045
epoch [865/1000], loss:3.00850
epoch [866/1000], loss:3.01031
epoch [867/1000], loss:3.00408
epoch [868/1000], loss:3.00111
epoch [869/1000], loss:3.00130
epoch [870/1000], loss:3.00163
epoch [871/1000], loss:2.99810
epoch [872/1000], loss:2.99874
epoch [873/1000], loss:2.99178
epoch [874/1000], loss:2.99280
epoch [875/1000], loss:2.99230
epoch [876/1000], loss:2.98815
epoch [877/1000], loss:2.98851
epoch [878/1000], loss:2.98612
epoch [879/1000], loss:2.98797
epoch [880/1000], loss:2.98337
epoch [881/1000], loss:2.98161
epoch [882/1000], loss:2.98003
epoch [883/1000], loss:2.97484
epoch [884/1000], loss:2.97611
epoch [885/1000], loss:2.97621
epoch [886/1000], loss:2.97396
epoch [887/1000], loss:2.96927
epoch [888/1000], loss:2.96680
epoch [889/1000], loss:2.96926
epoch [890/1000], loss:2.96575
epoch [891/1000], loss:2.96431
epoch [892/1000], loss:2.96193
epoch [893/1000], loss:2.95761
epoch [894/1000], loss:2.96028
epoch [895/1000], loss:2.96046
epoch [896/1000], loss:2.95814
epoch [897/1000], loss:2.95228
epoch [898/1000], loss:2.94921
epoch [899/1000], loss:2.95213
epoch [900/1000], loss:2.94890
epoch [901/1000], loss:2.94738
epoch [902/1000], loss:2.94390
epoch [903/1000], loss:2.94118
epoch [904/1000], loss:2.94426
epoch [905/1000], loss:2.94239
epoch [906/1000], loss:2.93883
epoch [907/1000], loss:2.93823
epoch [908/1000], loss:2.93640
epoch [909/1000], loss:2.93234
epoch [910/1000], loss:2.93235
epoch [911/1000], loss:2.92981
epoch [912/1000], loss:2.93039
epoch [913/1000], loss:2.93373
epoch [914/1000], loss:2.92795
epoch [915/1000], loss:2.92420
epoch [916/1000], loss:2.92136
epoch [917/1000], loss:2.91813
epoch [918/1000], loss:2.91754
epoch [919/1000], loss:2.91795
epoch [920/1000], loss:2.91643
epoch [921/1000], loss:2.91321
epoch [922/1000], loss:2.91369
epoch [923/1000], loss:2.91094
epoch [924/1000], loss:2.91049
epoch [925/1000], loss:2.90867
epoch [926/1000], loss:2.90595
epoch [927/1000], loss:2.90455
epoch [928/1000], loss:2.90523
epoch [929/1000], loss:2.90355
epoch [930/1000], loss:2.90085
epoch [931/1000], loss:2.89791
epoch [932/1000], loss:2.89439
epoch [933/1000], loss:2.89587
epoch [934/1000], loss:2.89358
epoch [935/1000], loss:2.89229
epoch [936/1000], loss:2.88939
epoch [937/1000], loss:2.89070
epoch [938/1000], loss:2.88834
epoch [939/1000], loss:2.88700
epoch [940/1000], loss:2.88633
epoch [941/1000], loss:2.88195
epoch [942/1000], loss:2.88308
epoch [943/1000], loss:2.87824
epoch [944/1000], loss:2.87709
epoch [945/1000], loss:2.87709
epoch [946/1000], loss:2.87699
epoch [947/1000], loss:2.87330
epoch [948/1000], loss:2.87141
epoch [949/1000], loss:2.87136
epoch [950/1000], loss:2.86982
epoch [951/1000], loss:2.86829
epoch [952/1000], loss:2.86615
epoch [953/1000], loss:2.86325
epoch [954/1000], loss:2.86094
epoch [955/1000], loss:2.86219
epoch [956/1000], loss:2.85894
epoch [957/1000], loss:2.86180
epoch [958/1000], loss:2.85887
epoch [959/1000], loss:2.85384
epoch [960/1000], loss:2.85410
epoch [961/1000], loss:2.85243
epoch [962/1000], loss:2.85051
epoch [963/1000], loss:2.84668
epoch [964/1000], loss:2.84494
epoch [965/1000], loss:2.84352
epoch [966/1000], loss:2.84500
epoch [967/1000], loss:2.84642
epoch [968/1000], loss:2.83922
epoch [969/1000], loss:2.83965
epoch [970/1000], loss:2.84072
epoch [971/1000], loss:2.83823
epoch [972/1000], loss:2.83543
epoch [973/1000], loss:2.83415
epoch [974/1000], loss:2.83639
epoch [975/1000], loss:2.82995
epoch [976/1000], loss:2.82914
epoch [977/1000], loss:2.82669
epoch [978/1000], loss:2.83094
epoch [979/1000], loss:2.82190
epoch [980/1000], loss:2.82548
epoch [981/1000], loss:2.82011
epoch [982/1000], loss:2.82137
epoch [983/1000], loss:2.81966
epoch [984/1000], loss:2.81743
epoch [985/1000], loss:2.81949
epoch [986/1000], loss:2.81346
epoch [987/1000], loss:2.81393
epoch [988/1000], loss:2.81204
epoch [989/1000], loss:2.81101
epoch [990/1000], loss:2.81068
epoch [991/1000], loss:2.80631
epoch [992/1000], loss:2.80828
epoch [993/1000], loss:2.80407
epoch [994/1000], loss:2.80417
epoch [995/1000], loss:2.80385
epoch [996/1000], loss:2.80122
epoch [997/1000], loss:2.80091
epoch [998/1000], loss:2.79750
epoch [999/1000], loss:2.79585
epoch [1000/1000], loss:2.79409

降维和聚类

import numpy as np

def cal_acc(gt, pred):
    """ Computes categorization accuracy of our task.
    Args:
      gt: Ground truth labels (9000, )
      pred: Predicted labels (9000, )
    Returns:
      acc: Accuracy (0~1 scalar)
    """
    # Calculate Correct predictions
    correct = np.sum(gt == pred)
    acc = correct / gt.shape[0]
    # 因为是binary unsupervised clustering,因此取max(acc,1-acc)# 因为我们只在乎有没有成功将图片分成两群
    return max(acc, 1-acc)
import matplotlib.pyplot as plt

def plot_scatter(feat, label, savefig=None):
    """ Plot Scatter Image.
    Args:
      feat: the (x, y) coordinate of clustering result, shape: (9000, 2)
      label: ground truth label of image (0/1), shape: (9000,)
    Returns:
      None
    """
    X = feat[:, 0]
    Y = feat[:, 1]
    plt.scatter(X, Y, c = label)
    plt.legend(loc='best')
    if savefig is not None:
        plt.savefig(savefig)
    plt.show()
    return

接着我们使用训练好的 model,来预测 testing data 的类别。

由于 testing data 与 training data 一样,因此我们使用同样的 dataset 来实作 dataloader。与 training 不同的地方在于 shuffle 这个参数值在这边是 False。

准备好 model 与 dataloader,我们就可以进行预测了。

我们只需要 encoder 的结果(latents),利用 latents 进行 clustering 之后,就可以分类了。

import torch
from sklearn.decomposition import KernelPCA
#主成分分析(PCA)
#主成分分析(Principal Component Analysis)是目前为止最流行的降维算法。首先它找到接近数据集分布的超平面,然后将所有的数据都投影到这个超平面上。
#保留最大方差的超平面
# kPCA 是无监督学习算法,因此没有明显的性能指标可以帮助我们选择最佳的核和超参数值。不过,降维通常是监督学习任务(例如分类)的准备步骤.
from sklearn.manifold import TSNE
from sklearn.cluster import MiniBatchKMeans

def inference(X, model, batch_size=256):
    X = preprocess(X)
    dataset = Image_Dataset(X)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    latents = []
    for i, x in enumerate(dataloader):
        #数据格式转换,以及取出相应格式的数据
        x = torch.FloatTensor(x)
        vec, img = model(x.cuda())
        if i == 0:
            #view()函数的功能根reshape类似,用来转换size大小。
          #x = x.view(batchsize, -1)中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。
            latents = vec.view(img.size()[0], -1).cpu().detach().numpy()
        else:
            latents = np.concatenate((latents, vec.view(img.size()[0], -1).cpu().detach().numpy()), axis = 0)
            #在零轴方向上合并
    print('Latents Shape:', latents.shape)
    return latents

def predict(latents):
    # First Dimension Reduction
    #这里用到的rbf核函数
    transformer = KernelPCA(n_components=200, kernel='rbf', n_jobs=-1)
    #n_components:  
    #意义:PCA算法中所要保留的主成分个数n,也即保留下来的特征个数n
    #n_jobs:int型变量,并行运行的个数。 
    #-1:使用所有CPU. n_jobs<-1时,使用(n_cpus+1+n_jobs)个CPU


    #transform函数是一定可以替换为fit_transform函数的
    #fit_transform函数不能替换为transform函数!
    #fit前缀只是方便后面API调用.
    kpca = transformer.fit_transform(latents)
    print('First Reduction Shape:', kpca.shape)

    # # Second Dimesnion Reduction
    X_embedded = TSNE(n_components=2).fit_transform(kpca)
    print('Second Reduction Shape:', X_embedded.shape)

    # Clustering
    #n_cluster:类中心的个数,默认为8
    #random_state:参数为int,RandomState instance or None.用来设置生成随机数的方式 
    pred = MiniBatchKMeans(n_clusters=2, random_state=0).fit(X_embedded)
    pred = [int(i) for i in pred.labels_]
    pred = np.array(pred)
    return pred, X_embedded

def invert(pred):
    return np.abs(1-pred)

def save_prediction(pred, out_csv='prediction.csv'):
    with open(out_csv, 'w') as f:
        f.write('id,label\n')
        for i, p in enumerate(pred):
            f.write(f'{
      i},{
      p}\n')
    print(f'Save prediction to {
      out_csv}.')

# load model
model = AE().cuda()
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()

# 准备 data
trainX = np.load('trainX.npy')

# 预测答案
latents = inference(X=trainX, model=model)
pred, X_embedded = predict(latents)

# 將预测結果存檔,上上传 kaggle
save_prediction(pred, 'prediction.csv')

# 由于是unsupervised的二分类问题,我们只在乎有没有成功将图片分成两群
# 如果上面的档案上传kaggle后正确率不足0.5,只要将label反过来就行了
save_prediction(invert(pred), 'prediction_invert.csv')
Latents Shape: (8500, 2048)
First Reduction Shape: (8500, 200)
Second Reduction Shape: (8500, 2)
Save prediction to prediction.csv.
Save prediction to prediction_invert.csv.

问题1(作图)

将 val data 的降维结果 (embedding) 与他们对应的 label 画出来。

valX = np.load('valX.npy')
valY = np.load('valY.npy')

# ==============================================
#  我们示范basline model的作图,
#  report请同学另外还要再画一张improved model的图。
# ==============================================
model.load_state_dict(torch.load('./checkpoints/last_checkpoint.pth'))
model.eval()
latents = inference(valX, model)
pred_from_latent, emb_from_latent = predict(latents)
acc_latent = cal_acc(valY, pred_from_latent)
print('The clustering accuracy is:', acc_latent)
print('The clustering result:')
plot_scatter(emb_from_latent, valY, savefig='p1_baseline.png')
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)


No handles with labels found to put in legend.


Second Reduction Shape: (500, 2)
The clustering accuracy is: 0.75
The clustering result:

在这里插入图片描述

问题2

使用你 test accuracy 最高的 autoencoder,从 trainX 中,取出 index 1, 2, 3, 6, 7, 9 这 6 张图片 画出他们的原图以及 reconstruct 之后的图片

import matplotlib.pyplot as plt
import numpy as np

# 画出原图
plt.figure(figsize=(10,4))
indexes = [1,2,3,6,7,9]
imgs = trainX[indexes,]
for i, img in enumerate(imgs):
    plt.subplot(2, 6, i+1, xticks=[], yticks=[])
    plt.imshow(img)

# 画出 reconstruct 的图
inp = torch.Tensor(trainX_preprocessed[indexes,]).cuda()
latents, recs = model(inp)
recs = ((recs+1)/2 ).cpu().detach().numpy()
recs = recs.transpose(0, 2, 3, 1)
for i, img in enumerate(recs):
    plt.subplot(2, 6, 6+i+1, xticks=[], yticks=[])
    plt.imshow(img)
  
plt.tight_layout()

在这里插入图片描述

问题3

在 autoencoder 的训练过程中,至少挑选 10 个 checkpoints 请用 model 的 train reconstruction error 对 val accuracy 作图 简单说明你观察到的现象

import os
import glob
checkpoints_list = sorted(glob.glob('checkpoints/checkpoint_*.pth'), key= lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))
print(checkpoints_list)
# load data
dataset = Image_Dataset(trainX_preprocessed)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

points = []
with torch.no_grad():
    for i, checkpoint in enumerate(checkpoints_list):
        print('[{}/{}] {}'.format(i+1, len(checkpoints_list), checkpoint))
        model.load_state_dict(torch.load(checkpoint))
        model.eval()
        err = 0
        n = 0
        for x in dataloader:
            x = x.cuda()
            _, rec = model(x)
            err += torch.nn.MSELoss(reduction='sum')(x, rec).item()
            n += x.flatten().size(0)
        print('Reconstruction error (MSE):', err/n)
        latents = inference(X=valX, model=model)
        pred, X_embedded = predict(latents)
        acc = cal_acc(valY, pred)
        print('Accuracy:', acc)
        points.append((err/n, acc))
['checkpoints/checkpoint_10.pth', 'checkpoints/checkpoint_20.pth', 'checkpoints/checkpoint_30.pth', 'checkpoints/checkpoint_40.pth', 'checkpoints/checkpoint_50.pth', 'checkpoints/checkpoint_60.pth', 'checkpoints/checkpoint_70.pth', 'checkpoints/checkpoint_80.pth', 'checkpoints/checkpoint_90.pth', 'checkpoints/checkpoint_100.pth', 'checkpoints/checkpoint_110.pth', 'checkpoints/checkpoint_120.pth', 'checkpoints/checkpoint_130.pth', 'checkpoints/checkpoint_140.pth', 'checkpoints/checkpoint_150.pth', 'checkpoints/checkpoint_160.pth', 'checkpoints/checkpoint_170.pth', 'checkpoints/checkpoint_180.pth', 'checkpoints/checkpoint_190.pth', 'checkpoints/checkpoint_200.pth', 'checkpoints/checkpoint_210.pth', 'checkpoints/checkpoint_220.pth', 'checkpoints/checkpoint_230.pth', 'checkpoints/checkpoint_240.pth', 'checkpoints/checkpoint_250.pth', 'checkpoints/checkpoint_260.pth', 'checkpoints/checkpoint_270.pth', 'checkpoints/checkpoint_280.pth', 'checkpoints/checkpoint_290.pth', 'checkpoints/checkpoint_300.pth', 'checkpoints/checkpoint_310.pth', 'checkpoints/checkpoint_320.pth', 'checkpoints/checkpoint_330.pth', 'checkpoints/checkpoint_340.pth', 'checkpoints/checkpoint_350.pth', 'checkpoints/checkpoint_360.pth', 'checkpoints/checkpoint_370.pth', 'checkpoints/checkpoint_380.pth', 'checkpoints/checkpoint_390.pth', 'checkpoints/checkpoint_400.pth', 'checkpoints/checkpoint_410.pth', 'checkpoints/checkpoint_420.pth', 'checkpoints/checkpoint_430.pth', 'checkpoints/checkpoint_440.pth', 'checkpoints/checkpoint_450.pth', 'checkpoints/checkpoint_460.pth', 'checkpoints/checkpoint_470.pth', 'checkpoints/checkpoint_480.pth', 'checkpoints/checkpoint_490.pth', 'checkpoints/checkpoint_500.pth', 'checkpoints/checkpoint_510.pth', 'checkpoints/checkpoint_520.pth', 'checkpoints/checkpoint_530.pth', 'checkpoints/checkpoint_540.pth', 'checkpoints/checkpoint_550.pth', 'checkpoints/checkpoint_560.pth', 'checkpoints/checkpoint_570.pth', 'checkpoints/checkpoint_580.pth', 'checkpoints/checkpoint_590.pth', 'checkpoints/checkpoint_600.pth', 'checkpoints/checkpoint_610.pth', 'checkpoints/checkpoint_620.pth', 'checkpoints/checkpoint_630.pth', 'checkpoints/checkpoint_640.pth', 'checkpoints/checkpoint_650.pth', 'checkpoints/checkpoint_660.pth', 'checkpoints/checkpoint_670.pth', 'checkpoints/checkpoint_680.pth', 'checkpoints/checkpoint_690.pth', 'checkpoints/checkpoint_700.pth', 'checkpoints/checkpoint_710.pth', 'checkpoints/checkpoint_720.pth', 'checkpoints/checkpoint_730.pth', 'checkpoints/checkpoint_740.pth', 'checkpoints/checkpoint_750.pth', 'checkpoints/checkpoint_760.pth', 'checkpoints/checkpoint_770.pth', 'checkpoints/checkpoint_780.pth', 'checkpoints/checkpoint_790.pth', 'checkpoints/checkpoint_800.pth', 'checkpoints/checkpoint_810.pth', 'checkpoints/checkpoint_820.pth', 'checkpoints/checkpoint_830.pth', 'checkpoints/checkpoint_840.pth', 'checkpoints/checkpoint_850.pth', 'checkpoints/checkpoint_860.pth', 'checkpoints/checkpoint_870.pth', 'checkpoints/checkpoint_880.pth', 'checkpoints/checkpoint_890.pth', 'checkpoints/checkpoint_900.pth', 'checkpoints/checkpoint_910.pth', 'checkpoints/checkpoint_920.pth', 'checkpoints/checkpoint_930.pth', 'checkpoints/checkpoint_940.pth', 'checkpoints/checkpoint_950.pth', 'checkpoints/checkpoint_960.pth', 'checkpoints/checkpoint_970.pth', 'checkpoints/checkpoint_980.pth', 'checkpoints/checkpoint_990.pth', 'checkpoints/checkpoint_1000.pth']
[1/100] checkpoints/checkpoint_10.pth
Reconstruction error (MSE): 0.10465650191961551
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.56
[2/100] checkpoints/checkpoint_20.pth
Reconstruction error (MSE): 0.08282024884691426
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.564
[3/100] checkpoints/checkpoint_30.pth
Reconstruction error (MSE): 0.07529751972123688
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[4/100] checkpoints/checkpoint_40.pth
Reconstruction error (MSE): 0.06996455570295745
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[5/100] checkpoints/checkpoint_50.pth
Reconstruction error (MSE): 0.06556768215403837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[6/100] checkpoints/checkpoint_60.pth
Reconstruction error (MSE): 0.0623151860704609
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[7/100] checkpoints/checkpoint_70.pth
Reconstruction error (MSE): 0.05941213181439568
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[8/100] checkpoints/checkpoint_80.pth
Reconstruction error (MSE): 0.057350127874636184
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[9/100] checkpoints/checkpoint_90.pth
Reconstruction error (MSE): 0.05522508699753705
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[10/100] checkpoints/checkpoint_100.pth
Reconstruction error (MSE): 0.05384457483478621
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[11/100] checkpoints/checkpoint_110.pth
Reconstruction error (MSE): 0.05195554022695504
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.5
[12/100] checkpoints/checkpoint_120.pth
Reconstruction error (MSE): 0.05074959627787272
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[13/100] checkpoints/checkpoint_130.pth
Reconstruction error (MSE): 0.04992709094402837
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[14/100] checkpoints/checkpoint_140.pth
Reconstruction error (MSE): 0.04817914684146058
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[15/100] checkpoints/checkpoint_150.pth
Reconstruction error (MSE): 0.04657277587815827
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.512
[16/100] checkpoints/checkpoint_160.pth
Reconstruction error (MSE): 0.045626810316945994
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[17/100] checkpoints/checkpoint_170.pth
Reconstruction error (MSE): 0.04440261214387183
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.548
[18/100] checkpoints/checkpoint_180.pth
Reconstruction error (MSE): 0.04345548491384469
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[19/100] checkpoints/checkpoint_190.pth
Reconstruction error (MSE): 0.04282478637321323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[20/100] checkpoints/checkpoint_200.pth
Reconstruction error (MSE): 0.042173867076051
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[21/100] checkpoints/checkpoint_210.pth
Reconstruction error (MSE): 0.041361579988517014
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[22/100] checkpoints/checkpoint_220.pth
Reconstruction error (MSE): 0.040615920683916874
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[23/100] checkpoints/checkpoint_230.pth
Reconstruction error (MSE): 0.039873808785980826
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[24/100] checkpoints/checkpoint_240.pth
Reconstruction error (MSE): 0.03932966136932373
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[25/100] checkpoints/checkpoint_250.pth
Reconstruction error (MSE): 0.038771576432620775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[26/100] checkpoints/checkpoint_260.pth
Reconstruction error (MSE): 0.0381339080099966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[27/100] checkpoints/checkpoint_270.pth
Reconstruction error (MSE): 0.03751208638209923
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.504
[28/100] checkpoints/checkpoint_280.pth
Reconstruction error (MSE): 0.037052626366708794
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[29/100] checkpoints/checkpoint_290.pth
Reconstruction error (MSE): 0.03666375287373861
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[30/100] checkpoints/checkpoint_300.pth
Reconstruction error (MSE): 0.03611967169069776
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[31/100] checkpoints/checkpoint_310.pth
Reconstruction error (MSE): 0.03564991631227381
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[32/100] checkpoints/checkpoint_320.pth
Reconstruction error (MSE): 0.035199516689076144
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[33/100] checkpoints/checkpoint_330.pth
Reconstruction error (MSE): 0.034691137108148314
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.524
[34/100] checkpoints/checkpoint_340.pth
Reconstruction error (MSE): 0.03432022960513246
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[35/100] checkpoints/checkpoint_350.pth
Reconstruction error (MSE): 0.033855246824376725
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.52
[36/100] checkpoints/checkpoint_360.pth
Reconstruction error (MSE): 0.03339220189113243
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[37/100] checkpoints/checkpoint_370.pth
Reconstruction error (MSE): 0.03329624884736304
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.502
[38/100] checkpoints/checkpoint_380.pth
Reconstruction error (MSE): 0.03264928217495189
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.53
[39/100] checkpoints/checkpoint_390.pth
Reconstruction error (MSE): 0.03237577991859586
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[40/100] checkpoints/checkpoint_400.pth
Reconstruction error (MSE): 0.03208851829229617
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.514
[41/100] checkpoints/checkpoint_410.pth
Reconstruction error (MSE): 0.0316866933037253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[42/100] checkpoints/checkpoint_420.pth
Reconstruction error (MSE): 0.031364078933117434
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[43/100] checkpoints/checkpoint_430.pth
Reconstruction error (MSE): 0.031136608348173254
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[44/100] checkpoints/checkpoint_440.pth
Reconstruction error (MSE): 0.0309632776297775
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[45/100] checkpoints/checkpoint_450.pth
Reconstruction error (MSE): 0.030496950392629587
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[46/100] checkpoints/checkpoint_460.pth
Reconstruction error (MSE): 0.030128193126005284
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[47/100] checkpoints/checkpoint_470.pth
Reconstruction error (MSE): 0.029998875262690527
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[48/100] checkpoints/checkpoint_480.pth
Reconstruction error (MSE): 0.029572404412662283
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[49/100] checkpoints/checkpoint_490.pth
Reconstruction error (MSE): 0.02939370559243595
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[50/100] checkpoints/checkpoint_500.pth
Reconstruction error (MSE): 0.02911538221321854
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[51/100] checkpoints/checkpoint_510.pth
Reconstruction error (MSE): 0.02889633548960966
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.546
[52/100] checkpoints/checkpoint_520.pth
Reconstruction error (MSE): 0.02860628096262614
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[53/100] checkpoints/checkpoint_530.pth
Reconstruction error (MSE): 0.028405724600249645
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.54
[54/100] checkpoints/checkpoint_540.pth
Reconstruction error (MSE): 0.028084655219433353
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.506
[55/100] checkpoints/checkpoint_550.pth
Reconstruction error (MSE): 0.02798689774905934
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[56/100] checkpoints/checkpoint_560.pth
Reconstruction error (MSE): 0.027731095856311276
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.552
[57/100] checkpoints/checkpoint_570.pth
Reconstruction error (MSE): 0.027528591081207875
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.522
[58/100] checkpoints/checkpoint_580.pth
Reconstruction error (MSE): 0.02748092877631094
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.516
[59/100] checkpoints/checkpoint_590.pth
Reconstruction error (MSE): 0.027148202690423704
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.536
[60/100] checkpoints/checkpoint_600.pth
Reconstruction error (MSE): 0.02693716204400156
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.538
[61/100] checkpoints/checkpoint_610.pth
Reconstruction error (MSE): 0.02663602849548938
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.656
[62/100] checkpoints/checkpoint_620.pth
Reconstruction error (MSE): 0.026486863996468338
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.51
[63/100] checkpoints/checkpoint_630.pth
Reconstruction error (MSE): 0.026279585034239526
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.518
[64/100] checkpoints/checkpoint_640.pth
Reconstruction error (MSE): 0.02615043982337503
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.61
[65/100] checkpoints/checkpoint_650.pth
Reconstruction error (MSE): 0.025924385631785674
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[66/100] checkpoints/checkpoint_660.pth
Reconstruction error (MSE): 0.025687772582559023
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[67/100] checkpoints/checkpoint_670.pth
Reconstruction error (MSE): 0.025555453281776577
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.628
[68/100] checkpoints/checkpoint_680.pth
Reconstruction error (MSE): 0.025691534911884983
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.648
[69/100] checkpoints/checkpoint_690.pth
Reconstruction error (MSE): 0.025101487271925984
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.708
[70/100] checkpoints/checkpoint_700.pth
Reconstruction error (MSE): 0.02504801980186911
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.732
[71/100] checkpoints/checkpoint_710.pth
Reconstruction error (MSE): 0.02484492769428328
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.752
[72/100] checkpoints/checkpoint_720.pth
Reconstruction error (MSE): 0.02478704075719796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[73/100] checkpoints/checkpoint_730.pth
Reconstruction error (MSE): 0.02446424291648117
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.608
[74/100] checkpoints/checkpoint_740.pth
Reconstruction error (MSE): 0.024349503433003145
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[75/100] checkpoints/checkpoint_750.pth
Reconstruction error (MSE): 0.02417324640236649
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.764
[76/100] checkpoints/checkpoint_760.pth
Reconstruction error (MSE): 0.024010706882850796
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.726
[77/100] checkpoints/checkpoint_770.pth
Reconstruction error (MSE): 0.02394120900771197
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.762
[78/100] checkpoints/checkpoint_780.pth
Reconstruction error (MSE): 0.023713757514953613
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.55
[79/100] checkpoints/checkpoint_790.pth
Reconstruction error (MSE): 0.02374166191325468
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
[80/100] checkpoints/checkpoint_800.pth
Reconstruction error (MSE): 0.023461397339315977
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.64
[81/100] checkpoints/checkpoint_810.pth
Reconstruction error (MSE): 0.023291605500613943
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.756
[82/100] checkpoints/checkpoint_820.pth
Reconstruction error (MSE): 0.023138159677094105
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.526
[83/100] checkpoints/checkpoint_830.pth
Reconstruction error (MSE): 0.02306466459760479
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[84/100] checkpoints/checkpoint_840.pth
Reconstruction error (MSE): 0.022922015835257138
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.664
[85/100] checkpoints/checkpoint_850.pth
Reconstruction error (MSE): 0.022727084767584706
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.67
[86/100] checkpoints/checkpoint_860.pth
Reconstruction error (MSE): 0.022709223756603166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[87/100] checkpoints/checkpoint_870.pth
Reconstruction error (MSE): 0.022506213861353257
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.542
[88/100] checkpoints/checkpoint_880.pth
Reconstruction error (MSE): 0.022330569921755323
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.682
[89/100] checkpoints/checkpoint_890.pth
Reconstruction error (MSE): 0.022259797694636325
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.792
[90/100] checkpoints/checkpoint_900.pth
Reconstruction error (MSE): 0.022161509541904226
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.534
[91/100] checkpoints/checkpoint_910.pth
Reconstruction error (MSE): 0.022015575642679253
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.712
[92/100] checkpoints/checkpoint_920.pth
Reconstruction error (MSE): 0.021944920754900166
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.758
[93/100] checkpoints/checkpoint_930.pth
Reconstruction error (MSE): 0.021774335898605047
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.59
[94/100] checkpoints/checkpoint_940.pth
Reconstruction error (MSE): 0.021657160151238534
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.76
[95/100] checkpoints/checkpoint_950.pth
Reconstruction error (MSE): 0.021555810619803037
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.698
[96/100] checkpoints/checkpoint_960.pth
Reconstruction error (MSE): 0.021441521494996313
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.528
[97/100] checkpoints/checkpoint_970.pth
Reconstruction error (MSE): 0.02138799679513071
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.802
[98/100] checkpoints/checkpoint_980.pth
Reconstruction error (MSE): 0.021166577629014558
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.508
[99/100] checkpoints/checkpoint_990.pth
Reconstruction error (MSE): 0.02112917330685784
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.736
[100/100] checkpoints/checkpoint_1000.pth
Reconstruction error (MSE): 0.02094145799150654
Latents Shape: (500, 2048)
First Reduction Shape: (500, 200)
Second Reduction Shape: (500, 2)
Accuracy: 0.554
ps = list(zip(*points))
plt.figure(figsize=(6,6))
plt.subplot(211, title='Reconstruction error (MSE)').plot(ps[0])
plt.subplot(212, title='Accuracy (val)').plot(ps[1])
plt.show()

在这里插入图片描述

精度抖动相当剧烈,无监督果然难train

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/wl1780852311/article/details/121149036

智能推荐

【通信原理】五、模拟调制系统_vsb系统仿真-程序员宅基地

文章浏览阅读1.9k次,点赞3次,收藏15次。AM、DSB、SSB、FM、包络检波、相干解调_vsb系统仿真

通过OpenSSL获取证书扩展属性之二:“密钥用法”和"增强型密钥用法"_openssl 增强型密钥用法-程序员宅基地

文章浏览阅读1.1w次。介绍如何使用Openssl解析CA证书、获取“密钥用法”和“增强型密钥用法”扩展属性。_openssl 增强型密钥用法

E: 仓库 “http://ppa.launchpad.net/fcitx-team/nightly/ubuntu bionic Release” 没有 Release 文件的解决办法-程序员宅基地

文章浏览阅读3k次,点赞6次,收藏5次。ubuntu18.04在运行sudo apt-get update命令时出现以下错误:E: 仓库 “http://ppa.launchpad.net/fcitx-team/nightly/ubuntu bionic Release” 没有 Release 文件解决办法:打开软件更新>其他软件,将做标记的两个勾选去掉问题解决...

资金账户、证券账户及银行账户_证券账户与资金账户与银行账户区别-程序员宅基地

文章浏览阅读1w次,点赞5次,收藏10次。1、 资金账户(证券公司开立的,与券商直接相关)资金账户是你登陆证券交易结算资金账户的凭证,你在一家证券公司开户后,就拥有了这家证券公司的资金账户,你平时用这个账户进行股票的买卖和操作。这是证券公司专门用来记录你资金流转的账户,但是你的资金并不在证券公司里,而是放在和证券公司合作的第三方存管银行账户里,你交易的时候通过交易软件把钱转到你的资金账户进行股票交易,这是为了保护投资者的利益,防止证券公司挪用和非法占有客户的资金。资金账号,是一种股市上的专业术语,一般指的是用于买卖股票的股东资金账户上的账..._证券账户与资金账户与银行账户区别

重置Catalyst 6500/6000 和 Cisco 7600 系列交换机Consle口密码详解_sys-sp-3-logger_flushed system was paused for-程序员宅基地

文章浏览阅读1.1k次。目录说明分解步骤输出示例其他类型的机器简版过程说明在运行 Cisco IOS 系统软件的 Catalyst 6500/6000 和 Cisco 7600 上,其启动顺序与 Cisco 7200 系列路由器有所不同,因为两者的硬件不一样。在您关机并重新开机机箱后,交换机处理器(SP)首先启动。在一小段时间(大约 25 到 60 秒)后,它将控制台所有权转交给路由处理器 (RP (MSFC))。RP 继续加载捆绑的软件映像。请务必在 SP 将控制台控制权转交给 RP 之后立即按 Ctrl-brk。如果您太早_sys-sp-3-logger_flushed system was paused for

MySQl建库建表及增删改查_头歌实践教学平台数据库用户数据库的创建及删除-程序员宅基地

文章浏览阅读427次。通过可视化工具建库建表创建数据库CREATE DATABASE studb2 CHAR SET utf8;切换数据库(使用use 将数据库切换到 studb2)USE studb2 ;在studb2 中创建名为t_stu的表CREATE TABLE t_stu( sid VARCHAR(10) , sname VARCHAR(20), age INT, height FLOAT , weight DOUBLE)CHAR SET utf8_头歌实践教学平台数据库用户数据库的创建及删除

随便推点

AOP与OOP有什么区别,谈谈AOP的原理是什么,大厂Android高级面试题汇总解答-程序员宅基地

文章浏览阅读521次,点赞25次,收藏11次。包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频**

最小费用流_单向图费用流-程序员宅基地

文章浏览阅读1.5k次。单向图#include//每次找费用的最短路,更新残留网络图直到找不到最短路为止#include//最大费用 权值取负值 结果取负值#include#include#includeusing namespace std;const int inf=0x3f3f3f3f;struct Node_单向图费用流

Python中的5个高阶概念属性的知识点!你要了解明白哦!_python属性的五大类-程序员宅基地

文章浏览阅读318次。在现代编程世界中,面向对象编程(OOP)语言在改变软件开发中的设计和实现模式方面发挥了进化作用。作为OOP家族的重要成员,Python在过去10年左右逐渐流行起来。与其他OOP语言一样,Python围绕大量不同的对象操作其数据,包括模块、类和函数。如果您有任何OOP语言的编程经验,您应该知道所有对象都有其内部特征数据,称为字段、属性或属性。在Python中,这些对象绑定的特征数据通常称为属性。在本文中,我将特别在自定义类的上下文中讨论它们。1. 类属性为了更好地管理项目中的数据,我们经常需要_python属性的五大类

python 基于PHP+MySQL的装修网站的设计与实现_python抓取装修需求-程序员宅基地

文章浏览阅读282次。5:系统简介设置:系统管理员应该可以通过系统简介设置功能设置系统前台的系统简介信息,系统前台的系统简介是随后台的变化而变化的,系统简介应该使用编辑器,实现图片,文字,列表,样式等多功能输入。6:系统公告设置:系统管理员应该可以通过系统公告设置功能设置系统前台的系统公告信息,系统前台的系统公告是随后台的变化而变化的,系统公告应该使用编辑器,实现图片,文字,列表,样式等多功能输入。应该都要能修改自己的登录密码,修改后需要重新登录。13:装修效果:员工给客户上传装修效果和装修进度,客户查询。_python抓取装修需求

ubuntu完美的nvidia驱动安装方式(ubuntu16+驱动410+cuda10.0)or(ubuntu16+驱动455+cuda11.1)_乌班图英伟达驱动选着哪个版本-程序员宅基地

文章浏览阅读2k次,点赞4次,收藏5次。ubuntu完美的nvidia驱动安装方式(ubuntu16+驱动410+cuda10.0) 本人卡 GeForce GTX TITAN X1.卸载驱动并重启电脑:sudo apt-get remove --purge nvidia-*sudo apt-get autoremove #特别重要sudo apt-get install -f #特别重要sudo reboot......_乌班图英伟达驱动选着哪个版本

解决redis超时io.lettuce.core.RedisCommandTimeoutException: Connection timed out after 5s-程序员宅基地

文章浏览阅读5.3k次。报错内容:io.lettuce.core.RedisCommandTimeoutException: Connection initialization timed out. Command timed out after 1 minute(s) at io.lettuce.core.internal.ExceptionFactory.createTimeoutException(ExceptionFactory.java:65) ~[lettuce-core-6.1.4.RELEASE.j..._io.lettuce.core.rediscommandtimeoutexception: connection initialization time