CMD_SET_TYPE(cmd, CMD_INV_DEV_ENTRY);
 }
 
-static void build_inv_iommu_pages(struct iommu_cmd *cmd, u64 address,
-                                 size_t size, u16 domid, int pde)
+/*
+ * Builds an invalidation address which is suitable for one page or multiple
+ * pages. Sets the size bit (S) as needed is more than one page is flushed.
+ */
+static inline u64 build_inv_address(u64 address, size_t size)
 {
-       u64 pages;
-       bool s;
+       u64 pages, end, msb_diff;
 
        pages = iommu_num_pages(address, size, PAGE_SIZE);
-       s     = false;
 
-       if (pages > 1) {
+       if (pages == 1)
+               return address & PAGE_MASK;
+
+       end = address + size - 1;
+
+       /*
+        * msb_diff would hold the index of the most significant bit that
+        * flipped between the start and end.
+        */
+       msb_diff = fls64(end ^ address) - 1;
+
+       /*
+        * Bits 63:52 are sign extended. If for some reason bit 51 is different
+        * between the start and the end, invalidate everything.
+        */
+       if (unlikely(msb_diff > 51)) {
+               address = CMD_INV_IOMMU_ALL_PAGES_ADDRESS;
+       } else {
                /*
-                * If we have to flush more than one page, flush all
-                * TLB entries for this domain
+                * The msb-bit must be clear on the address. Just set all the
+                * lower bits.
                 */
-               address = CMD_INV_IOMMU_ALL_PAGES_ADDRESS;
-               s = true;
+               address |= 1ull << (msb_diff - 1);
        }
 
+       /* Clear bits 11:0 */
        address &= PAGE_MASK;
 
+       /* Set the size bit - we flush more than one 4kb page */
+       return address | CMD_INV_IOMMU_PAGES_SIZE_MASK;
+}
+
+static void build_inv_iommu_pages(struct iommu_cmd *cmd, u64 address,
+                                 size_t size, u16 domid, int pde)
+{
+       u64 inv_address = build_inv_address(address, size);
+
        memset(cmd, 0, sizeof(*cmd));
        cmd->data[1] |= domid;
-       cmd->data[2]  = lower_32_bits(address);
-       cmd->data[3]  = upper_32_bits(address);
+       cmd->data[2]  = lower_32_bits(inv_address);
+       cmd->data[3]  = upper_32_bits(inv_address);
        CMD_SET_TYPE(cmd, CMD_INV_IOMMU_PAGES);
-       if (s) /* size bit - we flush more than one 4kb page */
-               cmd->data[2] |= CMD_INV_IOMMU_PAGES_SIZE_MASK;
        if (pde) /* PDE bit - we want to flush everything, not only the PTEs */
                cmd->data[2] |= CMD_INV_IOMMU_PAGES_PDE_MASK;
 }
 static void build_inv_iotlb_pages(struct iommu_cmd *cmd, u16 devid, int qdep,
                                  u64 address, size_t size)
 {
-       u64 pages;
-       bool s;
-
-       pages = iommu_num_pages(address, size, PAGE_SIZE);
-       s     = false;
-
-       if (pages > 1) {
-               /*
-                * If we have to flush more than one page, flush all
-                * TLB entries for this domain
-                */
-               address = CMD_INV_IOMMU_ALL_PAGES_ADDRESS;
-               s = true;
-       }
-
-       address &= PAGE_MASK;
+       u64 inv_address = build_inv_address(address, size);
 
        memset(cmd, 0, sizeof(*cmd));
        cmd->data[0]  = devid;
        cmd->data[0] |= (qdep & 0xff) << 24;
        cmd->data[1]  = devid;
-       cmd->data[2]  = lower_32_bits(address);
-       cmd->data[3]  = upper_32_bits(address);
+       cmd->data[2]  = lower_32_bits(inv_address);
+       cmd->data[3]  = upper_32_bits(inv_address);
        CMD_SET_TYPE(cmd, CMD_INV_IOTLB_PAGES);
-       if (s)
-               cmd->data[2] |= CMD_INV_IOMMU_PAGES_SIZE_MASK;
 }
 
 static void build_inv_iommu_pasid(struct iommu_cmd *cmd, u16 domid, u32 pasid,