#include "rxe.h"
 
+#define RXE_POOL_TIMEOUT       (200)
 #define RXE_POOL_ALIGN         (16)
 
 static const struct rxe_type_info {
        elem->pool = pool;
        elem->obj = obj;
        kref_init(&elem->ref_cnt);
+       init_completion(&elem->complete);
 
-       err = xa_alloc_cyclic(&pool->xa, &elem->index, elem, pool->limit,
+       /* allocate index in array but leave pointer as NULL so it
+        * can't be looked up until rxe_finalize() is called
+        */
+       err = xa_alloc_cyclic(&pool->xa, &elem->index, NULL, pool->limit,
                              &pool->next, GFP_KERNEL);
        if (err < 0)
                goto err_free;
        return NULL;
 }
 
-int __rxe_add_to_pool(struct rxe_pool *pool, struct rxe_pool_elem *elem)
+int __rxe_add_to_pool(struct rxe_pool *pool, struct rxe_pool_elem *elem,
+                               bool sleepable)
 {
        int err;
+       gfp_t gfp_flags;
 
        if (WARN_ON(pool->type == RXE_TYPE_MR))
                return -EINVAL;
        elem->pool = pool;
        elem->obj = (u8 *)elem - pool->elem_offset;
        kref_init(&elem->ref_cnt);
-
-       err = xa_alloc_cyclic(&pool->xa, &elem->index, elem, pool->limit,
-                             &pool->next, GFP_KERNEL);
+       init_completion(&elem->complete);
+
+       /* AH objects are unique in that the create_ah verb
+        * can be called in atomic context. If the create_ah
+        * call is not sleepable use GFP_ATOMIC.
+        */
+       gfp_flags = sleepable ? GFP_KERNEL : GFP_ATOMIC;
+
+       if (sleepable)
+               might_sleep();
+       err = xa_alloc_cyclic(&pool->xa, &elem->index, NULL, pool->limit,
+                             &pool->next, gfp_flags);
        if (err < 0)
                goto err_cnt;
 
 static void rxe_elem_release(struct kref *kref)
 {
        struct rxe_pool_elem *elem = container_of(kref, typeof(*elem), ref_cnt);
+
+       complete(&elem->complete);
+}
+
+int __rxe_cleanup(struct rxe_pool_elem *elem, bool sleepable)
+{
        struct rxe_pool *pool = elem->pool;
+       struct xarray *xa = &pool->xa;
+       static int timeout = RXE_POOL_TIMEOUT;
+       unsigned long flags;
+       int ret, err = 0;
+       void *xa_ret;
 
-       xa_erase(&pool->xa, elem->index);
+       if (sleepable)
+               might_sleep();
+
+       /* erase xarray entry to prevent looking up
+        * the pool elem from its index
+        */
+       xa_lock_irqsave(xa, flags);
+       xa_ret = __xa_erase(xa, elem->index);
+       xa_unlock_irqrestore(xa, flags);
+       WARN_ON(xa_err(xa_ret));
+
+       /* if this is the last call to rxe_put complete the
+        * object. It is safe to touch obj->elem after this since
+        * it is freed below
+        */
+       __rxe_put(elem);
+
+       /* wait until all references to the object have been
+        * dropped before final object specific cleanup and
+        * return to rdma-core
+        */
+       if (sleepable) {
+               if (!completion_done(&elem->complete) && timeout) {
+                       ret = wait_for_completion_timeout(&elem->complete,
+                                       timeout);
+
+                       /* Shouldn't happen. There are still references to
+                        * the object but, rather than deadlock, free the
+                        * object or pass back to rdma-core.
+                        */
+                       if (WARN_ON(!ret))
+                               err = -EINVAL;
+               }
+       } else {
+               unsigned long until = jiffies + timeout;
+
+               /* AH objects are unique in that the destroy_ah verb
+                * can be called in atomic context. This delay
+                * replaces the wait_for_completion call above
+                * when the destroy_ah call is not sleepable
+                */
+               while (!completion_done(&elem->complete) &&
+                               time_before(jiffies, until))
+                       mdelay(1);
+
+               if (WARN_ON(!completion_done(&elem->complete)))
+                       err = -EINVAL;
+       }
 
        if (pool->cleanup)
                pool->cleanup(elem);
                kfree(elem->obj);
 
        atomic_dec(&pool->num_elem);
+
+       return err;
 }
 
 int __rxe_get(struct rxe_pool_elem *elem)
 {
        return kref_put(&elem->ref_cnt, rxe_elem_release);
 }
+
+void __rxe_finalize(struct rxe_pool_elem *elem)
+{
+       struct xarray *xa = &elem->pool->xa;
+       unsigned long flags;
+       void *ret;
+
+       xa_lock_irqsave(xa, flags);
+       ret = __xa_store(&elem->pool->xa, elem->index, elem, GFP_KERNEL);
+       xa_unlock_irqrestore(xa, flags);
+       WARN_ON(xa_err(ret));
+}
 
        void                    *obj;
        struct kref             ref_cnt;
        struct list_head        list;
+       struct completion       complete;
        u32                     index;
 };
 
 void *rxe_alloc(struct rxe_pool *pool);
 
 /* connect already allocated object to pool */
-int __rxe_add_to_pool(struct rxe_pool *pool, struct rxe_pool_elem *elem);
-
-#define rxe_add_to_pool(pool, obj) __rxe_add_to_pool(pool, &(obj)->elem)
+int __rxe_add_to_pool(struct rxe_pool *pool, struct rxe_pool_elem *elem,
+                               bool sleepable);
+#define rxe_add_to_pool(pool, obj) __rxe_add_to_pool(pool, &(obj)->elem, true)
+#define rxe_add_to_pool_ah(pool, obj, sleepable) __rxe_add_to_pool(pool, \
+                               &(obj)->elem, sleepable)
 
 /* lookup an indexed object from index. takes a reference on object */
 void *rxe_pool_get_index(struct rxe_pool *pool, u32 index);
 
 int __rxe_get(struct rxe_pool_elem *elem);
