hv_utils_transport.c 8.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/*
 * Kernel/userspace transport abstraction for Hyper-V util driver.
 *
 * Copyright (C) 2015, Vitaly Kuznetsov <vkuznets@redhat.com>
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 as published
 * by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, GOOD TITLE or
 * NON INFRINGEMENT.  See the GNU General Public License for more
 * details.
 *
 */

#include <linux/slab.h>
#include <linux/fs.h>
#include <linux/poll.h>

#include "hyperv_vmbus.h"
#include "hv_utils_transport.h"

static DEFINE_SPINLOCK(hvt_list_lock);
static struct list_head hvt_list = LIST_HEAD_INIT(hvt_list);

static void hvt_reset(struct hvutil_transport *hvt)
{
	kfree(hvt->outmsg);
	hvt->outmsg = NULL;
	hvt->outmsg_len = 0;
	if (hvt->on_reset)
		hvt->on_reset();
}

static ssize_t hvt_op_read(struct file *file, char __user *buf,
			   size_t count, loff_t *ppos)
{
	struct hvutil_transport *hvt;
	int ret;

	hvt = container_of(file->f_op, struct hvutil_transport, fops);

45
46
	if (wait_event_interruptible(hvt->outmsg_q, hvt->outmsg_len > 0 ||
				     hvt->mode != HVUTIL_TRANSPORT_CHARDEV))
47
48
		return -EINTR;

49
	mutex_lock(&hvt->lock);
50
51
52
53
54
55

	if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) {
		ret = -EBADF;
		goto out_unlock;
	}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
	if (!hvt->outmsg) {
		ret = -EAGAIN;
		goto out_unlock;
	}

	if (count < hvt->outmsg_len) {
		ret = -EINVAL;
		goto out_unlock;
	}

	if (!copy_to_user(buf, hvt->outmsg, hvt->outmsg_len))
		ret = hvt->outmsg_len;
	else
		ret = -EFAULT;

	kfree(hvt->outmsg);
	hvt->outmsg = NULL;
	hvt->outmsg_len = 0;

75
76
77
78
	if (hvt->on_read)
		hvt->on_read();
	hvt->on_read = NULL;

79
out_unlock:
80
	mutex_unlock(&hvt->lock);
81
82
83
84
85
86
87
88
	return ret;
}

static ssize_t hvt_op_write(struct file *file, const char __user *buf,
			    size_t count, loff_t *ppos)
{
	struct hvutil_transport *hvt;
	u8 *inmsg;
89
	int ret;
90
91
92

	hvt = container_of(file->f_op, struct hvutil_transport, fops);

93
94
95
96
	inmsg = memdup_user(buf, count);
	if (IS_ERR(inmsg))
		return PTR_ERR(inmsg);

97
98
99
100
	if (hvt->mode == HVUTIL_TRANSPORT_DESTROY)
		ret = -EBADF;
	else
		ret = hvt->on_msg(inmsg, count);
101

102
103
	kfree(inmsg);

104
	return ret ? ret : count;
105
106
107
108
109
110
111
112
113
}

static unsigned int hvt_op_poll(struct file *file, poll_table *wait)
{
	struct hvutil_transport *hvt;

	hvt = container_of(file->f_op, struct hvutil_transport, fops);

	poll_wait(file, &hvt->outmsg_q, wait);
114
115

	if (hvt->mode == HVUTIL_TRANSPORT_DESTROY)
116
		return POLLERR | POLLHUP;
117

118
119
120
121
122
123
124
125
126
	if (hvt->outmsg_len > 0)
		return POLLIN | POLLRDNORM;

	return 0;
}

static int hvt_op_open(struct inode *inode, struct file *file)
{
	struct hvutil_transport *hvt;
127
128
	int ret = 0;
	bool issue_reset = false;
129
130
131

	hvt = container_of(file->f_op, struct hvutil_transport, fops);

132
133
134
135
136
137
138
139
140
	mutex_lock(&hvt->lock);

	if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) {
		ret = -EBADF;
	} else if (hvt->mode == HVUTIL_TRANSPORT_INIT) {
		/*
		 * Switching to CHARDEV mode. We switch bach to INIT when
		 * device gets released.
		 */
141
		hvt->mode = HVUTIL_TRANSPORT_CHARDEV;
142
	}
