From cb29b3b08956b7299e2578f4baec9e99a8270045 Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Tue, 9 Jun 2015 19:48:22 +0100 Subject: [PATCH 1/3] Minor cleanups --- tcp/segment.ml | 59 ++++++++++++++++++++++++---------------------- tcp/user_buffer.ml | 13 +++++----- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/tcp/segment.ml b/tcp/segment.ml index ce148bb81..bcbe9c433 100644 --- a/tcp/segment.ml +++ b/tcp/segment.ml @@ -24,6 +24,18 @@ let peek_opt_l seq = let _ = Lwt_sequence.add_l s seq in Some s +let peek_l seq = + match Lwt_sequence.take_opt_l seq with + | None -> assert false + | Some s -> + let _ = Lwt_sequence.add_l s seq in + s + +let rec reset_seq segs = + match Lwt_sequence.take_opt_l segs with + | None -> () + | Some _ -> reset_seq segs + (* The receive queue stores out-of-order segments, and can coalesece them on input and pass on an ordered list up the stack to the application. @@ -105,21 +117,20 @@ module Rx(Time:V1_LWT.TIME) = struct let input (q:t) seg = (* Check that the segment fits into the valid receive window *) let force_ack = ref false in - (* TODO check that this test for a valid RST is valid *) - if (seg.rst && (Window.valid q.wnd seg.sequence)) then begin - StateTick.tick q.state State.Recv_rst; - (* Dump all the received but out of order frames *) - q.segs <- S.empty; - (* Signal TX side *) - let txalert ack_svcd = match ack_svcd with - | true -> Lwt_mvar.put q.tx_ack ((Window.ack_seq q.wnd), (Window.ack_win q.wnd)) - | false -> return_unit - in - txalert (Window.ack_serviced q.wnd) >>= fun () -> - (* Use the fin path to inform the application of end of stream *) - Lwt_mvar.put q.rx_data (None, Some 0) - end else if not (Window.valid q.wnd seg.sequence) then return_unit - else + if not (Window.valid q.wnd seg.sequence) then Lwt.return_unit + else if seg.rst then ( + StateTick.tick q.state State.Recv_rst; + (* Dump all the received but out of order frames *) + q.segs <- S.empty; + (* Signal TX side *) + let txalert ack_svcd = + if not ack_svcd then Lwt.return_unit + else Lwt_mvar.put q.tx_ack (Window.ack_seq q.wnd, Window.ack_win q.wnd) + in + txalert (Window.ack_serviced q.wnd) >>= fun () -> + (* Use the fin path to inform the application of end of stream *) + Lwt_mvar.put q.rx_data (None, Some 0) + ) else (* Insert the latest segment *) let segs = S.add seg q.segs in (* Walk through the set and get a list of contiguous segments *) @@ -281,13 +292,6 @@ module Tx (Time:V1_LWT.TIME) (Clock:V1.CLOCK) = struct | _ -> Tcptimer.Stoptimer - let peek_l seq = - match Lwt_sequence.take_opt_l seq with - | None -> assert false - | Some s -> - let _ = Lwt_sequence.add_l s seq in - s - let rto_t q tx_ack = (* Listen for incoming TX acks from the receive queue and ACK segments in our retransmission queue *) @@ -342,12 +346,11 @@ module Tx (Time:V1_LWT.TIME) (Clock:V1.CLOCK) = struct let seq = Window.ack_seq q.wnd in let win = Window.ack_win q.wnd in begin match State.state q.state with - | State.Reset -> let rec empty_segs segs = - match Lwt_sequence.take_opt_l segs with - | None -> () - | Some s -> empty_segs segs - in - empty_segs q.segs + | State.Reset -> + (* Note: This is not stricly necessary, as the PCB will be + GCed later on. However, it helps removing pressure on + the GC. *) + reset_seq q.segs | _ -> let ack_len = Sequence.sub seq (Window.tx_una q.wnd) in let dupacktest () = diff --git a/tcp/user_buffer.ml b/tcp/user_buffer.ml index b5223f001..9ebad2872 100644 --- a/tcp/user_buffer.ml +++ b/tcp/user_buffer.ml @@ -321,15 +321,14 @@ module Tx(Time:V1_LWT.TIME)(Clock:V1.CLOCK) = struct clear_buffer t >>= fun () -> inform_app t - let rec dump_buffer t = - match Lwt_sequence.is_empty t.buffer with - | true -> return_unit - | false -> - let _ = Lwt_sequence.take_l t.buffer in - dump_buffer t + (* FIXME: duplicated code with Segment.reset_seq *) + let rec reset_seq segs = + match Lwt_sequence.take_opt_l segs with + | None -> () + | Some s -> reset_seq segs let reset t = - dump_buffer t >>= fun () -> + reset_seq t.buffer; inform_app t end From 2d7a172870ea0d0585f6c13d21ead6a326499667 Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Tue, 9 Jun 2015 22:27:26 +0100 Subject: [PATCH 2/3] myocamlbuild runes --- myocamlbuild.ml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/myocamlbuild.ml b/myocamlbuild.ml index 14043d6f2..d06d172f4 100644 --- a/myocamlbuild.ml +++ b/myocamlbuild.ml @@ -673,3 +673,8 @@ let dispatch_default = MyOCamlbuildBase.dispatch_default conf package_default;; # 674 "myocamlbuild.ml" (* OASIS_STOP *) Ocamlbuild_plugin.dispatch dispatch_default;; +(* Ocamlbuild_pack.Flags.mark_tag_used "tests";; *) +let () = + flag ["ocaml"; "doc"] (A"-colorize-code"); + flag ["ocaml"; "doc"] (A"-short-functors"); + flag ["ocaml"; "doc"] (A"-short-paths") From 0cefec41855302fdbe980f525b5250f6d065a8b2 Mon Sep 17 00:00:00 2001 From: Thomas Gazagnaire Date: Tue, 9 Jun 2015 22:28:29 +0100 Subject: [PATCH 3/3] Make Pcb.write fail if the connection is reset --- tcp/flow.ml | 82 +++++++++++++++++++++++++++------------------------- tcp/flow.mli | 1 + tcp/pcb.ml | 37 ++++++++++++++++-------- tcp/pcb.mli | 13 +++++---- 4 files changed, 76 insertions(+), 57 deletions(-) diff --git a/tcp/flow.ml b/tcp/flow.ml index 792b8d6af..e24b6a7a6 100644 --- a/tcp/flow.ml +++ b/tcp/flow.ml @@ -14,7 +14,12 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) -open Lwt +let (>>=) = Lwt.(>>=) +let (>|=) = Lwt.(>|=) + +(* TODO: modify V1.TCP to have a proper return type *) + +exception Bad_state of State.tcpstate module Make(IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) = struct @@ -35,57 +40,54 @@ module Make(IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) = struct | `Refused ] + let err_timeout () = + (* Printf.printf "Failed to connect to %s:%d\n%!" *) + (* (Ipaddr.V4.to_string daddr) dport; *) + Lwt.return (`Error `Timeout) + + let err_refused () = + (* Printf.printf "Refused connection to %s:%d\n%!" *) + (* (Ipaddr.V4.to_string daddr) dport; *) + Lwt.return (`Error `Refused) + + let ok x = Lwt.return (`Ok x) + let error_message = function | `Unknown msg -> msg | `Timeout -> "Timeout while attempting to connect" | `Refused -> "Connection refused" - let id t = Pcb.ip t + let err_rewrite = function + | `Error (`Bad_state _) -> `Error `Refused + | `Ok () as x -> x + + let err_raise = function + | `Error (`Bad_state s) -> Lwt.fail (Bad_state s) + | `Ok () -> Lwt.return_unit - let get_dest t = Pcb.get_dest t + let id = Pcb.ip + let get_dest = Pcb.get_dest + let close t = Pcb.close t + let input = Pcb.input let read t = (* TODO better error interface in Pcb *) Pcb.read t >>= function - | None -> return `Eof - | Some t -> return (`Ok t) + | None -> Lwt.return `Eof + | Some t -> Lwt.return (`Ok t) - let write t view = - Pcb.write t view >>= fun () -> - return (`Ok ()) - - let writev t views = - Pcb.writev t views >>= fun () -> - return (`Ok ()) - - let write_nodelay t view = - Pcb.write_nodelay t view - - let writev_nodelay t views = - Pcb.writev_nodelay t views - - let close t = - Pcb.close t + let write t view = Pcb.write t view >|= err_rewrite + let writev t views = Pcb.writev t views >|= err_rewrite + let write_nodelay t view = Pcb.write_nodelay t view >>= err_raise + let writev_nodelay t views = Pcb.writev_nodelay t views >>= err_raise + let connect ipv4 = ok (Pcb.create ipv4) + let disconnect _ = Lwt.return_unit let create_connection tcp (daddr, dport) = Pcb.connect tcp ~dest_ip:daddr ~dest_port:dport >>= function - | `Timeout -> - (* Printf.printf "Failed to connect to %s:%d\n%!" *) - (* (Ipaddr.V4.to_string daddr) dport; *) - return (`Error `Timeout) - | `Rst -> - (* Printf.printf "Refused connection to %s:%d\n%!" *) - (* (Ipaddr.V4.to_string daddr) dport; *) - return (`Error `Refused) - | `Ok (fl, _) -> - return (`Ok fl) - - let input t ~listeners ~src ~dst buf = - Pcb.input t ~listeners ~src ~dst buf - - let connect ipv4 = - return (`Ok (Pcb.create ipv4)) - - let disconnect _ = - return_unit + | `Timeout -> err_timeout () + | `Rst -> err_refused () + | `Ok (fl, _) -> ok fl + + end diff --git a/tcp/flow.mli b/tcp/flow.mli index 61332d478..67fab6eb8 100644 --- a/tcp/flow.mli +++ b/tcp/flow.mli @@ -14,6 +14,7 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +exception Bad_state of State.tcpstate module Make (IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) : sig include V1_LWT.TCP diff --git a/tcp/pcb.ml b/tcp/pcb.ml index b5edd90dd..1ed891ba9 100644 --- a/tcp/pcb.ml +++ b/tcp/pcb.ml @@ -18,6 +18,24 @@ open Lwt open Printf +type error = [`Bad_state of State.tcpstate] + +type 'a result = [`Ok of 'a | `Error of error] +let ok x = Lwt.return (`Ok x) +let error s = Lwt.return (`Error (`Bad_state s)) + +let (>+=) x f = + x >>= function + | `Ok x -> f x + | `Error _ as e -> Lwt.return e + +let iter_s f l = + let rec aux = function + | [] -> ok () + | h::t -> f h >+= fun () -> aux t + in + aux l + module Tcp_wire = Wire_structs.Tcp_wire cstruct pseudo_header { @@ -486,28 +504,23 @@ struct | av_len when av_len < len -> let first_bit = Cstruct.sub data 0 av_len in let remaing_bit = Cstruct.sub data av_len (len - av_len) in - writefn pcb wfn first_bit >>= fun () -> + writefn pcb wfn first_bit >+= fun () -> writefn pcb wfn remaing_bit | _ -> match State.state pcb.state with - | State.Established | State.Close_wait -> wfn [data] - (* URG_TODO: return error instead of dropping silently *) - | _ -> return_unit + | State.Established | State.Close_wait -> wfn [data] >>= ok + | e -> error e (* Blocking write on a PCB *) let write pcb data = writefn pcb (UTX.write pcb.utx) data - let writev pcb data = Lwt_list.iter_s (fun d -> write pcb d) data - + let writev pcb data = iter_s (write pcb) data let write_nodelay pcb data = writefn pcb (UTX.write_nodelay pcb.utx) data - let writev_nodelay pcb data = - Lwt_list.iter_s (fun d -> write_nodelay pcb d) data + let writev_nodelay pcb data = iter_s (write_nodelay pcb) data (* Close - no more will be written *) - let close pcb = - Tx.close pcb + let close pcb = Tx.close pcb - let get_dest pcb = - pcb.id.WIRE.dest_ip, pcb.id.WIRE.dest_port + let get_dest pcb = pcb.id.WIRE.dest_ip, pcb.id.WIRE.dest_port let getid t dest_ip dest_port = (* TODO: make this more robust and recognise when all ports are gone *) diff --git a/tcp/pcb.mli b/tcp/pcb.mli index dcbf85990..6058693b6 100644 --- a/tcp/pcb.mli +++ b/tcp/pcb.mli @@ -14,11 +14,14 @@ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. *) +type error = [`Bad_state of State.tcpstate] + +type 'a result = [`Ok of 'a | `Error of error] + module Make(Ip:V1_LWT.IP)(Time:V1_LWT.TIME)(Clock:V1.CLOCK)(Random:V1.RANDOM) : sig (** Overall state of the TCP stack *) type t - type pcb (** State for an individual connection *) @@ -47,13 +50,13 @@ module Make(Ip:V1_LWT.IP)(Time:V1_LWT.TIME)(Clock:V1.CLOCK)(Random:V1.RANDOM) : val write_wait_for : pcb -> int -> unit Lwt.t (* write - blocks if the write buffer is full *) - val write: pcb -> Cstruct.t -> unit Lwt.t - val writev: pcb -> Cstruct.t list -> unit Lwt.t + val write: pcb -> Cstruct.t -> unit result Lwt.t + val writev: pcb -> Cstruct.t list -> unit result Lwt.t (* version of write with Nagle disabled - will block if write buffer is full *) - val write_nodelay: pcb -> Cstruct.t -> unit Lwt.t - val writev_nodelay: pcb -> Cstruct.t list -> unit Lwt.t + val write_nodelay: pcb -> Cstruct.t -> unit result Lwt.t + val writev_nodelay: pcb -> Cstruct.t list -> unit result Lwt.t val create: Ip.t -> t (* val tcpstats: t -> unit *)