Skip to content

Commit

Permalink
stack-unix: catch exceptions in accept
Browse files Browse the repository at this point in the history
  • Loading branch information
hannesm committed Oct 9, 2018
1 parent 62baca0 commit 97ceb19
Showing 1 changed file with 46 additions and 49 deletions.
95 changes: 46 additions & 49 deletions src/stack-unix/tcpip_stack_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

open Lwt
open Lwt.Infix

let src = Logs.Src.create "tcpip-stack-socket" ~doc:"Platform's native TCP/IP stack"
module Log = (val Logs.src_log src : Logs.LOG)
Expand Down Expand Up @@ -56,13 +56,13 @@ let ipv4 _ = None
(* List of IP addresses to bind to *)
let configure _t addrs =
match addrs with
| [] -> return_unit
| [ip] when (Ipaddr.V4.compare Ipaddr.V4.any ip) = 0 -> return_unit
| [] -> Lwt.return_unit
| [ip] when (Ipaddr.V4.compare Ipaddr.V4.any ip) = 0 -> Lwt.return_unit
| l ->
let pp_iplist fmt l = Format.pp_print_list Ipaddr.V4.pp_hum fmt l in
Log.warn (fun f -> f
"Manager: sockets currently bind to all available IPs. IPs %a were specified, but this will be ignored" pp_iplist l);
return_unit
Lwt.return_unit

let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p

Expand All @@ -71,28 +71,26 @@ let listen_udpv4 t ~port callback =
raise (Invalid_argument (err_invalid_port port))
else
(* FIXME: we should not ignore the result *)
ignore_result (
Udpv4.get_udpv4_listening_fd t.udpv4 port
>>= fun fd ->
Lwt.async (fun () ->
Udpv4.get_udpv4_listening_fd t.udpv4 port >>= fun fd ->
let buf = Cstruct.create 4096 in
let rec loop () =
let continue () =
(* TODO cancellation *)
if true then loop () else return_unit in
Lwt_cstruct.recvfrom fd buf []
>>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
begin match sa with
| Lwt_unix.ADDR_INET (addr, src_port) ->
let src = Ipaddr_unix.V4.of_inet_addr_exn addr in
let dst = Ipaddr.V4.any in (* TODO *)
callback ~src ~dst ~src_port buf
| _ -> return_unit
end >>= fun () ->
continue ()
(* TODO cancellation *)
Lwt.catch (fun () ->
Lwt_cstruct.recvfrom fd buf [] >>= fun (len, sa) ->
let buf = Cstruct.sub buf 0 len in
(match sa with
| Lwt_unix.ADDR_INET (addr, src_port) ->
let src = Ipaddr_unix.V4.of_inet_addr_exn addr in
let dst = Ipaddr.V4.any in (* TODO *)
callback ~src ~dst ~src_port buf
| _ -> Lwt.return_unit))
(fun exn ->
Log.warn (fun m -> m "exception %s in recvfrom" (Printexc.to_string exn)) ;
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ()
)
loop ())

let listen_tcpv4 ?keepalive _t ~port callback =
if port < 0 || port > 65535 then
Expand All @@ -103,30 +101,30 @@ let listen_tcpv4 ?keepalive _t ~port callback =
(* TODO: as elsewhere in the module, we bind all available addresses; it would be better not to do so if the user has requested it *)
let interface = Ipaddr_unix.V4.to_inet_addr Ipaddr.V4.any in
(* FIXME: we should not ignore the result *)
ignore_result (
Lwt_unix.bind fd (Lwt_unix.ADDR_INET (interface, port))
>>= fun () ->
Lwt.async (fun () ->
Lwt_unix.bind fd (Lwt_unix.ADDR_INET (interface, port)) >>= fun () ->
Lwt_unix.listen fd 10;
(* TODO cancellation *)
let rec loop () =
let continue () =
(* TODO cancellation *)
if true then loop () else return_unit in
Lwt_unix.accept fd
>>= fun (afd, _) ->
( match keepalive with
| None -> ()
| Some { Mirage_protocols.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes );
Lwt.async (fun () ->
Lwt.catch
(fun () -> callback afd)
(fun _ -> return_unit)
);
return_unit
>>= fun () ->
continue () in
loop ()
)
Lwt.catch (fun () ->
Lwt_unix.accept fd >|= fun (afd, _) ->
(match keepalive with
| None -> ()
| Some { Mirage_protocols.Keepalive.after; interval; probes } ->
Tcp_socket_options.enable_keepalive ~fd:afd ~after ~interval ~probes);
Lwt.async
(fun () ->
Lwt.catch
(fun () -> callback afd)
(fun exn ->
Log.warn (fun m -> m "error %s in callback" (Printexc.to_string exn)) ;
Lwt.return_unit)))
(fun exn ->
Log.warn (fun m -> m "error %s in accept" (Printexc.to_string exn)) ;
Lwt.return_unit) >>= fun () ->
loop ()
in
loop ())

let listen _t =
let t, _ = Lwt.task () in
Expand All @@ -138,8 +136,7 @@ let connect ips udpv4 tcpv4 =
let tcpv4_listeners = Hashtbl.create 7 in
let t = { tcpv4; udpv4; udpv4_listeners; tcpv4_listeners } in
Log.info (fun f -> f "Manager: configuring");
configure t ips
>>= fun () ->
return t
configure t ips >|= fun () ->
t

let disconnect _ = return_unit
let disconnect _ = Lwt.return_unit

0 comments on commit 97ceb19

Please sign in to comment.