-
 #define rxe_get(obj) __rxe_get(&(obj)->elem)
 
 int __rxe_put(struct rxe_pool_elem *elem);
-
 #define rxe_put(obj) __rxe_put(&(obj)->elem)
 
+int __rxe_cleanup(struct rxe_pool_elem *elem, bool sleepable);
+#define rxe_cleanup(obj) __rxe_cleanup(&(obj)->elem, true)
+#define rxe_cleanup_ah(obj, sleepable) __rxe_cleanup(&(obj)->elem, sleepable)
+
 #define rxe_read(obj) kref_read(&(obj)->elem.ref_cnt)
 
+void __rxe_finalize(struct rxe_pool_elem *elem);
+#define rxe_finalize(obj) __rxe_finalize(&(obj)->elem)
+
 #endif /* RXE_POOL_H */
 
 {
        struct rxe_ucontext *uc = to_ruc(ibuc);
 
-       rxe_put(uc);
+       rxe_cleanup(uc);
 }
 
 static int rxe_port_immutable(struct ib_device *dev, u32 port_num,
 {
        struct rxe_pd *pd = to_rpd(ibpd);
 
-       rxe_put(pd);
+       rxe_cleanup(pd);
        return 0;
 }
 
        if (err)
                return err;
 
-       err = rxe_add_to_pool(&rxe->ah_pool, ah);
+       err = rxe_add_to_pool_ah(&rxe->ah_pool, ah,
+                       init_attr->flags & RDMA_CREATE_AH_SLEEPABLE);
        if (err)
                return err;
 
                err = copy_to_user(&uresp->ah_num, &ah->ah_num,
                                         sizeof(uresp->ah_num));
                if (err) {
-                       rxe_put(ah);
+                       rxe_cleanup(ah);
                        return -EFAULT;
                }
        } else if (ah->is_user) {
        }
 
        rxe_init_av(init_attr->ah_attr, &ah->av);
+       rxe_finalize(ah);
+
        return 0;
 }
 
 {
        struct rxe_ah *ah = to_rah(ibah);
 
-       rxe_put(ah);
+       rxe_cleanup_ah(ah, flags & RDMA_DESTROY_AH_SLEEPABLE);
+
        return 0;
 }
 
 
        err = rxe_srq_from_init(rxe, srq, init, udata, uresp);
        if (err)
-               goto err_put;
+               goto err_cleanup;
 
        return 0;
 
-err_put:
-       rxe_put(srq);
+err_cleanup:
+       rxe_cleanup(srq);
+
        return err;
 }
 
 {
        struct rxe_srq *srq = to_rsrq(ibsrq);
 
-       rxe_put(srq);
+       rxe_cleanup(srq);
        return 0;
 }
 
        if (err)
                goto qp_init;
 
+       rxe_finalize(qp);
        return 0;
 
 qp_init:
-       rxe_put(qp);
+       rxe_cleanup(qp);
        return err;
 }
 
        if (ret)
                return ret;
 
-       rxe_put(qp);
+       rxe_cleanup(qp);
        return 0;
 }
 
 
        rxe_cq_disable(cq);
 
-       rxe_put(cq);
+       rxe_cleanup(cq);
        return 0;
 }
 
 
        rxe_get(pd);
        rxe_mr_init_dma(pd, access, mr);
+       rxe_finalize(mr);
 
        return &mr->ibmr;
 }
        if (err)
                goto err3;
 
+       rxe_finalize(mr);
+
        return &mr->ibmr;
 
 err3:
        rxe_put(pd);
-       rxe_put(mr);
+       rxe_cleanup(mr);
 err2:
        return ERR_PTR(err);
 }
        if (err)
                goto err2;
 
+       rxe_finalize(mr);
+
        return &mr->ibmr;
 
 err2:
        rxe_put(pd);
-       rxe_put(mr);
+       rxe_cleanup(mr);
 err1:
        return ERR_PTR(err);
 }