diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-10-01 12:20:08 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-10-01 12:20:08 +0100 |
commit | d04c83839c44737e7fbf76d04464d5fe681bc9d4 (patch) | |
tree | 75bc0b83cb02748c6a05009b61cf6e91766075d1 /contrib/hiredis/async.c | |
parent | da5fa7ec0d65e6cdcd6211c6b4dc87c467eef354 (diff) | |
download | rspamd-d04c83839c44737e7fbf76d04464d5fe681bc9d4.tar.gz rspamd-d04c83839c44737e7fbf76d04464d5fe681bc9d4.zip |
[Rework] Update hiredis to 1.2.0
Diffstat (limited to 'contrib/hiredis/async.c')
-rw-r--r-- | contrib/hiredis/async.c | 665 |
1 files changed, 503 insertions, 162 deletions
diff --git a/contrib/hiredis/async.c b/contrib/hiredis/async.c index afa541304..f82f567f8 100644 --- a/contrib/hiredis/async.c +++ b/contrib/hiredis/async.c @@ -30,9 +30,12 @@ */ #include "fmacros.h" +#include "alloc.h" #include <stdlib.h> #include <string.h> +#ifndef _MSC_VER #include <strings.h> +#endif #include <assert.h> #include <ctype.h> #include <errno.h> @@ -40,25 +43,18 @@ #include "net.h" #include "dict.c" #include "sds.h" +#include "win32.h" -#define _EL_ADD_READ(ctx) do { \ - if ((ctx)->ev.addRead) (ctx)->ev.addRead((ctx)->ev.data); \ - } while(0) -#define _EL_DEL_READ(ctx) do { \ - if ((ctx)->ev.delRead) (ctx)->ev.delRead((ctx)->ev.data); \ - } while(0) -#define _EL_ADD_WRITE(ctx) do { \ - if ((ctx)->ev.addWrite) (ctx)->ev.addWrite((ctx)->ev.data); \ - } while(0) -#define _EL_DEL_WRITE(ctx) do { \ - if ((ctx)->ev.delWrite) (ctx)->ev.delWrite((ctx)->ev.data); \ - } while(0) -#define _EL_CLEANUP(ctx) do { \ - if ((ctx)->ev.cleanup) (ctx)->ev.cleanup((ctx)->ev.data); \ - } while(0); - -/* Forward declaration of function in hiredis.c */ +#include "async_private.h" + +#ifdef NDEBUG +#undef assert +#define assert(e) (void)(e) +#endif + +/* Forward declarations of hiredis.c functions */ int __redisAppendCommand(redisContext *c, const char *cmd, size_t len); +void __redisSetError(redisContext *c, int type, const char *str); /* Functions managing dictionary of callbacks for pub/sub. */ static unsigned int callbackHash(const void *key) { @@ -68,7 +64,12 @@ static unsigned int callbackHash(const void *key) { static void *callbackValDup(void *privdata, const void *src) { ((void) privdata); - redisCallback *dup = malloc(sizeof(*dup)); + redisCallback *dup; + + dup = hi_malloc(sizeof(*dup)); + if (dup == NULL) + return NULL; + memcpy(dup,src,sizeof(*dup)); return dup; } @@ -90,7 +91,7 @@ static void callbackKeyDestructor(void *privdata, void *key) { static void callbackValDestructor(void *privdata, void *val) { ((void) privdata); - free(val); + hi_free(val); } static dictType callbackDict = { @@ -104,10 +105,19 @@ static dictType callbackDict = { static redisAsyncContext *redisAsyncInitialize(redisContext *c) { redisAsyncContext *ac; + dict *channels = NULL, *patterns = NULL; + + channels = dictCreate(&callbackDict,NULL); + if (channels == NULL) + goto oom; + + patterns = dictCreate(&callbackDict,NULL); + if (patterns == NULL) + goto oom; - ac = realloc(c,sizeof(redisAsyncContext)); + ac = hi_realloc(c,sizeof(redisAsyncContext)); if (ac == NULL) - return NULL; + goto oom; c = &(ac->c); @@ -119,6 +129,7 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->err = 0; ac->errstr = NULL; ac->data = NULL; + ac->dataCleanup = NULL; ac->ev.data = NULL; ac->ev.addRead = NULL; @@ -126,17 +137,25 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->ev.addWrite = NULL; ac->ev.delWrite = NULL; ac->ev.cleanup = NULL; + ac->ev.scheduleTimer = NULL; ac->onConnect = NULL; + ac->onConnectNC = NULL; ac->onDisconnect = NULL; ac->replies.head = NULL; ac->replies.tail = NULL; - ac->sub.invalid.head = NULL; - ac->sub.invalid.tail = NULL; - ac->sub.channels = dictCreate(&callbackDict,NULL); - ac->sub.patterns = dictCreate(&callbackDict,NULL); + ac->sub.replies.head = NULL; + ac->sub.replies.tail = NULL; + ac->sub.channels = channels; + ac->sub.patterns = patterns; + ac->sub.pending_unsubs = 0; + return ac; +oom: + if (channels) dictRelease(channels); + if (patterns) dictRelease(patterns); + return NULL; } /* We want the error field to be accessible directly instead of requiring @@ -150,13 +169,21 @@ static void __redisAsyncCopyError(redisAsyncContext *ac) { ac->errstr = c->errstr; } -redisAsyncContext *redisAsyncConnect(const char *ip, int port) { +redisAsyncContext *redisAsyncConnectWithOptions(const redisOptions *options) { + redisOptions myOptions = *options; redisContext *c; redisAsyncContext *ac; - c = redisConnectNonBlock(ip,port); - if (c == NULL) + /* Clear any erroneously set sync callback and flag that we don't want to + * use freeReplyObject by default. */ + myOptions.push_cb = NULL; + myOptions.options |= REDIS_OPT_NO_PUSH_AUTOFREE; + + myOptions.options |= REDIS_OPT_NONBLOCK; + c = redisConnectWithOptions(&myOptions); + if (c == NULL) { return NULL; + } ac = redisAsyncInitialize(c); if (ac == NULL) { @@ -164,55 +191,70 @@ redisAsyncContext *redisAsyncConnect(const char *ip, int port) { return NULL; } + /* Set any configured async push handler */ + redisAsyncSetPushCallback(ac, myOptions.async_push_cb); + __redisAsyncCopyError(ac); return ac; } +redisAsyncContext *redisAsyncConnect(const char *ip, int port) { + redisOptions options = {0}; + REDIS_OPTIONS_SET_TCP(&options, ip, port); + return redisAsyncConnectWithOptions(&options); +} + redisAsyncContext *redisAsyncConnectBind(const char *ip, int port, const char *source_addr) { - redisContext *c = redisConnectBindNonBlock(ip,port,source_addr); - redisAsyncContext *ac = redisAsyncInitialize(c); - __redisAsyncCopyError(ac); - return ac; + redisOptions options = {0}; + REDIS_OPTIONS_SET_TCP(&options, ip, port); + options.endpoint.tcp.source_addr = source_addr; + return redisAsyncConnectWithOptions(&options); } redisAsyncContext *redisAsyncConnectBindWithReuse(const char *ip, int port, const char *source_addr) { - redisContext *c = redisConnectBindNonBlockWithReuse(ip,port,source_addr); - redisAsyncContext *ac = redisAsyncInitialize(c); - __redisAsyncCopyError(ac); - return ac; + redisOptions options = {0}; + REDIS_OPTIONS_SET_TCP(&options, ip, port); + options.options |= REDIS_OPT_REUSEADDR; + options.endpoint.tcp.source_addr = source_addr; + return redisAsyncConnectWithOptions(&options); } redisAsyncContext *redisAsyncConnectUnix(const char *path) { - redisContext *c; - redisAsyncContext *ac; + redisOptions options = {0}; + REDIS_OPTIONS_SET_UNIX(&options, path); + return redisAsyncConnectWithOptions(&options); +} - c = redisConnectUnixNonBlock(path); - if (c == NULL) - return NULL; +static int +redisAsyncSetConnectCallbackImpl(redisAsyncContext *ac, redisConnectCallback *fn, + redisConnectCallbackNC *fn_nc) +{ + /* If either are already set, this is an error */ + if (ac->onConnect || ac->onConnectNC) + return REDIS_ERR; - ac = redisAsyncInitialize(c); - if (ac == NULL) { - redisFree(c); - return NULL; + if (fn) { + ac->onConnect = fn; + } else if (fn_nc) { + ac->onConnectNC = fn_nc; } - __redisAsyncCopyError(ac); - return ac; + /* The common way to detect an established connection is to wait for + * the first write event to be fired. This assumes the related event + * library functions are already set. */ + _EL_ADD_WRITE(ac); + + return REDIS_OK; } int redisAsyncSetConnectCallback(redisAsyncContext *ac, redisConnectCallback *fn) { - if (ac->onConnect == NULL) { - ac->onConnect = fn; + return redisAsyncSetConnectCallbackImpl(ac, fn, NULL); +} - /* The common way to detect an established connection is to wait for - * the first write event to be fired. This assumes the related event - * library functions are already set. */ - _EL_ADD_WRITE(ac); - return REDIS_OK; - } - return REDIS_ERR; +int redisAsyncSetConnectCallbackNC(redisAsyncContext *ac, redisConnectCallbackNC *fn) { + return redisAsyncSetConnectCallbackImpl(ac, NULL, fn); } int redisAsyncSetDisconnectCallback(redisAsyncContext *ac, redisDisconnectCallback *fn) { @@ -228,7 +270,7 @@ static int __redisPushCallback(redisCallbackList *list, redisCallback *source) { redisCallback *cb; /* Copy callback from stack to heap */ - cb = malloc(sizeof(*cb)); + cb = hi_malloc(sizeof(*cb)); if (cb == NULL) return REDIS_ERR_OOM; @@ -256,7 +298,7 @@ static int __redisShiftCallback(redisCallbackList *list, redisCallback *target) /* Copy callback from heap to stack */ if (target != NULL) memcpy(target,cb,sizeof(*cb)); - free(cb); + hi_free(cb); return REDIS_OK; } return REDIS_ERR; @@ -271,46 +313,95 @@ static void __redisRunCallback(redisAsyncContext *ac, redisCallback *cb, redisRe } } +static void __redisRunPushCallback(redisAsyncContext *ac, redisReply *reply) { + if (ac->push_cb != NULL) { + ac->c.flags |= REDIS_IN_CALLBACK; + ac->push_cb(ac, reply); + ac->c.flags &= ~REDIS_IN_CALLBACK; + } +} + +static void __redisRunConnectCallback(redisAsyncContext *ac, int status) +{ + if (ac->onConnect == NULL && ac->onConnectNC == NULL) + return; + + if (!(ac->c.flags & REDIS_IN_CALLBACK)) { + ac->c.flags |= REDIS_IN_CALLBACK; + if (ac->onConnect) { + ac->onConnect(ac, status); + } else { + ac->onConnectNC(ac, status); + } + ac->c.flags &= ~REDIS_IN_CALLBACK; + } else { + /* already in callback */ + if (ac->onConnect) { + ac->onConnect(ac, status); + } else { + ac->onConnectNC(ac, status); + } + } +} + +static void __redisRunDisconnectCallback(redisAsyncContext *ac, int status) +{ + if (ac->onDisconnect) { + if (!(ac->c.flags & REDIS_IN_CALLBACK)) { + ac->c.flags |= REDIS_IN_CALLBACK; + ac->onDisconnect(ac, status); + ac->c.flags &= ~REDIS_IN_CALLBACK; + } else { + /* already in callback */ + ac->onDisconnect(ac, status); + } + } +} + /* Helper function to free the context. */ static void __redisAsyncFree(redisAsyncContext *ac) { redisContext *c = &(ac->c); redisCallback cb; - dictIterator *it; + dictIterator it; dictEntry *de; /* Execute pending callbacks with NULL reply. */ while (__redisShiftCallback(&ac->replies,&cb) == REDIS_OK) __redisRunCallback(ac,&cb,NULL); - - /* Execute callbacks for invalid commands */ - while (__redisShiftCallback(&ac->sub.invalid,&cb) == REDIS_OK) + while (__redisShiftCallback(&ac->sub.replies,&cb) == REDIS_OK) __redisRunCallback(ac,&cb,NULL); - /* Run subscription callbacks callbacks with NULL reply */ - it = dictGetIterator(ac->sub.channels); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.channels); + /* Run subscription callbacks with NULL reply */ + if (ac->sub.channels) { + dictInitIterator(&it,ac->sub.channels); + while ((de = dictNext(&it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + + dictRelease(ac->sub.channels); + } + + if (ac->sub.patterns) { + dictInitIterator(&it,ac->sub.patterns); + while ((de = dictNext(&it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); - it = dictGetIterator(ac->sub.patterns); - while ((de = dictNext(it)) != NULL) - __redisRunCallback(ac,dictGetEntryVal(de),NULL); - dictReleaseIterator(it); - dictRelease(ac->sub.patterns); + dictRelease(ac->sub.patterns); + } /* Signal event lib to clean up */ _EL_CLEANUP(ac); /* Execute disconnect callback. When redisAsyncFree() initiated destroying * this context, the status will always be REDIS_OK. */ - if (ac->onDisconnect && (c->flags & REDIS_CONNECTED)) { - if (c->flags & REDIS_FREEING) { - ac->onDisconnect(ac,REDIS_OK); - } else { - c->flags |= REDIS_FREEING; - ac->onDisconnect(ac,(ac->err == 0) ? REDIS_OK : REDIS_ERR); - } + if (c->flags & REDIS_CONNECTED) { + int status = ac->err == 0 ? REDIS_OK : REDIS_ERR; + if (c->flags & REDIS_FREEING) + status = REDIS_OK; + __redisRunDisconnectCallback(ac, status); + } + + if (ac->dataCleanup) { + ac->dataCleanup(ac->data); } /* Cleanup self */ @@ -322,14 +413,18 @@ static void __redisAsyncFree(redisAsyncContext *ac) { * free'ing. To do so, a flag is set on the context which is picked up by * redisProcessCallbacks(). Otherwise, the context is immediately free'd. */ void redisAsyncFree(redisAsyncContext *ac) { + if (ac == NULL) + return; + redisContext *c = &(ac->c); + c->flags |= REDIS_FREEING; if (!(c->flags & REDIS_IN_CALLBACK)) __redisAsyncFree(ac); } /* Helper function to make the disconnect happen and clean up. */ -static void __redisAsyncDisconnect(redisAsyncContext *ac) { +void __redisAsyncDisconnect(redisAsyncContext *ac) { redisContext *c = &(ac->c); /* Make sure error is accessible if there is any */ @@ -337,16 +432,23 @@ static void __redisAsyncDisconnect(redisAsyncContext *ac) { if (ac->err == 0) { /* For clean disconnects, there should be no pending callbacks. */ - assert(__redisShiftCallback(&ac->replies,NULL) == REDIS_ERR); + int ret = __redisShiftCallback(&ac->replies,NULL); + assert(ret == REDIS_ERR); } else { /* Disconnection is caused by an error, make sure that pending * callbacks cannot call new commands. */ c->flags |= REDIS_DISCONNECTING; } + /* cleanup event library on disconnect. + * this is safe to call multiple times */ + _EL_CLEANUP(ac); + /* For non-clean disconnects, __redisAsyncFree() will execute pending * callbacks with a NULL-reply. */ - __redisAsyncFree(ac); + if (!(c->flags & REDIS_NO_AUTO_FREE)) { + __redisAsyncFree(ac); + } } /* Tries to do a clean disconnect from Redis, meaning it stops new commands @@ -358,6 +460,9 @@ static void __redisAsyncDisconnect(redisAsyncContext *ac) { void redisAsyncDisconnect(redisAsyncContext *ac) { redisContext *c = &(ac->c); c->flags |= REDIS_DISCONNECTING; + + /** unset the auto-free flag here, because disconnect undoes this */ + c->flags &= ~REDIS_NO_AUTO_FREE; if (!(c->flags & REDIS_IN_CALLBACK) && ac->replies.head == NULL) __redisAsyncDisconnect(ac); } @@ -365,15 +470,17 @@ void redisAsyncDisconnect(redisAsyncContext *ac) { static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) { redisContext *c = &(ac->c); dict *callbacks; + redisCallback *cb = NULL; dictEntry *de; int pvariant; char *stype; - sds sname; + sds sname = NULL; - /* Custom reply functions are not supported for pub/sub. This will fail - * very hard when they are used... */ - if (reply->type == REDIS_REPLY_ARRAY) { - assert(reply->elements >= 2); + /* Match reply with the expected format of a pushed message. + * The type and number of elements (3 to 4) are specified at: + * https://redis.io/topics/pubsub#format-of-pushed-messages */ + if ((reply->type == REDIS_REPLY_ARRAY && !(c->flags & REDIS_SUPPORTS_PUSH) && reply->elements >= 3) || + reply->type == REDIS_REPLY_PUSH) { assert(reply->element[0]->type == REDIS_REPLY_STRING); stype = reply->element[0]->str; pvariant = (tolower(stype[0]) == 'p') ? 1 : 0; @@ -384,34 +491,84 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, callbacks = ac->sub.channels; /* Locate the right callback */ - assert(reply->element[1]->type == REDIS_REPLY_STRING); - sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); - de = dictFind(callbacks,sname); - if (de != NULL) { - memcpy(dstcb,dictGetEntryVal(de),sizeof(*dstcb)); - - /* If this is an unsubscribe message, remove it. */ - if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { + if (reply->element[1]->type == REDIS_REPLY_STRING) { + sname = sdsnewlen(reply->element[1]->str,reply->element[1]->len); + if (sname == NULL) goto oom; + + if ((de = dictFind(callbacks,sname)) != NULL) { + cb = dictGetEntryVal(de); + memcpy(dstcb,cb,sizeof(*dstcb)); + } + } + + /* If this is an subscribe reply decrease pending counter. */ + if (strcasecmp(stype+pvariant,"subscribe") == 0) { + assert(cb != NULL); + cb->pending_subs -= 1; + + } else if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { + if (cb == NULL) + ac->sub.pending_unsubs -= 1; + else if (cb->pending_subs == 0) dictDelete(callbacks,sname); - /* If this was the last unsubscribe message, revert to - * non-subscribe mode. */ - assert(reply->element[2]->type == REDIS_REPLY_INTEGER); - if (reply->element[2]->integer == 0) - c->flags &= ~REDIS_SUBSCRIBED; + /* If this was the last unsubscribe message, revert to + * non-subscribe mode. */ + assert(reply->element[2]->type == REDIS_REPLY_INTEGER); + + /* Unset subscribed flag only when no pipelined pending subscribe + * or pending unsubscribe replies. */ + if (reply->element[2]->integer == 0 + && dictSize(ac->sub.channels) == 0 + && dictSize(ac->sub.patterns) == 0 + && ac->sub.pending_unsubs == 0) { + c->flags &= ~REDIS_SUBSCRIBED; + + /* Move ongoing regular command callbacks. */ + redisCallback cb; + while (__redisShiftCallback(&ac->sub.replies,&cb) == REDIS_OK) { + __redisPushCallback(&ac->replies,&cb); + } } } sdsfree(sname); } else { - /* Shift callback for invalid commands. */ - __redisShiftCallback(&ac->sub.invalid,dstcb); + /* Shift callback for pending command in subscribed context. */ + __redisShiftCallback(&ac->sub.replies,dstcb); } return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + __redisAsyncCopyError(ac); + return REDIS_ERR; +} + +#define redisIsSpontaneousPushReply(r) \ + (redisIsPushReply(r) && !redisIsSubscribeReply(r)) + +static int redisIsSubscribeReply(redisReply *reply) { + char *str; + size_t len, off; + + /* We will always have at least one string with the subscribe/message type */ + if (reply->elements < 1 || reply->element[0]->type != REDIS_REPLY_STRING || + reply->element[0]->len < sizeof("message") - 1) + { + return 0; + } + + /* Get the string/len moving past 'p' if needed */ + off = tolower(reply->element[0]->str[0]) == 'p'; + str = reply->element[0]->str + off; + len = reply->element[0]->len - off; + + return !strncasecmp(str, "subscribe", len) || + !strncasecmp(str, "message", len) || + !strncasecmp(str, "unsubscribe", len); } void redisProcessCallbacks(redisAsyncContext *ac) { redisContext *c = &(ac->c); - redisCallback cb = {NULL, NULL, NULL}; void *reply = NULL; int status; @@ -424,19 +581,27 @@ void redisProcessCallbacks(redisAsyncContext *ac) { __redisAsyncDisconnect(ac); return; } - - /* If monitor mode, repush callback */ - if(c->flags & REDIS_MONITORING) { - __redisPushCallback(&ac->replies,&cb); - } - /* When the connection is not being disconnected, simply stop * trying to get replies and wait for the next loop tick. */ break; } - /* Even if the context is subscribed, pending regular callbacks will - * get a reply before pub/sub messages arrive. */ + /* Keep track of push message support for subscribe handling */ + if (redisIsPushReply(reply)) c->flags |= REDIS_SUPPORTS_PUSH; + + /* Send any non-subscribe related PUSH messages to our PUSH handler + * while allowing subscribe related PUSH messages to pass through. + * This allows existing code to be backward compatible and work in + * either RESP2 or RESP3 mode. */ + if (redisIsSpontaneousPushReply(reply)) { + __redisRunPushCallback(ac, reply); + c->reader->fn->freeObject(reply); + continue; + } + + /* Even if the context is subscribed, pending regular + * callbacks will get a reply before pub/sub messages arrive. */ + redisCallback cb = {NULL, NULL, 0, 0, NULL}; if (__redisShiftCallback(&ac->replies,&cb) != REDIS_OK) { /* * A spontaneous reply in a not-subscribed context can be the error @@ -460,15 +625,17 @@ void redisProcessCallbacks(redisAsyncContext *ac) { __redisAsyncDisconnect(ac); return; } - /* No more regular callbacks and no errors, the context *must* be subscribed or monitoring. */ - assert((c->flags & REDIS_SUBSCRIBED || c->flags & REDIS_MONITORING)); - if(c->flags & REDIS_SUBSCRIBED) + /* No more regular callbacks and no errors, the context *must* be subscribed. */ + assert(c->flags & REDIS_SUBSCRIBED); + if (c->flags & REDIS_SUBSCRIBED) __redisGetSubscribeCallback(ac,reply,&cb); } if (cb.fn != NULL) { __redisRunCallback(ac,&cb,reply); - c->reader->fn->freeObject(reply); + if (!(c->flags & REDIS_NO_AUTO_FREE_REPLIES)){ + c->reader->fn->freeObject(reply); + } /* Proceed with free'ing when redisAsyncFree() was called. */ if (c->flags & REDIS_FREEING) { @@ -481,11 +648,11 @@ void redisProcessCallbacks(redisAsyncContext *ac) { * abort with an error, but simply ignore it because the client * doesn't know what the server will spit out over the wire. */ c->reader->fn->freeObject(reply); - /* Proceed with free'ing when redisAsyncFree() was called. */ - if (c->flags & REDIS_FREEING) { - __redisAsyncFree(ac); - return; - } + } + + /* If in monitor mode, repush the callback */ + if (c->flags & REDIS_MONITORING) { + __redisPushCallback(&ac->replies,&cb); } } @@ -494,26 +661,61 @@ void redisProcessCallbacks(redisAsyncContext *ac) { __redisAsyncDisconnect(ac); } +static void __redisAsyncHandleConnectFailure(redisAsyncContext *ac) { + __redisRunConnectCallback(ac, REDIS_ERR); + __redisAsyncDisconnect(ac); +} + /* Internal helper function to detect socket status the first time a read or * write event fires. When connecting was not successful, the connect callback * is called with a REDIS_ERR status and the context is free'd. */ static int __redisAsyncHandleConnect(redisAsyncContext *ac) { + int completed = 0; redisContext *c = &(ac->c); - if (redisCheckSocketError(c) == REDIS_ERR) { - /* Try again later when connect(2) is still in progress. */ - if (errno == EINPROGRESS) - return REDIS_OK; - - if (ac->onConnect) ac->onConnect(ac,REDIS_ERR); - __redisAsyncDisconnect(ac); + if (redisCheckConnectDone(c, &completed) == REDIS_ERR) { + /* Error! */ + if (redisCheckSocketError(c) == REDIS_ERR) + __redisAsyncCopyError(ac); + __redisAsyncHandleConnectFailure(ac); return REDIS_ERR; + } else if (completed == 1) { + /* connected! */ + if (c->connection_type == REDIS_CONN_TCP && + redisSetTcpNoDelay(c) == REDIS_ERR) { + __redisAsyncHandleConnectFailure(ac); + return REDIS_ERR; + } + + /* flag us as fully connect, but allow the callback + * to disconnect. For that reason, permit the function + * to delete the context here after callback return. + */ + c->flags |= REDIS_CONNECTED; + __redisRunConnectCallback(ac, REDIS_OK); + if ((ac->c.flags & REDIS_DISCONNECTING)) { + redisAsyncDisconnect(ac); + return REDIS_ERR; + } else if ((ac->c.flags & REDIS_FREEING)) { + redisAsyncFree(ac); + return REDIS_ERR; + } + return REDIS_OK; + } else { + return REDIS_OK; } +} - /* Mark context as connected. */ - c->flags |= REDIS_CONNECTED; - if (ac->onConnect) ac->onConnect(ac,REDIS_OK); - return REDIS_OK; +void redisAsyncRead(redisAsyncContext *ac) { + redisContext *c = &(ac->c); + + if (redisBufferRead(c) == REDIS_ERR) { + __redisAsyncDisconnect(ac); + } else { + /* Always re-schedule reads */ + _EL_ADD_READ(ac); + redisProcessCallbacks(ac); + } } /* This function should be called when the socket is readable. @@ -521,6 +723,8 @@ static int __redisAsyncHandleConnect(redisAsyncContext *ac) { */ void redisAsyncHandleRead(redisAsyncContext *ac) { redisContext *c = &(ac->c); + /* must not be called from a callback */ + assert(!(c->flags & REDIS_IN_CALLBACK)); if (!(c->flags & REDIS_CONNECTED)) { /* Abort connect was not successful. */ @@ -531,18 +735,31 @@ void redisAsyncHandleRead(redisAsyncContext *ac) { return; } - if (redisBufferRead(c) == REDIS_ERR) { + c->funcs->async_read(ac); +} + +void redisAsyncWrite(redisAsyncContext *ac) { + redisContext *c = &(ac->c); + int done = 0; + + if (redisBufferWrite(c,&done) == REDIS_ERR) { __redisAsyncDisconnect(ac); } else { - /* Always re-schedule reads */ + /* Continue writing when not done, stop writing otherwise */ + if (!done) + _EL_ADD_WRITE(ac); + else + _EL_DEL_WRITE(ac); + + /* Always schedule reads after writes */ _EL_ADD_READ(ac); - redisProcessCallbacks(ac); } } void redisAsyncHandleWrite(redisAsyncContext *ac) { redisContext *c = &(ac->c); - int done = 0; + /* must not be called from a callback */ + assert(!(c->flags & REDIS_IN_CALLBACK)); if (!(c->flags & REDIS_CONNECTED)) { /* Abort connect was not successful. */ @@ -553,18 +770,46 @@ void redisAsyncHandleWrite(redisAsyncContext *ac) { return; } - if (redisBufferWrite(c,&done) == REDIS_ERR) { - __redisAsyncDisconnect(ac); - } else { - /* Continue writing when not done, stop writing otherwise */ - if (!done) - _EL_ADD_WRITE(ac); - else - _EL_DEL_WRITE(ac); + c->funcs->async_write(ac); +} - /* Always schedule reads after writes */ - _EL_ADD_READ(ac); +void redisAsyncHandleTimeout(redisAsyncContext *ac) { + redisContext *c = &(ac->c); + redisCallback cb; + /* must not be called from a callback */ + assert(!(c->flags & REDIS_IN_CALLBACK)); + + if ((c->flags & REDIS_CONNECTED)) { + if (ac->replies.head == NULL && ac->sub.replies.head == NULL) { + /* Nothing to do - just an idle timeout */ + return; + } + + if (!ac->c.command_timeout || + (!ac->c.command_timeout->tv_sec && !ac->c.command_timeout->tv_usec)) { + /* A belated connect timeout arriving, ignore */ + return; + } } + + if (!c->err) { + __redisSetError(c, REDIS_ERR_TIMEOUT, "Timeout"); + __redisAsyncCopyError(ac); + } + + if (!(c->flags & REDIS_CONNECTED)) { + __redisRunConnectCallback(ac, REDIS_ERR); + } + + while (__redisShiftCallback(&ac->replies, &cb) == REDIS_OK) { + __redisRunCallback(ac, &cb, NULL); + } + + /** + * TODO: Don't automatically sever the connection, + * rather, allow to ignore <x> responses before the queue is clear + */ + __redisAsyncDisconnect(ac); } /* Sets a pointer to the first argument and its length starting at p. Returns @@ -589,6 +834,10 @@ static const char *nextArgument(const char *start, const char **str, size_t *len static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *cmd, size_t len) { redisContext *c = &(ac->c); redisCallback cb; + struct dict *cbdict; + dictIterator it; + dictEntry *de; + redisCallback *existcb; int pvariant, hasnext; const char *cstr, *astr; size_t clen, alen; @@ -602,6 +851,8 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void /* Setup callback */ cb.fn = fn; cb.privdata = privdata; + cb.pending_subs = 1; + cb.unsubscribe_sent = 0; /* Find out which command will be appended. */ p = nextArgument(cmd,&cstr,&clen); @@ -611,38 +862,97 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void cstr += pvariant; clen -= pvariant; - if (hasnext && clen >= 9 && strncasecmp(cstr,"subscribe\r\n",9) == 0) { + if (hasnext && strncasecmp(cstr,"subscribe\r\n",11) == 0) { c->flags |= REDIS_SUBSCRIBED; /* Add every channel/pattern to the list of subscription callbacks. */ while ((p = nextArgument(p,&astr,&alen)) != NULL) { sname = sdsnewlen(astr,alen); + if (sname == NULL) + goto oom; + if (pvariant) - ret = dictReplace(ac->sub.patterns,sname,&cb); + cbdict = ac->sub.patterns; else - ret = dictReplace(ac->sub.channels,sname,&cb); + cbdict = ac->sub.channels; + + de = dictFind(cbdict,sname); + + if (de != NULL) { + existcb = dictGetEntryVal(de); + cb.pending_subs = existcb->pending_subs + 1; + } + + ret = dictReplace(cbdict,sname,&cb); if (ret == 0) sdsfree(sname); } - } else if (clen >= 11 && strncasecmp(cstr,"unsubscribe\r\n",11) == 0) { + } else if (strncasecmp(cstr,"unsubscribe\r\n",13) == 0) { /* It is only useful to call (P)UNSUBSCRIBE when the context is * subscribed to one or more channels or patterns. */ if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR; + if (pvariant) + cbdict = ac->sub.patterns; + else + cbdict = ac->sub.channels; + + if (hasnext) { + /* Send an unsubscribe with specific channels/patterns. + * Bookkeeping the number of expected replies */ + while ((p = nextArgument(p,&astr,&alen)) != NULL) { + sname = sdsnewlen(astr,alen); + if (sname == NULL) + goto oom; + + de = dictFind(cbdict,sname); + if (de != NULL) { + existcb = dictGetEntryVal(de); + if (existcb->unsubscribe_sent == 0) + existcb->unsubscribe_sent = 1; + else + /* Already sent, reply to be ignored */ + ac->sub.pending_unsubs += 1; + } else { + /* Not subscribed to, reply to be ignored */ + ac->sub.pending_unsubs += 1; + } + sdsfree(sname); + } + } else { + /* Send an unsubscribe without specific channels/patterns. + * Bookkeeping the number of expected replies */ + int no_subs = 1; + dictInitIterator(&it,cbdict); + while ((de = dictNext(&it)) != NULL) { + existcb = dictGetEntryVal(de); + if (existcb->unsubscribe_sent == 0) { + existcb->unsubscribe_sent = 1; + no_subs = 0; + } + } + /* Unsubscribing to all channels/patterns, where none is + * subscribed to, results in a single reply to be ignored. */ + if (no_subs == 1) + ac->sub.pending_unsubs += 1; + } + /* (P)UNSUBSCRIBE does not have its own response: every channel or * pattern that is unsubscribed will receive a message. This means we * should not append a callback function for this command. */ - } else if(clen >= 7 && strncasecmp(cstr,"monitor\r\n",7) == 0) { - /* Set monitor flag and push callback */ - c->flags |= REDIS_MONITORING; - __redisPushCallback(&ac->replies,&cb); + } else if (strncasecmp(cstr,"monitor\r\n",9) == 0) { + /* Set monitor flag and push callback */ + c->flags |= REDIS_MONITORING; + if (__redisPushCallback(&ac->replies,&cb) != REDIS_OK) + goto oom; } else { - if (c->flags & REDIS_SUBSCRIBED) - /* This will likely result in an error reply, but it needs to be - * received and passed to the callback. */ - __redisPushCallback(&ac->sub.invalid,&cb); - else - __redisPushCallback(&ac->replies,&cb); + if (c->flags & REDIS_SUBSCRIBED) { + if (__redisPushCallback(&ac->sub.replies,&cb) != REDIS_OK) + goto oom; + } else { + if (__redisPushCallback(&ac->replies,&cb) != REDIS_OK) + goto oom; + } } __redisAppendCommand(c,cmd,len); @@ -651,6 +961,10 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void _EL_ADD_WRITE(ac); return REDIS_OK; +oom: + __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); + __redisAsyncCopyError(ac); + return REDIS_ERR; } int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) { @@ -664,7 +978,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat return REDIS_ERR; status = __redisAsyncCommand(ac,fn,privdata,cmd,len); - free(cmd); + hi_free(cmd); return status; } @@ -679,9 +993,11 @@ int redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata int redisAsyncCommandArgv(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, int argc, const char **argv, const size_t *argvlen) { sds cmd; - int len; + long long len; int status; len = redisFormatSdsCommandArgv(&cmd,argc,argv,argvlen); + if (len < 0) + return REDIS_ERR; status = __redisAsyncCommand(ac,fn,privdata,cmd,len); sdsfree(cmd); return status; @@ -691,3 +1007,28 @@ int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void int status = __redisAsyncCommand(ac,fn,privdata,cmd,len); return status; } + +redisAsyncPushFn *redisAsyncSetPushCallback(redisAsyncContext *ac, redisAsyncPushFn *fn) { + redisAsyncPushFn *old = ac->push_cb; + ac->push_cb = fn; + return old; +} + +int redisAsyncSetTimeout(redisAsyncContext *ac, struct timeval tv) { + if (!ac->c.command_timeout) { + ac->c.command_timeout = hi_calloc(1, sizeof(tv)); + if (ac->c.command_timeout == NULL) { + __redisSetError(&ac->c, REDIS_ERR_OOM, "Out of memory"); + __redisAsyncCopyError(ac); + return REDIS_ERR; + } + } + + if (tv.tv_sec != ac->c.command_timeout->tv_sec || + tv.tv_usec != ac->c.command_timeout->tv_usec) + { + *ac->c.command_timeout = tv; + } + + return REDIS_OK; +} |