#include <linux/slab.h>
 #include <linux/soc/qcom/mdt_loader.h>
 
+static bool mdt_header_valid(const struct firmware *fw)
+{
+       const struct elf32_hdr *ehdr;
+       size_t phend;
+       size_t shend;
+
+       if (fw->size < sizeof(*ehdr))
+               return false;
+
+       ehdr = (struct elf32_hdr *)fw->data;
+
+       if (memcmp(ehdr->e_ident, ELFMAG, SELFMAG))
+               return false;
+
+       if (ehdr->e_phentsize != sizeof(struct elf32_phdr))
+               return -EINVAL;
+
+       phend = size_add(size_mul(sizeof(struct elf32_phdr), ehdr->e_phnum), ehdr->e_phoff);
+       if (phend > fw->size)
+               return false;
+
+       if (ehdr->e_shentsize != sizeof(struct elf32_shdr))
+               return -EINVAL;
+
+       shend = size_add(size_mul(sizeof(struct elf32_shdr), ehdr->e_shnum), ehdr->e_shoff);
+       if (shend > fw->size)
+               return false;
+
+       return true;
+}
+
 static bool mdt_phdr_valid(const struct elf32_phdr *phdr)
 {
        if (phdr->p_type != PT_LOAD)
        phys_addr_t max_addr = 0;
        int i;
 
+       if (!mdt_header_valid(fw))
+               return -EINVAL;
+
        ehdr = (struct elf32_hdr *)fw->data;
        phdrs = (struct elf32_phdr *)(ehdr + 1);
 
        ssize_t ret;
        void *data;
 
+       if (!mdt_header_valid(fw))
+               return ERR_PTR(-EINVAL);
+
        ehdr = (struct elf32_hdr *)fw->data;
        phdrs = (struct elf32_phdr *)(ehdr + 1);
 
        int ret;
        int i;
 
+       if (!mdt_header_valid(fw))
+               return -EINVAL;
+
        ehdr = (struct elf32_hdr *)fw->data;
        phdrs = (struct elf32_phdr *)(ehdr + 1);
 
        if (!fw || !mem_region || !mem_phys || !mem_size)
                return -EINVAL;
 
+       if (!mdt_header_valid(fw))
+               return -EINVAL;
+
        is_split = qcom_mdt_bins_are_split(fw, fw_name);
        ehdr = (struct elf32_hdr *)fw->data;
        phdrs = (struct elf32_phdr *)(ehdr + 1);