struct hl_vm_phys_pg_pack *phys_pg_pack;
        u64 paddr = 0, total_size, num_pgs, i;
        u32 num_curr_pgs, page_size;
-       int handle, rc;
        bool contiguous;
+       int handle, rc;
 
        num_curr_pgs = 0;
 
        contiguous = args->flags & HL_MEM_CONTIGUOUS;
 
        if (contiguous) {
-               paddr = (u64) gen_pool_alloc(vm->dram_pg_pool, total_size);
+               if (is_power_of_2(page_size))
+                       paddr = (u64) (uintptr_t) gen_pool_dma_alloc_align(vm->dram_pg_pool,
+                                                               total_size, NULL, page_size);
+               else
+                       paddr = (u64) (uintptr_t) gen_pool_alloc(vm->dram_pg_pool, total_size);
                if (!paddr) {
                        dev_err(hdev->dev,
                                "failed to allocate %llu contiguous pages with total size of %llu\n",
                        phys_pg_pack->pages[i] = paddr + i * page_size;
        } else {
                for (i = 0 ; i < num_pgs ; i++) {
-                       phys_pg_pack->pages[i] = (u64) gen_pool_alloc(
-                                                       vm->dram_pg_pool,
-                                                       page_size);
+                       if (is_power_of_2(page_size))
+                               phys_pg_pack->pages[i] =
+                                               (u64) gen_pool_dma_alloc_align(vm->dram_pg_pool,
+                                                                               page_size, NULL,
+                                                                               page_size);
+                       else
+                               phys_pg_pack->pages[i] = (u64) gen_pool_alloc(vm->dram_pg_pool,
+                                                                               page_size);
                        if (!phys_pg_pack->pages[i]) {
                                dev_err(hdev->dev,
                                        "Failed to allocate device memory (out of memory)\n");