143
144
145
146
147
	else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) {
		/*
		 * We're switching from netlink communication to using char
		 * device. Issue the reset first.
		 */
148
		issue_reset = true;
149
		hvt->mode = HVUTIL_TRANSPORT_CHARDEV;
150
151
152
	} else {
		ret = -EBUSY;
	}
153

154
155
156
157
158
159
	if (issue_reset)
		hvt_reset(hvt);

	mutex_unlock(&hvt->lock);

	return ret;
160
161
}

162
163
164
165
166
167
168
static void hvt_transport_free(struct hvutil_transport *hvt)
{
	misc_deregister(&hvt->mdev);
	kfree(hvt->outmsg);
	kfree(hvt);
}

169
170
171
static int hvt_op_release(struct inode *inode, struct file *file)
{
	struct hvutil_transport *hvt;
172
	int mode_old;
173
174
175

	hvt = container_of(file->f_op, struct hvutil_transport, fops);

176
	mutex_lock(&hvt->lock);
177
	mode_old = hvt->mode;
178
179
	if (hvt->mode != HVUTIL_TRANSPORT_DESTROY)
		hvt->mode = HVUTIL_TRANSPORT_INIT;
180
181
182
183
184
185
	/*
	 * Cleanup message buffers to avoid spurious messages when the daemon
	 * connects back.
	 */
	hvt_reset(hvt);

186
	if (mode_old == HVUTIL_TRANSPORT_DESTROY)
187
188
189
		complete(&hvt->release);

	mutex_unlock(&hvt->lock);
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
	return 0;
}

static void hvt_cn_callback(struct cn_msg *msg, struct netlink_skb_parms *nsp)
{
	struct hvutil_transport *hvt, *hvt_found = NULL;

	spin_lock(&hvt_list_lock);
	list_for_each_entry(hvt, &hvt_list, list) {
		if (hvt->cn_id.idx == msg->id.idx &&
		    hvt->cn_id.val == msg->id.val) {
			hvt_found = hvt;
			break;
		}
	}
	spin_unlock(&hvt_list_lock);
	if (!hvt_found) {
		pr_warn("hvt_cn_callback: spurious message received!\n");
		return;
	}

	/*
	 * Switching to NETLINK mode. Switching to CHARDEV happens when someone
	 * opens the device.
	 */
216
	mutex_lock(&hvt->lock);
217
218
219
220
221
222
223
	if (hvt->mode == HVUTIL_TRANSPORT_INIT)
		hvt->mode = HVUTIL_TRANSPORT_NETLINK;

	if (hvt->mode == HVUTIL_TRANSPORT_NETLINK)
		hvt_found->on_msg(msg->data, msg->len);
	else
		pr_warn("hvt_cn_callback: unexpected netlink message!\n");
224
	mutex_unlock(&hvt->lock);
225
226
}

227
228
int hvutil_transport_send(struct hvutil_transport *hvt, void *msg, int len,
			  void (*on_read_cb)(void))
229
230
231
232
{
	struct cn_msg *cn_msg;
	int ret = 0;

233
234
	if (hvt->mode == HVUTIL_TRANSPORT_INIT ||
	    hvt->mode == HVUTIL_TRANSPORT_DESTROY) {
235
236
237
		return -EINVAL;
	} else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) {
		cn_msg = kzalloc(sizeof(*cn_msg) + len, GFP_ATOMIC);
238
		if (!cn_msg)
239
240
241
242
243
244
245
			return -ENOMEM;
		cn_msg->id.idx = hvt->cn_id.idx;
		cn_msg->id.val = hvt->cn_id.val;
		cn_msg->len = len;
		memcpy(cn_msg->data, msg, len);
		ret = cn_netlink_send(cn_msg, 0, 0, GFP_ATOMIC);
		kfree(cn_msg);
246
247
248
249
250
251
252
		/*
		 * We don't know when netlink messages are delivered but unlike
		 * in CHARDEV mode we're not blocked and we can send next
		 * messages right away.
		 */
		if (on_read_cb)
			on_read_cb();
253
254
255
		return ret;
	}
	/* HVUTIL_TRANSPORT_CHARDEV */
