radeon_ring_write(ring, 0);
        radeon_ring_write(ring, 1 << vm_id);
 
+       /* wait for the invalidate to complete */
+       radeon_ring_write(ring, PACKET3(PACKET3_WAIT_REG_MEM, 5));
+       radeon_ring_write(ring, (WAIT_REG_MEM_FUNCTION(0) |  /* always */
+                                WAIT_REG_MEM_ENGINE(0))); /* me */
+       radeon_ring_write(ring, VM_INVALIDATE_REQUEST >> 2);
+       radeon_ring_write(ring, 0);
+       radeon_ring_write(ring, 0); /* ref */
+       radeon_ring_write(ring, 0); /* mask */
+       radeon_ring_write(ring, 0x20); /* poll interval */
+
        /* sync PFP to ME, otherwise we might get invalid PFP reads */
        radeon_ring_write(ring, PACKET3(PACKET3_PFP_SYNC_ME, 0));
        radeon_ring_write(ring, 0x0);
 
        radeon_ring_write(ring, DMA_PACKET(DMA_PACKET_SRBM_WRITE, 0, 0, 0, 0));
        radeon_ring_write(ring, (0xf << 16) | (VM_INVALIDATE_REQUEST >> 2));
        radeon_ring_write(ring, 1 << vm_id);
+
+       /* wait for invalidate to complete */
+       radeon_ring_write(ring, DMA_PACKET(DMA_PACKET_POLL_REG_MEM, 0, 0, 0, 0));
+       radeon_ring_write(ring, VM_INVALIDATE_REQUEST);
+       radeon_ring_write(ring, 0xff << 16); /* retry */
+       radeon_ring_write(ring, 1 << vm_id); /* mask */
+       radeon_ring_write(ring, 0); /* value */
+       radeon_ring_write(ring, (0 << 28) | 0x20); /* func(always) | poll interval */
 }
 
 /**
 
 #define        PACKET3_MPEG_INDEX                              0x3A
 #define        PACKET3_COPY_DW                                 0x3B
 #define        PACKET3_WAIT_REG_MEM                            0x3C
+#define                WAIT_REG_MEM_FUNCTION(x)                ((x) << 0)
+                /* 0 - always
+                * 1 - <
+                * 2 - <=
+                * 3 - ==
+                * 4 - !=
+                * 5 - >=
+                * 6 - >
+                */
+#define                WAIT_REG_MEM_MEM_SPACE(x)               ((x) << 4)
+                /* 0 - reg
+                * 1 - mem
+                */
+#define                WAIT_REG_MEM_ENGINE(x)                  ((x) << 8)
+                /* 0 - me
+                * 1 - pfp
+                */
 #define        PACKET3_MEM_WRITE                               0x3D
 #define        PACKET3_COPY_DATA                               0x40
 #define        PACKET3_CP_DMA                                  0x41
 #define        DMA_PACKET_TRAP                                   0x7
 #define        DMA_PACKET_SRBM_WRITE                             0x9
 #define        DMA_PACKET_CONSTANT_FILL                          0xd
+#define        DMA_PACKET_POLL_REG_MEM                           0xe
 #define        DMA_PACKET_NOP                                    0xf
 
 #define VCE_STATUS                                     0x20004