Winsock 套接字状态通知
简介
下表中的套接字状态通知 API 提供了一种可缩放且高效的方法来获取有关套接字状态更改的通知, (CPU 和内存) 方面都有效。 这包括有关诸如非阻止读取、非阻塞写入、错误条件和其他信息的通知。
API | 说明 |
---|---|
ProcessSocketNotifications 函数 | 将一组套接字与完成端口相关联,并检索该端口上已挂起的任何通知。 关联后,完成端口将接收指定的套接字状态通知。 |
SOCK_NOTIFY_REGISTRATION 结构 | 表示提供给 ProcessSocketNotifications 函数的信息。 |
SocketNotificationRetrieveEvents 函数 | 提供此内联帮助程序函数以方便从 OVERLAPPED_ENTRY检索事件掩码。 |
工作流首先将套接字与 I/O 完成端口关联 (ProcessSocketNotifications 和 SOCK_NOTIFY_REGISTRATION) 。 之后,端口使用常用的 I/O 完成端口查询方法提供有关套接字状态更改的信息。
这些 API 允许轻松构造与平台无关的抽象。 因此,支持持久性和一次性标志,以及级别和边缘触发的标志。 例如,建议多线程服务器采用单次级别触发注册模式。
建议
这些 API 提供了 WSAPoll 和 选择 API 的可缩放替代项。
它们是与 I/O 完成端口一起使用的重叠套接字 I/O 的替代方法,无需使用永久每个套接字的 I/O 缓冲区。 但是,如果每个套接字 I/O 缓冲区不是重要考虑因素, (套接字数相对较低,或者它们经常) 使用,重叠的套接字 I/O 可能会因为内核转换次数较少以及模型更简单而开销较少。
套接字只能与单个 I/O 完成端口相关联。 套接字只能向 I/O 完成端口注册一次。 若要更改完成密钥,请取消注册通知,等待 SOCK_NOTIFY_EVENT_REMOVE 消息 (查看 ProcessSocketNotifications 和 SocketNotificationRetrieveEvents 主题) ,然后重新注册套接字。
为了避免释放仍在使用的内存,应仅在收到注册 SOCK_NOTIFY_EVENT_REMOVE通知后 释放注册的相关数据结构。 使用 closesocket 函数关闭用于注册通知的套接字描述符时,会自动取消注册其通知。 但是,可能仍会传递已排队的通知。 通过 closesocket 自动注销不会生成 SOCK_NOTIFY_EVENT_REMOVE 通知。
如果需要多线程处理,则应使用单个 I/O 完成端口和多个线程处理通知。 这允许 I/O 完成端口根据需要跨多个线程横向扩展工作。 避免将多个 I/O 完成端口 (例如,每个线程) 一个,因为该设计容易受到单个线程上的瓶颈影响,而其他线程则处于空闲状态。
如果多个线程使用级别触发的通知对通知数据包取消排队,则应提供 SOCK_NOTIFY_TRIGGER_ONESHOT 以避免多个线程接收状态更改通知。 处理套接字通知后,应重新注册通知。
如果多个线程在面向流的连接上取消排队通知数据包,其中需要在单个线程上处理单个消息,请考虑使用级别触发的一次性通知。 这降低了多个线程接收需要跨线程重新组合的消息片段的可能性。
如果使用边缘触发的通知,则不建议使用一次性通知,因为启用注册后需要清空套接字。 这是一种更复杂的实现模式,并且成本更高,因为它始终需要返回 WSAEWOULDBLOCK 的调用。
如果要在单个侦听套接字上横向扩展连接接受,则服务器应使用 AcceptEx 函数,而不是订阅连接请求的通知。 接受连接以响应通知会隐式限制与处理现有连接请求相关的连接接受率。
下面是演示某些套接字状态通知方案的代码示例。 某些代码包含为自己的应用程序 执行 项。
通用代码
首先,下面是一个代码列表,其中包含以下方案使用的一些常见定义和函数。
#include "pch.h"
#include <winsock2.h>
#pragma comment(lib, "Ws2_32")
#define SERVER_ADDRESS 0x0100007f // localhost
#define SERVER_PORT 0xffff // TODO: select an actual valid port
#define MAX_TIMEOUT 1000
#define CLIENT_LOOP_COUNT 10
typedef struct SERVER_CONTEXT {
HANDLE ioCompletionPort;
SOCKET listenerSocket;
} SERVER_CONTEXT;
typedef struct CLIENT_CONTEXT {
UINT32 transmitCount;
} CLIENT_CONTEXT;
SRWLOCK g_printLock = SRWLOCK_INIT;
VOID DestroyServerContext(_Inout_ _Post_invalid_ SERVER_CONTEXT* serverContext) {
if (serverContext->listenerSocket != INVALID_SOCKET) {
closesocket(serverContext->listenerSocket);
}
if (serverContext->ioCompletionPort != NULL) {
CloseHandle(serverContext->ioCompletionPort);
}
free(serverContext);
}
DWORD CreateServerContext(_Outptr_ SERVER_CONTEXT** serverContext) {
DWORD errorCode;
SERVER_CONTEXT* localContext = NULL;
sockaddr_in serverAddress = { };
localContext = (SERVER_CONTEXT*)malloc(sizeof(*localContext));
if (localContext == NULL) {
errorCode = ERROR_NOT_ENOUGH_MEMORY;
goto Exit;
}
ZeroMemory(localContext, sizeof(*localContext));
localContext->listenerSocket = INVALID_SOCKET;
localContext->ioCompletionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (localContext->ioCompletionPort == NULL) {
errorCode = GetLastError();
goto Exit;
}
localContext->listenerSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (localContext->listenerSocket == INVALID_SOCKET) {
errorCode = GetLastError();
goto Exit;
}
serverAddress.sin_family = AF_INET;
serverAddress.sin_addr.s_addr = SERVER_ADDRESS;
serverAddress.sin_port = SERVER_PORT;
if (bind(localContext->listenerSocket, (sockaddr*)&serverAddress, sizeof(serverAddress)) != 0) {
errorCode = GetLastError();
goto Exit;
}
if (listen(localContext->listenerSocket, 0) != 0) {
errorCode = GetLastError();
goto Exit;
}
*serverContext = localContext;
localContext = NULL;
errorCode = ERROR_SUCCESS;
Exit:
if (localContext != NULL) {
DestroyServerContext(localContext);
}
return errorCode;
}
// Create a socket, connect to the server, send transmitCount copies of the
// payload, then disconnect.
DWORD
WINAPI
ClientThreadRoutine(_In_ PVOID clientContextPointer) {
const UINT32 payload = 0xdeadbeef;
CLIENT_CONTEXT* clientContext = (CLIENT_CONTEXT*)clientContextPointer;
sockaddr_in serverAddress = {};
SOCKET clientSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (clientSocket == INVALID_SOCKET) {
goto Exit;
}
serverAddress.sin_family = AF_INET;
serverAddress.sin_addr.s_addr = SERVER_ADDRESS;
serverAddress.sin_port = SERVER_PORT;
if (connect(clientSocket, (sockaddr*)&serverAddress, sizeof(serverAddress)) != 0) {
goto Exit;
}
for (UINT32 Index = 0; Index < clientContext->transmitCount; Index += 1) {
if (send(clientSocket, (const char*)&payload, sizeof(payload), 0) < 0) {
goto Exit;
}
}
if (shutdown(clientSocket, SD_BOTH) != 0) {
goto Exit;
}
Exit:
if (clientSocket != INVALID_SOCKET) {
closesocket(INVALID_SOCKET);
}
free(clientContext);
return 0;
}
DWORD CreateClientThread(_In_ UINT32 transmitCount) {
DWORD errorCode = ERROR_SUCCESS;
CLIENT_CONTEXT* clientContext = NULL;
HANDLE clientThread = NULL;
clientContext = (CLIENT_CONTEXT*)malloc(sizeof(*clientContext));
if (clientContext == NULL) {
errorCode = ERROR_NOT_ENOUGH_MEMORY;
goto Exit;
}
ZeroMemory(clientContext, sizeof(*clientContext));
clientContext->transmitCount = transmitCount;
clientThread = CreateThread(NULL, 0, ClientThreadRoutine, clientContext, 0, NULL);
if (clientThread == NULL) {
errorCode = GetLastError();
goto Exit;
}
clientContext = NULL;
Exit:
if (clientContext != NULL) {
free(clientContext);
}
if (clientThread != NULL) {
CloseHandle(clientThread);
}
return errorCode;
}
VOID PrintError(DWORD errorCode) {
AcquireSRWLockExclusive(&g_printLock);
wprintf_s(L"Server thread %d encountered an error %d.", GetCurrentThreadId(), errorCode);
WCHAR errorString[512];
if (FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM,
NULL,
errorCode,
0,
errorString,
RTL_NUMBER_OF(errorString),
NULL) != 0)
{
wprintf_s(L"%s", errorString);
}
ReleaseSRWLockExclusive(&g_printLock);
}
// This routine must be used only if a single socket is registered.
DWORD DeregisterAndWait(_In_ HANDLE ioCompletionPort, _In_ SOCKET socket) {
DWORD errorCode;
SOCK_NOTIFY_REGISTRATION registration = {};
OVERLAPPED_ENTRY notification;
UINT32 notificationCount;
// Keep looping until the registration is removed, or a timeout is hit.
while (TRUE) {
registration.operation = SOCK_NOTIFY_OP_REMOVE;
registration.socket = socket;
errorCode = ProcessSocketNotifications(ioCompletionPort,
1,
®istration,
MAX_TIMEOUT,
1,
¬ification,
¬ificationCount);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
if (registration.registrationResult != ERROR_SUCCESS) {
errorCode = registration.registrationResult;
goto Exit;
}
// Drops all non-removal notifications. Must be used only
// if a single socket is registered.
if (SocketNotificationRetrieveEvents(¬ification) & SOCK_NOTIFY_EVENT_REMOVE) {
break;
}
}
Exit:
return errorCode;
}
轮询的简单替换
此方案演示了使用轮询 (WSAPoll) 或类似 API 的应用程序的直接替换。 它是单线程的,使用持久性 (而不是一次性) 注册。 由于注册不需要重新注册,因此它使用 GetQueuedCompletionStatusEx 取消排队通知。
VOID SimplePollReplacement() {
DWORD errorCode;
WSADATA wsaData;
SERVER_CONTEXT* serverContext = NULL;
SOCKET tcpAcceptSocket = INVALID_SOCKET;
u_long nonBlocking = 1;
SOCKET currentSocket;
SOCK_NOTIFY_REGISTRATION registration = {};
OVERLAPPED_ENTRY notification;
ULONG notificationCount;
UINT32 events;
CHAR dataBuffer[512];
if (WSAStartup(WINSOCK_VERSION, &wsaData) != 0) {
errorCode = GetLastError();
PrintError(errorCode);
return;
}
errorCode = CreateServerContext(&serverContext);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
errorCode = CreateClientThread(CLIENT_LOOP_COUNT);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
tcpAcceptSocket = accept(serverContext->listenerSocket, NULL, NULL);
if (tcpAcceptSocket == INVALID_SOCKET) {
errorCode = GetLastError();
goto Exit;
}
if (ioctlsocket(tcpAcceptSocket, FIONBIO, &nonBlocking) != 0) {
errorCode = GetLastError();
goto Exit;
}
// Register the accepted connection.
registration.completionKey = (PVOID)tcpAcceptSocket;
registration.eventFilter = SOCK_NOTIFY_REGISTER_EVENT_IN | SOCK_NOTIFY_REGISTER_EVENT_HANGUP;
registration.operation = SOCK_NOTIFY_OP_ENABLE;
registration.triggerFlags = SOCK_NOTIFY_TRIGGER_LEVEL;
registration.socket = tcpAcceptSocket;
errorCode = ProcessSocketNotifications(serverContext->ioCompletionPort,
1,
®istration,
0,
0,
NULL,
NULL);
// Make sure all registrations were processed.
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
// Make sure each registration was successful.
if (registration.registrationResult != ERROR_SUCCESS) {
errorCode = registration.registrationResult;
goto Exit;
}
// Keep receiving data until the client disconnects.
while (TRUE) {
wprintf_s(L"Waiting for client action...\r\n");
if (!GetQueuedCompletionStatusEx(serverContext->ioCompletionPort,
¬ification,
1,
¬ificationCount,
MAX_TIMEOUT,
FALSE))
{
errorCode = GetLastError();
goto Exit;
}
// The completion key is the socket we supplied above.
//
// This is true only because the registration supplied the socket as the completion
// key. A more typical pattern is to supply a context pointer. This example supplies
// the socket directly, for simplicity.
//
// The events are stored in the number-of-bytes-received field.
events = SocketNotificationRetrieveEvents(¬ification);
currentSocket = (SOCKET)notification.lpCompletionKey;
if (events & SOCK_NOTIFY_EVENT_IN) {
// We don't check for a 0-size receive because we subscribed to hang-up notifications.
if (recv(currentSocket, dataBuffer, sizeof(dataBuffer), 0) < 0) {
errorCode = GetLastError();
goto Exit;
}
wprintf_s(L"Received client data.\r\n");
}
if (events & SOCK_NOTIFY_EVENT_HANGUP) {
wprintf_s(L"Client hung up. Exiting. \r\n");
break;
}
if (events & SOCK_NOTIFY_EVENT_ERR) {
wprintf_s(L"The socket was ungracefully reset or another error occurred. Exiting.\r\n");
// Obtain a more detailed error code by issuing a non-blocking receive.
recv(currentSocket, dataBuffer, sizeof(dataBuffer), 0);
errorCode = GetLastError();
goto Exit;
}
}
errorCode = ERROR_SUCCESS;
Exit:
if (errorCode != ERROR_SUCCESS) {
PrintError(errorCode);
}
if (serverContext != NULL) {
if (tcpAcceptSocket != INVALID_SOCKET) {
DeregisterAndWait(serverContext->ioCompletionPort, tcpAcceptSocket);
}
DestroyServerContext(serverContext);
}
if (tcpAcceptSocket != INVALID_SOCKET) {
closesocket(tcpAcceptSocket);
}
WSACleanup();
}
边缘触发的 UDP 服务器
这是有关如何将 API 与边缘触发配合使用的简单说明。
重要
服务器必须一直接收,直到收到 WSAEWOULDBLOCK。 否则,它无法确定将观察到上升的边缘。 因此,服务器的套接字也必须是非阻塞的。
此示例使用 UDP 来演示缺少 HANGUP 通知。 假设常见的帮助程序根据需要创建 UDP 套接字需要一些自由。
// This example assumes that substantially similar helpers are available for UDP sockets.
VOID SimpleEdgeTriggeredSample() {
DWORD errorCode;
WSADATA wsaData;
SOCKET serverSocket = INVALID_SOCKET;
SOCKET currentSocket;
HANDLE ioCompletionPort = NULL;
sockaddr_in serverAddress = { };
u_long nonBlocking = 1;
SOCK_NOTIFY_REGISTRATION registration = {};
OVERLAPPED_ENTRY notification;
ULONG notificationCount;
UINT32 events;
CHAR dataBuffer[512];
UINT32 datagramCount;
int receiveResult;
if (WSAStartup(WINSOCK_VERSION, &wsaData) != 0) {
errorCode = GetLastError();
PrintError(errorCode);
return;
}
ioCompletionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
if (ioCompletionPort == NULL) {
errorCode = GetLastError();
goto Exit;
}
serverSocket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
if (serverSocket == INVALID_SOCKET) {
errorCode = GetLastError();
goto Exit;
}
// Register the server UDP socket before binding to a port to ensure data doesn't become
// present before the registration. Otherwise, the server could miss the notification and
// hang.
//
// Edge-triggered is not recommended with one-shot due to the difficulty in re-registering.
registration.completionKey = (PVOID)serverSocket;
registration.eventFilter = SOCK_NOTIFY_EVENT_IN;
registration.operation = SOCK_NOTIFY_OP_ENABLE;
registration.triggerFlags = SOCK_NOTIFY_TRIGGER_EDGE;
registration.socket = serverSocket;
errorCode = ProcessSocketNotifications(ioCompletionPort, 1, ®istration, 0, 0, NULL, NULL);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
if (registration.registrationResult != ERROR_SUCCESS) {
errorCode = registration.registrationResult;
goto Exit;
}
// Use non-blocking sockets with edge-triggered notifications, since the data must be
// drained before a rising edge can be observed again.
errorCode = ioctlsocket(serverSocket, FIONBIO, &nonBlocking);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
serverAddress.sin_family = AF_INET;
serverAddress.sin_addr.s_addr = SERVER_ADDRESS;
serverAddress.sin_port = SERVER_PORT;
if (bind(serverSocket, (sockaddr*)&serverAddress, sizeof(serverAddress)) != 0) {
errorCode = GetLastError();
goto Exit;
}
// Create the client.
// While CreateClientThread connects to a TCP socket and sends data over it, for this example
// assume that CreateClientThread creates a UDP socket instead, and sends data over it.
errorCode = CreateClientThread(CLIENT_LOOP_COUNT);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
// Receive the packets.
datagramCount = 0;
while (datagramCount < CLIENT_LOOP_COUNT) {
wprintf_s(L"Waiting for client action...\r\n");
if (!GetQueuedCompletionStatusEx(ioCompletionPort,
¬ification,
1,
¬ificationCount,
MAX_TIMEOUT,
FALSE))
{
errorCode = GetLastError();
goto Exit;
}
// The completion key is the socket we supplied above.
//
// This is true only because the registration supplied the socket as the completion
// key. A more typical pattern is to supply a context pointer. This example supplies
// the socket directly, for simplicity.
//
// The events are the integer value of the overlapped pointer.
events = SocketNotificationRetrieveEvents(¬ification);
currentSocket = (SOCKET)notification.lpCompletionKey;
if (events & SOCK_NOTIFY_EVENT_ERR) {
// Obtain a more detailed error code by issuing a non-blocking receive.
recv(currentSocket, dataBuffer, sizeof(dataBuffer), 0);
errorCode = GetLastError();
goto Exit;
}
if ((events & SOCK_NOTIFY_EVENT_IN) == 0) {
continue;
}
// Keep looping receiving data until the read would block, otherwise the edge may not
// have been reset.
while (TRUE) {
receiveResult = recv(currentSocket, dataBuffer, sizeof(dataBuffer), 0);
if (receiveResult < 0) {
errorCode = GetLastError();
if (errorCode != WSAEWOULDBLOCK) {
goto Exit;
}
break;
}
datagramCount += 1;
wprintf_s(L"Received client data.\r\n");
}
}
wprintf_s(L"Received all data. Exiting... \r\n");
errorCode = ERROR_SUCCESS;
Exit:
if (errorCode != ERROR_SUCCESS) {
PrintError(errorCode);
}
if (serverSocket != INVALID_SOCKET) {
if (ioCompletionPort != NULL) {
DeregisterAndWait(ioCompletionPort, serverSocket);
}
closesocket(serverSocket);
}
if (ioCompletionPort != NULL) {
CloseHandle(ioCompletionPort);
}
WSACleanup();
}
多线程服务器
此示例演示了一种更真实的多线程使用模式,该模式使用 I/O 完成端口的横向扩展功能跨多个服务器线程分配工作。 服务器使用一次性级别触发来避免多个线程为同一套接字拾取通知,并允许每个线程一次一个区块排出收到的数据。
它还演示了与完成端口一起使用的一些常见模式。 完成键用于提供每个套接字的上下文指针。 上下文指针具有一个标头,用于描述所使用的套接字类型,以便可以在单个完成端口上使用多个套接字类型。 示例中的注释突出显示了任意完成可以取消排队 (与 GetQueuedCompletionStatusEx 函数) 一样,而不仅仅是套接字通知。 PostQueuedCompletionStatus API 用于将消息发布到线程,并唤醒它们,而无需等待套接字通知的到来。
最后,该示例演示了正确取消注册和清理线程工作负载中的套接字上下文的一些复杂之处。 在此示例中,套接字上下文由接收通知的线程隐式拥有。 如果无法注册通知,线程将保留所有权。
#define CLIENT_THREAD_COUNT 100
// The I/O completion port infrastructure ensures that the system isn't over-subscribed by
// ensuring server-side threads block if they exceed the number of logical processors. If the
// machine has more than 16 logical processors, then this can be observed by increasing this number.
#define SERVER_THREAD_COUNT 16
#define SERVER_DEQUEUE_COUNT 3
#define SERVER_EXIT_KEY ((ULONG_PTR)-1)
typedef struct SERVER_THREAD_CONTEXT {
SERVER_CONTEXT* commonContext;
SRWLOCK stateLock;
_Guarded_by_(stateLock) UINT32 deregisterCount;
_Guarded_by_(stateLock) BOOLEAN shouldExit;
} SERVER_THREAD_CONTEXT;
typedef enum SOCKET_TYPE {
SOCKET_TYPE_LISTENER,
SOCKET_TYPE_ACCEPT
} SOCKET_TYPE;
typedef struct SOCKET_CONTEXT {
SOCKET_TYPE socketType;
SOCKET socket;
} SOCKET_CONTEXT;
VOID CancelServerThreadsAsync(_Inout_ SERVER_THREAD_CONTEXT* serverThreadContext) {
AcquireSRWLockExclusive(&serverThreadContext->stateLock);
serverThreadContext->shouldExit = TRUE;
ReleaseSRWLockExclusive(&serverThreadContext->stateLock);
}
VOID IndicateServerThreadExit(_In_ HANDLE ioCompletionPort) {
// Notify a server thread that it needs to exit. It can then notify the other threads when it
// exits.
//
// If this fails, then server threads may hang, and this program will never terminate. That
// is an unrecoverable error.
if (!PostQueuedCompletionStatus(ioCompletionPort, 0, SERVER_EXIT_KEY, NULL)) {
RaiseFailFastException(NULL, NULL, 0);
}
}
VOID DestroySocketContext(_Inout_ _Post_invalid_ SOCKET_CONTEXT* socketContext) {
if (socketContext->socket != INVALID_SOCKET) {
closesocket(socketContext->socket);
}
free(socketContext);
}
DWORD AcceptConnection(_In_ SOCKET listenSocket, _Outptr_ SOCKET_CONTEXT** socketContextOut) {
DWORD errorCode;
SOCKET_CONTEXT* socketContext = NULL;
socketContext = (SOCKET_CONTEXT*)malloc(sizeof(*socketContext));
if (socketContext == NULL) {
errorCode = ERROR_NOT_ENOUGH_MEMORY;
goto Exit;
}
ZeroMemory(socketContext, sizeof(*socketContext));
socketContext->socketType = SOCKET_TYPE_ACCEPT;
socketContext->socket = accept(listenSocket, NULL, NULL);
if (socketContext->socket == INVALID_SOCKET) {
errorCode = GetLastError();
goto Exit;
}
*socketContextOut = socketContext;
socketContext = NULL;
Exit:
if (socketContext != NULL) {
_ASSERT(errorCode != ERROR_SUCCESS);
DestroySocketContext(socketContext);
}
return errorCode;
}
DWORD
WINAPI
ServerThreadRoutine(_In_ PVOID serverThreadContextPointer) {
DWORD errorCode;
SERVER_THREAD_CONTEXT* serverThreadContext;
HANDLE ioCompletionPort;
// Accepting a connection requires two registrations: one to re-enable the listening socket
// notification, and one to register the newly-accepted connection.
SOCK_NOTIFY_REGISTRATION registrationBuffer[SERVER_DEQUEUE_COUNT * 2];
UINT32 registrationCount;
SOCK_NOTIFY_REGISTRATION* registration;
OVERLAPPED_ENTRY notifications[SERVER_DEQUEUE_COUNT];
UINT32 notificationCount;
UINT32 events;
SOCKET_CONTEXT* socketContext;
SOCKET_CONTEXT* acceptedContext;
BOOLEAN shouldExit;
CHAR dataBuffer[512];
serverThreadContext = (SERVER_THREAD_CONTEXT*)serverThreadContextPointer;
ioCompletionPort = serverThreadContext->commonContext->ioCompletionPort;
// Boot-strap the loop process.
registrationCount = 0;
// Keep looping, processing notifications until exit has been requested.
while (TRUE) {
AcquireSRWLockExclusive(&serverThreadContext->stateLock);
shouldExit = serverThreadContext->shouldExit;
ReleaseSRWLockExclusive(&serverThreadContext->stateLock);
if (shouldExit) {
goto Exit;
}
AcquireSRWLockExclusive(&g_printLock);
wprintf_s(L"Server thread %d waiting for client action...\r\n", GetCurrentThreadId());
ReleaseSRWLockExclusive(&g_printLock);
// Process notifications and re-register one-shot notifications that were processed on a
// previous iteration.
errorCode = ProcessSocketNotifications(ioCompletionPort,
registrationCount,
(registrationCount == 0) ? NULL : registrationBuffer,
MAX_TIMEOUT,
RTL_NUMBER_OF(notifications),
notifications,
¬ificationCount);
// TODO: Production code should handle failure better. This can fail due to transient memory conditions, or due to
// invalid input such as a bad handle. Retrying in case the memory conditions abate is
// a reasonable strategy.
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
// Check whether any registrations failed, and attempt to clean up if they did.
errorCode = ERROR_SUCCESS;
for (UINT32 i = 0; i < registrationCount; i += 1) {
registration = ®istrationBuffer[i];
if (registration->registrationResult == ERROR_SUCCESS) {
continue;
}
// Preserve the first failure code.
if (errorCode == ERROR_SUCCESS) {
errorCode = registration->registrationResult;
}
// All the registrations are oneshot, so if the registration failed, then only this thread
// has access to the context. Attempt to clean up fully:
// - The listening socket is owned by the main thread, so ignore that.
// - If the socket hasn't been registered, just free its memory.
// - Otherwise, attempt to deregister it.
socketContext = (SOCKET_CONTEXT*)registration->completionKey;
if (socketContext->socketType == SOCKET_TYPE_LISTENER) {
continue;
}
// Best-effort de-registration. In case of failure, simply get rid of the socket and
// context. This is safe to do because the notification for the socket can't be enabled.
// Either it was never registered in the first place, or re-registration failed, and it
// was previously disabled by nature of being a one-shot registration.
registration->operation = SOCK_NOTIFY_OP_REMOVE;
errorCode = ProcessSocketNotifications(ioCompletionPort,
1,
registration,
0,
0,
NULL,
NULL);
if ((errorCode != ERROR_SUCCESS) ||
(registration->registrationResult != ERROR_SUCCESS)) {
DestroySocketContext(socketContext);
}
}
// Process the notifications. Many will need to be re-enabled because they are one-shot,
// so ensure that we can build that incrementally.
registrationCount = 0;
ZeroMemory(registrationBuffer, sizeof(registrationBuffer));
for (UINT32 i = 0; i < notificationCount; i += 1) {
if (notifications[i].lpCompletionKey == SERVER_EXIT_KEY) {
_ASSERT(serverThreadContext->shouldExit);
// On exit, this thread will post the next exit message.
errorCode = ERROR_SUCCESS;
goto Exit;
}
socketContext = (SOCKET_CONTEXT*)notifications[i].lpCompletionKey;
events = SocketNotificationRetrieveEvents(¬ifications[i]);
// Process the socket notification, taking socket-specific actions.
switch (socketContext->socketType) {
case SOCKET_TYPE_LISTENER:
// Accepting connections in response to notifications implicitly throttles
// the rate at which incoming connections are accepted, and limits scale-out for
// new connection acceptance. Consider using AcceptEx if greater scaling of
//connection acceptance is desired.
// Perform an accept regardless of the notification. The only possible notifications
// are for available connections or error conditions. Any possible error conditions
// will be processed as part of the accept.
errorCode = AcceptConnection(socketContext->socket, &acceptedContext);
if (errorCode == ERROR_SUCCESS) {
// Register the accepted connection.
registration = ®istrationBuffer[registrationCount];
registration->socket = acceptedContext->socket;
registration->completionKey = acceptedContext;
registration->eventFilter = SOCK_NOTIFY_EVENT_IN | SOCK_NOTIFY_EVENT_HANGUP;
registration->operation =
SOCK_NOTIFY_OP_ENABLE;
registration->triggerFlags = SOCK_NOTIFY_TRIGGER_ONESHOT | SOCK_NOTIFY_TRIGGER_LEVEL;
registrationCount += 1;
}
// Re-arm the existing listening socket registration.
registration = ®istrationBuffer[registrationCount];
registration->socket = socketContext->socket;
registration->completionKey = socketContext;
registration->eventFilter = SOCK_NOTIFY_EVENT_IN;
registration->operation =
SOCK_NOTIFY_OP_ENABLE;
registration->triggerFlags = SOCK_NOTIFY_TRIGGER_ONESHOT | SOCK_NOTIFY_TRIGGER_LEVEL;
registrationCount += 1;
break;
case SOCKET_TYPE_ACCEPT:
// The registration was removed. Clean up the context.
if (events & SOCK_NOTIFY_EVENT_REMOVE) {
AcquireSRWLockExclusive(&serverThreadContext->stateLock);
serverThreadContext->deregisterCount += 1;
if (serverThreadContext->deregisterCount >= CLIENT_THREAD_COUNT) {
serverThreadContext->shouldExit = TRUE;
}
ReleaseSRWLockExclusive(&serverThreadContext->stateLock);
DestroySocketContext(socketContext);
continue;
}
registration = ®istrationBuffer[registrationCount];
// If a hangup occurred, then remove the registration.
if (events & SOCK_NOTIFY_EVENT_HANGUP) {
registration->eventFilter = 0;
registration->operation = SOCK_NOTIFY_OP_REMOVE;
}
// Receive data.
if (events & (SOCK_NOTIFY_EVENT_IN | SOCK_NOTIFY_EVENT_ERR)) {
// TODO: Handle errors (for example, due to connection reset). The error from recv can
// be used to retrieve the underlying socket for a SOCK_NOTIFY_EVENT_ERR.
if (recv(socketContext->socket, dataBuffer, sizeof(dataBuffer), 0) < 0) {
registration->operation = SOCK_NOTIFY_OP_REMOVE;
registration->eventFilter = 0;
}
else {
registration->operation |=
SOCK_NOTIFY_OP_ENABLE;
registration->triggerFlags =
SOCK_NOTIFY_TRIGGER_ONESHOT | SOCK_NOTIFY_TRIGGER_LEVEL;
registration->eventFilter = SOCK_NOTIFY_EVENT_IN | SOCK_NOTIFY_EVENT_HANGUP;
}
}
registration->socket = socketContext->socket;
registration->completionKey = socketContext;
registrationCount += 1;
break;
// TODO:
//
// Other (potentially non-socket) I/O completion can be processed here. For instance,
// this could also be processing disk I/O. The contexts will need to have a common
// header that can be used to differentiate between the different context types,
// similar to how the listening and accepted sockets are differentiated.
//
// case ... :
default:
_ASSERT(!"Unexpected socket type!");
errorCode = ERROR_UNIDENTIFIED_ERROR;
goto Exit;
}
}
}
errorCode = ERROR_SUCCESS;
Exit:
// If an error occurred, then ensure the other threads know they should exit.
// TODO: use an error handling strategy that isn't just exiting.
if (errorCode != ERROR_SUCCESS) {
PrintError(errorCode);
CancelServerThreadsAsync(serverThreadContext);
}
// Wake a remaining server thread.
IndicateServerThreadExit(ioCompletionPort);
AcquireSRWLockExclusive(&g_printLock);
wprintf_s(L"Server thread %d exited\r\n", GetCurrentThreadId());
ReleaseSRWLockExclusive(&g_printLock);
return errorCode;
}
VOID MultiThreadedTcpServer() {
DWORD errorCode;
WSADATA wsaData;
SERVER_THREAD_CONTEXT serverContext = { NULL, SRWLOCK_INIT, 0, FALSE };
SOCKET_CONTEXT listenContext = {};
SOCK_NOTIFY_REGISTRATION registration = {};
HANDLE serverThreads[SERVER_THREAD_COUNT] = {};
UINT32 serverThreadCount = 0;
if (WSAStartup(WINSOCK_VERSION, &wsaData) != 0) {
errorCode = GetLastError();
PrintError(errorCode);
return;
}
listenContext.socket = INVALID_SOCKET;
listenContext.socketType = SOCKET_TYPE_LISTENER;
errorCode = CreateServerContext(&serverContext.commonContext);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
// Register the listening socket with the I/O completion port so the server threads are notified
// of incoming connections.
listenContext.socket = serverContext.commonContext->listenerSocket;
registration.completionKey = &listenContext;
registration.eventFilter = SOCK_NOTIFY_EVENT_IN;
registration.operation = SOCK_NOTIFY_OP_ENABLE;
registration.triggerFlags = SOCK_NOTIFY_TRIGGER_LEVEL | SOCK_NOTIFY_TRIGGER_PERSISTENT;
registration.socket = listenContext.socket;
errorCode = ProcessSocketNotifications(serverContext.commonContext->ioCompletionPort,
1,
®istration,
0,
0,
NULL,
NULL);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
// Create the server threads. These are likely over-subscribed, but the I/O completion port
// ensures that they scale appropriately.
while (serverThreadCount < RTL_NUMBER_OF(serverThreads)) {
serverThreads[serverThreadCount] =
CreateThread(NULL, 0, ServerThreadRoutine, &serverContext, 0, NULL);
if (serverThreads[serverThreadCount] == NULL) {
errorCode = GetLastError();
goto Exit;
}
}
// Create the client threads, which are badly over-subscribed.
for (UINT32 i = 0; i < CLIENT_THREAD_COUNT; i += 1) {
errorCode = CreateClientThread(CLIENT_LOOP_COUNT);
if (errorCode != ERROR_SUCCESS) {
goto Exit;
}
}
errorCode = ERROR_SUCCESS;
Exit:
if (errorCode != ERROR_SUCCESS) {
PrintError(errorCode);
// In case of error, ensure that all server threads know to exit.
if (serverContext.commonContext != NULL) {
CancelServerThreadsAsync(&serverContext);
IndicateServerThreadExit(serverContext.commonContext->ioCompletionPort);
}
}
if (serverThreadCount > 0) {
wprintf_s(L"Waiting for %d server threads to exit...\r\n", serverThreadCount);
errorCode = WaitForMultipleObjects(serverThreadCount, serverThreads, TRUE, INFINITE);
_ASSERT(errorCode == ERROR_SUCCESS);
}
// TODO: In case of failure, clean up remaining state. For example, Accepted connections can be kept in
// a global list, which can be closed from this thread.
for (UINT32 i = 0; i < serverThreadCount; i += 1) {
CloseHandle(serverThreads[i]);
}
DestroyServerContext(serverContext.commonContext);
WSACleanup();
}