256
	mutex_lock(&hvt->lock);
257
258
259
260
261
	if (hvt->mode != HVUTIL_TRANSPORT_CHARDEV) {
		ret = -EINVAL;
		goto out_unlock;
	}

262
263
264
265
266
267
	if (hvt->outmsg) {
		/* Previous message wasn't received */
		ret = -EFAULT;
		goto out_unlock;
	}
	hvt->outmsg = kzalloc(len, GFP_KERNEL);
268
269
270
	if (hvt->outmsg) {
		memcpy(hvt->outmsg, msg, len);
		hvt->outmsg_len = len;
271
		hvt->on_read = on_read_cb;
272
273
274
		wake_up_interruptible(&hvt->outmsg_q);
	} else
		ret = -ENOMEM;
275
out_unlock:
276
	mutex_unlock(&hvt->lock);
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
	return ret;
}

struct hvutil_transport *hvutil_transport_init(const char *name,
					       u32 cn_idx, u32 cn_val,
					       int (*on_msg)(void *, int),
					       void (*on_reset)(void))
{
	struct hvutil_transport *hvt;

	hvt = kzalloc(sizeof(*hvt), GFP_KERNEL);
	if (!hvt)
		return NULL;

	hvt->cn_id.idx = cn_idx;
	hvt->cn_id.val = cn_val;

	hvt->mdev.minor = MISC_DYNAMIC_MINOR;
	hvt->mdev.name = name;

	hvt->fops.owner = THIS_MODULE;
	hvt->fops.read = hvt_op_read;
	hvt->fops.write = hvt_op_write;
	hvt->fops.poll = hvt_op_poll;
	hvt->fops.open = hvt_op_open;
	hvt->fops.release = hvt_op_release;

	hvt->mdev.fops = &hvt->fops;

	init_waitqueue_head(&hvt->outmsg_q);
307
	mutex_init(&hvt->lock);
308
	init_completion(&hvt->release);
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327

	spin_lock(&hvt_list_lock);
	list_add(&hvt->list, &hvt_list);
	spin_unlock(&hvt_list_lock);

	hvt->on_msg = on_msg;
	hvt->on_reset = on_reset;

	if (misc_register(&hvt->mdev))
		goto err_free_hvt;

	/* Use cn_id.idx/cn_id.val to determine if we need to setup netlink */
	if (hvt->cn_id.idx > 0 && hvt->cn_id.val > 0 &&
	    cn_add_callback(&hvt->cn_id, name, hvt_cn_callback))
		goto err_free_hvt;

	return hvt;

err_free_hvt:
328
329
330
	spin_lock(&hvt_list_lock);
	list_del(&hvt->list);
	spin_unlock(&hvt_list_lock);
331
332
333
334
335
336
	kfree(hvt);
	return NULL;
}

void hvutil_transport_destroy(struct hvutil_transport *hvt)
{
337
338
	int mode_old;

339
	mutex_lock(&hvt->lock);
340
	mode_old = hvt->mode;
341
342
343
344
	hvt->mode = HVUTIL_TRANSPORT_DESTROY;
	wake_up_interruptible(&hvt->outmsg_q);
	mutex_unlock(&hvt->lock);

345
346
347
348
349
	/*
	 * In case we were in 'chardev' mode we still have an open fd so we
	 * have to defer freeing the device. Netlink interface can be freed
	 * now.
	 */
350
351
352
353
354
	spin_lock(&hvt_list_lock);
	list_del(&hvt->list);
	spin_unlock(&hvt_list_lock);
	if (hvt->cn_id.idx > 0 && hvt->cn_id.val > 0)
		cn_del_callback(&hvt->cn_id);
355

356
357
358
359
	if (mode_old == HVUTIL_TRANSPORT_CHARDEV)
		wait_for_completion(&hvt->release);

	hvt_transport_free(hvt);
360
}