#include "intel_display_types.h"
 #include "intel_global_state.h"
 
+static void __intel_atomic_global_state_free(struct kref *kref)
+{
+       struct intel_global_state *obj_state =
+               container_of(kref, struct intel_global_state, ref);
+       struct intel_global_obj *obj = obj_state->obj;
+
+       obj->funcs->atomic_destroy_state(obj, obj_state);
+}
+
+static void intel_atomic_global_state_put(struct intel_global_state *obj_state)
+{
+       kref_put(&obj_state->ref, __intel_atomic_global_state_free);
+}
+
+static struct intel_global_state *
+intel_atomic_global_state_get(struct intel_global_state *obj_state)
+{
+       kref_get(&obj_state->ref);
+
+       return obj_state;
+}
+
 void intel_atomic_global_obj_init(struct drm_i915_private *dev_priv,
                                  struct intel_global_obj *obj,
                                  struct intel_global_state *state,
 {
        memset(obj, 0, sizeof(*obj));
 
+       state->obj = obj;
+
+       kref_init(&state->ref);
+
        obj->state = state;
        obj->funcs = funcs;
        list_add_tail(&obj->head, &dev_priv->global_obj_list);
 
        list_for_each_entry_safe(obj, next, &dev_priv->global_obj_list, head) {
                list_del(&obj->head);
-               obj->funcs->atomic_destroy_state(obj, obj->state);
+
+               drm_WARN_ON(&dev_priv->drm, kref_read(&obj->state->ref) != 1);
+               intel_atomic_global_state_put(obj->state);
        }
 }
 
        if (!obj_state)
                return ERR_PTR(-ENOMEM);
 
+       obj_state->obj = obj;
        obj_state->changed = false;
 
+       kref_init(&obj_state->ref);
+
        state->global_objs[index].state = obj_state;
-       state->global_objs[index].old_state = obj->state;
+       state->global_objs[index].old_state =
+               intel_atomic_global_state_get(obj->state);
        state->global_objs[index].new_state = obj_state;
        state->global_objs[index].ptr = obj;
        obj_state->state = state;
                new_obj_state->state = NULL;
 
                state->global_objs[i].state = old_obj_state;
-               obj->state = new_obj_state;
+
+               intel_atomic_global_state_put(obj->state);
+               obj->state = intel_atomic_global_state_get(new_obj_state);
        }
 }
 
        int i;
 
        for (i = 0; i < state->num_global_objs; i++) {
-               struct intel_global_obj *obj = state->global_objs[i].ptr;
+               intel_atomic_global_state_put(state->global_objs[i].old_state);
+               intel_atomic_global_state_put(state->global_objs[i].new_state);
 
-               obj->funcs->atomic_destroy_state(obj,
-                                                state->global_objs[i].state);
                state->global_objs[i].ptr = NULL;
                state->global_objs[i].state = NULL;
                state->global_objs[i].old_state = NULL;