Skip to content

Commit

Permalink
Merge pull request #150 from samoht/master
Browse files Browse the repository at this point in the history
make Pcb.write fail if the connection is reset
  • Loading branch information
samoht committed Jun 10, 2015
2 parents aab5709 + 0cefec4 commit 03220b3
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 92 deletions.
5 changes: 5 additions & 0 deletions myocamlbuild.ml
Original file line number Diff line number Diff line change
Expand Up @@ -673,3 +673,8 @@ let dispatch_default = MyOCamlbuildBase.dispatch_default conf package_default;;
# 674 "myocamlbuild.ml"
(* OASIS_STOP *)
Ocamlbuild_plugin.dispatch dispatch_default;;
(* Ocamlbuild_pack.Flags.mark_tag_used "tests";; *)
let () =
flag ["ocaml"; "doc"] (A"-colorize-code");
flag ["ocaml"; "doc"] (A"-short-functors");
flag ["ocaml"; "doc"] (A"-short-paths")
82 changes: 42 additions & 40 deletions tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

open Lwt
let (>>=) = Lwt.(>>=)
let (>|=) = Lwt.(>|=)

(* TODO: modify V1.TCP to have a proper return type *)

exception Bad_state of State.tcpstate

module Make(IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) = struct

Expand All @@ -35,57 +40,54 @@ module Make(IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) = struct
| `Refused
]

let err_timeout () =
(* Printf.printf "Failed to connect to %s:%d\n%!" *)
(* (Ipaddr.V4.to_string daddr) dport; *)
Lwt.return (`Error `Timeout)

let err_refused () =
(* Printf.printf "Refused connection to %s:%d\n%!" *)
(* (Ipaddr.V4.to_string daddr) dport; *)
Lwt.return (`Error `Refused)

let ok x = Lwt.return (`Ok x)

let error_message = function
| `Unknown msg -> msg
| `Timeout -> "Timeout while attempting to connect"
| `Refused -> "Connection refused"

let id t = Pcb.ip t
let err_rewrite = function
| `Error (`Bad_state _) -> `Error `Refused
| `Ok () as x -> x

let err_raise = function
| `Error (`Bad_state s) -> Lwt.fail (Bad_state s)
| `Ok () -> Lwt.return_unit

let get_dest t = Pcb.get_dest t
let id = Pcb.ip
let get_dest = Pcb.get_dest
let close t = Pcb.close t
let input = Pcb.input

let read t =
(* TODO better error interface in Pcb *)
Pcb.read t >>= function
| None -> return `Eof
| Some t -> return (`Ok t)
| None -> Lwt.return `Eof
| Some t -> Lwt.return (`Ok t)

let write t view =
Pcb.write t view >>= fun () ->
return (`Ok ())

let writev t views =
Pcb.writev t views >>= fun () ->
return (`Ok ())

let write_nodelay t view =
Pcb.write_nodelay t view

let writev_nodelay t views =
Pcb.writev_nodelay t views

let close t =
Pcb.close t
let write t view = Pcb.write t view >|= err_rewrite
let writev t views = Pcb.writev t views >|= err_rewrite
let write_nodelay t view = Pcb.write_nodelay t view >>= err_raise
let writev_nodelay t views = Pcb.writev_nodelay t views >>= err_raise
let connect ipv4 = ok (Pcb.create ipv4)
let disconnect _ = Lwt.return_unit

let create_connection tcp (daddr, dport) =
Pcb.connect tcp ~dest_ip:daddr ~dest_port:dport >>= function
| `Timeout ->
(* Printf.printf "Failed to connect to %s:%d\n%!" *)
(* (Ipaddr.V4.to_string daddr) dport; *)
return (`Error `Timeout)
| `Rst ->
(* Printf.printf "Refused connection to %s:%d\n%!" *)
(* (Ipaddr.V4.to_string daddr) dport; *)
return (`Error `Refused)
| `Ok (fl, _) ->
return (`Ok fl)

let input t ~listeners ~src ~dst buf =
Pcb.input t ~listeners ~src ~dst buf

let connect ipv4 =
return (`Ok (Pcb.create ipv4))

let disconnect _ =
return_unit
| `Timeout -> err_timeout ()
| `Rst -> err_refused ()
| `Ok (fl, _) -> ok fl


end
1 change: 1 addition & 0 deletions tcp/flow.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

exception Bad_state of State.tcpstate

module Make (IP:V1_LWT.IP)(TM:V1_LWT.TIME)(C:V1.CLOCK)(R:V1.RANDOM) : sig
include V1_LWT.TCP
Expand Down
37 changes: 25 additions & 12 deletions tcp/pcb.ml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
open Lwt
open Printf

type error = [`Bad_state of State.tcpstate]

type 'a result = [`Ok of 'a | `Error of error]
let ok x = Lwt.return (`Ok x)
let error s = Lwt.return (`Error (`Bad_state s))

