数据结构-集合

集合作为redis中的一种常用的数据结构,根据其集合内的元素特点,底层采用数组哈希表两种方式进行实现。集合初始化的时候,会先根据的第一个元素是否能够转化成整形,然后决定使用数组还是哈希表来实现。哈希表是通用的实现方式,数组只是针对所有的元素都是整数的情况。在一定的情况下,会发生转换,转换是单向的,只能从数组转到到哈希表。触发的条件有两个,在数组实现的基础上,一是往集合内加入非整数元素,二是集合的大小超过了限制(默认是512,可以通过参数set-max-intset-entries修改)。这是因为数组的实现的集合,在每次新增、删除操作的时候,都会重新去申请内存,当涉及到的内存比较大的时候,效率会有所降低。

概述

在集合的实现里,重点是在集合操作上,求交集、并集、差集、随机删除以及随机获取。根据集合的数组组成不同,不同的算法带来的效率是不一样的。

交集

求交集的时候,会先对所有集合按照其元素数量大小进行升序排序,以第一个集合(元素数量最小的)为基准,依次便利第一个集合内的每一个元素,查看在其他集合中是否都存在

并集

求并集的时候,先建立一个新的空集合,依次把每个集合的每个元素放入新集合

差集

求差集的时候,是有一个基准,即求第一个集合和其他集合的差集。根据集合的元素数量组成,有两种算法。

  1. 依次遍历基准集合的每一个元素,判断是否在其他集合中都不存在。算法的时间复杂度为O(N*M),N为基准集合元素的数量,M为所有集合的数量
  2. 复制一个基准集合,然后依次遍历其余集合的每一个元素,从复制的基准集合中删除该元素。算法的时间复杂度为O(N),N为所有集合的元素数量和

第一种下,要求从第二个集合开始按照其元素数量大小进行降序排序,因为数量大的集合出现相同元素的概率更大,避免后续的对比

随机删除

随机删除操作,根据集合元素的数量和要删除数量大小的关系,有三种情况

  1. 要删除的数量大于等于集合的元素数量:直接删除整个集合
  2. 要删除的数量小于集合的元素数量,且占集合元素数量比例小:随机选择要删除的元素
  3. 要删除的数量小于集合的元素数量,且占集合元素数量比例大:随机选择要留下来的元素,反向操作是,因为哈希表在使用率很低的时候,随机获取数据的成本很高

随机获取

随机获取的操作,根据数据允不允许多次出现以及集合元素的数量和要获取数量大小的关系,分为四种情况

  1. 允许重复出现:每次都随机获取一个,直到数量满足
  2. 不允许重复出现,且要获取的数量大于等于集合的元素数量:直接返回整个集合
  3. 不允许重复出现,要获取的数量小于集合的元素数量,且占集合元素数量比例小:每随机获取一个元素后,把该元素从集合内删除,直到数量满足
  4. 不允许重复出现,要获取的数量小于集合的元素数量,且占集合元素数量比例大:每次都随机获取一个,如果重复忽略,直到数量满足

源码分析

集合实现(t_set.c)

setTypeCreate

根据数据能否转换成整数,返回空集合(数组实现/哈希实现)

1
2
3
4
5
6
robj *setTypeCreate(sds value) {
//能够转换成整数
if (isSdsRepresentableAsLongLong(value,NULL) == C_OK)
return createIntsetObject();
return createSetObject();
}
setTypeAdd

新增元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
int setTypeAdd(robj *subject, sds value) {
long long llval;

//哈希表实现
if (subject->encoding == OBJ_ENCODING_HT) {
dict *ht = subject->ptr;
dictEntry *de = dictAddRaw(ht,value,NULL);
if (de) {
dictSetKey(ht,de,sdsdup(value));
dictSetVal(ht,de,NULL);
return 1;
}
} else if (subject->encoding == OBJ_ENCODING_INTSET) {
//数组实现

//新数据还是整数
if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
uint8_t success = 0;

//数组添加
subject->ptr = intsetAdd(subject->ptr,llval,&success);
if (success) {

//数组长度达到限制了,转换成哈希表实现
if (intsetLen(subject->ptr) > server.set_max_intset_entries)
setTypeConvert(subject,OBJ_ENCODING_HT);
return 1;
}
} else {

//新数据不是整数,直接转换成哈希表
setTypeConvert(subject,OBJ_ENCODING_HT);

//哈希表添加数据
serverAssert(dictAdd(subject->ptr,sdsdup(value),NULL) == DICT_OK);
return 1;
}
} else {
serverPanic("Unknown set encoding");
}
return 0;
}
setTypeRemove

删除元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int setTypeRemove(robj *setobj, sds value) {
long long llval;

//哈希表实现
if (setobj->encoding == OBJ_ENCODING_HT) {
if (dictDelete(setobj->ptr,value) == DICT_OK) {
if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
return 1;
}
} else if (setobj->encoding == OBJ_ENCODING_INTSET) {
//数组实现
if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
int success;
setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
if (success) return 1;
}
} else {
serverPanic("Unknown set encoding");
}
return 0;
}
setTypeIsMember

查找元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int setTypeIsMember(robj *subject, sds value) {
long long llval;

//哈希表实现
if (subject->encoding == OBJ_ENCODING_HT) {
return dictFind((dict*)subject->ptr,value) != NULL;
} else if (subject->encoding == OBJ_ENCODING_INTSET) {
//数组实现
if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
return intsetFind((intset*)subject->ptr,llval);
}
} else {
serverPanic("Unknown set encoding");
}
return 0;
}
setTypeInitIterator

