ライブラリを使えば簡単に最適輸送を求められる • 数値例: 二つの正規分布からの点群の比較 import numpy as np import matplotlib.pyplot as plt import ot # POT ライブラリ n = 100 # 点群サイズ mu = np.random.randn(n, 2) # 入力分布 1 (青) nu = np.random.randn(n, 2) + 1 # 入力分布 2 (オレンジ) a = np.ones(n) / n # 点重み (1/n, ..., 1/n) b = np.ones(n) / n # 点重み (1/n, ..., 1/n) C = np.linalg.norm(nu[np.newaxis] - mu[:, np.newaxis], axis=2) # コスト行列 P = ot.emd(a, b, C) # 最適輸送行列の計算(POT ライブラリを使用) plt.scatter(mu[:, 0], mu[:, 1]) # mu の散布図描写 plt.scatter(nu[:, 0], nu[:, 1]) # nu の散布図描写 for i in range(n): j = P[i].argmax() # i の対応相手: 最もたくさん輸送している先 plt.plot([mu[i, 0], nu[j, 0]], [mu[i, 1], nu[j, 1]], c='grey', zorder=-1)