patch-2.1.61 linux/fs/smbfs/sock.c

Next file: linux/fs/vfat/namei.c
Previous file: linux/fs/smbfs/proc.c
Back to the patch index
Back to the overall index

diff -u --recursive --new-file v2.1.60/linux/fs/smbfs/sock.c linux/fs/smbfs/sock.c
@@ -7,11 +7,9 @@
  */
 
 #include <linux/sched.h>
-#include <linux/smb_fs.h>
 #include <linux/errno.h>
 #include <linux/socket.h>
 #include <linux/fcntl.h>
-#include <linux/stat.h>
 #include <linux/in.h>
 #include <linux/net.h>
 #include <linux/mm.h>
@@ -19,6 +17,7 @@
 #include <net/scm.h>
 #include <net/ip.h>
 
+#include <linux/smb_fs.h>
 #include <linux/smb.h>
 #include <linux/smbno.h>
 
@@ -97,11 +96,15 @@
 
 	while (1)
 	{
+		result = -EIO;
 		if (sk->dead)
 		{
+#ifdef SMBFS_PARANOIA
 			printk("smb_data_callback: sock dead!\n");
-			return;
+#endif
+			break;
 		}
+
 		result = _recvfrom(socket, (void *) peek_buf, 1,
 				   MSG_PEEK | MSG_DONTWAIT);
 		if (result == -EAGAIN)
@@ -362,6 +365,16 @@
 }
 
 /*
+ * Since we allocate memory in increments of PAGE_SIZE,
+ * round up the packet length to the next multiple.
+ */
+int
+smb_round_length(int len)
+{
+	return (len + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
+}
+ 
+/*
  * smb_receive
  * fs points to the correct segment
  */
@@ -369,6 +382,7 @@
 smb_receive(struct smb_sb_info *server)
 {
 	struct socket *socket = server_sock(server);
+	unsigned char * packet = server->packet;
 	int len, result;
 	unsigned char peek_buf[4];
 
@@ -383,19 +397,22 @@
 	 */
 	if (len + 4 > server->packet_size)
 	{
-		char * packet;
-		pr_debug("smb_receive: Increase packet size from %d to %d\n",
-			server->packet_size, len + 4);
+		int new_len = smb_round_length(len + 4);
+
+#ifdef SMBFS_PARANOIA
+printk("smb_receive: Increase packet size from %d to %d\n",
+server->packet_size, new_len);
+#endif
 		result = -ENOMEM;
-		packet = smb_vmalloc(len + 4);
+		packet = smb_vmalloc(new_len);
 		if (packet == NULL)
 			goto out;
 		smb_vfree(server->packet);
 		server->packet = packet;
-		server->packet_size = len + 4;
+		server->packet_size = new_len;
 	}
-	memcpy(server->packet, peek_buf, 4);
-	result = smb_receive_raw(socket, server->packet + 4, len);
+	memcpy(packet, peek_buf, 4);
+	result = smb_receive_raw(socket, packet + 4, len);
 	if (result < 0)
 	{
 #ifdef SMBFS_DEBUG_VERBOSE
@@ -403,8 +420,8 @@
 #endif
 		goto out;
 	}
-	server->rcls = *(server->packet+9);
-	server->err = WVAL(server->packet, 11);
+	server->rcls = *(packet+9);
+	server->err = WVAL(packet, 11);
 
 #ifdef SMBFS_DEBUG_VERBOSE
 if (server->rcls != 0)
@@ -415,136 +432,165 @@
 }
 
 /*
- * This routine needs a lot of work.  We should check whether the packet
- * is all one part before allocating a new one, and should try first to
- * copy to a temp buffer before allocating.
- * The final server->packet should be the larger of the two.
+ * This routine checks first for "fast track" processing, as most
+ * packets won't need to be copied. Otherwise, it allocates a new
+ * packet to hold the incoming data.
+ *
+ * Note that the final server packet must be the larger of the two;
+ * server packets aren't allowed to shrink.
  */
 static int
 smb_receive_trans2(struct smb_sb_info *server,
 		   int *ldata, unsigned char **data,
-		   int *lparam, unsigned char **param)
+		   int *lparm, unsigned char **parm)
 {
-	int total_data = 0;
-	int total_param = 0;
+	unsigned char *inbuf, *base, *rcv_buf = NULL;
+	unsigned int parm_disp, parm_offset, parm_count, parm_tot, parm_len = 0;
+	unsigned int data_disp, data_offset, data_count, data_tot, data_len = 0;
+	unsigned int total_p = 0, total_d = 0, buf_len = 0;
 	int result;
-	unsigned char *rcv_buf;
-	int buf_len;
-	int data_len = 0;
-	int param_len = 0;
-
-	if ((result = smb_receive(server)) < 0)
-	{
-		return result;
-	}
-	if (server->rcls != 0)
-	{
-		*param = *data = server->packet;
-		*ldata = *lparam = 0;
-		return 0;
-	}
-	total_data = WVAL(server->packet, smb_tdrcnt);
-	total_param = WVAL(server->packet, smb_tprcnt);
-
-	pr_debug("smb_receive_trans2: td=%d,tp=%d\n", total_data, total_param);
-
-	if ((total_data > TRANS2_MAX_TRANSFER)
-	    || (total_param > TRANS2_MAX_TRANSFER))
-	{
-		pr_debug("smb_receive_trans2: data/param too long\n");
-		return -EIO;
-	}
-	buf_len = total_data + total_param;
-	if (server->packet_size > buf_len)
-	{
-		buf_len = server->packet_size;
-	}
-	if ((rcv_buf = smb_vmalloc(buf_len)) == NULL)
-	{
-		pr_debug("smb_receive_trans2: could not alloc data area\n");
-		return -ENOMEM;
-	}
-	*param = rcv_buf;
-	*data = rcv_buf + total_param;
 
 	while (1)
 	{
-		unsigned char *inbuf = server->packet;
-
-		if (WVAL(inbuf, smb_prdisp) + WVAL(inbuf, smb_prcnt)
-		    > total_param)
-		{
-			pr_debug("smb_receive_trans2: invalid parameters\n");
-			result = -EIO;
-			goto fail;
-		}
-		memcpy(*param + WVAL(inbuf, smb_prdisp),
-		       smb_base(inbuf) + WVAL(inbuf, smb_proff),
-		       WVAL(inbuf, smb_prcnt));
-		param_len += WVAL(inbuf, smb_prcnt);
-
-		if (WVAL(inbuf, smb_drdisp) + WVAL(inbuf, smb_drcnt)
-		    > total_data)
-		{
-			pr_debug("smb_receive_trans2: invalid data block\n");
-			result = -EIO;
-			goto fail;
-		}
-		pr_debug("disp: %d, off: %d, cnt: %d\n",
-			 WVAL(inbuf, smb_drdisp), WVAL(inbuf, smb_droff),
-			 WVAL(inbuf, smb_drcnt));
-
-		memcpy(*data + WVAL(inbuf, smb_drdisp),
-		       smb_base(inbuf) + WVAL(inbuf, smb_droff),
-		       WVAL(inbuf, smb_drcnt));
-		data_len += WVAL(inbuf, smb_drcnt);
-
-		if ((WVAL(inbuf, smb_tdrcnt) > total_data)
-		    || (WVAL(inbuf, smb_tprcnt) > total_param))
+		result = smb_receive(server);
+		if (result < 0)
+			goto out;
+		inbuf = server->packet;
+		if (server->rcls != 0)
 		{
-			pr_debug("smb_receive_trans2: data/params grew!\n");
-			result = -EIO;
-			goto fail;
+			*parm = *data = inbuf;
+			*ldata = *lparm = 0;
+			goto out;
 		}
-		/* the total lengths might shrink! */
-		total_data = WVAL(inbuf, smb_tdrcnt);
-		total_param = WVAL(inbuf, smb_tprcnt);
+		/*
+		 * Extract the control data from the packet.
+		 */
+		data_tot    = WVAL(inbuf, smb_tdrcnt);
+		parm_tot    = WVAL(inbuf, smb_tprcnt);
+		parm_disp   = WVAL(inbuf, smb_prdisp);
+		parm_offset = WVAL(inbuf, smb_proff);
+		parm_count  = WVAL(inbuf, smb_prcnt);
+		data_disp   = WVAL(inbuf, smb_drdisp);
+		data_offset = WVAL(inbuf, smb_droff);
+		data_count  = WVAL(inbuf, smb_drcnt);
+		base = smb_base(inbuf);
+
+		/*
+		 * Assume success and increment lengths.
+		 */
+		parm_len += parm_count;
+		data_len += data_count;
+
+		if (!rcv_buf)
+		{
+			/*
+			 * Check for fast track processing ... just this packet.
+			 */
+			if (parm_count == parm_tot && data_count == data_tot)
+			{
+#ifdef SMBFS_DEBUG_VERBOSE
+printk("smb_receive_trans2: fast track, parm=%u %u %u, data=%u %u %u\n",
+parm_disp, parm_offset, parm_count, data_disp, data_offset, data_count);
+#endif
+				*parm  = base + parm_offset;
+				*data  = base + data_offset;
+				goto success;
+			}
+
+			if (parm_tot > TRANS2_MAX_TRANSFER ||
+	  		    data_tot > TRANS2_MAX_TRANSFER)
+				goto out_too_long;
+
+			/*
+			 * Save the total parameter and data length.
+			 */
+			total_d = data_tot;
+			total_p = parm_tot;
+
+			buf_len = total_d + total_p;
+			if (server->packet_size > buf_len)
+				buf_len = server->packet_size;
+			buf_len = smb_round_length(buf_len);
+
+			rcv_buf = smb_vmalloc(buf_len);
+			if (!rcv_buf)
+				goto out_no_mem;
+			*parm = rcv_buf;
+			*data = rcv_buf + total_p;
+		}
+		else if (data_tot > total_d || parm_tot > total_p)
+			goto out_data_grew;
+
+		if (parm_disp + parm_count > total_p)
+			goto out_bad_parm;
+		if (data_disp + data_count > total_d)
+			goto out_bad_data;
+		memcpy(*parm + parm_disp, base + parm_offset, parm_count);
+		memcpy(*data + data_disp, base + data_offset, data_count);
 
 #ifdef SMBFS_PARANOIA
-if ((data_len >= total_data || param_len >= total_param) &&
-   !(data_len >= total_data && param_len >= total_param))
-printk("smb_receive_trans2: dlen=%d, tdata=%d, plen=%d, tlen=%d\n",
-data_len, total_data, param_len, total_param);
+printk("smb_receive_trans2: copied, parm=%u of %u, data=%u of %u\n",
+parm_len, parm_tot, data_len, data_tot);
 #endif
-		/* shouldn't this be an OR test? don't want to overrun */
-		if ((data_len >= total_data) && (param_len >= total_param))
-		{
+		/*
+		 * Check whether we've received all of the data. Note that
+		 * we use the packet totals -- total lengths might shrink!
+		 */
+		if (data_len >= data_tot && parm_len >= parm_tot)
 			break;
-		}
-		if ((result = smb_receive(server)) < 0)
-		{
-			goto fail;
-		}
-		result = -EIO;
-		if (server->rcls != 0)
-			goto fail;
 	}
-	*ldata = data_len;
-	*lparam = param_len;
 
+	/*
+	 * Install the new packet.  Note that it's possible, though
+	 * unlikely, that the new packet could be smaller than the
+	 * old one, in which case we just copy the data.
+	 */
+	inbuf = server->packet;
+	if (buf_len >= server->packet_size)
+	{
+		server->packet_size = buf_len;
+		server->packet = rcv_buf;
+		rcv_buf = inbuf;
+	} else
+	{
 #ifdef SMBFS_PARANOIA
-if (buf_len < server->packet_size)
-printk("smb_receive_trans2: changing packet, old size=%d, new size=%d\n",
+printk("smb_receive_trans2: copying data, old size=%d, new size=%u\n",
 server->packet_size, buf_len);
 #endif
-	smb_vfree(server->packet);
-	server->packet = rcv_buf;
-	server->packet_size = buf_len;
-	return 0;
+		memcpy(inbuf, rcv_buf, parm_len + data_len);
+	}
 
-      fail:
-	smb_vfree(rcv_buf);
+success:
+	*ldata = data_len;
+	*lparm = parm_len;
+out:
+	if (rcv_buf)
+		smb_vfree(rcv_buf);
 	return result;
+
+out_no_mem:
+#ifdef SMBFS_PARANOIA
+	printk("smb_receive_trans2: couldn't allocate data area\n");
+#endif
+	result = -ENOMEM;
+	goto out;
+out_too_long:
+	printk("smb_receive_trans2: data/param too long, data=%d, parm=%d\n",
+		data_tot, parm_tot);
+	goto out_error;
+out_data_grew:
+	printk("smb_receive_trans2: data/params grew!\n");
+	goto out_error;
+out_bad_parm:
+	printk("smb_receive_trans2: invalid parms, disp=%d, cnt=%d, tot=%d\n",
+		parm_disp, parm_count, parm_tot);
+	goto out_error;
+out_bad_data:
+	printk("smb_receive_trans2: invalid data, disp=%d, cnt=%d, tot=%d\n",
+		data_disp, data_count, data_tot);
+out_error:
+	result = -EIO;
+	goto out;
 }
 
 /*
@@ -759,14 +805,13 @@
 	}
 	if (result < 0)
 		goto bad_conn;
-	pr_debug("smb_trans2_request: result = %d\n", result);
 
 out:
 	return result;
 
 bad_conn:
 #ifdef SMBFS_PARANOIA
-printk("smb_trans2_request: connection bad, setting invalid\n");
+printk("smb_trans2_request: result=%d, setting invalid\n", result);
 #endif
 	server->state = CONN_INVALID;
 	smb_invalidate_inodes(server);

FUNET's LINUX-ADM group, linux-adm@nic.funet.fi
TCL-scripts by Sam Shen, slshen@lbl.gov