net: defragment IP packets

The defragmenting code is enabled by CONFIG_IP_DEFRAG; the code is
useful for TFTP and NFS transfers.  The user can specify the maximum
defragmented payload as CONFIG_NET_MAXDEFRAG (default 16k).
Since NFS has a bigger per-packet overhead than TFTP, the static
reassembly buffer can hold CONFIG_NET_MAXDEFRAG + the NFS overhead.

The packet buffer is used as an array of "hole" structures, acting as
a double-linked list. Each new fragment can split a hole in two,
reduce a hole or fill a hole. No support is there for a fragment
overlapping two diffrent holes (i.e., thre new fragment is across an
already-received fragment).

Signed-off-by: Alessandro Rubini <rubini@gnudd.com>
Signed-off-by: Ben Warren <biggerbadderben@gmail.com>
diff --git a/net/net.c b/net/net.c
index d1cc9b2..cab4b2d 100644
--- a/net/net.c
+++ b/net/net.c
@@ -1107,6 +1107,176 @@
 }
 #endif
 
+#ifdef CONFIG_IP_DEFRAG
+/*
+ * This function collects fragments in a single packet, according
+ * to the algorithm in RFC815. It returns NULL or the pointer to
+ * a complete packet, in static storage
+ */
+#ifndef CONFIG_NET_MAXDEFRAG
+#define CONFIG_NET_MAXDEFRAG 16384
+#endif
+/*
+ * MAXDEFRAG, above, is chosen in the config file and  is real data
+ * so we need to add the NFS overhead, which is more than TFTP.
+ * To use sizeof in the internal unnamed structures, we need a real
+ * instance (can't do "sizeof(struct rpc_t.u.reply))", unfortunately).
+ * The compiler doesn't complain nor allocates the actual structure
+ */
+static struct rpc_t rpc_specimen;
+#define IP_PKTSIZE (CONFIG_NET_MAXDEFRAG + sizeof(rpc_specimen.u.reply))
+
+#define IP_MAXUDP (IP_PKTSIZE - IP_HDR_SIZE_NO_UDP)
+
+/*
+ * this is the packet being assembled, either data or frag control.
+ * Fragments go by 8 bytes, so this union must be 8 bytes long
+ */
+struct hole {
+	/* first_byte is address of this structure */
+	u16 last_byte;	/* last byte in this hole + 1 (begin of next hole) */
+	u16 next_hole;	/* index of next (in 8-b blocks), 0 == none */
+	u16 prev_hole;	/* index of prev, 0 == none */
+	u16 unused;
+};
+
+static IP_t *__NetDefragment(IP_t *ip, int *lenp)
+{
+	static uchar pkt_buff[IP_PKTSIZE] __attribute__((aligned(PKTALIGN)));
+	static u16 first_hole, total_len;
+	struct hole *payload, *thisfrag, *h, *newh;
+	IP_t *localip = (IP_t *)pkt_buff;
+	uchar *indata = (uchar *)ip;
+	int offset8, start, len, done = 0;
+	u16 ip_off = ntohs(ip->ip_off);
+
+	/* payload starts after IP header, this fragment is in there */
+	payload = (struct hole *)(pkt_buff + IP_HDR_SIZE_NO_UDP);
+	offset8 =  (ip_off & IP_OFFS);
+	thisfrag = payload + offset8;
+	start = offset8 * 8;
+	len = ntohs(ip->ip_len) - IP_HDR_SIZE_NO_UDP;
+
+	if (start + len > IP_MAXUDP) /* fragment extends too far */
+		return NULL;
+
+	if (!total_len || localip->ip_id != ip->ip_id) {
+		/* new (or different) packet, reset structs */
+		total_len = 0xffff;
+		payload[0].last_byte = ~0;
+		payload[0].next_hole = 0;
+		payload[0].prev_hole = 0;
+		first_hole = 0;
+		/* any IP header will work, copy the first we received */
+		memcpy(localip, ip, IP_HDR_SIZE_NO_UDP);
+	}
+
+	/*
+	 * What follows is the reassembly algorithm. We use the payload
+	 * array as a linked list of hole descriptors, as each hole starts
+	 * at a multiple of 8 bytes. However, last byte can be whatever value,
+	 * so it is represented as byte count, not as 8-byte blocks.
+	 */
+
+	h = payload + first_hole;
+	while (h->last_byte < start) {
+		if (!h->next_hole) {
+			/* no hole that far away */
+			return NULL;
+		}
+		h = payload + h->next_hole;
+	}
+
+	if (offset8 + (len / 8) <= h - payload) {
+		/* no overlap with holes (dup fragment?) */
+		return NULL;
+	}
+
+	if (!(ip_off & IP_FLAGS_MFRAG)) {
+		/* no more fragmentss: truncate this (last) hole */
+		total_len = start + len;
+		h->last_byte = start + len;
+	}
+
+	/*
+	 * There is some overlap: fix the hole list. This code doesn't
+	 * deal with a fragment that overlaps with two different holes
+	 * (thus being a superset of a previously-received fragment).
+	 */
+
+	if ( (h >= thisfrag) && (h->last_byte <= start + len) ) {
+		/* complete overlap with hole: remove hole */
+		if (!h->prev_hole && !h->next_hole) {
+			/* last remaining hole */
+			done = 1;
+		} else if (!h->prev_hole) {
+			/* first hole */
+			first_hole = h->next_hole;
+			payload[h->next_hole].prev_hole = 0;
+		} else if (!h->next_hole) {
+			/* last hole */
+			payload[h->prev_hole].next_hole = 0;
+		} else {
+			/* in the middle of the list */
+			payload[h->next_hole].prev_hole = h->prev_hole;
+			payload[h->prev_hole].next_hole = h->next_hole;
+		}
+
+	} else if (h->last_byte <= start + len) {
+		/* overlaps with final part of the hole: shorten this hole */
+		h->last_byte = start;
+
+	} else if (h >= thisfrag) {
+		/* overlaps with initial part of the hole: move this hole */
+		newh = thisfrag + (len / 8);
+		*newh = *h;
+		h = newh;
+		if (h->next_hole)
+			payload[h->next_hole].prev_hole = (h - payload);
+		if (h->prev_hole)
+			payload[h->prev_hole].next_hole = (h - payload);
+		else
+			first_hole = (h - payload);
+
+	} else {
+		/* fragment sits in the middle: split the hole */
+		newh = thisfrag + (len / 8);
+		*newh = *h;
+		h->last_byte = start;
+		h->next_hole = (newh - payload);
+		newh->prev_hole = (h - payload);
+		if (newh->next_hole)
+			payload[newh->next_hole].prev_hole = (newh - payload);
+	}
+
+	/* finally copy this fragment and possibly return whole packet */
+	memcpy((uchar *)thisfrag, indata + IP_HDR_SIZE_NO_UDP, len);
+	if (!done)
+		return NULL;
+
+	localip->ip_len = htons(total_len);
+	*lenp = total_len + IP_HDR_SIZE_NO_UDP;
+	return localip;
+}
+
+static inline IP_t *NetDefragment(IP_t *ip, int *lenp)
+{
+	u16 ip_off = ntohs(ip->ip_off);
+	if (!(ip_off & (IP_OFFS | IP_FLAGS_MFRAG)))
+		return ip; /* not a fragment */
+	return __NetDefragment(ip, lenp);
+}
+
+#else /* !CONFIG_IP_DEFRAG */
+
+static inline IP_t *NetDefragment(IP_t *ip, int *lenp)
+{
+	u16 ip_off = ntohs(ip->ip_off);
+	if (!(ip_off & (IP_OFFS | IP_FLAGS_MFRAG)))
+		return ip; /* not a fragment */
+	return NULL;
+}
+#endif
 
 void
 NetReceive(volatile uchar * inpkt, int len)
