const u8 *key, unsigned int keylen)
 {
        unsigned long alignmask = crypto_aead_alignmask(tfm);
+       int err;
 
        if ((unsigned long)key & alignmask)
-               return setkey_unaligned(tfm, key, keylen);
+               err = setkey_unaligned(tfm, key, keylen);
+       else
+               err = crypto_aead_alg(tfm)->setkey(tfm, key, keylen);
+
+       if (err)
+               return err;
 
-       return crypto_aead_alg(tfm)->setkey(tfm, key, keylen);
+       crypto_aead_clear_flags(tfm, CRYPTO_TFM_NEED_KEY);
+       return 0;
 }
 EXPORT_SYMBOL_GPL(crypto_aead_setkey);
 
        struct crypto_aead *aead = __crypto_aead_cast(tfm);
        struct aead_alg *alg = crypto_aead_alg(aead);
 
+       crypto_aead_set_flags(aead, CRYPTO_TFM_NEED_KEY);
+
        aead->authsize = alg->maxauthsize;
 
        if (alg->exit)
 
 
 struct aead_tfm {
        struct crypto_aead *aead;
-       bool has_key;
        struct crypto_skcipher *null_tfm;
 };
 
 
        err = -ENOKEY;
        lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
-       if (!tfm->has_key)
+       if (crypto_aead_get_flags(tfm->aead) & CRYPTO_TFM_NEED_KEY)
                goto unlock;
 
        if (!pask->refcnt++)
 static int aead_setkey(void *private, const u8 *key, unsigned int keylen)
 {
        struct aead_tfm *tfm = private;
-       int err;
-
-       err = crypto_aead_setkey(tfm->aead, key, keylen);
-       tfm->has_key = !err;
 
-       return err;
+       return crypto_aead_setkey(tfm->aead, key, keylen);
 }
 
 static void aead_sock_destruct(struct sock *sk)
 {
        struct aead_tfm *tfm = private;
 
-       if (!tfm->has_key)
+       if (crypto_aead_get_flags(tfm->aead) & CRYPTO_TFM_NEED_KEY)
                return -ENOKEY;
 
        return aead_accept_parent_nokey(private, sk);
 
  */
 static inline int crypto_aead_encrypt(struct aead_request *req)
 {
-       return crypto_aead_alg(crypto_aead_reqtfm(req))->encrypt(req);
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+
+       if (crypto_aead_get_flags(aead) & CRYPTO_TFM_NEED_KEY)
+               return -ENOKEY;
+
+       return crypto_aead_alg(aead)->encrypt(req);
 }
 
 /**
 {
        struct crypto_aead *aead = crypto_aead_reqtfm(req);
 
+       if (crypto_aead_get_flags(aead) & CRYPTO_TFM_NEED_KEY)
+               return -ENOKEY;
+
        if (req->cryptlen < crypto_aead_authsize(aead))
                return -EINVAL;