net/rpmsg: fix return value of rpmsg_sock_sendmsg()
[rpmsg/rpmsg.git] / net / rpmsg / rpmsg_proto.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* AF_RPMSG: Remote processor messaging sockets
3  *
4  * Copyright (C) 2011-2018 Texas Instruments Incorporated - http://www.ti.com/
5  *
6  * Ohad Ben-Cohen <ohad@wizery.com>
7  * Robert Tivy <rtivy@ti.com>
8  * G Anthony <a0783926@ti.com>
9  * Suman Anna <s-anna@ti.com>
10  */
12 #define pr_fmt(fmt)    "%s: " fmt, __func__
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/device.h>
17 #include <linux/types.h>
18 #include <linux/list.h>
19 #include <linux/errno.h>
20 #include <linux/skbuff.h>
21 #include <linux/err.h>
22 #include <linux/mutex.h>
23 #include <linux/rpmsg.h>
24 #include <linux/radix-tree.h>
25 #include <linux/remoteproc.h>
26 #include <linux/rpmsg/virtio_rpmsg.h>
27 #include <net/sock.h>
28 #include <uapi/linux/rpmsg_socket.h>
30 #define RPMSG_CB(skb)   (*(struct sockaddr_rpmsg *)&((skb)->cb))
32 /* Maximum buffer size supported by virtio rpmsg transport.
33  * Must match value as in drivers/rpmsg/virtio_rpmsg_bus.c
34  */
35 #define RPMSG_BUF_SIZE               (512)
37 struct rpmsg_socket {
38         struct sock sk;
39         struct rpmsg_device *rpdev;
40         struct rpmsg_endpoint *endpt;
41         int rproc_id;
42 };
44 /* Connection and socket states */
45 enum {
46         RPMSG_CONNECTED = 1,
47         RPMSG_OPEN,
48         RPMSG_LISTENING,
49         RPMSG_CLOSED,
50 };
52 /* A single-level radix-tree-based scheme is used to maintain the rpmsg
53  * channels we're exposing to userland. The radix tree maps a rproc index
54  * id to its published rpmsg-proto channel. Only a single rpmsg device is
55  * supported at the moment from each remote processor. This can be easily
56  * scaled to multiple devices using unique destination addresses but this
57  *_will_ require additional semantic changes on bind() and connect().
58  */
59 static RADIX_TREE(rpmsg_channels, GFP_KERNEL);
61 /* Synchronization of access to the tree is achieved using a mutex,
62  * because we're using non-atomic radix tree allocations.
63  */
64 static DEFINE_MUTEX(rpmsg_channels_lock);
66 static int rpmsg_sock_cb(struct rpmsg_device *rpdev, void *data, int len,
67                          void *priv, u32 src);
69 static struct proto rpmsg_proto = {
70         .name           = "RPMSG",
71         .owner          = THIS_MODULE,
72         .obj_size       = sizeof(struct rpmsg_socket),
73 };
75 /* Retrieve the rproc instance so that it can be used for retrieving
76  * the processor id associated with the rpmsg channel.
77  */
78 static inline struct rproc *rpdev_to_rproc(struct rpmsg_device *rpdev)
79 {
80         return rproc_get_by_child(&rpdev->dev);
81 }
83 /* Retrieve the rproc id. The rproc id _relies_ on aliases being defined
84  * in the DT blob for each of the remoteproc devices, and is essentially
85  * the alias id. These are assumed to match to be fixed for a particular
86  * SoC, and this provides a means to have a fixed interface to identify
87  * a remote processor.
88  */
89 static int rpmsg_sock_get_proc_id(struct rpmsg_device *rpdev)
90 {
91         struct rproc *rproc = rpdev_to_rproc(rpdev);
92         int id;
94         if (!rproc) {
95                 WARN_ON(1);
96                 return -EINVAL;
97         }
99         id = rproc_get_id(rproc);
100         WARN_ON(id < 0);
102         return id;
105 static int rpmsg_sock_connect(struct socket *sock, struct sockaddr *addr,
106                               int alen, int flags)
108         struct sock *sk = sock->sk;
109         struct rpmsg_socket *rpsk;
110         struct sockaddr_rpmsg *sa;
111         int err = 0;
112         struct rpmsg_device *rpdev;
114         if (sk->sk_state != RPMSG_OPEN)
115                 return -EBADFD;
117         if (sk->sk_type != SOCK_SEQPACKET)
118                 return -EINVAL;
120         if (!addr || addr->sa_family != AF_RPMSG)
121                 return -EINVAL;
123         if (alen < sizeof(*sa))
124                 return -EINVAL;
126         sa = (struct sockaddr_rpmsg *)addr;
128         lock_sock(sk);
130         rpsk = container_of(sk, struct rpmsg_socket, sk);
132         mutex_lock(&rpmsg_channels_lock);
134         /* find the set of channels exposed by this remote processor */
135         rpdev = radix_tree_lookup(&rpmsg_channels, sa->vproc_id);
136         if (!rpdev) {
137                 err = -EINVAL;
138                 goto out;
139         }
141         rpsk->rproc_id = sa->vproc_id;
142         rpsk->rpdev = rpdev;
144         /* XXX take care of disconnection state too */
145         sk->sk_state = RPMSG_CONNECTED;
147 out:
148         mutex_unlock(&rpmsg_channels_lock);
149         release_sock(sk);
150         return err;
153 static int rpmsg_sock_sendmsg(struct socket *sock, struct msghdr *msg,
154                               size_t len)
156         struct sock *sk = sock->sk;
157         struct rpmsg_socket *rpsk;
158         char payload[RPMSG_BUF_SIZE];/* todo: sane payload length methodology */
159         int err;
161         /* XXX check for sock_error as well ? */
162         /* XXX handle noblock ? */
163         if (msg->msg_flags & MSG_OOB)
164                 return -EOPNOTSUPP;
166         /* no payload ? */
167         if (!msg->msg_iter.iov->iov_base)
168                 return -EINVAL;
170         /* make sure the length is valid for copying into kernel buffer */
171         if (len > RPMSG_BUF_SIZE - sizeof(struct rpmsg_hdr))
172                 return -EMSGSIZE;
174         lock_sock(sk);
176         /* we don't support loopback at this point */
177         if (sk->sk_state != RPMSG_CONNECTED) {
178                 release_sock(sk);
179                 return -ENOTCONN;
180         }
182         rpsk = container_of(sk, struct rpmsg_socket, sk);
184         /* XXX for now, ignore the peer address. later use it
185          * with rpmsg_sendto, but only if user is root
186          */
187         err = memcpy_from_msg(payload, msg, len);
188         if (err)
189                 goto out;
191         err = rpmsg_send(rpsk->rpdev->ept, payload, len);
192         if (err)
193                 pr_err("rpmsg_send failed: %d\n", err);
194         else
195                 err = len;
197 out:
198         release_sock(sk);
199         return err;
202 static int rpmsg_sock_recvmsg(struct socket *sock, struct msghdr *msg,
203                               size_t len, int flags)
205         struct sock *sk = sock->sk;
206         struct sockaddr_rpmsg *sa;
207         struct sk_buff *skb;
208         int noblock = flags & MSG_DONTWAIT;
209         int ret;
211         if (flags & MSG_OOB) {
212                 pr_err("MSG_OOB: %d\n", EOPNOTSUPP);
213                 return -EOPNOTSUPP;
214         }
216         msg->msg_namelen = 0;
218         skb = skb_recv_datagram(sk, flags, noblock, &ret);
219         if (!skb) {
220                 /* check for shutdown ? */
221                 pr_err("skb_recv_datagram: %d\n", ret);
222                 return ret;
223         }
225         if (msg->msg_name) {
226                 msg->msg_namelen = sizeof(*sa);
227                 sa = (struct sockaddr_rpmsg *)msg->msg_name;
228                 sa->vproc_id = RPMSG_CB(skb).vproc_id;
229                 sa->addr = RPMSG_CB(skb).addr;
230                 sa->family = AF_RPMSG;
231         }
233         if (len > skb->len) {
234                 len = skb->len;
235         } else if (len < skb->len) {
236                 pr_warn("user buffer is too small\n");
237                 /* XXX truncate or error ? */
238                 msg->msg_flags |= MSG_TRUNC;
239         }
241         ret = skb_copy_datagram_msg(skb, 0, msg, len);
242         if (ret) {
243                 pr_err("error copying skb data: %d\n", ret);
244                 goto out_free;
245         }
247         ret = len;
249 out_free:
250         skb_free_datagram(sk, skb);
251         return ret;
254 static __poll_t rpmsg_sock_poll(struct file *file, struct socket *sock,
255                                 poll_table *wait)
257         struct sock *sk = sock->sk;
258         __poll_t mask = 0;
260         poll_wait(file, sk_sleep(sk), wait);
262         /* exceptional events? */
263         if (sk->sk_err || !skb_queue_empty(&sk->sk_error_queue))
264                 mask |= EPOLLERR;
265         if (sk->sk_shutdown & RCV_SHUTDOWN)
266                 mask |= EPOLLRDHUP;
267         if (sk->sk_shutdown == SHUTDOWN_MASK)
268                 mask |= EPOLLHUP;
270         /* readable? */
271         if (!skb_queue_empty(&sk->sk_receive_queue) ||
272             (sk->sk_shutdown & RCV_SHUTDOWN))
273                 mask |= EPOLLIN | EPOLLRDNORM;
275         if (sk->sk_state == RPMSG_CLOSED)
276                 mask |= EPOLLHUP;
278         /* XXX is writable ?
279          * this depends on the destination processor.
280          * if loopback: we're writable unless no memory
281          * if to remote: we need enabled rpmsg buffer or user supplied bufs
282          * for now, let's always be writable.
283          */
284         mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
286         return mask;
289 /* return bound socket address information, either local or remote */
290 static int rpmsg_sock_getname(struct socket *sock, struct sockaddr *addr,
291                               int peer)
293         struct sock *sk = sock->sk;
294         struct rpmsg_socket *rpsk;
295         struct rpmsg_device *rpdev;
296         struct sockaddr_rpmsg *sa;
297         int ret;
299         rpsk = container_of(sk, struct rpmsg_socket, sk);
301         lock_sock(sk);
302         rpdev = rpsk->rpdev;
303         if (!rpdev) {
304                 ret = peer ? -ENOTCONN : -EINVAL;
305                 goto out;
306         }
308         addr->sa_family = AF_RPMSG;
309         sa = (struct sockaddr_rpmsg *)addr;
310         ret = sizeof(*sa);
312         if (peer) {
313                 sa->vproc_id = rpsk->rproc_id;
314                 sa->addr = rpdev->dst;
315         } else {
316                 sa->vproc_id = RPMSG_LOCALHOST;
317                 sa->addr = rpsk->endpt ? rpsk->endpt->addr : rpsk->rpdev->src;
318         }
320 out:
321         release_sock(sk);
322         return ret;
325 static int rpmsg_sock_release(struct socket *sock)
327         struct sock *sk = sock->sk;
328         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
330         if (!sk)
331                 return 0;
333         /* function can be called with NULL endpoints, so it is effective for
334          * Rx sockets and a no-op for Tx sockets
335          */
336         rpmsg_destroy_ept(rpsk->endpt);
338         sock_put(sock->sk);
340         return 0;
343 /* Notes:
344  * - calling connect after bind isn't currently supported (is it even needed?).
345  * - userspace arguments to bind aren't intuitive: one needs to provide
346  *   the vproc id of the remote processor that the channel needs to be shared
347  *   with, and the -local- source address the channel is to be bound with
348  */
349 static int
350 rpmsg_sock_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
352         struct sock *sk = sock->sk;
353         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
354         struct rpmsg_device *rpdev;
355         struct rpmsg_endpoint *endpt;
356         struct rpmsg_channel_info chinfo = {};
357         struct sockaddr_rpmsg *sa = (struct sockaddr_rpmsg *)uaddr;
359         if (sock->state == SS_CONNECTED)
360                 return -EINVAL;
362         if (addr_len != sizeof(*sa))
363                 return -EINVAL;
365         if (sa->family != AF_RPMSG)
366                 return -EINVAL;
368         if (rpsk->endpt)
369                 return -EBUSY;
371         if (sk->sk_state != RPMSG_OPEN)
372                 return -EINVAL;
374         rpdev = radix_tree_lookup(&rpmsg_channels, sa->vproc_id);
375         if (!rpdev)
376                 return -EINVAL;
378         /* bind this socket with a receiving endpoint */
379         chinfo.src = sa->addr;
380         chinfo.dst = RPMSG_ADDR_ANY;
381         endpt = rpmsg_create_ept(rpdev, rpmsg_sock_cb, sk, chinfo);
382         if (!endpt)
383                 return -EINVAL;
385         rpsk->rpdev = rpdev;
386         rpsk->endpt = endpt;
387         rpsk->rproc_id = sa->vproc_id;
389         sk->sk_state = RPMSG_LISTENING;
391         return 0;
394 static const struct proto_ops rpmsg_sock_ops = {
395         .family         = PF_RPMSG,
396         .owner          = THIS_MODULE,
398         .release        = rpmsg_sock_release,
399         .connect        = rpmsg_sock_connect,
400         .getname        = rpmsg_sock_getname,
401         .sendmsg        = rpmsg_sock_sendmsg,
402         .recvmsg        = rpmsg_sock_recvmsg,
403         .poll           = rpmsg_sock_poll,
404         .bind           = rpmsg_sock_bind,
406         .listen         = sock_no_listen,
407         .accept         = sock_no_accept,
408         .ioctl          = sock_no_ioctl,
409         .mmap           = sock_no_mmap,
410         .socketpair     = sock_no_socketpair,
411         .shutdown       = sock_no_shutdown,
412         .setsockopt     = sock_no_setsockopt,
413         .getsockopt     = sock_no_getsockopt
414 };
416 static void rpmsg_sock_destruct(struct sock *sk)
420 static int rpmsg_sock_create(struct net *net, struct socket *sock, int proto,
421                              int kern)
423         struct sock *sk;
424         struct rpmsg_socket *rpsk;
426         if (sock->type != SOCK_SEQPACKET)
427                 return -ESOCKTNOSUPPORT;
428         if (proto != 0)
429                 return -EPROTONOSUPPORT;
431         sk = sk_alloc(net, PF_RPMSG, GFP_KERNEL, &rpmsg_proto, kern);
432         if (!sk)
433                 return -ENOMEM;
435         sock->state = SS_UNCONNECTED;
436         sock->ops = &rpmsg_sock_ops;
437         sock_init_data(sock, sk);
439         sk->sk_destruct = rpmsg_sock_destruct;
440         sk->sk_protocol = proto;
442         sk->sk_state = RPMSG_OPEN;
444         rpsk = container_of(sk, struct rpmsg_socket, sk);
445         /* use RPMSG_LOCALHOST to serve as an invalid value */
446         rpsk->rproc_id = RPMSG_LOCALHOST;
448         return 0;
451 static const struct net_proto_family rpmsg_proto_family = {
452         .family = PF_RPMSG,
453         .create = rpmsg_sock_create,
454         .owner = THIS_MODULE,
455 };
457 static int __rpmsg_sock_cb(struct device *dev, int from_vproc_id, void *data,
458                            int len, struct sock *sk, u32 src)
460         struct rpmsg_socket *rpsk = container_of(sk, struct rpmsg_socket, sk);
461         struct sk_buff *skb;
462         int ret;
464 #if defined(CONFIG_DYNAMIC_DEBUG)
465         dynamic_hex_dump("rpmsg_proto Rx: ", DUMP_PREFIX_NONE, 16, 1, data,
466                          len, true);
467 #endif
469         lock_sock(sk);
471         switch (sk->sk_state) {
472         case RPMSG_CONNECTED:
473                 if (rpsk->rpdev->dst != src)
474                         dev_warn(dev, "unexpected source address: %d\n", src);
475                 break;
476         case RPMSG_LISTENING:
477                 /* When an inbound message is received while we're listening,
478                  * we implicitly become connected
479                  */
480                 sk->sk_state = RPMSG_CONNECTED;
481                 rpsk->rpdev->dst = src;
482                 break;
483         default:
484                 dev_warn(dev, "unexpected inbound message (from %d)\n", src);
485                 break;
486         }
488         skb = sock_alloc_send_skb(sk, len, 1, &ret);
489         if (!skb) {
490                 dev_err(dev, "sock_alloc_send_skb failed: %d\n", ret);
491                 ret = -ENOMEM;
492                 goto out;
493         }
495         RPMSG_CB(skb).vproc_id = from_vproc_id;
496         RPMSG_CB(skb).addr = src;
497         RPMSG_CB(skb).family = AF_RPMSG;
499         memcpy(skb_put(skb, len), data, len);
501         ret = sock_queue_rcv_skb(sk, skb);
502         if (ret) {
503                 dev_err(dev, "sock_queue_rcv_skb failed: %d\n", ret);
504                 kfree_skb(skb);
505         }
507 out:
508         release_sock(sk);
509         return ret;
512 static int rpmsg_sock_cb(struct rpmsg_device *rpdev, void *data, int len,
513                          void *priv, u32 src)
515         int id = rpmsg_sock_get_proc_id(rpdev);
517         return __rpmsg_sock_cb(&rpdev->dev, id, data, len, priv, src);
520 static int rpmsg_proto_cb(struct rpmsg_device *rpdev, void *data, int len,
521                           void *priv, u32 src)
523         dev_err(&rpdev->dev, "rpmsg_proto device not designed to receive any messages\n");
524         return 0;
527 static int rpmsg_proto_probe(struct rpmsg_device *rpdev)
529         struct device *dev = &rpdev->dev;
530         int ret, dst = rpdev->dst, id;
531         struct rpmsg_device *vrp_dev;
533         if (WARN_ON(dst == RPMSG_ADDR_ANY))
534                 return -EINVAL;
536         id = rpmsg_sock_get_proc_id(rpdev);
537         if (id < 0)
538                 return -EINVAL;
540         mutex_lock(&rpmsg_channels_lock);
542         /* are we exposing a rpmsg proto device for this remote processor yet?
543          * If not, associate id/device for later lookup in rpmsg_sock_bind().
544          * Multiple devices per remote processor are not supported.
545          */
546         vrp_dev = radix_tree_lookup(&rpmsg_channels, id);
547         if (!vrp_dev) {
548                 ret = radix_tree_insert(&rpmsg_channels, id, rpdev);
549                 if (ret) {
550                         dev_err(dev, "radix_tree_insert failed: %d\n", ret);
551                         goto out;
552                 }
553         } else {
554                 ret = -ENODEV;
555                 dev_err(dev, "multiple rpmsg-proto devices from the same rproc is not supported.\n");
556                 goto out;
557         }
559 out:
560         mutex_unlock(&rpmsg_channels_lock);
562         return ret;
565 static void rpmsg_proto_remove(struct rpmsg_device *rpdev)
567         struct device *dev = &rpdev->dev;
568         int id, dst = rpdev->dst;
569         struct rpmsg_device *vrp_dev;
571         if (dst == RPMSG_ADDR_ANY)
572                 return;
574         id = rpmsg_sock_get_proc_id(rpdev);
576         mutex_lock(&rpmsg_channels_lock);
578         vrp_dev = radix_tree_lookup(&rpmsg_channels, id);
579         if (!vrp_dev) {
580                 dev_err(dev, "can't find rpmsg device for rproc %d\n", id);
581                 goto out;
582         }
583         if (vrp_dev != rpdev)
584                 dev_err(dev, "can't match the stored rpdev for rproc %d\n", id);
586         if (!radix_tree_delete(&rpmsg_channels, id))
587                 dev_err(dev, "failed to delete rpdev for rproc %d\n", id);
589 out:
590         mutex_unlock(&rpmsg_channels_lock);
593 static struct rpmsg_device_id rpmsg_proto_id_table[] = {
594         { .name = "rpmsg-proto" },
595         { },
596 };
597 MODULE_DEVICE_TABLE(rpmsg, rpmsg_proto_id_table);
599 static struct rpmsg_driver rpmsg_proto_driver = {
600         .drv.name       = KBUILD_MODNAME,
601         .id_table       = rpmsg_proto_id_table,
602         .probe          = rpmsg_proto_probe,
603         .callback       = rpmsg_proto_cb,
604         .remove         = rpmsg_proto_remove,
605 };
607 static int __init rpmsg_proto_init(void)
609         int ret;
611         ret = proto_register(&rpmsg_proto, 0);
612         if (ret) {
613                 pr_err("proto_register failed: %d\n", ret);
614                 return ret;
615         }
617         ret = sock_register(&rpmsg_proto_family);
618         if (ret) {
619                 pr_err("sock_register failed: %d\n", ret);
620                 goto proto_unreg;
621         }
623         ret = register_rpmsg_driver(&rpmsg_proto_driver);
624         if (ret) {
625                 pr_err("register_rpmsg_driver failed: %d\n", ret);
626                 goto sock_unreg;
627         }
629         return 0;
631 sock_unreg:
632         sock_unregister(PF_RPMSG);
633 proto_unreg:
634         proto_unregister(&rpmsg_proto);
635         return ret;
638 static void __exit rpmsg_proto_exit(void)
640         unregister_rpmsg_driver(&rpmsg_proto_driver);
641         sock_unregister(PF_RPMSG);
642         proto_unregister(&rpmsg_proto);
645 module_init(rpmsg_proto_init);
646 module_exit(rpmsg_proto_exit);
648 MODULE_DESCRIPTION("Remote processor messaging protocol");
649 MODULE_LICENSE("GPL v2");
650 MODULE_ALIAS("rpmsg:rpmsg-proto");
651 MODULE_ALIAS_NETPROTO(AF_RPMSG);