获取迭代器

1
2
3
4
5
6
7
8
9
10
11
12
13
setTypeIterator *setTypeInitIterator(robj *subject) {
setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
si->subject = subject;
si->encoding = subject->encoding;
if (si->encoding == OBJ_ENCODING_HT) {
si->di = dictGetIterator(subject->ptr);
} else if (si->encoding == OBJ_ENCODING_INTSET) {
si->ii = 0;
} else {
serverPanic("Unknown set encoding");
}
return si;
}
setTypeReleaseIterator

释放迭代器

1
2
3
4
5
void setTypeReleaseIterator(setTypeIterator *si) {
if (si->encoding == OBJ_ENCODING_HT)
dictReleaseIterator(si->di);
zfree(si);
}
setTypeNext

通过迭代器获取下一个元素的内容,返回实现类型(哈希表/数组),根据返回类型数据存放在sdsele或者llele

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int setTypeNext(setTypeIterator *si, sds *sdsele, int64_t *llele) {
if (si->encoding == OBJ_ENCODING_HT) {
dictEntry *de = dictNext(si->di);
if (de == NULL) return -1;
*sdsele = dictGetKey(de);
*llele = -123456789; /* Not needed. Defensive. */
} else if (si->encoding == OBJ_ENCODING_INTSET) {
if (!intsetGet(si->subject->ptr,si->ii++,llele))
return -1;
*sdsele = NULL; /* Not needed. Defensive. */
} else {
serverPanic("Wrong set encoding in setTypeNext");
}
return si->encoding;
}
setTypeNextObject

通过迭代器获取下一个元素,返回字符串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
sds setTypeNextObject(setTypeIterator *si) {
int64_t intele;
sds sdsele;
int encoding;

//获取实现方式&构造字符串
encoding = setTypeNext(si,&sdsele,&intele);
switch(encoding) {
case -1: return NULL;
case OBJ_ENCODING_INTSET:
return sdsfromlonglong(intele);
case OBJ_ENCODING_HT:
return sdsdup(sdsele);
default:
serverPanic("Unsupported encoding");
}
return NULL;
}
setTypeRandomElement

通过迭代器随机获取一个元素的内容,返回实现类型(哈希表/数组),根据返回类型数据存放在sdsele或者llele

1
2
3
4
5
6
7
8
9
10
11
12
13
int setTypeRandomElement(robj *setobj, sds *sdsele, int64_t *llele) {
if (setobj->encoding == OBJ_ENCODING_HT) {
dictEntry *de = dictGetRandomKey(setobj->ptr);
*sdsele = dictGetKey(de);
*llele = -123456789; /* Not needed. Defensive. */
} else if (setobj->encoding == OBJ_ENCODING_INTSET) {
*llele = intsetRandom(setobj->ptr);
*sdsele = NULL; /* Not needed. Defensive. */
} else {
serverPanic("Unknown set encoding");
}
return setobj->encoding;
}
setTypeSize

获取集合大小

1
2
3
4
5
6
7
8
9
unsigned long setTypeSize(const robj *subject) {
if (subject->encoding == OBJ_ENCODING_HT) {
return dictSize((const dict*)subject->ptr);
} else if (subject->encoding == OBJ_ENCODING_INTSET) {
return intsetLen((const intset*)subject->ptr);
} else {
serverPanic("Unknown set encoding");
}
}
setTypeConvert

转换集合底层实现方式,只能从数组到哈希表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
void setTypeConvert(robj *setobj, int enc) {
setTypeIterator *si;
serverAssertWithInfo(NULL,setobj,setobj->type == OBJ_SET &&
setobj->encoding == OBJ_ENCODING_INTSET);

if (enc == OBJ_ENCODING_HT) {
int64_t intele;
dict *d = dictCreate(&setDictType,NULL);
sds element;

//提前设置哈希大小,防止操作期间出现rehash
dictExpand(d,intsetLen(setobj->ptr));

//依次读取数组内容,转换成字符串,插入哈希表
si = setTypeInitIterator(setobj);
while (setTypeNext(si,&element,&intele) != -1) {
element = sdsfromlonglong(intele);
serverAssert(dictAdd(d,element,NULL) == DICT_OK);
}
setTypeReleaseIterator(si);

//更改集合实现方式为哈希表
setobj->encoding = OBJ_ENCODING_HT;
zfree(setobj->ptr);
setobj->ptr = d;
} else {
serverPanic("Unsupported set conversion");
}
}
saddCommand

响应sadd命令,添加数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void saddCommand(client *c) {
robj *set;
int j, added = 0;

//确定集合名称
set = lookupKeyWrite(c->db,c->argv[1]);

//第一次,创建集合
if (set == NULL) {
set = setTypeCreate(c->argv[2]->ptr);
dbAdd(c->db,c->argv[1],set);
} else {
//校验key对应的类型
if (set->type != OBJ_SET) {
addReply(c,shared.wrongtypeerr);
return;
}
}

//依次添加数据
for (j = 2; j < c->argc; j++) {
if (setTypeAdd(set,c->argv[j]->ptr)) added++;
}

//通知
if (added) {
signalModifiedKey(c->db,c->argv[1]);
notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[1],c->db->id);
}

