Browse Source

Third round of review changes from Graeme

Sentryo R&D 3 years ago
parent
commit
fa035f329e
7 changed files with 117 additions and 133 deletions
  1. 2 4
      decode.go
  2. 1 1
      dumpcommand/tcpdump.go
  3. 1 1
      examples/reassemblydump/main.go
  4. 1 1
      layers/tcp.go
  5. 15 26
      packet.go
  6. 92 100
      reassembly/tcpassembly.go
  7. 5 0
      reassembly/tcpcheck.go

+ 2 - 4
decode.go

@@ -67,10 +67,8 @@ type PacketBuilder interface {
 	// data will be dumped to stderr so you can create a test.  This should never
 	// be called from a production decoder.
 	DumpPacketData()
-	// DecodeApplicationLayers returns true if the packet was
-	// created with a decode option to enable application-level
-	// decoding (currently only checked by the TCP decoder).
-	DecodeApplicationLayers() bool
+	// DecodeOptions returns the decode options
+	DecodeOptions() *DecodeOptions
 }
 
 // Decoder is an interface for logic to decode a packet layer.  Users may

+ 1 - 1
dumpcommand/tcpdump.go

@@ -44,7 +44,7 @@ func Run(src gopacket.PacketDataSource) {
 	source := gopacket.NewPacketSource(src, dec)
 	source.Lazy = *lazy
 	source.NoCopy = true
-	source.DecodeApplicationLayers = true
+	source.DecodeStreamsAsDatagrams = true
 	fmt.Fprintln(os.Stderr, "Starting to read packets")
 	count := 0
 	bytes := int64(0)

+ 1 - 1
examples/reassemblydump/main.go

@@ -588,7 +588,7 @@ func main() {
 		}
 		if count%*statsevery == 0 {
 			ref := packet.Metadata().CaptureInfo.Timestamp
-			flushed, closed := assembler.FlushCloseOlderThan(ref.Add(-timeout), ref.Add(-closeTimeout))
+			flushed, closed := assembler.FlushWithOptions(reassembly.FlushOptions{T: ref.Add(-timeout), TC: ref.Add(-closeTimeout)})
 			Debug("Forced flush: %d flushed, %d closed (%s)", flushed, closed, ref)
 		}
 

+ 1 - 1
layers/tcp.go

@@ -307,7 +307,7 @@ func decodeTCP(data []byte, p gopacket.PacketBuilder) error {
 	if err != nil {
 		return err
 	}
-	if p.DecodeApplicationLayers() {
+	if p.DecodeOptions().DecodeStreamsAsDatagrams {
 		return p.NextDecoder(tcp.NextLayerType())
 	} else {
 		return p.NextDecoder(gopacket.LayerTypePayload)

+ 15 - 26
packet.go

@@ -112,15 +112,7 @@ type packet struct {
 	// metadata is the PacketMetadata for this packet
 	metadata PacketMetadata
 
-	// recoverPanics is true if we should recover from panics we see while
-	// decoding and set a DecodeFailure layer.
-	recoverPanics bool
-
-	// decodeApplicationLayers is true if we should try go decode
-	// layers after TCP in single packets. This is disabled by
-	// default because the reassembly package drives the decoding
-	// of TCP payload data after reassembly.
-	decodeApplicationLayers bool
+	decodeOptions DecodeOptions
 
 	// Pointers to the various important layers
 	link        LinkLayer
@@ -185,6 +177,10 @@ func (p *packet) Data() []byte {
 	return p.data
 }
 
+func (p *packet) DecodeOptions() *DecodeOptions {
+	return &p.decodeOptions
+}
+
 func (p *packet) addFinalDecodeError(err error, stack []byte) {
 	fail := &DecodeFailure{err: err, stack: stack}
 	if p.last == nil {
@@ -197,7 +193,7 @@ func (p *packet) addFinalDecodeError(err error, stack []byte) {
 }
 
 func (p *packet) recoverDecodeError() {
-	if p.recoverPanics {
+	if !p.decodeOptions.SkipDecodeRecovery {
 		if r := recover(); r != nil {
 			p.addFinalDecodeError(fmt.Errorf("%v", r), debug.Stack())
 		}
@@ -495,8 +491,6 @@ func (p *eagerPacket) LayerClass(lc LayerClass) Layer {
 func (p *eagerPacket) String() string { return p.packetString() }
 func (p *eagerPacket) Dump() string   { return p.packetDump() }
 
-func (p *eagerPacket) DecodeApplicationLayers() bool { return p.decodeApplicationLayers }
-
 // lazyPacket does lazy decoding on its packet data.  On construction it does
 // no initial decoding.  For each function call, it decodes only as many layers
 // as are necessary to compute the return value for that function.
@@ -609,8 +603,6 @@ func (p *lazyPacket) LayerClass(lc LayerClass) Layer {
 func (p *lazyPacket) String() string { p.Layers(); return p.packetString() }
 func (p *lazyPacket) Dump() string   { p.Layers(); return p.packetDump() }
 
-func (p *lazyPacket) DecodeApplicationLayers() bool { return false }
-
 // DecodeOptions tells gopacket how to decode a packet.
 type DecodeOptions struct {
 	// Lazy decoding decodes the minimum number of layers needed to return data
@@ -630,11 +622,11 @@ type DecodeOptions struct {
 	// the issue.  If this flag is set, panics are instead allowed to continue up
 	// the stack.
 	SkipDecodeRecovery bool
-	// DecodeApplicationLayers enables routing of
-	// application-level layers in the TCP decoder.  This is
-	// disabled by default as decoding is driven by the reassembly
-	// package, on reassembled TCP payload data.
-	DecodeApplicationLayers bool
+	// DecodeStreamsAsDatagrams enables routing of application-level layers in the TCP
+	// decoder. If true, we should try to decode layers after TCP in single packets.
+	// This is disabled by default because the reassembly package drives the decoding
+	// of TCP payload data after reassembly.
+	DecodeStreamsAsDatagrams bool
 }
 
 // Default decoding provides the safest (but slowest) method for decoding
@@ -652,8 +644,8 @@ var Lazy = DecodeOptions{Lazy: true}
 // NoCopy is a DecodeOptions with just NoCopy set.
 var NoCopy = DecodeOptions{NoCopy: true}
 
-// DecodeApplicationLayers is a DecodeOptions with just DecodeApplicationLayers set.
-var DecodeApplicationLayers = DecodeOptions{DecodeApplicationLayers: true}
+// DecodeStreamsAsDatagrams is a DecodeOptions with just DecodeStreamsAsDatagrams set.
+var DecodeStreamsAsDatagrams = DecodeOptions{DecodeStreamsAsDatagrams: true}
 
 // NewPacket creates a new Packet object from a set of bytes.  The
 // firstLayerDecoder tells it how to interpret the first layer from the bytes,
@@ -666,11 +658,10 @@ func NewPacket(data []byte, firstLayerDecoder Decoder, options DecodeOptions) Pa
 	}
 	if options.Lazy {
 		p := &lazyPacket{
-			packet: packet{data: data},
+			packet: packet{data: data, decodeOptions: options},
 			next:   firstLayerDecoder,
 		}
 		p.layers = p.initialLayers[:0]
-		p.recoverPanics = !options.SkipDecodeRecovery
 		// Crazy craziness:
 		// If the following return statemet is REMOVED, and Lazy is FALSE, then
 		// eager packet processing becomes 17% FASTER.  No, there is no logical
@@ -684,11 +675,9 @@ func NewPacket(data []byte, firstLayerDecoder Decoder, options DecodeOptions) Pa
 		return p
 	}
 	p := &eagerPacket{
-		packet: packet{data: data},
+		packet: packet{data: data, decodeOptions: options},
 	}
 	p.layers = p.initialLayers[:0]
-	p.recoverPanics = !options.SkipDecodeRecovery
-	p.decodeApplicationLayers = options.DecodeApplicationLayers
 	p.initialDecode(firstLayerDecoder)
 	return p
 }

+ 92 - 100
reassembly/tcpassembly.go

@@ -106,16 +106,16 @@ type ScatterGather interface {
 
 // byteContainer is either a page or a livePacket
 type byteContainer interface {
-	Bytes() []byte
-	Length() int
-	ConvertToPages(*pageCache, int, AssemblerContext) (*page, *page, int)
-	CaptureInfo() gopacket.CaptureInfo
-	AssemblerContext() AssemblerContext
-	Release(*pageCache) int
-	Start() bool
-	End() bool
-	Seq() Sequence
-	IsPacket() bool
+	getBytes() []byte
+	length() int
+	convertToPages(*pageCache, int, AssemblerContext) (*page, *page, int)
+	captureInfo() gopacket.CaptureInfo
+	assemblerContext() AssemblerContext
+	release(*pageCache) int
+	isStart() bool
+	isEnd() bool
+	getSeq() Sequence
+	isPacket() bool
 }
 
 // Implements a ScatterGather
@@ -135,18 +135,18 @@ type reassemblyObject struct {
 func (rl *reassemblyObject) Lengths() (int, int) {
 	l := 0
 	for _, r := range rl.all {
-		l += r.Length()
+		l += r.length()
 	}
 	return l, rl.saved
 }
 
 func (rl *reassemblyObject) Fetch(l int) []byte {
-	if l <= rl.all[0].Length() {
-		return rl.all[0].Bytes()[:l]
+	if l <= rl.all[0].length() {
+		return rl.all[0].getBytes()[:l]
 	}
 	bytes := make([]byte, 0, l)
 	for _, bc := range rl.all {
-		bytes = append(bytes, bc.Bytes()...)
+		bytes = append(bytes, bc.getBytes()...)
 	}
 	return bytes[:l]
 }
@@ -159,22 +159,22 @@ func (rl *reassemblyObject) CaptureInfo(offset int) gopacket.CaptureInfo {
 	current := 0
 	for _, r := range rl.all {
 		if current >= offset {
-			return r.CaptureInfo()
+			return r.captureInfo()
 		}
-		current += r.Length()
+		current += r.length()
 	}
 	// Invalid offset
 	return gopacket.CaptureInfo{}
 }
 
 func (rl *reassemblyObject) Info() (TCPFlowDirection, bool, bool, int) {
-	return rl.Direction, rl.all[0].Start(), rl.all[len(rl.all)-1].End(), rl.Skip
+	return rl.Direction, rl.all[0].isStart(), rl.all[len(rl.all)-1].isEnd(), rl.Skip
 }
 
 func (rl *reassemblyObject) Stats() TCPAssemblyStats {
 	packets := int(0)
 	for _, r := range rl.all {
-		if r.IsPacket() {
+		if r.isPacket() {
 			packets++
 		}
 	}
@@ -226,27 +226,25 @@ func (dir TCPFlowDirection) Reverse() TCPFlowDirection {
 // avoids memory allocation.  Used pages are stored in a doubly-linked list in
 // a connection.
 type page struct {
-	bytes []byte
-	seq   Sequence
-	//index      int
+	bytes      []byte
+	seq        Sequence
 	prev, next *page
 	buf        [pageBytes]byte
-	ac         AssemblerContext
+	ac         AssemblerContext // only set for the first page of a packet
 	seen       time.Time
 	start, end bool
-	isPacket   bool // only the first page of a packet has this being set
 }
 
-func (p *page) Bytes() []byte {
+func (p *page) getBytes() []byte {
 	return p.bytes
 }
-func (p *page) CaptureInfo() gopacket.CaptureInfo {
+func (p *page) captureInfo() gopacket.CaptureInfo {
 	return p.ac.GetCaptureInfo()
 }
-func (p *page) AssemblerContext() AssemblerContext {
+func (p *page) assemblerContext() AssemblerContext {
 	return p.ac
 }
-func (p *page) ConvertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
+func (p *page) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
 	if skip != 0 {
 		p.bytes = p.bytes[skip:]
 		p.seq = p.seq.Add(skip)
@@ -254,27 +252,27 @@ func (p *page) ConvertToPages(pc *pageCache, skip int, ac AssemblerContext) (*pa
 	p.prev, p.next = nil, nil
 	return p, p, 1
 }
-func (p *page) Length() int {
+func (p *page) length() int {
 	return len(p.bytes)
 }
-func (p *page) Release(pc *pageCache) int {
+func (p *page) release(pc *pageCache) int {
 	pc.replace(p)
 	return 1
 }
-func (p *page) Start() bool {
+func (p *page) isStart() bool {
 	return p.start
 }
-func (p *page) End() bool {
+func (p *page) isEnd() bool {
 	return p.end
 }
-func (p *page) Seq() Sequence {
+func (p *page) getSeq() Sequence {
 	return p.seq
 }
-func (p *page) IsPacket() bool {
-	return p.isPacket
+func (p *page) isPacket() bool {
+	return p.ac != nil
 }
 func (p *page) String() string {
-	return fmt.Sprintf("[email protected]%p{seq: %v, Bytes:%d, -> nextSeq:%v} (prev:%p, next:%p)", p, p.seq, len(p.bytes), p.seq+Sequence(len(p.bytes)), p.prev, p.next)
+	return fmt.Sprintf("[email protected]%p{seq: %v, bytes:%d, -> nextSeq:%v} (prev:%p, next:%p)", p, p.seq, len(p.bytes), p.seq+Sequence(len(p.bytes)), p.prev, p.next)
 }
 
 /* livePacket: implements a byteContainer */
@@ -287,39 +285,39 @@ type livePacket struct {
 	seq   Sequence
 }
 
-func (lp *livePacket) Bytes() []byte {
+func (lp *livePacket) getBytes() []byte {
 	return lp.bytes
 }
-func (lp *livePacket) CaptureInfo() gopacket.CaptureInfo {
+func (lp *livePacket) captureInfo() gopacket.CaptureInfo {
 	return lp.ci
 }
-func (lp *livePacket) AssemblerContext() AssemblerContext {
+func (lp *livePacket) assemblerContext() AssemblerContext {
 	return lp.ac
 }
-func (lp *livePacket) Length() int {
+func (lp *livePacket) length() int {
 	return len(lp.bytes)
 }
-func (lp *livePacket) Start() bool {
+func (lp *livePacket) isStart() bool {
 	return lp.start
 }
-func (lp *livePacket) End() bool {
+func (lp *livePacket) isEnd() bool {
 	return lp.end
 }
-func (lp *livePacket) Seq() Sequence {
+func (lp *livePacket) getSeq() Sequence {
 	return lp.seq
 }
-func (lp *livePacket) IsPacket() bool {
+func (lp *livePacket) isPacket() bool {
 	return true
 }
 
 // Creates a page (or set of pages) from a TCP packet: returns the first and last
 // page in its doubly-linked list of new pages.
-func (lp *livePacket) ConvertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
+func (lp *livePacket) convertToPages(pc *pageCache, skip int, ac AssemblerContext) (*page, *page, int) {
 	ts := lp.ci.Timestamp
 	first := pc.next(ts)
 	current := first
 	current.prev = nil
-	first.isPacket = true
+	first.ac = ac
 	numPages := 1
 	seq, bytes := lp.seq.Add(skip), lp.bytes[skip:]
 	for {
@@ -327,10 +325,9 @@ func (lp *livePacket) ConvertToPages(pc *pageCache, skip int, ac AssemblerContex
 		current.bytes = current.buf[:length]
 		copy(current.bytes, bytes)
 		current.seq = seq
-		current.ac = ac
 		bytes = bytes[length:]
 		if len(bytes) == 0 {
-			current.end = lp.End()
+			current.end = lp.isEnd()
 			current.next = nil
 			break
 		}
@@ -338,6 +335,7 @@ func (lp *livePacket) ConvertToPages(pc *pageCache, skip int, ac AssemblerContex
 		current.next = pc.next(ts)
 		current.next.prev = current
 		current = current.next
+		current.ac = nil
 		numPages++
 	}
 	return first, current, numPages
@@ -346,7 +344,7 @@ func (lp *livePacket) estimateNumberOfPages() int {
 	return (len(lp.bytes) + pageBytes + 1) / pageBytes
 }
 
-func (lp *livePacket) Release(*pageCache) int {
+func (lp *livePacket) release(*pageCache) int {
 	return 0
 }
 
@@ -451,31 +449,15 @@ type connection struct {
 
 func (c *connection) reset(k key, s Stream, ts time.Time) {
 	c.key = k
-	c.c2s.pages = 0
-	c.s2c.pages = 0
-	c.c2s.first, c.c2s.last = nil, nil
-	c.s2c.first, c.s2c.last = nil, nil
-	c.c2s.nextSeq = invalidSequence
-	c.s2c.nextSeq = invalidSequence
-	c.c2s.ackSeq = invalidSequence
-	c.s2c.ackSeq = invalidSequence
-	c.c2s.created = ts
-	c.s2c.created = ts
-	c.s2c.lastSeen = ts
-	c.c2s.stream = s
-	c.s2c.stream = s
-	c.c2s.closed = false
-	c.s2c.closed = false
-	c.c2s.dir = TCPDirClientToServer
-	c.s2c.dir = TCPDirServerToClient
-	c.c2s.queuedBytes = 0
-	c.s2c.queuedBytes = 0
-	c.c2s.queuedPackets = 0
-	c.s2c.queuedPackets = 0
-	c.c2s.overlapBytes = 0
-	c.s2c.overlapBytes = 0
-	c.c2s.overlapPackets = 0
-	c.s2c.overlapPackets = 0
+	base := halfconnection{
+		nextSeq:  invalidSequence,
+		ackSeq:   invalidSequence,
+		created:  ts,
+		lastSeen: ts,
+		stream:   s,
+	}
+	c.c2s, c.s2c = base, base
+	c.c2s.dir, c.s2c.dir = TCPDirClientToServer, TCPDirServerToClient
 }
 
 func (c *connection) String() string {
@@ -790,7 +772,7 @@ func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerC
 			if *debugLog {
 				log.Printf("case 3\n")
 			}
-			if cur.isPacket {
+			if cur.isPacket() {
 				half.overlapPackets++
 			}
 			half.overlapBytes += len(cur.bytes)
@@ -806,7 +788,7 @@ func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerC
 				half.last = cur.prev
 			}
 			tmp := cur.prev
-			half.pages -= cur.Release(a.pc)
+			half.pages -= cur.release(a.pc)
 			cur = tmp
 			continue
 		}
@@ -850,7 +832,7 @@ func (a *Assembler) checkOverlap(half *halfconnection, queue bool, ac AssemblerC
 	a.cacheLP.bytes = bytes
 	a.cacheLP.seq = start
 	if len(bytes) > 0 && queue {
-		p, p2, numPages := a.cacheLP.ConvertToPages(a.pc, 0, ac)
+		p, p2, numPages := a.cacheLP.convertToPages(a.pc, 0, ac)
 		half.queuedPackets++
 		half.queuedBytes += len(bytes)
 		half.pages += numPages
@@ -909,17 +891,17 @@ func (a *Assembler) dump(text string, half *halfconnection) {
 		log.Printf(" * half.saved = %p\n", half.saved)
 		p = half.saved
 		for p != nil {
-			log.Printf("\tseq:%d %s bytes:%s\n", p.Seq(), p, hex.EncodeToString(p.bytes))
+			log.Printf("\tseq:%d %s bytes:%s\n", p.getSeq(), p, hex.EncodeToString(p.bytes))
 			p = p.next
 		}
 	}
 	log.Printf(" * a.ret\n")
 	for i, r := range a.ret {
-		log.Printf("\t%d: %s b:%s\n", i, r.CaptureInfo(), hex.EncodeToString(r.Bytes()))
+		log.Printf("\t%d: %s b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
 	}
 	log.Printf(" * a.cacheSG.all\n")
 	for i, r := range a.cacheSG.all {
-		log.Printf("\t%d: %s b:%s\n", i, r.CaptureInfo(), hex.EncodeToString(r.Bytes()))
+		log.Printf("\t%d: %s b:%s\n", i, r.captureInfo(), hex.EncodeToString(r.getBytes()))
 	}
 }
 
@@ -1001,11 +983,11 @@ func (a *Assembler) buildSG(half *halfconnection) (bool, Sequence) {
 	// find if there are skipped bytes
 	skip := -1
 	if half.nextSeq != invalidSequence {
-		skip = half.nextSeq.Difference(a.ret[0].Seq())
+		skip = half.nextSeq.Difference(a.ret[0].getSeq())
 	}
-	last := a.ret[0].Seq().Add(a.ret[0].Length())
+	last := a.ret[0].getSeq().Add(a.ret[0].length())
 	// Prepend saved bytes
-	saved := a.addPending(half, a.ret[0].Seq())
+	saved := a.addPending(half, a.ret[0].getSeq())
 	// Append continuous bytes
 	nextSeq := a.addContiguous(half, last)
 	a.cacheSG.all = a.ret
@@ -1015,7 +997,7 @@ func (a *Assembler) buildSG(half *halfconnection) (bool, Sequence) {
 	a.cacheSG.toKeep = -1
 	a.setStatsToSG(half)
 	a.dump("after buildSG", half)
-	return a.ret[len(a.ret)-1].End(), nextSeq
+	return a.ret[len(a.ret)-1].isEnd(), nextSeq
 }
 
 func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
@@ -1033,13 +1015,13 @@ func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
 		skip = a.cacheSG.toKeep
 		found := false
 		for ndx, r = range a.cacheSG.all {
-			if a.cacheSG.toKeep < cur+r.Length() {
+			if a.cacheSG.toKeep < cur+r.length() {
 				found = true
 				break
 			}
-			cur += r.Length()
-			if skip >= r.Length() {
-				skip -= r.Length()
+			cur += r.length()
+			if skip >= r.length() {
+				skip -= r.length()
 			}
 		}
 		if !found {
@@ -1063,7 +1045,7 @@ func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
 				half.first = half.first.next
 			}
 		}
-		half.pages -= r.Release(a.pc)
+		half.pages -= r.release(a.pc)
 	}
 	a.dump("after consumed release", half)
 	// Keep un-consumed pages
@@ -1071,7 +1053,7 @@ func (a *Assembler) cleanSG(half *halfconnection, ac AssemblerContext) {
 	half.saved = nil
 	var saved *page
 	for _, r := range a.cacheSG.all[ndx:] {
-		first, last, nb := r.ConvertToPages(a.pc, skip, ac)
+		first, last, nb := r.convertToPages(a.pc, skip, ac)
 		if half.saved == nil {
 			half.saved = first
 		} else {
@@ -1125,7 +1107,7 @@ func (a *Assembler) addPending(half *halfconnection, firstSeq Sequence) int {
 		var next *page
 		for p := half.saved; p != nil; p = next {
 			next = p.next
-			p.Release(a.pc)
+			p.release(a.pc)
 		}
 		half.saved = nil
 		ret = []byteContainer{}
@@ -1181,7 +1163,7 @@ func (a *Assembler) skipFlush(conn *connection, half *halfconnection) {
 	}
 	a.ret = a.ret[:0]
 	a.addNextFromConn(half)
-	nextSeq := a.sendToConnection(conn, half, a.ret[0].AssemblerContext())
+	nextSeq := a.sendToConnection(conn, half, a.ret[0].assemblerContext())
 	if nextSeq != invalidSequence {
 		half.nextSeq = nextSeq
 	}
@@ -1222,10 +1204,19 @@ func (a *Assembler) addNextFromConn(conn *halfconnection) {
 	}
 }
 
-// FlushCloseOlderThan finds any streams waiting for packets older than
-// the given time, and pushes through the data they have (IE: tells
+// FlushOptions provide options for flushing connections.
+type FlushOptions struct {
+	T  time.Time // If nonzero, only connections with data older than T are flushed
+	TC time.Time // If nonzero, only connections with data older than TC are closed (if no FIN/RST received)
+}
+
+// FlushWithOptions finds any streams waiting for packets older than
+// the given time T, and pushes through the data they have (IE: tells
 // them to stop waiting and skip the data they're waiting for).
 //
+// It also closes streams older than TC (that can be set to zero, to keep
+// long-lived stream alive, but to flush data anyway).
+//
 // Each Stream maintains a list of zero or more sets of bytes it has received
 // out-of-order.  For example, if it has processed up through sequence number
 // 10, it might have bytes [15-20), [20-25), [30,50) in its list.  Each set of
@@ -1238,13 +1229,9 @@ func (a *Assembler) addNextFromConn(conn *halfconnection) {
 // otherwise it will wait until the next FlushCloseOlderThan to see if bytes
 // [25-30) come in.
 //
-// If it pushes all bytes (or there were no sets of bytes to begin with)
-// AND the connection has not received any bytes since the passed-in time,
-// the connection will be closed.
-//
 // Returns the number of connections flushed, and of those, the number closed
 // because of the flush.
-func (a *Assembler) FlushCloseOlderThan(t time.Time, tc time.Time) (flushed, closed int) {
+func (a *Assembler) FlushWithOptions(opt FlushOptions) (flushed, closed int) {
 	conns := a.connPool.connections()
 	closes := 0
 	flushes := 0
@@ -1252,7 +1239,7 @@ func (a *Assembler) FlushCloseOlderThan(t time.Time, tc time.Time) (flushed, clo
 		remove := false
 		conn.mu.Lock()
 		for _, half := range []*halfconnection{&conn.s2c, &conn.c2s} {
-			flushed, closed := a.flushClose(conn, half, t, tc)
+			flushed, closed := a.flushClose(conn, half, opt.T, opt.TC)
 			if flushed {
 				flushes++
 			}
@@ -1260,7 +1247,7 @@ func (a *Assembler) FlushCloseOlderThan(t time.Time, tc time.Time) (flushed, clo
 				closes++
 			}
 		}
-		if conn.s2c.closed && conn.c2s.closed && conn.s2c.lastSeen.Before(tc) && conn.c2s.lastSeen.Before(tc) {
+		if conn.s2c.closed && conn.c2s.closed && conn.s2c.lastSeen.Before(opt.TC) && conn.c2s.lastSeen.Before(opt.TC) {
 			remove = true
 		}
 		conn.mu.Unlock()
@@ -1271,6 +1258,11 @@ func (a *Assembler) FlushCloseOlderThan(t time.Time, tc time.Time) (flushed, clo
 	return flushes, closes
 }
 
+// FlushCloseOlderThan flushes and closes streams older than given time
+func (a *Assembler) FlushCloseOlderThan(t time.Time) (flushed, closed int) {
+	return a.FlushWithOptions(FlushOptions{T: t, TC: t})
+}
+
 func (a *Assembler) flushClose(conn *connection, half *halfconnection, t time.Time, tc time.Time) (bool, bool) {
 	flushed, closed := false, false
 	if half.closed {

+ 5 - 0
reassembly/tcpcheck.go

@@ -107,6 +107,11 @@ func (t *TCPOptionCheck) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir TC
 
 // TCPSimpleFSM implements a very simple TCP state machine
 //
+// Usage:
+// When implementing a Stream interface and to avoid to consider packets that
+// would be rejected due to client/server's TCP stack, the  Accept() can call
+// TCPSimpleFSM.CheckState().
+//
 // Limitations:
 // - packet should be received in-order.
 // - no check on sequence number is performed