Skip to content

Commit

Permalink
Merge pull request #457 from hannesm/new-stack
Browse files Browse the repository at this point in the history
Adapt to mirage/mirage-protocols#28 changes:
  • Loading branch information
hannesm authored Nov 11, 2021
2 parents 3e9c163 + cfb1284 commit 6681eda
Show file tree
Hide file tree
Showing 20 changed files with 461 additions and 425 deletions.
6 changes: 6 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
### v6.4.0 (2021-11-11)

* Adapt to mirage-protocols 6.0.0 API (#457 @hannesm)
* TCP and UDP now have a listen and unlisten function (fixes #452)
* type ipinput (in TCP and UDP) and listener (in TCP) have been removed

### v6.3.0 (2021-10-25)

* Use Cstruct.length instead of deprecated Cstruct.len (#454 @hannesm)
Expand Down
102 changes: 15 additions & 87 deletions src/stack-direct/tcpip_stack_direct.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@ open Lwt.Infix
let src = Logs.Src.create "tcpip-stack-direct" ~doc:"Pure OCaml TCP/IP stack"
module Log = (val Logs.src_log src : Logs.LOG)

type direct_ipv4_input = src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t
module type UDPV4_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.V4.t
and type ipinput = direct_ipv4_input

module type TCPV4_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.V4.t
and type ipinput = direct_ipv4_input

module Make
(Time : Mirage_time.S)
Expand All @@ -51,8 +48,6 @@ module Make
icmpv4: Icmpv4.t;
udpv4 : Udpv4.t;
tcpv4 : Tcpv4.t;
udpv4_listeners: (int, Udpv4.callback) Hashtbl.t;
tcpv4_listeners: (int, Tcpv4.listener) Hashtbl.t;
mutable task : unit Lwt.t option;
}

Expand All @@ -64,26 +59,11 @@ module Make
let udpv4 { udpv4; _ } = udpv4
let ipv4 { ipv4; _ } = ipv4

let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p

let listen_udpv4 t ~port callback =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.udpv4_listeners port callback

Udpv4.listen t.udpv4 ~port callback

let listen_tcpv4 ?keepalive t ~port process =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.tcpv4_listeners port { Tcpv4.process; keepalive }

let udpv4_listeners t ~dst_port =
try Some (Hashtbl.find t.udpv4_listeners dst_port)
with Not_found -> None

let tcpv4_listeners t dst_port =
try Some (Hashtbl.find t.tcpv4_listeners dst_port)
with Not_found -> None
Tcpv4.listen t.tcpv4 ~port ?keepalive process

let listen t =
Lwt.catch (fun () ->
Expand All @@ -92,10 +72,8 @@ module Make
~arpv4:(Arpv4.input t.arpv4)
~ipv4:(
Ipv4.input
~tcp:(Tcpv4.input t.tcpv4
~listeners:(tcpv4_listeners t))
~udp:(Udpv4.input t.udpv4
~listeners:(udpv4_listeners t))
~tcp:(Tcpv4.input t.tcpv4)
~udp:(Udpv4.input t.udpv4)
~default:(fun ~proto ~src ~dst buf ->
match proto with
| 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf
Expand Down Expand Up @@ -127,10 +105,7 @@ module Make
| e -> Lwt.fail e)

let connect netif ethif arpv4 ipv4 icmpv4 udpv4 tcpv4 =
let udpv4_listeners = Hashtbl.create 7 in
let tcpv4_listeners = Hashtbl.create 7 in
let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4;
udpv4_listeners; tcpv4_listeners; task = None } in
let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4; task = None } in
Log.info (fun f -> f "stack assembled: %a" pp t);
Lwt.async (fun () -> let task = listen t in t.task <- Some task; task);
Lwt.return t
Expand All @@ -141,14 +116,11 @@ module Make
Lwt.return_unit
end

type direct_ipv6_input = src:Ipaddr.V6.t -> dst:Ipaddr.V6.t -> Cstruct.t -> unit Lwt.t
module type UDPV6_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.V6.t
and type ipinput = direct_ipv6_input

module type TCPV6_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.V6.t
and type ipinput = direct_ipv6_input

module MakeV6
(Time : Mirage_time.S)
Expand All @@ -169,8 +141,6 @@ module MakeV6
ipv6 : Ipv6.t;
udpv6 : Udpv6.t;
tcpv6 : Tcpv6.t;
udpv6_listeners: (int, Udpv6.callback) Hashtbl.t;
tcpv6_listeners: (int, Tcpv6.listener) Hashtbl.t;
mutable task : unit Lwt.t option;
}

Expand All @@ -182,25 +152,11 @@ module MakeV6
let udp { udpv6; _ } = udpv6
let ip { ipv6; _ } = ipv6

let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p

let listen_udp t ~port callback =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.udpv6_listeners port callback
Udpv6.listen t.udpv6 ~port callback

let listen_tcp ?keepalive t ~port process =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.tcpv6_listeners port { Tcpv6.process; keepalive }

let udpv6_listeners t ~dst_port =
try Some (Hashtbl.find t.udpv6_listeners dst_port)
with Not_found -> None

let tcpv6_listeners t dst_port =
try Some (Hashtbl.find t.tcpv6_listeners dst_port)
with Not_found -> None
Tcpv6.listen t.tcpv6 ~port ?keepalive process

let listen t =
Lwt.catch (fun () ->
Expand All @@ -210,10 +166,8 @@ module MakeV6
~ipv4:(fun _ -> Lwt.return_unit)
~ipv6:(
Ipv6.input
~tcp:(Tcpv6.input t.tcpv6
~listeners:(tcpv6_listeners t))
~udp:(Udpv6.input t.udpv6
~listeners:(udpv6_listeners t))
~tcp:(Tcpv6.input t.tcpv6)
~udp:(Udpv6.input t.udpv6)
~default:(fun ~proto:_ ~src:_ ~dst:_ _ -> Lwt.return_unit)
t.ipv6)
t.ethif
Expand Down Expand Up @@ -241,10 +195,7 @@ module MakeV6
| e -> Lwt.fail e)

let connect netif ethif ipv6 udpv6 tcpv6 =
let udpv6_listeners = Hashtbl.create 7 in
let tcpv6_listeners = Hashtbl.create 7 in
let t = { netif; ethif; ipv6; tcpv6; udpv6;
udpv6_listeners; tcpv6_listeners; task = None } in
let t = { netif; ethif; ipv6; tcpv6; udpv6; task = None } in
Log.info (fun f -> f "stack assembled: %a" pp t);
Lwt.async (fun () -> let task = listen t in t.task <- Some task; task);
Lwt.return t
Expand All @@ -256,15 +207,11 @@ module MakeV6

end

type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t

module type UDPV4V6_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.t
and type ipinput = direct_ipv4v6_input

module type TCPV4V6_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.t
and type ipinput = direct_ipv4v6_input

module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) = struct

