Python小练习:Sinkhorn-Knopp算法

发布时间 2023-03-24 15:29:13作者: 凯鲁嘎吉

Python小练习:Sinkhorn-Knopp算法

作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/

本文介绍Sinkhorn-Knopp算法的Python实现,通过参考并修改两种不同的实现方法,来真正弄懂算法原理。详细的原理部分可参考文末给出的参考文献。
公式为:$P = diag(u)\exp \left( {\frac{{ - S}}{\varepsilon }} \right)diag(v)$。输入S,输出P,其中u与v是renormalization向量,eps用来控制P的平滑性。

1. sinkhorn_test.py

  1 # -*- coding: utf-8 -*-
  2 # Author:凯鲁嘎吉 Coral Gajic
  3 # https://www.cnblogs.com/kailugaji/
  4 # Sinkhorn-Knopp算法(以方阵为例)
  5 # 对于一个n*n方阵
  6 # 1) 先逐行做归一化:将第一行的每个元素除以第一行所有元素之和,得到新的"第一行",每行都做相同的操作
  7 # 2) 再逐列做归一化,操作同上
  8 # 重复以上的两步1)与2),最终可以收敛到一个行和为1,列和也为1的双随机矩阵。
  9 import torch
 10 import numpy as np
 11 import time
 12 # 方法1:
 13 '''
 14     https://github.com/miralab-ustc/rl-cbm
 15 '''
 16 # numpy转换成tensor
 17 def sinkhorn(scores, eps = 5, n_iter = 3):
 18     def remove_infs(x): # 替换掉数据里面的INF与0
 19         mm = x[torch.isfinite(x)].max().item() # m是x的最大值
 20         x[torch.isinf(x)] = mm # 用最大值替换掉数据里面的INF
 21         x[x==0] = 1e-38 # 将数据里面的0元素替换为1e-38
 22         return x
 23     # 若以(2, 8)为例
 24     scores = torch.tensor(scores)
 25     t0 = time.time()
 26     n, m = scores.shape # torch.Size([2, 8])
 27     scores1 = scores.view(n*m) # torch.Size([16])
 28     Q = torch.softmax(-scores1/eps, dim=0) # softmax
 29     Q = remove_infs(Q).view(n,m).T # torch.Size([8, 2])
 30     r, c = torch.ones(n), torch.ones(m) * (n / m)
 31     # 确保sum(r)=sum(c)
 32     # 对应地P的行和为r,列和为c
 33     for _ in range(n_iter):
 34         u = (c/torch.sum(Q, dim=1)) # torch.sum(Q, dim=1)按列求和,得到1行8列的数torch.Size([8])
 35         Q *= remove_infs(u).unsqueeze(1) #  torch.Size([8, 2])
 36         v = (r/torch.sum(Q,dim=0)) # torch.sum(Q,dim=0)按行求和,得到torch.Size([2])
 37         Q *= remove_infs(v).unsqueeze(0) # torch.Size([8, 2])
 38     bsum = torch.sum(Q, dim=0, keepdim=True) # 按行求和,torch.Size([1, 2])
 39     Q = Q / remove_infs(bsum)
 40     # bsum = torch.sum(Q, dim=1, keepdim=True)
 41     # Q = Q / remove_infs(bsum)
 42     P = Q.T # 转置,torch.Size([2, 8])
 43     t1 = time.time()
 44     compute_time = t1 - t0
 45     assert torch.isnan(P.sum())==False
 46     P = np.array(P)
 47     scores = np.array(scores)
 48     dist = np.sum(P * scores)
 49     return P, dist, compute_time
 50 
 51 # 方法2:
 52 # Sinkhorn-Knopp算法
 53 '''
 54     https://michielstock.github.io/posts/2017/2017-11-5-OptimalTransport/
 55     https://zhuanlan.zhihu.com/p/542379144
 56 '''
 57 # numpy
 58 def compute_optimal_transport(scores, eps = 5, n_iter = 3):
 59     """
 60     Computes the optimal transport matrix and Sinkhorn distance using the
 61     Sinkhorn-Knopp algorithm
 62     Inputs:
 63         - scores : cost matrix (n * m)
 64         - r : vector of marginals (n, )
 65         - c : vector of marginals (m, )
 66         - eps : strength of the entropic regularization
 67         - epsilon : convergence parameter
 68     Outputs:
 69         - P : optimal transport matrix (n x m)
 70         - dist : Sinkhorn distance
 71     """
 72     t0 = time.time()
 73     n, m = scores.shape
 74     r = np.ones(n)  # P矩阵列和为r
 75     c = np.ones(m)*(n/m)  # P矩阵行和为c
 76     # 确保:np.sum(r)==np.sum(c)
 77     P = np.exp(- scores / eps)
 78     P /= P.sum()
 79     u = np.zeros(n)
 80     # normalize this matrix
 81     # while np.max(np.abs(u - P.sum(1))) > epsilon:
 82     for _ in range(n_iter):
 83         u = P.sum(1)
 84         P *= (r / u).reshape((-1, 1)) # 行归r化
 85         P *= (c / P.sum(0)).reshape((1, -1)) # 列归c化
 86     t1 = time.time()
 87     compute_time = t1 - t0
 88     dist = np.sum(P * scores)
 89     return P, dist, compute_time
 90 
 91 np.random.seed(1)
 92 n = 5 # 行数
 93 m = 5 # 列数
 94 num = 3 # 保留小数位数
 95 n_iter = 100 # 迭代次数
 96 eps = 0.5
 97 scores = np.random.rand(n ,m) # cost matrix
 98 print('原始数据:\n', np.around(scores, num))
 99 print('------------------------------------------------')
