diff --git a/async.c b/async.c index f82f567..465c366 100644 --- a/async.c +++ b/async.c @@ -105,7 +105,7 @@ static dictType callbackDict = { static redisAsyncContext *redisAsyncInitialize(redisContext *c) { redisAsyncContext *ac; - dict *channels = NULL, *patterns = NULL; + dict *channels = NULL, *patterns = NULL, *schannels = NULL; channels = dictCreate(&callbackDict,NULL); if (channels == NULL) @@ -115,6 +115,10 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { if (patterns == NULL) goto oom; + schannels = dictCreate(&callbackDict,NULL); + if (schannels == NULL) + goto oom; + ac = hi_realloc(c,sizeof(redisAsyncContext)); if (ac == NULL) goto oom; @@ -149,12 +153,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) { ac->sub.replies.tail = NULL; ac->sub.channels = channels; ac->sub.patterns = patterns; + ac->sub.schannels = schannels; ac->sub.pending_unsubs = 0; return ac; oom: if (channels) dictRelease(channels); if (patterns) dictRelease(patterns); + if (schannels) dictRelease(schannels); return NULL; } @@ -388,6 +394,14 @@ static void __redisAsyncFree(redisAsyncContext *ac) { dictRelease(ac->sub.patterns); } + if (ac->sub.schannels) { + dictInitIterator(&it,ac->sub.schannels); + while ((de = dictNext(&it)) != NULL) + __redisRunCallback(ac,dictGetEntryVal(de),NULL); + + dictRelease(ac->sub.schannels); + } + /* Signal event lib to clean up */ _EL_CLEANUP(ac); @@ -467,12 +481,18 @@ void redisAsyncDisconnect(redisAsyncContext *ac) { __redisAsyncDisconnect(ac); } +static int redisIsShardedVariant(const char* cstr) { + return !strncasecmp("sm", cstr, 2) || /* smessage */ + !strncasecmp("ss", cstr, 2) || /* ssubscribe */ + !strncasecmp("sun",cstr, 3); /* sunsubscribe */ +} + static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) { redisContext *c = &(ac->c); dict *callbacks; redisCallback *cb = NULL; dictEntry *de; - int pvariant; + int pvariant, svariant; char *stype; sds sname = NULL; @@ -484,11 +504,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, assert(reply->element[0]->type == REDIS_REPLY_STRING); stype = reply->element[0]->str; pvariant = (tolower(stype[0]) == 'p') ? 1 : 0; + svariant = redisIsShardedVariant(stype); - if (pvariant) - callbacks = ac->sub.patterns; - else - callbacks = ac->sub.channels; + callbacks = pvariant ? ac->sub.patterns : + svariant ? ac->sub.schannels : + ac->sub.channels; /* Locate the right callback */ if (reply->element[1]->type == REDIS_REPLY_STRING) { @@ -502,11 +522,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, } /* If this is an subscribe reply decrease pending counter. */ - if (strcasecmp(stype+pvariant,"subscribe") == 0) { + if (strcasecmp(stype+pvariant+svariant,"subscribe") == 0) { assert(cb != NULL); cb->pending_subs -= 1; - } else if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { + } else if (strcasecmp(stype+pvariant+svariant,"unsubscribe") == 0) { if (cb == NULL) ac->sub.pending_unsubs -= 1; else if (cb->pending_subs == 0) @@ -521,6 +541,7 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, if (reply->element[2]->integer == 0 && dictSize(ac->sub.channels) == 0 && dictSize(ac->sub.patterns) == 0 + && dictSize(ac->sub.schannels) == 0 && ac->sub.pending_unsubs == 0) { c->flags &= ~REDIS_SUBSCRIBED; @@ -558,7 +579,7 @@ static int redisIsSubscribeReply(redisReply *reply) { } /* Get the string/len moving past 'p' if needed */ - off = tolower(reply->element[0]->str[0]) == 'p'; + off = tolower(reply->element[0]->str[0]) == 'p' || redisIsShardedVariant(reply->element[0]->str); str = reply->element[0]->str + off; len = reply->element[0]->len - off; @@ -838,7 +859,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void dictIterator it; dictEntry *de; redisCallback *existcb; - int pvariant, hasnext; + int pvariant, hasnext, hasprefix, svariant; const char *cstr, *astr; size_t clen, alen; const char *p; @@ -859,8 +880,10 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void assert(p != NULL); hasnext = (p[0] == '$'); pvariant = (tolower(cstr[0]) == 'p') ? 1 : 0; - cstr += pvariant; - clen -= pvariant; + svariant = redisIsShardedVariant(cstr); + hasprefix = svariant || pvariant; + cstr += hasprefix; + clen -= hasprefix; if (hasnext && strncasecmp(cstr,"subscribe\r\n",11) == 0) { c->flags |= REDIS_SUBSCRIBED; @@ -871,10 +894,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void if (sname == NULL) goto oom; - if (pvariant) - cbdict = ac->sub.patterns; - else - cbdict = ac->sub.channels; + cbdict = pvariant ? ac->sub.patterns : + svariant ? ac->sub.schannels : + ac->sub.channels; de = dictFind(cbdict,sname); @@ -892,10 +914,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void * subscribed to one or more channels or patterns. */ if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR; - if (pvariant) - cbdict = ac->sub.patterns; - else - cbdict = ac->sub.channels; + cbdict = pvariant ? ac->sub.patterns : + svariant ? ac->sub.schannels : + ac->sub.channels; if (hasnext) { /* Send an unsubscribe with specific channels/patterns. diff --git a/async.h b/async.h index 4f94660..252537f 100644 --- a/async.h +++ b/async.h @@ -108,6 +108,7 @@ typedef struct redisAsyncContext { redisCallbackList replies; struct dict *channels; struct dict *patterns; + struct dict *schannels; int pending_unsubs; } sub;