Expand Down Expand Up @@ -404,8 +351,6 @@ module MakeV4V6
ip : IP.t;
udp : Udp.t;
tcp : Tcp.t;
udp_listeners: (int, Udp.callback) Hashtbl.t;
tcp_listeners: (int, Tcp.listener) Hashtbl.t;
mutable task : unit Lwt.t option;
}

Expand All @@ -417,31 +362,17 @@ module MakeV4V6
let udp { udp; _ } = udp
let ip { ip; _ } = ip

let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p

let listen_udp t ~port callback =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.udp_listeners port callback
Udp.listen t.udp ~port callback

let listen_tcp ?keepalive t ~port process =
if port < 0 || port > 65535
then raise (Invalid_argument (err_invalid_port port))
else Hashtbl.replace t.tcp_listeners port { Tcp.process; keepalive }

let udp_listeners t ~dst_port =
try Some (Hashtbl.find t.udp_listeners dst_port)
with Not_found -> None

let tcp_listeners t dst_port =
try Some (Hashtbl.find t.tcp_listeners dst_port)
with Not_found -> None
Tcp.listen t.tcp ~port ?keepalive process

let listen t =
Lwt.catch (fun () ->
Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t);
let tcp = Tcp.input t.tcp ~listeners:(tcp_listeners t)
and udp = Udp.input t.udp ~listeners:(udp_listeners t)
let tcp = Tcp.input t.tcp
and udp = Udp.input t.udp
and default ~proto ~src ~dst buf =
match proto, src, dst with
| 1, Ipaddr.V4 src, Ipaddr.V4 dst -> Icmpv4.input t.icmpv4 ~src ~dst buf
Expand Down Expand Up @@ -476,10 +407,7 @@ module MakeV4V6
| e -> Lwt.fail e)

