net/rpmsg: add support for new rpmsg sockets
[rpmsg/rpmsg.git] / net / rpmsg / rpmsg_proto.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* AF_RPMSG: Remote processor messaging sockets
3  *
4  * Copyright (C) 2011-2018 Texas Instruments Incorporated - http://www.ti.com/
5  *
6  * Ohad Ben-Cohen <ohad@wizery.com>
7  * Robert Tivy <rtivy@ti.com>
8  * G Anthony <a0783926@ti.com>
9  * Suman Anna <s-anna@ti.com>
10  */
12 #define pr_fmt(fmt)    "%s: " fmt, __func__
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/device.h>
17 #include <linux/types.h>
18 #include <linux/list.h>
19 #include <linux/errno.h>
20 #include <linux/skbuff.h>
21 #include <linux/err.h>
22 #include <linux/mutex.h>
23 #include <linux/rpmsg.h>
24 #include <linux/radix-tree.h>
25 #include <linux/remoteproc.h>
26 #include <linux/rpmsg/virtio_rpmsg.h>
27 #include <net/sock.h>
28 #include <uapi/linux/rpmsg_socket.h>
30 #define RPMSG_CB(skb)   (*(struct sockaddr_rpmsg *)&((skb)->cb))
32 /* Maximum buffer size supported by virtio rpmsg transport.
33  * Must match value as in drivers/rpmsg/virtio_rpmsg_bus.c
34  */
35 #define RPMSG_BUF_SIZE               (512)
37 struct rpmsg_socket {
38         struct sock sk;
39         struct rpmsg_device *rpdev;
40         struct rpmsg_endpoint *endpt;
41         int rproc_id;
42 };
44 /* Connection and socket states */
45 enum {
46         RPMSG_CONNECTED = 1,
47         RPMSG_OPEN,
48         RPMSG_LISTENING,
49         RPMSG_CLOSED,
50 };
52 /* A single-level radix-tree-based scheme is used to maintain the rpmsg
53  * channels we're exposing to userland. The radix tree maps a rproc index
54  * id to its published rpmsg-proto channel. Only a single rpmsg device is
55  * supported at the moment from each remote processor. This can be easily
56  * scaled to multiple devices using unique destination addresses but this
57  *_will_ require additional semantic changes on bind() and connect().
58  */
59 static RADIX_TREE(rpmsg_channels, GFP_KERNEL);
61 /* Synchronization of access to the tree is achieved using a mutex,
62  * because we're using non-atomic radix tree allocations.
63  */
64 static DEFINE_MUTEX(rpmsg_channels_lock);
66 static int rpmsg_sock_cb(struct rpmsg_device *rpdev, void *data, int len,
67                          void *priv, u32 src);
69 static struct proto rpmsg_proto = {
70         .name           = "RPMSG",
71         .owner          = THIS_MODULE,
72         .obj_size       = sizeof(struct rpmsg_socket),
73 };
75 /* Retrieve the rproc instance so that it can be used for retrieving
76  * the processor id associated with the rpmsg channel.
77  */
78 static inline struct rproc *rpdev_to_rproc(struct rpmsg_device *rpdev)
79 {
80         return rproc_get_by_child(&rpdev->dev);
81 }
83 /* Retrieve the rproc id. The rproc id _relies_ on aliases being defined
84  * in the DT blob for each of the remoteproc devices, and is essentially
85  * the alias id. These are assumed to match to be fixed for a particular
86  * SoC, and this provides a means to have a fixed interface to identify
87  * a remote processor.
88  */
89 static int rpmsg_sock_get_proc_id(struct rpmsg_device *rpdev)
90 {
91         struct rproc *rproc = rpdev_to_rproc(rpdev);
92         int id;
94         if (!rproc) {
95                 WARN_ON(1);
96                 return -EINVAL;
97         }
99         id = rproc_get_id(rproc);
100         WARN_ON(id < 0);
102         return id;
105 static int rpmsg_sock_connect(struct socket *sock, struct sockaddr *addr,
106                               int alen, int flags)
108         struct sock *sk = sock->sk;
109         struct rpmsg_socket *rpsk;
110         struct sockaddr_rpmsg *sa;
111         int err = 0;
112         struct rpmsg_device *rpdev;
114         if (sk->sk_state != RPMSG_OPEN)
115                 return -EBADFD;
117         if (sk->sk_type != SOCK_SEQPACKET)
118                 return -EINVAL;
120         if (!addr || addr->sa_family != AF_RPMSG)
121                 return -EINVAL;
123         if (alen < sizeof(*sa))
124                 return -EINVAL;
126         sa = (struct sockaddr_rpmsg *)addr;
128         lock_sock(sk);
130         rpsk = container_of(sk, struct rpmsg_socket, sk);
132         mutex_lock(&rpmsg_channels_lock);
134         /* find the set of channels exposed by this remote processor */
135         rpdev = radix_tree_lookup(&rpmsg_channels, sa->vproc_id);
136         if (!rpdev) {
137                 err = -EINVAL;
138                 goto out;
139         }
141         rpsk->rproc_id = sa->vproc_id;
142         rpsk->rpdev = rpdev;
144         /* XXX take care of disconnection state too */
145         sk->sk_state = RPMSG_CONNECTED;
147 out:
148         mutex_unlock(&rpmsg_channels_lock);
149         release_sock(sk);
150         return err;
153 static int rpmsg_sock_sendmsg(struct socket *sock, struct msghdr *msg,
154                               size_t len)
156         struct sock *sk = sock->sk;
157         struct rpmsg_socket *rpsk;
158         char payload[RPMSG_BUF_SIZE];/* todo: sane payload length methodology */
159         int err;
161         /* XXX check for sock_error as well ? */
162         /* XXX handle noblock ? */
163         if (msg->msg_flags & MSG_OOB)
164                 return -EOPNOTSUPP;
166         /* no payload ? */
167         if (!msg->msg_iter.iov->iov_base)
168                 return -EINVAL;
170         /* make sure the length is valid for copying into kernel buffer */
171         if (len > RPMSG_BUF_SIZE - sizeof(struct rpmsg_hdr))
172                 return -EMSGSIZE;
174         lock_sock(sk);
176         /* we don't support loopback at this point */
177         if (sk->sk_state != RPMSG_CONNECTED) {
178                 release_sock(sk);
179                 return -ENOTCONN;
180         }
182         rpsk = container_of(sk, struct rpmsg_socket, sk);
184         /* XXX for now, ignore the peer address. later use it
185          * with rpmsg_sendto, but only if user is root
186          */
187         err = memcpy_from_msg(payload, msg, len);
188         if (err)
189                 goto out;
191         err = rpmsg_send(rpsk->rpdev->ept, payload, len);
192         if (err)
193                 pr_err("rpmsg_send failed: %d\n", err);
195 out:
196         release_sock(sk);
197         return err;
200 static int rpmsg_sock_recvmsg(struct socket *sock, struct msghdr *msg,
201                               size_t len, int flags)
203         struct sock *sk = sock->sk;
204         struct sockaddr_rpmsg *sa;
205         struct sk_buff *skb;
206         int noblock = flags & MSG_DONTWAIT;
207         int ret;
209         if (flags & MSG_OOB) {
210                 pr_err("MSG_OOB: %d\n", EOPNOTSUPP);
211                 return -EOPNOTSUPP;
212         }
214         msg->msg_namelen = 0;
216         skb = skb_recv_datagram(sk, flags, noblock, &ret);
217         if (!skb) {
218                 /* check for shutdown ? */
219                 pr_err("skb_recv_datagram: %d\n", ret);
220                 return ret;
221         }
223         if (msg->msg_name) {
224                 msg->msg_namelen = sizeof(*sa);
225                 sa = (struct sockaddr_rpmsg *)msg->msg_name;
226                 sa->vproc_id = RPMSG_CB(skb).vproc_id;
227                 sa->addr = RPMSG_CB(skb).addr;
228                 sa->family = AF_RPMSG;
229         }
231         if (len > skb->len) {
232                 len = skb->len;
233         } else if (len < skb->len) {
234                 pr_warn("user buffer is too small\n");
235                 /* XXX truncate or error ? */
236                 msg->msg_flags |= MSG_TRUNC;
237         }
239         ret = skb_copy_datagram_msg(skb, 0, msg, len);
240         if (ret) {
241                 pr_err("error copying skb data: %d\n", ret);
242                 goto out_free;
243         }
245         ret = len;
247 out_free:
248         skb_free_datagram(sk, skb);
249         return ret;
252 static __poll_t rpmsg_sock_poll(struct file *file, struct socket *sock,
253                                 poll_table *wait)
255         struct sock *sk = sock->sk;
256         __poll_t mask = 0;
258         poll_wait(file, sk_sleep(sk), wait);
260         /* exceptional events? */
261         if (sk->sk_err || !skb_queue_empty(&sk->sk_error_queue))
262                 mask |= EPOLLERR;
263         if (sk->sk_shutdown & RCV_SHUTDOWN)
264                 mask |= EPOLLRDHUP;
265         if (sk->sk_shutdown == SHUTDOWN_MASK)
266                 mask |= EPOLLHUP;
268         /* readable? */
269         if (!skb_queue_empty(&sk->sk_receive_queue) ||
270             (sk->sk_shutdown & RCV_SHUTDOWN))
271                 mask |= EPOLLIN | EPOLLRDNORM;
273         if (sk->sk_state == RPMSG_CLOSED)
274                 mask |= EPOLLHUP;
276         /* XXX is writable ?
277          * this depends on the destination processor.
278          * if loopback: we're writable unless no memory
279          * if to remote: we need enabled rpmsg buffer or user supplied bufs
280          * for now, let's always be writable.
281          */
282         mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
284         return mask;
287 /* return bound socket address information, either local or remote */
288 static int rpmsg_sock_getname(struct socket *sock, struct sockaddr *addr,
289                               int peer)
291         struct sock *sk = sock->sk;
292         struct rpmsg_socket *rpsk;
293         struct rpmsg_device *rpdev;
294         struct sockaddr_rpmsg *sa;
295         int ret;
297         rpsk = container_of(sk, struct rpmsg_socket, sk);
299         lock_sock(sk);
300         rpdev = rpsk->rpdev;
301         if (!rpdev) {
302                 ret = peer ? -ENOTCONN : -EINVAL;
303                 goto out;
304         }
306         addr->sa_family = AF_RPMSG;
307         sa = (struct sockaddr_rpmsg *)addr;
308         ret = sizeof(*sa);
310         if (peer) {
311                 sa->vproc_id = rpsk->rproc_id;
312                 sa->addr = rpdev->dst;
313         } else {
314                 sa->vproc_id = RPMSG_LOCALHOST;
315                 sa->addr = rpsk->endpt ? rpsk->endpt->addr : rpsk->rpdev->src;
316         }
318 out:
319         release_sock(sk);
320         return ret;
323 static int rpmsg_sock_release(struct socket *sock)
325         struct sock *sk = sock->sk;
326         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
328         if (!sk)
329                 return 0;
331         /* function can be called with NULL endpoints, so it is effective for
332          * Rx sockets and a no-op for Tx sockets
333          */
334         rpmsg_destroy_ept(rpsk->endpt);
336         sock_put(sock->sk);
338         return 0;
341 /* Notes:
342  * - calling connect after bind isn't currently supported (is it even needed?).
343  * - userspace arguments to bind aren't intuitive: one needs to provide
344  *   the vproc id of the remote processor that the channel needs to be shared
345  *   with, and the -local- source address the channel is to be bound with
346  */
347 static int
348 rpmsg_sock_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
350         struct sock *sk = sock->sk;
351         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
352         struct rpmsg_device *rpdev;
353         struct rpmsg_endpoint *endpt;
354         struct rpmsg_channel_info chinfo = {};
355         struct sockaddr_rpmsg *sa = (struct sockaddr_rpmsg *)uaddr;
357         if (sock->state == SS_CONNECTED)
358                 return -EINVAL;
360         if (addr_len != sizeof(*sa))
361                 return -EINVAL;
363         if (sa->family != AF_RPMSG)
364                 return -EINVAL;
366         if (rpsk->endpt)
367                 return -EBUSY;
369         if (sk->sk_state != RPMSG_OPEN)
370                 return -EINVAL;
372         rpdev = radix_tree_lookup(&rpmsg_channels, sa->vproc_id);
373         if (!rpdev)
374                 return -EINVAL;
376         /* bind this socket with a receiving endpoint */
377         chinfo.src = sa->addr;
378         chinfo.dst = RPMSG_ADDR_ANY;
379         endpt = rpmsg_create_ept(rpdev, rpmsg_sock_cb, sk, chinfo);
380         if (!endpt)
381                 return -EINVAL;
383         rpsk->rpdev = rpdev;
384         rpsk->endpt = endpt;
385         rpsk->rproc_id = sa->vproc_id;
387         sk->sk_state = RPMSG_LISTENING;
389         return 0;
392 static const struct proto_ops rpmsg_sock_ops = {
393         .family         = PF_RPMSG,
394         .owner          = THIS_MODULE,
396         .release        = rpmsg_sock_release,
397         .connect        = rpmsg_sock_connect,
398         .getname        = rpmsg_sock_getname,
399         .sendmsg        = rpmsg_sock_sendmsg,
400         .recvmsg        = rpmsg_sock_recvmsg,
401         .poll           = rpmsg_sock_poll,
402         .bind           = rpmsg_sock_bind,
404         .listen         = sock_no_listen,
405         .accept         = sock_no_accept,
406         .ioctl          = sock_no_ioctl,
407         .mmap           = sock_no_mmap,
408         .socketpair     = sock_no_socketpair,
409         .shutdown       = sock_no_shutdown,
410         .setsockopt     = sock_no_setsockopt,
411         .getsockopt     = sock_no_getsockopt
412 };
414 static void rpmsg_sock_destruct(struct sock *sk)
418 static int rpmsg_sock_create(struct net *net, struct socket *sock, int proto,
419                              int kern)
421         struct sock *sk;
422         struct rpmsg_socket *rpsk;
424         if (sock->type != SOCK_SEQPACKET)
425                 return -ESOCKTNOSUPPORT;
426         if (proto != 0)
427                 return -EPROTONOSUPPORT;
429         sk = sk_alloc(net, PF_RPMSG, GFP_KERNEL, &rpmsg_proto, kern);
430         if (!sk)
431                 return -ENOMEM;
433         sock->state = SS_UNCONNECTED;
434         sock->ops = &rpmsg_sock_ops;
435         sock_init_data(sock, sk);
437         sk->sk_destruct = rpmsg_sock_destruct;
438         sk->sk_protocol = proto;
440         sk->sk_state = RPMSG_OPEN;
442         rpsk = container_of(sk, struct rpmsg_socket, sk);
443         /* use RPMSG_LOCALHOST to serve as an invalid value */
444         rpsk->rproc_id = RPMSG_LOCALHOST;
446         return 0;
449 static const struct net_proto_family rpmsg_proto_family = {
450         .family = PF_RPMSG,
451         .create = rpmsg_sock_create,
452         .owner = THIS_MODULE,
453 };
455 static int __rpmsg_sock_cb(struct device *dev, int from_vproc_id, void *data,
456                            int len, struct sock *sk, u32 src)
458         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
459         struct sk_buff *skb;
460         int ret;
462 #if defined(CONFIG_DYNAMIC_DEBUG)
463         dynamic_hex_dump("rpmsg_proto Rx: ", DUMP_PREFIX_NONE, 16, 1, data,
464                          len, true);
465 #endif
467         lock_sock(sk);
469         switch (sk->sk_state) {
470         case RPMSG_CONNECTED:
471                 if (rpsk->rpdev->dst != src)
472                         dev_warn(dev, "unexpected source address: %d\n", src);
473                 break;
474         case RPMSG_LISTENING:
475                 /* When an inbound message is received while we're listening,
476                  * we implicitly become connected
477                  */
478                 sk->sk_state = RPMSG_CONNECTED;
479                 rpsk->rpdev->dst = src;
480                 break;
481         default:
482                 dev_warn(dev, "unexpected inbound message (from %d)\n", src);
483                 break;
484         }
486         skb = sock_alloc_send_skb(sk, len, 1, &ret);
487         if (!skb) {
488                 dev_err(dev, "sock_alloc_send_skb failed: %d\n", ret);
489                 ret = -ENOMEM;
490                 goto out;
491         }
493         RPMSG_CB(skb).vproc_id = from_vproc_id;
494         RPMSG_CB(skb).addr = src;
495         RPMSG_CB(skb).family = AF_RPMSG;
497         memcpy(skb_put(skb, len), data, len);
499         ret = sock_queue_rcv_skb(sk, skb);
500         if (ret) {
501                 dev_err(dev, "sock_queue_rcv_skb failed: %d\n", ret);
502                 kfree_skb(skb);
503         }
505 out:
506         release_sock(sk);
507         return ret;
510 static int rpmsg_sock_cb(struct rpmsg_device *rpdev, void *data, int len,
511                          void *priv, u32 src)
513         int id = rpmsg_sock_get_proc_id(rpdev);
515         return __rpmsg_sock_cb(&rpdev->dev, id, data, len, priv, src);
518 static int rpmsg_proto_cb(struct rpmsg_device *rpdev, void *data, int len,
519                           void *priv, u32 src)
521         dev_err(&rpdev->dev, "rpmsg_proto device not designed to receive any messages\n");
522         return 0;
525 static int rpmsg_proto_probe(struct rpmsg_device *rpdev)
527         struct device *dev = &rpdev->dev;
528         int ret, dst = rpdev->dst, id;
529         struct rpmsg_device *vrp_dev;
531         if (WARN_ON(dst == RPMSG_ADDR_ANY))
532                 return -EINVAL;
534         id = rpmsg_sock_get_proc_id(rpdev);
535         if (id < 0)
536                 return -EINVAL;
538         mutex_lock(&rpmsg_channels_lock);
540         /* are we exposing a rpmsg proto device for this remote processor yet?
541          * If not, associate id/device for later lookup in rpmsg_sock_bind().
542          * Multiple devices per remote processor are not supported.
543          */
544         vrp_dev = radix_tree_lookup(&rpmsg_channels, id);
545         if (!vrp_dev) {
546                 ret = radix_tree_insert(&rpmsg_channels, id, rpdev);
547                 if (ret) {
548                         dev_err(dev, "radix_tree_insert failed: %d\n", ret);
549                         goto out;
550                 }
551         } else {
552                 ret = -ENODEV;
553                 dev_err(dev, "multiple rpmsg-proto devices from the same rproc is not supported.\n");
554                 goto out;
555         }
557 out:
558         mutex_unlock(&rpmsg_channels_lock);
560         return ret;
563 static void rpmsg_proto_remove(struct rpmsg_device *rpdev)
565         struct device *dev = &rpdev->dev;
566         int id, dst = rpdev->dst;
567         struct rpmsg_device *vrp_dev;
569         if (dst == RPMSG_ADDR_ANY)
570                 return;
572         id = rpmsg_sock_get_proc_id(rpdev);
574         mutex_lock(&rpmsg_channels_lock);
576         vrp_dev = radix_tree_lookup(&rpmsg_channels, id);
577         if (!vrp_dev) {
578                 dev_err(dev, "can't find rpmsg device for rproc %d\n", id);
579                 goto out;
580         }
581         if (vrp_dev != rpdev)
582                 dev_err(dev, "can't match the stored rpdev for rproc %d\n", id);
584         if (!radix_tree_delete(&rpmsg_channels, id))
585                 dev_err(dev, "failed to delete rpdev for rproc %d\n", id);
587 out:
588         mutex_unlock(&rpmsg_channels_lock);
591 static struct rpmsg_device_id rpmsg_proto_id_table[] = {
592         { .name = "rpmsg-proto" },
593         { },
594 };
595 MODULE_DEVICE_TABLE(rpmsg, rpmsg_proto_id_table);
597 static struct rpmsg_driver rpmsg_proto_driver = {
598         .drv.name       = KBUILD_MODNAME,
599         .id_table       = rpmsg_proto_id_table,
600         .probe          = rpmsg_proto_probe,
601         .callback       = rpmsg_proto_cb,
602         .remove         = rpmsg_proto_remove,
603 };
605 static int __init rpmsg_proto_init(void)
607         int ret;
609         ret = proto_register(&rpmsg_proto, 0);
610         if (ret) {
611                 pr_err("proto_register failed: %d\n", ret);
612                 return ret;
613         }
615         ret = sock_register(&rpmsg_proto_family);
616         if (ret) {
617                 pr_err("sock_register failed: %d\n", ret);
618                 goto proto_unreg;
619         }
621         ret = register_rpmsg_driver(&rpmsg_proto_driver);
622         if (ret) {
623                 pr_err("register_rpmsg_driver failed: %d\n", ret);
624                 goto sock_unreg;
625         }
627         return 0;
629 sock_unreg:
630         sock_unregister(PF_RPMSG);
631 proto_unreg:
632         proto_unregister(&rpmsg_proto);
633         return ret;
636 static void __exit rpmsg_proto_exit(void)
638         unregister_rpmsg_driver(&rpmsg_proto_driver);
639         sock_unregister(PF_RPMSG);
640         proto_unregister(&rpmsg_proto);
643 module_init(rpmsg_proto_init);
644 module_exit(rpmsg_proto_exit);
646 MODULE_DESCRIPTION("Remote processor messaging protocol");
647 MODULE_LICENSE("GPL v2");
648 MODULE_ALIAS("rpmsg:rpmsg-proto");
649 MODULE_ALIAS_NETPROTO(AF_RPMSG);