#include "ufshcd.h"
 
+#define UFSHCD_ENABLE_INTRS    (UTP_TRANSFER_REQ_COMPL |\
+                                UTP_TASK_REQ_COMPL |\
+                                UFSHCD_ERROR_MASK)
+
 enum {
        UFSHCD_MAX_CHANNEL      = 0,
        UFSHCD_MAX_ID           = 1,
        INT_AGGR_CONFIG,
 };
 
+/**
+ * ufshcd_get_intr_mask - Get the interrupt bit mask
+ * @hba - Pointer to adapter instance
+ *
+ * Returns interrupt bit mask per version
+ */
+static inline u32 ufshcd_get_intr_mask(struct ufs_hba *hba)
+{
+       if (hba->ufs_version == UFSHCI_VERSION_10)
+               return INTERRUPT_MASK_ALL_VER_10;
+       else
+               return INTERRUPT_MASK_ALL_VER_11;
+}
+
 /**
  * ufshcd_get_ufs_version - Get the UFS version supported by the HBA
  * @hba - Pointer to adapter instance
 }
 
 /**
- * ufshcd_int_config - enable/disable interrupts
+ * ufshcd_enable_intr - enable interrupts
  * @hba: per adapter instance
- * @option: interrupt option
+ * @intrs: interrupt bits
  */
-static void ufshcd_int_config(struct ufs_hba *hba, u32 option)
+static void ufshcd_enable_intr(struct ufs_hba *hba, u32 intrs)
 {
-       switch (option) {
-       case UFSHCD_INT_ENABLE:
-               ufshcd_writel(hba, hba->int_enable_mask, REG_INTERRUPT_ENABLE);
-               break;
-       case UFSHCD_INT_DISABLE:
-               if (hba->ufs_version == UFSHCI_VERSION_10)
-                       ufshcd_writel(hba, INTERRUPT_DISABLE_MASK_10,
-                                     REG_INTERRUPT_ENABLE);
-               else
-                       ufshcd_writel(hba, INTERRUPT_DISABLE_MASK_11,
-                                     REG_INTERRUPT_ENABLE);
-               break;
+       u32 set = ufshcd_readl(hba, REG_INTERRUPT_ENABLE);
+
+       if (hba->ufs_version == UFSHCI_VERSION_10) {
+               u32 rw;
+               rw = set & INTERRUPT_MASK_RW_VER_10;
+               set = rw | ((set ^ intrs) & intrs);
+       } else {
+               set |= intrs;
+       }
+
+       ufshcd_writel(hba, set, REG_INTERRUPT_ENABLE);
+}
+
+/**
+ * ufshcd_disable_intr - disable interrupts
+ * @hba: per adapter instance
+ * @intrs: interrupt bits
+ */
+static void ufshcd_disable_intr(struct ufs_hba *hba, u32 intrs)
+{
+       u32 set = ufshcd_readl(hba, REG_INTERRUPT_ENABLE);
+
+       if (hba->ufs_version == UFSHCI_VERSION_10) {
+               u32 rw;
+               rw = (set & INTERRUPT_MASK_RW_VER_10) &
+                       ~(intrs & INTERRUPT_MASK_RW_VER_10);
+               set = rw | ((set & intrs) & ~INTERRUPT_MASK_RW_VER_10);
+
+       } else {
+               set &= ~intrs;
        }
+
+       ufshcd_writel(hba, set, REG_INTERRUPT_ENABLE);
 }
 
 /**
        uic_cmd->argument3 = 0;
 
        /* enable UIC related interrupts */
-       hba->int_enable_mask |= UIC_COMMAND_COMPL;
-       ufshcd_int_config(hba, UFSHCD_INT_ENABLE);
+       ufshcd_enable_intr(hba, UIC_COMMAND_COMPL);
 
        /* sending UIC commands to controller */
        ufshcd_send_uic_command(hba, uic_cmd);
        }
 
        /* Enable required interrupts */
-       hba->int_enable_mask |= (UTP_TRANSFER_REQ_COMPL |
-                                UIC_ERROR |
-                                UTP_TASK_REQ_COMPL |
-                                DEVICE_FATAL_ERROR |
-                                CONTROLLER_FATAL_ERROR |
-                                SYSTEM_BUS_FATAL_ERROR);
-       ufshcd_int_config(hba, UFSHCD_INT_ENABLE);
+       ufshcd_enable_intr(hba, UFSHCD_ENABLE_INTRS);
 
        /* Configure interrupt aggregation */
        ufshcd_config_int_aggr(hba, INT_AGGR_CONFIG);
 void ufshcd_remove(struct ufs_hba *hba)
 {
        /* disable interrupts */
-       ufshcd_int_config(hba, UFSHCD_INT_DISABLE);
+       ufshcd_disable_intr(hba, hba->intr_mask);
 
        ufshcd_hba_stop(hba);
        ufshcd_hba_free(hba);
        /* Get UFS version supported by the controller */
        hba->ufs_version = ufshcd_get_ufs_version(hba);
 
+       /* Get Interrupt bit mask per version */
+       hba->intr_mask = ufshcd_get_intr_mask(hba);
+
        /* Allocate memory for host memory space */
        err = ufshcd_memory_alloc(hba);
        if (err) {