#include <linux/fs.h>
 #include <linux/mm.h>
 #include <linux/mman.h>
+#include <linux/memory-tiers.h>
 #include "dax-private.h"
 #include "bus.h"
 
+/*
+ * Default abstract distance assigned to the NUMA node onlined
+ * by DAX/kmem if the low level platform driver didn't initialize
+ * one for this NUMA node.
+ */
+#define MEMTIER_DEFAULT_DAX_ADISTANCE  (MEMTIER_ADISTANCE_DRAM * 5)
+
 /* Memory resource name used for add_memory_driver_managed(). */
 static const char *kmem_name;
 /* Set if any memory will remain added when the driver will be unloaded. */
        struct resource *res[];
 };
 
+static struct memory_dev_type *dax_slowmem_type;
 static int dev_dax_kmem_probe(struct dev_dax *dev_dax)
 {
        struct device *dev = &dev_dax->dev;
                return -EINVAL;
        }
 
+       init_node_memory_type(numa_node, dax_slowmem_type);
+
+       rc = -ENOMEM;
        data = kzalloc(struct_size(data, res, dev_dax->nr_range), GFP_KERNEL);
        if (!data)
-               return -ENOMEM;
+               goto err_dax_kmem_data;
 
-       rc = -ENOMEM;
        data->res_name = kstrdup(dev_name(dev), GFP_KERNEL);
        if (!data->res_name)
                goto err_res_name;
        kfree(data->res_name);
 err_res_name:
        kfree(data);
+err_dax_kmem_data:
+       clear_node_memory_type(numa_node, dax_slowmem_type);
        return rc;
 }
 
 static void dev_dax_kmem_remove(struct dev_dax *dev_dax)
 {
        int i, success = 0;
+       int node = dev_dax->target_node;
        struct device *dev = &dev_dax->dev;
        struct dax_kmem_data *data = dev_get_drvdata(dev);
 
                kfree(data->res_name);
                kfree(data);
                dev_set_drvdata(dev, NULL);
+               /*
+                * Clear the memtype association on successful unplug.
+                * If not, we have memory blocks left which can be
+                * offlined/onlined later. We need to keep memory_dev_type
+                * for that. This implies this reference will be around
+                * till next reboot.
+                */
+               clear_node_memory_type(node, dax_slowmem_type);
        }
 }
 #else
        if (!kmem_name)
                return -ENOMEM;
 
+       dax_slowmem_type = alloc_memory_type(MEMTIER_DEFAULT_DAX_ADISTANCE);
+       if (IS_ERR(dax_slowmem_type)) {
+               rc = PTR_ERR(dax_slowmem_type);
+               goto err_dax_slowmem_type;
+       }
+
        rc = dax_driver_register(&device_dax_kmem_driver);
        if (rc)
-               kfree_const(kmem_name);
+               goto error_dax_driver;
+
+       return rc;
+
+error_dax_driver:
+       destroy_memory_type(dax_slowmem_type);
+err_dax_slowmem_type:
+       kfree_const(kmem_name);
        return rc;
 }
 
        dax_driver_unregister(&device_dax_kmem_driver);
        if (!any_hotremove_failed)
                kfree_const(kmem_name);
+       destroy_memory_type(dax_slowmem_type);
 }
 
 MODULE_AUTHOR("Intel Corporation");
 
 #ifndef _LINUX_MEMORY_TIERS_H
 #define _LINUX_MEMORY_TIERS_H
 
+#include <linux/types.h>
+#include <linux/nodemask.h>
+#include <linux/kref.h>
 /*
  * Each tier cover a abstrace distance chunk size of 128
  */
 #define MEMTIER_ADISTANCE_DRAM ((4 * MEMTIER_CHUNK_SIZE) + (MEMTIER_CHUNK_SIZE >> 1))
 #define MEMTIER_HOTPLUG_PRIO   100
 
+struct memory_tier;
+struct memory_dev_type {
+       /* list of memory types that are part of same tier as this type */
+       struct list_head tier_sibiling;
+       /* abstract distance for this specific memory type */
+       int adistance;
+       /* Nodes of same abstract distance */
+       nodemask_t nodes;
+       struct kref kref;
+       struct memory_tier *memtier;
+};
+
 #ifdef CONFIG_NUMA
-#include <linux/types.h>
 extern bool numa_demotion_enabled;
+struct memory_dev_type *alloc_memory_type(int adistance);
+void destroy_memory_type(struct memory_dev_type *memtype);
+void init_node_memory_type(int node, struct memory_dev_type *default_type);
+void clear_node_memory_type(int node, struct memory_dev_type *memtype);
 
 #else
 
 #define numa_demotion_enabled  false
+/*
+ * CONFIG_NUMA implementation returns non NULL error.
+ */
+static inline struct memory_dev_type *alloc_memory_type(int adistance)
+{
+       return NULL;
+}
+
+static inline void destroy_memory_type(struct memory_dev_type *memtype)
+{
+
+}
+
+static inline void init_node_memory_type(int node, struct memory_dev_type *default_type)
+{
+
+}
+
+static inline void clear_node_memory_type(int node, struct memory_dev_type *memtype)
+{
+
+}
 #endif /* CONFIG_NUMA */
 #endif  /* _LINUX_MEMORY_TIERS_H */
 
 // SPDX-License-Identifier: GPL-2.0
-#include <linux/types.h>
-#include <linux/nodemask.h>
 #include <linux/slab.h>
 #include <linux/lockdep.h>
 #include <linux/sysfs.h>
        int adistance_start;
 };
 
