Add sharded subscribe support

This commit is contained in:
Vyacheslav 2023-11-09 10:36:57 +03:00
parent 869f3d0ef1
commit 23ea5aa4a7
2 changed files with 42 additions and 20 deletions

61
async.c
View File

@ -105,7 +105,7 @@ static dictType callbackDict = {
static redisAsyncContext *redisAsyncInitialize(redisContext *c) { static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
redisAsyncContext *ac; redisAsyncContext *ac;
dict *channels = NULL, *patterns = NULL; dict *channels = NULL, *patterns = NULL, *schannels = NULL;
channels = dictCreate(&callbackDict,NULL); channels = dictCreate(&callbackDict,NULL);
if (channels == NULL) if (channels == NULL)
@ -115,6 +115,10 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
if (patterns == NULL) if (patterns == NULL)
goto oom; goto oom;
schannels = dictCreate(&callbackDict,NULL);
if (schannels == NULL)
goto oom;
ac = hi_realloc(c,sizeof(redisAsyncContext)); ac = hi_realloc(c,sizeof(redisAsyncContext));
if (ac == NULL) if (ac == NULL)
goto oom; goto oom;
@ -149,12 +153,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
ac->sub.replies.tail = NULL; ac->sub.replies.tail = NULL;
ac->sub.channels = channels; ac->sub.channels = channels;
ac->sub.patterns = patterns; ac->sub.patterns = patterns;
ac->sub.schannels = schannels;
ac->sub.pending_unsubs = 0; ac->sub.pending_unsubs = 0;
return ac; return ac;
oom: oom:
if (channels) dictRelease(channels); if (channels) dictRelease(channels);
if (patterns) dictRelease(patterns); if (patterns) dictRelease(patterns);
if (schannels) dictRelease(schannels);
return NULL; return NULL;
} }
@ -388,6 +394,14 @@ static void __redisAsyncFree(redisAsyncContext *ac) {
dictRelease(ac->sub.patterns); dictRelease(ac->sub.patterns);
} }
if (ac->sub.schannels) {
dictInitIterator(&it,ac->sub.schannels);
while ((de = dictNext(&it)) != NULL)
__redisRunCallback(ac,dictGetEntryVal(de),NULL);
dictRelease(ac->sub.schannels);
}
/* Signal event lib to clean up */ /* Signal event lib to clean up */
_EL_CLEANUP(ac); _EL_CLEANUP(ac);
@ -467,12 +481,18 @@ void redisAsyncDisconnect(redisAsyncContext *ac) {
__redisAsyncDisconnect(ac); __redisAsyncDisconnect(ac);
} }
static int redisIsShardedVariant(const char* cstr) {
return !strncasecmp("sm", cstr, 2) || /* smessage */
!strncasecmp("ss", cstr, 2) || /* ssubscribe */
!strncasecmp("sun",cstr, 3); /* sunsubscribe */
}
static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) { static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) {
redisContext *c = &(ac->c); redisContext *c = &(ac->c);
dict *callbacks; dict *callbacks;
redisCallback *cb = NULL; redisCallback *cb = NULL;
dictEntry *de; dictEntry *de;
int pvariant; int pvariant, svariant;
char *stype; char *stype;
sds sname = NULL; sds sname = NULL;
@ -484,11 +504,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
assert(reply->element[0]->type == REDIS_REPLY_STRING); assert(reply->element[0]->type == REDIS_REPLY_STRING);
stype = reply->element[0]->str; stype = reply->element[0]->str;
pvariant = (tolower(stype[0]) == 'p') ? 1 : 0; pvariant = (tolower(stype[0]) == 'p') ? 1 : 0;
svariant = redisIsShardedVariant(stype);
if (pvariant) callbacks = pvariant ? ac->sub.patterns :
callbacks = ac->sub.patterns; svariant ? ac->sub.schannels :
else ac->sub.channels;
callbacks = ac->sub.channels;
/* Locate the right callback */ /* Locate the right callback */
if (reply->element[1]->type == REDIS_REPLY_STRING) { if (reply->element[1]->type == REDIS_REPLY_STRING) {
@ -502,11 +522,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
} }
/* If this is an subscribe reply decrease pending counter. */ /* If this is an subscribe reply decrease pending counter. */
if (strcasecmp(stype+pvariant,"subscribe") == 0) { if (strcasecmp(stype+pvariant+svariant,"subscribe") == 0) {
assert(cb != NULL); assert(cb != NULL);
cb->pending_subs -= 1; cb->pending_subs -= 1;
} else if (strcasecmp(stype+pvariant,"unsubscribe") == 0) { } else if (strcasecmp(stype+pvariant+svariant,"unsubscribe") == 0) {
if (cb == NULL) if (cb == NULL)
ac->sub.pending_unsubs -= 1; ac->sub.pending_unsubs -= 1;
else if (cb->pending_subs == 0) else if (cb->pending_subs == 0)
@ -521,6 +541,7 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
if (reply->element[2]->integer == 0 if (reply->element[2]->integer == 0
&& dictSize(ac->sub.channels) == 0 && dictSize(ac->sub.channels) == 0
&& dictSize(ac->sub.patterns) == 0 && dictSize(ac->sub.patterns) == 0
&& dictSize(ac->sub.schannels) == 0
&& ac->sub.pending_unsubs == 0) { && ac->sub.pending_unsubs == 0) {
c->flags &= ~REDIS_SUBSCRIBED; c->flags &= ~REDIS_SUBSCRIBED;
@ -558,7 +579,7 @@ static int redisIsSubscribeReply(redisReply *reply) {
} }
/* Get the string/len moving past 'p' if needed */ /* Get the string/len moving past 'p' if needed */
off = tolower(reply->element[0]->str[0]) == 'p'; off = tolower(reply->element[0]->str[0]) == 'p' || redisIsShardedVariant(reply->element[0]->str);
str = reply->element[0]->str + off; str = reply->element[0]->str + off;
len = reply->element[0]->len - off; len = reply->element[0]->len - off;
@ -838,7 +859,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
dictIterator it; dictIterator it;
dictEntry *de; dictEntry *de;
redisCallback *existcb; redisCallback *existcb;
int pvariant, hasnext; int pvariant, hasnext, hasprefix, svariant;
const char *cstr, *astr; const char *cstr, *astr;
size_t clen, alen; size_t clen, alen;
const char *p; const char *p;
@ -859,8 +880,10 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
assert(p != NULL); assert(p != NULL);
hasnext = (p[0] == '$'); hasnext = (p[0] == '$');
pvariant = (tolower(cstr[0]) == 'p') ? 1 : 0; pvariant = (tolower(cstr[0]) == 'p') ? 1 : 0;
cstr += pvariant; svariant = redisIsShardedVariant(cstr);
clen -= pvariant; hasprefix = svariant || pvariant;
cstr += hasprefix;
clen -= hasprefix;
if (hasnext && strncasecmp(cstr,"subscribe\r\n",11) == 0) { if (hasnext && strncasecmp(cstr,"subscribe\r\n",11) == 0) {
c->flags |= REDIS_SUBSCRIBED; c->flags |= REDIS_SUBSCRIBED;
@ -871,10 +894,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
if (sname == NULL) if (sname == NULL)
goto oom; goto oom;
if (pvariant) cbdict = pvariant ? ac->sub.patterns :
cbdict = ac->sub.patterns; svariant ? ac->sub.schannels :
else ac->sub.channels;
cbdict = ac->sub.channels;
de = dictFind(cbdict,sname); de = dictFind(cbdict,sname);
@ -892,10 +914,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
* subscribed to one or more channels or patterns. */ * subscribed to one or more channels or patterns. */
if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR; if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR;
if (pvariant) cbdict = pvariant ? ac->sub.patterns :
cbdict = ac->sub.patterns; svariant ? ac->sub.schannels :
else ac->sub.channels;
cbdict = ac->sub.channels;
if (hasnext) { if (hasnext) {
/* Send an unsubscribe with specific channels/patterns. /* Send an unsubscribe with specific channels/patterns.

View File

@ -108,6 +108,7 @@ typedef struct redisAsyncContext {
redisCallbackList replies; redisCallbackList replies;
struct dict *channels; struct dict *channels;
struct dict *patterns; struct dict *patterns;
struct dict *schannels;
int pending_unsubs; int pending_unsubs;
} sub; } sub;