From a8943376a2c4fd88a6b3932be80d8d6e4f6c86f3 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 7 Nov 2021 22:55:55 +0100 Subject: [PATCH 1/2] Adapt to mirage/mirage-protocols#28 changes: TCP and UDP layers (direct & socket): type ipinput is gone val handle no longer has a ~listener argument val listen : t -> ~port:int -> (?keepalive) -> callback -> unit val unlisten : t -> ~port:int -> unit On the code level: the listener hash tables moved from stack-* to tcp and udp (tcp/flow; udp/udp; tcp{v4,v6.v4v6}_socket; udp{v4,v6,v4v6}_socket). This moved quite some code in the socket stack from the tcpip_stack_socket to the respective tcp/udp layers. --- src/stack-direct/tcpip_stack_direct.ml | 102 ++------- src/stack-direct/tcpip_stack_direct.mli | 12 -- src/stack-unix/dune | 12 +- src/stack-unix/tcp_socket.ml | 15 +- src/stack-unix/tcpip_stack_socket.ml | 276 ++---------------------- src/stack-unix/tcpv4_socket.ml | 63 +++++- src/stack-unix/tcpv4_socket.mli | 3 +- src/stack-unix/tcpv4v6_socket.ml | 89 +++++++- src/stack-unix/tcpv4v6_socket.mli | 3 +- src/stack-unix/tcpv6_socket.ml | 69 +++++- src/stack-unix/tcpv6_socket.mli | 3 +- src/stack-unix/udpv4_socket.ml | 57 ++++- src/stack-unix/udpv4v6_socket.ml | 64 +++++- src/stack-unix/udpv6_socket.ml | 56 ++++- src/tcp/flow.ml | 33 +-- src/tcp/flow.mli | 1 - src/udp/udp.ml | 19 +- src/udp/udp.mli | 1 - tcpip.opam | 2 +- 19 files changed, 455 insertions(+), 425 deletions(-) diff --git a/src/stack-direct/tcpip_stack_direct.ml b/src/stack-direct/tcpip_stack_direct.ml index 956294e7f..1c70e1e94 100644 --- a/src/stack-direct/tcpip_stack_direct.ml +++ b/src/stack-direct/tcpip_stack_direct.ml @@ -19,14 +19,11 @@ open Lwt.Infix let src = Logs.Src.create "tcpip-stack-direct" ~doc:"Pure OCaml TCP/IP stack" module Log = (val Logs.src_log src : Logs.LOG) -type direct_ipv4_input = src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t module type UDPV4_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.V4.t - and type ipinput = direct_ipv4_input module type TCPV4_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.V4.t - and type ipinput = direct_ipv4_input module Make (Time : Mirage_time.S) @@ -51,8 +48,6 @@ module Make icmpv4: Icmpv4.t; udpv4 : Udpv4.t; tcpv4 : Tcpv4.t; - udpv4_listeners: (int, Udpv4.callback) Hashtbl.t; - tcpv4_listeners: (int, Tcpv4.listener) Hashtbl.t; mutable task : unit Lwt.t option; } @@ -64,26 +59,11 @@ module Make let udpv4 { udpv4; _ } = udpv4 let ipv4 { ipv4; _ } = ipv4 - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udpv4 t ~port callback = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.udpv4_listeners port callback - + Udpv4.listen t.udpv4 ~port callback let listen_tcpv4 ?keepalive t ~port process = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.tcpv4_listeners port { Tcpv4.process; keepalive } - - let udpv4_listeners t ~dst_port = - try Some (Hashtbl.find t.udpv4_listeners dst_port) - with Not_found -> None - - let tcpv4_listeners t dst_port = - try Some (Hashtbl.find t.tcpv4_listeners dst_port) - with Not_found -> None + Tcpv4.listen t.tcpv4 ~port ?keepalive process let listen t = Lwt.catch (fun () -> @@ -92,10 +72,8 @@ module Make ~arpv4:(Arpv4.input t.arpv4) ~ipv4:( Ipv4.input - ~tcp:(Tcpv4.input t.tcpv4 - ~listeners:(tcpv4_listeners t)) - ~udp:(Udpv4.input t.udpv4 - ~listeners:(udpv4_listeners t)) + ~tcp:(Tcpv4.input t.tcpv4) + ~udp:(Udpv4.input t.udpv4) ~default:(fun ~proto ~src ~dst buf -> match proto with | 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf @@ -127,10 +105,7 @@ module Make | e -> Lwt.fail e) let connect netif ethif arpv4 ipv4 icmpv4 udpv4 tcpv4 = - let udpv4_listeners = Hashtbl.create 7 in - let tcpv4_listeners = Hashtbl.create 7 in - let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4; - udpv4_listeners; tcpv4_listeners; task = None } in + let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4; task = None } in Log.info (fun f -> f "stack assembled: %a" pp t); Lwt.async (fun () -> let task = listen t in t.task <- Some task; task); Lwt.return t @@ -141,14 +116,11 @@ module Make Lwt.return_unit end -type direct_ipv6_input = src:Ipaddr.V6.t -> dst:Ipaddr.V6.t -> Cstruct.t -> unit Lwt.t module type UDPV6_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.V6.t - and type ipinput = direct_ipv6_input module type TCPV6_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.V6.t - and type ipinput = direct_ipv6_input module MakeV6 (Time : Mirage_time.S) @@ -169,8 +141,6 @@ module MakeV6 ipv6 : Ipv6.t; udpv6 : Udpv6.t; tcpv6 : Tcpv6.t; - udpv6_listeners: (int, Udpv6.callback) Hashtbl.t; - tcpv6_listeners: (int, Tcpv6.listener) Hashtbl.t; mutable task : unit Lwt.t option; } @@ -182,25 +152,11 @@ module MakeV6 let udp { udpv6; _ } = udpv6 let ip { ipv6; _ } = ipv6 - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udp t ~port callback = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.udpv6_listeners port callback + Udpv6.listen t.udpv6 ~port callback let listen_tcp ?keepalive t ~port process = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.tcpv6_listeners port { Tcpv6.process; keepalive } - - let udpv6_listeners t ~dst_port = - try Some (Hashtbl.find t.udpv6_listeners dst_port) - with Not_found -> None - - let tcpv6_listeners t dst_port = - try Some (Hashtbl.find t.tcpv6_listeners dst_port) - with Not_found -> None + Tcpv6.listen t.tcpv6 ~port ?keepalive process let listen t = Lwt.catch (fun () -> @@ -210,10 +166,8 @@ module MakeV6 ~ipv4:(fun _ -> Lwt.return_unit) ~ipv6:( Ipv6.input - ~tcp:(Tcpv6.input t.tcpv6 - ~listeners:(tcpv6_listeners t)) - ~udp:(Udpv6.input t.udpv6 - ~listeners:(udpv6_listeners t)) + ~tcp:(Tcpv6.input t.tcpv6) + ~udp:(Udpv6.input t.udpv6) ~default:(fun ~proto:_ ~src:_ ~dst:_ _ -> Lwt.return_unit) t.ipv6) t.ethif @@ -241,10 +195,7 @@ module MakeV6 | e -> Lwt.fail e) let connect netif ethif ipv6 udpv6 tcpv6 = - let udpv6_listeners = Hashtbl.create 7 in - let tcpv6_listeners = Hashtbl.create 7 in - let t = { netif; ethif; ipv6; tcpv6; udpv6; - udpv6_listeners; tcpv6_listeners; task = None } in + let t = { netif; ethif; ipv6; tcpv6; udpv6; task = None } in Log.info (fun f -> f "stack assembled: %a" pp t); Lwt.async (fun () -> let task = listen t in t.task <- Some task; task); Lwt.return t @@ -256,15 +207,11 @@ module MakeV6 end -type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t - module type UDPV4V6_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.t - and type ipinput = direct_ipv4v6_input module type TCPV4V6_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.t - and type ipinput = direct_ipv4v6_input module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) = struct @@ -404,8 +351,6 @@ module MakeV4V6 ip : IP.t; udp : Udp.t; tcp : Tcp.t; - udp_listeners: (int, Udp.callback) Hashtbl.t; - tcp_listeners: (int, Tcp.listener) Hashtbl.t; mutable task : unit Lwt.t option; } @@ -417,31 +362,17 @@ module MakeV4V6 let udp { udp; _ } = udp let ip { ip; _ } = ip - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udp t ~port callback = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.udp_listeners port callback + Udp.listen t.udp ~port callback let listen_tcp ?keepalive t ~port process = - if port < 0 || port > 65535 - then raise (Invalid_argument (err_invalid_port port)) - else Hashtbl.replace t.tcp_listeners port { Tcp.process; keepalive } - - let udp_listeners t ~dst_port = - try Some (Hashtbl.find t.udp_listeners dst_port) - with Not_found -> None - - let tcp_listeners t dst_port = - try Some (Hashtbl.find t.tcp_listeners dst_port) - with Not_found -> None + Tcp.listen t.tcp ~port ?keepalive process let listen t = Lwt.catch (fun () -> Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t); - let tcp = Tcp.input t.tcp ~listeners:(tcp_listeners t) - and udp = Udp.input t.udp ~listeners:(udp_listeners t) + let tcp = Tcp.input t.tcp + and udp = Udp.input t.udp and default ~proto ~src ~dst buf = match proto, src, dst with | 1, Ipaddr.V4 src, Ipaddr.V4 dst -> Icmpv4.input t.icmpv4 ~src ~dst buf @@ -476,10 +407,7 @@ module MakeV4V6 | e -> Lwt.fail e) let connect netif ethif arpv4 ip icmpv4 udp tcp = - let udp_listeners = Hashtbl.create 7 in - let tcp_listeners = Hashtbl.create 7 in - let t = { netif; ethif; arpv4; ip; icmpv4; tcp; udp; - udp_listeners; tcp_listeners; task = None } in + let t = { netif; ethif; arpv4; ip; icmpv4; tcp; udp; task = None } in Log.info (fun f -> f "stack assembled: %a" pp t); Lwt.async (fun () -> let task = listen t in t.task <- Some task; task); Lwt.return t diff --git a/src/stack-direct/tcpip_stack_direct.mli b/src/stack-direct/tcpip_stack_direct.mli index 4f8da4f3c..0a11fa5b2 100644 --- a/src/stack-direct/tcpip_stack_direct.mli +++ b/src/stack-direct/tcpip_stack_direct.mli @@ -14,15 +14,11 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -type direct_ipv4_input = src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t - module type UDPV4_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.V4.t - and type ipinput = direct_ipv4_input module type TCPV4_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.V4.t - and type ipinput = direct_ipv4_input module Make (Time : Mirage_time.S) @@ -48,15 +44,11 @@ module Make connections, they will be able to do so. *) end -type direct_ipv6_input = src:Ipaddr.V6.t -> dst:Ipaddr.V6.t -> Cstruct.t -> unit Lwt.t - module type UDPV6_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.V6.t - and type ipinput = direct_ipv6_input module type TCPV6_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.V6.t - and type ipinput = direct_ipv6_input module MakeV6 (Time : Mirage_time.S) @@ -79,15 +71,11 @@ module MakeV6 they will be able to do so. *) end -type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t - module type UDPV4V6_DIRECT = Mirage_protocols.UDP with type ipaddr = Ipaddr.t - and type ipinput = direct_ipv4v6_input module type TCPV4V6_DIRECT = Mirage_protocols.TCP with type ipaddr = Ipaddr.t - and type ipinput = direct_ipv4v6_input module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) : sig include Mirage_protocols.IP with type ipaddr = Ipaddr.t diff --git a/src/stack-unix/dune b/src/stack-unix/dune index 035eadb9a..f954ff651 100644 --- a/src/stack-unix/dune +++ b/src/stack-unix/dune @@ -15,7 +15,7 @@ (wrapped false) (instrumentation (backend bisect_ppx)) - (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols)) + (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs)) (library (name udpv6_socket) @@ -24,7 +24,7 @@ (wrapped false) (instrumentation (backend bisect_ppx)) - (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols)) + (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs)) (library (name udpv4v6_socket) @@ -33,7 +33,7 @@ (wrapped false) (instrumentation (backend bisect_ppx)) - (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols)) + (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs)) (library (name tcp_socket_options) @@ -56,7 +56,7 @@ (instrumentation (backend bisect_ppx)) (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols - tcp_socket_options)) + tcp_socket_options logs)) (library (name tcpv6_socket) @@ -66,7 +66,7 @@ (instrumentation (backend bisect_ppx)) (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols - tcpv4_socket tcp_socket_options)) + tcpv4_socket tcp_socket_options logs)) (library (name tcpv4v6_socket) @@ -76,7 +76,7 @@ (instrumentation (backend bisect_ppx)) (libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols - tcpv4_socket tcp_socket_options)) + tcpv4_socket tcp_socket_options logs)) (library (name tcpip_stack_socket) diff --git a/src/stack-unix/tcp_socket.ml b/src/stack-unix/tcp_socket.ml index 510a41597..b37a9d26d 100644 --- a/src/stack-unix/tcp_socket.ml +++ b/src/stack-unix/tcp_socket.ml @@ -11,6 +11,10 @@ let pp_write_error ppf = function | #Mirage_protocols.Tcp.write_error as e -> Mirage_protocols.Tcp.pp_write_error ppf e | `Exn e -> Fmt.exn ppf e +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + let disconnect _ = return_unit @@ -61,13 +65,4 @@ let close fd = | Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit | e -> Lwt.fail e) -type listener = { - process: Lwt_unix.file_descr -> unit Lwt.t; - keepalive: Mirage_protocols.Keepalive.t option; -} - -(* FIXME: how does this work at all ?? *) -let input _t ~listeners:_ = - (* TODO terminate when signalled by disconnect *) - let t, _ = Lwt.task () in - t +let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit diff --git a/src/stack-unix/tcpip_stack_socket.ml b/src/stack-unix/tcpip_stack_socket.ml index f706e722e..66ee4ee24 100644 --- a/src/stack-unix/tcpip_stack_socket.ml +++ b/src/stack-unix/tcpip_stack_socket.ml @@ -19,12 +19,6 @@ open Lwt.Infix let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack" module Log = (val Logs.src_log src : Logs.LOG) -let ignore_canceled = function - | Lwt.Canceled -> Lwt.return_unit - | exn -> raise exn - -let safe_close = Tcp_socket.close - module V4 = struct module TCPV4 = Tcpv4_socket module UDPV4 = Udpv4_socket @@ -35,98 +29,28 @@ module V4 = struct tcpv4 : TCPV4.t; stop : unit Lwt.u; switched_off : unit Lwt.t; - mutable active_fds : Lwt_unix.file_descr list; } let udpv4 { udpv4; _ } = udpv4 let tcpv4 { tcpv4; _ } = tcpv4 let ipv4 _ = () - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udpv4 t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - UDPV4.get_udpv4_listening_fd t.udpv4 port >>= fun (_, fd) -> - let buf = Cstruct.create 4096 in - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> - let buf = Cstruct.sub buf 0 len in - (match sa with - | Lwt_unix.ADDR_INET (addr, src_port) -> - let src = Ipaddr_unix.V4.of_inet_addr_exn addr in - let dst = Ipaddr.V4.any in (* TODO *) - callback ~src ~dst ~src_port buf - | _ -> Lwt.return_unit) >|= fun () -> - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> - safe_close fd) + UDPV4.listen t.udpv4 ~port callback let listen_tcpv4 ?keepalive t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - t.active_fds <- fd :: t.active_fds; - Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - Unix.bind (Lwt_unix.unix_file_descr fd) (Unix.ADDR_INET (t.udpv4.interface, port)); - Lwt_unix.listen fd 10; - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - (* TODO cancellation *) - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_unix.accept fd >|= fun (afd, _) -> - t.active_fds <- afd :: t.active_fds; - (match keepalive with - | None -> () - | Some { Mirage_protocols.Keepalive.after; interval; probes } -> - Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); - Lwt.async - (fun () -> - Lwt.catch - (fun () -> callback afd) - (fun exn -> - Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; - safe_close afd)); - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> safe_close fd) + TCPV4.listen t.tcpv4 ~port ?keepalive callback let listen t = t.switched_off let connect udpv4 tcpv4 = Log.info (fun f -> f "IPv4 socket stack: connect"); let switched_off, stop = Lwt.wait () in - Lwt.return { tcpv4; udpv4; stop; switched_off; active_fds = []; } + TCPV4.set_switched_off tcpv4 switched_off; + UDPV4.set_switched_off udpv4 switched_off; + Lwt.return { tcpv4; udpv4; stop; switched_off } let disconnect t = - Lwt_list.iter_p safe_close t.active_fds >>= fun () -> TCPV4.disconnect t.tcpv4 >>= fun () -> UDPV4.disconnect t.udpv4 >|= fun () -> Lwt.wakeup_later t.stop () @@ -142,98 +66,28 @@ module V6 = struct tcp : TCP.t; stop : unit Lwt.u; switched_off : unit Lwt.t; - mutable active_fds : Lwt_unix.file_descr list; } let udp { udp; _ } = udp let tcp { tcp; _ } = tcp let ip _ = () - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udp t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - UDP.get_udpv6_listening_fd t.udp port >>= fun (_, fd) -> - let buf = Cstruct.create 4096 in - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> - let buf = Cstruct.sub buf 0 len in - (match sa with - | Lwt_unix.ADDR_INET (addr, src_port) -> - let src = Ipaddr_unix.V6.of_inet_addr_exn addr in - let dst = Ipaddr.V6.unspecified in (* TODO *) - callback ~src ~dst ~src_port buf - | _ -> Lwt.return_unit) >|= fun () -> - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> safe_close fd) + UDP.listen t.udp ~port callback let listen_tcp ?keepalive t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in - t.active_fds <- fd :: t.active_fds; - Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - Lwt_unix.(setsockopt fd IPV6_ONLY true); - Unix.bind (Lwt_unix.unix_file_descr fd) (Lwt_unix.ADDR_INET (t.udp.interface, port)); - Lwt_unix.listen fd 10; - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - (* TODO cancellation *) - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_unix.accept fd >|= fun (afd, _) -> - t.active_fds <- afd :: t.active_fds; - (match keepalive with - | None -> () - | Some { Mirage_protocols.Keepalive.after; interval; probes } -> - Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); - Lwt.async - (fun () -> - Lwt.catch - (fun () -> callback afd) - (fun exn -> - Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; - safe_close afd)); - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> safe_close fd) + TCP.listen t.tcp ~port ?keepalive callback let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "IPv6 socket stack: connect"); let switched_off, stop = Lwt.wait () in - Lwt.return { tcp; udp; stop; switched_off; active_fds = []; } + UDP.set_switched_off udp switched_off; + TCP.set_switched_off tcp switched_off; + Lwt.return { tcp; udp; stop; switched_off } let disconnect t = - Lwt_list.iter_p safe_close t.active_fds >>= fun () -> TCP.disconnect t.tcp >>= fun () -> UDP.disconnect t.udp >|= fun () -> Lwt.wakeup_later t.stop () @@ -249,128 +103,28 @@ module V4V6 = struct tcp : TCP.t; stop : unit Lwt.u; switched_off : unit Lwt.t; - mutable active_fds : Lwt_unix.file_descr list; } let udp { udp; _ } = udp let tcp { tcp; _ } = tcp let ip _ = () - let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p - let listen_udp t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - UDP.get_udpv4v6_listening_fd t.udp port >|= fun (_, fds) -> - t.active_fds <- fds @ t.active_fds; - List.iter (fun fd -> - Lwt.async (fun () -> - let buf = Cstruct.create 4096 in - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> - let buf = Cstruct.sub buf 0 len in - (match sa with - | Lwt_unix.ADDR_INET (addr, src_port) -> - let src = Ipaddr_unix.of_inet_addr addr in - let src = - match Ipaddr.to_v4 src with - | None -> src - | Some v4 -> Ipaddr.V4 v4 - in - let dst = Ipaddr.(V6 V6.unspecified) in (* TODO *) - callback ~src ~dst ~src_port buf - | _ -> Lwt.return_unit) >|= fun () -> - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> safe_close fd)) fds) + UDP.listen t.udp ~port callback let listen_tcp ?keepalive t ~port callback = - if port < 0 || port > 65535 then - raise (Invalid_argument (err_invalid_port port)) - else - let fds = - match t.udp.interface with - | `Any -> - let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in - Lwt_unix.(setsockopt fd SO_REUSEADDR true); - Lwt_unix.(setsockopt fd IPV6_ONLY false); - [ (fd, Lwt_unix.ADDR_INET (UDP.any_v6, port)) ] - | `Ip (v4, v6) -> - let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - Lwt_unix.(setsockopt fd SO_REUSEADDR true); - let fd' = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in - Lwt_unix.(setsockopt fd' SO_REUSEADDR true); - Lwt_unix.(setsockopt fd' IPV6_ONLY true); - [ (fd, Lwt_unix.ADDR_INET (v4, port)) ; (fd', Lwt_unix.ADDR_INET (v6, port)) ] - | `V4_only ip -> - let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; - [ (fd, Lwt_unix.ADDR_INET (ip, port)) ] - | `V6_only ip -> - let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in - Lwt_unix.(setsockopt fd SO_REUSEADDR true); - Lwt_unix.(setsockopt fd IPV6_ONLY true); - [ (fd, Lwt_unix.ADDR_INET (ip, port)) ] - in - t.active_fds <- List.map fst fds @ t.active_fds; - List.iter (fun (fd, addr) -> - Unix.bind (Lwt_unix.unix_file_descr fd) addr; - Lwt_unix.listen fd 10; - (* FIXME: we should not ignore the result *) - Lwt.async (fun () -> - (* TODO cancellation *) - let rec loop () = - if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; - Lwt.catch (fun () -> - Lwt_unix.accept fd >|= fun (afd, _) -> - t.active_fds <- afd :: t.active_fds; - (match keepalive with - | None -> () - | Some { Mirage_protocols.Keepalive.after; interval; probes } -> - Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); - Lwt.async - (fun () -> - Lwt.catch - (fun () -> callback afd) - (fun exn -> - Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; - safe_close afd)); - `Continue) - (function - | Unix.Unix_error (Unix.EBADF, _, _) -> - Log.warn (fun m -> m "error bad file descriptor in accept") ; - Lwt.return `Stop - | exn -> - Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; - Lwt.return `Continue) >>= function - | `Continue -> loop () - | `Stop -> Lwt.return_unit - in - Lwt.catch loop ignore_canceled >>= fun () -> safe_close fd)) fds + TCP.listen t.tcp ~port ?keepalive callback let listen t = t.switched_off let connect udp tcp = Log.info (fun f -> f "Dual IPv4 and IPv6 socket stack: connect"); let switched_off, stop = Lwt.wait () in - Lwt.return { tcp; udp; stop; switched_off; active_fds = [] } + UDP.set_switched_off udp switched_off; + TCP.set_switched_off tcp switched_off; + Lwt.return { tcp; udp; stop; switched_off } let disconnect t = - Lwt_list.iter_p safe_close t.active_fds >>= fun () -> TCP.disconnect t.tcp >>= fun () -> UDP.disconnect t.udp >|= fun () -> Lwt.wakeup_later t.stop () diff --git a/src/stack-unix/tcpv4_socket.ml b/src/stack-unix/tcpv4_socket.ml index e55b4caf4..9f8f472c3 100644 --- a/src/stack-unix/tcpv4_socket.ml +++ b/src/stack-unix/tcpv4_socket.ml @@ -14,15 +14,19 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "tcpv4-socket" ~doc:"TCP socket v4 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.V4.t type flow = Lwt_unix.file_descr -type ipinput = unit Lwt.t type t = { interface: Unix.inet_addr; (* source ip to bind to *) mutable active_connections : Lwt_unix.file_descr list; + listen_sockets : (int, Lwt_unix.file_descr) Hashtbl.t; + mutable switched_off : unit Lwt.t; } include Tcp_socket @@ -31,10 +35,17 @@ let connect addr = let t = { interface = Ipaddr_unix.V4.to_inet_addr (Ipaddr.V4.Prefix.address addr); active_connections = []; + listen_sockets = Hashtbl.create 7; + switched_off = Lwt.return_unit; } in Lwt.return t -let disconnect t = Lwt_list.iter_p close t.active_connections +let set_switched_off t switched_off = t.switched_off <- switched_off + +let disconnect t = + Lwt_list.iter_p close t.active_connections >>= fun () -> + Lwt_list.iter_p close + (Hashtbl.fold (fun _ fd acc -> fd :: acc) t.listen_sockets []) let dst fd = match Lwt_unix.getpeername fd with @@ -62,3 +73,51 @@ let create_connection ?keepalive t (dst,dst_port) = (fun exn -> close fd >|= fun () -> Error (`Exn exn)) + +let unlisten t ~port = + match Hashtbl.find_opt t.listen_sockets port with + | None -> () + | Some fd -> + Hashtbl.remove t.listen_sockets port; + try Unix.close (Lwt_unix.unix_file_descr fd) with _ -> () + +let listen t ~port ?keepalive callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); + unlisten t ~port; + let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in + Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; + Unix.bind (Lwt_unix.unix_file_descr fd) (Unix.ADDR_INET (t.interface, port)); + Hashtbl.replace t.listen_sockets port fd; + Lwt_unix.listen fd 10; + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + (* TODO cancellation *) + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_unix.accept fd >|= fun (afd, _) -> + t.active_connections <- afd :: t.active_connections; + (match keepalive with + | None -> () + | Some { Mirage_protocols.Keepalive.after; interval; probes } -> + Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); + Lwt.async + (fun () -> + Lwt.catch + (fun () -> callback afd) + (fun exn -> + Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; + close afd)); + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> close fd) diff --git a/src/stack-unix/tcpv4_socket.mli b/src/stack-unix/tcpv4_socket.mli index 6b5b86bfd..85b0ef377 100644 --- a/src/stack-unix/tcpv4_socket.mli +++ b/src/stack-unix/tcpv4_socket.mli @@ -16,7 +16,6 @@ include Mirage_protocols.TCP with type ipaddr = Ipaddr.V4.t - and type ipinput = unit Lwt.t and type flow = Lwt_unix.file_descr and type error = [ Mirage_protocols.Tcp.error | `Exn of exn ] and type write_error = [ Mirage_protocols.Tcp.write_error | `Exn of exn ] @@ -24,3 +23,5 @@ include Mirage_protocols.TCP val connect : Ipaddr.V4.Prefix.t -> t Lwt.t val disconnect : t -> unit Lwt.t + +val set_switched_off : t -> unit Lwt.t -> unit diff --git a/src/stack-unix/tcpv4v6_socket.ml b/src/stack-unix/tcpv4v6_socket.ml index 9e9e76cef..15dcc2c19 100644 --- a/src/stack-unix/tcpv4v6_socket.ml +++ b/src/stack-unix/tcpv4v6_socket.ml @@ -15,24 +15,31 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "tcpv4v6-socket" ~doc:"TCP socket v4v6 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.t type flow = Lwt_unix.file_descr -type ipinput = unit Lwt.t type t = { interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *) mutable active_connections : Lwt_unix.file_descr list; + listen_sockets : (int, Lwt_unix.file_descr list) Hashtbl.t; + mutable switched_off : unit Lwt.t; } +let set_switched_off t switched_off = t.switched_off <- switched_off + +let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified + include Tcp_socket let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = let interface = let v4 = Ipaddr.V4.Prefix.address ipv4 in let v4_unix = Ipaddr_unix.V4.to_inet_addr v4 in - let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified in if ipv4_only then `V4_only v4_unix else if ipv6_only then @@ -50,9 +57,12 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = else `Ip (v4_unix, Ipaddr_unix.V6.to_inet_addr v6) in - Lwt.return {interface; active_connections = []} + Lwt.return {interface; active_connections = []; listen_sockets = Hashtbl.create 7; switched_off = Lwt.return_unit} -let disconnect t = Lwt_list.iter_p close t.active_connections +let disconnect t = + Lwt_list.iter_p close t.active_connections >>= fun () -> + Lwt_list.iter_p close + (Hashtbl.fold (fun _ fd acc -> fd @ acc) t.listen_sockets []) let dst fd = match Lwt_unix.getpeername fd with @@ -97,3 +107,74 @@ let create_connection ?keepalive t (dst,dst_port) = (fun exn -> close fd >>= fun () -> Lwt.return (Error (`Exn exn))) + +let unlisten t ~port = + match Hashtbl.find_opt t.listen_sockets port with + | None -> () + | Some fds -> + Hashtbl.remove t.listen_sockets port; + try List.iter (fun fd -> Unix.close (Lwt_unix.unix_file_descr fd)) fds with _ -> () + +let listen t ~port ?keepalive callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); + unlisten t ~port; + let fds = + match t.interface with + | `Any -> + let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in + Lwt_unix.(setsockopt fd SO_REUSEADDR true); + Lwt_unix.(setsockopt fd IPV6_ONLY false); + [ (fd, Lwt_unix.ADDR_INET (any_v6, port)) ] + | `Ip (v4, v6) -> + let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in + Lwt_unix.(setsockopt fd SO_REUSEADDR true); + let fd' = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in + Lwt_unix.(setsockopt fd' SO_REUSEADDR true); + Lwt_unix.(setsockopt fd' IPV6_ONLY true); + [ (fd, Lwt_unix.ADDR_INET (v4, port)) ; (fd', Lwt_unix.ADDR_INET (v6, port)) ] + | `V4_only ip -> + let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in + Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; + [ (fd, Lwt_unix.ADDR_INET (ip, port)) ] + | `V6_only ip -> + let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in + Lwt_unix.(setsockopt fd SO_REUSEADDR true); + Lwt_unix.(setsockopt fd IPV6_ONLY true); + [ (fd, Lwt_unix.ADDR_INET (ip, port)) ] + in + List.iter (fun (fd, addr) -> + Unix.bind (Lwt_unix.unix_file_descr fd) addr; + Hashtbl.replace t.listen_sockets port (List.map fst fds); + Lwt_unix.listen fd 10; + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + (* TODO cancellation *) + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_unix.accept fd >|= fun (afd, _) -> + t.active_connections <- afd :: t.active_connections; + (match keepalive with + | None -> () + | Some { Mirage_protocols.Keepalive.after; interval; probes } -> + Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); + Lwt.async + (fun () -> + Lwt.catch + (fun () -> callback afd) + (fun exn -> + Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; + close afd)); + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> close fd)) fds diff --git a/src/stack-unix/tcpv4v6_socket.mli b/src/stack-unix/tcpv4v6_socket.mli index 0ca1ecfb1..c6f3860bd 100644 --- a/src/stack-unix/tcpv4v6_socket.mli +++ b/src/stack-unix/tcpv4v6_socket.mli @@ -17,9 +17,10 @@ include Mirage_protocols.TCP with type ipaddr = Ipaddr.t - and type ipinput = unit Lwt.t and type flow = Lwt_unix.file_descr and type error = [ Mirage_protocols.Tcp.error | `Exn of exn ] and type write_error = [ Mirage_protocols.Tcp.write_error | `Exn of exn ] val connect : ipv4_only:bool -> ipv6_only:bool -> Ipaddr.V4.Prefix.t -> Ipaddr.V6.Prefix.t option -> t Lwt.t + +val set_switched_off : t -> unit Lwt.t -> unit diff --git a/src/stack-unix/tcpv6_socket.ml b/src/stack-unix/tcpv6_socket.ml index aa17cfd9d..8b0abc320 100644 --- a/src/stack-unix/tcpv6_socket.ml +++ b/src/stack-unix/tcpv6_socket.ml @@ -15,17 +15,23 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "tcpv6-socket" ~doc:"TCP socket v6 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.V6.t type flow = Lwt_unix.file_descr -type ipinput = unit Lwt.t type t = { interface: Unix.inet_addr; (* source ip to bind to *) mutable active_connections : Lwt_unix.file_descr list; + listen_sockets : (int, Lwt_unix.file_descr) Hashtbl.t; + mutable switched_off : unit Lwt.t; } +let set_switched_off t switched_off = t.switched_off <- switched_off + include Tcp_socket let connect addr = @@ -34,9 +40,17 @@ let connect addr = | None -> Ipaddr.V6.unspecified | Some ip -> Ipaddr.V6.Prefix.address ip in - Lwt.return { interface = Ipaddr_unix.V6.to_inet_addr ip; active_connections = []; } + Lwt.return { + interface = Ipaddr_unix.V6.to_inet_addr ip; + active_connections = []; + listen_sockets = Hashtbl.create 7; + switched_off = Lwt.return_unit + } -let disconnect t = Lwt_list.iter_p close t.active_connections +let disconnect t = + Lwt_list.iter_p close t.active_connections >>= fun () -> + Lwt_list.iter_p close + (Hashtbl.fold (fun _ fd acc -> fd :: acc) t.listen_sockets []) let dst fd = match Lwt_unix.getpeername fd with @@ -65,3 +79,52 @@ let create_connection ?keepalive t (dst,dst_port) = (fun exn -> close fd >>= fun () -> Lwt.return (Error (`Exn exn))) + +let unlisten t ~port = + match Hashtbl.find_opt t.listen_sockets port with + | None -> () + | Some fd -> + Hashtbl.remove t.listen_sockets port; + try Unix.close (Lwt_unix.unix_file_descr fd) with _ -> () + +let listen t ~port ?keepalive callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); + unlisten t ~port; + let fd = Lwt_unix.(socket PF_INET6 SOCK_STREAM 0) in + Lwt_unix.setsockopt fd Lwt_unix.SO_REUSEADDR true; + Lwt_unix.(setsockopt fd IPV6_ONLY true); + Unix.bind (Lwt_unix.unix_file_descr fd) (Lwt_unix.ADDR_INET (t.interface, port)); + Hashtbl.replace t.listen_sockets port fd; + Lwt_unix.listen fd 10; + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + (* TODO cancellation *) + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_unix.accept fd >|= fun (afd, _) -> + t.active_connections <- afd :: t.active_connections; + (match keepalive with + | None -> () + | Some { Mirage_protocols.Keepalive.after; interval; probes } -> + Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes); + Lwt.async + (fun () -> + Lwt.catch + (fun () -> callback afd) + (fun exn -> + Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ; + close afd)); + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> close fd) diff --git a/src/stack-unix/tcpv6_socket.mli b/src/stack-unix/tcpv6_socket.mli index 48a9c06da..f060ecc0d 100644 --- a/src/stack-unix/tcpv6_socket.mli +++ b/src/stack-unix/tcpv6_socket.mli @@ -17,7 +17,6 @@ include Mirage_protocols.TCP with type ipaddr = Ipaddr.V6.t - and type ipinput = unit Lwt.t and type flow = Lwt_unix.file_descr and type error = [ Mirage_protocols.Tcp.error | `Exn of exn ] and type write_error = [ Mirage_protocols.Tcp.write_error | `Exn of exn ] @@ -25,3 +24,5 @@ include Mirage_protocols.TCP val connect : Ipaddr.V6.Prefix.t option -> t Lwt.t val disconnect : t -> unit Lwt.t + +val set_switched_off : t -> unit Lwt.t -> unit diff --git a/src/stack-unix/udpv4_socket.ml b/src/stack-unix/udpv4_socket.ml index 8ff3e2604..b4fabc986 100644 --- a/src/stack-unix/udpv4_socket.ml +++ b/src/stack-unix/udpv4_socket.ml @@ -14,18 +14,27 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "udpv4-socket" ~doc:"UDP socket v4 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.V4.t -type ipinput = unit Lwt.t type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t type t = { interface: Unix.inet_addr; (* source ip to bind to *) listen_fds: ((Unix.inet_addr * int),Lwt_unix.file_descr) Hashtbl.t; (* UDPv4 fds bound to a particular source ip/port *) + mutable switched_off: unit Lwt.t; } -let get_udpv4_listening_fd ?(preserve = true) {listen_fds;interface} port = +let set_switched_off t switched_off = t.switched_off <- switched_off + +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + +let get_udpv4_listening_fd ?(preserve = true) {listen_fds;interface;_} port = try Lwt.return (false, Hashtbl.find listen_fds (interface,port)) with Not_found -> @@ -50,14 +59,14 @@ let connect ip = let t = let listen_fds = Hashtbl.create 7 in let interface = Ipaddr_unix.V4.to_inet_addr (Ipaddr.V4.Prefix.address ip) in - { interface; listen_fds } + { interface; listen_fds; switched_off = Lwt.return_unit } in Lwt.return t let disconnect t = Hashtbl.fold (fun _ fd r -> r >>= fun () -> close fd) t.listen_fds Lwt.return_unit -let input ~listeners:_ _ = Lwt.return_unit +let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = let open Lwt_unix in @@ -75,3 +84,43 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = write_to_fd fd buf >>= fun r -> (if created then close fd else Lwt.return_unit) >|= fun () -> r + +let unlisten t ~port = + try + let fd = Hashtbl.find t.listen_fds (t.interface, port) in + Hashtbl.remove t.listen_fds (t.interface, port); + Unix.close (Lwt_unix.unix_file_descr fd) + with _ -> () + +let listen t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); + unlisten t ~port; + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + get_udpv4_listening_fd t port >>= fun (_, fd) -> + let buf = Cstruct.create 4096 in + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> + let buf = Cstruct.sub buf 0 len in + (match sa with + | Lwt_unix.ADDR_INET (addr, src_port) -> + let src = Ipaddr_unix.V4.of_inet_addr_exn addr in + let dst = Ipaddr.V4.any in (* TODO *) + callback ~src ~dst ~src_port buf + | _ -> Lwt.return_unit) >|= fun () -> + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> + close fd) diff --git a/src/stack-unix/udpv4v6_socket.ml b/src/stack-unix/udpv4v6_socket.ml index 9ca37f994..8a624fb7c 100644 --- a/src/stack-unix/udpv4v6_socket.ml +++ b/src/stack-unix/udpv4v6_socket.ml @@ -15,10 +15,12 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "udpv4v6-socket" ~doc:"UDP socket v4v6 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.t -type ipinput = unit Lwt.t type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified @@ -26,9 +28,16 @@ let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified type t = { interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *) listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option) Hashtbl.t; (* UDP fds bound to a particular port *) + mutable switched_off : unit Lwt.t; } -let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;interface} port = +let set_switched_off t switched_off = t.switched_off <- switched_off + +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + +let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;interface;_} port = try Lwt.return (match Hashtbl.find listen_fds port with @@ -108,7 +117,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 = `Ip (v4_unix, Ipaddr_unix.V6.to_inet_addr v6) in let listen_fds = Hashtbl.create 7 in - Lwt.return { interface; listen_fds } + Lwt.return { interface; listen_fds; switched_off = Lwt.return_unit } let disconnect t = Hashtbl.fold (fun _ (fd, fd') r -> @@ -117,7 +126,7 @@ let disconnect t = match fd' with None -> Lwt.return_unit | Some fd -> close fd) t.listen_fds Lwt.return_unit -let input ~listeners:_ _ = Lwt.return_unit +let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = let open Lwt_unix in @@ -147,3 +156,50 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = (if created then close fd else Lwt.return_unit) >|= fun () -> r) | _ -> Lwt.return (Error `Different_ip_version) + +let unlisten t ~port = + try + let fd, fd' = Hashtbl.find t.listen_fds port in + Hashtbl.remove t.listen_fds port; + (match fd' with None -> () | Some fd' -> Unix.close (Lwt_unix.unix_file_descr fd')); + Unix.close (Lwt_unix.unix_file_descr fd) + with _ -> () + +let listen t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + get_udpv4v6_listening_fd t port >|= fun (_, fds) -> + List.iter (fun fd -> + Lwt.async (fun () -> + let buf = Cstruct.create 4096 in + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> + let buf = Cstruct.sub buf 0 len in + (match sa with + | Lwt_unix.ADDR_INET (addr, src_port) -> + let src = Ipaddr_unix.of_inet_addr addr in + let src = + match Ipaddr.to_v4 src with + | None -> src + | Some v4 -> Ipaddr.V4 v4 + in + let dst = Ipaddr.(V6 V6.unspecified) in (* TODO *) + callback ~src ~dst ~src_port buf + | _ -> Lwt.return_unit) >|= fun () -> + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> close fd)) fds) diff --git a/src/stack-unix/udpv6_socket.ml b/src/stack-unix/udpv6_socket.ml index c7ad602c6..e720aa1eb 100644 --- a/src/stack-unix/udpv6_socket.ml +++ b/src/stack-unix/udpv6_socket.ml @@ -15,18 +15,27 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +let src = Logs.Src.create "udpv6-socket" ~doc:"UDP socket v6 (platform native)" +module Log = (val Logs.src_log src : Logs.LOG) + open Lwt.Infix type ipaddr = Ipaddr.V6.t -type ipinput = unit Lwt.t type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t type t = { interface: Unix.inet_addr; (* source ip to bind to *) listen_fds: ((Unix.inet_addr * int),Lwt_unix.file_descr) Hashtbl.t; (* UDPv6 fds bound to a particular source ip/port *) + mutable switched_off : unit Lwt.t; } -let get_udpv6_listening_fd ?(preserve = true) {listen_fds;interface} port = +let set_switched_off t switched_off = t.switched_off <- switched_off + +let ignore_canceled = function + | Lwt.Canceled -> Lwt.return_unit + | exn -> raise exn + +let get_udpv6_listening_fd ?(preserve = true) {listen_fds;interface;_} port = try Lwt.return (false, Hashtbl.find listen_fds (interface,port)) with Not_found -> @@ -56,14 +65,14 @@ let connect id = | None -> Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified | Some ip -> Ipaddr_unix.V6.to_inet_addr (Ipaddr.V6.Prefix.address ip) in - { interface; listen_fds } + { interface; listen_fds; switched_off = Lwt.return_unit } in Lwt.return t let disconnect t = Hashtbl.fold (fun _ fd r -> r >>= fun () -> close fd) t.listen_fds Lwt.return_unit -let input ~listeners:_ _ = Lwt.return_unit +let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = let open Lwt_unix in @@ -81,3 +90,42 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf = write_to_fd fd buf >>= fun r -> (if created then close fd else Lwt.return_unit) >|= fun () -> r + +let unlisten t ~port = + try + let fd = Hashtbl.find t.listen_fds (t.interface, port) in + Hashtbl.remove t.listen_fds (t.interface, port); + Unix.close (Lwt_unix.unix_file_descr fd) + with _ -> () + +let listen t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)); + unlisten t ~port; + (* FIXME: we should not ignore the result *) + Lwt.async (fun () -> + get_udpv6_listening_fd t port >>= fun (_, fd) -> + let buf = Cstruct.create 4096 in + let rec loop () = + if not (Lwt.is_sleeping t.switched_off) then raise Lwt.Canceled ; + Lwt.catch (fun () -> + Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) -> + let buf = Cstruct.sub buf 0 len in + (match sa with + | Lwt_unix.ADDR_INET (addr, src_port) -> + let src = Ipaddr_unix.V6.of_inet_addr_exn addr in + let dst = Ipaddr.V6.unspecified in (* TODO *) + callback ~src ~dst ~src_port buf + | _ -> Lwt.return_unit) >|= fun () -> + `Continue) + (function + | Unix.Unix_error (Unix.EBADF, _, _) -> + Log.warn (fun m -> m "error bad file descriptor in accept") ; + Lwt.return `Stop + | exn -> + Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ; + Lwt.return `Continue) >>= function + | `Continue -> loop () + | `Stop -> Lwt.return_unit + in + Lwt.catch loop ignore_canceled >>= fun () -> close fd) diff --git a/src/tcp/flow.ml b/src/tcp/flow.ml index a4718ed24..45b9f6162 100644 --- a/src/tcp/flow.ml +++ b/src/tcp/flow.ml @@ -45,9 +45,6 @@ struct | #Mirage_protocols.Tcp.write_error as e -> Mirage_protocols.Tcp.pp_write_error ppf e type ipaddr = Ip.ipaddr - type buffer = Cstruct.t - type +'a io = 'a Lwt.t - type ipinput = src:ipaddr -> dst:ipaddr -> buffer -> unit io type pcb = { id: WIRE.t; @@ -64,13 +61,9 @@ struct type flow = pcb type connection = flow * unit Lwt.t - type listener = { - process: flow -> unit io; - keepalive: Mirage_protocols.Keepalive.t option; - } - type t = { ip : Ip.t; + listeners : (int, Mirage_protocols.Keepalive.t option * (flow -> unit Lwt.t)) Hashtbl.t ; mutable active : bool ; mutable localport : int; channels: (WIRE.t, connection) Hashtbl.t; @@ -82,6 +75,14 @@ struct connects: (WIRE.t, ((connection, error) result Lwt.u * Sequence.t * Mirage_protocols.Keepalive.t option)) Hashtbl.t; } + let listen t ~port ?keepalive cb = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + Hashtbl.replace t.listeners port (keepalive, cb) + + let unlisten t ~port = Hashtbl.remove t.listeners port + let _pp_pcb fmt pcb = Format.fprintf fmt "id=[%a] state=[%a]" WIRE.pp pcb.id State.pp pcb.state @@ -491,10 +492,10 @@ struct Tx.send_rst t id ~sequence ~ack_number ~syn ~fin >>= fun _ -> Lwt.return_unit (* discard errors; we won't retry *) - let process_syn t id ~listeners ~tx_wnd ~ack_number ~sequence ~options ~syn ~fin = + let process_syn t id ~tx_wnd ~ack_number ~sequence ~options ~syn ~fin = log_with_stats "process-syn" t; - match listeners @@ WIRE.src_port id with - | Some { process; keepalive } -> + match Hashtbl.find_opt t.listeners (WIRE.src_port id) with + | Some (keepalive, process) -> let tx_isn = Sequence.of_int32 (Randomconv.int32 Random.generate) in (* TODO: make this configurable per listener *) let rx_wnd = 65535 in @@ -537,7 +538,7 @@ struct Tx.send_rst t id ~sequence ~ack_number ~syn ~fin >>= fun _ -> Lwt.return_unit (* if send fails, who cares *) - let input_no_pcb t listeners (parsed, payload) id = + let input_no_pcb t (parsed, payload) id = if not t.active then (* TODO: eventually send an RST? *) Lwt.return_unit @@ -547,7 +548,7 @@ struct | true, _, _ -> process_reset t id ~ack ~ack_number | false, true, true -> process_synack t id ~ack_number ~sequence ~tx_wnd:window ~options ~syn ~fin - | false, true , false -> process_syn t id ~listeners ~tx_wnd:window + | false, true , false -> process_syn t id ~tx_wnd:window ~ack_number ~sequence ~options ~syn ~fin | false, false, true -> let open RXS in @@ -557,7 +558,7 @@ struct Lwt.return_unit (* Main input function for TCP packets *) - let input t ~listeners ~src ~dst data = + let input t ~src ~dst data = let open Tcp_packet in match Unmarshal.of_cstruct data with | Error s -> Log.debug (fun f -> f "parsing TCP header failed: %s" s); @@ -571,7 +572,7 @@ struct (* PCB exists, so continue the connection state machine in tcp_input *) (Rx.input t RXS.({header = pkt; payload})) (* No existing PCB, so check if it is a SYN for a listening function *) - (input_no_pcb t listeners (pkt, payload)) + (input_no_pcb t (pkt, payload)) (* Blocking read on a PCB *) let read pcb = @@ -734,7 +735,7 @@ struct let listens = Hashtbl.create 1 in let connects = Hashtbl.create 1 in let channels = Hashtbl.create 7 in - Lwt.return { ip; active = true; localport; channels; listens; connects } + Lwt.return { ip; listeners = Hashtbl.create 7; active = true; localport; channels; listens; connects } let disconnect t = t.active <- false; diff --git a/src/tcp/flow.mli b/src/tcp/flow.mli index 8d8a6221e..44b06417b 100644 --- a/src/tcp/flow.mli +++ b/src/tcp/flow.mli @@ -20,6 +20,5 @@ module Make (IP:Mirage_protocols.IP) (R:Mirage_random.S) : sig include Mirage_protocols.TCP with type ipaddr = IP.ipaddr - and type ipinput = src:IP.ipaddr -> dst:IP.ipaddr -> Cstruct.t -> unit Lwt.t val connect : IP.t -> t Lwt.t end diff --git a/src/udp/udp.ml b/src/udp/udp.ml index ce4a52ca3..57b02ee2a 100644 --- a/src/udp/udp.ml +++ b/src/udp/udp.ml @@ -22,7 +22,6 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make(Ip: Mirage_protocols.IP)(Random:Mirage_random.S) = struct type ipaddr = Ip.ipaddr - type ipinput = src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t type error = [ `Ip of Ip.error ] @@ -30,24 +29,32 @@ module Make(Ip: Mirage_protocols.IP)(Random:Mirage_random.S) = struct type t = { ip : Ip.t; + listeners : (int, callback) Hashtbl.t; } let pp_ip = Ip.pp_ipaddr + let listen t ~port callback = + if port < 0 || port > 65535 then + raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port)) + else + Hashtbl.replace t.listeners port callback + + let unlisten t ~port = Hashtbl.remove t.listeners port + (* TODO: ought we to check to make sure the destination is relevant here? Currently we process all incoming packets without making sure they're either unicast for us or otherwise interesting. *) - let input ~listeners _t ~src ~dst buf = + let input t ~src ~dst buf = match Udp_packet.Unmarshal.of_cstruct buf with | Error s -> Log.debug (fun f -> f "Discarding received UDP message: error parsing: %s" s); Lwt.return_unit | Ok ({ Udp_packet.src_port; dst_port}, payload) -> - match listeners ~dst_port with + match Hashtbl.find_opt t.listeners dst_port with | None -> Lwt.return_unit - | Some fn -> - fn ~src ~dst ~src_port payload + | Some fn -> fn ~src ~dst ~src_port payload let writev ?src ?src_port ?ttl ~dst ~dst_port t bufs = let src_port = match src_port with @@ -79,7 +86,7 @@ module Make(Ip: Mirage_protocols.IP)(Random:Mirage_random.S) = struct let connect ip = Log.info (fun f -> f "UDP interface connected on %a" (Fmt.list Ip.pp_ipaddr) @@ Ip.get_ip ip); - let t = { ip } in + let t = { ip ; listeners = Hashtbl.create 7 } in Lwt.return t let disconnect t = diff --git a/src/udp/udp.mli b/src/udp/udp.mli index 21249cec0..eb9d5b69b 100644 --- a/src/udp/udp.mli +++ b/src/udp/udp.mli @@ -18,6 +18,5 @@ module Make (IP:Mirage_protocols.IP)(R:Mirage_random.S) : sig include Mirage_protocols.UDP with type ipaddr = IP.ipaddr - and type ipinput = src:IP.ipaddr -> dst:IP.ipaddr -> Cstruct.t -> unit Lwt.t val connect : IP.t -> t Lwt.t end diff --git a/tcpip.opam b/tcpip.opam index bd0b72b06..92ca53c7d 100644 --- a/tcpip.opam +++ b/tcpip.opam @@ -34,7 +34,7 @@ depends: [ "mirage-clock" {>= "3.0.0"} "mirage-random" {>= "2.0.0"} "mirage-stack" {>= "2.2.0"} - "mirage-protocols" {>= "5.0.0"} + "mirage-protocols" {>= "6.0.0"} "mirage-time" {>= "2.0.0"} "ipaddr" {>= "5.0.0"} "macaddr" {>="4.0.0"} From cfb128484022e9dfbda6d66e74b85e7024abfd45 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Wed, 10 Nov 2021 18:14:40 +0100 Subject: [PATCH 2/2] changes for 6.4.0 --- CHANGES.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 015109a8f..0665725fd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +### v6.4.0 (2021-11-11) + +* Adapt to mirage-protocols 6.0.0 API (#457 @hannesm) +* TCP and UDP now have a listen and unlisten function (fixes #452) +* type ipinput (in TCP and UDP) and listener (in TCP) have been removed + ### v6.3.0 (2021-10-25) * Use Cstruct.length instead of deprecated Cstruct.len (#454 @hannesm)