-struct memory_dev_type {
-       /* list of memory types that are part of same tier as this type */
-       struct list_head tier_sibiling;
-       /* abstract distance for this specific memory type */
-       int adistance;
-       /* Nodes of same abstract distance */
-       nodemask_t nodes;
-       struct memory_tier *memtier;
+struct node_memory_type_map {
+       struct memory_dev_type *memtype;
+       int map_count;
 };
 
 static DEFINE_MUTEX(memory_tier_lock);
 static LIST_HEAD(memory_tiers);
-static struct memory_dev_type *node_memory_types[MAX_NUMNODES];
-/*
- * For now we can have 4 faster memory tiers with smaller adistance
- * than default DRAM tier.
- */
-static struct memory_dev_type default_dram_type  = {
-       .adistance = MEMTIER_ADISTANCE_DRAM,
-       .tier_sibiling = LIST_HEAD_INIT(default_dram_type.tier_sibiling),
-};
+static struct node_memory_type_map node_memory_types[MAX_NUMNODES];
+static struct memory_dev_type *default_dram_type;
 
 static struct memory_tier *find_create_memory_tier(struct memory_dev_type *memtype)
 {
        return new_memtier;
 }
 
+static inline void __init_node_memory_type(int node, struct memory_dev_type *memtype)
+{
+       if (!node_memory_types[node].memtype)
+               node_memory_types[node].memtype = memtype;
+       /*
+        * for each device getting added in the same NUMA node
+        * with this specific memtype, bump the map count. We
+        * Only take memtype device reference once, so that
+        * changing a node memtype can be done by droping the
+        * only reference count taken here.
+        */
+
+       if (node_memory_types[node].memtype == memtype) {
+               if (!node_memory_types[node].map_count++)
+                       kref_get(&memtype->kref);
+       }
+}
+
 static struct memory_tier *set_node_memory_tier(int node)
 {
        struct memory_tier *memtier;
        if (!node_state(node, N_MEMORY))
                return ERR_PTR(-EINVAL);
 
-       if (!node_memory_types[node])
-               node_memory_types[node] = &default_dram_type;
+       __init_node_memory_type(node, default_dram_type);
 
-       memtype = node_memory_types[node];
+       memtype = node_memory_types[node].memtype;
        node_set(node, memtype->nodes);
        memtier = find_create_memory_tier(memtype);
        return memtier;
        if (memtier) {
                struct memory_dev_type *memtype;
 
-               memtype = node_memory_types[node];
+               memtype = node_memory_types[node].memtype;
                node_clear(node, memtype->nodes);
                if (nodes_empty(memtype->nodes)) {
                        list_del_init(&memtype->tier_sibiling);
        return cleared;
 }
 
+static void release_memtype(struct kref *kref)
+{
+       struct memory_dev_type *memtype;
+
+       memtype = container_of(kref, struct memory_dev_type, kref);
+       kfree(memtype);
+}
+
+struct memory_dev_type *alloc_memory_type(int adistance)
+{
+       struct memory_dev_type *memtype;
+
+       memtype = kmalloc(sizeof(*memtype), GFP_KERNEL);
+       if (!memtype)
+               return ERR_PTR(-ENOMEM);
+
+       memtype->adistance = adistance;
+       INIT_LIST_HEAD(&memtype->tier_sibiling);
+       memtype->nodes  = NODE_MASK_NONE;
+       memtype->memtier = NULL;
+       kref_init(&memtype->kref);
+       return memtype;
+}
+EXPORT_SYMBOL_GPL(alloc_memory_type);
+
+void destroy_memory_type(struct memory_dev_type *memtype)
+{
+       kref_put(&memtype->kref, release_memtype);
+}
+EXPORT_SYMBOL_GPL(destroy_memory_type);
+
+void init_node_memory_type(int node, struct memory_dev_type *memtype)
+{
+
+       mutex_lock(&memory_tier_lock);
+       __init_node_memory_type(node, memtype);
+       mutex_unlock(&memory_tier_lock);
+}
+EXPORT_SYMBOL_GPL(init_node_memory_type);
+
+void clear_node_memory_type(int node, struct memory_dev_type *memtype)
+{
+       mutex_lock(&memory_tier_lock);
+       if (node_memory_types[node].memtype == memtype)
+               node_memory_types[node].map_count--;
+       /*
+        * If we umapped all the attached devices to this node,
+        * clear the node memory type.
+        */
+       if (!node_memory_types[node].map_count) {
+               node_memory_types[node].memtype = NULL;
+               kref_put(&memtype->kref, release_memtype);
+       }
+       mutex_unlock(&memory_tier_lock);
+}
+EXPORT_SYMBOL_GPL(clear_node_memory_type);
+
 static int __meminit memtier_hotplug_callback(struct notifier_block *self,
                                              unsigned long action, void *_arg)
 {
        struct memory_tier *memtier;
 
        mutex_lock(&memory_tier_lock);
+       /*
+        * For now we can have 4 faster memory tiers with smaller adistance
+        * than default DRAM tier.
+        */
+       default_dram_type = alloc_memory_type(MEMTIER_ADISTANCE_DRAM);
+       if (!default_dram_type)
+               panic("%s() failed to allocate default DRAM tier\n", __func__);
+
        /*
         * Look at all the existing N_MEMORY nodes and add them to
         * default memory tier or to a tier if we already have memory