فهرست منبع

Implement ssl_recvall & ssl_sendall

Doug Nazar 1 سال پیش
والد
کامیت
846a01acd8
4فایلهای تغییر یافته به همراه156 افزوده شده و 68 حذف شده
  1. 3 0
      include/nrpe-ssl.h
  2. 11 37
      src/check_nrpe.c
  3. 137 0
      src/nrpe-ssl.c
  4. 5 31
      src/nrpe.c

+ 3 - 0
include/nrpe-ssl.h

@@ -44,4 +44,7 @@ void ssl_log_startup(int server);
 int ssl_load_certificates(void);
 int ssl_set_ciphers(void);
 int ssl_verify_callback_common(int preverify_ok, X509_STORE_CTX * ctx, int is_invalid);
+
+int ssl_recvall(SSL *ssl, char *buf, int *len, int timeout);
+int ssl_sendall(SSL *ssl, char *buf, int len);
 #endif

+ 11 - 37
src/check_nrpe.c

@@ -1059,16 +1059,11 @@ int send_request(void)
 	bytes_to_send = pkt_size;
 
 #ifdef HAVE_SSL
-	if (use_ssl == FALSE)
+	if (use_ssl)
+		rc = ssl_sendall(ssl, send_pkt, bytes_to_send);
+	else
 #endif
 		rc = sendall(sd, (char *)send_pkt, &bytes_to_send);
