Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Learning Optimal Priors for Task-Invariant Repr...

Hiroshi Takahashi
June 08, 2024
74

Learning Optimal Priors for Task-Invariant Representations in Variational Autoencoders

- KDD2022
- Paper is available at https://dl.acm.org/doi/10.1145/3534678.3539291

Hiroshi Takahashi

June 08, 2024
Tweet

Transcript

  1. KDD 2022 Research Track Learning Optimal Priors for Task-Invariant Representations

    in Variational Autoencoders Hiroshi Takahashi1, Tomoharu Iwata1, Atsutoshi Kumagai1, Sekitoshi Kanai1, Masanori Yamada1, Yuuki Yamanaka1, Hisashi Kashima2 1NTT, 2Kyoto University
  2. 1 Copyright 2022 NTT CORPORATION [Introduction] Variational Autoencoder • The

    variational autoencoder (VAE) is a powerful latent variable model for unsupervised representation learning. downstream applications (such as classification, data generation, out-of-distribution detection, etc.) 𝜙 𝜃 𝐱 𝐱 𝐳 data encoder decoder 𝑝(𝐳) standard Gaussian prior data latent variable
  3. 2 Copyright 2022 NTT CORPORATION [Introduction] Multi-Task Learning • However,

    the VAE cannot perform well with insufficient data points since it depends on neural networks. • To solve this, we focus on obtaining task-invariant latent variable from multiple tasks. 𝜙 encoder 𝐳 task-invariant latent variable multiple tasks useful for tasks of insufficient data points insufficient data points a lot of datapoints
  4. 3 Copyright 2022 NTT CORPORATION [Introduction] Conditional VAE • For

    multiple tasks, the conditional VAE (CVAE) is widely used, which tries to obtain task-invariant latent variable. 𝜙 𝜃 𝐱 𝐱 𝐳 data encoder decoder 𝑝(𝐳) data task-invariant latent variable task index task index 𝑠 𝑠 standard Gaussian prior
  5. 4 Copyright 2022 NTT CORPORATION [Introduction] Problem and Contribution •

    Although the CVAE can reduce the dependency of 𝐳 on 𝑠 to some extent, this dependency remains in many cases. • The contribution of this study is as follows: 1. We investigate the cause of the task-dependency in the CVAE and reveal that the simple prior is one of the causes. 2. We introduce the optimal prior to reduce the task-dependency. 3. We theoretically and experimentally show that our learned representation works well on multiple tasks.
  6. 5 Copyright 2022 NTT CORPORATION <latexit sha1_base64="oe48kSpU4v4VxmcJwBGi8bPtZ7c=">AAADs3icjVFLb9NAEB7XPIp5NIULEheLqFEilWiDykNISIW2CAmQ+iBppWxk2c4mturHxl5HtK7/AH+AAyeQEEJc+Qdc+AMcyhUuiGORuHBg/ChJqaCM5d2Zb+ab+XbX4I4dCkJ2pQn52PETJydPKafPnD03VZo+3wr9KDBZ0/QdP9gw9JA5tseawhYO2+AB013DYevG5kKaXx+yILR977HY4qzj6n3P7tmmLhDSpqVH1NWFZepOfC/R4iwI3HihdWcpSapZaPRiKiwm9GT2d8wtO6kplds5YMRLyOVavDiiPElmw9qg6LhPGaW3k50DlQl1WE+0qeOpfIyUzx3vujPqgDQa2H1LdKjne5FrsIBSpUIHg0jvqleO0FZMXNTiBw+T6v8rpS0WCJWP5Wv7MhRFK5VJnWSmHnYahVOGwpb90mug0AUfTIjABQYeCPQd0CHErw0NIMAR60CMWICeneUZJKAgN8IqhhU6opu49jFqF6iHcdozzNgmTnHwD5Cpwgz5SN6QPfKBvCVfyc+/9oqzHqmWLdyNnMu4NvX04tqPI1ku7gKsEeufmgX04Gam1UbtPEPSU5g5f7j9bG/t1upMXCEvyTfU/4Lskvd4Am/43Xy1wlafQ/oAjT+v+7DTulpvXK/PrcyV5+8WTzEJl+AyVPG+b8A83IdlaIIpvZM+SZ+lL/I1uS0bcjcvnZAKzgU4YLL7C/j4DU8=</latexit> FCVAE(✓, ) =

    E pD(x,s)q (z|x,s) [ln p✓(x|z, s)] E pD(x,s) [DKL(q (z|x, s)kp(z))] <latexit sha1_base64="R5iKmH4CEPmUQAzNPJKK3ZYfmTM=">AAADjXicnVHLbtNAFL2ueZTwSAobJDYRUSGRUDRBpSBEUQVCsOyDtJXiyBq743hUP6b2JGozzA+wYMuCFUgIIbZsYcOGH2DRT0Asi8SGBdeOqxQqqMq1PHPumXvO3JlxRMBTSciOMWEeO37i5OSp0ukzZ8+VK1PnV9K4n7is7cZBnKw5NGUBj1hbchmwNZEwGjoBW3U27mfrqwOWpDyOHsttwboh7UXc4y6VSNlTxlVhKyuk0nc8ZUmfSap1fY/Y0k/SxpzFI1k9pGwPDvW1tCHq47SRwyRU63pMzo2gox5oW23ucxY+3+c7HPtuZb7aCpgnO5aXUFf9d0P6SDsmvOfLbqlkV2qkSfKoHgStAtSgiIW48gYsWIcYXOhDCAwikIgDoJDi14EWEBDIdUEhlyDi+ToDDSXU9rGKYQVFdgPHHmadgo0wzzzTXO3iLgH+CSqrME2+kLdkl3wm78hX8vOvXir3yHrZxtkZaZmwy08vLv84VBXiLMEfq/7ZswQPbuW9cuxd5Ex2CnekHwyf7y7fXppWV8gr8g37f0l2yCc8QTT47r5eZEsvIHuA1p/XfRCsXG+2ZpszizO1+XvFU0zCJbgMdbzvmzAPj2AB2uAaz4z3xgfjo1k2b5h3zLuj0gmj0FyA38J8+AvILQEq</latexit> p✓(x|s) = Z p✓(x|z, s)p(z)dz = E q (z|x,s)  p✓(x|z, s)p(z) q (z|x, s) [Preliminaries] Reviewing CVAE • The CVAE models a conditional probability of 𝐱 given 𝑠 as: • The CVAE is trained by maximizing the ELBO that is the lower bound of the log-likelihoods as follows: decoder prior encoder = ℛ(𝜙) data distribution
  7. 6 Copyright 2022 NTT CORPORATION [Preliminaries] Mutual Information • To

    investigate the cause of dependency of 𝐳 on 𝑠, we introduce the mutual information 𝐼(𝑆; 𝑍), which measures the mutual dependence between two random variables. 𝐼 𝑆; 𝑍 becomes large if 𝐳 depends on 𝑠 𝐼 𝑆; 𝑍 becomes small if 𝐳 does NOT depend on 𝑠 𝐻(𝑆) 𝐻(𝑍) 𝐻(𝑆) 𝐻(𝑍)
  8. 7 Copyright 2022 NTT CORPORATION [Proposed] Theorem 1 • The

    CVAE tries to minimize the mutual information 𝐼(𝑆; 𝑍) by minimizing its upper bound ℛ(𝜙): • However, ℛ(𝜙) is NOT a tight upper bound of 𝐼(𝑆; 𝑍) since 𝐷!" (𝑞# (𝐳)||𝑝(𝐳)) usually gives a large value. <latexit sha1_base64="2WF4WdrdpOv468GFxtVHg/eznZs=">AAADZ3icjVFdaxNBFL2b+FHjR1IFKfgSDQ27VMJEioqlULSCpT60iUlLM+myu06SIfvV3UmwHecP+OBrBZ8URMSf4Yt/wIf+hOJjBV988GYTiLFWvcvuPXPuPWfvzNihy2NByIGWSp86febs1LnM+QsXL2Vz05frcdCLHFZzAjeINm0rZi73WU1w4bLNMGKWZ7tsw+4+GNQ3+iyKeeA/Ebsha3pW2+ct7lgCKXNa06hniY5jubKi9ATbLUnDDldGpkjZTo/3h6wtHypThqZcHvc9UzdjQ1GXtURj2ZSrj5W+Y8oJl3Hznno+oaN1Fol8+EvdMGjE2x3RpDRTXFzRqwtbxtx/+P7Rao7GPc+U3cWy2parioYcF2pF39yWetdQC1vDbGTMXIGUSBL546A8AgUYxVqQew8UnkIADvTAAwY+CMQuWBDj04AyEAiRa4JELkLEkzoDBRnU9rCLYYeFbBe/bVw1RqyP64FnnKgd/IuLb4TKPMySL+QDOSKfyUdySH6c6CUTj8Esu5jtoZaFZvbFTPX7P1UeZgGdseqvMwtowd1kVo6zhwkz2IUz1Pf39o+q9yqzskjekq84/xtyQD7hDvz+N+fdOqu8hsEFlH8/7uOgfqtUvl2aX58vLN0fXcUUXIMboON534EleARrUANH62gvtX3tVeownU1fTc8MW1PaSHMFJiJ9/Sf28+nG</latexit> R( ) ⌘ E pD(x,s) [DKL(q (z|x, s)kp(z))] = I(S; Z) + DKL(q (z)kp(z)) + K X k=1 ⇡kI(X(k); Z(k)) mutual information between 𝐱 and 𝐳 when 𝑠 = 𝑘 𝜋! = 𝑝(𝑠 = 𝑘) 𝑞" 𝐳 = ∫ 𝑞" 𝐳 𝐱, 𝑠 𝑝# 𝐱, s d𝐱
  9. 8 Copyright 2022 NTT CORPORATION [Proposed] Effects of Priors !

    !"# $ "! # $ ! ; & ! '$% (& ) ∥ + ) ℛ - # .; & Proposed Method ℛ(𝜙) is NOT a tight upper bound of 𝐼(𝑆; 𝑍) since 𝐷$% (𝑞" (𝐳)||𝑝(𝐳)) usually gives a large value. When 𝑝 𝐳 = 𝑞" 𝐳 , ℛ(𝜙) becomes the tightest upper bound of 𝐼(𝑆; 𝑍). • That is, the simple prior 𝑝(𝐳) is one causes of the task- dependency, and 𝑞# 𝐳 is the optimal prior to reduce it.
  10. 9 Copyright 2022 NTT CORPORATION [Proposed] Theorem 2 • The

    ELBO with this optimal prior ℱ$%&'&() (𝜃, 𝜙) is always larger than or equal to original ELBO ℱ*+,- (𝜃, 𝜙): • That is, ℱ$%&'&() (𝜃, 𝜙) is also a better lower bound of the log-likelihood than ℱ*+,- 𝜃, 𝜙 . • This contributes to obtaining better representation for the improved performance on the target tasks. <latexit sha1_base64="cReRpIFFkHRHyAEW/aHr3JatyTY=">AAADWXicnZHPaxNBFMffZv0R1x+N9iJ4CYaWBEuYlKIiCNWqCHpIW5MWu2WZnU6SoftjOjsJtMuCV/sPePDUgoj4Z3jxH/BQ/AvEYwu9ePDtZrVqbQudZWfe+773efNmxpWeiDQhO0bBPHP23PniBevipctXRkpXr7WjsK8Yb7HQC9WiSyPuiYC3tNAeX5SKU9/1+IK7OpPGFwZcRSIMXuh1yZd92g1ERzCqUXJKe7ZPdY9RL36SOHHmKD9uqlCGEV9JkmomuZ3Y1j2uaTLx25c9kdSs+//lZ9oPHp/I3nrkxM+eJ9W1nPsVOOA2kprd5kqX5Z9SzbK7fK18+o0tp1QhdZKN8mGjkRsVyEczLL0HG1YgBAZ98IFDABptDyhE+C1BAwhI1JYhRk2hJbI4hwQsZPuYxTGDorqKcxe9pVwN0E9rRhnNcBcPf4VkGcbIF/KB7JLP5CP5Rn4cWSvOaqS9rOPqDlkunZHN6/P7J1I+rhp6B9SxPWvowN2sV4G9y0xJT8GG/GDjze78vbmxeJxsk+/Y/xbZIZ/wBMFgj72b5XNvIX2Axr/XfdhoT9Ybt+tTs1OV6Yf5UxThBtyEKt73HZiGp9CEFjDjpfHKeG1sFr6ahlk0rWFqwciZUfhrmKM/AQbF6Kc=</latexit> FProposed(✓, ) = FCVAE(✓, ) + DKL(q (z)kp(z)) FCVAE(✓, )
  11. 10 Copyright 2022 NTT CORPORATION [Proposed] Optimizing ℱ!"#$#%& (𝜃, 𝜙)

    • We optimize ℱ$%&'&() 𝜃, 𝜙 = ℱ*+,- 𝜃, 𝜙 + 𝐷!" (𝑞# (𝐳)||𝑝(𝐳)) by approximating the KL divergence 𝐷!" (𝑞# (𝐳)||𝑝(𝐳)): • We approximate 𝑞# 𝐳 /𝑝(𝐳) by density ratio trick, which can estimate the density ratio between two distributions using samples from both distribution (See Section 3.3). <latexit sha1_base64="PVz8Nq1rbUNMiC1/ST13WyzPjus=">AAADDHichVHLShxBFL3dUWPG15hsAm4GB2XcDDVGkhAiiHEh6MJHZhRsabrLGqewX1bXDGjRP+Ai2yyyUhARt+7ElRDyAy78hJClATcuvN3T4gvH21TXuafuuXWqyg4cHkpCLjT9VVt7x+vON5mu7p7evmz/20ro1wVlZeo7vli2rZA53GNlyaXDlgPBLNd22JK98S1eX2owEXLf+y63ArbqWuser3JqSaTM7MGUqWZmo8KmqQzXkjW7qoygxqOocJtuRyNGhQmZC+5TI8PjBvdkrrXO8YyqsKhqWRWpB52jBAtXrUV3bMbM5kmRJJF7CkopyEMac372AAxYAx8o1MEFBh5IxA5YEOK3AiUgECC3Cgo5gYgn6wwiyKC2jlUMKyxkN/C/jtlKynqYxz3DRE1xFweHQGUOhsg5OSSX5A85In/J9bO9VNIj9rKFs93UssDs23m/ePWiysVZQu1O1dKzhCp8Trxy9B4kTHwK2tQ3tn9eLn5ZGFLDZI/8Q/+75IKc4Qm8xn+6P88WfkH8AKXH1/0UVEaLpY/Fsfmx/MRk+hSdMACDUMD7/gQTMA1zUAaq9WgftK/auP5DP9ZP9NNmqa6lmnfwIPTfN8qWzOQ=</latexit> DKL(q (z)kp(z)) = Z q (z) ln q (z) p(z) dz
  12. 11 Copyright 2022 NTT CORPORATION [Proposed] Theoretical Contributions • Our

    theoretical contributions are summarized as follows: • We next evaluate our representation on various datasets. • The simple prior is one of the causes of the task-dependency. • 𝑞! 𝐳 is the optimal prior to reduce the task-dependency. • ℱ"#$%$&'(𝜃, 𝜙) gives a better lower bound of the log-likelihood, which enables us to obtain better representation than the CVAE. Theorem 1 shows: Theorem 2 shows:
  13. 12 Copyright 2022 NTT CORPORATION [Experiments] Datasets • We used

    two handwritten digits (USPS and MNIST), two house number digits (SynthDigits and SVHN), and three face datasets (Frey, Olivetti, and UMist).
  14. 13 Copyright 2022 NTT CORPORATION [Experiments] Settings • On digits

    datasets, we conducted two-task experiments, which estimate the performance on the target tasks: • The source task has a lot of training data points. • The target task has only 100 training data points. • Pairs are (USPS→MNIST), (MNIST→USPS), (SynthDigits→SVHN), and (SVHN→SynthDigits). • On face datasets, we conducted three-task experiment, which simultaneously evaluates the performance on each task using a single estimator.
  15. 15 Copyright 2022 NTT CORPORATION VAE CVAE Proposed USPS→MNIST −163.25

    ± 2.15 −152.32 ± 1.64 −+,-. ./ ± .. /0 MNIST→USPS −235.23 ± 1.54 −1++. +/ ± .. 22 −1+1. ++ ± +. ,/ Synth→SVHN 1146.04 ± 35.65 1397.36 ± 10.89 +,7.. 18 ± ++. ,, SVHN→Synth 760.66 ± 8.85 814.63 ± 10.09 /22. 2+ ± ++. ,+ Face Datasets 895.41 ± 2.98 902.99 ± 3.69 -+7. ./ ± 2. .2 [Results] Density Estimation Performance Almost equal to or better performance than other approaches
  16. 16 Copyright 2022 NTT CORPORATION VAE CVAE Proposed USPS→MNIST 0.52

    ± 2.15 0.53 ± 0.02 ). *+ ± ). ), MNIST→USPS 0.64 ± 0.01 0.67 ± 0.01 ). 01 ± ). )2 Synth→SVHN 0.20 ± 0.00 ). 2, ± ). )) 0.19 ± 0.00 SVHN→Synth 0.25 ± 0.01 0.25 ± 0.00 ). 2* ± ). )) [Results] Downstream Classification Almost equal to or better performance than other approaches
  17. 17 Copyright 2022 NTT CORPORATION Conclusion • Our contribution for

    the CVAE are summarized as follows: • The simple prior is one of the causes of the task-dependency. • We propose the optimal prior to reduce the task-dependency. • Our approach gives a better lower bound of the log-likelihood, which enable us to obtain better representation than the CVAE. Theorem 1 shows: Theorem 2 shows: • Our approach achieves better performance on various datasets. Experiments shows: