}
 
 /* Digest hash size if it is too large */
-static int hash_digest_key(struct caam_hash_ctx *ctx, const u8 *key_in,
-                          u32 *keylen, u8 *key_out, u32 digestsize)
+static int hash_digest_key(struct caam_hash_ctx *ctx, u32 *keylen, u8 *key,
+                          u32 digestsize)
 {
        struct device *jrdev = ctx->jrdev;
        u32 *desc;
        struct split_key_result result;
-       dma_addr_t src_dma, dst_dma;
+       dma_addr_t key_dma;
        int ret;
 
        desc = kmalloc(CAAM_CMD_SZ * 8 + CAAM_PTR_SZ * 2, GFP_KERNEL | GFP_DMA);
 
        init_job_desc(desc, 0);
 
-       src_dma = dma_map_single(jrdev, (void *)key_in, *keylen,
-                                DMA_TO_DEVICE);
-       if (dma_mapping_error(jrdev, src_dma)) {
-               dev_err(jrdev, "unable to map key input memory\n");
-               kfree(desc);
-               return -ENOMEM;
-       }
-       dst_dma = dma_map_single(jrdev, (void *)key_out, digestsize,
-                                DMA_FROM_DEVICE);
-       if (dma_mapping_error(jrdev, dst_dma)) {
-               dev_err(jrdev, "unable to map key output memory\n");
-               dma_unmap_single(jrdev, src_dma, *keylen, DMA_TO_DEVICE);
+       key_dma = dma_map_single(jrdev, key, *keylen, DMA_BIDIRECTIONAL);
+       if (dma_mapping_error(jrdev, key_dma)) {
+               dev_err(jrdev, "unable to map key memory\n");
                kfree(desc);
                return -ENOMEM;
        }
        /* Job descriptor to perform unkeyed hash on key_in */
        append_operation(desc, ctx->adata.algtype | OP_ALG_ENCRYPT |
                         OP_ALG_AS_INITFINAL);
-       append_seq_in_ptr(desc, src_dma, *keylen, 0);
+       append_seq_in_ptr(desc, key_dma, *keylen, 0);
        append_seq_fifo_load(desc, *keylen, FIFOLD_CLASS_CLASS2 |
                             FIFOLD_TYPE_LAST2 | FIFOLD_TYPE_MSG);
-       append_seq_out_ptr(desc, dst_dma, digestsize, 0);
+       append_seq_out_ptr(desc, key_dma, digestsize, 0);
        append_seq_store(desc, digestsize, LDST_CLASS_2_CCB |
                         LDST_SRCDST_BYTE_CONTEXT);
 
 #ifdef DEBUG
        print_hex_dump(KERN_ERR, "key_in@"__stringify(__LINE__)": ",
-                      DUMP_PREFIX_ADDRESS, 16, 4, key_in, *keylen, 1);
+                      DUMP_PREFIX_ADDRESS, 16, 4, key, *keylen, 1);
        print_hex_dump(KERN_ERR, "jobdesc@"__stringify(__LINE__)": ",
                       DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1);
 #endif
 #ifdef DEBUG
                print_hex_dump(KERN_ERR,
                               "digested key@"__stringify(__LINE__)": ",
-                              DUMP_PREFIX_ADDRESS, 16, 4, key_in,
-                              digestsize, 1);
+                              DUMP_PREFIX_ADDRESS, 16, 4, key, digestsize, 1);
 #endif
        }
-       dma_unmap_single(jrdev, src_dma, *keylen, DMA_TO_DEVICE);
-       dma_unmap_single(jrdev, dst_dma, digestsize, DMA_FROM_DEVICE);
+       dma_unmap_single(jrdev, key_dma, *keylen, DMA_BIDIRECTIONAL);
 
        *keylen = digestsize;
 
 #endif
 
        if (keylen > blocksize) {
-               hashed_key = kmalloc_array(digestsize,
-                                          sizeof(*hashed_key),
-                                          GFP_KERNEL | GFP_DMA);
+               hashed_key = kmemdup(key, keylen, GFP_KERNEL | GFP_DMA);
                if (!hashed_key)
                        return -ENOMEM;
-               ret = hash_digest_key(ctx, key, &keylen, hashed_key,
-                                     digestsize);
+               ret = hash_digest_key(ctx, &keylen, hashed_key, digestsize);
                if (ret)
                        goto bad_free_key;
                key = hashed_key;
 
 {
        u32 *desc;
        struct split_key_result result;
-       dma_addr_t dma_addr_in, dma_addr_out;
+       dma_addr_t dma_addr;
        int ret = -ENOMEM;
 
        adata->keylen = split_key_len(adata->algtype & OP_ALG_ALGSEL_MASK);
                return ret;
        }
 
-       dma_addr_in = dma_map_single(jrdev, (void *)key_in, keylen,
-                                    DMA_TO_DEVICE);
-       if (dma_mapping_error(jrdev, dma_addr_in)) {
-               dev_err(jrdev, "unable to map key input memory\n");
-               goto out_free;
-       }
+       memcpy(key_out, key_in, keylen);
 
-       dma_addr_out = dma_map_single(jrdev, key_out, adata->keylen_pad,
-                                     DMA_FROM_DEVICE);
-       if (dma_mapping_error(jrdev, dma_addr_out)) {
-               dev_err(jrdev, "unable to map key output memory\n");
-               goto out_unmap_in;
+       dma_addr = dma_map_single(jrdev, key_out, adata->keylen_pad,
+                                 DMA_BIDIRECTIONAL);
+       if (dma_mapping_error(jrdev, dma_addr)) {
+               dev_err(jrdev, "unable to map key memory\n");
+               goto out_free;
        }
 
        init_job_desc(desc, 0);
-       append_key(desc, dma_addr_in, keylen, CLASS_2 | KEY_DEST_CLASS_REG);
+       append_key(desc, dma_addr, keylen, CLASS_2 | KEY_DEST_CLASS_REG);
 
        /* Sets MDHA up into an HMAC-INIT */
        append_operation(desc, (adata->algtype & OP_ALG_ALGSEL_MASK) |
         * FIFO_STORE with the explicit split-key content store
         * (0x26 output type)
         */
-       append_fifo_store(desc, dma_addr_out, adata->keylen,
+       append_fifo_store(desc, dma_addr, adata->keylen,
                          LDST_CLASS_2_CCB | FIFOST_TYPE_SPLIT_KEK);
 
 #ifdef DEBUG
-       print_hex_dump(KERN_ERR, "ctx.key@"__stringify(__LINE__)": ",
-                      DUMP_PREFIX_ADDRESS, 16, 4, key_in, keylen, 1);
        print_hex_dump(KERN_ERR, "jobdesc@"__stringify(__LINE__)": ",
                       DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1);
 #endif
 #endif
        }
 
-       dma_unmap_single(jrdev, dma_addr_out, adata->keylen_pad,
-                        DMA_FROM_DEVICE);
-out_unmap_in:
-       dma_unmap_single(jrdev, dma_addr_in, keylen, DMA_TO_DEVICE);
+       dma_unmap_single(jrdev, dma_addr, adata->keylen_pad, DMA_BIDIRECTIONAL);
 out_free:
        kfree(desc);
        return ret;