#include <asm/iommu.h>
 
 /*
- * The RMP entry format is not architectural. The format is defined in PPR
- * Family 19h Model 01h, Rev B1 processor.
+ * The RMP entry information as returned by the RMPREAD instruction.
  */
 struct rmpentry {
+       u64 gpa;
+       u8  assigned            :1,
+           rsvd1               :7;
+       u8  pagesize            :1,
+           hpage_region_status :1,
+           rsvd2               :6;
+       u8  immutable           :1,
+           rsvd3               :7;
+       u8  rsvd4;
+       u32 asid;
+} __packed;
+
+/*
+ * The raw RMP entry format is not architectural. The format is defined in PPR
+ * Family 19h Model 01h, Rev B1 processor. This format represents the actual
+ * entry in the RMP table memory. The bitfield definitions are used for machines
+ * without the RMPREAD instruction (Zen3 and Zen4), otherwise the "hi" and "lo"
+ * fields are only used for dumping the raw data.
+ */
+struct rmpentry_raw {
        union {
                struct {
                        u64 assigned    : 1,
 #define PFN_PMD_MASK   GENMASK_ULL(63, PMD_SHIFT - PAGE_SHIFT)
 
 static u64 probed_rmp_base, probed_rmp_size;
-static struct rmpentry *rmptable __ro_after_init;
+static struct rmpentry_raw *rmptable __ro_after_init;
 static u64 rmptable_max_pfn __ro_after_init;
 
 static LIST_HEAD(snp_leaked_pages_list);
        rmptable_start += RMPTABLE_CPU_BOOKKEEPING_SZ;
        rmptable_size = probed_rmp_size - RMPTABLE_CPU_BOOKKEEPING_SZ;
 
-       rmptable = (struct rmpentry *)rmptable_start;
-       rmptable_max_pfn = rmptable_size / sizeof(struct rmpentry) - 1;
+       rmptable = (struct rmpentry_raw *)rmptable_start;
+       rmptable_max_pfn = rmptable_size / sizeof(struct rmpentry_raw) - 1;
 
        cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "x86/rmptable_init:online", __snp_enable, NULL);
 
  */
 device_initcall(snp_rmptable_init);
 
-static struct rmpentry *get_rmpentry(u64 pfn)
+static struct rmpentry_raw *get_raw_rmpentry(u64 pfn)
 {
-       if (WARN_ON_ONCE(pfn > rmptable_max_pfn))
+       if (!rmptable)
+               return ERR_PTR(-ENODEV);
+
+       if (unlikely(pfn > rmptable_max_pfn))
                return ERR_PTR(-EFAULT);
 
-       return &rmptable[pfn];
+       return rmptable + pfn;
+}
+
+static int get_rmpentry(u64 pfn, struct rmpentry *e)
+{
+       struct rmpentry_raw *e_raw;
+
+       e_raw = get_raw_rmpentry(pfn);
+       if (IS_ERR(e_raw))
+               return PTR_ERR(e_raw);
+
+       /*
+        * Map the raw RMP table entry onto the RMPREAD output format.
+        * The 2MB region status indicator (hpage_region_status field) is not
+        * calculated, since the overhead could be significant and the field
+        * is not used.
+        */
+       memset(e, 0, sizeof(*e));
+       e->gpa       = e_raw->gpa << PAGE_SHIFT;
+       e->asid      = e_raw->asid;
+       e->assigned  = e_raw->assigned;
+       e->pagesize  = e_raw->pagesize;
+       e->immutable = e_raw->immutable;
+
+       return 0;
 }
 
-static struct rmpentry *__snp_lookup_rmpentry(u64 pfn, int *level)
+static int __snp_lookup_rmpentry(u64 pfn, struct rmpentry *e, int *level)
 {
-       struct rmpentry *large_entry, *entry;
+       struct rmpentry e_large;
+       int ret;
 
        if (!cc_platform_has(CC_ATTR_HOST_SEV_SNP))
-               return ERR_PTR(-ENODEV);
+               return -ENODEV;
 
-       entry = get_rmpentry(pfn);
-       if (IS_ERR(entry))
-               return entry;
+       ret = get_rmpentry(pfn, e);
+       if (ret)
+               return ret;
 
        /*
         * Find the authoritative RMP entry for a PFN. This can be either a 4K
         * RMP entry or a special large RMP entry that is authoritative for a
         * whole 2M area.
         */
-       large_entry = get_rmpentry(pfn & PFN_PMD_MASK);
-       if (IS_ERR(large_entry))
-               return large_entry;
+       ret = get_rmpentry(pfn & PFN_PMD_MASK, &e_large);
+       if (ret)
+               return ret;
 
-       *level = RMP_TO_PG_LEVEL(large_entry->pagesize);
+       *level = RMP_TO_PG_LEVEL(e_large.pagesize);
 
-       return entry;
+       return 0;
 }
 
 int snp_lookup_rmpentry(u64 pfn, bool *assigned, int *level)
 {
-       struct rmpentry *e;
+       struct rmpentry e;
+       int ret;
 
-       e = __snp_lookup_rmpentry(pfn, level);
-       if (IS_ERR(e))
-               return PTR_ERR(e);
+       ret = __snp_lookup_rmpentry(pfn, &e, level);
+       if (ret)
+               return ret;
 
-       *assigned = !!e->assigned;
+       *assigned = !!e.assigned;
        return 0;
 }
 EXPORT_SYMBOL_GPL(snp_lookup_rmpentry);
  */
 static void dump_rmpentry(u64 pfn)
 {
+       struct rmpentry_raw *e_raw;
        u64 pfn_i, pfn_end;
-       struct rmpentry *e;
-       int level;
+       struct rmpentry e;
+       int level, ret;
 
-       e = __snp_lookup_rmpentry(pfn, &level);
-       if (IS_ERR(e)) {
-               pr_err("Failed to read RMP entry for PFN 0x%llx, error %ld\n",
-                      pfn, PTR_ERR(e));
+       ret = __snp_lookup_rmpentry(pfn, &e, &level);
+       if (ret) {
+               pr_err("Failed to read RMP entry for PFN 0x%llx, error %d\n",
+                      pfn, ret);
                return;
        }
 
-       if (e->assigned) {
+       if (e.assigned) {
+               e_raw = get_raw_rmpentry(pfn);
+               if (IS_ERR(e_raw)) {
+                       pr_err("Failed to read RMP contents for PFN 0x%llx, error %ld\n",
+                              pfn, PTR_ERR(e_raw));
+                       return;
+               }
+
                pr_info("PFN 0x%llx, RMP entry: [0x%016llx - 0x%016llx]\n",
-                       pfn, e->lo, e->hi);
+                       pfn, e_raw->lo, e_raw->hi);
                return;
        }
 
                pfn, pfn_i, pfn_end);
 
        while (pfn_i < pfn_end) {
-               e = __snp_lookup_rmpentry(pfn_i, &level);
-               if (IS_ERR(e)) {
-                       pr_err("Error %ld reading RMP entry for PFN 0x%llx\n",
-                              PTR_ERR(e), pfn_i);
+               e_raw = get_raw_rmpentry(pfn_i);
+               if (IS_ERR(e_raw)) {
+                       pr_err("Error %ld reading RMP contents for PFN 0x%llx\n",
+                              PTR_ERR(e_raw), pfn_i);
                        pfn_i++;
                        continue;
                }
 
-               if (e->lo || e->hi)
-                       pr_info("PFN: 0x%llx, [0x%016llx - 0x%016llx]\n", pfn_i, e->lo, e->hi);
+               if (e_raw->lo || e_raw->hi)
+                       pr_info("PFN: 0x%llx, [0x%016llx - 0x%016llx]\n", pfn_i, e_raw->lo, e_raw->hi);
                pfn_i++;
        }
 }