#include <linux/sysfs.h>
 #include <linux/kobject.h>
 #include <linux/memory.h>
+#include <linux/mmzone.h>
 #include <linux/memory-tiers.h>
 
 #include "internal.h"
 
 static struct memory_tier *__node_get_memory_tier(int node)
 {
-       struct memory_dev_type *memtype;
+       pg_data_t *pgdat;
 
-       memtype = node_memory_types[node];
-       if (memtype && node_isset(node, memtype->nodes))
-               return memtype->memtier;
-       return NULL;
+       pgdat = NODE_DATA(node);
+       if (!pgdat)
+               return NULL;
+       /*
+        * Since we hold memory_tier_lock, we can avoid
+        * RCU read locks when accessing the details. No
+        * parallel updates are possible here.
+        */
+       return rcu_dereference_check(pgdat->memtier,
+                                    lockdep_is_held(&memory_tier_lock));
 }
 
 #ifdef CONFIG_MIGRATION
 {
        struct memory_tier *memtier;
        struct memory_dev_type *memtype;
+       pg_data_t *pgdat = NODE_DATA(node);
+
 
        lockdep_assert_held_once(&memory_tier_lock);
 
        memtype = node_memory_types[node].memtype;
        node_set(node, memtype->nodes);
        memtier = find_create_memory_tier(memtype);
+       if (!IS_ERR(memtier))
+               rcu_assign_pointer(pgdat->memtier, memtier);
        return memtier;
 }
 
 static void destroy_memory_tier(struct memory_tier *memtier)
 {
        list_del(&memtier->list);
+       /*
+        * synchronize_rcu in clear_node_memory_tier makes sure
+        * we don't have rcu access to this memory tier.
+        */
        kfree(memtier);
 }
 
 static bool clear_node_memory_tier(int node)
 {
        bool cleared = false;
+       pg_data_t *pgdat;
        struct memory_tier *memtier;
 
+       pgdat = NODE_DATA(node);
+       if (!pgdat)
+               return false;
+
+       /*
+        * Make sure that anybody looking at NODE_DATA who finds
+        * a valid memtier finds memory_dev_types with nodes still
+        * linked to the memtier. We achieve this by waiting for
+        * rcu read section to finish using synchronize_rcu.
+        * This also enables us to free the destroyed memory tier
+        * with kfree instead of kfree_rcu
+        */
        memtier = __node_get_memory_tier(node);
        if (memtier) {
                struct memory_dev_type *memtype;
 
+               rcu_assign_pointer(pgdat->memtier, NULL);
+               synchronize_rcu();
                memtype = node_memory_types[node].memtype;
                node_clear(node, memtype->nodes);
                if (nodes_empty(memtype->nodes)) {