net.c 30.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* Copyright (C) 2009 Red Hat, Inc.
 * Author: Michael S. Tsirkin <mst@redhat.com>
 *
 * This work is licensed under the terms of the GNU GPL, version 2.
 *
 * virtio-net server in host kernel.
 */

#include <linux/compat.h>
#include <linux/eventfd.h>
#include <linux/vhost.h>
#include <linux/virtio_net.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
15
#include <linux/moduleparam.h>
16
17
18
#include <linux/mutex.h>
#include <linux/workqueue.h>
#include <linux/file.h>
19
#include <linux/slab.h>
20
#include <linux/vmalloc.h>
21
22
23
24
25

#include <linux/net.h>
#include <linux/if_packet.h>
#include <linux/if_arp.h>
#include <linux/if_tun.h>
Arnd Bergmann's avatar
Arnd Bergmann committed
26
#include <linux/if_macvlan.h>
27
#include <linux/if_vlan.h>
28
29
30
31
32

#include <net/sock.h>

#include "vhost.h"

33
static int experimental_zcopytx = 1;
34
module_param(experimental_zcopytx, int, 0444);
35
36
MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
		                       " 1 -Enable; 0 - Disable");
37

38
39
40
41
/* Max number of bytes transferred before requeueing the job.
 * Using this limit prevents one virtqueue from starving others. */
#define VHOST_NET_WEIGHT 0x80000

42
43
44
45
/* MAX number of TX used buffers for outstanding zerocopy */
#define VHOST_MAX_PEND 128
#define VHOST_GOODCOPY_LEN 256

46
47
48
49
50
/*
 * For transmit, used buffer len is unused; we override it to track buffer
 * status internally; used for zerocopy tx only.
 */
/* Lower device DMA failed */
51
#define VHOST_DMA_FAILED_LEN	((__force __virtio32)3)
52
/* Lower device DMA done */
53
#define VHOST_DMA_DONE_LEN	((__force __virtio32)2)
54
/* Lower device DMA in progress */
55
#define VHOST_DMA_IN_PROGRESS	((__force __virtio32)1)
56
/* Buffer unused */
57
#define VHOST_DMA_CLEAR_LEN	((__force __virtio32)0)
58

59
#define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN)
60

61
62
63
enum {
	VHOST_NET_FEATURES = VHOST_FEATURES |
			 (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
Jason Wang's avatar
Jason Wang committed
64
65
			 (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
			 (1ULL << VIRTIO_F_IOMMU_PLATFORM)
66
67
};

68
69
70
71
72
73
enum {
	VHOST_NET_VQ_RX = 0,
	VHOST_NET_VQ_TX = 1,
	VHOST_NET_VQ_MAX = 2,
};

74
struct vhost_net_ubuf_ref {
75
76
77
78
79
80
	/* refcount follows semantics similar to kref:
	 *  0: object is released
	 *  1: no outstanding ubufs
	 * >1: outstanding ubufs
	 */
	atomic_t refcount;
81
82
83
84
	wait_queue_head_t wait;
	struct vhost_virtqueue *vq;
};

85
86
struct vhost_net_virtqueue {
	struct vhost_virtqueue vq;
87
88
	size_t vhost_hlen;
	size_t sock_hlen;
89
90
91
92
93
94
95
96
97
	/* vhost zerocopy support fields below: */
	/* last used idx for outstanding DMA zerocopy buffers */
	int upend_idx;
	/* first used idx for DMA done zerocopy buffers */
	int done_idx;
	/* an array of userspace buffers info */
	struct ubuf_info *ubuf_info;
	/* Reference counting for outstanding ubufs.
	 * Protected by vq mutex. Writers must also take device mutex. */
98
	struct vhost_net_ubuf_ref *ubufs;
99
100
};

101
102
struct vhost_net {
	struct vhost_dev dev;
103
	struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
104
	struct vhost_poll poll[VHOST_NET_VQ_MAX];
105
106
107
108
109
110
	/* Number of TX recently submitted.
	 * Protected by tx vq lock. */
	unsigned tx_packets;
	/* Number of times zerocopy TX recently failed.
	 * Protected by tx vq lock. */
	unsigned tx_zcopy_err;
111
112
	/* Flush in progress. Protected by tx vq lock. */
	bool tx_flush;
113
114
};

115
static unsigned vhost_net_zcopy_mask __read_mostly;
116

117
static void vhost_net_enable_zcopy(int vq)
118
{
119
	vhost_net_zcopy_mask |= 0x1 << vq;
120
121
}

122
123
static struct vhost_net_ubuf_ref *
vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
124
{
125
	struct vhost_net_ubuf_ref *ubufs;
126
127
128
129
130
131
	/* No zero copy backend? Nothing to count. */
	if (!zcopy)
		return NULL;
	ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
	if (!ubufs)
		return ERR_PTR(-ENOMEM);
132
	atomic_set(&ubufs->refcount, 1);
133
134
135
136
137
	init_waitqueue_head(&ubufs->wait);
	ubufs->vq = vq;
	return ubufs;
}

138
static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
139
{
140
141
142
143
	int r = atomic_sub_return(1, &ubufs->refcount);
	if (unlikely(!r))
		wake_up(&ubufs->wait);
	return r;
144
145
}

146
static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
147
{
148
149
	vhost_net_ubuf_put(ubufs);
	wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
150
151
152
153
154
}

static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
{
	vhost_net_ubuf_put_and_wait(ubufs);
155
156
157
	kfree(ubufs);
}

158
159
160
161
static void vhost_net_clear_ubuf_info(struct vhost_net *n)
{
	int i;

162
163
164
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
		kfree(n->vqs[i].ubuf_info);
		n->vqs[i].ubuf_info = NULL;
165
166
167
	}
}

Asias He's avatar
Asias He committed
168
static int vhost_net_set_ubuf_info(struct vhost_net *n)
169
170
171
172
{
	bool zcopy;
	int i;

173
	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
174
		zcopy = vhost_net_zcopy_mask & (0x1 << i);
175
176
177
178
179
180
181
182
183
184
		if (!zcopy)
			continue;
		n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
					      UIO_MAXIOV, GFP_KERNEL);
		if  (!n->vqs[i].ubuf_info)
			goto err;
	}
	return 0;

