t-SNE可视化-Python实现

t-SNE

本文主要是对An Introduction to t-SNE with Python Example博客的翻译记录,和一些入门的Python代码,可以的话推荐阅读原文。

主要参考

介绍:
An Introduction to t-SNE with Python Example
GitHub:
sas-python-work/tSneExampleBlogPost.ipynb
t-SNE-tutorial
tSNE
tsne-pytorch
PintheMemory/tsnelib.py
加速包:
Multicore-TSNE
tsne-cuda

t-SNE介绍

t-Distributed Stochastic Neighbor Embedding (t-SNE) 是一种无监督的非线性技术,主要用于数据探索和高维数据的可视化。 简单来说,t-SNE 让您对数据在高维空间中的排列方式有一种感觉或直觉。 它由 Laurens van der Maatens 和 Geoffrey Hinton 于 2008 年开发。

简单来说就是高维数据可视化,目的是观察高维数据的之间的分布情况

t-SNE与PCA的区别

首先要注意的是,PCA 是在 1933 年开发的,而 t-SNE 是在 2008 年开发的。自 1933 年以来,数据科学领域发生了很大变化,主要是在计算和数据大小方面。 其次,PCA 是一种线性降维技术,旨在最大化方差并保持较大的成对距离。 换句话说,不同的事物最终会相距甚远。 这会导致可视化效果不佳,尤其是在处理非线性流形结构时。 将流形结构视为任何几何形状,例如:圆柱体、球体、曲线等。
t-SNE 与 PCA 的不同之处在于仅保留小的成对距离或局部相似性,而 PCA 关注的是保留大的成对距离以最大化方差。 Laurens 使用图 1 [1] 中的 Swiss Roll 数据集很好地说明了 PCA 和 t-SNE 方法。 您可以看到,由于这个玩具数据集(流形)的非线性和保留较大的距离,PCA 会错误地保留数据的结构。

t-SNE原理

t-SNE 算法计算高维空间和低维空间中实例对之间的相似性度量。 然后,它尝试使用成本函数优化这两个相似性度量。 让我们将其分解为 3 个基本步骤:

  1. 第一步,测量高维空间中点之间的相似度。 想想散布在二维空间上的一堆数据点(图 2)。 对于每个数据点 (xi),我们将在该点上以高斯分布为中心。 然后我们测量该高斯分布下所有点 (xj) 的密度。 然后对所有点重新归一化。 这为我们提供了所有点的一组概率 (Pij)。 这些概率与相似性成正比。 这意味着,如果数据点 x1 和 x2 在这个高斯圆下具有相等的值,那么它们的比例和相似性是相等的,因此它们在这个高维空间的结构中具有局部相似性。 高斯分布或圆可以使用所谓的 perplexity 来操纵,它会影响分布的方差(圆的大小)以及最近邻的数量。 perplexity 的正常范围在 5 到 50 之间 [2]。
  2. 第 2 步与第 1 步类似,但不是使用高斯分布,而是使用具有一个自由度的学生 t 分布,也称为柯西分布(图 3)。 这为我们提供了低维空间中的第二组概率(Qij)。 如图所示,学生 t 分布的尾部比正态分布更重。 厚重的尾巴可以更好地模拟远距离。
  3. 最后一步是我们希望这些来自低维空间 (Qij) 的概率集尽可能地反映高维空间 (Pij) 的概率。 我们希望这两个地图结构相似。 我们使用 Kullback-Liebler 散度 (KL) 测量二维空间的概率分布之间的差异。最后,我们使用梯度下降来最小化我们的 KL 成本函数。

t-SNE的Python实现

入门例子

import numpy as np

from sklearn.manifold import TSNE
# For the UCI ML handwritten digits dataset
from sklearn.datasets import load_digits

# Import matplotlib for plotting graphs ans seaborn for attractive graphics.
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sns

def plot(x, colors):
    # Choosing color palette
    # https://seaborn.pydata.org/generated/seaborn.color_palette.html
    palette = np.array(sns.color_palette("pastel", 10))
    # pastel, husl, and so on

    # Create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=palette[colors.astype(np.int8)])
    # Add the labels for each digit.
    txts = []
    for i in range(10):
        # Position of each label.
        xtext, ytext = np.median(x[colors == i, :], axis=0)
        txt = ax.text(xtext, ytext, str(i), fontsize=24)
        txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])
        txts.append(txt)
    plt.savefig('./digits_tsne-pastel.png', dpi=120)
    return f, ax, txts


digits = load_digits()
print(digits.data.shape)
# There are 10 classes (0 to 9) with alomst 180 images in each class 
# The images are 8x8 and hence 64 pixels(dimensions)

# Place the arrays of data of each digit on top of each other and store in X
X = np.vstack([digits.data[digits.target==i] for i in range(10)])
# Place the arrays of data of each target digit by the side of each other continuosly and store in Y
Y = np.hstack([digits.target[digits.target==i] for i in range(10)])

# Implementing the TSNE Function - ah Scikit learn makes it so easy!
digits_final = TSNE(perplexity=30).fit_transform(X) 
# Play around with varying the parameters like perplexity, random_state to get different plots

plot(digits_final, Y)

生成的图片(分别是husl风格和pastel风格,每次运行结果不一样):
在这里插入图片描述
在这里插入图片描述
在前面的基础上换一种可视化风格:

def plot2(data, x='x', y='y'):
    sns.set_context("notebook", font_scale=1.1)
    sns.set_style("ticks")

    sns.lmplot(x=x,
            y=y,
            data=data,
            fit_reg=False,
            legend=True,
            height=9,
            hue='Label',
            scatter_kws={"s":200, "alpha":0.3})

    plt.title('t-SNE Results: Digits', weight='bold').set_fontsize('14')
    plt.xlabel(x, weight='bold').set_fontsize('10')
    plt.ylabel(y, weight='bold').set_fontsize('10')
    plt.savefig('./digits_tsne-plot2.png', dpi=120)

import pandas as pd
data = {'x': digits_final[:, 0],
        'y': digits_final[:, 1],
        'Label': Y}
data = pd.DataFrame(data)
plot2(data)

生成的图片:
在这里插入图片描述

高级例子

分割网络特征t-SNE可视化

Pending…

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年3月8日 下午10:46
下一篇 2023年3月8日 下午10:49

相关推荐