This commit is contained in:
Vyacheslav Vanin 2025-02-12 14:18:41 +00:00 committed by GitHub
commit 9ee4899f8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 20 deletions

61
async.c
View File

@ -105,7 +105,7 @@ static dictType callbackDict = {
static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
redisAsyncContext *ac;
dict *channels = NULL, *patterns = NULL;
dict *channels = NULL, *patterns = NULL, *schannels = NULL;
channels = dictCreate(&callbackDict,NULL);
if (channels == NULL)
@ -115,6 +115,10 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
if (patterns == NULL)
goto oom;
schannels = dictCreate(&callbackDict,NULL);
if (schannels == NULL)
goto oom;
ac = hi_realloc(c,sizeof(redisAsyncContext));
if (ac == NULL)
goto oom;
@ -149,12 +153,14 @@ static redisAsyncContext *redisAsyncInitialize(redisContext *c) {
ac->sub.replies.tail = NULL;
ac->sub.channels = channels;
ac->sub.patterns = patterns;
ac->sub.schannels = schannels;
ac->sub.pending_unsubs = 0;
return ac;
oom:
if (channels) dictRelease(channels);
if (patterns) dictRelease(patterns);
if (schannels) dictRelease(schannels);
return NULL;
}
@ -388,6 +394,14 @@ static void __redisAsyncFree(redisAsyncContext *ac) {
dictRelease(ac->sub.patterns);
}
if (ac->sub.schannels) {
dictInitIterator(&it,ac->sub.schannels);
while ((de = dictNext(&it)) != NULL)
__redisRunCallback(ac,dictGetEntryVal(de),NULL);
dictRelease(ac->sub.schannels);
}
/* Signal event lib to clean up */
_EL_CLEANUP(ac);
@ -467,12 +481,18 @@ void redisAsyncDisconnect(redisAsyncContext *ac) {
__redisAsyncDisconnect(ac);
}
static int redisIsShardedVariant(const char* cstr) {
return !strncasecmp("sm", cstr, 2) || /* smessage */
!strncasecmp("ss", cstr, 2) || /* ssubscribe */
!strncasecmp("sun",cstr, 3); /* sunsubscribe */
}
static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply, redisCallback *dstcb) {
redisContext *c = &(ac->c);
dict *callbacks;
redisCallback *cb = NULL;
dictEntry *de;
int pvariant;
int pvariant, svariant;
char *stype;
sds sname = NULL;
@ -484,11 +504,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
assert(reply->element[0]->type == REDIS_REPLY_STRING);
stype = reply->element[0]->str;
pvariant = (tolower(stype[0]) == 'p') ? 1 : 0;
svariant = redisIsShardedVariant(stype);
if (pvariant)
callbacks = ac->sub.patterns;
else
callbacks = ac->sub.channels;
callbacks = pvariant ? ac->sub.patterns :
svariant ? ac->sub.schannels :
ac->sub.channels;
/* Locate the right callback */
if (reply->element[1]->type == REDIS_REPLY_STRING) {
@ -502,11 +522,11 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
}
/* If this is an subscribe reply decrease pending counter. */
if (strcasecmp(stype+pvariant,"subscribe") == 0) {
if (strcasecmp(stype+pvariant+svariant,"subscribe") == 0) {
assert(cb != NULL);
cb->pending_subs -= 1;
} else if (strcasecmp(stype+pvariant,"unsubscribe") == 0) {
} else if (strcasecmp(stype+pvariant+svariant,"unsubscribe") == 0) {
if (cb == NULL)
ac->sub.pending_unsubs -= 1;
else if (cb->pending_subs == 0)
@ -521,6 +541,7 @@ static int __redisGetSubscribeCallback(redisAsyncContext *ac, redisReply *reply,
if (reply->element[2]->integer == 0
&& dictSize(ac->sub.channels) == 0
&& dictSize(ac->sub.patterns) == 0
&& dictSize(ac->sub.schannels) == 0
&& ac->sub.pending_unsubs == 0) {
c->flags &= ~REDIS_SUBSCRIBED;
@ -558,7 +579,7 @@ static int redisIsSubscribeReply(redisReply *reply) {
}
/* Get the string/len moving past 'p' if needed */
off = tolower(reply->element[0]->str[0]) == 'p';
off = tolower(reply->element[0]->str[0]) == 'p' || redisIsShardedVariant(reply->element[0]->str);
str = reply->element[0]->str + off;
len = reply->element[0]->len - off;
@ -838,7 +859,7 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
dictIterator it;
dictEntry *de;
redisCallback *existcb;
int pvariant, hasnext;
int pvariant, hasnext, hasprefix, svariant;
const char *cstr, *astr;
size_t clen, alen;
const char *p;
@ -859,8 +880,10 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
assert(p != NULL);
hasnext = (p[0] == '$');
pvariant = (tolower(cstr[0]) == 'p') ? 1 : 0;
cstr += pvariant;
clen -= pvariant;
svariant = redisIsShardedVariant(cstr);
hasprefix = svariant || pvariant;
cstr += hasprefix;
clen -= hasprefix;
if (hasnext && strncasecmp(cstr,"subscribe\r\n",11) == 0) {
c->flags |= REDIS_SUBSCRIBED;
@ -871,10 +894,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
if (sname == NULL)
goto oom;
if (pvariant)
cbdict = ac->sub.patterns;
else
cbdict = ac->sub.channels;
cbdict = pvariant ? ac->sub.patterns :
svariant ? ac->sub.schannels :
ac->sub.channels;
de = dictFind(cbdict,sname);
@ -892,10 +914,9 @@ static int __redisAsyncCommand(redisAsyncContext *ac, redisCallbackFn *fn, void
* subscribed to one or more channels or patterns. */
if (!(c->flags & REDIS_SUBSCRIBED)) return REDIS_ERR;
if (pvariant)
cbdict = ac->sub.patterns;
else
cbdict = ac->sub.channels;
cbdict = pvariant ? ac->sub.patterns :
svariant ? ac->sub.schannels :
ac->sub.channels;
if (hasnext) {
/* Send an unsubscribe with specific channels/patterns.

View File

@ -108,6 +108,7 @@ typedef struct redisAsyncContext {
redisCallbackList replies;
struct dict *channels;
struct dict *patterns;
struct dict *schannels;
int pending_unsubs;
} sub;