err:
185
	vhost_net_clear_ubuf_info(n);
186
187
188
	return -ENOMEM;
}

Asias He's avatar
Asias He committed
189
static void vhost_net_vq_reset(struct vhost_net *n)
190
191
192
{
	int i;

193
194
	vhost_net_clear_ubuf_info(n);

195
196
197
198
	for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
		n->vqs[i].done_idx = 0;
		n->vqs[i].upend_idx = 0;
		n->vqs[i].ubufs = NULL;
199
200
		n->vqs[i].vhost_hlen = 0;
		n->vqs[i].sock_hlen = 0;
201
202
203
204
	}

}

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
static void vhost_net_tx_packet(struct vhost_net *net)
{
	++net->tx_packets;
	if (net->tx_packets < 1024)
		return;
	net->tx_packets = 0;
	net->tx_zcopy_err = 0;
}

static void vhost_net_tx_err(struct vhost_net *net)
{
	++net->tx_zcopy_err;
}

static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
{
221
222
223
224
225
	/* TX flush waits for outstanding DMAs to be done.
	 * Don't start new DMAs.
	 */
	return !net->tx_flush &&
		net->tx_packets / 64 >= net->tx_zcopy_err;
226
227
}

228
229
230
231
232
233
static bool vhost_sock_zcopy(struct socket *sock)
{
	return unlikely(experimental_zcopytx) &&
		sock_flag(sock->sk, SOCK_ZEROCOPY);
}

234
235
236
237
238
/* In case of DMA done not in order in lower device driver for some reason.
 * upend_idx is used to track end of used idx, done_idx is used to track head
 * of used idx. Once lower device DMA done contiguously, we will signal KVM
 * guest used idx.
 */
239
240
static void vhost_zerocopy_signal_used(struct vhost_net *net,
				       struct vhost_virtqueue *vq)
241
{
242
243
	struct vhost_net_virtqueue *nvq =
		container_of(vq, struct vhost_net_virtqueue, vq);
244
	int i, add;
245
246
	int j = 0;

247
	for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
248
249
		if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
			vhost_net_tx_err(net);
250
251
252
253
254
255
		if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
			vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
			++j;
		} else
			break;
	}
256
257
258
259
260
261
262
	while (j) {
		add = min(UIO_MAXIOV - nvq->done_idx, j);
		vhost_add_used_and_signal_n(vq->dev, vq,
					    &vq->heads[nvq->done_idx], add);
		nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
		j -= add;
	}
263
264
}