100 # 方法1:
101 P, dist, compute_time_1 = sinkhorn(scores, eps = eps, n_iter = n_iter)
102 print('1. 处理后的结果:\n', np.around(P, num))
103 print('1. 行和:\n', np.sum(P, axis = 0))
104 print('1. 列和:\n', np.sum(P, axis = 1))
105 print('1. Sinkhorn距离:', np.around(dist, num))
106 print('1. 计算时间:', np.around(compute_time_1, 8), '')
107 print('------------------------------------------------')
108 # 方法2:
109 P, dist, compute_time_2 = compute_optimal_transport(scores, eps = eps, n_iter = n_iter)
110 print('2. 处理后的结果:\n', np.around(P, num))
111 print('2. 行和:\n', np.sum(P, axis = 0))
112 print('2. 列和:\n', np.sum(P, axis = 1))
113 print('2. Sinkhorn距离:', np.around(dist, num))
114 print('2. 计算时间:', np.around(compute_time_2, 8), '')

2. 结果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/sinkhorn_test.py"
原始数据:
 [[0.417 0.72  0.    0.302 0.147]
 [0.092 0.186 0.346 0.397 0.539]
 [0.419 0.685 0.204 0.878 0.027]
 [0.67  0.417 0.559 0.14  0.198]
 [0.801 0.968 0.313 0.692 0.876]]
------------------------------------------------
1. 处理后的结果:
 [[0.178 0.121 0.263 0.214 0.225]
 [0.308 0.318 0.119 0.161 0.093]
 [0.212 0.155 0.209 0.081 0.342]
 [0.117 0.242 0.094 0.324 0.222]
 [0.185 0.164 0.314 0.22  0.117]]
1. 行和:
 [1. 1. 1. 1. 1.]
1. 列和:
 [1. 1. 1. 1. 1.]
1. Sinkhorn距离: 1.802
1. 计算时间: 0.01776958------------------------------------------------
2. 处理后的结果:
 [[0.178 0.121 0.263 0.214 0.225]
 [0.308 0.318 0.119 0.161 0.093]
 [0.212 0.155 0.209 0.081 0.342]
 [0.117 0.242 0.094 0.324 0.222]
 [0.185 0.164 0.314 0.22  0.117]]
2. 行和:
 [1. 1. 1. 1. 1.]
2. 列和:
 [1. 1. 1. 1. 1.]
2. Sinkhorn距离: 1.802
2. 计算时间: 0.00101089 秒

Process finished with exit code 0

当 n=10, m=5 时,结果为

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/sinkhorn_test.py"
原始数据:
 [[0.417 0.72  0.    0.302 0.147]
 [0.092 0.186 0.346 0.397 0.539]
 [0.419 0.685 0.204 0.878 0.027]
 [0.67  0.417 0.559 0.14  0.198]
 [0.801 0.968 0.313 0.692 0.876]
 [0.895 0.085 0.039 0.17  0.878]
 [0.098 0.421 0.958 0.533 0.692]
 [0.316 0.687 0.835 0.018 0.75 ]
 [0.989 0.748 0.28  0.789 0.103]
 [0.448 0.909 0.294 0.288 0.13 ]]
------------------------------------------------
1. 处理后的结果:
 [[0.17  0.111 0.299 0.183 0.237]
 [0.308 0.306 0.141 0.143 0.102]
 [0.2   0.141 0.234 0.068 0.356]
 [0.118 0.234 0.112 0.29  0.246]
 [0.177 0.152 0.358 0.188 0.124]
 [0.063 0.385 0.268 0.231 0.053]
 [0.422 0.265 0.058 0.151 0.104]
 [0.268 0.153 0.072 0.415 0.091]
 [0.082 0.16  0.259 0.105 0.393]
 [0.191 0.091 0.199 0.225 0.293]]
1. 行和:
 [2. 2. 2. 2. 2.]
1. 列和:
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
1. Sinkhorn距离: 3.356
1. 计算时间: 0.02233005------------------------------------------------
2. 处理后的结果:
 [[0.17  0.111 0.299 0.183 0.237]
 [0.308 0.306 0.141 0.143 0.102]
 [0.2   0.141 0.234 0.068 0.356]
 [0.118 0.234 0.112 0.29  0.246]
 [0.177 0.152 0.358 0.188 0.124]
 [0.063 0.385 0.268 0.231 0.053]
 [0.422 0.265 0.058 0.151 0.104]
 [0.268 0.153 0.072 0.415 0.091]
 [0.082 0.16  0.259 0.105 0.393]
 [0.191 0.091 0.199 0.225 0.293]]
2. 行和:
 [2. 2. 2. 2. 2.]
2. 列和:
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
2. Sinkhorn距离: 3.356
2. 计算时间: 0.00100446 秒

Process finished with exit code 0

3. 参考文献

[1] Cuturi M. Sinkhorn distances: Lightspeed computation of optimal transport[J]. Advances in neural information processing systems, 2013, 26.
[2] Liu Q, Zhou Q, Yang R, et al. Robust Representation Learning by Clustering with Bisimulation Metrics for Visual Reinforcement Learning with Distractions[J]. arXiv preprint arXiv:2302.12003, 2023.
[3] Michiel Stock, Notes on Optimal Transport, https://michielstock.github.io/posts/2017/2017-11-5-OptimalTransport/