@@ -1333,10 +1503,12 @@
 
 	case PROT_IP:
 		debug("Got IP\n");
+		/* Before we start poking the header, make sure it is there */
 		if (len < IP_HDR_SIZE) {
 			debug("len bad %d < %lu\n", len, (ulong)IP_HDR_SIZE);
 			return;
 		}
+		/* Check the packet length */
 		if (len < ntohs(ip->ip_len)) {
 			printf("len bad %d < %d\n", len, ntohs(ip->ip_len));
 			return;
@@ -1344,21 +1516,20 @@
 		len = ntohs(ip->ip_len);
 		debug("len=%d, v=%02x\n", len, ip->ip_hl_v & 0xff);
 
+		/* Can't deal with anything except IPv4 */
 		if ((ip->ip_hl_v & 0xf0) != 0x40) {
 			return;
 		}
-		/* Can't deal with fragments */
-		if (ip->ip_off & htons(IP_OFFS | IP_FLAGS_MFRAG)) {
-			return;
-		}
-		/* can't deal with headers > 20 bytes */
+		/* Can't deal with IP options (headers != 20 bytes) */
 		if ((ip->ip_hl_v & 0x0f) > 0x05) {
 			return;
 		}
+		/* Check the Checksum of the header */
 		if (!NetCksumOk((uchar *)ip, IP_HDR_SIZE_NO_UDP / 2)) {
 			puts ("checksum bad\n");
 			return;
 		}
+		/* If it is not for us, ignore it */
 		tmp = NetReadIP(&ip->ip_dst);
 		if (NetOurIP && tmp != NetOurIP && tmp != 0xFFFFFFFF) {
 #ifdef CONFIG_MCAST_TFTP
@@ -1367,6 +1538,13 @@
 			return;
 		}
 		/*
+		 * The function returns the unchanged packet if it's not
+		 * a fragment, and either the complete packet or NULL if
+		 * it is a fragment (if !CONFIG_IP_DEFRAG, it returns NULL)
+		 */
+		if (!(ip = NetDefragment(ip, &len)))
+			return;
+		/*
 		 * watch for ICMP host redirects
 		 *
 		 * There is no real handler code (yet). We just watch