let (>+=) x f =
x >>= function
| `Ok x -> f x
| `Error _ as e -> Lwt.return e

let iter_s f l =
let rec aux = function
| [] -> ok ()
| h::t -> f h >+= fun () -> aux t
in
aux l

module Tcp_wire = Wire_structs.Tcp_wire

cstruct pseudo_header {
Expand Down Expand Up @@ -486,28 +504,23 @@ struct
| av_len when av_len < len ->
let first_bit = Cstruct.sub data 0 av_len in
let remaing_bit = Cstruct.sub data av_len (len - av_len) in
writefn pcb wfn first_bit >>= fun () ->
writefn pcb wfn first_bit >+= fun () ->
writefn pcb wfn remaing_bit
| _ ->
match State.state pcb.state with
| State.Established | State.Close_wait -> wfn [data]
(* URG_TODO: return error instead of dropping silently *)
| _ -> return_unit
| State.Established | State.Close_wait -> wfn [data] >>= ok
| e -> error e

(* Blocking write on a PCB *)
let write pcb data = writefn pcb (UTX.write pcb.utx) data
let writev pcb data = Lwt_list.iter_s (fun d -> write pcb d) data

let writev pcb data = iter_s (write pcb) data
let write_nodelay pcb data = writefn pcb (UTX.write_nodelay pcb.utx) data
let writev_nodelay pcb data =
Lwt_list.iter_s (fun d -> write_nodelay pcb d) data
let writev_nodelay pcb data = iter_s (write_nodelay pcb) data

(* Close - no more will be written *)
let close pcb =
Tx.close pcb
let close pcb = Tx.close pcb

let get_dest pcb =
pcb.id.WIRE.dest_ip, pcb.id.WIRE.dest_port
let get_dest pcb = pcb.id.WIRE.dest_ip, pcb.id.WIRE.dest_port

let getid t dest_ip dest_port =
(* TODO: make this more robust and recognise when all ports are gone *)
Expand Down
13 changes: 8 additions & 5 deletions tcp/pcb.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

type error = [`Bad_state of State.tcpstate]

type 'a result = [`Ok of 'a | `Error of error]

module Make(Ip:V1_LWT.IP)(Time:V1_LWT.TIME)(Clock:V1.CLOCK)(Random:V1.RANDOM) : sig

(** Overall state of the TCP stack *)
type t

type pcb

(** State for an individual connection *)
Expand Down Expand Up @@ -47,13 +50,13 @@ module Make(Ip:V1_LWT.IP)(Time:V1_LWT.TIME)(Clock:V1.CLOCK)(Random:V1.RANDOM) :
val write_wait_for : pcb -> int -> unit Lwt.t

(* write - blocks if the write buffer is full *)
val write: pcb -> Cstruct.t -> unit Lwt.t
val writev: pcb -> Cstruct.t list -> unit Lwt.t
val write: pcb -> Cstruct.t -> unit result Lwt.t
val writev: pcb -> Cstruct.t list -> unit result Lwt.t

(* version of write with Nagle disabled - will block if write
buffer is full *)
val write_nodelay: pcb -> Cstruct.t -> unit Lwt.t
val writev_nodelay: pcb -> Cstruct.t list -> unit Lwt.t
val write_nodelay: pcb -> Cstruct.t -> unit result Lwt.t
val writev_nodelay: pcb -> Cstruct.t list -> unit result Lwt.t

val create: Ip.t -> t
(* val tcpstats: t -> unit *)
Expand Down
59 changes: 31 additions & 28 deletions tcp/segment.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ let peek_opt_l seq =
let _ = Lwt_sequence.add_l s seq in
Some s

let peek_l seq =
match Lwt_sequence.take_opt_l seq with
| None -> assert false
| Some s ->
let _ = Lwt_sequence.add_l s seq in
s

let rec reset_seq segs =
match Lwt_sequence.take_opt_l segs with
| None -> ()
| Some _ -> reset_seq segs

(* The receive queue stores out-of-order segments, and can
coalesece them on input and pass on an ordered list up the
stack to the application.
Expand Down Expand Up @@ -105,21 +117,20 @@ module Rx(Time:V1_LWT.TIME) = struct
let input (q:t) seg =
(* Check that the segment fits into the valid receive window *)
let force_ack = ref false in
(* TODO check that this test for a valid RST is valid *)
if (seg.rst && (Window.valid q.wnd seg.sequence)) then begin
StateTick.tick q.state State.Recv_rst;
(* Dump all the received but out of order frames *)
q.segs <- S.empty;
(* Signal TX side *)
let txalert ack_svcd = match ack_svcd with
| true -> Lwt_mvar.put q.tx_ack ((Window.ack_seq q.wnd), (Window.ack_win q.wnd))
| false -> return_unit
in
txalert (Window.ack_serviced q.wnd) >>= fun () ->
(* Use the fin path to inform the application of end of stream *)
Lwt_mvar.put q.rx_data (None, Some 0)
end else if not (Window.valid q.wnd seg.sequence) then return_unit
else
if not (Window.valid q.wnd seg.sequence) then Lwt.return_unit
else if seg.rst then (
StateTick.tick q.state State.Recv_rst;
(* Dump all the received but out of order frames *)
q.segs <- S.empty;
(* Signal TX side *)
let txalert ack_svcd =
if not ack_svcd then Lwt.return_unit
else Lwt_mvar.put q.tx_ack (Window.ack_seq q.wnd, Window.ack_win q.wnd)
in
txalert (Window.ack_serviced q.wnd) >>= fun () ->
(* Use the fin path to inform the application of end of stream *)
Lwt_mvar.put q.rx_data (None, Some 0)
) else
(* Insert the latest segment *)
let segs = S.add seg q.segs in
(* Walk through the set and get a list of contiguous segments *)
Expand Down Expand Up @@ -281,13 +292,6 @@ module Tx (Time:V1_LWT.TIME) (Clock:V1.CLOCK) = struct
| _ ->
Tcptimer.Stoptimer

let peek_l seq =
match Lwt_sequence.take_opt_l seq with
| None -> assert false
| Some s ->
let _ = Lwt_sequence.add_l s seq in
s

let rto_t q tx_ack =
(* Listen for incoming TX acks from the receive queue and ACK
segments in our retransmission queue *)
Expand Down Expand Up @@ -342,12 +346,11 @@ module Tx (Time:V1_LWT.TIME) (Clock:V1.CLOCK) = struct
let seq = Window.ack_seq q.wnd in
let win = Window.ack_win q.wnd in
begin match State.state q.state with
| State.Reset -> let rec empty_segs segs =
match Lwt_sequence.take_opt_l segs with
| None -> ()
| Some s -> empty_segs segs
in
empty_segs q.segs
| State.Reset ->
(* Note: This is not stricly necessary, as the PCB will be
GCed later on. However, it helps removing pressure on
the GC. *)
reset_seq q.segs
| _ ->
let ack_len = Sequence.sub seq (Window.tx_una q.wnd) in
let dupacktest () =
Expand Down
13 changes: 6 additions & 7 deletions tcp/user_buffer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,14 @@ module Tx(Time:V1_LWT.TIME)(Clock:V1.CLOCK) = struct
clear_buffer t >>= fun () ->
inform_app t

let rec dump_buffer t =
match Lwt_sequence.is_empty t.buffer with
| true -> return_unit
| false ->
let _ = Lwt_sequence.take_l t.buffer in
dump_buffer t
(* FIXME: duplicated code with Segment.reset_seq *)
let rec reset_seq segs =
match Lwt_sequence.take_opt_l segs with
| None -> ()
| Some s -> reset_seq segs

let reset t =
dump_buffer t >>= fun () ->
reset_seq t.buffer;
inform_app t

end

0 comments on commit 03220b3

Please sign in to comment.