//更新数据库key改变的次数
server.dirty += added;
addReplyLongLong(c,added);
}
sremCommand

响应srem命令,删除元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void sremCommand(client *c) {
robj *set;
int j, deleted = 0, keyremoved = 0;

//确定集合存在
if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
checkType(c,set,OBJ_SET)) return;

//依次删除
for (j = 2; j < c->argc; j++) {
if (setTypeRemove(set,c->argv[j]->ptr)) {
deleted++;
if (setTypeSize(set) == 0) {
dbDelete(c->db,c->argv[1]);
keyremoved = 1;
break;
}
}
}

//通知
if (deleted) {
signalModifiedKey(c->db,c->argv[1]);
notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
if (keyremoved)
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],
c->db->id);
server.dirty += deleted;
}
addReplyLongLong(c,deleted);
}
smoveCommand

响应smove命令,把一个元素从一个集合移动到另一个集合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
void smoveCommand(client *c) {
robj *srcset, *dstset, *ele;
srcset = lookupKeyWrite(c->db,c->argv[1]);
dstset = lookupKeyWrite(c->db,c->argv[2]);
ele = c->argv[3];

//校验源头是否存在
if (srcset == NULL) {
addReply(c,shared.czero);
return;
}

//校验源数据类型和目的类型
if (checkType(c,srcset,OBJ_SET) ||
(dstset && checkType(c,dstset,OBJ_SET))) return;

//校验源和目的是否一致
if (srcset == dstset) {
addReply(c,setTypeIsMember(srcset,ele->ptr) ?
shared.cone : shared.czero);
return;
}

//删除源
if (!setTypeRemove(srcset,ele->ptr)) {
addReply(c,shared.czero);
return;
}
notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);

//删除后,是否变成空了
if (setTypeSize(srcset) == 0) {
dbDelete(c->db,c->argv[1]);
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
}

//目的不存在,新建集合
if (!dstset) {
dstset = setTypeCreate(ele->ptr);
dbAdd(c->db,c->argv[2],dstset);
}

signalModifiedKey(c->db,c->argv[1]);
signalModifiedKey(c->db,c->argv[2]);
server.dirty++;

//目的插入值
if (setTypeAdd(dstset,ele->ptr)) {
server.dirty++;
notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[2],c->db->id);
}
addReply(c,shared.cone);
}
sismemberCommand

响应sismember命令,确定元素是否在集合中

1
2
3
4
5
6
7
8
9
10
11
12
13
void sismemberCommand(client *c) {
robj *set;

//确定集合
if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
checkType(c,set,OBJ_SET)) return;

//查找
if (setTypeIsMember(set,c->argv[2]->ptr))
addReply(c,shared.cone);
else
addReply(c,shared.czero);
}
scardCommand

响应scard命令,获取集合大小

1
2
3
4
5
6
7
8
void scardCommand(client *c) {
robj *o;

if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
checkType(c,o,OBJ_SET)) return;

addReplyLongLong(c,setTypeSize(o));
}
spopWithCountCommand

随机删除一个或多个元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#define SPOP_MOVE_STRATEGY_MUL 5

void spopWithCountCommand(client *c) {
long l;
unsigned long count, size;
robj *set;

//获取要删除的个数
if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
if (l >= 0) {
count = (unsigned long) l;
} else {
addReply(c,shared.outofrangeerr);
return;
}

//确定集合
if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
== NULL || checkType(c,set,OBJ_SET)) return;

if (count == 0) {
addReply(c,shared.emptymultibulk);
return;
}

size = setTypeSize(set);

notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
server.dirty += count;

//要删除的数量大于集合的大小
if (count >= size) {

sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);

//删除db里key
dbDelete(c->db,c->argv[1]);
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);

//重写命令为del操作
rewriteClientCommandVector(c,2,shared.del,c->argv[1]);
signalModifiedKey(c->db,c->argv[1]);
server.dirty++;
return;
}

//删除的数量小于集合的大小时,等价于srem命令
robj *propargv[3];
propargv[0] = createStringObject("SREM",4);
propargv[1] = c->argv[1];
addReplyMultiBulkLen(c,count);

sds sdsele;
robj *objele;
int encoding;
int64_t llele;

//集合内剩余的数量
unsigned long remaining = size-count;

//删除的数量小于集合的大小时,根据要删除的数量占集合大小的比例不同,分为两种情况:1.要删除的很少时,那么就正常操作,随机取n个元素;2.要删除的很多时,由哈希表实现的集合使用率会越来越低,随机获取元素的代价也会随之越来越高,会影响执行效率,这时候反向操作,取需要留下来的数据