265
static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
266
{
267
	struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
268
	struct vhost_virtqueue *vq = ubufs->vq;
269
	int cnt;
270

271
272
	rcu_read_lock_bh();

273
274
275
	/* set len to mark this desc buffers done DMA */
	vq->heads[ubuf->desc].len = success ?
		VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
276
	cnt = vhost_net_ubuf_put(ubufs);
277

278
279
	/*
	 * Trigger polling thread if guest stopped submitting new buffers:
280
	 * in this case, the refcount after decrement will eventually reach 1.
281
282
283
284
	 * We also trigger polling periodically after each 16 packets
	 * (the value 16 here is more or less arbitrary, it's tuned to trigger
	 * less than 10% of times).
	 */
285
	if (cnt <= 1 || !(cnt % 16))
286
		vhost_poll_queue(&vq->poll);
287
288

	rcu_read_unlock_bh();
289
290
}

Jason Wang's avatar
Jason Wang committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
static inline unsigned long busy_clock(void)
{
	return local_clock() >> 10;
}

static bool vhost_can_busy_poll(struct vhost_dev *dev,
				unsigned long endtime)
{
	return likely(!need_resched()) &&
	       likely(!time_after(busy_clock(), endtime)) &&
	       likely(!signal_pending(current)) &&
	       !vhost_has_work(dev);
}

static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
				    struct vhost_virtqueue *vq,
				    struct iovec iov[], unsigned int iov_size,
				    unsigned int *out_num, unsigned int *in_num)
{
	unsigned long uninitialized_var(endtime);
	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
Jason Wang's avatar
Jason Wang committed
312
				  out_num, in_num, NULL, NULL);
Jason Wang's avatar
Jason Wang committed
313
314
315
316
317
318
319
320
321

	if (r == vq->num && vq->busyloop_timeout) {
		preempt_disable();
		endtime = busy_clock() + vq->busyloop_timeout;
		while (vhost_can_busy_poll(vq->dev, endtime) &&
		       vhost_vq_avail_empty(vq->dev, vq))
			cpu_relax_lowlatency();
		preempt_enable();
		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
Jason Wang's avatar
Jason Wang committed
322
				      out_num, in_num, NULL, NULL);
Jason Wang's avatar
Jason Wang committed
323
324
325
326
327
	}

	return r;
}

328
329
330
331
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
static void handle_tx(struct vhost_net *net)
{
332
	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
333
	struct vhost_virtqueue *vq = &nvq->vq;
334
	unsigned out, in;
335
	int head;
336
337
338
339
340
341
342
343
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL,
		.msg_controllen = 0,
		.msg_flags = MSG_DONTWAIT,
	};
	size_t len, total_len = 0;
344
	int err;
345
	size_t hdr_size;
Arnd Bergmann's avatar
Arnd Bergmann committed
346
	struct socket *sock;
347
	struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
348
	bool zcopy, zcopy_used;
Arnd Bergmann's avatar
Arnd Bergmann committed
349

350
351
	mutex_lock(&vq->mutex);
	sock = vq->private_data;
352
	if (!sock)
353
		goto out;
354

Jason Wang's avatar
Jason Wang committed
355
356
357
	if (!vq_iotlb_prefetch(vq))
		goto out;

358
	vhost_disable_notify(&net->dev, vq);
359

360
	hdr_size = nvq->vhost_hlen;
361
	zcopy = nvq->ubufs;
362
363

	for (;;) {
364
365
		/* Release DMAs done buffers first */
		if (zcopy)
366
			vhost_zerocopy_signal_used(net, vq);
367

368
369
370
371
372
373
374
		/* If more outstanding DMAs, queue the work.
		 * Handle upend_idx wrap around
		 */
		if (unlikely((nvq->upend_idx + vq->num - VHOST_MAX_PEND)
			      % UIO_MAXIOV == nvq->done_idx))
			break;

Jason Wang's avatar
Jason Wang committed
375
376
377
		head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
						ARRAY_SIZE(vq->iov),
						&out, &in);
378
		/* On error, stop handling until the next kick. */
379
		if (unlikely(head < 0))
380
			break;
381
382
		/* Nothing new?  Wait for eventfd to tell us they refilled. */
		if (head == vq->num) {
383
384
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
				vhost_disable_notify(&net->dev, vq);
385
386
387
388
389
390
391
392
393
394
395
				continue;
			}
			break;
		}
		if (in) {
			vq_err(vq, "Unexpected descriptor format for TX: "
			       "out %d, int %d\n", out, in);
			break;
		}
		/* Skip header. TODO: support TSO. */
		len = iov_length(vq->iov, out);
Al Viro's avatar
Al Viro committed
396
		iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
397
		iov_iter_advance(&msg.msg_iter, hdr_size);
