#include <linux/bitmap.h>
 #include <linux/lockdep.h>
 
+#include <acpi/acpi_numa.h>
+
 enum virtio_mem_mb_state {
        /* Unplugged, not added to Linux. Can be reused later. */
        VIRTIO_MEM_MB_STATE_UNUSED = 0,
 
        /* The device block size (for communicating with the device). */
        uint32_t device_block_size;
+       /* The translated node id. NUMA_NO_NODE in case not specified. */
+       int nid;
        /* Physical start address of the memory region. */
        uint64_t addr;
        /* Maximum region size in bytes. */
 static int virtio_mem_mb_add(struct virtio_mem *vm, unsigned long mb_id)
 {
        const uint64_t addr = virtio_mem_mb_id_to_phys(mb_id);
-       int nid = memory_add_physaddr_to_nid(addr);
+       int nid = vm->nid;
+
+       if (nid == NUMA_NO_NODE)
+               nid = memory_add_physaddr_to_nid(addr);
 
        dev_dbg(&vm->vdev->dev, "adding memory block: %lu\n", mb_id);
        return add_memory(nid, addr, memory_block_size_bytes());
 static int virtio_mem_mb_remove(struct virtio_mem *vm, unsigned long mb_id)
 {
        const uint64_t addr = virtio_mem_mb_id_to_phys(mb_id);
-       int nid = memory_add_physaddr_to_nid(addr);
+       int nid = vm->nid;
+
+       if (nid == NUMA_NO_NODE)
+               nid = memory_add_physaddr_to_nid(addr);
 
        dev_dbg(&vm->vdev->dev, "removing memory block: %lu\n", mb_id);
        return remove_memory(nid, addr, memory_block_size_bytes());
        spin_unlock_irqrestore(&vm->removal_lock, flags);
 }
 
+static int virtio_mem_translate_node_id(struct virtio_mem *vm, uint16_t node_id)
+{
+       int node = NUMA_NO_NODE;
+
+#if defined(CONFIG_ACPI_NUMA)
+       if (virtio_has_feature(vm->vdev, VIRTIO_MEM_F_ACPI_PXM))
+               node = pxm_to_node(node_id);
+#endif
+       return node;
+}
+
 /*
  * Test if a virtio-mem device overlaps with the given range. Can be called
  * from (notifier) callbacks lockless.
 static int virtio_mem_init(struct virtio_mem *vm)
 {
        const uint64_t phys_limit = 1UL << MAX_PHYSMEM_BITS;
+       uint16_t node_id;
 
        if (!vm->vdev->config->get) {
                dev_err(&vm->vdev->dev, "config access disabled\n");
                     &vm->plugged_size);
        virtio_cread(vm->vdev, struct virtio_mem_config, block_size,
                     &vm->device_block_size);
+       virtio_cread(vm->vdev, struct virtio_mem_config, node_id,
+                    &node_id);
+       vm->nid = virtio_mem_translate_node_id(vm, node_id);
        virtio_cread(vm->vdev, struct virtio_mem_config, addr, &vm->addr);
        virtio_cread(vm->vdev, struct virtio_mem_config, region_size,
                     &vm->region_size);
                 memory_block_size_bytes());
        dev_info(&vm->vdev->dev, "subblock size: 0x%x",
                 vm->subblock_size);
+       if (vm->nid != NUMA_NO_NODE)
+               dev_info(&vm->vdev->dev, "nid: %d", vm->nid);
 
        return 0;
 }
 }
 #endif
 
+static unsigned int virtio_mem_features[] = {
+#if defined(CONFIG_NUMA) && defined(CONFIG_ACPI_NUMA)
+       VIRTIO_MEM_F_ACPI_PXM,
+#endif
+};
+
 static struct virtio_device_id virtio_mem_id_table[] = {
        { VIRTIO_ID_MEM, VIRTIO_DEV_ANY_ID },
        { 0 },
 };
 
 static struct virtio_driver virtio_mem_driver = {
+       .feature_table = virtio_mem_features,
+       .feature_table_size = ARRAY_SIZE(virtio_mem_features),
        .driver.name = KBUILD_MODNAME,
        .driver.owner = THIS_MODULE,
        .id_table = virtio_mem_id_table,