//需要删除的量少的情况
if (remaining*SPOP_MOVE_STRATEGY_MUL > count) {
while(count--) {

//随机获取元素
encoding = setTypeRandomElement(set,&sdsele,&llele);
if (encoding == OBJ_ENCODING_INTSET) {
addReplyBulkLongLong(c,llele);
objele = createStringObjectFromLongLong(llele);
set->ptr = intsetRemove(set->ptr,llele,NULL);
} else {
addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
objele = createStringObject(sdsele,sdslen(sdsele));
setTypeRemove(set,sdsele);
}

//主从同步/aof srem操作
propargv[2] = objele;
alsoPropagate(server.sremCommand,c->db->id,propargv,3,
PROPAGATE_AOF|PROPAGATE_REPL);
decrRefCount(objele);
}
} else {
//需要删除的量多的情况
robj *newset = NULL;

//反向获取需要留下的数据
while(remaining--) {
encoding = setTypeRandomElement(set,&sdsele,&llele);
if (encoding == OBJ_ENCODING_INTSET) {
sdsele = sdsfromlonglong(llele);
} else {
sdsele = sdsdup(sdsele);
}
if (!newset) newset = setTypeCreate(sdsele);
setTypeAdd(newset,sdsele);
setTypeRemove(set,sdsele);
sdsfree(sdsele);
}

//原先的集合是要删除的,发送给客户端
setTypeIterator *si;
si = setTypeInitIterator(set);
while((encoding = setTypeNext(si,&sdsele,&llele)) != -1) {
if (encoding == OBJ_ENCODING_INTSET) {
addReplyBulkLongLong(c,llele);
objele = createStringObjectFromLongLong(llele);
} else {
addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
objele = createStringObject(sdsele,sdslen(sdsele));
}

//主从同步/aof srem操作
propargv[2] = objele;
alsoPropagate(server.sremCommand,c->db->id,propargv,3,
PROPAGATE_AOF|PROPAGATE_REPL);
decrRefCount(objele);
}
setTypeReleaseIterator(si);

//更新数据库的值
dbOverwrite(c->db,c->argv[1],newset);
}

//不同步命令本身,因为前面已经使用alsoPropagate把数据作为srem操作同步了
decrRefCount(propargv[0]);
preventCommandPropagation(c);
signalModifiedKey(c->db,c->argv[1]);
server.dirty++;
}
spopCommand

响应spop命令,随机返回并删除一个或多个元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
void spopCommand(client *c) {
robj *set, *ele, *aux;
sds sdsele;
int64_t llele;
int encoding;

//校验参数个数
if (c->argc == 3) {
//调用上面spopWithCountCommand命令
spopWithCountCommand(c);
return;
} else if (c->argc > 3) {
addReply(c,shared.syntaxerr);
return;
}

//检查key存在且对应的value为集合
if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
checkType(c,set,OBJ_SET)) return;

//没有指定数量时,默认为1,随机获取一个
encoding = setTypeRandomElement(set,&sdsele,&llele);

//删除
if (encoding == OBJ_ENCODING_INTSET) {
ele = createStringObjectFromLongLong(llele);
set->ptr = intsetRemove(set->ptr,llele,NULL);
} else {
ele = createStringObject(sdsele,sdslen(sdsele));
setTypeRemove(set,ele->ptr);
}

//发布订阅通知
notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);

//主从同步/aof
aux = createStringObject("SREM",4);
rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
decrRefCount(aux);

//添加客户端返回
addReplyBulk(c,ele);
decrRefCount(ele);

//集合删除后为空时,删除key
if (setTypeSize(set) == 0) {
dbDelete(c->db,c->argv[1]);
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
}

signalModifiedKey(c->db,c->argv[1]);
server.dirty++;
}
srandmemberWithCountCommand

随机获取一个或多个元素,当参数为负数时,表示返回的结果内,同一元素可以重复出现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#define SRANDMEMBER_SUB_STRATEGY_MUL 3

void srandmemberWithCountCommand(client *c) {
long l;
unsigned long count, size;

//结果集内元素是否唯一
int uniq = 1;
robj *set;
sds ele;
int64_t llele;
int encoding;

dict *d;

//获取数量
if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
if (l >= 0) {
count = (unsigned long) l;
} else {
//负数表示返回集合内一个元素可以重复出现
count = -l;
uniq = 0;
}

//确定key存在,且value是集合
if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptymultibulk))
== NULL || checkType(c,set,OBJ_SET)) return;
size = setTypeSize(set);

//集合为空
if (count == 0) {
addReply(c,shared.emptymultibulk);
return;
}

//可以重复出现的情况
if (!uniq) {
addReplyMultiBulkLen(c,count);
while(count--) {
encoding = setTypeRandomElement(set,&ele,&llele);
if (encoding == OBJ_ENCODING_INTSET) {
addReplyBulkLongLong(c,llele);
} else {
addReplyBulkCBuffer(c,ele,sdslen(ele));
}
}
return;
}

//要求的数量大于集合的数量,直接返回整个集合
if (count >= size) {
sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
return;
}

//辅助字典
d = dictCreate(&objectKeyPointerValueDictType,NULL);

//如果需要获取的数量占集合数量的比例比较高,创建一个新的哈希表把集合内的数据拷贝过去,然后从新的哈希表随机获取元素,然后删除保证不会重复
if (count*SRANDMEMBER_SUB_STRATEGY_MUL > size) {
setTypeIterator *si;

//把集合内的所有数据拷贝一份到哈希表中
si = setTypeInitIterator(set);
while((encoding = setTypeNext(si,&ele,&llele)) != -1) {
int retval = DICT_ERR;

if (encoding == OBJ_ENCODING_INTSET) {
retval = dictAdd(d,createStringObjectFromLongLong(llele),NULL);
} else {
retval = dictAdd(d,createStringObject(ele,sdslen(ele)),NULL);
}
serverAssert(retval == DICT_OK);
}
setTypeReleaseIterator(si);
serverAssert(dictSize(d) == size);

//随机获取&删除
while(size > count) {
dictEntry *de;

de = dictGetRandomKey(d);
dictDelete(d,dictGetKey(de));
size--;
}
}

