#include <linux/bitmap.h>
 
 #include "regs/xe_gt_regs.h"
+#include "xe_assert.h"
 #include "xe_gt.h"
 #include "xe_mmio.h"
 
        bitmap_from_arr32(mask, &val, XE_MAX_EU_FUSE_BITS);
 }
 
+/**
+ * gen_l3_mask_from_pattern - Replicate a bit pattern according to a mask
+ *
+ * It is used to compute the L3 bank masks in a generic format on
+ * various platforms where the internal representation of L3 node
+ * and masks from registers are different.
+ *
+ * @xe: device
+ * @dst: destination
+ * @pattern: pattern to replicate
+ * @patternbits: size of the pattern, in bits
+ * @mask: mask describing where to replicate the pattern
+ *
+ * Example 1:
+ * ----------
+ * @pattern =    0b1111
+ *                 └┬─┘
+ * @patternbits =   4 (bits)
+ * @mask = 0b0101
+ *           ││││
+ *           │││└────────────────── 0b1111 (=1×0b1111)
+ *           ││└──────────── 0b0000    │   (=0×0b1111)
+ *           │└────── 0b1111    │      │   (=1×0b1111)
+ *           └ 0b0000    │      │      │   (=0×0b1111)
+ *                │      │      │      │
+ * @dst =      0b0000 0b1111 0b0000 0b1111
+ *
+ * Example 2:
+ * ----------
+ * @pattern =    0b11111111
+ *                 └┬─────┘
+ * @patternbits =   8 (bits)
+ * @mask = 0b10
+ *           ││
+ *           ││
+ *           ││
+ *           │└────────── 0b00000000 (=0×0b11111111)
+ *           └ 0b11111111      │     (=1×0b11111111)
+ *                  │          │
+ * @dst =      0b11111111 0b00000000
+ */
+static void
+gen_l3_mask_from_pattern(struct xe_device *xe, xe_l3_bank_mask_t dst,
+                        xe_l3_bank_mask_t pattern, int patternbits,
+                        unsigned long mask)
+{
+       unsigned long bit;
+
+       xe_assert(xe, fls(mask) <= patternbits);
+       for_each_set_bit(bit, &mask, 32) {
+               xe_l3_bank_mask_t shifted_pattern = {};
+
+               bitmap_shift_left(shifted_pattern, pattern, bit * patternbits,
+                                 XE_MAX_L3_BANK_MASK_BITS);
+               bitmap_or(dst, dst, shifted_pattern, XE_MAX_L3_BANK_MASK_BITS);
+       }
+}
+
+static void
+load_l3_bank_mask(struct xe_gt *gt, xe_l3_bank_mask_t l3_bank_mask)
+{
+       struct xe_device *xe = gt_to_xe(gt);
+       u32 fuse3 = xe_mmio_read32(gt, MIRROR_FUSE3);
+
+       if (GRAPHICS_VER(xe) >= 20) {
+               xe_l3_bank_mask_t per_node = {};
+               u32 meml3_en = REG_FIELD_GET(XE2_NODE_ENABLE_MASK, fuse3);
+               u32 bank_val = REG_FIELD_GET(XE2_GT_L3_MODE_MASK, fuse3);
+
+               bitmap_from_arr32(per_node, &bank_val, 32);
+               gen_l3_mask_from_pattern(xe, l3_bank_mask, per_node, 4,
+                                        meml3_en);
+       } else if (GRAPHICS_VERx100(xe) >= 1270) {
+               xe_l3_bank_mask_t per_node = {};
+               xe_l3_bank_mask_t per_mask_bit = {};
+               u32 meml3_en = REG_FIELD_GET(MEML3_EN_MASK, fuse3);
+               u32 fuse4 = xe_mmio_read32(gt, XEHP_FUSE4);
+               u32 bank_val = REG_FIELD_GET(GT_L3_EXC_MASK, fuse4);
+
+               bitmap_set_value8(per_mask_bit, 0x3, 0);
+               gen_l3_mask_from_pattern(xe, per_node, per_mask_bit, 2, bank_val);
+               gen_l3_mask_from_pattern(xe, l3_bank_mask, per_node, 4,
+                                        meml3_en);
+       } else if (xe->info.platform == XE_PVC) {
+               xe_l3_bank_mask_t per_node = {};
+               xe_l3_bank_mask_t per_mask_bit = {};
+               u32 meml3_en = REG_FIELD_GET(MEML3_EN_MASK, fuse3);
+               u32 bank_val = REG_FIELD_GET(XEHPC_GT_L3_MODE_MASK, fuse3);
+
+               bitmap_set_value8(per_mask_bit, 0xf, 0);
+               gen_l3_mask_from_pattern(xe, per_node, per_mask_bit, 4,
+                                        bank_val);
+               gen_l3_mask_from_pattern(xe, l3_bank_mask, per_node, 16,
+                                        meml3_en);
+       } else if (xe->info.platform == XE_DG2) {
+               xe_l3_bank_mask_t per_node = {};
+               u32 mask = REG_FIELD_GET(MEML3_EN_MASK, fuse3);
+
+               bitmap_set_value8(per_node, 0xff, 0);
+               gen_l3_mask_from_pattern(xe, l3_bank_mask, per_node, 8, mask);
+       } else {
+               /* 1:1 register bit to mask bit (inverted register bits) */
+               u32 mask = REG_FIELD_GET(XELP_GT_L3_MODE_MASK, ~fuse3);
+
+               bitmap_from_arr32(l3_bank_mask, &mask, 32);
+       }
+}
+
 static void
 get_num_dss_regs(struct xe_device *xe, int *geometry_regs, int *compute_regs)
 {
                      XEHPC_GT_COMPUTE_DSS_ENABLE_EXT,
                      XE2_GT_COMPUTE_DSS_2);
        load_eu_mask(gt, gt->fuse_topo.eu_mask_per_dss);
+       load_l3_bank_mask(gt, gt->fuse_topo.l3_bank_mask);
 
        p = drm_dbg_printer(>_to_xe(gt)->drm, DRM_UT_DRIVER, "GT topology");
 
        drm_printf(p, "EU mask per DSS:     %*pb\n", XE_MAX_EU_FUSE_BITS,
                   gt->fuse_topo.eu_mask_per_dss);
 
+       drm_printf(p, "L3 bank mask:        %*pb\n", XE_MAX_L3_BANK_MASK_BITS,
+                  gt->fuse_topo.l3_bank_mask);
 }
 
 /*