build_encrypt_ctxt((struct smb2_encryption_neg_context *)pneg_ctxt,
                                   conn->cipher_type);
                rsp->NegotiateContextCount = cpu_to_le16(++neg_ctxt_cnt);
-               ctxt_size += sizeof(struct smb2_encryption_neg_context);
+               ctxt_size += sizeof(struct smb2_encryption_neg_context) + 2;
                /* Round to 8 byte boundary */
                pneg_ctxt +=
-                       round_up(sizeof(struct smb2_encryption_neg_context),
+                       round_up(sizeof(struct smb2_encryption_neg_context) + 2,
                                 8);
        }
 
                build_compression_ctxt((struct smb2_compression_ctx *)pneg_ctxt,
                                       conn->compress_algorithm);
                rsp->NegotiateContextCount = cpu_to_le16(++neg_ctxt_cnt);
-               ctxt_size += sizeof(struct smb2_compression_ctx);
+               ctxt_size += sizeof(struct smb2_compression_ctx) + 2;
                /* Round to 8 byte boundary */
-               pneg_ctxt += round_up(sizeof(struct smb2_compression_ctx), 8);
+               pneg_ctxt += round_up(sizeof(struct smb2_compression_ctx) + 2,
+                                     8);
        }
 
        if (conn->posix_ext_supported) {
        return err;
 }
 
-static int decode_encrypt_ctxt(struct ksmbd_conn *conn,
-                              struct smb2_encryption_neg_context *pneg_ctxt)
+static void decode_encrypt_ctxt(struct ksmbd_conn *conn,
+                               struct smb2_encryption_neg_context *pneg_ctxt,
+                               int len_of_ctxts)
 {
-       int i;
        int cph_cnt = le16_to_cpu(pneg_ctxt->CipherCount);
+       int i, cphs_size = cph_cnt * sizeof(__le16);
 
        conn->cipher_type = 0;
 
+       if (sizeof(struct smb2_encryption_neg_context) + cphs_size >
+           len_of_ctxts) {
+               pr_err("Invalid cipher count(%d)\n", cph_cnt);
+               return;
+       }
+
        if (!(server_conf.flags & KSMBD_GLOBAL_FLAG_SMB2_ENCRYPTION))
-               goto out;
+               return;
 
        for (i = 0; i < cph_cnt; i++) {
                if (pneg_ctxt->Ciphers[i] == SMB2_ENCRYPTION_AES128_GCM ||
                        break;
                }
        }
-
-out:
-       /*
-        * Return encrypt context size in request.
-        * So need to plus extra number of ciphers size.
-        */
-       return sizeof(struct smb2_encryption_neg_context) +
-               ((cph_cnt - 1) * 2);
 }
 
