diff --git a/async.c b/async.c index 2fcafc1..15b6f42 100644 --- a/async.c +++ b/async.c @@ -66,12 +66,19 @@ int __redisAppendCommand(redisContext *c, const char *cmd, size_t len); void __redisSetError(redisContext *c, int type, const char *str); /* Reference counting for callback struct. */ -static void callbackIncrRefCount(redisCallback *cb) { +static void callbackIncrRefCount(redisAsyncContext *ac, redisCallback *cb) { + (void)ac; cb->refcount++; } -static void callbackDecrRefCount(redisCallback *cb) { +static void callbackDecrRefCount(redisAsyncContext *ac, redisCallback *cb) { cb->refcount--; if (cb->refcount == 0) { + if (cb->finalizer != NULL) { + redisContext *c = &(ac->c); + c->flags |= REDIS_IN_CALLBACK; + cb->finalizer(ac, cb->privdata); + c->flags &= ~REDIS_IN_CALLBACK; + } hi_free(cb); } } @@ -98,14 +105,12 @@ static void callbackKeyDestructor(void *privdata, void *key) { } static void *callbackValDup(void *privdata, const void *val) { - (void)privdata; - callbackIncrRefCount((redisCallback *)val); + callbackIncrRefCount((redisAsyncContext *)privdata, (redisCallback *)val); return (void *)val; } static void callbackValDestructor(void *privdata, void *val) { - (void)privdata; - callbackDecrRefCount((redisCallback *)val); + callbackDecrRefCount((redisAsyncContext *)privdata, (redisCallback *)val); } static dictType callbackDict = { @@ -121,22 +126,22 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { redisAsyncContext *ac; dict *channels = NULL, *patterns = NULL, *shard_channels = NULL; - channels = dictCreate(&callbackDict,NULL); + ac = hi_realloc(c,sizeof(redisAsyncContext)); + if (ac == NULL) + goto oom; + + channels = dictCreate(&callbackDict, ac); if (channels == NULL) goto oom; - patterns = dictCreate(&callbackDict,NULL); + patterns = dictCreate(&callbackDict, ac); if (patterns == NULL) goto oom; - shard_channels = dictCreate(&callbackDict,NULL); + shard_channels = dictCreate(&callbackDict, ac); if (shard_channels == NULL) goto oom; - ac = hi_realloc(c,sizeof(redisAsyncContext)); - if (ac == NULL) - goto oom; - c = &(ac->c); /* The regular connect functions will always set the flag REDIS_CONNECTED. @@ -173,6 +178,7 @@ oom: if (channels) dictRelease(channels); if (patterns) dictRelease(patterns); if (shard_channels) dictRelease(shard_channels); + if (ac) hi_free(ac); return NULL; } @@ -386,7 +392,7 @@ static void __redisAsyncFree(redisAsyncContext *ac) { /* Execute pending callbacks with NULL reply. */ while (__redisShiftCallback(&ac->replies,&cb) == REDIS_OK) { __redisRunCallback(ac,cb,NULL); - callbackDecrRefCount(cb); + callbackDecrRefCount(ac, cb); } /* Run subscription callbacks with NULL reply */ @@ -700,7 +706,7 @@ void redisProcessCallbacks(redisAsyncContext *ac) { /* The command needs more repies. Put it first in queue. */ __redisUnshiftCallback(&ac->replies, cb); } else { - callbackDecrRefCount(cb); + callbackDecrRefCount(ac, cb); } } @@ -858,7 +864,7 @@ void redisAsyncHandleTimeout(redisAsyncContext *ac) { while (__redisShiftCallback(&ac->replies, &cb) == REDIS_OK) { __redisRunCallback(ac, cb, NULL); - callbackDecrRefCount(cb); + callbackDecrRefCount(ac, cb); } /** @@ -899,7 +905,9 @@ static int isPubsubCommand(const char *cmd, size_t len) { /* Helper function for the redisAsyncCommand* family of functions. Writes a * formatted command to the output buffer and registers the provided callback * function with the context. */ -static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *cmd, size_t len) { +static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, + redisFinalizerCallback *finalizer, void *privdata, + const char *cmd, size_t len) { redisContext *c = &(ac->c); redisCallback *cb; const char *cstr, *astr; @@ -914,6 +922,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void if (cb == NULL) goto oom; cb->fn = fn; + cb->finalizer = finalizer; cb->privdata = privdata; cb->refcount = 1; cb->pending_replies = 1; /* Most commands have exactly 1 reply. */ @@ -953,11 +962,16 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void oom: __redisSetError(&(ac->c), REDIS_ERR_OOM, "Out of memory"); __redisAsyncCopyError(ac); - callbackDecrRefCount(cb); + callbackDecrRefCount(ac, cb); return REDIS_ERR; } int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *format, va_list ap) { + return redisvAsyncCommandWithFinalizer(ac, fn, NULL, privdata, format, ap); +} + +int redisvAsyncCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, + void *privdata, const char *format, va_list ap) { char *cmd; int len; int status; @@ -967,7 +981,7 @@ int redisvAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdat if (len < 0) return REDIS_ERR; - status = __redisAsyncCommand(ac,fn,privdata,cmd,len); + status = __redisAsyncCommand(ac,fn,finalizer,privdata,cmd,len); hi_free(cmd); return status; } @@ -981,20 +995,41 @@ int redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata return status; } +int redisAsyncCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, + void *privdata, const char *format, ...) { + va_list ap; + int status; + va_start(ap,format); + status = redisvAsyncCommandWithFinalizer(ac,fn,finalizer,privdata,format,ap); + va_end(ap); + return status; +} + int redisAsyncCommandArgv(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, int argc, const char **argv, const size_t *argvlen) { + return redisAsyncCommandArgvWithFinalizer(ac, fn, NULL, privdata, argc, argv, argvlen); +} + +int redisAsyncCommandArgvWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, + void *privdata, int argc, const char **argv, const size_t *argvlen) { sds cmd; long long len; int status; len = redisFormatSdsCommandArgv(&cmd,argc,argv,argvlen); if (len < 0) return REDIS_ERR; - status = __redisAsyncCommand(ac,fn,privdata,cmd,len); + status = __redisAsyncCommand(ac,fn,finalizer,privdata,cmd,len); sdsfree(cmd); return status; } int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *cmd, size_t len) { - int status = __redisAsyncCommand(ac,fn,privdata,cmd,len); + int status = __redisAsyncCommand(ac,fn,NULL,privdata,cmd,len); + return status; +} + +int redisAsyncFormattedCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, + void *privdata, const char *cmd, size_t len) { + int status = __redisAsyncCommand(ac,fn,finalizer,privdata,cmd,len); return status; } diff --git a/async.h b/async.h index f83e1aa..4838300 100644 --- a/async.h +++ b/async.h @@ -42,9 +42,11 @@ struct dict; /* dictionary header is included in async.c */ /* Reply callback prototype and container */ typedef void (redisCallbackFn)(struct redisAsyncContext*, void*, void*); +typedef void (redisFinalizerCallback)(struct redisAsyncContext *ac, void *privdata); typedef struct redisCallback { struct redisCallback *next; /* simple singly linked list */ redisCallbackFn *fn; + redisFinalizerCallback *finalizer; void *privdata; unsigned int refcount; /* Reference counter used when callback is used * for multiple pubsub channels. */ @@ -147,6 +149,11 @@ int redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata int redisAsyncCommandArgv(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, int argc, const char **argv, const size_t *argvlen); int redisAsyncFormattedCommand(redisAsyncContext *ac, redisCallbackFn *fn, void *privdata, const char *cmd, size_t len); +int redisvAsyncCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, void *privdata, const char *format, va_list ap); +int redisAsyncCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, void *privdata, const char *format, ...); +int redisAsyncCommandArgvWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, void *privdata, int argc, const char **argv, const size_t *argvlen); +int redisAsyncFormattedCommandWithFinalizer(redisAsyncContext *ac, redisCallbackFn *fn, redisFinalizerCallback *finalizer, void *privdata, const char *cmd, size_t len); + #ifdef __cplusplus } #endif diff --git a/test.c b/test.c index 6ac3ea0..d63f94b 100644 --- a/test.c +++ b/test.c @@ -1682,6 +1682,12 @@ void null_cb(redisAsyncContext *ac, void *r, void *privdata) { state->checkpoint++; } +void finalizer_cb(redisAsyncContext *ac, void *privdata) { + TestState *state = privdata; + (void)ac; + state->checkpoint++; +} + static void test_pubsub_handling(struct config config) { test("Subscribe, handle published message and unsubscribe: "); /* Setup event dispatcher with a testcase timeout */ @@ -1701,10 +1707,10 @@ static void test_pubsub_handling(struct config config) { /* Start subscribe */ TestState state = {.options = &options}; - redisAsyncCommand(ac,subscribe_cb,&state,"subscribe mychannel"); + redisAsyncCommandWithFinalizer(ac,subscribe_cb,finalizer_cb,&state,"subscribe mychannel"); /* Make sure non-subscribe commands are handled */ - redisAsyncCommand(ac,array_cb,&state,"PING"); + redisAsyncCommandWithFinalizer(ac,array_cb,finalizer_cb,&state,"PING"); /* Start event dispatching loop */ test_cond(event_base_dispatch(base) == 0); @@ -1712,7 +1718,7 @@ static void test_pubsub_handling(struct config config) { event_base_free(base); /* Verify test checkpoints */ - assert(state.checkpoint == 3); + assert(state.checkpoint == 5); } /* Unexpected push message, will trigger a failure */ @@ -1930,8 +1936,8 @@ static void test_pubsub_multiple_channels(struct config config) { /* Start subscribing to two channels */ TestState state = {.options = &options}; - redisAsyncCommand(ac,subscribe_channel_a_cb,&state,"subscribe A"); - redisAsyncCommand(ac,subscribe_channel_b_cb,&state,"subscribe B"); + redisAsyncCommandWithFinalizer(ac,subscribe_channel_a_cb,finalizer_cb,&state,"subscribe A"); + redisAsyncCommandWithFinalizer(ac,subscribe_channel_b_cb,finalizer_cb,&state,"subscribe B"); /* Start event dispatching loop */ assert(event_base_dispatch(base) == 0); @@ -1939,7 +1945,7 @@ static void test_pubsub_multiple_channels(struct config config) { event_base_free(base); /* Verify test checkpoints */ - test_cond(state.checkpoint == 6); + test_cond(state.checkpoint == 8); } /* Command callback for test_monitor() */