//如果需要获取的数量占集合数量的比例比较低时,直接随机获取,直到不重复的数量满足为止
else {
unsigned long added = 0;
robj *objele;

while(added < count) {
encoding = setTypeRandomElement(set,&ele,&llele);
if (encoding == OBJ_ENCODING_INTSET) {
objele = createStringObjectFromLongLong(llele);
} else {
objele = createStringObject(ele,sdslen(ele));
}

//尝试添加,失败的话再循环
if (dictAdd(d,objele,NULL) == DICT_OK)
added++;
else
decrRefCount(objele);
}
}

//返回结果
{
dictIterator *di;
dictEntry *de;

addReplyMultiBulkLen(c,count);
di = dictGetIterator(d);
while((de = dictNext(di)) != NULL)
addReplyBulk(c,dictGetKey(de));
dictReleaseIterator(di);
dictRelease(d);
}
}
srandmemberCommand

响应srandmember命令,随机获取一个或多个元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
void srandmemberCommand(client *c) {
robj *set;
sds ele;
int64_t llele;
int encoding;

//校验参数
if (c->argc == 3) {
srandmemberWithCountCommand(c);
return;
} else if (c->argc > 3) {
addReply(c,shared.syntaxerr);
return;
}

//没有指定数量,默认一个
if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.nullbulk)) == NULL ||
checkType(c,set,OBJ_SET)) return;

encoding = setTypeRandomElement(set,&ele,&llele);
if (encoding == OBJ_ENCODING_INTSET) {
addReplyBulkLongLong(c,llele);
} else {
addReplyBulkCBuffer(c,ele,sdslen(ele));
}
}
qsortCompareSetsByCardinality

集合快速排序比较方法,根据集合大小

1
2
3
4
5
int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
if (setTypeSize(*(robj**)s1) > setTypeSize(*(robj**)s2)) return 1;
if (setTypeSize(*(robj**)s1) < setTypeSize(*(robj**)s2)) return -1;
return 0;
}
qsortCompareSetsByRevCardinality

集合快速排序比较方法,根据集合大小倒序排,用于计算集合差集时用

1
2
3
4
5
6
7
8
9
int qsortCompareSetsByRevCardinality(const void *s1, const void *s2) {
robj *o1 = *(robj**)s1, *o2 = *(robj**)s2;
unsigned long first = o1 ? setTypeSize(o1) : 0;
unsigned long second = o2 ? setTypeSize(o2) : 0;

if (first < second) return 1;
if (first > second) return -1;
return 0;
}
sinterGenericCommand

计算集合交集通用方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
void sinterGenericCommand(client *c, robj **setkeys,
unsigned long setnum, robj *dstkey) {
robj **sets = zmalloc(sizeof(robj*)*setnum);
setTypeIterator *si;
robj *dstset = NULL;
sds elesds;
int64_t intobj;
void *replylen = NULL;
unsigned long j, cardinality = 0;
int encoding;

//校验所有集合都是存在的
for (j = 0; j < setnum; j++) {
robj *setobj = dstkey ?
lookupKeyWrite(c->db,setkeys[j]) :
lookupKeyRead(c->db,setkeys[j]);
if (!setobj) {
zfree(sets);
if (dstkey) {
if (dbDelete(c->db,dstkey)) {
signalModifiedKey(c->db,dstkey);
server.dirty++;
}
addReply(c,shared.czero);
} else {
addReply(c,shared.emptymultibulk);
}
return;
}
if (checkType(c,setobj,OBJ_SET)) {
zfree(sets);
return;
}
sets[j] = setobj;
}

//按照集合大小排序,从小到大
qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);

if (!dstkey) {
replylen = addDeferredMultiBulkLength(c);
} else {
dstset = createIntsetObject();
}

//排序后,第一个集合是数量最小的,检查第一个集合的每一个元素在其他集合内是否存在,如果没有其他任何集合有的话,跳过
si = setTypeInitIterator(sets[0]);
while((encoding = setTypeNext(si,&elesds,&intobj)) != -1) {
for (j = 1; j < setnum; j++) {
if (sets[j] == sets[0]) continue;

//第一个集合是数组实现
if (encoding == OBJ_ENCODING_INTSET) {

//都是数组实现的情况,很快
if (sets[j]->encoding == OBJ_ENCODING_INTSET &&
!intsetFind((intset*)sets[j]->ptr,intobj))
{
break;

//数组和哈希
} else if (sets[j]->encoding == OBJ_ENCODING_HT) {
elesds = sdsfromlonglong(intobj);
if (!setTypeIsMember(sets[j],elesds)) {
sdsfree(elesds);
break;
}
sdsfree(elesds);
}
} else if (encoding == OBJ_ENCODING_HT) {
if (!setTypeIsMember(sets[j],elesds)) {
break;
}
}
}

//所有集合都有该元素的情况
if (j == setnum) {
if (!dstkey) {
if (encoding == OBJ_ENCODING_HT)
addReplyBulkCBuffer(c,elesds,sdslen(elesds));
else
addReplyBulkLongLong(c,intobj);
cardinality++;
} else {
if (encoding == OBJ_ENCODING_INTSET) {
elesds = sdsfromlonglong(intobj);
setTypeAdd(dstset,elesds);
sdsfree(elesds);
} else {
setTypeAdd(dstset,elesds);
}
}
}
}
setTypeReleaseIterator(si);