-static int decode_compress_ctxt(struct ksmbd_conn *conn,
-                               struct smb2_compression_ctx *pneg_ctxt)
+static void decode_compress_ctxt(struct ksmbd_conn *conn,
+                                struct smb2_compression_ctx *pneg_ctxt)
 {
-       int algo_cnt = le16_to_cpu(pneg_ctxt->CompressionAlgorithmCount);
-
        conn->compress_algorithm = SMB3_COMPRESS_NONE;
-
-       /*
-        * Return compression context size in request.
-        * So need to plus extra number of CompressionAlgorithms size.
-        */
-       return sizeof(struct smb2_compression_ctx) +
-               ((algo_cnt - 1) * 2);
 }
 
 static __le32 deassemble_neg_contexts(struct ksmbd_conn *conn,
                                      struct smb2_negotiate_req *req)
 {
-       int i = 0;
-       __le32 status = 0;
        /* +4 is to account for the RFC1001 len field */
-       char *pneg_ctxt = (char *)req +
-                       le32_to_cpu(req->NegotiateContextOffset) + 4;
-       __le16 *ContextType = (__le16 *)pneg_ctxt;
+       struct smb2_neg_context *pctx = (struct smb2_neg_context *)((char *)req + 4);
+       int i = 0, len_of_ctxts;
+       int offset = le32_to_cpu(req->NegotiateContextOffset);
        int neg_ctxt_cnt = le16_to_cpu(req->NegotiateContextCount);
-       int ctxt_size;
+       int len_of_smb = be32_to_cpu(req->hdr.smb2_buf_length);
+       __le32 status = STATUS_INVALID_PARAMETER;
+
+       ksmbd_debug(SMB, "decoding %d negotiate contexts\n", neg_ctxt_cnt);
+       if (len_of_smb <= offset) {
+               ksmbd_debug(SMB, "Invalid response: negotiate context offset\n");
+               return status;
+       }
+
+       len_of_ctxts = len_of_smb - offset;
 
-       ksmbd_debug(SMB, "negotiate context count = %d\n", neg_ctxt_cnt);
-       status = STATUS_INVALID_PARAMETER;
        while (i++ < neg_ctxt_cnt) {
-               if (*ContextType == SMB2_PREAUTH_INTEGRITY_CAPABILITIES) {
+               int clen;
+
+               /* check that offset is not beyond end of SMB */
+               if (len_of_ctxts == 0)
+                       break;
+
+               if (len_of_ctxts < sizeof(struct smb2_neg_context))
+                       break;
+
+               pctx = (struct smb2_neg_context *)((char *)pctx + offset);
+               clen = le16_to_cpu(pctx->DataLength);
+               if (clen + sizeof(struct smb2_neg_context) > len_of_ctxts)
+                       break;
+
+               if (pctx->ContextType == SMB2_PREAUTH_INTEGRITY_CAPABILITIES) {
                        ksmbd_debug(SMB,
                                    "deassemble SMB2_PREAUTH_INTEGRITY_CAPABILITIES context\n");
                        if (conn->preauth_info->Preauth_HashId)
                                break;
 
                        status = decode_preauth_ctxt(conn,
-                                                    (struct smb2_preauth_neg_context *)pneg_ctxt);
-                       pneg_ctxt += DIV_ROUND_UP(sizeof(struct smb2_preauth_neg_context), 8) * 8;
-               } else if (*ContextType == SMB2_ENCRYPTION_CAPABILITIES) {
+                                                    (struct smb2_preauth_neg_context *)pctx);
+                       if (status != STATUS_SUCCESS)
+                               break;
+               } else if (pctx->ContextType == SMB2_ENCRYPTION_CAPABILITIES) {
                        ksmbd_debug(SMB,
                                    "deassemble SMB2_ENCRYPTION_CAPABILITIES context\n");
                        if (conn->cipher_type)
                                break;
 
-                       ctxt_size = decode_encrypt_ctxt(conn,
-                               (struct smb2_encryption_neg_context *)pneg_ctxt);
-                       pneg_ctxt += DIV_ROUND_UP(ctxt_size, 8) * 8;
-               } else if (*ContextType == SMB2_COMPRESSION_CAPABILITIES) {
+                       decode_encrypt_ctxt(conn,
+                                           (struct smb2_encryption_neg_context *)pctx,
+                                           len_of_ctxts);
+               } else if (pctx->ContextType == SMB2_COMPRESSION_CAPABILITIES) {
                        ksmbd_debug(SMB,
                                    "deassemble SMB2_COMPRESSION_CAPABILITIES context\n");
                        if (conn->compress_algorithm)
                                break;
 
-                       ctxt_size = decode_compress_ctxt(conn,
-                               (struct smb2_compression_ctx *)pneg_ctxt);
-                       pneg_ctxt += DIV_ROUND_UP(ctxt_size, 8) * 8;
-               } else if (*ContextType == SMB2_NETNAME_NEGOTIATE_CONTEXT_ID) {
+                       decode_compress_ctxt(conn,
+                                            (struct smb2_compression_ctx *)pctx);
+               } else if (pctx->ContextType == SMB2_NETNAME_NEGOTIATE_CONTEXT_ID) {
                        ksmbd_debug(SMB,
                                    "deassemble SMB2_NETNAME_NEGOTIATE_CONTEXT_ID context\n");
-                       ctxt_size = sizeof(struct smb2_netname_neg_context);
-                       ctxt_size += DIV_ROUND_UP(le16_to_cpu(((struct smb2_netname_neg_context *)
-                                                              pneg_ctxt)->DataLength), 8) * 8;
-                       pneg_ctxt += ctxt_size;
-               } else if (*ContextType == SMB2_POSIX_EXTENSIONS_AVAILABLE) {
+               } else if (pctx->ContextType == SMB2_POSIX_EXTENSIONS_AVAILABLE) {
                        ksmbd_debug(SMB,
                                    "deassemble SMB2_POSIX_EXTENSIONS_AVAILABLE context\n");
                        conn->posix_ext_supported = true;
-                       pneg_ctxt += DIV_ROUND_UP(sizeof(struct smb2_posix_neg_context), 8) * 8;
                }
-               ContextType = (__le16 *)pneg_ctxt;
 
-               if (status != STATUS_SUCCESS)
-                       break;
+               /* offsets must be 8 byte aligned */
+               clen = (clen + 7) & ~0x7;
+               offset = clen + sizeof(struct smb2_neg_context);
+               len_of_ctxts -= clen + sizeof(struct smb2_neg_context);
        }
        return status;
 }