let connect netif ethif arpv4 ip icmpv4 udp tcp =
let udp_listeners = Hashtbl.create 7 in
let tcp_listeners = Hashtbl.create 7 in
let t = { netif; ethif; arpv4; ip; icmpv4; tcp; udp;
udp_listeners; tcp_listeners; task = None } in
let t = { netif; ethif; arpv4; ip; icmpv4; tcp; udp; task = None } in
Log.info (fun f -> f "stack assembled: %a" pp t);
Lwt.async (fun () -> let task = listen t in t.task <- Some task; task);
Lwt.return t
Expand Down
12 changes: 0 additions & 12 deletions src/stack-direct/tcpip_stack_direct.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

type direct_ipv4_input = src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> Cstruct.t -> unit Lwt.t

module type UDPV4_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.V4.t
and type ipinput = direct_ipv4_input

module type TCPV4_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.V4.t
and type ipinput = direct_ipv4_input

module Make
(Time : Mirage_time.S)
Expand All @@ -48,15 +44,11 @@ module Make
connections, they will be able to do so. *)
end

type direct_ipv6_input = src:Ipaddr.V6.t -> dst:Ipaddr.V6.t -> Cstruct.t -> unit Lwt.t

module type UDPV6_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.V6.t
and type ipinput = direct_ipv6_input

module type TCPV6_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.V6.t
and type ipinput = direct_ipv6_input

module MakeV6
(Time : Mirage_time.S)
Expand All @@ -79,15 +71,11 @@ module MakeV6
they will be able to do so. *)
end

type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t

module type UDPV4V6_DIRECT = Mirage_protocols.UDP
with type ipaddr = Ipaddr.t
and type ipinput = direct_ipv4v6_input

module type TCPV4V6_DIRECT = Mirage_protocols.TCP
with type ipaddr = Ipaddr.t
and type ipinput = direct_ipv4v6_input

module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) : sig
include Mirage_protocols.IP with type ipaddr = Ipaddr.t
Expand Down
12 changes: 6 additions & 6 deletions src/stack-unix/dune
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
(wrapped false)
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs))

(library
(name udpv6_socket)
Expand All @@ -24,7 +24,7 @@
(wrapped false)
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs))

(library
(name udpv4v6_socket)
Expand All @@ -33,7 +33,7 @@
(wrapped false)
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols logs))

(library
(name tcp_socket_options)
Expand All @@ -56,7 +56,7 @@
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols
tcp_socket_options))
tcp_socket_options logs))

(library
(name tcpv6_socket)
Expand All @@ -66,7 +66,7 @@
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols
tcpv4_socket tcp_socket_options))
tcpv4_socket tcp_socket_options logs))

(library
(name tcpv4v6_socket)
Expand All @@ -76,7 +76,7 @@
(instrumentation
(backend bisect_ppx))
(libraries lwt.unix ipaddr.unix cstruct-lwt fmt mirage-protocols
tcpv4_socket tcp_socket_options))
tcpv4_socket tcp_socket_options logs))

(library
(name tcpip_stack_socket)
Expand Down
15 changes: 5 additions & 10 deletions src/stack-unix/tcp_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ let pp_write_error ppf = function
| #Mirage_protocols.Tcp.write_error as e -> Mirage_protocols.Tcp.pp_write_error ppf e
| `Exn e -> Fmt.exn ppf e

let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

let disconnect _ =
return_unit

Expand Down Expand Up @@ -61,13 +65,4 @@ let close fd =
| Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit
| e -> Lwt.fail e)

type listener = {
process: Lwt_unix.file_descr -> unit Lwt.t;
keepalive: Mirage_protocols.Keepalive.t option;
}

(* FIXME: how does this work at all ?? *)
let input _t ~listeners:_ =
(* TODO terminate when signalled by disconnect *)
let t, _ = Lwt.task () in
t
let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit
Loading

0 comments on commit 6681eda

Please sign in to comment.