//返回结果
if (dstkey) {
int deleted = dbDelete(c->db,dstkey);

//交集不为空
if (setTypeSize(dstset) > 0) {
dbAdd(c->db,dstkey,dstset);
addReplyLongLong(c,setTypeSize(dstset));
notifyKeyspaceEvent(NOTIFY_SET,"sinterstore",
dstkey,c->db->id);
} else {
decrRefCount(dstset);
addReply(c,shared.czero);
if (deleted)
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
dstkey,c->db->id);
}
signalModifiedKey(c->db,dstkey);
server.dirty++;
} else {
setDeferredMultiBulkLength(c,replylen,cardinality);
}
zfree(sets);
}
sinterCommand

响应sinter命令,获取集合的交集

1
2
3
void sinterCommand(client *c) {
sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
}
sinterstoreCommand

响应sinterstore命令,获取集合的交集,并把结果存起来

1
2
3
void sinterstoreCommand(client *c) {
sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
}
sunionDiffGenericCommand

获取集合并集和差集,通用操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#define SET_OP_UNION 0
#define SET_OP_DIFF 1
#define SET_OP_INTER 2

void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
robj *dstkey, int op) {
robj **sets = zmalloc(sizeof(robj*)*setnum);
setTypeIterator *si;
robj *dstset = NULL;
sds ele;
int j, cardinality = 0;
int diff_algo = 1;

//过滤空集合
for (j = 0; j < setnum; j++) {
robj *setobj = dstkey ?
lookupKeyWrite(c->db,setkeys[j]) :
lookupKeyRead(c->db,setkeys[j]);
if (!setobj) {
sets[j] = NULL;
continue;
}
if (checkType(c,setobj,OBJ_SET)) {
zfree(sets);
return;
}
sets[j] = setobj;
}

//求差集的算法有两种,根据集合的特点确定那种更高效些
//第一种时间复杂度是O(N*M)N是第一个集合的大小,M是集合的数量,以第一个集合为基准,遍历确定每一个元素在其他集合内是否存在,都不存在则为有效元素
//第二种时间负责度是O(N)N是所有集合的大小的和,以第一个集合为基准,遍历其他集合内的每一个元素,从第一个集合内删除,最后第一集合剩余的为有效元素
if (op == SET_OP_DIFF && sets[0]) {
long long algo_one_work = 0, algo_two_work = 0;

for (j = 0; j < setnum; j++) {
if (sets[j] == NULL) continue;

algo_one_work += setTypeSize(sets[0]);
algo_two_work += setTypeSize(sets[j]);
}

//第一种算法更有优势,如果当一些元素都有的话
algo_one_work /= 2;

//确定使用哪种算法
diff_algo = (algo_one_work <= algo_two_work) ? 1 : 2;

if (diff_algo == 1 && setnum > 1) {
//第一种算法的话,需要根据集合大小倒序排,注意从第二个集合开始,差集是相对第一个集合的
qsort(sets+1,setnum-1,sizeof(robj*),
qsortCompareSetsByRevCardinality);
}
}

dstset = createIntsetObject();

//求并集,直接把每个集合的元素添加进去
if (op == SET_OP_UNION) {
for (j = 0; j < setnum; j++) {
if (!sets[j]) continue;

si = setTypeInitIterator(sets[j]);
while((ele = setTypeNextObject(si)) != NULL) {
if (setTypeAdd(dstset,ele)) cardinality++;
sdsfree(ele);
}
setTypeReleaseIterator(si);
}
} else if (op == SET_OP_DIFF && sets[0] && diff_algo == 1) {
//差集第一种算法
si = setTypeInitIterator(sets[0]);
while((ele = setTypeNextObject(si)) != NULL) {
for (j = 1; j < setnum; j++) {
if (!sets[j]) continue;
if (sets[j] == sets[0]) break;
if (setTypeIsMember(sets[j],ele)) break;
}

//其他集合都没有这个元素,通过
if (j == setnum) {
setTypeAdd(dstset,ele);
cardinality++;
}
sdsfree(ele);
}
setTypeReleaseIterator(si);
} else if (op == SET_OP_DIFF && sets[0] && diff_algo == 2) {
//差集第二种算法
for (j = 0; j < setnum; j++) {
if (!sets[j]) continue;

si = setTypeInitIterator(sets[j]);
while((ele = setTypeNextObject(si)) != NULL) {
if (j == 0) {
if (setTypeAdd(dstset,ele)) cardinality++;
} else {
if (setTypeRemove(dstset,ele)) cardinality--;
}
sdsfree(ele);
}
setTypeReleaseIterator(si);

if (cardinality == 0) break;
}
}

//直接返回
if (!dstkey) {
addReplyMultiBulkLen(c,cardinality);
si = setTypeInitIterator(dstset);
while((ele = setTypeNextObject(si)) != NULL) {
addReplyBulkCBuffer(c,ele,sdslen(ele));
sdsfree(ele);
}
setTypeReleaseIterator(si);
decrRefCount(dstset);
} else {
//输出保存为新key
int deleted = dbDelete(c->db,dstkey);
if (setTypeSize(dstset) > 0) {
dbAdd(c->db,dstkey,dstset);
addReplyLongLong(c,setTypeSize(dstset));
notifyKeyspaceEvent(NOTIFY_SET,
op == SET_OP_UNION ? "sunionstore" : "sdiffstore",
dstkey,c->db->id);
} else {
decrRefCount(dstset);
addReply(c,shared.czero);
if (deleted)
notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
dstkey,c->db->id);
}
signalModifiedKey(c->db,dstkey);
server.dirty++;
}
zfree(sets);
}
sunionCommand

