Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions

Best AI papers explained - A podcast by Enoch H. Kang

Categories:

This academic paper explores masked diffusion models (MDMs), a promising approach for generative modeling in discrete domains. It investigates the trade-off between training complexity and inference flexibility in MDMs compared to autoregressive models (ARMs). The authors demonstrate that MDMs are trained on computationally challenging subproblems, leading to performance imbalances. However, they show that adaptive inference strategies, which strategically select the token decoding order, can significantly enhance MDM capabilities, allowing them to circumvent these difficult problems. Notably, adaptive MDMs achieve superior performance on logic puzzles like Sudoku, even surpassing ARMs with more parameters and explicit training for decoding order.