#include <linux/mpi.h>
 
 struct dh_ctx {
-       MPI p;
-       MPI g;
-       MPI xa;
+       MPI p;  /* Value is guaranteed to be set. */
+       MPI q;  /* Value is optional. */
+       MPI g;  /* Value is guaranteed to be set. */
+       MPI xa; /* Value is guaranteed to be set. */
 };
 
 static void dh_clear_ctx(struct dh_ctx *ctx)
 {
        mpi_free(ctx->p);
+       mpi_free(ctx->q);
        mpi_free(ctx->g);
        mpi_free(ctx->xa);
        memset(ctx, 0, sizeof(*ctx));
        if (!ctx->p)
                return -EINVAL;
 
+       if (params->q && params->q_size) {
+               ctx->q = mpi_read_raw_data(params->q, params->q_size);
+               if (!ctx->q)
+                       return -EINVAL;
+       }
+
        ctx->g = mpi_read_raw_data(params->g, params->g_size);
        if (!ctx->g)
                return -EINVAL;
        return -EINVAL;
 }
 
+/*
+ * SP800-56A public key verification:
+ *
+ * * If Q is provided as part of the domain paramenters, a full validation
+ *   according to SP800-56A section 5.6.2.3.1 is performed.
+ *
+ * * If Q is not provided, a partial validation according to SP800-56A section
+ *   5.6.2.3.2 is performed.
+ */
+static int dh_is_pubkey_valid(struct dh_ctx *ctx, MPI y)
+{
+       if (unlikely(!ctx->p))
+               return -EINVAL;
+
+       /*
+        * Step 1: Verify that 2 <= y <= p - 2.
+        *
+        * The upper limit check is actually y < p instead of y < p - 1
+        * as the mpi_sub_ui function is yet missing.
+        */
+       if (mpi_cmp_ui(y, 1) < 1 || mpi_cmp(y, ctx->p) >= 0)
+               return -EINVAL;
+
+       /* Step 2: Verify that 1 = y^q mod p */
+       if (ctx->q) {
+               MPI val = mpi_alloc(0);
+               int ret;
+
+               if (!val)
+                       return -ENOMEM;
+
+               ret = mpi_powm(val, y, ctx->q, ctx->p);
+
+               if (ret) {
+                       mpi_free(val);
+                       return ret;
+               }
+
+               ret = mpi_cmp_ui(val, 1);
+
+               mpi_free(val);
+
+               if (ret != 0)
+                       return -EINVAL;
+       }
+
+       return 0;
+}
+
 static int dh_compute_value(struct kpp_request *req)
 {
        struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
                        ret = -EINVAL;
                        goto err_free_val;
                }
+               ret = dh_is_pubkey_valid(ctx, base);
+               if (ret)
+                       goto err_free_val;
        } else {
                base = ctx->g;
        }
 
 
 static inline unsigned int dh_data_size(const struct dh *p)
 {
-       return p->key_size + p->p_size + p->g_size;
+       return p->key_size + p->p_size + p->q_size + p->g_size;
 }
 
 unsigned int crypto_dh_key_len(const struct dh *p)
        ptr = dh_pack_data(ptr, &secret, sizeof(secret));
        ptr = dh_pack_data(ptr, ¶ms->key_size, sizeof(params->key_size));
        ptr = dh_pack_data(ptr, ¶ms->p_size, sizeof(params->p_size));
+       ptr = dh_pack_data(ptr, ¶ms->q_size, sizeof(params->q_size));
        ptr = dh_pack_data(ptr, ¶ms->g_size, sizeof(params->g_size));
        ptr = dh_pack_data(ptr, params->key, params->key_size);
        ptr = dh_pack_data(ptr, params->p, params->p_size);
+       ptr = dh_pack_data(ptr, params->q, params->q_size);
        dh_pack_data(ptr, params->g, params->g_size);
 
        return 0;
 
        ptr = dh_unpack_data(¶ms->key_size, ptr, sizeof(params->key_size));
        ptr = dh_unpack_data(¶ms->p_size, ptr, sizeof(params->p_size));
+       ptr = dh_unpack_data(¶ms->q_size, ptr, sizeof(params->q_size));
        ptr = dh_unpack_data(¶ms->g_size, ptr, sizeof(params->g_size));
        if (secret.len != crypto_dh_key_len(params))
                return -EINVAL;
         * some drivers assume otherwise.
         */
        if (params->key_size > params->p_size ||
-           params->g_size > params->p_size)
+           params->g_size > params->p_size || params->q_size > params->p_size)
                return -EINVAL;
 
        /* Don't allocate memory. Set pointers to data within
         */
        params->key = (void *)ptr;
        params->p = (void *)(ptr + params->key_size);
-       params->g = (void *)(ptr + params->key_size + params->p_size);
+       params->q = (void *)(ptr + params->key_size + params->p_size);
+       params->g = (void *)(ptr + params->key_size + params->p_size +
+                            params->q_size);
 
        /*
         * Don't permit 'p' to be 0.  It's not a prime number, and it's subject
        if (memchr_inv(params->p, 0, params->p_size) == NULL)
                return -EINVAL;
 
+       /* It is permissible to not provide Q. */
+       if (params->q_size == 0)
+               params->q = NULL;
+
        return 0;
 }
 EXPORT_SYMBOL_GPL(crypto_dh_decode_key);