398
		/* Sanity check */
Al Viro's avatar
Al Viro committed
399
		if (!msg_data_left(&msg)) {
400
401
			vq_err(vq, "Unexpected header len for TX: "
			       "%zd expected %zd\n",
402
			       len, hdr_size);
403
404
			break;
		}
Al Viro's avatar
Al Viro committed
405
		len = msg_data_left(&msg);
406
407
408
409
410

		zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
				   && (nvq->upend_idx + 1) % UIO_MAXIOV !=
				      nvq->done_idx
				   && vhost_net_tx_select_zcopy(net);
411

412
		/* use msg_control to pass vhost zerocopy ubuf info to skb */
413
		if (zcopy_used) {
414
415
416
			struct ubuf_info *ubuf;
			ubuf = nvq->ubuf_info + nvq->upend_idx;

417
			vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
418
419
420
421
422
423
424
			vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
			ubuf->callback = vhost_zerocopy_callback;
			ubuf->ctx = nvq->ubufs;
			ubuf->desc = nvq->upend_idx;
			msg.msg_control = ubuf;
			msg.msg_controllen = sizeof(ubuf);
			ubufs = nvq->ubufs;
425
			atomic_inc(&ubufs->refcount);
426
			nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
427
		} else {
428
			msg.msg_control = NULL;
429
430
			ubufs = NULL;
		}
431
		/* TODO: Check specific error and bomb out unless ENOBUFS? */
432
		err = sock->ops->sendmsg(sock, &msg, len);
433
		if (unlikely(err < 0)) {
434
			if (zcopy_used) {
435
				vhost_net_ubuf_put(ubufs);
436
437
				nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
					% UIO_MAXIOV;
438
			}
439
			vhost_discard_vq_desc(vq, 1);
440
441
442
			break;
		}
		if (err != len)
443
444
			pr_debug("Truncated TX packet: "
				 " len %d != %zd\n", err, len);
445
		if (!zcopy_used)
446
			vhost_add_used_and_signal(&net->dev, vq, head, 0);
447
		else
448
			vhost_zerocopy_signal_used(net, vq);
449
		total_len += len;
450
		vhost_net_tx_packet(net);
451
452
453
454
455
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}
456
out:
457
458
459
	mutex_unlock(&vq->mutex);
}

460
461
462
463
static int peek_head_len(struct sock *sk)
{
	struct sk_buff *head;
	int len = 0;
464
	unsigned long flags;
465

466
	spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
467
	head = skb_peek(&sk->sk_receive_queue);
468
	if (likely(head)) {
469
		len = head->len;
470
		if (skb_vlan_tag_present(head))
471
472
473
			len += VLAN_HLEN;
	}

474
	spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
475
476
477
	return len;
}

Jason Wang's avatar
Jason Wang committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
{
	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
	struct vhost_virtqueue *vq = &nvq->vq;
	unsigned long uninitialized_var(endtime);
	int len = peek_head_len(sk);

	if (!len && vq->busyloop_timeout) {
		/* Both tx vq and rx socket were polled here */
		mutex_lock(&vq->mutex);
		vhost_disable_notify(&net->dev, vq);

		preempt_disable();
		endtime = busy_clock() + vq->busyloop_timeout;

		while (vhost_can_busy_poll(&net->dev, endtime) &&
		       skb_queue_empty(&sk->sk_receive_queue) &&
		       vhost_vq_avail_empty(&net->dev, vq))
			cpu_relax_lowlatency();

		preempt_enable();

		if (vhost_enable_notify(&net->dev, vq))
			vhost_poll_queue(&vq->poll);
		mutex_unlock(&vq->mutex);

		len = peek_head_len(sk);
	}

	return len;
}

510
511
512
513
514
515
516
/* This is a multi-buffer version of vhost_get_desc, that works if
 *	vq has read descriptors only.
 * @vq		- the relevant virtqueue
 * @datalen	- data length we'll be reading
 * @iovcount	- returned count of io vectors we fill
 * @log		- vhost log
 * @log_num	- log offset
517
 * @quota       - headcount quota, 1 for big buffer
518
519
520
521
522
523
524
 *	returns number of buffer heads allocated, negative on error
 */
static int get_rx_bufs(struct vhost_virtqueue *vq,
		       struct vring_used_elem *heads,
		       int datalen,
		       unsigned *iovcount,
		       struct vhost_log *log,
525
526
		       unsigned *log_num,
		       unsigned int quota)