-#ifdef HAVE_SSL
-	else {
-		rc = SSL_write(ssl, send_pkt, bytes_to_send);
-		if (rc < 0)
-			rc = -1;
-	}
-#endif
 
 	if (v3_send_packet) {
 		free(v3_send_packet);
@@ -1233,9 +1228,6 @@ int read_response(void)
 
 int read_packet(int sock, void *ssl_ptr, v2_packet ** v2_pkt, v3_packet ** v3_pkt)
 {
-#ifdef HAVE_SSL
-	int32_t bytes_read = 0;
-#endif
 	v2_packet	packet;
 	int32_t pkt_size, common_size, tot_bytes, bytes_to_recv, buffer_size;
 	int rc;
@@ -1338,9 +1330,7 @@ int read_packet(int sock, void *ssl_ptr, v2_packet ** v2_pkt, v3_packet ** v3_pk
 	else {
 		SSL *ssl = (SSL *) ssl_ptr;
 
-		while (((rc = SSL_read(ssl, &packet, bytes_to_recv)) <= 0)
-			   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-		}
+		rc = ssl_recvall(ssl, (char *)&packet, &tot_bytes, socket_timeout);
 
 		if (rc <= 0 || rc != bytes_to_recv) {
 			if (rc > 0 && rc < bytes_to_recv) {
@@ -1380,20 +1370,14 @@ int read_packet(int sock, void *ssl_ptr, v2_packet ** v2_pkt, v3_packet ** v3_pk
 
 			/* Read the alignment filler */
 			bytes_to_recv = sizeof(int16_t);
-			while (((rc = SSL_read(ssl, &buffer_size, bytes_to_recv)) <= 0)
-				   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-			}
-
+			rc = ssl_recvall(ssl, (char *)&buffer_size, &bytes_to_recv, socket_timeout);
 			if (rc <= 0 || bytes_to_recv != sizeof(int16_t))
 				return -1;
 			tot_bytes += rc;
 
 			/* Read the buffer size */
 			bytes_to_recv = sizeof(buffer_size);
-			while (((rc = SSL_read(ssl, &buffer_size, bytes_to_recv)) <= 0)
-				   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-			}
-
+			rc = ssl_recvall(ssl, (char *)&buffer_size, &bytes_to_recv, socket_timeout);
 			if (rc <= 0 || bytes_to_recv != sizeof(buffer_size))
 				return -1;
 			tot_bytes += rc;
@@ -1415,19 +1399,9 @@ int read_packet(int sock, void *ssl_ptr, v2_packet ** v2_pkt, v3_packet ** v3_pk
 		}
 
 		bytes_to_recv = buffer_size;
-		for (;;) {
-			while (((rc = SSL_read(ssl, &buff_ptr[bytes_read], bytes_to_recv)) <= 0)
-				   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-			}
+		rc = ssl_recvall(ssl, buff_ptr, &bytes_to_recv, socket_timeout);
 
-			if (rc <= 0)
-				break;
-			bytes_read += rc;
-			bytes_to_recv -= rc;
-			tot_bytes += rc;
-		}
-
-		if (rc < 0 || bytes_read != buffer_size) {
+		if (rc <= 0 || rc != buffer_size) {
 			if (packet_ver >= NRPE_PACKET_VERSION_3) {
 				free(*v3_pkt);
 				*v3_pkt = NULL;
@@ -1435,11 +1409,11 @@ int read_packet(int sock, void *ssl_ptr, v2_packet ** v2_pkt, v3_packet ** v3_pk
 				free(*v2_pkt);
 				*v2_pkt = NULL;
 			}
-			if (bytes_read != buffer_size) {
+			if (rc > 0 && rc < buffer_size) {
 				if (packet_ver >= NRPE_PACKET_VERSION_3) {
-					printf("CHECK_NRPE: Receive buffer size - %ld bytes received (%ld expected).\n", (long)bytes_read, (long)buffer_size);
+					printf("CHECK_NRPE: Receive buffer size - %ld bytes received (%ld expected).\n", (long)rc, (long)buffer_size);
 				} else {
-					printf("CHECK_NRPE: Receive underflow - only %ld bytes received (%ld expected).\n", (long)bytes_read, (long)buffer_size);
+					printf("CHECK_NRPE: Receive underflow - only %ld bytes received (%ld expected).\n", (long)rc, (long)buffer_size);
 				}
 			}
 			return -1;

+ 137 - 0
src/nrpe-ssl.c

@@ -285,3 +285,140 @@ int ssl_verify_callback_common(int preverify_ok, X509_STORE_CTX * ctx, int is_in
 
 	return preverify_ok;
 }
+
+
+int ssl_recvall(SSL *ssl, char *buf, int *len, int timeout)
+{
+	time_t start_time;
+	time_t current_time;
+	int total = 0;
+	int bytesleft = *len;
+	int n = 0;
+	int fd = SSL_get_fd(ssl);
+
+	time(&start_time);
+
+	while (total < *len) {
+		int ern;
+		int ssl_err;
+		unsigned long x;
+
+		n = SSL_read(ssl, buf + total, bytesleft);
+		if (n > 0) {
+			// Success. Adjust and keep going
+			total += n;
+			bytesleft -= n;
+			continue;
+		}
+
+		ern = errno;
+		ssl_err = SSL_get_error(ssl, n);
+
+		if (ssl_err == SSL_ERROR_WANT_READ) {
+			int rc;
+			fd_set rfds;
+			struct timeval tv;
+
+			FD_ZERO(&rfds);
+			FD_SET(fd, &rfds);
+
+			tv.tv_sec = 2;
+			tv.tv_usec = 0;
+
+			rc = select(fd + 1, &rfds, NULL, NULL, &tv);
+			if (rc == -1) {
+				logit(LOG_ERR, "ERROR: ssl_recvall() select failed (errno=%i)", errno);
+				break;
+			}
+
+			if (rc == 0) {
+				// Timed out
+				time(&current_time);
+				if (current_time - start_time > timeout) {
+					logit(LOG_ERR, "ERROR: ssl_recvall() timed out");
+					break;
+				}
+			}
+
+			// we've either timed out or our fd should be ready for more data
+			continue;
+		}
+
+		logit(LOG_ERR, "ERROR: Error reading data! (rc=%i, errno=%i, ssl=%i)", n, ern, ssl_err);
+		while ((x = ERR_get_error()) != 0) {
+			logit(LOG_ERR, "     : %s", ERR_reason_error_string(x));
+		}
+
+		return -1;
+	}
+
+	/* return number of bytes actually received here */
+	*len = total;
+
+	/* return <=0 on failure, bytes received on success */
+	return (n <= 0) ? n : total;
+}
+
+int ssl_sendall(SSL *ssl, char *buf, int len)
+{
+	int total = 0;
+	int bytesleft = len;
+	int n = 0;
+	int timeouts = 0;
+	int fd = SSL_get_fd(ssl);
+
+	while (total < len) {
+		int ern;
+		int ssl_err;
+		unsigned long x;
+
+		n = SSL_write(ssl, buf + total, bytesleft);
+		if (n > 0) {
+			// Success. Adjust and keep going
+			total += n;
+			bytesleft -= n;
+			continue;
+		}
+
+		ern = errno;
+		ssl_err = SSL_get_error(ssl, n);
+
+		if (ssl_err == SSL_ERROR_WANT_WRITE) {
+			int rc;
+			fd_set wfds;
+			struct timeval tv;
+
+			FD_ZERO(&wfds);
+			FD_SET(fd, &wfds);
+
+			tv.tv_sec = 2;
+			tv.tv_usec = 0;
+
+			rc = select(fd + 1, NULL, &wfds, NULL, &tv);
+			if (rc == -1) {
+				logit(LOG_ERR, "ERROR: ssl_sendall() select failed (errno=%i)", errno);
+				break;
+			}
+
+			if (rc == 0) {
+				// Timed out
+				if (++timeouts > 5) {
+					logit(LOG_ERR, "ERROR: ssl_sendall() timed out");
+					break;
+				}
+			}
+
+			// we've either timed out or our fd should be ready for more data
+			continue;
+		}
+
+		logit(LOG_ERR, "ERROR: Error sending data! (rc=%i, errno=%i, ssl=%i)", n, ern, ssl_err);
+		while ((x = ERR_get_error()) != 0) {
+			logit(LOG_ERR, "     : %s", ERR_reason_error_string(x));
+		}
+
+		return -1;
+	}
+
+	return (total == len) ? 0 : -1;	/* return -1 on failure, 0 on success */
+}

+ 5 - 31
src/nrpe.c

@@ -1899,7 +1899,7 @@ void handle_connection(int sock)
 	bytes_to_send = pkt_size;
 #ifdef HAVE_SSL
 	if (use_ssl)
-		SSL_write(ssl, send_pkt, bytes_to_send);
+		ssl_sendall(ssl, send_pkt, bytes_to_send);
 	else
 #endif
 		sendall(sock, send_pkt, &bytes_to_send);
@@ -2109,28 +2109,8 @@ int read_packet(int sock, void *ssl_ptr, v2_packet * v2_pkt, v3_packet ** v3_pkt
 	}
 	else {
 		SSL      *ssl = (SSL *) ssl_ptr;
-		int       sockfd, retval;
-		fd_set    rfds;
-		struct timeval timeout;
 
-		sockfd = SSL_get_fd(ssl);
-
-		FD_ZERO(&rfds);
-		FD_SET(sockfd, &rfds);
-
-		timeout.tv_sec = connection_timeout;
-		timeout.tv_usec = 0;
-
-		do {
-			retval = select(sockfd + 1, &rfds, NULL, NULL, &timeout);
-
-			if (retval > 0) {
-				rc = SSL_read(ssl, v2_pkt, bytes_to_recv);
-			} else {
-				logit(LOG_ERR, "Error (!log_opts): Could not complete SSL_read with %s: timeout %d seconds", remote_host, connection_timeout);
-				return -1;
-			}
-		} while (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ);
+		rc = ssl_recvall(ssl, (char *)v2_pkt, &tot_bytes, socket_timeout);
 
 		if (rc <= 0 || rc != bytes_to_recv)
 			return -1;
@@ -2155,9 +2135,7 @@ int read_packet(int sock, void *ssl_ptr, v2_packet * v2_pkt, v3_packet ** v3_pkt
 
 			/* Read the alignment filler */
 			bytes_to_recv = sizeof(int16_t);
-			while (((rc = SSL_read(ssl, &buffer_size, bytes_to_recv)) <= 0)
-				   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-			}
+			rc = ssl_recvall(ssl, (char *)&buffer_size, &bytes_to_recv, socket_timeout);
 
 			if (rc <= 0 || bytes_to_recv != sizeof(int16_t))
 				return -1;
@@ -2165,9 +2143,7 @@ int read_packet(int sock, void *ssl_ptr, v2_packet * v2_pkt, v3_packet ** v3_pkt
 
 			/* Read the buffer size */
 			bytes_to_recv = sizeof(buffer_size);
-			while (((rc = SSL_read(ssl, &buffer_size, bytes_to_recv)) <= 0)
-				   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-			}
+			rc = ssl_recvall(ssl, (char *)&buffer_size, &bytes_to_recv, socket_timeout);
 
 			if (rc <= 0 || bytes_to_recv != sizeof(buffer_size))
 				return -1;
@@ -2190,9 +2166,7 @@ int read_packet(int sock, void *ssl_ptr, v2_packet * v2_pkt, v3_packet ** v3_pkt
 		}
 
 		bytes_to_recv = buffer_size;
-		while (((rc = SSL_read(ssl, buff_ptr, bytes_to_recv)) <= 0)
-			   && (SSL_get_error(ssl, rc) == SSL_ERROR_WANT_READ)) {
-		}
+		rc = ssl_recvall(ssl, buff_ptr, &bytes_to_recv, socket_timeout);
 
 		if (rc <= 0 || rc != buffer_size) {
 			if (packet_ver == NRPE_PACKET_VERSION_3) {