Python小练习:Sinkhorn-Knopp算法
作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/
本文介绍Sinkhorn-Knopp算法的Python实现,通过参考并修改两种不同的实现方法,来真正弄懂算法原理。详细的原理部分可参考文末给出的参考文献。
公式为:$P = diag(u)\exp \left( {\frac{{ - S}}{\varepsilon }} \right)diag(v)$。输入S,输出P,其中u与v是renormalization向量,eps用来控制P的平滑性。
1. sinkhorn_test.py
1 # -*- coding: utf-8 -*- 2 # Author:凯鲁嘎吉 Coral Gajic 3 # https://www.cnblogs.com/kailugaji/ 4 # Sinkhorn-Knopp算法(以方阵为例) 5 # 对于一个n*n方阵 6 # 1) 先逐行做归一化:将第一行的每个元素除以第一行所有元素之和,得到新的"第一行",每行都做相同的操作 7 # 2) 再逐列做归一化,操作同上 8 # 重复以上的两步1)与2),最终可以收敛到一个行和为1,列和也为1的双随机矩阵。 9 import torch 10 import numpy as np 11 import time 12 # 方法1: 13 ''' 14 https://github.com/miralab-ustc/rl-cbm 15 ''' 16 # numpy转换成tensor 17 def sinkhorn(scores, eps = 5, n_iter = 3): 18 def remove_infs(x): # 替换掉数据里面的INF与0 19 mm = x[torch.isfinite(x)].max().item() # m是x的最大值 20 x[torch.isinf(x)] = mm # 用最大值替换掉数据里面的INF 21 x[x==0] = 1e-38 # 将数据里面的0元素替换为1e-38 22 return x 23 # 若以(2, 8)为例 24 scores = torch.tensor(scores) 25 t0 = time.time() 26 n, m = scores.shape # torch.Size([2, 8]) 27 scores1 = scores.view(n*m) # torch.Size([16]) 28 Q = torch.softmax(-scores1/eps, dim=0) # softmax 29 Q = remove_infs(Q).view(n,m).T # torch.Size([8, 2]) 30 r, c = torch.ones(n), torch.ones(m) * (n / m) 31 # 确保sum(r)=sum(c) 32 # 对应地P的行和为r,列和为c 33 for _ in range(n_iter): 34 u = (c/torch.sum(Q, dim=1)) # torch.sum(Q, dim=1)按列求和,得到1行8列的数torch.Size([8]) 35 Q *= remove_infs(u).unsqueeze(1) # torch.Size([8, 2]) 36 v = (r/torch.sum(Q,dim=0)) # torch.sum(Q,dim=0)按行求和,得到torch.Size([2]) 37 Q *= remove_infs(v).unsqueeze(0) # torch.Size([8, 2]) 38 bsum = torch.sum(Q, dim=0, keepdim=True) # 按行求和,torch.Size([1, 2]) 39 Q = Q / remove_infs(bsum) 40 # bsum = torch.sum(Q, dim=1, keepdim=True) 41 # Q = Q / remove_infs(bsum) 42 P = Q.T # 转置,torch.Size([2, 8]) 43 t1 = time.time() 44 compute_time = t1 - t0 45 assert torch.isnan(P.sum())==False 46 P = np.array(P) 47 scores = np.array(scores) 48 dist = np.sum(P * scores) 49 return P, dist, compute_time 50 51 # 方法2: 52 # Sinkhorn-Knopp算法 53 ''' 54 https://michielstock.github.io/posts/2017/2017-11-5-OptimalTransport/ 55 https://zhuanlan.zhihu.com/p/542379144 56 ''' 57 # numpy 58 def compute_optimal_transport(scores, eps = 5, n_iter = 3): 59 """ 60 Computes the optimal transport matrix and Sinkhorn distance using the 61 Sinkhorn-Knopp algorithm 62 Inputs: 63 - scores : cost matrix (n * m) 64 - r : vector of marginals (n, ) 65 - c : vector of marginals (m, ) 66 - eps : strength of the entropic regularization 67 - epsilon : convergence parameter 68 Outputs: 69 - P : optimal transport matrix (n x m) 70 - dist : Sinkhorn distance 71 """ 72 t0 = time.time() 73 n, m = scores.shape 74 r = np.ones(n) # P矩阵列和为r 75 c = np.ones(m)*(n/m) # P矩阵行和为c 76 # 确保:np.sum(r)==np.sum(c) 77 P = np.exp(- scores / eps) 78 P /= P.sum() 79 u = np.zeros(n) 80 # normalize this matrix 81 # while np.max(np.abs(u - P.sum(1))) > epsilon: 82 for _ in range(n_iter): 83 u = P.sum(1) 84 P *= (r / u).reshape((-1, 1)) # 行归r化 85 P *= (c / P.sum(0)).reshape((1, -1)) # 列归c化 86 t1 = time.time() 87 compute_time = t1 - t0 88 dist = np.sum(P * scores) 89 return P, dist, compute_time 90 91 np.random.seed(1) 92 n = 5 # 行数 93 m = 5 # 列数 94 num = 3 # 保留小数位数 95 n_iter = 100 # 迭代次数 96 eps = 0.5 97 scores = np.random.rand(n ,m) # cost matrix 98 print('原始数据:\n', np.around(scores, num)) 99 print('------------------------------------------------') 100 # 方法1: 101 P, dist, compute_time_1 = sinkhorn(scores, eps = eps, n_iter = n_iter) 102 print('1. 处理后的结果:\n', np.around(P, num)) 103 print('1. 行和:\n', np.sum(P, axis = 0)) 104 print('1. 列和:\n', np.sum(P, axis = 1)) 105 print('1. Sinkhorn距离:', np.around(dist, num)) 106 print('1. 计算时间:', np.around(compute_time_1, 8), '秒') 107 print('------------------------------------------------') 108 # 方法2: 109 P, dist, compute_time_2 = compute_optimal_transport(scores, eps = eps, n_iter = n_iter) 110 print('2. 处理后的结果:\n', np.around(P, num)) 111 print('2. 行和:\n', np.sum(P, axis = 0)) 112 print('2. 列和:\n', np.sum(P, axis = 1)) 113 print('2. Sinkhorn距离:', np.around(dist, num)) 114 print('2. 计算时间:', np.around(compute_time_2, 8), '秒')
2. 结果
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/sinkhorn_test.py" 原始数据: [[0.417 0.72 0. 0.302 0.147] [0.092 0.186 0.346 0.397 0.539] [0.419 0.685 0.204 0.878 0.027] [0.67 0.417 0.559 0.14 0.198] [0.801 0.968 0.313 0.692 0.876]] ------------------------------------------------ 1. 处理后的结果: [[0.178 0.121 0.263 0.214 0.225] [0.308 0.318 0.119 0.161 0.093] [0.212 0.155 0.209 0.081 0.342] [0.117 0.242 0.094 0.324 0.222] [0.185 0.164 0.314 0.22 0.117]] 1. 行和: [1. 1. 1. 1. 1.] 1. 列和: [1. 1. 1. 1. 1.] 1. Sinkhorn距离: 1.802 1. 计算时间: 0.01776958 秒 ------------------------------------------------ 2. 处理后的结果: [[0.178 0.121 0.263 0.214 0.225] [0.308 0.318 0.119 0.161 0.093] [0.212 0.155 0.209 0.081 0.342] [0.117 0.242 0.094 0.324 0.222] [0.185 0.164 0.314 0.22 0.117]] 2. 行和: [1. 1. 1. 1. 1.] 2. 列和: [1. 1. 1. 1. 1.] 2. Sinkhorn距离: 1.802 2. 计算时间: 0.00101089 秒 Process finished with exit code 0
当 n=10, m=5 时,结果为
D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/sinkhorn_test.py" 原始数据: [[0.417 0.72 0. 0.302 0.147] [0.092 0.186 0.346 0.397 0.539] [0.419 0.685 0.204 0.878 0.027] [0.67 0.417 0.559 0.14 0.198] [0.801 0.968 0.313 0.692 0.876] [0.895 0.085 0.039 0.17 0.878] [0.098 0.421 0.958 0.533 0.692] [0.316 0.687 0.835 0.018 0.75 ] [0.989 0.748 0.28 0.789 0.103] [0.448 0.909 0.294 0.288 0.13 ]] ------------------------------------------------ 1. 处理后的结果: [[0.17 0.111 0.299 0.183 0.237] [0.308 0.306 0.141 0.143 0.102] [0.2 0.141 0.234 0.068 0.356] [0.118 0.234 0.112 0.29 0.246] [0.177 0.152 0.358 0.188 0.124] [0.063 0.385 0.268 0.231 0.053] [0.422 0.265 0.058 0.151 0.104] [0.268 0.153 0.072 0.415 0.091] [0.082 0.16 0.259 0.105 0.393] [0.191 0.091 0.199 0.225 0.293]] 1. 行和: [2. 2. 2. 2. 2.] 1. 列和: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] 1. Sinkhorn距离: 3.356 1. 计算时间: 0.02233005 秒 ------------------------------------------------ 2. 处理后的结果: [[0.17 0.111 0.299 0.183 0.237] [0.308 0.306 0.141 0.143 0.102] [0.2 0.141 0.234 0.068 0.356] [0.118 0.234 0.112 0.29 0.246] [0.177 0.152 0.358 0.188 0.124] [0.063 0.385 0.268 0.231 0.053] [0.422 0.265 0.058 0.151 0.104] [0.268 0.153 0.072 0.415 0.091] [0.082 0.16 0.259 0.105 0.393] [0.191 0.091 0.199 0.225 0.293]] 2. 行和: [2. 2. 2. 2. 2.] 2. 列和: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] 2. Sinkhorn距离: 3.356 2. 计算时间: 0.00100446 秒 Process finished with exit code 0
3. 参考文献
[1] Cuturi M. Sinkhorn distances: Lightspeed computation of optimal transport[J]. Advances in neural information processing systems, 2013, 26.
[2] Liu Q, Zhou Q, Yang R, et al. Robust Representation Learning by Clustering with Bisimulation Metrics for Visual Reinforcement Learning with Distractions[J]. arXiv preprint arXiv:2302.12003, 2023.
[3] Michiel Stock, Notes on Optimal Transport, https://michielstock.github.io/posts/2017/2017-11-5-OptimalTransport/