527
528
529
530
531
532
{
	unsigned int out, in;
	int seg = 0;
	int headcount = 0;
	unsigned d;
	int r, nlogs = 0;
533
534
535
536
	/* len is always initialized before use since we are always called with
	 * datalen > 0.
	 */
	u32 uninitialized_var(len);
537

538
	while (datalen > 0 && headcount < quota) {
Jason Wang's avatar
Jason Wang committed
539
		if (unlikely(seg >= UIO_MAXIOV)) {
540
541
542
			r = -ENOBUFS;
			goto err;
		}
543
		r = vhost_get_vq_desc(vq, vq->iov + seg,
544
545
				      ARRAY_SIZE(vq->iov) - seg, &out,
				      &in, log, log_num);
546
547
548
549
		if (unlikely(r < 0))
			goto err;

		d = r;
550
551
552
553
554
555
556
557
558
559
560
561
562
563
		if (d == vq->num) {
			r = 0;
			goto err;
		}
		if (unlikely(out || in <= 0)) {
			vq_err(vq, "unexpected descriptor format for RX: "
				"out %d, in %d\n", out, in);
			r = -EINVAL;
			goto err;
		}
		if (unlikely(log)) {
			nlogs += *log_num;
			log += *log_num;
		}
564
565
566
567
		heads[headcount].id = cpu_to_vhost32(vq, d);
		len = iov_length(vq->iov + seg, in);
		heads[headcount].len = cpu_to_vhost32(vq, len);
		datalen -= len;
568
569
570
		++headcount;
		seg += in;
	}
571
	heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
572
573
574
	*iovcount = seg;
	if (unlikely(log))
		*log_num = nlogs;
575
576
577
578
579
580

	/* Detect overrun */
	if (unlikely(datalen > 0)) {
		r = UIO_MAXIOV + 1;
		goto err;
	}
581
582
583
584
585
586
	return headcount;
err:
	vhost_discard_vq_desc(vq, headcount);
	return r;
}

587
588
/* Expects to be always run from workqueue - which acts as
 * read-size critical section for our kind of RCU. */
589
static void handle_rx(struct vhost_net *net)
590
{
591
592
	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
	struct vhost_virtqueue *vq = &nvq->vq;
593
594
595
596
597
598
599
600
601
	unsigned uninitialized_var(in), log;
	struct vhost_log *vq_log;
	struct msghdr msg = {
		.msg_name = NULL,
		.msg_namelen = 0,
		.msg_control = NULL, /* FIXME: get and handle RX aux data. */
		.msg_controllen = 0,
		.msg_flags = MSG_DONTWAIT,
	};
602
603
604
	struct virtio_net_hdr hdr = {
		.flags = 0,
		.gso_type = VIRTIO_NET_HDR_GSO_NONE
605
606
	};
	size_t total_len = 0;
607
608
	int err, mergeable;
	s16 headcount;
609
610
	size_t vhost_hlen, sock_hlen;
	size_t vhost_len, sock_len;
611
	struct socket *sock;
612
	struct iov_iter fixup;
613
	__virtio16 num_buffers;
614
615

	mutex_lock(&vq->mutex);
616
617
618
	sock = vq->private_data;
	if (!sock)
		goto out;
Jason Wang's avatar
Jason Wang committed
619
620
621
622

	if (!vq_iotlb_prefetch(vq))
		goto out;

623
	vhost_disable_notify(&net->dev, vq);
624

625
626
	vhost_hlen = nvq->vhost_hlen;
	sock_hlen = nvq->sock_hlen;
627

628
	vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
629
		vq->log : NULL;
630
	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
631

Jason Wang's avatar
Jason Wang committed
632
	while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
633
634
635
		sock_len += sock_hlen;
		vhost_len = sock_len + vhost_hlen;
		headcount = get_rx_bufs(vq, vq->heads, vhost_len,
636
637
					&in, vq_log, &log,
					likely(mergeable) ? UIO_MAXIOV : 1);
638
639
640
		/* On error, stop handling until the next kick. */
		if (unlikely(headcount < 0))
			break;
641
642
		/* On overrun, truncate and discard */
		if (unlikely(headcount > UIO_MAXIOV)) {
Al Viro's avatar
Al Viro committed
643
			iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
644
			err = sock->ops->recvmsg(sock, &msg,
645
646
647
648
						 1, MSG_DONTWAIT | MSG_TRUNC);
			pr_debug("Discarded rx packet: len %zd\n", sock_len);
			continue;
		}
649
650
		/* OK, now we need to know about added descriptors. */
		if (!headcount) {
651
			if (unlikely(vhost_enable_notify(&net->dev, vq))) {
652
653
				/* They have slipped one in as we were
				 * doing that: check again. */
654
				vhost_disable_notify(&net->dev, vq);
655
656
657
658
659
660
661
				continue;
			}
			/* Nothing new?  Wait for eventfd to tell us
			 * they refilled. */
			break;
		}
		/* We don't need to be notified again. */
662
663
664
665
666
667
668
669
		iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
		fixup = msg.msg_iter;
		if (unlikely((vhost_hlen))) {
			/* We will supply the header ourselves
			 * TODO: support TSO.
			 */
			iov_iter_advance(&msg.msg_iter, vhost_hlen);
		}
670
		err = sock->ops->recvmsg(sock, &msg,
671
672
673
674
675
676
677
678
679
680
					 sock_len, MSG_DONTWAIT | MSG_TRUNC);
		/* Userspace might have consumed the packet meanwhile:
		 * it's not supposed to do this usually, but might be hard
		 * to prevent. Discard data we got (if any) and keep going. */
		if (unlikely(err != sock_len)) {
			pr_debug("Discarded rx packet: "
				 " len %d, expected %zd\n", err, sock_len);
			vhost_discard_vq_desc(vq, headcount);
			continue;
		}
681
		/* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
682
683
684
685
686
687
688
689
690
691
692
693
		if (unlikely(vhost_hlen)) {
			if (copy_to_iter(&hdr, sizeof(hdr),
					 &fixup) != sizeof(hdr)) {
				vq_err(vq, "Unable to write vnet_hdr "
				       "at addr %p\n", vq->iov->iov_base);
				break;
			}
		} else {
			/* Header came from socket; we'll need to patch
			 * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
			 */
			iov_iter_advance(&fixup, sizeof(hdr));
694
695
		}
		/* TODO: Should check and handle checksum. */