响应sunion命令,获取集合的并集

1
2
3
void sunionCommand(client *c) {
sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_UNION);
}
sunionstoreCommand

响应sunionstore命令,获取集合的并集,并把结果存储到新的key中

1
2
3
void sunionstoreCommand(client *c) {
sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_UNION);
}
sdiffCommand

响应sdiff命令,获取集合的差集

1
2
3
void sdiffCommand(client *c) {
sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_DIFF);
}
sdiffstoreCommand

响应sdiffstore命令,获取集合的差集

1
2
3
void sdiffstoreCommand(client *c) {
sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_DIFF);
}
sscanCommand

响应sscan命令,迭代获取集合内的元素

1
2
3
4
5
6
7
8
9
void sscanCommand(client *c) {
robj *set;
unsigned long cursor;

if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return;
if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL ||
checkType(c,set,OBJ_SET)) return;
scanGenericCommand(c,set,cursor);
}

数组实现

数据结构

数组实现集合的在inset.hintset.c文件中。数据结构很简单,只有三个属性,且集合的每一次新增、删除都是通过remalloc实现,不涉及其他的更高级用法。且在新增的元素的时候,数组始终是从小到大有序的。查询的时候采用二分查找法。

1
2
3
4
5
typedef struct intset {
uint32_t encoding; //编码类型,区分元素的大小范围,为int16、int32、int64三种
uint32_t length; //集合大小
int8_t contents[]; //集合具体内容
} intset;
编码类型

用来标识集合内元素的最大值的范围

1
2
3
#define INTSET_ENC_INT16 (sizeof(int16_t))
#define INTSET_ENC_INT32 (sizeof(int32_t))
#define INTSET_ENC_INT64 (sizeof(int64_t))
_intsetValueEncoding

获取一个整数的编码类型

1
2
3
4
5
6
7
8
static uint8_t _intsetValueEncoding(int64_t v) {
if (v < INT32_MIN || v > INT32_MAX)
return INTSET_ENC_INT64;
else if (v < INT16_MIN || v > INT16_MAX)
return INTSET_ENC_INT32;
else
return INTSET_ENC_INT16;
}
_intsetGetEncoded

获取集合内具体索引位置的数值,需要指定编码类型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
static int64_t _intsetGetEncoded(intset *is, int pos, uint8_t enc) {
int64_t v64;
int32_t v32;
int16_t v16;

if (enc == INTSET_ENC_INT64) {
memcpy(&v64,((int64_t*)is->contents)+pos,sizeof(v64));
memrev64ifbe(&v64);
return v64;
} else if (enc == INTSET_ENC_INT32) {
memcpy(&v32,((int32_t*)is->contents)+pos,sizeof(v32));
memrev32ifbe(&v32);
return v32;
} else {
memcpy(&v16,((int16_t*)is->contents)+pos,sizeof(v16));
memrev16ifbe(&v16);
return v16;
}
}
_intsetGet

获取集合内具体索引位置的数值,采用集合的自己的编码类型

1
2
3
static int64_t _intsetGet(intset *is, int pos) {
return _intsetGetEncoded(is,pos,intrev32ifbe(is->encoding));
}
_intsetSet

写入集合内具体索引位置的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
static void _intsetSet(intset *is, int pos, int64_t value) {
uint32_t encoding = intrev32ifbe(is->encoding);

if (encoding == INTSET_ENC_INT64) {
((int64_t*)is->contents)[pos] = value;
memrev64ifbe(((int64_t*)is->contents)+pos);
} else if (encoding == INTSET_ENC_INT32) {
((int32_t*)is->contents)[pos] = value;
memrev32ifbe(((int32_t*)is->contents)+pos);
} else {
((int16_t*)is->contents)[pos] = value;
memrev16ifbe(((int16_t*)is->contents)+pos);
}
}
intsetNew

创建新的集合

1
2
3
4
5
6
intset *intsetNew(void) {
intset *is = zmalloc(sizeof(intset));
is->encoding = intrev32ifbe(INTSET_ENC_INT16);
is->length = 0;
return is;
}
intsetResize

调整集合的大小,在新增和删除元素的时候会用到

1
2
3
4
5
static intset *intsetResize(intset *is, uint32_t len) {
uint32_t size = len*intrev32ifbe(is->encoding);
is = zrealloc(is,sizeof(intset)+size);
return is;
}
intsetSearch

