summaryrefslogtreecommitdiff
path: root/drivers
diff options
context:
space:
mode:
Diffstat (limited to 'drivers')
-rw-r--r--drivers/vhost/vhost.c64
-rw-r--r--drivers/vhost/vhost.h3
2 files changed, 50 insertions, 17 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 12203d3893c5..ffbaf7d32e2c 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -280,7 +280,11 @@ EXPORT_SYMBOL_GPL(vhost_vq_flush);
void vhost_dev_flush(struct vhost_dev *dev)
{
- vhost_worker_flush(dev->worker);
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ xa_for_each(&dev->worker_xa, i, worker)
+ vhost_worker_flush(worker);
}
EXPORT_SYMBOL_GPL(vhost_dev_flush);
@@ -482,7 +486,6 @@ void vhost_dev_init(struct vhost_dev *dev,
dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL;
- dev->worker = NULL;
dev->iov_limit = iov_limit;
dev->weight = weight;
dev->byte_weight = byte_weight;
@@ -492,7 +495,7 @@ void vhost_dev_init(struct vhost_dev *dev,
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
-
+ xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i];
@@ -554,15 +557,35 @@ static void vhost_detach_mm(struct vhost_dev *dev)
dev->mm = NULL;
}
-static void vhost_worker_free(struct vhost_dev *dev)
+static void vhost_worker_destroy(struct vhost_dev *dev,
+ struct vhost_worker *worker)
+{
+ if (!worker)
+ return;
+
+ WARN_ON(!llist_empty(&worker->work_list));
+ xa_erase(&dev->worker_xa, worker->id);
+ vhost_task_stop(worker->vtsk);
+ kfree(worker);
+}
+
+static void vhost_workers_free(struct vhost_dev *dev)
{
- if (!dev->worker)
+ struct vhost_worker *worker;
+ unsigned long i;
+
+ if (!dev->use_worker)
return;
- WARN_ON(!llist_empty(&dev->worker->work_list));
- vhost_task_stop(dev->worker->vtsk);
- kfree(dev->worker);
- dev->worker = NULL;
+ for (i = 0; i < dev->nvqs; i++)
+ dev->vqs[i]->worker = NULL;
+ /*
+ * Free the default worker we created and cleanup workers userspace
+ * created but couldn't clean up (it forgot or crashed).
+ */
+ xa_for_each(&dev->worker_xa, i, worker)
+ vhost_worker_destroy(dev, worker);
+ xa_destroy(&dev->worker_xa);
}
static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
@@ -570,6 +593,8 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
struct vhost_worker *worker;
struct vhost_task *vtsk;
char name[TASK_COMM_LEN];
+ int ret;
+ u32 id;
worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
if (!worker)
@@ -584,16 +609,18 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
init_llist_head(&worker->work_list);
worker->kcov_handle = kcov_common_handle();
worker->vtsk = vtsk;
- /*
- * vsock can already try to queue so make sure llist and vtsk are both
- * set before vhost_work_queue sees dev->worker is set.
- */
- smp_wmb();
- dev->worker = worker;
vhost_task_start(vtsk);
+
+ ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
+ if (ret < 0)
+ goto stop_worker;
+ worker->id = id;
+
return worker;
+stop_worker:
+ vhost_task_stop(vtsk);
free_worker:
kfree(worker);
return NULL;
@@ -650,6 +677,11 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
err = -ENOMEM;
goto err_worker;
}
+ /*
+ * vsock can already try to queue so make sure the worker
+ * is setup before vhost_vq_work_queue sees vq->worker is set.
+ */
+ smp_wmb();
for (i = 0; i < dev->nvqs; i++)
dev->vqs[i]->worker = worker;
@@ -751,7 +783,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
- vhost_worker_free(dev);
+ vhost_workers_free(dev);
vhost_detach_mm(dev);
}
EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index b850f534bc9a..31937e98c01b 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -30,6 +30,7 @@ struct vhost_worker {
struct vhost_task *vtsk;
struct llist_head work_list;
u64 kcov_handle;
+ u32 id;
};
/* Poll a file (eventfd or socket) */
@@ -159,7 +160,6 @@ struct vhost_dev {
struct vhost_virtqueue **vqs;
int nvqs;
struct eventfd_ctx *log_ctx;
- struct vhost_worker *worker;
struct vhost_iotlb *umem;
struct vhost_iotlb *iotlb;
spinlock_t iotlb_lock;
@@ -169,6 +169,7 @@ struct vhost_dev {
int iov_limit;
int weight;
int byte_weight;
+ struct xarray worker_xa;
bool use_worker;
int (*msg_handler)(struct vhost_dev *dev, u32 asid,
struct vhost_iotlb_msg *msg);