696

697
		num_buffers = cpu_to_vhost16(vq, headcount);
698
		if (likely(mergeable) &&
699
700
		    copy_to_iter(&num_buffers, sizeof num_buffers,
				 &fixup) != sizeof num_buffers) {
701
702
703
704
705
706
707
708
709
710
711
712
713
714
			vq_err(vq, "Failed num_buffers write");
			vhost_discard_vq_desc(vq, headcount);
			break;
		}
		vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
					    headcount);
		if (unlikely(vq_log))
			vhost_log_write(vq, vq_log, log, vhost_len);
		total_len += vhost_len;
		if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
			vhost_poll_queue(&vq->poll);
			break;
		}
	}
715
out:
716
717
718
	mutex_unlock(&vq->mutex);
}

719
static void handle_tx_kick(struct vhost_work *work)
720
{
721
722
723
724
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

725
726
727
	handle_tx(net);
}

728
static void handle_rx_kick(struct vhost_work *work)
729
{
730
731
732
733
	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
						  poll.work);
	struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);

734
735
736
	handle_rx(net);
}

737
static void handle_tx_net(struct vhost_work *work)
738
{
739
740
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_TX].work);
741
742
743
	handle_tx(net);
}

744
static void handle_rx_net(struct vhost_work *work)
745
{
746
747
	struct vhost_net *net = container_of(work, struct vhost_net,
					     poll[VHOST_NET_VQ_RX].work);
748
749
750
751
752
	handle_rx(net);
}

static int vhost_net_open(struct inode *inode, struct file *f)
{
753
	struct vhost_net *n;
754
	struct vhost_dev *dev;
755
	struct vhost_virtqueue **vqs;
Zhi Yong Wu's avatar
Zhi Yong Wu committed
756
	int i;
757

758
759
760
761
762
763
	n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
	if (!n) {
		n = vmalloc(sizeof *n);
		if (!n)
			return -ENOMEM;
	}
764
765
	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
	if (!vqs) {
766
		kvfree(n);
767
768
		return -ENOMEM;
	}
769
770

	dev = &n->dev;
771
772
773
774
	vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
	vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
	n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
	n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
775
776
777
778
779
	for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
		n->vqs[i].ubufs = NULL;
		n->vqs[i].ubuf_info = NULL;
		n->vqs[i].upend_idx = 0;
		n->vqs[i].done_idx = 0;
780
781
		n->vqs[i].vhost_hlen = 0;
		n->vqs[i].sock_hlen = 0;
782
	}
Zhi Yong Wu's avatar
Zhi Yong Wu committed
783
	vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
784

785
786
	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