二分查找数据,返回值表示查询结果1找到 0没找到,pos表示值在数组中的索引,如果没有找到的话,pos表示需要再此位置进行插入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
static uint8_t intsetSearch(intset *is, int64_t value, uint32_t *pos) {
int min = 0, max = intrev32ifbe(is->length)-1, mid = -1;
int64_t cur = -1;

//集合为空,直接不用找
if (intrev32ifbe(is->length) == 0) {
if (pos) *pos = 0;
return 0;
} else {

//因为是数组是有序的,先检查最大和最小边界情况
if (value > _intsetGet(is,max)) {
if (pos) *pos = intrev32ifbe(is->length);
return 0;
} else if (value < _intsetGet(is,0)) {
if (pos) *pos = 0;
return 0;
}
}

//二分查找
while(max >= min) {
mid = ((unsigned int)min + (unsigned int)max) >> 1;
cur = _intsetGet(is,mid);
if (value > cur) {
min = mid+1;
} else if (value < cur) {
max = mid-1;
} else {
break;
}
}

if (value == cur) {
if (pos) *pos = mid;
return 1;
} else {
if (pos) *pos = min;
return 0;
}
}
intsetUpgradeAndAdd

升级编码类型并插入数据,用于新插入数据编码类型大于集合的编码类型的情况

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
static intset *intsetUpgradeAndAdd(intset *is, int64_t value) {
uint8_t curenc = intrev32ifbe(is->encoding);
uint8_t newenc = _intsetValueEncoding(value);
int length = intrev32ifbe(is->length);

//确定头插还是尾插,出现此种情况,新的数据要不是最小的,要不是最大的
int prepend = value < 0 ? 1 : 0;

//更新集合的编码类型
is->encoding = intrev32ifbe(newenc);

//集合大小+1
is = intsetResize(is,intrev32ifbe(is->length)+1);

//从后往前把旧数据按照就编码取出,按照新的编码写入
while(length--)
_intsetSet(is,length+prepend,_intsetGetEncoded(is,length,curenc));

//头插/尾插
if (prepend)
_intsetSet(is,0,value);
else
_intsetSet(is,intrev32ifbe(is->length),value);

//更新集合大小
is->length = intrev32ifbe(intrev32ifbe(is->length)+1);
return is;
}
intsetMoveTail

移动数据,新增的时候调用此方法,给新数据腾位置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
static void intsetMoveTail(intset *is, uint32_t from, uint32_t to) {
void *src, *dst;
uint32_t bytes = intrev32ifbe(is->length)-from;
uint32_t encoding = intrev32ifbe(is->encoding);

if (encoding == INTSET_ENC_INT64) {
src = (int64_t*)is->contents+from;
dst = (int64_t*)is->contents+to;
bytes *= sizeof(int64_t);
} else if (encoding == INTSET_ENC_INT32) {
src = (int32_t*)is->contents+from;
dst = (int32_t*)is->contents+to;
bytes *= sizeof(int32_t);
} else {
src = (int16_t*)is->contents+from;
dst = (int16_t*)is->contents+to;
bytes *= sizeof(int16_t);
}
memmove(dst,src,bytes);
}
intsetAdd

插入新元素,success表示成功与否

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
intset *intsetAdd(intset *is, int64_t value, uint8_t *success) {
uint8_t valenc = _intsetValueEncoding(value);
uint32_t pos;
if (success) *success = 1;

//需要升级编码类型
if (valenc > intrev32ifbe(is->encoding)) {
return intsetUpgradeAndAdd(is,value);
} else {
//检查是否已经存在
if (intsetSearch(is,value,&pos)) {
if (success) *success = 0;
return is;
}

//不存在,数组容量+1,移动数据腾位置
is = intsetResize(is,intrev32ifbe(is->length)+1);
if (pos < intrev32ifbe(is->length)) intsetMoveTail(is,pos,pos+1);
}

//写入新数据&更新长度
_intsetSet(is,pos,value);
is->length = intrev32ifbe(intrev32ifbe(is->length)+1);
return is;
}
intsetRemove

删除元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
intset *intsetRemove(intset *is, int64_t value, int *success) {
uint8_t valenc = _intsetValueEncoding(value);
uint32_t pos;
if (success) *success = 0;

//先判断编码,再查找
if (valenc <= intrev32ifbe(is->encoding) && intsetSearch(is,value,&pos)) {
uint32_t len = intrev32ifbe(is->length);

if (success) *success = 1;

//移动数据,更新长度
if (pos < (len-1)) intsetMoveTail(is,pos+1,pos);
is = intsetResize(is,len-1);
is->length = intrev32ifbe(len-1);
}
return is;
}
intsetFind

查找元素

1
2
3
4
uint8_t intsetFind(intset *is, int64_t value) {
uint8_t valenc = _intsetValueEncoding(value);
return valenc <= intrev32ifbe(is->encoding) && intsetSearch(is,value,NULL);
}
intsetRandom

随机获取元素

1
2
3
int64_t intsetRandom(intset *is) {
return _intsetGet(is,rand()%intrev32ifbe(is->length));
}
intsetGet

获取指定位置的数据

1
2
3
4
5
6
7
uint8_t intsetGet(intset *is, uint32_t pos, int64_t *value) {
if (pos < intrev32ifbe(is->length)) {
*value = _intsetGet(is,pos);
return 1;
}
return 0;
}
intsetLen

获取集合大小

1
2
3
uint32_t intsetLen(const intset *is) {
return intrev32ifbe(is->length);
}
intsetBlobLen

获取集合所用的内容大小

1
2
3
size_t intsetBlobLen(intset *is) {
return sizeof(intset)+intrev32ifbe(is->length)*intrev32ifbe(is->encoding);
}

哈希表实现

参照哈希表文章