787
788
789
790
791
792
793
794
795

	f->private_data = n;

	return 0;
}

static void vhost_net_disable_vq(struct vhost_net *n,
				 struct vhost_virtqueue *vq)
{
796
797
798
	struct vhost_net_virtqueue *nvq =
		container_of(vq, struct vhost_net_virtqueue, vq);
	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
799
800
	if (!vq->private_data)
		return;
801
	vhost_poll_stop(poll);
802
803
}

804
static int vhost_net_enable_vq(struct vhost_net *n,
805
806
				struct vhost_virtqueue *vq)
{
807
808
809
	struct vhost_net_virtqueue *nvq =
		container_of(vq, struct vhost_net_virtqueue, vq);
	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
Arnd Bergmann's avatar
Arnd Bergmann committed
810
811
	struct socket *sock;

812
	sock = vq->private_data;
813
	if (!sock)
814
815
		return 0;

816
	return vhost_poll_start(poll, sock->file);
817
818
819
820
821
822
823
824
}

static struct socket *vhost_net_stop_vq(struct vhost_net *n,
					struct vhost_virtqueue *vq)
{
	struct socket *sock;

	mutex_lock(&vq->mutex);
825
	sock = vq->private_data;
826
	vhost_net_disable_vq(n, vq);
827
	vq->private_data = NULL;
828
829
830
831
832
833
834
	mutex_unlock(&vq->mutex);
	return sock;
}

static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
			   struct socket **rx_sock)
{
835
836
	*tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
	*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
837
838
839
840
841
}

static void vhost_net_flush_vq(struct vhost_net *n, int index)
{
	vhost_poll_flush(n->poll + index);
842
	vhost_poll_flush(&n->vqs[index].vq.poll);
843
844
845
846
847
848
}

static void vhost_net_flush(struct vhost_net *n)
{
	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
849
	if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
850
		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
851
		n->tx_flush = true;
852
		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
853
		/* Wait for all lower device DMAs done. */
854
		vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
855
		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
856
		n->tx_flush = false;
857
		atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
858
		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
859
	}
860
861
862
863
864
865
866
867
868
869
}

static int vhost_net_release(struct inode *inode, struct file *f)
{
	struct vhost_net *n = f->private_data;
	struct socket *tx_sock;
	struct socket *rx_sock;

	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
870
	vhost_dev_stop(&n->dev);
871
	vhost_dev_cleanup(&n->dev, false);
872
	vhost_net_vq_reset(n);
873
	if (tx_sock)
Al Viro's avatar
Al Viro committed
874
		sockfd_put(tx_sock);
875
	if (rx_sock)
Al Viro's avatar
Al Viro committed
876
		sockfd_put(rx_sock);
877
878
	/* Make sure no callbacks are outstanding */
	synchronize_rcu_bh();
879
880
881
	/* We do an extra flush before freeing memory,
	 * since jobs can re-queue themselves. */
	vhost_net_flush(n);
882
	kfree(n->dev.vqs);
883
	kvfree(n);
884
885
886
887
888
889
890
891
892
893
894
	return 0;
}

static struct socket *get_raw_socket(int fd)
{
	struct {
		struct sockaddr_ll sa;
		char  buf[MAX_ADDR_LEN];
	} uaddr;
	int uaddr_len = sizeof uaddr, r;
	struct socket *sock = sockfd_lookup(fd, &r);
895

896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
	if (!sock)
		return ERR_PTR(-ENOTSOCK);

	/* Parameter checking */
	if (sock->sk->sk_type != SOCK_RAW) {
		r = -ESOCKTNOSUPPORT;
		goto err;
	}

	r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
			       &uaddr_len, 0);
	if (r)
		goto err;

	if (uaddr.sa.sll_family != AF_PACKET) {
		r = -EPFNOSUPPORT;
		goto err;
	}
	return sock;
err:
Al Viro's avatar
Al Viro committed
916
	sockfd_put(sock);
917
918
919
	return ERR_PTR(r);
}

Arnd Bergmann's avatar
Arnd Bergmann committed
920
static struct socket *get_tap_socket(int fd)
921
922
923
{
	struct file *file = fget(fd);
	struct socket *sock;
924

925
926
927
	if (!file)
		return ERR_PTR(-EBADF);
	sock = tun_get_socket(file);
Arnd Bergmann's avatar
Arnd Bergmann committed
928
929
930
	if (!IS_ERR(sock))
		return sock;
	sock = macvtap_get_socket(file);
931
932
933
934
935
936
937
938
	if (IS_ERR(sock))
		fput(file);
	return sock;
}

static struct socket *get_socket(int fd)
{
	struct socket *sock;
939

940
941
942
943
944
945
	/* special case to disable backend */
	if (fd == -1)
		return NULL;
	sock = get_raw_socket(fd);
	if (!IS_ERR(sock))
		return sock;
Arnd Bergmann's avatar
Arnd Bergmann committed
946
	sock = get_tap_socket(fd);
947
948
949
950
951
952
953
954
955
	if (!IS_ERR(sock))
		return sock;
	return ERR_PTR(-ENOTSOCK);
}

static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
{
	struct socket *sock, *oldsock;
	struct vhost_virtqueue *vq;
956
	struct vhost_net_virtqueue *nvq;
957
	struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
958
959
960
961
962
963
964
965
966
967
968
	int r;

	mutex_lock(&n->dev.mutex);
	r = vhost_dev_check_owner(&n->dev);
	if (r)
		goto err;

	if (index >= VHOST_NET_VQ_MAX) {
		r = -ENOBUFS;
		goto err;
	}
969
	vq = &n->vqs[index].vq;
970
	nvq = &n->vqs[index];
971
972
973
974
975
	mutex_lock(&vq->mutex);

	/* Verify that ring has been setup correctly. */
	if (!vhost_vq_access_ok(vq)) {
		r = -EFAULT;
976
		goto err_vq;
977
978
979
980
	}
	sock = get_socket(fd);
	if (IS_ERR(sock)) {
		r = PTR_ERR(sock);
981
		goto err_vq;
982
983
984
	}

	/* start polling new socket */
985
	oldsock = vq->private_data;
986
	if (sock != oldsock) {
987
988
		ubufs = vhost_net_ubuf_alloc(vq,
					     sock && vhost_sock_zcopy(sock));
989
990
991
992
		if (IS_ERR(ubufs)) {
			r = PTR_ERR(ubufs);
			goto err_ubufs;
		}
993

994
		vhost_net_disable_vq(n, vq);
995
		vq->private_data = sock;
Greg Kurz's avatar
Greg Kurz committed
996
		r = vhost_vq_init_access(vq);
997
		if (r)
998
			goto err_used;
999
1000
1001
		r = vhost_net_enable_vq(n, vq);
		if (r)
			goto err_used;
1002

1003
1004
		oldubufs = nvq->ubufs;
		nvq->ubufs = ubufs;
1005
1006
1007

		n->tx_packets = 0;
		n->tx_zcopy_err = 0;
1008
		n->tx_flush = false;
Jeff Dike's avatar
Jeff Dike committed
1009
	}
1010

1011
1012
	mutex_unlock(&vq->mutex);

1013
	if (oldubufs) {
1014
		vhost_net_ubuf_put_wait_and_free(oldubufs);
1015
		mutex_lock(&vq->mutex);
1016
		vhost_zerocopy_signal_used(n, vq);
1017
1018
		mutex_unlock(&vq->mutex);
	}
1019

1020
1021
	if (oldsock) {
		vhost_net_flush_vq(n, index);
Al Viro's avatar
Al Viro committed
1022
		sockfd_put(oldsock);
1023
	}
1024

1025
1026
1027
	mutex_unlock(&n->dev.mutex);
	return 0;

1028
err_used:
1029
	vq->private_data = oldsock;
1030
1031
	vhost_net_enable_vq(n, vq);
	if (ubufs)
1032
		vhost_net_ubuf_put_wait_and_free(ubufs);
1033
err_ubufs:
Al Viro's avatar
Al Viro committed
1034
	sockfd_put(sock);
1035
1036
err_vq:
	mutex_unlock(&vq->mutex);
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
err:
	mutex_unlock(&n->dev.mutex);
	return r;
}

static long vhost_net_reset_owner(struct vhost_net *n)
{
	struct socket *tx_sock = NULL;
	struct socket *rx_sock = NULL;
	long err;
1047
	struct vhost_umem *umem;
1048

1049
1050
1051
1052
	mutex_lock(&n->dev.mutex);
	err = vhost_dev_check_owner(&n->dev);
	if (err)
		goto done;
1053
1054
	umem = vhost_dev_reset_owner_prepare();
	if (!umem) {
1055
1056
1057
		err = -ENOMEM;
		goto done;
	}
1058
1059
	vhost_net_stop(n, &tx_sock, &rx_sock);
	vhost_net_flush(n);
1060
	vhost_dev_reset_owner(&n->dev, umem);