复制项目

This commit is contained in:
kim.dev.6789
2026-01-14 22:16:44 +08:00
parent e2577b8cee
commit e50142a3b9
691 changed files with 97009 additions and 1 deletions

View File

@@ -0,0 +1,17 @@
package cache
import (
"context"
)
// BatchDeleter interface defines a set of methods for batch deleting cache and publishing deletion information.
type BatchDeleter interface {
//ChainExecDel method is used for chain calls and must call Clone to prevent memory pollution.
ChainExecDel(ctx context.Context) error
//ExecDelWithKeys method directly takes keys for deletion.
ExecDelWithKeys(ctx context.Context, keys []string) error
//Clone method creates a copy of the BatchDeleter to avoid modifying the original object.
Clone() BatchDeleter
//AddKeys method adds keys to be deleted.
AddKeys(keys ...string)
}

27
pkg/common/storage/cache/black.go vendored Normal file
View File

@@ -0,0 +1,27 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
)
type BlackCache interface {
BatchDeleter
CloneBlackCache() BlackCache
GetBlackIDs(ctx context.Context, userID string) (blackIDs []string, err error)
// del user's blackIDs msgCache, exec when a user's black list changed
DelBlackIDs(ctx context.Context, userID string) BlackCache
}

View File

@@ -0,0 +1,8 @@
package cache
import "context"
type ClientConfigCache interface {
DeleteUserCache(ctx context.Context, userIDs []string) error
GetUserConfig(ctx context.Context, userID string) (map[string]string, error)
}

View File

@@ -0,0 +1,65 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
// arg fn will exec when no data in msgCache.
type ConversationCache interface {
BatchDeleter
CloneConversationCache() ConversationCache
// get user's conversationIDs from msgCache
GetUserConversationIDs(ctx context.Context, ownerUserID string) ([]string, error)
GetUserNotNotifyConversationIDs(ctx context.Context, userID string) ([]string, error)
GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error)
DelConversationIDs(userIDs ...string) ConversationCache
GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error)
DelUserConversationIDsHash(ownerUserIDs ...string) ConversationCache
// get one conversation from msgCache
GetConversation(ctx context.Context, ownerUserID, conversationID string) (*relationtb.Conversation, error)
DelConversations(ownerUserID string, conversationIDs ...string) ConversationCache
DelUsersConversation(conversationID string, ownerUserIDs ...string) ConversationCache
// get one conversation from msgCache
GetConversations(ctx context.Context, ownerUserID string,
conversationIDs []string) ([]*relationtb.Conversation, error)
// get one user's all conversations from msgCache
GetUserAllConversations(ctx context.Context, ownerUserID string) ([]*relationtb.Conversation, error)
// get user conversation recv msg from msgCache
GetUserRecvMsgOpt(ctx context.Context, ownerUserID, conversationID string) (opt int, err error)
DelUserRecvMsgOpt(ownerUserID, conversationID string) ConversationCache
// get one super group recv msg but do not notification userID list
// GetSuperGroupRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) (userIDs []string, err error)
DelSuperGroupRecvMsgNotNotifyUserIDs(groupID string) ConversationCache
// get one super group recv msg but do not notification userID list hash
// GetSuperGroupRecvMsgNotNotifyUserIDsHash(ctx context.Context, groupID string) (hash uint64, err error)
DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID string) ConversationCache
// GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
DelUserAllHasReadSeqs(ownerUserID string, conversationIDs ...string) ConversationCache
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
DelConversationNotReceiveMessageUserIDs(conversationIDs ...string) ConversationCache
DelConversationNotNotifyMessageUserIDs(userIDs ...string) ConversationCache
DelUserPinnedConversations(userIDs ...string) ConversationCache
DelConversationVersionUserIDs(userIDs ...string) ConversationCache
FindMaxConversationUserVersion(ctx context.Context, userID string) (*relationtb.VersionLog, error)
}

15
pkg/common/storage/cache/doc.go vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright © 2024 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache // import "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"

48
pkg/common/storage/cache/friend.go vendored Normal file
View File

@@ -0,0 +1,48 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
// FriendCache is an interface for caching friend-related data.
type FriendCache interface {
BatchDeleter
CloneFriendCache() FriendCache
GetFriendIDs(ctx context.Context, ownerUserID string) (friendIDs []string, err error)
// Called when friendID list changed
DelFriendIDs(ownerUserID ...string) FriendCache
// Get single friendInfo from the cache
GetFriend(ctx context.Context, ownerUserID, friendUserID string) (friend *relationtb.Friend, err error)
// Delete friend when friend info changed
DelFriend(ownerUserID, friendUserID string) FriendCache
// Delete friends when friends' info changed
DelFriends(ownerUserID string, friendUserIDs []string) FriendCache
DelOwner(friendUserID string, ownerUserIDs []string) FriendCache
DelMaxFriendVersion(ownerUserIDs ...string) FriendCache
//DelSortFriendUserIDs(ownerUserIDs ...string) FriendCache
//FindSortFriendUserIDs(ctx context.Context, ownerUserID string) ([]string, error)
//FindFriendIncrVersion(ctx context.Context, ownerUserID string, version uint, limit int) (*relationtb.VersionLog, error)
FindMaxFriendVersion(ctx context.Context, ownerUserID string) (*relationtb.VersionLog, error)
}

70
pkg/common/storage/cache/group.go vendored Normal file
View File

@@ -0,0 +1,70 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/common"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
type GroupHash interface {
GetGroupHash(ctx context.Context, groupID string) (uint64, error)
}
type GroupCache interface {
BatchDeleter
CloneGroupCache() GroupCache
GetGroupsInfo(ctx context.Context, groupIDs []string) (groups []*model.Group, err error)
GetGroupInfo(ctx context.Context, groupID string) (group *model.Group, err error)
DelGroupsInfo(groupIDs ...string) GroupCache
GetGroupMembersHash(ctx context.Context, groupID string) (hashCode uint64, err error)
GetGroupMemberHashMap(ctx context.Context, groupIDs []string) (map[string]*common.GroupSimpleUserID, error)
DelGroupMembersHash(groupID string) GroupCache
GetGroupMemberIDs(ctx context.Context, groupID string) (groupMemberIDs []string, err error)
DelGroupMemberIDs(groupID string) GroupCache
GetJoinedGroupIDs(ctx context.Context, userID string) (joinedGroupIDs []string, err error)
DelJoinedGroupID(userID ...string) GroupCache
GetGroupMemberInfo(ctx context.Context, groupID, userID string) (groupMember *model.GroupMember, err error)
GetGroupMembersInfo(ctx context.Context, groupID string, userID []string) (groupMembers []*model.GroupMember, err error)
GetAllGroupMembersInfo(ctx context.Context, groupID string) (groupMembers []*model.GroupMember, err error)
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) ([]*model.GroupMember, error)
GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
GetGroupOwner(ctx context.Context, groupID string) (*model.GroupMember, error)
GetGroupsOwner(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error)
DelGroupRoleLevel(groupID string, roleLevel []int32) GroupCache
DelGroupAllRoleLevel(groupID string) GroupCache
DelGroupMembersInfo(groupID string, userID ...string) GroupCache
GetGroupRoleLevelMemberInfo(ctx context.Context, groupID string, roleLevel int32) ([]*model.GroupMember, error)
GetGroupRolesLevelMemberInfo(ctx context.Context, groupID string, roleLevels []int32) ([]*model.GroupMember, error)
GetGroupMemberNum(ctx context.Context, groupID string) (memberNum int64, err error)
DelGroupsMemberNum(groupID ...string) GroupCache
//FindSortGroupMemberUserIDs(ctx context.Context, groupID string) ([]string, error)
//FindSortJoinGroupIDs(ctx context.Context, userID string) ([]string, error)
DelMaxGroupMemberVersion(groupIDs ...string) GroupCache
DelMaxJoinGroupVersion(userIDs ...string) GroupCache
FindMaxGroupMemberVersion(ctx context.Context, groupID string) (*model.VersionLog, error)
BatchFindMaxGroupMemberVersion(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error)
FindMaxJoinGroupVersion(ctx context.Context, userID string) (*model.VersionLog, error)
}

View File

@@ -0,0 +1,50 @@
package mcache
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/openimsdk/tools/s3/minio"
)
func NewMinioCache(cache database.Cache) minio.Cache {
return &minioCache{
cache: cache,
expireTime: time.Hour * 24 * 7,
}
}
type minioCache struct {
cache database.Cache
expireTime time.Duration
}
func (g *minioCache) getObjectImageInfoKey(key string) string {
return cachekey.GetObjectImageInfoKey(key)
}
func (g *minioCache) getMinioImageThumbnailKey(key string, format string, width int, height int) string {
return cachekey.GetMinioImageThumbnailKey(key, format, width, height)
}
func (g *minioCache) DelObjectImageInfoKey(ctx context.Context, keys ...string) error {
ks := make([]string, 0, len(keys))
for _, key := range keys {
ks = append(ks, g.getObjectImageInfoKey(key))
}
return g.cache.Del(ctx, ks)
}
func (g *minioCache) DelImageThumbnailKey(ctx context.Context, key string, format string, width int, height int) error {
return g.cache.Del(ctx, []string{g.getMinioImageThumbnailKey(key, format, width, height)})
}
func (g *minioCache) GetImageObjectKeyInfo(ctx context.Context, key string, fn func(ctx context.Context) (*minio.ImageInfo, error)) (*minio.ImageInfo, error) {
return getCache[*minio.ImageInfo](ctx, g.cache, g.getObjectImageInfoKey(key), g.expireTime, fn)
}
func (g *minioCache) GetThumbnailKey(ctx context.Context, key string, format string, width int, height int, minioCache func(ctx context.Context) (string, error)) (string, error) {
return getCache[string](ctx, g.cache, g.getMinioImageThumbnailKey(key, format, width, height), g.expireTime, minioCache)
}

View File

@@ -0,0 +1,132 @@
package mcache
import (
"context"
"strconv"
"sync"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache/lru"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
var (
memMsgCache lru.LRU[string, *model.MsgInfoModel]
initMemMsgCache sync.Once
)
func NewMsgCache(cache database.Cache, msgDocDatabase database.Msg) cache.MsgCache {
initMemMsgCache.Do(func() {
memMsgCache = lru.NewLazyLRU[string, *model.MsgInfoModel](1024*8, time.Hour, time.Second*10, localcache.EmptyTarget{}, nil)
})
return &msgCache{
cache: cache,
msgDocDatabase: msgDocDatabase,
memMsgCache: memMsgCache,
}
}
type msgCache struct {
cache database.Cache
msgDocDatabase database.Msg
memMsgCache lru.LRU[string, *model.MsgInfoModel]
}
func (x *msgCache) getSendMsgKey(id string) string {
return cachekey.GetSendMsgKey(id)
}
func (x *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32) error {
return x.cache.Set(ctx, x.getSendMsgKey(id), strconv.Itoa(int(status)), time.Hour*24)
}
func (x *msgCache) GetSendMsgStatus(ctx context.Context, id string) (int32, error) {
key := x.getSendMsgKey(id)
res, err := x.cache.Get(ctx, []string{key})
if err != nil {
return 0, err
}
val, ok := res[key]
if !ok {
return 0, errs.Wrap(redis.Nil)
}
status, err := strconv.Atoi(val)
if err != nil {
return 0, errs.WrapMsg(err, "GetSendMsgStatus strconv.Atoi error", "val", val)
}
return int32(status), nil
}
func (x *msgCache) getMsgCacheKey(conversationID string, seq int64) string {
return cachekey.GetMsgCacheKey(conversationID, seq)
}
func (x *msgCache) GetMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) ([]*model.MsgInfoModel, error) {
if len(seqs) == 0 {
return nil, nil
}
keys := make([]string, 0, len(seqs))
keySeq := make(map[string]int64, len(seqs))
for _, seq := range seqs {
key := x.getMsgCacheKey(conversationID, seq)
keys = append(keys, key)
keySeq[key] = seq
}
res, err := x.memMsgCache.GetBatch(keys, func(keys []string) (map[string]*model.MsgInfoModel, error) {
findSeqs := make([]int64, 0, len(keys))
for _, key := range keys {
seq, ok := keySeq[key]
if !ok {
continue
}
findSeqs = append(findSeqs, seq)
}
res, err := x.msgDocDatabase.FindSeqs(ctx, conversationID, seqs)
if err != nil {
return nil, err
}
kv := make(map[string]*model.MsgInfoModel)
for i := range res {
msg := res[i]
if msg == nil || msg.Msg == nil || msg.Msg.Seq <= 0 {
continue
}
key := x.getMsgCacheKey(conversationID, msg.Msg.Seq)
kv[key] = msg
}
return kv, nil
})
if err != nil {
return nil, err
}
return datautil.Values(res), nil
}
func (x msgCache) DelMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) error {
if len(seqs) == 0 {
return nil
}
for _, seq := range seqs {
x.memMsgCache.Del(x.getMsgCacheKey(conversationID, seq))
}
return nil
}
func (x *msgCache) SetMessageBySeqs(ctx context.Context, conversationID string, msgs []*model.MsgInfoModel) error {
for i := range msgs {
msg := msgs[i]
if msg == nil || msg.Msg == nil || msg.Msg.Seq <= 0 {
continue
}
x.memMsgCache.Set(x.getMsgCacheKey(conversationID, msg.Msg.Seq), msg)
}
return nil
}

View File

@@ -0,0 +1,82 @@
package mcache
import (
"context"
"sync"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
)
var (
globalOnlineCache cache.OnlineCache
globalOnlineOnce sync.Once
)
func NewOnlineCache() cache.OnlineCache {
globalOnlineOnce.Do(func() {
globalOnlineCache = &onlineCache{
user: make(map[string]map[int32]struct{}),
}
})
return globalOnlineCache
}
type onlineCache struct {
lock sync.RWMutex
user map[string]map[int32]struct{}
}
func (x *onlineCache) GetOnline(ctx context.Context, userID string) ([]int32, error) {
x.lock.RLock()
defer x.lock.RUnlock()
pSet, ok := x.user[userID]
if !ok {
return nil, nil
}
res := make([]int32, 0, len(pSet))
for k := range pSet {
res = append(res, k)
}
return res, nil
}
func (x *onlineCache) SetUserOnline(ctx context.Context, userID string, online, offline []int32) error {
x.lock.Lock()
defer x.lock.Unlock()
pSet, ok := x.user[userID]
if ok {
for _, p := range offline {
delete(pSet, p)
}
}
if len(online) > 0 {
if !ok {
pSet = make(map[int32]struct{})
x.user[userID] = pSet
}
for _, p := range online {
pSet[p] = struct{}{}
}
}
if len(pSet) == 0 {
delete(x.user, userID)
}
return nil
}
func (x *onlineCache) GetAllOnlineUsers(ctx context.Context, cursor uint64) (map[string][]int32, uint64, error) {
if cursor != 0 {
return nil, 0, nil
}
x.lock.RLock()
defer x.lock.RUnlock()
res := make(map[string][]int32)
for k, v := range x.user {
pSet := make([]int32, 0, len(v))
for p := range v {
pSet = append(pSet, p)
}
res[k] = pSet
}
return res, 0, nil
}

View File

@@ -0,0 +1,79 @@
package mcache
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
)
func NewSeqConversationCache(sc database.SeqConversation) cache.SeqConversationCache {
return &seqConversationCache{
sc: sc,
}
}
type seqConversationCache struct {
sc database.SeqConversation
}
func (x *seqConversationCache) Malloc(ctx context.Context, conversationID string, size int64) (int64, error) {
return x.sc.Malloc(ctx, conversationID, size)
}
func (x *seqConversationCache) SetMinSeq(ctx context.Context, conversationID string, seq int64) error {
return x.sc.SetMinSeq(ctx, conversationID, seq)
}
func (x *seqConversationCache) GetMinSeq(ctx context.Context, conversationID string) (int64, error) {
return x.sc.GetMinSeq(ctx, conversationID)
}
func (x *seqConversationCache) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
res := make(map[string]int64)
for _, conversationID := range conversationIDs {
seq, err := x.GetMinSeq(ctx, conversationID)
if err != nil {
return nil, err
}
res[conversationID] = seq
}
return res, nil
}
func (x *seqConversationCache) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
res := make(map[string]database.SeqTime)
for _, conversationID := range conversationIDs {
seq, err := x.GetMinSeq(ctx, conversationID)
if err != nil {
return nil, err
}
res[conversationID] = database.SeqTime{Seq: seq}
}
return res, nil
}
func (x *seqConversationCache) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) {
return x.sc.GetMaxSeq(ctx, conversationID)
}
func (x *seqConversationCache) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) {
seq, err := x.GetMinSeq(ctx, conversationID)
if err != nil {
return database.SeqTime{}, err
}
return database.SeqTime{Seq: seq}, nil
}
func (x *seqConversationCache) SetMinSeqs(ctx context.Context, seqs map[string]int64) error {
for conversationID, seq := range seqs {
if err := x.sc.SetMinSeq(ctx, conversationID, seq); err != nil {
return err
}
}
return nil
}
func (x *seqConversationCache) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
return x.GetMaxSeqsWithTime(ctx, conversationIDs)
}

View File

@@ -0,0 +1,98 @@
package mcache
import (
"context"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
)
func NewThirdCache(cache database.Cache) cache.ThirdCache {
return &thirdCache{
cache: cache,
}
}
type thirdCache struct {
cache database.Cache
}
func (c *thirdCache) getGetuiTokenKey() string {
return cachekey.GetGetuiTokenKey()
}
func (c *thirdCache) getGetuiTaskIDKey() string {
return cachekey.GetGetuiTaskIDKey()
}
func (c *thirdCache) getUserBadgeUnreadCountSumKey(userID string) string {
return cachekey.GetUserBadgeUnreadCountSumKey(userID)
}
func (c *thirdCache) getFcmAccountTokenKey(account string, platformID int) string {
return cachekey.GetFcmAccountTokenKey(account, platformID)
}
func (c *thirdCache) get(ctx context.Context, key string) (string, error) {
res, err := c.cache.Get(ctx, []string{key})
if err != nil {
return "", err
}
if val, ok := res[key]; ok {
return val, nil
}
return "", errs.Wrap(redis.Nil)
}
func (c *thirdCache) SetFcmToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) (err error) {
return errs.Wrap(c.cache.Set(ctx, c.getFcmAccountTokenKey(account, platformID), fcmToken, time.Duration(expireTime)*time.Second))
}
func (c *thirdCache) GetFcmToken(ctx context.Context, account string, platformID int) (string, error) {
return c.get(ctx, c.getFcmAccountTokenKey(account, platformID))
}
func (c *thirdCache) DelFcmToken(ctx context.Context, account string, platformID int) error {
return c.cache.Del(ctx, []string{c.getFcmAccountTokenKey(account, platformID)})
}
func (c *thirdCache) IncrUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) {
return c.cache.Incr(ctx, c.getUserBadgeUnreadCountSumKey(userID), 1)
}
func (c *thirdCache) SetUserBadgeUnreadCountSum(ctx context.Context, userID string, value int) error {
return c.cache.Set(ctx, c.getUserBadgeUnreadCountSumKey(userID), strconv.Itoa(value), 0)
}
func (c *thirdCache) GetUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) {
str, err := c.get(ctx, c.getUserBadgeUnreadCountSumKey(userID))
if err != nil {
return 0, err
}
val, err := strconv.Atoi(str)
if err != nil {
return 0, errs.WrapMsg(err, "strconv.Atoi", "str", str)
}
return val, nil
}
func (c *thirdCache) SetGetuiToken(ctx context.Context, token string, expireTime int64) error {
return c.cache.Set(ctx, c.getGetuiTokenKey(), token, time.Duration(expireTime)*time.Second)
}
func (c *thirdCache) GetGetuiToken(ctx context.Context) (string, error) {
return c.get(ctx, c.getGetuiTokenKey())
}
func (c *thirdCache) SetGetuiTaskID(ctx context.Context, taskID string, expireTime int64) error {
return c.cache.Set(ctx, c.getGetuiTaskIDKey(), taskID, time.Duration(expireTime)*time.Second)
}
func (c *thirdCache) GetGetuiTaskID(ctx context.Context) (string, error) {
return c.get(ctx, c.getGetuiTaskIDKey())
}

166
pkg/common/storage/cache/mcache/token.go vendored Normal file
View File

@@ -0,0 +1,166 @@
package mcache
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel {
c := &tokenCache{cache: cache}
c.accessExpire = c.getExpireTime(accessExpire)
return c
}
type tokenCache struct {
cache database.Cache
accessExpire time.Duration
}
func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string {
return cachekey.GetTokenKey(userID, platformID) + ":" + token
}
func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire)
}
// SetTokenFlagEx set token and flag with expire time
func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error {
return x.SetTokenFlag(ctx, userID, platformID, token, flag)
}
func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
prefix := x.getTokenKey(userID, platformID, "")
m, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, errs.Wrap(err)
}
mm := make(map[string]int)
for k, v := range m {
state, err := strconv.Atoi(v)
if err != nil {
log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID)
continue
}
mm[strings.TrimPrefix(k, prefix)] = state
}
return mm, nil
}
func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
key := cachekey.GetTemporaryTokenKey(userID, platformID, token)
if _, err := x.cache.Get(ctx, []string{key}); err != nil {
return err
}
return nil
}
func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
prefix := cachekey.UidPidToken + userID + ":"
tokens, err := x.cache.Prefix(ctx, prefix)
if err != nil {
return nil, err
}
res := make(map[int]map[string]int)
for key, flagStr := range tokens {
flag, err := strconv.Atoi(flagStr)
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2)
if len(arr) != 2 {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
platformID, err := strconv.Atoi(arr[0])
if err != nil {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
token := arr[1]
if token == "" {
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
continue
}
tk, ok := res[platformID]
if !ok {
tk = make(map[string]int)
res[platformID] = tk
}
tk[token] = flag
}
return res, nil
}
func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
for token, flag := range m {
err := x.SetTokenFlag(ctx, userID, platformID, token, flag)
if err != nil {
return err
}
}
return nil
}
func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
for prefix, tokenFlag := range tokens {
for token, flag := range tokenFlag {
flagStr := fmt.Sprintf("%v", flag)
if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil {
return err
}
}
}
return nil
}
func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, token := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, token))
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t)
}
func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
keys := make([]string, 0, len(tokens))
for platformID, ts := range tokens {
for _, t := range ts {
keys = append(keys, x.getTokenKey(userID, platformID, t))
}
}
return x.cache.Del(ctx, keys)
}
func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
keys := make([]string, 0, len(fields))
for _, f := range fields {
keys = append(keys, x.getTokenKey(userID, platformID, f))
}
if err := x.cache.Del(ctx, keys); err != nil {
return err
}
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := x.cache.Set(ctx, k, "", x.accessExpire); err != nil {
return errs.Wrap(err)
}
}
return nil
}

View File

@@ -0,0 +1,63 @@
package mcache
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/openimsdk/tools/log"
)
func getCache[V any](ctx context.Context, cache database.Cache, key string, expireTime time.Duration, fn func(ctx context.Context) (V, error)) (V, error) {
getDB := func() (V, bool, error) {
res, err := cache.Get(ctx, []string{key})
if err != nil {
var val V
return val, false, err
}
var val V
if str, ok := res[key]; ok {
if json.Unmarshal([]byte(str), &val) != nil {
return val, false, err
}
return val, true, nil
}
return val, false, nil
}
dbVal, ok, err := getDB()
if err != nil {
return dbVal, err
}
if ok {
return dbVal, nil
}
lockValue, err := cache.Lock(ctx, key, time.Minute)
if err != nil {
return dbVal, err
}
defer func() {
if err := cache.Unlock(ctx, key, lockValue); err != nil {
log.ZError(ctx, "unlock cache key", err, "key", key, "value", lockValue)
}
}()
dbVal, ok, err = getDB()
if err != nil {
return dbVal, err
}
if ok {
return dbVal, nil
}
val, err := fn(ctx)
if err != nil {
return val, err
}
data, err := json.Marshal(val)
if err != nil {
return val, err
}
if err := cache.Set(ctx, key, string(data), expireTime); err != nil {
return val, err
}
return val, nil
}

30
pkg/common/storage/cache/msg.go vendored Normal file
View File

@@ -0,0 +1,30 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
type MsgCache interface {
SetSendMsgStatus(ctx context.Context, id string, status int32) error
GetSendMsgStatus(ctx context.Context, id string) (int32, error)
GetMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) ([]*model.MsgInfoModel, error)
DelMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) error
SetMessageBySeqs(ctx context.Context, conversationID string, msgs []*model.MsgInfoModel) error
}

9
pkg/common/storage/cache/online.go vendored Normal file
View File

@@ -0,0 +1,9 @@
package cache
import "context"
type OnlineCache interface {
GetOnline(ctx context.Context, userID string) ([]int32, error)
SetUserOnline(ctx context.Context, userID string, online, offline []int32) error
GetAllOnlineUsers(ctx context.Context, cursor uint64) (map[string][]int32, uint64, error)
}

135
pkg/common/storage/cache/redis/batch.go vendored Normal file
View File

@@ -0,0 +1,135 @@
package redis
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"github.com/dtm-labs/rockscache"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
// GetRocksCacheOptions returns the default configuration options for RocksCache.
func GetRocksCacheOptions() *rockscache.Options {
opts := rockscache.NewDefaultOptions()
opts.LockExpire = rocksCacheTimeout
opts.WaitReplicasTimeout = rocksCacheTimeout
opts.StrongConsistency = true
opts.RandomExpireAdjustment = 0.2
return &opts
}
func newRocksCacheClient(rdb redis.UniversalClient) *rocksCacheClient {
if rdb == nil {
return &rocksCacheClient{}
}
rc := &rocksCacheClient{
rdb: rdb,
client: rockscache.NewClient(rdb, *GetRocksCacheOptions()),
}
return rc
}
type rocksCacheClient struct {
rdb redis.UniversalClient
client *rockscache.Client
}
func (x *rocksCacheClient) GetClient() *rockscache.Client {
return x.client
}
func (x *rocksCacheClient) Disable() bool {
return x.client == nil
}
func (x *rocksCacheClient) GetRedis() redis.UniversalClient {
return x.rdb
}
func (x *rocksCacheClient) GetBatchDeleter(topics ...string) cache.BatchDeleter {
return NewBatchDeleterRedis(x, topics)
}
func batchGetCache2[K comparable, V any](ctx context.Context, rcClient *rocksCacheClient, expire time.Duration, ids []K, idKey func(id K) string, vId func(v *V) K, fn func(ctx context.Context, ids []K) ([]*V, error)) ([]*V, error) {
if len(ids) == 0 {
return nil, nil
}
if rcClient.Disable() {
return fn(ctx, ids)
}
findKeys := make([]string, 0, len(ids))
keyId := make(map[string]K)
for _, id := range ids {
key := idKey(id)
if _, ok := keyId[key]; ok {
continue
}
keyId[key] = id
findKeys = append(findKeys, key)
}
slotKeys, err := groupKeysBySlot(ctx, rcClient.GetRedis(), findKeys)
if err != nil {
return nil, err
}
result := make([]*V, 0, len(findKeys))
for _, keys := range slotKeys {
indexCache, err := rcClient.GetClient().FetchBatch2(ctx, keys, expire, func(idx []int) (map[int]string, error) {
queryIds := make([]K, 0, len(idx))
idIndex := make(map[K]int)
for _, index := range idx {
id := keyId[keys[index]]
idIndex[id] = index
queryIds = append(queryIds, id)
}
values, err := fn(ctx, queryIds)
if err != nil {
log.ZError(ctx, "batchGetCache query database failed", err, "keys", keys, "queryIds", queryIds)
return nil, err
}
if len(values) == 0 {
return map[int]string{}, nil
}
cacheIndex := make(map[int]string)
for _, value := range values {
id := vId(value)
index, ok := idIndex[id]
if !ok {
continue
}
bs, err := json.Marshal(value)
if err != nil {
log.ZError(ctx, "marshal failed", err)
return nil, err
}
cacheIndex[index] = string(bs)
}
return cacheIndex, nil
})
if err != nil {
return nil, errs.WrapMsg(err, "FetchBatch2 failed")
}
for index, data := range indexCache {
if data == "" {
continue
}
var value V
if err := json.Unmarshal([]byte(data), &value); err != nil {
return nil, errs.WrapMsg(err, "Unmarshal failed")
}
if cb, ok := any(&value).(BatchCacheCallback[K]); ok {
cb.BatchCache(keyId[keys[index]])
}
result = append(result, &value)
}
}
return result, nil
}
type BatchCacheCallback[K comparable] interface {
BatchCache(id K)
}

View File

@@ -0,0 +1,149 @@
package redis
import (
"context"
"encoding/json"
"fmt"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"github.com/dtm-labs/rockscache"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
const (
rocksCacheTimeout = 11 * time.Second
)
// BatchDeleterRedis is a concrete implementation of the BatchDeleter interface based on Redis and RocksCache.
type BatchDeleterRedis struct {
redisClient redis.UniversalClient
keys []string
rocksClient *rockscache.Client
redisPubTopics []string
}
// NewBatchDeleterRedis creates a new BatchDeleterRedis instance.
func NewBatchDeleterRedis(rcClient *rocksCacheClient, redisPubTopics []string) *BatchDeleterRedis {
return &BatchDeleterRedis{
redisClient: rcClient.GetRedis(),
rocksClient: rcClient.GetClient(),
redisPubTopics: redisPubTopics,
}
}
// ExecDelWithKeys directly takes keys for batch deletion and publishes deletion information.
func (c *BatchDeleterRedis) ExecDelWithKeys(ctx context.Context, keys []string) error {
distinctKeys := datautil.Distinct(keys)
return c.execDel(ctx, distinctKeys)
}
// ChainExecDel is used for chain calls for batch deletion. It must call Clone to prevent memory pollution.
func (c *BatchDeleterRedis) ChainExecDel(ctx context.Context) error {
distinctKeys := datautil.Distinct(c.keys)
return c.execDel(ctx, distinctKeys)
}
// execDel performs batch deletion and publishes the keys that have been deleted to update the local cache information of other nodes.
func (c *BatchDeleterRedis) execDel(ctx context.Context, keys []string) error {
if len(keys) > 0 {
log.ZDebug(ctx, "delete cache", "topic", c.redisPubTopics, "keys", keys)
// Batch delete keys
err := ProcessKeysBySlot(ctx, c.redisClient, keys, func(ctx context.Context, slot int64, keys []string) error {
return c.rocksClient.TagAsDeletedBatch2(ctx, keys)
})
if err != nil {
return err
}
// Publish the keys that have been deleted to Redis to update the local cache information of other nodes
if len(c.redisPubTopics) > 0 && len(keys) > 0 {
keysByTopic := localcache.GetPublishKeysByTopic(c.redisPubTopics, keys)
for topic, keys := range keysByTopic {
if len(keys) > 0 {
data, err := json.Marshal(keys)
if err != nil {
log.ZWarn(ctx, "keys json marshal failed", err, "topic", topic, "keys", keys)
} else {
if err := c.redisClient.Publish(ctx, topic, string(data)).Err(); err != nil {
log.ZWarn(ctx, "redis publish cache delete error", err, "topic", topic, "keys", keys)
}
}
}
}
}
}
return nil
}
// Clone creates a copy of BatchDeleterRedis for chain calls to prevent memory pollution.
func (c *BatchDeleterRedis) Clone() cache.BatchDeleter {
return &BatchDeleterRedis{
redisClient: c.redisClient,
keys: c.keys,
rocksClient: c.rocksClient,
redisPubTopics: c.redisPubTopics,
}
}
// AddKeys adds keys to be deleted.
func (c *BatchDeleterRedis) AddKeys(keys ...string) {
c.keys = append(c.keys, keys...)
}
type disableBatchDeleter struct{}
func (x disableBatchDeleter) ChainExecDel(ctx context.Context) error {
return nil
}
func (x disableBatchDeleter) ExecDelWithKeys(ctx context.Context, keys []string) error {
return nil
}
func (x disableBatchDeleter) Clone() cache.BatchDeleter {
return x
}
func (x disableBatchDeleter) AddKeys(keys ...string) {}
func getCache[T any](ctx context.Context, rcClient *rocksCacheClient, key string, expire time.Duration, fn func(ctx context.Context) (T, error)) (T, error) {
if rcClient.Disable() {
return fn(ctx)
}
var t T
var write bool
v, err := rcClient.GetClient().Fetch2(ctx, key, expire, func() (s string, err error) {
t, err = fn(ctx)
if err != nil {
//log.ZError(ctx, "getCache query database failed", err, "key", key)
return "", err
}
bs, err := json.Marshal(t)
if err != nil {
return "", errs.WrapMsg(err, "marshal failed")
}
write = true
return string(bs), nil
})
if err != nil {
return t, errs.Wrap(err)
}
if write {
return t, nil
}
if v == "" {
return t, errs.ErrRecordNotFound.WrapMsg("cache is not found")
}
err = json.Unmarshal([]byte(v), &t)
if err != nil {
errInfo := fmt.Sprintf("cache json.Unmarshal failed, key:%s, value:%s, expire:%s", key, v, expire)
return t, errs.WrapMsg(err, errInfo)
}
return t, nil
}

View File

@@ -0,0 +1,56 @@
package redis
import (
"context"
"testing"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/redisutil"
)
func TestName(t *testing.T) {
//var rocks rockscache.Client
//rdb := getRocksCacheRedisClient(&rocks)
//t.Log(rdb == nil)
ctx := context.Background()
rdb, err := redisutil.NewRedisClient(ctx, (&config.Redis{
Address: []string{"172.16.8.48:16379"},
Password: "openIM123",
DB: 3,
}).Build())
if err != nil {
panic(err)
}
mgocli, err := mongoutil.NewMongoDB(ctx, (&config.Mongo{
Address: []string{"172.16.8.48:37017"},
Database: "openim_v3",
Username: "openIM",
Password: "openIM123",
MaxPoolSize: 100,
MaxRetry: 1,
}).Build())
if err != nil {
panic(err)
}
//userMgo, err := mgo.NewUserMongo(mgocli.GetDB())
//if err != nil {
// panic(err)
//}
//rock := rockscache.NewClient(rdb, rockscache.NewDefaultOptions())
mgoSeqUser, err := mgo.NewSeqUserMongo(mgocli.GetDB())
if err != nil {
panic(err)
}
seqUser := NewSeqUserCacheRedis(rdb, mgoSeqUser)
res, err := seqUser.GetUserReadSeqs(ctx, "2110910952", []string{"sg_2920732023", "sg_345762580"})
if err != nil {
panic(err)
}
t.Log(res)
}

65
pkg/common/storage/cache/redis/black.go vendored Normal file
View File

@@ -0,0 +1,65 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/redis/go-redis/v9"
)
const (
blackExpireTime = time.Second * 60 * 60 * 12
)
type BlackCacheRedis struct {
cache.BatchDeleter
expireTime time.Duration
rcClient *rocksCacheClient
blackDB database.Black
}
func NewBlackCacheRedis(rdb redis.UniversalClient, localCache *config.LocalCache, blackDB database.Black) cache.BlackCache {
rc := newRocksCacheClient(rdb)
return &BlackCacheRedis{
BatchDeleter: rc.GetBatchDeleter(localCache.Friend.Topic),
expireTime: blackExpireTime,
rcClient: rc,
blackDB: blackDB,
}
}
func (b *BlackCacheRedis) CloneBlackCache() cache.BlackCache {
return &BlackCacheRedis{
BatchDeleter: b.BatchDeleter.Clone(),
expireTime: b.expireTime,
rcClient: b.rcClient,
blackDB: b.blackDB,
}
}
func (b *BlackCacheRedis) getBlackIDsKey(ownerUserID string) string {
return cachekey.GetBlackIDsKey(ownerUserID)
}
func (b *BlackCacheRedis) GetBlackIDs(ctx context.Context, userID string) (blackIDs []string, err error) {
return getCache(
ctx,
b.rcClient,
b.getBlackIDsKey(userID),
b.expireTime,
func(ctx context.Context) ([]string, error) {
return b.blackDB.FindBlackUserIDs(ctx, userID)
},
)
}
func (b *BlackCacheRedis) DelBlackIDs(_ context.Context, userID string) cache.BlackCache {
cache := b.CloneBlackCache()
cache.AddKeys(b.getBlackIDsKey(userID))
return cache
}

View File

@@ -0,0 +1,69 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/redis/go-redis/v9"
)
func NewClientConfigCache(rdb redis.UniversalClient, mgo database.ClientConfig) cache.ClientConfigCache {
rc := newRocksCacheClient(rdb)
return &ClientConfigCache{
mgo: mgo,
rcClient: rc,
delete: rc.GetBatchDeleter(),
}
}
type ClientConfigCache struct {
mgo database.ClientConfig
rcClient *rocksCacheClient
delete cache.BatchDeleter
}
func (x *ClientConfigCache) getExpireTime(userID string) time.Duration {
if userID == "" {
return time.Hour * 24
} else {
return time.Hour
}
}
func (x *ClientConfigCache) getClientConfigKey(userID string) string {
return cachekey.GetClientConfigKey(userID)
}
func (x *ClientConfigCache) GetConfig(ctx context.Context, userID string) (map[string]string, error) {
return getCache(ctx, x.rcClient, x.getClientConfigKey(userID), x.getExpireTime(userID), func(ctx context.Context) (map[string]string, error) {
return x.mgo.Get(ctx, userID)
})
}
func (x *ClientConfigCache) DeleteUserCache(ctx context.Context, userIDs []string) error {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, x.getClientConfigKey(userID))
}
return x.delete.ExecDelWithKeys(ctx, keys)
}
func (x *ClientConfigCache) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) {
config, err := x.GetConfig(ctx, "")
if err != nil {
return nil, err
}
if userID != "" {
userConfig, err := x.GetConfig(ctx, userID)
if err != nil {
return nil, err
}
for k, v := range userConfig {
config[k] = v
}
}
return config, nil
}

View File

@@ -0,0 +1,276 @@
package redis
import (
"context"
"math/big"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/encrypt"
"github.com/redis/go-redis/v9"
)
const (
conversationExpireTime = time.Second * 60 * 60 * 12
)
func NewConversationRedis(rdb redis.UniversalClient, localCache *config.LocalCache, db database.Conversation) cache.ConversationCache {
rc := newRocksCacheClient(rdb)
return &ConversationRedisCache{
BatchDeleter: rc.GetBatchDeleter(localCache.Conversation.Topic),
rcClient: rc,
conversationDB: db,
expireTime: conversationExpireTime,
}
}
type ConversationRedisCache struct {
cache.BatchDeleter
rcClient *rocksCacheClient
conversationDB database.Conversation
expireTime time.Duration
}
func (c *ConversationRedisCache) CloneConversationCache() cache.ConversationCache {
return &ConversationRedisCache{
BatchDeleter: c.BatchDeleter.Clone(),
rcClient: c.rcClient,
conversationDB: c.conversationDB,
expireTime: c.expireTime,
}
}
func (c *ConversationRedisCache) getConversationKey(ownerUserID, conversationID string) string {
return cachekey.GetConversationKey(ownerUserID, conversationID)
}
func (c *ConversationRedisCache) getConversationIDsKey(ownerUserID string) string {
return cachekey.GetConversationIDsKey(ownerUserID)
}
func (c *ConversationRedisCache) getNotNotifyConversationIDsKey(ownerUserID string) string {
return cachekey.GetNotNotifyConversationIDsKey(ownerUserID)
}
func (c *ConversationRedisCache) getPinnedConversationIDsKey(ownerUserID string) string {
return cachekey.GetPinnedConversationIDs(ownerUserID)
}
func (c *ConversationRedisCache) getSuperGroupRecvNotNotifyUserIDsKey(groupID string) string {
return cachekey.GetSuperGroupRecvNotNotifyUserIDsKey(groupID)
}
func (c *ConversationRedisCache) getRecvMsgOptKey(ownerUserID, conversationID string) string {
return cachekey.GetRecvMsgOptKey(ownerUserID, conversationID)
}
func (c *ConversationRedisCache) getSuperGroupRecvNotNotifyUserIDsHashKey(groupID string) string {
return cachekey.GetSuperGroupRecvNotNotifyUserIDsHashKey(groupID)
}
func (c *ConversationRedisCache) getConversationHasReadSeqKey(ownerUserID, conversationID string) string {
return cachekey.GetConversationHasReadSeqKey(ownerUserID, conversationID)
}
func (c *ConversationRedisCache) getConversationNotReceiveMessageUserIDsKey(conversationID string) string {
return cachekey.GetConversationNotReceiveMessageUserIDsKey(conversationID)
}
func (c *ConversationRedisCache) getUserConversationIDsHashKey(ownerUserID string) string {
return cachekey.GetUserConversationIDsHashKey(ownerUserID)
}
func (c *ConversationRedisCache) getConversationUserMaxVersionKey(ownerUserID string) string {
return cachekey.GetConversationUserMaxVersionKey(ownerUserID)
}
func (c *ConversationRedisCache) GetUserConversationIDs(ctx context.Context, ownerUserID string) ([]string, error) {
return getCache(ctx, c.rcClient, c.getConversationIDsKey(ownerUserID), c.expireTime, func(ctx context.Context) ([]string, error) {
return c.conversationDB.FindUserIDAllConversationID(ctx, ownerUserID)
})
}
func (c *ConversationRedisCache) GetUserNotNotifyConversationIDs(ctx context.Context, userID string) ([]string, error) {
return getCache(ctx, c.rcClient, c.getNotNotifyConversationIDsKey(userID), c.expireTime, func(ctx context.Context) ([]string, error) {
return c.conversationDB.FindUserIDAllNotNotifyConversationID(ctx, userID)
})
}
func (c *ConversationRedisCache) GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error) {
return getCache(ctx, c.rcClient, c.getPinnedConversationIDsKey(userID), c.expireTime, func(ctx context.Context) ([]string, error) {
return c.conversationDB.FindUserIDAllPinnedConversationID(ctx, userID)
})
}
func (c *ConversationRedisCache) DelConversationIDs(userIDs ...string) cache.ConversationCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, c.getConversationIDsKey(userID))
}
cache := c.CloneConversationCache()
cache.AddKeys(keys...)
return cache
}
func (c *ConversationRedisCache) GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error) {
return getCache(
ctx,
c.rcClient,
c.getUserConversationIDsHashKey(ownerUserID),
c.expireTime,
func(ctx context.Context) (uint64, error) {
conversationIDs, err := c.GetUserConversationIDs(ctx, ownerUserID)
if err != nil {
return 0, err
}
datautil.Sort(conversationIDs, true)
bi := big.NewInt(0)
bi.SetString(encrypt.Md5(strings.Join(conversationIDs, ";"))[0:8], 16)
return bi.Uint64(), nil
},
)
}
func (c *ConversationRedisCache) DelUserConversationIDsHash(ownerUserIDs ...string) cache.ConversationCache {
keys := make([]string, 0, len(ownerUserIDs))
for _, ownerUserID := range ownerUserIDs {
keys = append(keys, c.getUserConversationIDsHashKey(ownerUserID))
}
cache := c.CloneConversationCache()
cache.AddKeys(keys...)
return cache
}
func (c *ConversationRedisCache) GetConversation(ctx context.Context, ownerUserID, conversationID string) (*model.Conversation, error) {
return getCache(ctx, c.rcClient, c.getConversationKey(ownerUserID, conversationID), c.expireTime, func(ctx context.Context) (*model.Conversation, error) {
return c.conversationDB.Take(ctx, ownerUserID, conversationID)
})
}
func (c *ConversationRedisCache) DelConversations(ownerUserID string, conversationIDs ...string) cache.ConversationCache {
keys := make([]string, 0, len(conversationIDs))
for _, conversationID := range conversationIDs {
keys = append(keys, c.getConversationKey(ownerUserID, conversationID))
}
cache := c.CloneConversationCache()
cache.AddKeys(keys...)
return cache
}
func (c *ConversationRedisCache) GetConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*model.Conversation, error) {
return batchGetCache2(ctx, c.rcClient, c.expireTime, conversationIDs, func(conversationID string) string {
return c.getConversationKey(ownerUserID, conversationID)
}, func(conversation *model.Conversation) string {
return conversation.ConversationID
}, func(ctx context.Context, conversationIDs []string) ([]*model.Conversation, error) {
return c.conversationDB.Find(ctx, ownerUserID, conversationIDs)
})
}
func (c *ConversationRedisCache) GetUserAllConversations(ctx context.Context, ownerUserID string) ([]*model.Conversation, error) {
conversationIDs, err := c.GetUserConversationIDs(ctx, ownerUserID)
if err != nil {
return nil, err
}
return c.GetConversations(ctx, ownerUserID, conversationIDs)
}
func (c *ConversationRedisCache) GetUserRecvMsgOpt(ctx context.Context, ownerUserID, conversationID string) (opt int, err error) {
return getCache(ctx, c.rcClient, c.getRecvMsgOptKey(ownerUserID, conversationID), c.expireTime, func(ctx context.Context) (opt int, err error) {
return c.conversationDB.GetUserRecvMsgOpt(ctx, ownerUserID, conversationID)
})
}
func (c *ConversationRedisCache) DelUsersConversation(conversationID string, ownerUserIDs ...string) cache.ConversationCache {
keys := make([]string, 0, len(ownerUserIDs))
for _, ownerUserID := range ownerUserIDs {
keys = append(keys, c.getConversationKey(ownerUserID, conversationID))
}
cache := c.CloneConversationCache()
cache.AddKeys(keys...)
return cache
}
func (c *ConversationRedisCache) DelUserRecvMsgOpt(ownerUserID, conversationID string) cache.ConversationCache {
cache := c.CloneConversationCache()
cache.AddKeys(c.getRecvMsgOptKey(ownerUserID, conversationID))
return cache
}
func (c *ConversationRedisCache) DelSuperGroupRecvMsgNotNotifyUserIDs(groupID string) cache.ConversationCache {
cache := c.CloneConversationCache()
cache.AddKeys(c.getSuperGroupRecvNotNotifyUserIDsKey(groupID))
return cache
}
func (c *ConversationRedisCache) DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID string) cache.ConversationCache {
cache := c.CloneConversationCache()
cache.AddKeys(c.getSuperGroupRecvNotNotifyUserIDsHashKey(groupID))
return cache
}
func (c *ConversationRedisCache) DelUserAllHasReadSeqs(ownerUserID string, conversationIDs ...string) cache.ConversationCache {
cache := c.CloneConversationCache()
for _, conversationID := range conversationIDs {
cache.AddKeys(c.getConversationHasReadSeqKey(ownerUserID, conversationID))
}
return cache
}
func (c *ConversationRedisCache) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) {
return getCache(ctx, c.rcClient, c.getConversationNotReceiveMessageUserIDsKey(conversationID), c.expireTime, func(ctx context.Context) ([]string, error) {
return c.conversationDB.GetConversationNotReceiveMessageUserIDs(ctx, conversationID)
})
}
func (c *ConversationRedisCache) DelConversationNotReceiveMessageUserIDs(conversationIDs ...string) cache.ConversationCache {
cache := c.CloneConversationCache()
for _, conversationID := range conversationIDs {
cache.AddKeys(c.getConversationNotReceiveMessageUserIDsKey(conversationID))
}
return cache
}
func (c *ConversationRedisCache) DelConversationNotNotifyMessageUserIDs(userIDs ...string) cache.ConversationCache {
cache := c.CloneConversationCache()
for _, userID := range userIDs {
cache.AddKeys(c.getNotNotifyConversationIDsKey(userID))
}
return cache
}
func (c *ConversationRedisCache) DelUserPinnedConversations(userIDs ...string) cache.ConversationCache {
cache := c.CloneConversationCache()
for _, userID := range userIDs {
cache.AddKeys(c.getPinnedConversationIDsKey(userID))
}
return cache
}
func (c *ConversationRedisCache) DelConversationVersionUserIDs(userIDs ...string) cache.ConversationCache {
cache := c.CloneConversationCache()
for _, userID := range userIDs {
cache.AddKeys(c.getConversationUserMaxVersionKey(userID))
}
return cache
}
func (c *ConversationRedisCache) FindMaxConversationUserVersion(ctx context.Context, userID string) (*model.VersionLog, error) {
return getCache(ctx, c.rcClient, c.getConversationUserMaxVersionKey(userID), c.expireTime, func(ctx context.Context) (*model.VersionLog, error) {
return c.conversationDB.FindConversationUserVersion(ctx, userID, 0, 0)
})
}

167
pkg/common/storage/cache/redis/friend.go vendored Normal file
View File

@@ -0,0 +1,167 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
const (
friendExpireTime = time.Second * 60 * 60 * 12
)
// FriendCacheRedis is an implementation of the FriendCache interface using Redis.
type FriendCacheRedis struct {
cache.BatchDeleter
friendDB database.Friend
expireTime time.Duration
rcClient *rocksCacheClient
syncCount int
}
// NewFriendCacheRedis creates a new instance of FriendCacheRedis.
func NewFriendCacheRedis(rdb redis.UniversalClient, localCache *config.LocalCache, friendDB database.Friend) cache.FriendCache {
rc := newRocksCacheClient(rdb)
return &FriendCacheRedis{
BatchDeleter: rc.GetBatchDeleter(localCache.Friend.Topic),
friendDB: friendDB,
expireTime: friendExpireTime,
rcClient: rc,
}
}
func (f *FriendCacheRedis) CloneFriendCache() cache.FriendCache {
return &FriendCacheRedis{
BatchDeleter: f.BatchDeleter.Clone(),
friendDB: f.friendDB,
expireTime: f.expireTime,
rcClient: f.rcClient,
}
}
// getFriendIDsKey returns the key for storing friend IDs in the cache.
func (f *FriendCacheRedis) getFriendIDsKey(ownerUserID string) string {
return cachekey.GetFriendIDsKey(ownerUserID)
}
func (f *FriendCacheRedis) getFriendMaxVersionKey(ownerUserID string) string {
return cachekey.GetFriendMaxVersionKey(ownerUserID)
}
// getTwoWayFriendsIDsKey returns the key for storing two-way friend IDs in the cache.
func (f *FriendCacheRedis) getTwoWayFriendsIDsKey(ownerUserID string) string {
return cachekey.GetTwoWayFriendsIDsKey(ownerUserID)
}
// getFriendKey returns the key for storing friend info in the cache.
func (f *FriendCacheRedis) getFriendKey(ownerUserID, friendUserID string) string {
return cachekey.GetFriendKey(ownerUserID, friendUserID)
}
// GetFriendIDs retrieves friend IDs from the cache or the database if not found.
func (f *FriendCacheRedis) GetFriendIDs(ctx context.Context, ownerUserID string) (friendIDs []string, err error) {
return getCache(ctx, f.rcClient, f.getFriendIDsKey(ownerUserID), f.expireTime, func(ctx context.Context) ([]string, error) {
return f.friendDB.FindFriendUserIDs(ctx, ownerUserID)
})
}
// DelFriendIDs deletes friend IDs from the cache.
func (f *FriendCacheRedis) DelFriendIDs(ownerUserIDs ...string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
keys := make([]string, 0, len(ownerUserIDs))
for _, userID := range ownerUserIDs {
keys = append(keys, f.getFriendIDsKey(userID))
}
newFriendCache.AddKeys(keys...)
return newFriendCache
}
// GetTwoWayFriendIDs retrieves two-way friend IDs from the cache.
func (f *FriendCacheRedis) GetTwoWayFriendIDs(ctx context.Context, ownerUserID string) (twoWayFriendIDs []string, err error) {
friendIDs, err := f.GetFriendIDs(ctx, ownerUserID)
if err != nil {
return nil, err
}
for _, friendID := range friendIDs {
friendFriendID, err := f.GetFriendIDs(ctx, friendID)
if err != nil {
return nil, err
}
if datautil.Contain(ownerUserID, friendFriendID...) {
twoWayFriendIDs = append(twoWayFriendIDs, ownerUserID)
}
}
return twoWayFriendIDs, nil
}
// DelTwoWayFriendIDs deletes two-way friend IDs from the cache.
func (f *FriendCacheRedis) DelTwoWayFriendIDs(ctx context.Context, ownerUserID string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
newFriendCache.AddKeys(f.getTwoWayFriendsIDsKey(ownerUserID))
return newFriendCache
}
// GetFriend retrieves friend info from the cache or the database if not found.
func (f *FriendCacheRedis) GetFriend(ctx context.Context, ownerUserID, friendUserID string) (friend *model.Friend, err error) {
return getCache(ctx, f.rcClient, f.getFriendKey(ownerUserID,
friendUserID), f.expireTime, func(ctx context.Context) (*model.Friend, error) {
return f.friendDB.Take(ctx, ownerUserID, friendUserID)
})
}
// DelFriend deletes friend info from the cache.
func (f *FriendCacheRedis) DelFriend(ownerUserID, friendUserID string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
newFriendCache.AddKeys(f.getFriendKey(ownerUserID, friendUserID))
return newFriendCache
}
// DelFriends deletes multiple friend infos from the cache.
func (f *FriendCacheRedis) DelFriends(ownerUserID string, friendUserIDs []string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
for _, friendUserID := range friendUserIDs {
key := f.getFriendKey(ownerUserID, friendUserID)
newFriendCache.AddKeys(key) // Assuming AddKeys marks the keys for deletion
}
return newFriendCache
}
func (f *FriendCacheRedis) DelOwner(friendUserID string, ownerUserIDs []string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
for _, ownerUserID := range ownerUserIDs {
key := f.getFriendKey(ownerUserID, friendUserID)
newFriendCache.AddKeys(key) // Assuming AddKeys marks the keys for deletion
}
return newFriendCache
}
func (f *FriendCacheRedis) DelMaxFriendVersion(ownerUserIDs ...string) cache.FriendCache {
newFriendCache := f.CloneFriendCache()
for _, ownerUserID := range ownerUserIDs {
key := f.getFriendMaxVersionKey(ownerUserID)
newFriendCache.AddKeys(key) // Assuming AddKeys marks the keys for deletion
}
return newFriendCache
}
func (f *FriendCacheRedis) FindMaxFriendVersion(ctx context.Context, ownerUserID string) (*model.VersionLog, error) {
return getCache(ctx, f.rcClient, f.getFriendMaxVersionKey(ownerUserID), f.expireTime, func(ctx context.Context) (*model.VersionLog, error) {
return f.friendDB.FindIncrVersion(ctx, ownerUserID, 0, 0)
})
}

385
pkg/common/storage/cache/redis/group.go vendored Normal file
View File

@@ -0,0 +1,385 @@
package redis
import (
"context"
"fmt"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/common"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
const (
groupExpireTime = time.Second * 60 * 60 * 12
)
type GroupCacheRedis struct {
cache.BatchDeleter
groupDB database.Group
groupMemberDB database.GroupMember
groupRequestDB database.GroupRequest
expireTime time.Duration
rcClient *rocksCacheClient
groupHash cache.GroupHash
}
func NewGroupCacheRedis(rdb redis.UniversalClient, localCache *config.LocalCache, groupDB database.Group, groupMemberDB database.GroupMember, groupRequestDB database.GroupRequest, hashCode cache.GroupHash) cache.GroupCache {
rc := newRocksCacheClient(rdb)
return &GroupCacheRedis{
BatchDeleter: rc.GetBatchDeleter(localCache.Group.Topic),
rcClient: rc,
expireTime: groupExpireTime,
groupDB: groupDB,
groupMemberDB: groupMemberDB,
groupRequestDB: groupRequestDB,
groupHash: hashCode,
}
}
func (g *GroupCacheRedis) CloneGroupCache() cache.GroupCache {
return &GroupCacheRedis{
BatchDeleter: g.BatchDeleter.Clone(),
rcClient: g.rcClient,
expireTime: g.expireTime,
groupDB: g.groupDB,
groupMemberDB: g.groupMemberDB,
groupRequestDB: g.groupRequestDB,
}
}
func (g *GroupCacheRedis) getGroupInfoKey(groupID string) string {
return cachekey.GetGroupInfoKey(groupID)
}
func (g *GroupCacheRedis) getJoinedGroupsKey(userID string) string {
return cachekey.GetJoinedGroupsKey(userID)
}
func (g *GroupCacheRedis) getGroupMembersHashKey(groupID string) string {
return cachekey.GetGroupMembersHashKey(groupID)
}
func (g *GroupCacheRedis) getGroupMemberIDsKey(groupID string) string {
return cachekey.GetGroupMemberIDsKey(groupID)
}
func (g *GroupCacheRedis) getGroupMemberInfoKey(groupID, userID string) string {
return cachekey.GetGroupMemberInfoKey(groupID, userID)
}
func (g *GroupCacheRedis) getGroupMemberNumKey(groupID string) string {
return cachekey.GetGroupMemberNumKey(groupID)
}
func (g *GroupCacheRedis) getGroupRoleLevelMemberIDsKey(groupID string, roleLevel int32) string {
return cachekey.GetGroupRoleLevelMemberIDsKey(groupID, roleLevel)
}
func (g *GroupCacheRedis) getGroupMemberMaxVersionKey(groupID string) string {
return cachekey.GetGroupMemberMaxVersionKey(groupID)
}
func (g *GroupCacheRedis) getJoinGroupMaxVersionKey(userID string) string {
return cachekey.GetJoinGroupMaxVersionKey(userID)
}
func (g *GroupCacheRedis) getGroupID(group *model.Group) string {
return group.GroupID
}
func (g *GroupCacheRedis) GetGroupsInfo(ctx context.Context, groupIDs []string) (groups []*model.Group, err error) {
return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs, g.getGroupInfoKey, g.getGroupID, g.groupDB.Find)
}
func (g *GroupCacheRedis) GetGroupInfo(ctx context.Context, groupID string) (group *model.Group, err error) {
return getCache(ctx, g.rcClient, g.getGroupInfoKey(groupID), g.expireTime, func(ctx context.Context) (*model.Group, error) {
return g.groupDB.Take(ctx, groupID)
})
}
func (g *GroupCacheRedis) DelGroupsInfo(groupIDs ...string) cache.GroupCache {
newGroupCache := g.CloneGroupCache()
keys := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
keys = append(keys, g.getGroupInfoKey(groupID))
}
newGroupCache.AddKeys(keys...)
return newGroupCache
}
func (g *GroupCacheRedis) DelGroupsOwner(groupIDs ...string) cache.GroupCache {
newGroupCache := g.CloneGroupCache()
keys := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
keys = append(keys, g.getGroupRoleLevelMemberIDsKey(groupID, constant.GroupOwner))
}
newGroupCache.AddKeys(keys...)
return newGroupCache
}
func (g *GroupCacheRedis) DelGroupRoleLevel(groupID string, roleLevels []int32) cache.GroupCache {
newGroupCache := g.CloneGroupCache()
keys := make([]string, 0, len(roleLevels))
for _, roleLevel := range roleLevels {
keys = append(keys, g.getGroupRoleLevelMemberIDsKey(groupID, roleLevel))
}
newGroupCache.AddKeys(keys...)
return newGroupCache
}
func (g *GroupCacheRedis) DelGroupAllRoleLevel(groupID string) cache.GroupCache {
return g.DelGroupRoleLevel(groupID, []int32{constant.GroupOwner, constant.GroupAdmin, constant.GroupOrdinaryUsers})
}
func (g *GroupCacheRedis) GetGroupMembersHash(ctx context.Context, groupID string) (hashCode uint64, err error) {
if g.groupHash == nil {
return 0, errs.ErrInternalServer.WrapMsg("group hash is nil")
}
return getCache(ctx, g.rcClient, g.getGroupMembersHashKey(groupID), g.expireTime, func(ctx context.Context) (uint64, error) {
return g.groupHash.GetGroupHash(ctx, groupID)
})
}
func (g *GroupCacheRedis) GetGroupMemberHashMap(ctx context.Context, groupIDs []string) (map[string]*common.GroupSimpleUserID, error) {
if g.groupHash == nil {
return nil, errs.ErrInternalServer.WrapMsg("group hash is nil")
}
res := make(map[string]*common.GroupSimpleUserID)
for _, groupID := range groupIDs {
hash, err := g.GetGroupMembersHash(ctx, groupID)
if err != nil {
return nil, err
}
log.ZDebug(ctx, "GetGroupMemberHashMap", "groupID", groupID, "hash", hash)
num, err := g.GetGroupMemberNum(ctx, groupID)
if err != nil {
return nil, err
}
res[groupID] = &common.GroupSimpleUserID{Hash: hash, MemberNum: uint32(num)}
}
return res, nil
}
func (g *GroupCacheRedis) DelGroupMembersHash(groupID string) cache.GroupCache {
cache := g.CloneGroupCache()
cache.AddKeys(g.getGroupMembersHashKey(groupID))
return cache
}
func (g *GroupCacheRedis) GetGroupMemberIDs(ctx context.Context, groupID string) (groupMemberIDs []string, err error) {
return getCache(ctx, g.rcClient, g.getGroupMemberIDsKey(groupID), g.expireTime, func(ctx context.Context) ([]string, error) {
return g.groupMemberDB.FindMemberUserID(ctx, groupID)
})
}
func (g *GroupCacheRedis) DelGroupMemberIDs(groupID string) cache.GroupCache {
cache := g.CloneGroupCache()
cache.AddKeys(g.getGroupMemberIDsKey(groupID))
return cache
}
func (g *GroupCacheRedis) findUserJoinedGroupID(ctx context.Context, userID string) ([]string, error) {
groupIDs, err := g.groupMemberDB.FindUserJoinedGroupID(ctx, userID)
if err != nil {
return nil, err
}
return g.groupDB.FindJoinSortGroupID(ctx, groupIDs)
}
func (g *GroupCacheRedis) GetJoinedGroupIDs(ctx context.Context, userID string) (joinedGroupIDs []string, err error) {
return getCache(ctx, g.rcClient, g.getJoinedGroupsKey(userID), g.expireTime, func(ctx context.Context) ([]string, error) {
return g.findUserJoinedGroupID(ctx, userID)
})
}
func (g *GroupCacheRedis) DelJoinedGroupID(userIDs ...string) cache.GroupCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, g.getJoinedGroupsKey(userID))
}
cache := g.CloneGroupCache()
cache.AddKeys(keys...)
return cache
}
func (g *GroupCacheRedis) GetGroupMemberInfo(ctx context.Context, groupID, userID string) (groupMember *model.GroupMember, err error) {
return getCache(ctx, g.rcClient, g.getGroupMemberInfoKey(groupID, userID), g.expireTime, func(ctx context.Context) (*model.GroupMember, error) {
return g.groupMemberDB.Take(ctx, groupID, userID)
})
}
func (g *GroupCacheRedis) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) {
return batchGetCache2(ctx, g.rcClient, g.expireTime, userIDs, func(userID string) string {
return g.getGroupMemberInfoKey(groupID, userID)
}, func(member *model.GroupMember) string {
return member.UserID
}, func(ctx context.Context, userIDs []string) ([]*model.GroupMember, error) {
return g.groupMemberDB.Find(ctx, groupID, userIDs)
})
}
func (g *GroupCacheRedis) GetAllGroupMembersInfo(ctx context.Context, groupID string) (groupMembers []*model.GroupMember, err error) {
groupMemberIDs, err := g.GetGroupMemberIDs(ctx, groupID)
if err != nil {
return nil, err
}
return g.GetGroupMembersInfo(ctx, groupID, groupMemberIDs)
}
func (g *GroupCacheRedis) DelGroupMembersInfo(groupID string, userIDs ...string) cache.GroupCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, g.getGroupMemberInfoKey(groupID, userID))
}
cache := g.CloneGroupCache()
cache.AddKeys(keys...)
return cache
}
func (g *GroupCacheRedis) GetGroupMemberNum(ctx context.Context, groupID string) (memberNum int64, err error) {
return getCache(ctx, g.rcClient, g.getGroupMemberNumKey(groupID), g.expireTime, func(ctx context.Context) (int64, error) {
return g.groupMemberDB.TakeGroupMemberNum(ctx, groupID)
})
}
func (g *GroupCacheRedis) DelGroupsMemberNum(groupID ...string) cache.GroupCache {
keys := make([]string, 0, len(groupID))
for _, groupID := range groupID {
keys = append(keys, g.getGroupMemberNumKey(groupID))
}
cache := g.CloneGroupCache()
cache.AddKeys(keys...)
return cache
}
func (g *GroupCacheRedis) GetGroupOwner(ctx context.Context, groupID string) (*model.GroupMember, error) {
members, err := g.GetGroupRoleLevelMemberInfo(ctx, groupID, constant.GroupOwner)
if err != nil {
return nil, err
}
if len(members) == 0 {
return nil, errs.ErrRecordNotFound.WrapMsg(fmt.Sprintf("group %s owner not found", groupID))
}
return members[0], nil
}
func (g *GroupCacheRedis) GetGroupsOwner(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error) {
members := make([]*model.GroupMember, 0, len(groupIDs))
for _, groupID := range groupIDs {
items, err := g.GetGroupRoleLevelMemberInfo(ctx, groupID, constant.GroupOwner)
if err != nil {
return nil, err
}
if len(items) > 0 {
members = append(members, items[0])
}
}
return members, nil
}
func (g *GroupCacheRedis) GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error) {
return getCache(ctx, g.rcClient, g.getGroupRoleLevelMemberIDsKey(groupID, roleLevel), g.expireTime, func(ctx context.Context) ([]string, error) {
return g.groupMemberDB.FindRoleLevelUserIDs(ctx, groupID, roleLevel)
})
}
func (g *GroupCacheRedis) GetGroupRoleLevelMemberInfo(ctx context.Context, groupID string, roleLevel int32) ([]*model.GroupMember, error) {
userIDs, err := g.GetGroupRoleLevelMemberIDs(ctx, groupID, roleLevel)
if err != nil {
return nil, err
}
return g.GetGroupMembersInfo(ctx, groupID, userIDs)
}
func (g *GroupCacheRedis) GetGroupRolesLevelMemberInfo(ctx context.Context, groupID string, roleLevels []int32) ([]*model.GroupMember, error) {
var userIDs []string
for _, roleLevel := range roleLevels {
ids, err := g.GetGroupRoleLevelMemberIDs(ctx, groupID, roleLevel)
if err != nil {
return nil, err
}
userIDs = append(userIDs, ids...)
}
return g.GetGroupMembersInfo(ctx, groupID, userIDs)
}
func (g *GroupCacheRedis) FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) ([]*model.GroupMember, error) {
if len(groupIDs) == 0 {
var err error
groupIDs, err = g.GetJoinedGroupIDs(ctx, userID)
if err != nil {
return nil, err
}
}
return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs, func(groupID string) string {
return g.getGroupMemberInfoKey(groupID, userID)
}, func(member *model.GroupMember) string {
return member.GroupID
}, func(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error) {
return g.groupMemberDB.FindInGroup(ctx, userID, groupIDs)
})
}
func (g *GroupCacheRedis) DelMaxGroupMemberVersion(groupIDs ...string) cache.GroupCache {
keys := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
keys = append(keys, g.getGroupMemberMaxVersionKey(groupID))
}
cache := g.CloneGroupCache()
cache.AddKeys(keys...)
return cache
}
func (g *GroupCacheRedis) DelMaxJoinGroupVersion(userIDs ...string) cache.GroupCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, g.getJoinGroupMaxVersionKey(userID))
}
cache := g.CloneGroupCache()
cache.AddKeys(keys...)
return cache
}
func (g *GroupCacheRedis) FindMaxGroupMemberVersion(ctx context.Context, groupID string) (*model.VersionLog, error) {
return getCache(ctx, g.rcClient, g.getGroupMemberMaxVersionKey(groupID), g.expireTime, func(ctx context.Context) (*model.VersionLog, error) {
return g.groupMemberDB.FindMemberIncrVersion(ctx, groupID, 0, 0)
})
}
func (g *GroupCacheRedis) BatchFindMaxGroupMemberVersion(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error) {
return batchGetCache2(ctx, g.rcClient, g.expireTime, groupIDs,
func(groupID string) string {
return g.getGroupMemberMaxVersionKey(groupID)
}, func(versionLog *model.VersionLog) string {
return versionLog.DID
}, func(ctx context.Context, groupIDs []string) ([]*model.VersionLog, error) {
// create two slices with len is groupIDs, just need 0
versions := make([]uint, len(groupIDs))
limits := make([]int, len(groupIDs))
return g.groupMemberDB.BatchFindMemberIncrVersion(ctx, groupIDs, versions, limits)
})
}
func (g *GroupCacheRedis) FindMaxJoinGroupVersion(ctx context.Context, userID string) (*model.VersionLog, error) {
return getCache(ctx, g.rcClient, g.getJoinGroupMaxVersionKey(userID), g.expireTime, func(ctx context.Context) (*model.VersionLog, error) {
return g.groupMemberDB.FindJoinIncrVersion(ctx, userID, 0, 0)
})
}

View File

@@ -0,0 +1,127 @@
package redis
import (
"context"
"errors"
"fmt"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
var (
setBatchWithCommonExpireScript = redis.NewScript(`
local expire = tonumber(ARGV[1])
for i, key in ipairs(KEYS) do
redis.call('SET', key, ARGV[i + 1])
redis.call('EXPIRE', key, expire)
end
return #KEYS
`)
setBatchWithIndividualExpireScript = redis.NewScript(`
local n = #KEYS
for i = 1, n do
redis.call('SET', KEYS[i], ARGV[i])
redis.call('EXPIRE', KEYS[i], ARGV[i + n])
end
return n
`)
deleteBatchScript = redis.NewScript(`
for i, key in ipairs(KEYS) do
redis.call('DEL', key)
end
return #KEYS
`)
getBatchScript = redis.NewScript(`
local values = {}
for i, key in ipairs(KEYS) do
local value = redis.call('GET', key)
table.insert(values, value)
end
return values
`)
)
func callLua(ctx context.Context, rdb redis.Scripter, script *redis.Script, keys []string, args []any) (any, error) {
log.ZDebug(ctx, "callLua args", "scriptHash", script.Hash(), "keys", keys, "args", args)
r := script.EvalSha(ctx, rdb, keys, args)
if redis.HasErrorPrefix(r.Err(), "NOSCRIPT") {
if err := script.Load(ctx, rdb).Err(); err != nil {
r = script.Eval(ctx, rdb, keys, args)
} else {
r = script.EvalSha(ctx, rdb, keys, args)
}
}
v, err := r.Result()
if errors.Is(err, redis.Nil) {
err = nil
}
return v, errs.WrapMsg(err, "call lua err", "scriptHash", script.Hash(), "keys", keys, "args", args)
}
func LuaSetBatchWithCommonExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expire int) error {
// Check if the lengths of keys and values match
if len(keys) != len(values) {
return errs.New("keys and values length mismatch").Wrap()
}
// Ensure allocation size does not overflow
maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems)
if len(values) > maxAllowedLen-1 {
return fmt.Errorf("values length is too large, causing overflow")
}
var vals = make([]any, 0, 1+len(values))
vals = append(vals, expire)
for _, v := range values {
vals = append(vals, v)
}
_, err := callLua(ctx, rdb, setBatchWithCommonExpireScript, keys, vals)
return err
}
func LuaSetBatchWithIndividualExpire(ctx context.Context, rdb redis.Scripter, keys []string, values []string, expires []int) error {
// Check if the lengths of keys, values, and expires match
if len(keys) != len(values) || len(keys) != len(expires) {
return errs.New("keys and values length mismatch").Wrap()
}
// Ensure the allocation size does not overflow
maxAllowedLen := (1 << 31) - 1 // 2GB limit (maximum address space for 32-bit systems)
if len(values) > maxAllowedLen-1 {
return errs.New(fmt.Sprintf("values length %d exceeds the maximum allowed length %d", len(values), maxAllowedLen-1)).Wrap()
}
var vals = make([]any, 0, len(values)+len(expires))
for _, v := range values {
vals = append(vals, v)
}
for _, ex := range expires {
vals = append(vals, ex)
}
_, err := callLua(ctx, rdb, setBatchWithIndividualExpireScript, keys, vals)
return err
}
func LuaDeleteBatch(ctx context.Context, rdb redis.Scripter, keys []string) error {
_, err := callLua(ctx, rdb, deleteBatchScript, keys, nil)
return err
}
func LuaGetBatch(ctx context.Context, rdb redis.Scripter, keys []string) ([]any, error) {
v, err := callLua(ctx, rdb, getBatchScript, keys, nil)
if err != nil {
return nil, err
}
values, ok := v.([]any)
if !ok {
return nil, servererrs.ErrArgs.WrapMsg("invalid lua get batch result")
}
return values, nil
}

View File

@@ -0,0 +1,75 @@
package redis
import (
"context"
"github.com/go-redis/redismock/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestLuaSetBatchWithCommonExpire(t *testing.T) {
rdb, mock := redismock.NewClientMock()
ctx := context.Background()
keys := []string{"key1", "key2"}
values := []string{"value1", "value2"}
expire := 10
mock.ExpectEvalSha(setBatchWithCommonExpireScript.Hash(), keys, []any{expire, "value1", "value2"}).SetVal(int64(len(keys)))
err := LuaSetBatchWithCommonExpire(ctx, rdb, keys, values, expire)
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestLuaSetBatchWithIndividualExpire(t *testing.T) {
rdb, mock := redismock.NewClientMock()
ctx := context.Background()
keys := []string{"key1", "key2"}
values := []string{"value1", "value2"}
expires := []int{10, 20}
args := make([]any, 0, len(values)+len(expires))
for _, v := range values {
args = append(args, v)
}
for _, ex := range expires {
args = append(args, ex)
}
mock.ExpectEvalSha(setBatchWithIndividualExpireScript.Hash(), keys, args).SetVal(int64(len(keys)))
err := LuaSetBatchWithIndividualExpire(ctx, rdb, keys, values, expires)
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestLuaDeleteBatch(t *testing.T) {
rdb, mock := redismock.NewClientMock()
ctx := context.Background()
keys := []string{"key1", "key2"}
mock.ExpectEvalSha(deleteBatchScript.Hash(), keys, []any{}).SetVal(int64(len(keys)))
err := LuaDeleteBatch(ctx, rdb, keys)
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
func TestLuaGetBatch(t *testing.T) {
rdb, mock := redismock.NewClientMock()
ctx := context.Background()
keys := []string{"key1", "key2"}
expectedValues := []any{"value1", "value2"}
mock.ExpectEvalSha(getBatchScript.Hash(), keys, []any{}).SetVal(expectedValues)
values, err := LuaGetBatch(ctx, rdb, keys)
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
assert.Equal(t, expectedValues, values)
}

59
pkg/common/storage/cache/redis/minio.go vendored Normal file
View File

@@ -0,0 +1,59 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/s3/minio"
"github.com/redis/go-redis/v9"
)
func NewMinioCache(rdb redis.UniversalClient) minio.Cache {
rc := newRocksCacheClient(rdb)
return &minioCacheRedis{
BatchDeleter: rc.GetBatchDeleter(),
rcClient: rc,
expireTime: time.Hour * 24 * 7,
}
}
type minioCacheRedis struct {
cache.BatchDeleter
rcClient *rocksCacheClient
expireTime time.Duration
}
func (g *minioCacheRedis) getObjectImageInfoKey(key string) string {
return cachekey.GetObjectImageInfoKey(key)
}
func (g *minioCacheRedis) getMinioImageThumbnailKey(key string, format string, width int, height int) string {
return cachekey.GetMinioImageThumbnailKey(key, format, width, height)
}
func (g *minioCacheRedis) DelObjectImageInfoKey(ctx context.Context, keys ...string) error {
ks := make([]string, 0, len(keys))
for _, key := range keys {
ks = append(ks, g.getObjectImageInfoKey(key))
}
return g.BatchDeleter.ExecDelWithKeys(ctx, ks)
}
func (g *minioCacheRedis) DelImageThumbnailKey(ctx context.Context, key string, format string, width int, height int) error {
return g.BatchDeleter.ExecDelWithKeys(ctx, []string{g.getMinioImageThumbnailKey(key, format, width, height)})
}
func (g *minioCacheRedis) GetImageObjectKeyInfo(ctx context.Context, key string, fn func(ctx context.Context) (*minio.ImageInfo, error)) (*minio.ImageInfo, error) {
info, err := getCache(ctx, g.rcClient, g.getObjectImageInfoKey(key), g.expireTime, fn)
if err != nil {
return nil, err
}
return info, nil
}
func (g *minioCacheRedis) GetThumbnailKey(ctx context.Context, key string, format string, width int, height int, minioCache func(ctx context.Context) (string, error)) (string, error) {
return getCache(ctx, g.rcClient, g.getMinioImageThumbnailKey(key, format, width, height), g.expireTime, minioCache)
}

94
pkg/common/storage/cache/redis/msg.go vendored Normal file
View File

@@ -0,0 +1,94 @@
package redis
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
) //
// msgCacheTimeout is expiration time of message cache, 86400 seconds
const msgCacheTimeout = time.Hour * 24
func NewMsgCache(client redis.UniversalClient, db database.Msg) cache.MsgCache {
return &msgCache{
rcClient: newRocksCacheClient(client),
msgDocDatabase: db,
}
}
type msgCache struct {
rcClient *rocksCacheClient
msgDocDatabase database.Msg
}
func (c *msgCache) getSendMsgKey(id string) string {
return cachekey.GetSendMsgKey(id)
}
func (c *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32) error {
return errs.Wrap(c.rcClient.GetRedis().Set(ctx, c.getSendMsgKey(id), status, time.Hour*24).Err())
}
func (c *msgCache) GetSendMsgStatus(ctx context.Context, id string) (int32, error) {
result, err := c.rcClient.GetRedis().Get(ctx, c.getSendMsgKey(id)).Int()
return int32(result), errs.Wrap(err)
}
func (c *msgCache) GetMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) ([]*model.MsgInfoModel, error) {
if len(seqs) == 0 {
return nil, nil
}
getKey := func(seq int64) string {
return cachekey.GetMsgCacheKey(conversationID, seq)
}
getMsgID := func(msg *model.MsgInfoModel) int64 {
return msg.Msg.Seq
}
find := func(ctx context.Context, seqs []int64) ([]*model.MsgInfoModel, error) {
return c.msgDocDatabase.FindSeqs(ctx, conversationID, seqs)
}
return batchGetCache2(ctx, c.rcClient, msgCacheTimeout, seqs, getKey, getMsgID, find)
}
func (c *msgCache) DelMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) error {
if len(seqs) == 0 {
return nil
}
keys := datautil.Slice(seqs, func(seq int64) string {
return cachekey.GetMsgCacheKey(conversationID, seq)
})
slotKeys, err := groupKeysBySlot(ctx, c.rcClient.GetRedis(), keys)
if err != nil {
return err
}
for _, keys := range slotKeys {
if err := c.rcClient.GetClient().TagAsDeletedBatch2(ctx, keys); err != nil {
return err
}
}
return nil
}
func (c *msgCache) SetMessageBySeqs(ctx context.Context, conversationID string, msgs []*model.MsgInfoModel) error {
for _, msg := range msgs {
if msg == nil || msg.Msg == nil || msg.Msg.Seq <= 0 {
continue
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
if err := c.rcClient.GetClient().RawSet(ctx, cachekey.GetMsgCacheKey(conversationID, msg.Msg.Seq), string(data), msgCacheTimeout); err != nil {
return err
}
}
return nil
}

161
pkg/common/storage/cache/redis/online.go vendored Normal file
View File

@@ -0,0 +1,161 @@
package redis
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
func NewUserOnline(rdb redis.UniversalClient) cache.OnlineCache {
if rdb == nil || config.Standalone() {
return mcache.NewOnlineCache()
}
return &userOnline{
rdb: rdb,
expire: cachekey.OnlineExpire,
channelName: cachekey.OnlineChannel,
}
}
type userOnline struct {
rdb redis.UniversalClient
expire time.Duration
channelName string
}
func (s *userOnline) getUserOnlineKey(userID string) string {
return cachekey.GetOnlineKey(userID)
}
func (s *userOnline) GetOnline(ctx context.Context, userID string) ([]int32, error) {
members, err := s.rdb.ZRangeByScore(ctx, s.getUserOnlineKey(userID), &redis.ZRangeBy{
Min: strconv.FormatInt(time.Now().Unix(), 10),
Max: "+inf",
}).Result()
if err != nil {
return nil, errs.Wrap(err)
}
platformIDs := make([]int32, 0, len(members))
for _, member := range members {
val, err := strconv.Atoi(member)
if err != nil {
return nil, errs.Wrap(err)
}
platformIDs = append(platformIDs, int32(val))
}
return platformIDs, nil
}
func (s *userOnline) GetAllOnlineUsers(ctx context.Context, cursor uint64) (map[string][]int32, uint64, error) {
result := make(map[string][]int32)
keys, nextCursor, err := s.rdb.Scan(ctx, cursor, fmt.Sprintf("%s*", cachekey.OnlineKey), constant.ParamMaxLength).Result()
if err != nil {
return nil, 0, err
}
for _, key := range keys {
userID := cachekey.GetOnlineKeyUserID(key)
strValues, err := s.rdb.ZRange(ctx, key, 0, -1).Result()
if err != nil {
return nil, 0, err
}
values := make([]int32, 0, len(strValues))
for _, value := range strValues {
intValue, err := strconv.Atoi(value)
if err != nil {
return nil, 0, errs.Wrap(err)
}
values = append(values, int32(intValue))
}
result[userID] = values
}
return result, nextCursor, nil
}
func (s *userOnline) SetUserOnline(ctx context.Context, userID string, online, offline []int32) error {
// 使用Lua脚本原子更新在线状态与在线人数缓存
script := `
local key = KEYS[1]
local countKey = KEYS[2]
local expire = tonumber(ARGV[1])
local now = ARGV[2]
local score = ARGV[3]
local offlineLen = tonumber(ARGV[4])
redis.call("ZREMRANGEBYSCORE", key, "-inf", now)
for i = 5, offlineLen+4 do
redis.call("ZREM", key, ARGV[i])
end
local before = redis.call("ZCARD", key)
for i = 5+offlineLen, #ARGV do
redis.call("ZADD", key, score, ARGV[i])
end
redis.call("EXPIRE", key, expire)
local after = redis.call("ZCARD", key)
local current = redis.call("GET", countKey)
if not current then
current = 0
else
current = tonumber(current)
end
if before == 0 and after > 0 then
redis.call("SET", countKey, current + 1)
elseif before > 0 and after == 0 then
local next = current - 1
if next < 0 then
next = 0
end
redis.call("SET", countKey, next)
end
if before ~= after then
local members = redis.call("ZRANGE", key, 0, -1)
table.insert(members, "1")
return members
else
return {"0"}
end
`
now := time.Now()
argv := make([]any, 0, 2+len(online)+len(offline))
argv = append(argv, int32(s.expire/time.Second), now.Unix(), now.Add(s.expire).Unix(), int32(len(offline)))
for _, platformID := range offline {
argv = append(argv, platformID)
}
for _, platformID := range online {
argv = append(argv, platformID)
}
keys := []string{s.getUserOnlineKey(userID), cachekey.OnlineUserCountKey}
platformIDs, err := s.rdb.Eval(ctx, script, keys, argv).StringSlice()
if err != nil {
log.ZError(ctx, "redis SetUserOnline", err, "userID", userID, "online", online, "offline", offline)
return err
}
if len(platformIDs) == 0 {
return errs.ErrInternalServer.WrapMsg("SetUserOnline redis lua invalid return value")
}
if platformIDs[len(platformIDs)-1] != "0" {
log.ZDebug(ctx, "redis SetUserOnline push", "userID", userID, "online", online, "offline", offline, "platformIDs", platformIDs[:len(platformIDs)-1])
platformIDs[len(platformIDs)-1] = userID
msg := strings.Join(platformIDs, ":")
if err := s.rdb.Publish(ctx, s.channelName, msg).Err(); err != nil {
return errs.Wrap(err)
}
} else {
log.ZDebug(ctx, "redis SetUserOnline not push", "userID", userID, "online", online, "offline", offline)
}
return nil
}

View File

@@ -0,0 +1,149 @@
package redis
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
)
const onlineUserCountHistorySeparator = ":"
// OnlineUserCountSample 在线人数历史采样点
type OnlineUserCountSample struct {
// Timestamp 采样时间(毫秒时间戳)
Timestamp int64
// Count 采样在线人数
Count int64
}
// GetOnlineUserCount 读取在线人数缓存
func GetOnlineUserCount(ctx context.Context, rdb redis.UniversalClient) (int64, error) {
if rdb == nil {
return 0, errs.ErrInternalServer.WrapMsg("redis client is nil")
}
val, err := rdb.Get(ctx, cachekey.OnlineUserCountKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return 0, err
}
return 0, errs.Wrap(err)
}
count, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, errs.WrapMsg(err, "parse online user count failed")
}
return count, nil
}
// RefreshOnlineUserCount 刷新在线人数缓存
func RefreshOnlineUserCount(ctx context.Context, rdb redis.UniversalClient) (int64, error) {
if rdb == nil {
return 0, errs.ErrInternalServer.WrapMsg("redis client is nil")
}
var (
cursor uint64
total int64
)
now := strconv.FormatInt(time.Now().Unix(), 10)
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, fmt.Sprintf("%s*", cachekey.OnlineKey), constant.ParamMaxLength).Result()
if err != nil {
return 0, errs.Wrap(err)
}
for _, key := range keys {
count, err := rdb.ZCount(ctx, key, now, "+inf").Result()
if err != nil {
return 0, errs.Wrap(err)
}
if count > 0 {
total++
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
if err := rdb.Set(ctx, cachekey.OnlineUserCountKey, total, 0).Err(); err != nil {
return 0, errs.Wrap(err)
}
return total, nil
}
// AppendOnlineUserCountHistory 写入在线人数历史采样
func AppendOnlineUserCountHistory(ctx context.Context, rdb redis.UniversalClient, timestamp int64, count int64) error {
if rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client is nil")
}
if timestamp <= 0 {
return errs.ErrArgs.WrapMsg("invalid timestamp")
}
member := fmt.Sprintf("%d%s%d", timestamp, onlineUserCountHistorySeparator, count)
if err := rdb.ZAdd(ctx, cachekey.OnlineUserCountHistoryKey, redis.Z{
Score: float64(timestamp),
Member: member,
}).Err(); err != nil {
return errs.Wrap(err)
}
// 清理历史数据,避免无界增长
retentionMs := int64(cachekey.OnlineUserCountHistoryRetention / time.Millisecond)
cutoff := timestamp - retentionMs
if cutoff > 0 {
if err := rdb.ZRemRangeByScore(ctx, cachekey.OnlineUserCountHistoryKey, "0", strconv.FormatInt(cutoff, 10)).Err(); err != nil {
return errs.Wrap(err)
}
}
return nil
}
// GetOnlineUserCountHistory 读取在线人数历史采样
func GetOnlineUserCountHistory(ctx context.Context, rdb redis.UniversalClient, startTime int64, endTime int64) ([]OnlineUserCountSample, error) {
if rdb == nil {
return nil, errs.ErrInternalServer.WrapMsg("redis client is nil")
}
if startTime <= 0 || endTime <= 0 || endTime <= startTime {
return nil, nil
}
// 包含endTime的数据使用endTime作为最大值
values, err := rdb.ZRangeByScore(ctx, cachekey.OnlineUserCountHistoryKey, &redis.ZRangeBy{
Min: strconv.FormatInt(startTime, 10),
Max: strconv.FormatInt(endTime, 10),
}).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, errs.Wrap(err)
}
if len(values) == 0 {
return nil, nil
}
samples := make([]OnlineUserCountSample, 0, len(values))
for _, val := range values {
parts := strings.SplitN(val, onlineUserCountHistorySeparator, 2)
if len(parts) != 2 {
continue
}
ts, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
continue
}
cnt, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
samples = append(samples, OnlineUserCountSample{
Timestamp: ts,
Count: cnt,
})
}
return samples, nil
}

View File

@@ -0,0 +1,52 @@
package redis
import (
"context"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"github.com/openimsdk/tools/db/redisutil"
)
/*
address: [ 172.16.8.48:7001, 172.16.8.48:7002, 172.16.8.48:7003, 172.16.8.48:7004, 172.16.8.48:7005, 172.16.8.48:7006 ]
username:
password: passwd123
clusterMode: true
db: 0
maxRetry: 10
*/
func TestName111111(t *testing.T) {
conf := config.Redis{
Address: []string{
"172.16.8.124:7001",
"172.16.8.124:7002",
"172.16.8.124:7003",
"172.16.8.124:7004",
"172.16.8.124:7005",
"172.16.8.124:7006",
},
RedisMode: "cluster",
Password: "passwd123",
//Address: []string{"localhost:16379"},
//Password: "openIM123",
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1000)
defer cancel()
rdb, err := redisutil.NewRedisClient(ctx, conf.Build())
if err != nil {
panic(err)
}
online := NewUserOnline(rdb)
userID := "a123456"
t.Log(online.GetOnline(ctx, userID))
t.Log(online.SetUserOnline(ctx, userID, []int32{1, 2, 3, 4}, nil))
t.Log(online.GetOnline(ctx, userID))
}
func TestName111(t *testing.T) {
}

View File

@@ -0,0 +1,211 @@
package redis
import (
"context"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/errgroup"
)
const (
defaultBatchSize = 50
defaultConcurrentLimit = 3
)
// RedisShardManager is a class for sharding and processing keys
type RedisShardManager struct {
redisClient redis.UniversalClient
config *Config
}
type Config struct {
batchSize int
continueOnError bool
concurrentLimit int
}
// Option is a function type for configuring Config
type Option func(c *Config)
//// NewRedisShardManager creates a new RedisShardManager instance
//func NewRedisShardManager(redisClient redis.UniversalClient, opts ...Option) *RedisShardManager {
// config := &Config{
// batchSize: defaultBatchSize, // Default batch size is 50 keys
// continueOnError: false,
// concurrentLimit: defaultConcurrentLimit, // Default concurrent limit is 3
// }
// for _, opt := range opts {
// opt(config)
// }
// rsm := &RedisShardManager{
// redisClient: redisClient,
// config: config,
// }
// return rsm
//}
//
//// WithBatchSize sets the number of keys to process per batch
//func WithBatchSize(size int) Option {
// return func(c *Config) {
// c.batchSize = size
// }
//}
//
//// WithContinueOnError sets whether to continue processing on error
//func WithContinueOnError(continueOnError bool) Option {
// return func(c *Config) {
// c.continueOnError = continueOnError
// }
//}
//
//// WithConcurrentLimit sets the concurrency limit
//func WithConcurrentLimit(limit int) Option {
// return func(c *Config) {
// c.concurrentLimit = limit
// }
//}
//
//// ProcessKeysBySlot groups keys by their Redis cluster hash slots and processes them using the provided function.
//func (rsm *RedisShardManager) ProcessKeysBySlot(
// ctx context.Context,
// keys []string,
// processFunc func(ctx context.Context, slot int64, keys []string) error,
//) error {
//
// // Group keys by slot
// slots, err := groupKeysBySlot(ctx, rsm.redisClient, keys)
// if err != nil {
// return err
// }
//
// g, ctx := errgroup.WithContext(ctx)
// g.SetLimit(rsm.config.concurrentLimit)
//
// // Process keys in each slot using the provided function
// for slot, singleSlotKeys := range slots {
// batches := splitIntoBatches(singleSlotKeys, rsm.config.batchSize)
// for _, batch := range batches {
// slot, batch := slot, batch // Avoid closure capture issue
// g.Go(func() error {
// err := processFunc(ctx, slot, batch)
// if err != nil {
// log.ZWarn(ctx, "Batch processFunc failed", err, "slot", slot, "keys", batch)
// if !rsm.config.continueOnError {
// return err
// }
// }
// return nil
// })
// }
// }
//
// if err := g.Wait(); err != nil {
// return err
// }
// return nil
//}
// groupKeysBySlot groups keys by their Redis cluster hash slots.
func groupKeysBySlot(ctx context.Context, redisClient redis.UniversalClient, keys []string) (map[int64][]string, error) {
slots := make(map[int64][]string)
clusterClient, isCluster := redisClient.(*redis.ClusterClient)
if isCluster && len(keys) > 1 {
pipe := clusterClient.Pipeline()
cmds := make([]*redis.IntCmd, len(keys))
for i, key := range keys {
cmds[i] = pipe.ClusterKeySlot(ctx, key)
}
_, err := pipe.Exec(ctx)
if err != nil {
return nil, errs.WrapMsg(err, "get slot err")
}
for i, cmd := range cmds {
slot, err := cmd.Result()
if err != nil {
log.ZWarn(ctx, "some key get slot err", err, "key", keys[i])
return nil, errs.WrapMsg(err, "get slot err", "key", keys[i])
}
slots[slot] = append(slots[slot], keys[i])
}
} else {
// If not a cluster client, put all keys in the same slot (0)
slots[0] = keys
}
return slots, nil
}
// splitIntoBatches splits keys into batches of the specified size
func splitIntoBatches(keys []string, batchSize int) [][]string {
var batches [][]string
for batchSize < len(keys) {
keys, batches = keys[batchSize:], append(batches, keys[0:batchSize:batchSize])
}
return append(batches, keys)
}
// ProcessKeysBySlot groups keys by their Redis cluster hash slots and processes them using the provided function.
func ProcessKeysBySlot(
ctx context.Context,
redisClient redis.UniversalClient,
keys []string,
processFunc func(ctx context.Context, slot int64, keys []string) error,
opts ...Option,
) error {
config := &Config{
batchSize: defaultBatchSize,
continueOnError: false,
concurrentLimit: defaultConcurrentLimit,
}
for _, opt := range opts {
opt(config)
}
// Group keys by slot
slots, err := groupKeysBySlot(ctx, redisClient, keys)
if err != nil {
return err
}
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(config.concurrentLimit)
// Process keys in each slot using the provided function
for slot, singleSlotKeys := range slots {
batches := splitIntoBatches(singleSlotKeys, config.batchSize)
for _, batch := range batches {
slot, batch := slot, batch // Avoid closure capture issue
g.Go(func() error {
err := processFunc(ctx, slot, batch)
if err != nil {
log.ZWarn(ctx, "Batch processFunc failed", err, "slot", slot, "keys", batch)
if !config.continueOnError {
return err
}
}
return nil
})
}
}
if err := g.Wait(); err != nil {
return err
}
return nil
}
func DeleteCacheBySlot(ctx context.Context, rcClient *rocksCacheClient, keys []string) error {
switch len(keys) {
case 0:
return nil
case 1:
return rcClient.GetClient().TagAsDeletedBatch2(ctx, keys)
default:
return ProcessKeysBySlot(ctx, rcClient.GetRedis(), keys, func(ctx context.Context, slot int64, keys []string) error {
return rcClient.GetClient().TagAsDeletedBatch2(ctx, keys)
})
}
}

95
pkg/common/storage/cache/redis/s3.go vendored Normal file
View File

@@ -0,0 +1,95 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/s3"
"github.com/openimsdk/tools/s3/cont"
"github.com/redis/go-redis/v9"
)
func NewObjectCacheRedis(rdb redis.UniversalClient, objDB database.ObjectInfo) cache.ObjectCache {
rc := newRocksCacheClient(rdb)
return &objectCacheRedis{
BatchDeleter: rc.GetBatchDeleter(),
rcClient: rc,
expireTime: time.Hour * 12,
objDB: objDB,
}
}
type objectCacheRedis struct {
cache.BatchDeleter
objDB database.ObjectInfo
rcClient *rocksCacheClient
expireTime time.Duration
}
func (g *objectCacheRedis) getObjectKey(engine string, name string) string {
return cachekey.GetObjectKey(engine, name)
}
func (g *objectCacheRedis) CloneObjectCache() cache.ObjectCache {
return &objectCacheRedis{
BatchDeleter: g.BatchDeleter.Clone(),
rcClient: g.rcClient,
expireTime: g.expireTime,
objDB: g.objDB,
}
}
func (g *objectCacheRedis) DelObjectName(engine string, names ...string) cache.ObjectCache {
objectCache := g.CloneObjectCache()
keys := make([]string, 0, len(names))
for _, name := range names {
keys = append(keys, g.getObjectKey(name, engine))
}
objectCache.AddKeys(keys...)
return objectCache
}
func (g *objectCacheRedis) GetName(ctx context.Context, engine string, name string) (*model.Object, error) {
return getCache(ctx, g.rcClient, g.getObjectKey(name, engine), g.expireTime, func(ctx context.Context) (*model.Object, error) {
return g.objDB.Take(ctx, engine, name)
})
}
func NewS3Cache(rdb redis.UniversalClient, s3 s3.Interface) cont.S3Cache {
rc := newRocksCacheClient(rdb)
return &s3CacheRedis{
BatchDeleter: rc.GetBatchDeleter(),
rcClient: rc,
expireTime: time.Hour * 12,
s3: s3,
}
}
type s3CacheRedis struct {
cache.BatchDeleter
s3 s3.Interface
rcClient *rocksCacheClient
expireTime time.Duration
}
func (g *s3CacheRedis) getS3Key(engine string, name string) string {
return cachekey.GetS3Key(engine, name)
}
func (g *s3CacheRedis) DelS3Key(ctx context.Context, engine string, keys ...string) error {
ks := make([]string, 0, len(keys))
for _, key := range keys {
ks = append(ks, g.getS3Key(engine, key))
}
return g.BatchDeleter.ExecDelWithKeys(ctx, ks)
}
func (g *s3CacheRedis) GetKey(ctx context.Context, engine string, name string) (*s3.ObjectInfo, error) {
return getCache(ctx, g.rcClient, g.getS3Key(engine, name), g.expireTime, func(ctx context.Context) (*s3.ObjectInfo, error) {
return g.s3.StatObject(ctx, name)
})
}

View File

@@ -0,0 +1,521 @@
package redis
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
func NewSeqConversationCacheRedis(rdb redis.UniversalClient, mgo database.SeqConversation) cache.SeqConversationCache {
if rdb == nil {
return mcache.NewSeqConversationCache(mgo)
}
return &seqConversationCacheRedis{
mgo: mgo,
lockTime: time.Second * 3,
dataTime: time.Hour * 24 * 365,
minSeqExpireTime: time.Hour,
rcClient: newRocksCacheClient(rdb),
}
}
type seqConversationCacheRedis struct {
mgo database.SeqConversation
rcClient *rocksCacheClient
lockTime time.Duration
dataTime time.Duration
minSeqExpireTime time.Duration
}
func (s *seqConversationCacheRedis) getMinSeqKey(conversationID string) string {
return cachekey.GetMallocMinSeqKey(conversationID)
}
func (s *seqConversationCacheRedis) SetMinSeq(ctx context.Context, conversationID string, seq int64) error {
return s.SetMinSeqs(ctx, map[string]int64{conversationID: seq})
}
func (s *seqConversationCacheRedis) GetMinSeq(ctx context.Context, conversationID string) (int64, error) {
return getCache(ctx, s.rcClient, s.getMinSeqKey(conversationID), s.minSeqExpireTime, func(ctx context.Context) (int64, error) {
return s.mgo.GetMinSeq(ctx, conversationID)
})
}
func (s *seqConversationCacheRedis) getSingleMaxSeq(ctx context.Context, conversationID string) (map[string]int64, error) {
seq, err := s.GetMaxSeq(ctx, conversationID)
if err != nil {
return nil, err
}
return map[string]int64{conversationID: seq}, nil
}
func (s *seqConversationCacheRedis) getSingleMaxSeqWithTime(ctx context.Context, conversationID string) (map[string]database.SeqTime, error) {
seq, err := s.GetMaxSeqWithTime(ctx, conversationID)
if err != nil {
return nil, err
}
return map[string]database.SeqTime{conversationID: seq}, nil
}
func (s *seqConversationCacheRedis) batchGetMaxSeq(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]int64) error {
result := make([]*redis.StringCmd, len(keys))
pipe := s.rcClient.GetRedis().Pipeline()
for i, key := range keys {
result[i] = pipe.HGet(ctx, key, "CURR")
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return errs.Wrap(err)
}
var notFoundKey []string
for i, r := range result {
req, err := r.Int64()
if err == nil {
seqs[keyConversationID[keys[i]]] = req
} else if errors.Is(err, redis.Nil) {
notFoundKey = append(notFoundKey, keys[i])
} else {
return errs.Wrap(err)
}
}
for _, key := range notFoundKey {
conversationID := keyConversationID[key]
seq, err := s.GetMaxSeq(ctx, conversationID)
if err != nil {
return err
}
seqs[conversationID] = seq
}
return nil
}
func (s *seqConversationCacheRedis) batchGetMaxSeqWithTime(ctx context.Context, keys []string, keyConversationID map[string]string, seqs map[string]database.SeqTime) error {
result := make([]*redis.SliceCmd, len(keys))
pipe := s.rcClient.GetRedis().Pipeline()
for i, key := range keys {
result[i] = pipe.HMGet(ctx, key, "CURR", "TIME")
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return errs.Wrap(err)
}
var notFoundKey []string
for i, r := range result {
val, err := r.Result()
if len(val) != 2 {
return errs.WrapMsg(err, "batchGetMaxSeqWithTime invalid result", "key", keys[i], "res", val)
}
if val[0] == nil {
notFoundKey = append(notFoundKey, keys[i])
continue
}
seq, err := s.parseInt64(val[0])
if err != nil {
return err
}
mill, err := s.parseInt64(val[1])
if err != nil {
return err
}
seqs[keyConversationID[keys[i]]] = database.SeqTime{Seq: seq, Time: mill}
}
for _, key := range notFoundKey {
conversationID := keyConversationID[key]
seq, err := s.GetMaxSeqWithTime(ctx, conversationID)
if err != nil {
return err
}
seqs[conversationID] = seq
}
return nil
}
func (s *seqConversationCacheRedis) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
switch len(conversationIDs) {
case 0:
return map[string]int64{}, nil
case 1:
return s.getSingleMaxSeq(ctx, conversationIDs[0])
}
keys := make([]string, 0, len(conversationIDs))
keyConversationID := make(map[string]string, len(conversationIDs))
for _, conversationID := range conversationIDs {
key := s.getSeqMallocKey(conversationID)
if _, ok := keyConversationID[key]; ok {
continue
}
keys = append(keys, key)
keyConversationID[key] = conversationID
}
if len(keys) == 1 {
return s.getSingleMaxSeq(ctx, conversationIDs[0])
}
slotKeys, err := groupKeysBySlot(ctx, s.rcClient.GetRedis(), keys)
if err != nil {
return nil, err
}
seqs := make(map[string]int64, len(conversationIDs))
for _, keys := range slotKeys {
if err := s.batchGetMaxSeq(ctx, keys, keyConversationID, seqs); err != nil {
return nil, err
}
}
return seqs, nil
}
func (s *seqConversationCacheRedis) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
switch len(conversationIDs) {
case 0:
return map[string]database.SeqTime{}, nil
case 1:
return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0])
}
keys := make([]string, 0, len(conversationIDs))
keyConversationID := make(map[string]string, len(conversationIDs))
for _, conversationID := range conversationIDs {
key := s.getSeqMallocKey(conversationID)
if _, ok := keyConversationID[key]; ok {
continue
}
keys = append(keys, key)
keyConversationID[key] = conversationID
}
if len(keys) == 1 {
return s.getSingleMaxSeqWithTime(ctx, conversationIDs[0])
}
slotKeys, err := groupKeysBySlot(ctx, s.rcClient.GetRedis(), keys)
if err != nil {
return nil, err
}
seqs := make(map[string]database.SeqTime, len(conversationIDs))
for _, keys := range slotKeys {
if err := s.batchGetMaxSeqWithTime(ctx, keys, keyConversationID, seqs); err != nil {
return nil, err
}
}
return seqs, nil
}
func (s *seqConversationCacheRedis) getSeqMallocKey(conversationID string) string {
return cachekey.GetMallocSeqKey(conversationID)
}
func (s *seqConversationCacheRedis) setSeq(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64, mill int64) (int64, error) {
if lastSeq < currSeq {
return 0, errs.New("lastSeq must be greater than currSeq")
}
// 0: success
// 1: success the lock has expired, but has not been locked by anyone else
// 2: already locked, but not by yourself
script := `
local key = KEYS[1]
local lockValue = ARGV[1]
local dataSecond = ARGV[2]
local curr_seq = tonumber(ARGV[3])
local last_seq = tonumber(ARGV[4])
local mallocTime = ARGV[5]
if redis.call("EXISTS", key) == 0 then
redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq, "TIME", mallocTime)
redis.call("EXPIRE", key, dataSecond)
return 1
end
if redis.call("HGET", key, "LOCK") ~= lockValue then
return 2
end
redis.call("HDEL", key, "LOCK")
redis.call("HSET", key, "CURR", curr_seq, "LAST", last_seq, "TIME", mallocTime)
redis.call("EXPIRE", key, dataSecond)
return 0
`
result, err := s.rcClient.GetRedis().Eval(ctx, script, []string{key}, owner, int64(s.dataTime/time.Second), currSeq, lastSeq, mill).Int64()
if err != nil {
return 0, errs.Wrap(err)
}
return result, nil
}
// malloc size=0 is to get the current seq size>0 is to allocate seq
func (s *seqConversationCacheRedis) malloc(ctx context.Context, key string, size int64) ([]int64, error) {
// 0: success
// 1: need to obtain and lock
// 2: already locked
// 3: exceeded the maximum value and locked
script := `
local key = KEYS[1]
local size = tonumber(ARGV[1])
local lockSecond = ARGV[2]
local dataSecond = ARGV[3]
local mallocTime = ARGV[4]
local result = {}
if redis.call("EXISTS", key) == 0 then
local lockValue = math.random(0, 999999999)
redis.call("HSET", key, "LOCK", lockValue)
redis.call("EXPIRE", key, lockSecond)
table.insert(result, 1)
table.insert(result, lockValue)
table.insert(result, mallocTime)
return result
end
if redis.call("HEXISTS", key, "LOCK") == 1 then
table.insert(result, 2)
return result
end
local curr_seq = tonumber(redis.call("HGET", key, "CURR"))
local last_seq = tonumber(redis.call("HGET", key, "LAST"))
if size == 0 then
redis.call("EXPIRE", key, dataSecond)
table.insert(result, 0)
table.insert(result, curr_seq)
table.insert(result, last_seq)
local setTime = redis.call("HGET", key, "TIME")
if setTime then
table.insert(result, setTime)
else
table.insert(result, 0)
end
return result
end
local max_seq = curr_seq + size
if max_seq > last_seq then
local lockValue = math.random(0, 999999999)
redis.call("HSET", key, "LOCK", lockValue)
redis.call("HSET", key, "CURR", last_seq)
redis.call("HSET", key, "TIME", mallocTime)
redis.call("EXPIRE", key, lockSecond)
table.insert(result, 3)
table.insert(result, curr_seq)
table.insert(result, last_seq)
table.insert(result, lockValue)
table.insert(result, mallocTime)
return result
end
redis.call("HSET", key, "CURR", max_seq)
redis.call("HSET", key, "TIME", ARGV[4])
redis.call("EXPIRE", key, dataSecond)
table.insert(result, 0)
table.insert(result, curr_seq)
table.insert(result, last_seq)
table.insert(result, mallocTime)
return result
`
result, err := s.rcClient.GetRedis().Eval(ctx, script, []string{key}, size, int64(s.lockTime/time.Second), int64(s.dataTime/time.Second), time.Now().UnixMilli()).Int64Slice()
if err != nil {
return nil, errs.Wrap(err)
}
return result, nil
}
func (s *seqConversationCacheRedis) wait(ctx context.Context) error {
timer := time.NewTimer(time.Second / 4)
defer timer.Stop()
select {
case <-timer.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (s *seqConversationCacheRedis) setSeqRetry(ctx context.Context, key string, owner int64, currSeq int64, lastSeq int64, mill int64) {
for i := 0; i < 10; i++ {
state, err := s.setSeq(ctx, key, owner, currSeq, lastSeq, mill)
if err != nil {
log.ZError(ctx, "set seq cache failed", err, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq, "count", i+1)
if err := s.wait(ctx); err != nil {
return
}
continue
}
switch state {
case 0: // ideal state
case 1:
log.ZWarn(ctx, "set seq cache lock not found", nil, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq)
case 2:
log.ZWarn(ctx, "set seq cache lock to be held by someone else", nil, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq)
default:
log.ZError(ctx, "set seq cache lock unknown state", nil, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq)
}
return
}
log.ZError(ctx, "set seq cache retrying still failed", nil, "key", key, "owner", owner, "currSeq", currSeq, "lastSeq", lastSeq)
}
func (s *seqConversationCacheRedis) getMallocSize(conversationID string, size int64) int64 {
if size == 0 {
return 0
}
var basicSize int64
if msgprocessor.IsGroupConversationID(conversationID) {
basicSize = 100
} else {
basicSize = 50
}
basicSize += size
return basicSize
}
func (s *seqConversationCacheRedis) Malloc(ctx context.Context, conversationID string, size int64) (int64, error) {
seq, _, err := s.mallocTime(ctx, conversationID, size)
return seq, err
}
func (s *seqConversationCacheRedis) mallocTime(ctx context.Context, conversationID string, size int64) (int64, int64, error) {
if size < 0 {
return 0, 0, errs.New("size must be greater than 0")
}
key := s.getSeqMallocKey(conversationID)
for i := 0; i < 10; i++ {
states, err := s.malloc(ctx, key, size)
if err != nil {
return 0, 0, err
}
switch states[0] {
case 0: // success
return states[1], states[3], nil
case 1: // not found
mallocSize := s.getMallocSize(conversationID, size)
seq, err := s.mgo.Malloc(ctx, conversationID, mallocSize)
if err != nil {
return 0, 0, err
}
s.setSeqRetry(ctx, key, states[1], seq+size, seq+mallocSize, states[2])
return seq, 0, nil
case 2: // locked
if err := s.wait(ctx); err != nil {
return 0, 0, err
}
continue
case 3: // exceeded cache max value
currSeq := states[1]
lastSeq := states[2]
mill := states[4]
mallocSize := s.getMallocSize(conversationID, size)
seq, err := s.mgo.Malloc(ctx, conversationID, mallocSize)
if err != nil {
return 0, 0, err
}
if lastSeq == seq {
s.setSeqRetry(ctx, key, states[3], currSeq+size, seq+mallocSize, mill)
return currSeq, states[4], nil
} else {
log.ZWarn(ctx, "malloc seq not equal cache last seq", nil, "conversationID", conversationID, "currSeq", currSeq, "lastSeq", lastSeq, "mallocSeq", seq)
s.setSeqRetry(ctx, key, states[3], seq+size, seq+mallocSize, mill)
return seq, mill, nil
}
default:
log.ZError(ctx, "malloc seq unknown state", nil, "state", states[0], "conversationID", conversationID, "size", size)
return 0, 0, errs.New(fmt.Sprintf("unknown state: %d", states[0]))
}
}
log.ZError(ctx, "malloc seq retrying still failed", nil, "conversationID", conversationID, "size", size)
return 0, 0, errs.New("malloc seq waiting for lock timeout", "conversationID", conversationID, "size", size)
}
func (s *seqConversationCacheRedis) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) {
return s.Malloc(ctx, conversationID, 0)
}
func (s *seqConversationCacheRedis) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) {
seq, mill, err := s.mallocTime(ctx, conversationID, 0)
if err != nil {
return database.SeqTime{}, err
}
return database.SeqTime{Seq: seq, Time: mill}, nil
}
func (s *seqConversationCacheRedis) SetMinSeqs(ctx context.Context, seqs map[string]int64) error {
keys := make([]string, 0, len(seqs))
for conversationID, seq := range seqs {
keys = append(keys, s.getMinSeqKey(conversationID))
if err := s.mgo.SetMinSeq(ctx, conversationID, seq); err != nil {
return err
}
}
return DeleteCacheBySlot(ctx, s.rcClient, keys)
}
// GetCacheMaxSeqWithTime only get the existing cache, if there is no cache, no cache will be generated
func (s *seqConversationCacheRedis) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
if len(conversationIDs) == 0 {
return map[string]database.SeqTime{}, nil
}
key2conversationID := make(map[string]string)
keys := make([]string, 0, len(conversationIDs))
for _, conversationID := range conversationIDs {
key := s.getSeqMallocKey(conversationID)
if _, ok := key2conversationID[key]; ok {
continue
}
key2conversationID[key] = conversationID
keys = append(keys, key)
}
slotKeys, err := groupKeysBySlot(ctx, s.rcClient.GetRedis(), keys)
if err != nil {
return nil, err
}
res := make(map[string]database.SeqTime)
for _, keys := range slotKeys {
if len(keys) == 0 {
continue
}
pipe := s.rcClient.GetRedis().Pipeline()
cmds := make([]*redis.SliceCmd, 0, len(keys))
for _, key := range keys {
cmds = append(cmds, pipe.HMGet(ctx, key, "CURR", "TIME"))
}
if _, err := pipe.Exec(ctx); err != nil {
return nil, errs.Wrap(err)
}
for i, cmd := range cmds {
val, err := cmd.Result()
if err != nil {
return nil, err
}
if len(val) != 2 {
return nil, errs.WrapMsg(err, "GetCacheMaxSeqWithTime invalid result", "key", keys[i], "res", val)
}
if val[0] == nil {
continue
}
seq, err := s.parseInt64(val[0])
if err != nil {
return nil, err
}
mill, err := s.parseInt64(val[1])
if err != nil {
return nil, err
}
conversationID := key2conversationID[keys[i]]
res[conversationID] = database.SeqTime{Seq: seq, Time: mill}
}
}
return res, nil
}
func (s *seqConversationCacheRedis) parseInt64(val any) (int64, error) {
switch v := val.(type) {
case nil:
return 0, nil
case int:
return int64(v), nil
case int64:
return v, nil
case string:
res, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, errs.WrapMsg(err, "invalid string not int64", "value", v)
}
return res, nil
default:
return 0, errs.New("invalid result not int64", "resType", fmt.Sprintf("%T", v), "value", v)
}
}

View File

@@ -0,0 +1,144 @@
package redis
import (
"context"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func newTestSeq() *seqConversationCacheRedis {
mgocli, err := mongo.Connect(context.Background(), options.Client().ApplyURI("mongodb://openIM:openIM123@127.0.0.1:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second))
if err != nil {
panic(err)
}
model, err := mgo.NewSeqConversationMongo(mgocli.Database("openim_v3"))
if err != nil {
panic(err)
}
opt := &redis.Options{
Addr: "127.0.0.1:16379",
Password: "openIM123",
DB: 1,
}
rdb := redis.NewClient(opt)
if err := rdb.Ping(context.Background()).Err(); err != nil {
panic(err)
}
return NewSeqConversationCacheRedis(rdb, model).(*seqConversationCacheRedis)
}
func TestSeq(t *testing.T) {
ts := newTestSeq()
var (
wg sync.WaitGroup
speed atomic.Int64
)
const count = 128
wg.Add(count)
for i := 0; i < count; i++ {
index := i + 1
go func() {
defer wg.Done()
var size int64 = 10
cID := strconv.Itoa(index * 1)
for i := 1; ; i++ {
//first, err := ts.mgo.Malloc(context.Background(), cID, size) // mongo
first, err := ts.Malloc(context.Background(), cID, size) // redis
if err != nil {
t.Logf("[%d-%d] %s %s", index, i, cID, err)
return
}
speed.Add(size)
_ = first
//t.Logf("[%d] %d -> %d", i, first+1, first+size)
}
}()
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
ticker := time.NewTicker(time.Second)
for {
select {
case <-done:
ticker.Stop()
return
case <-ticker.C:
value := speed.Swap(0)
t.Logf("speed: %d/s", value)
}
}
}
func TestDel(t *testing.T) {
ts := newTestSeq()
for i := 1; i < 100; i++ {
var size int64 = 100
first, err := ts.Malloc(context.Background(), "100", size)
if err != nil {
t.Logf("[%d] %s", i, err)
return
}
t.Logf("[%d] %d -> %d", i, first+1, first+size)
time.Sleep(time.Second)
}
}
func TestSeqMalloc(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMaxSeq(context.Background(), "100"))
}
func TestMinSeq(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMinSeq(context.Background(), "10000000"))
}
func TestMalloc(t *testing.T) {
ts := newTestSeq()
t.Log(ts.mallocTime(context.Background(), "10000000", 100))
}
func TestHMGET(t *testing.T) {
ts := newTestSeq()
res, err := ts.GetCacheMaxSeqWithTime(context.Background(), []string{"10000000", "123456"})
if err != nil {
panic(err)
}
t.Log(res)
}
func TestGetMaxSeqWithTime(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMaxSeqWithTime(context.Background(), "10000000"))
}
func TestGetMaxSeqWithTime1(t *testing.T) {
ts := newTestSeq()
t.Log(ts.GetMaxSeqsWithTime(context.Background(), []string{"10000000", "12345", "111"}))
}
//
//func TestHMGET(t *testing.T) {
// ts := newTestSeq()
// res, err := ts.rdb.HMGet(context.Background(), "MALLOC_SEQ:1", "CURR", "TIME1").Result()
// if err != nil {
// panic(err)
// }
// t.Log(res)
//}

View File

@@ -0,0 +1,184 @@
package redis
import (
"context"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
)
func NewSeqUserCacheRedis(rdb redis.UniversalClient, mgo database.SeqUser) cache.SeqUser {
return &seqUserCacheRedis{
mgo: mgo,
readSeqWriteRatio: 100,
expireTime: time.Hour * 24 * 7,
readExpireTime: time.Hour * 24 * 30,
rocks: newRocksCacheClient(rdb),
}
}
type seqUserCacheRedis struct {
mgo database.SeqUser
rocks *rocksCacheClient
expireTime time.Duration
readExpireTime time.Duration
readSeqWriteRatio int64
}
func (s *seqUserCacheRedis) getSeqUserMaxSeqKey(conversationID string, userID string) string {
return cachekey.GetSeqUserMaxSeqKey(conversationID, userID)
}
func (s *seqUserCacheRedis) getSeqUserMinSeqKey(conversationID string, userID string) string {
return cachekey.GetSeqUserMinSeqKey(conversationID, userID)
}
func (s *seqUserCacheRedis) getSeqUserReadSeqKey(conversationID string, userID string) string {
return cachekey.GetSeqUserReadSeqKey(conversationID, userID)
}
func (s *seqUserCacheRedis) GetUserMaxSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return getCache(ctx, s.rocks, s.getSeqUserMaxSeqKey(conversationID, userID), s.expireTime, func(ctx context.Context) (int64, error) {
return s.mgo.GetUserMaxSeq(ctx, conversationID, userID)
})
}
func (s *seqUserCacheRedis) SetUserMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
if err := s.mgo.SetUserMaxSeq(ctx, conversationID, userID, seq); err != nil {
return err
}
return s.rocks.GetClient().TagAsDeleted2(ctx, s.getSeqUserMaxSeqKey(conversationID, userID))
}
func (s *seqUserCacheRedis) GetUserMinSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return getCache(ctx, s.rocks, s.getSeqUserMinSeqKey(conversationID, userID), s.expireTime, func(ctx context.Context) (int64, error) {
return s.mgo.GetUserMinSeq(ctx, conversationID, userID)
})
}
func (s *seqUserCacheRedis) SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
return s.SetUserMinSeqs(ctx, userID, map[string]int64{conversationID: seq})
}
func (s *seqUserCacheRedis) GetUserReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return getCache(ctx, s.rocks, s.getSeqUserReadSeqKey(conversationID, userID), s.readExpireTime, func(ctx context.Context) (int64, error) {
return s.mgo.GetUserReadSeq(ctx, conversationID, userID)
})
}
func (s *seqUserCacheRedis) SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
if s.rocks.GetRedis() == nil {
return s.SetUserReadSeqToDB(ctx, conversationID, userID, seq)
}
dbSeq, err := s.GetUserReadSeq(ctx, conversationID, userID)
if err != nil {
return err
}
if dbSeq < seq {
if err := s.rocks.GetClient().RawSet(ctx, s.getSeqUserReadSeqKey(conversationID, userID), strconv.Itoa(int(seq)), s.readExpireTime); err != nil {
return errs.Wrap(err)
}
}
return nil
}
func (s *seqUserCacheRedis) SetUserReadSeqToDB(ctx context.Context, conversationID string, userID string, seq int64) error {
return s.mgo.SetUserReadSeq(ctx, conversationID, userID, seq)
}
func (s *seqUserCacheRedis) SetUserMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error {
keys := make([]string, 0, len(seqs))
for conversationID, seq := range seqs {
if err := s.mgo.SetUserMinSeq(ctx, conversationID, userID, seq); err != nil {
return err
}
keys = append(keys, s.getSeqUserMinSeqKey(conversationID, userID))
}
return DeleteCacheBySlot(ctx, s.rocks, keys)
}
func (s *seqUserCacheRedis) setUserRedisReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error {
keys := make([]string, 0, len(seqs))
keySeq := make(map[string]int64)
for conversationID, seq := range seqs {
key := s.getSeqUserReadSeqKey(conversationID, userID)
keys = append(keys, key)
keySeq[key] = seq
}
slotKeys, err := groupKeysBySlot(ctx, s.rocks.GetRedis(), keys)
if err != nil {
return err
}
for _, keys := range slotKeys {
pipe := s.rocks.GetRedis().Pipeline()
for _, key := range keys {
pipe.HSet(ctx, key, "value", strconv.FormatInt(keySeq[key], 10))
pipe.Expire(ctx, key, s.readExpireTime)
}
if _, err := pipe.Exec(ctx); err != nil {
return err
}
}
return nil
}
func (s *seqUserCacheRedis) SetUserReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error {
if len(seqs) == 0 {
return nil
}
if err := s.setUserRedisReadSeqs(ctx, userID, seqs); err != nil {
return err
}
return nil
}
func (s *seqUserCacheRedis) GetUserReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) {
res, err := batchGetCache2(ctx, s.rocks, s.readExpireTime, conversationIDs, func(conversationID string) string {
return s.getSeqUserReadSeqKey(conversationID, userID)
}, func(v *readSeqModel) string {
return v.ConversationID
}, func(ctx context.Context, conversationIDs []string) ([]*readSeqModel, error) {
seqs, err := s.mgo.GetUserReadSeqs(ctx, userID, conversationIDs)
if err != nil {
return nil, err
}
res := make([]*readSeqModel, 0, len(seqs))
for conversationID, seq := range seqs {
res = append(res, &readSeqModel{ConversationID: conversationID, Seq: seq})
}
return res, nil
})
if err != nil {
return nil, err
}
data := make(map[string]int64)
for _, v := range res {
data[v.ConversationID] = v.Seq
}
return data, nil
}
var _ BatchCacheCallback[string] = (*readSeqModel)(nil)
type readSeqModel struct {
ConversationID string
Seq int64
}
func (r *readSeqModel) BatchCache(conversationID string) {
r.ConversationID = conversationID
}
func (r *readSeqModel) UnmarshalJSON(bytes []byte) (err error) {
r.Seq, err = strconv.ParseInt(string(bytes), 10, 64)
return
}
func (r *readSeqModel) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(r.Seq, 10)), nil
}

View File

@@ -0,0 +1,112 @@
package redis
import (
"context"
"fmt"
"log"
"strconv"
"sync/atomic"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
mgo2 "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func newTestOnline() *userOnline {
opt := &redis.Options{
Addr: "172.16.8.48:16379",
Password: "openIM123",
DB: 0,
}
rdb := redis.NewClient(opt)
if err := rdb.Ping(context.Background()).Err(); err != nil {
panic(err)
}
return &userOnline{rdb: rdb, expire: time.Hour, channelName: "user_online"}
}
func TestOnline(t *testing.T) {
ts := newTestOnline()
var count atomic.Int64
for i := 0; i < 64; i++ {
go func(userID string) {
var err error
for i := 0; ; i++ {
if i%2 == 0 {
err = ts.SetUserOnline(context.Background(), userID, []int32{5, 6}, []int32{7, 8, 9})
} else {
err = ts.SetUserOnline(context.Background(), userID, []int32{1, 2, 3}, []int32{4, 5, 6})
}
if err != nil {
panic(err)
}
count.Add(1)
}
}(strconv.Itoa(10000 + i))
}
ticker := time.NewTicker(time.Second)
for range ticker.C {
t.Log(count.Swap(0))
}
}
func TestGetOnline(t *testing.T) {
ts := newTestOnline()
ctx := context.Background()
pIDs, err := ts.GetOnline(ctx, "10000")
if err != nil {
panic(err)
}
t.Log(pIDs)
}
func TestRecvOnline(t *testing.T) {
ts := newTestOnline()
ctx := context.Background()
pubsub := ts.rdb.Subscribe(ctx, cachekey.OnlineChannel)
_, err := pubsub.Receive(ctx)
if err != nil {
log.Fatalf("Could not subscribe: %v", err)
}
ch := pubsub.Channel()
for msg := range ch {
fmt.Printf("Received message from channel %s: %s\n", msg.Channel, msg.Payload)
}
}
func TestName1(t *testing.T) {
opt := &redis.Options{
Addr: "172.16.8.48:16379",
Password: "openIM123",
DB: 0,
}
rdb := redis.NewClient(opt)
mgo, err := mongo.Connect(context.Background(),
options.Client().
ApplyURI("mongodb://openIM:openIM123@172.16.8.48:37017/openim_v3?maxPoolSize=100").
SetConnectTimeout(5*time.Second))
if err != nil {
panic(err)
}
model, err := mgo2.NewSeqUserMongo(mgo.Database("openim_v3"))
if err != nil {
panic(err)
}
seq := NewSeqUserCacheRedis(rdb, model)
res, err := seq.GetUserReadSeqs(context.Background(), "2110910952", []string{"sg_345762580", "2000", "3000"})
if err != nil {
panic(err)
}
t.Log(res)
}

90
pkg/common/storage/cache/redis/third.go vendored Normal file
View File

@@ -0,0 +1,90 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
)
func NewThirdCache(rdb redis.UniversalClient) cache.ThirdCache {
return &thirdCache{rdb: rdb}
}
type thirdCache struct {
rdb redis.UniversalClient
}
func (c *thirdCache) getGetuiTokenKey() string {
return cachekey.GetGetuiTokenKey()
}
func (c *thirdCache) getGetuiTaskIDKey() string {
return cachekey.GetGetuiTaskIDKey()
}
func (c *thirdCache) getUserBadgeUnreadCountSumKey(userID string) string {
return cachekey.GetUserBadgeUnreadCountSumKey(userID)
}
func (c *thirdCache) getFcmAccountTokenKey(account string, platformID int) string {
return cachekey.GetFcmAccountTokenKey(account, platformID)
}
func (c *thirdCache) SetFcmToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) (err error) {
return errs.Wrap(c.rdb.Set(ctx, c.getFcmAccountTokenKey(account, platformID), fcmToken, time.Duration(expireTime)*time.Second).Err())
}
func (c *thirdCache) GetFcmToken(ctx context.Context, account string, platformID int) (string, error) {
val, err := c.rdb.Get(ctx, c.getFcmAccountTokenKey(account, platformID)).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}
func (c *thirdCache) DelFcmToken(ctx context.Context, account string, platformID int) error {
return errs.Wrap(c.rdb.Del(ctx, c.getFcmAccountTokenKey(account, platformID)).Err())
}
func (c *thirdCache) IncrUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) {
seq, err := c.rdb.Incr(ctx, c.getUserBadgeUnreadCountSumKey(userID)).Result()
return int(seq), errs.Wrap(err)
}
func (c *thirdCache) SetUserBadgeUnreadCountSum(ctx context.Context, userID string, value int) error {
return errs.Wrap(c.rdb.Set(ctx, c.getUserBadgeUnreadCountSumKey(userID), value, 0).Err())
}
func (c *thirdCache) GetUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error) {
val, err := c.rdb.Get(ctx, c.getUserBadgeUnreadCountSumKey(userID)).Int()
return val, errs.Wrap(err)
}
func (c *thirdCache) SetGetuiToken(ctx context.Context, token string, expireTime int64) error {
return errs.Wrap(c.rdb.Set(ctx, c.getGetuiTokenKey(), token, time.Duration(expireTime)*time.Second).Err())
}
func (c *thirdCache) GetGetuiToken(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.getGetuiTokenKey()).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}
func (c *thirdCache) SetGetuiTaskID(ctx context.Context, taskID string, expireTime int64) error {
return errs.Wrap(c.rdb.Set(ctx, c.getGetuiTaskIDKey(), taskID, time.Duration(expireTime)*time.Second).Err())
}
func (c *thirdCache) GetGetuiTaskID(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.getGetuiTaskIDKey()).Result()
if err != nil {
return "", errs.Wrap(err)
}
return val, nil
}

248
pkg/common/storage/cache/redis/token.go vendored Normal file
View File

@@ -0,0 +1,248 @@
package redis
import (
"context"
"encoding/json"
"strconv"
"sync"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
type tokenCache struct {
rdb redis.UniversalClient
accessExpire time.Duration
localCache *config.LocalCache
}
func NewTokenCacheModel(rdb redis.UniversalClient, localCache *config.LocalCache, accessExpire int64) cache.TokenModel {
c := &tokenCache{rdb: rdb, localCache: localCache}
c.accessExpire = c.getExpireTime(accessExpire)
return c
}
func (c *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HSet(ctx, key, token, flag).Err(); err != nil {
return errs.Wrap(err)
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, key)
}
return nil
}
// SetTokenFlagEx set token and flag with expire time
func (c *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HSet(ctx, key, token, flag).Err(); err != nil {
return errs.Wrap(err)
}
if err := c.rdb.Expire(ctx, key, c.accessExpire).Err(); err != nil {
return errs.Wrap(err)
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, key)
}
return nil
}
func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
m, err := c.rdb.HGetAll(ctx, cachekey.GetTokenKey(userID, platformID)).Result()
if err != nil {
return nil, errs.Wrap(err)
}
mm := make(map[string]int)
for k, v := range m {
state, err := strconv.Atoi(v)
if err != nil {
return nil, errs.WrapMsg(err, "redis token value is not int", "value", v, "userID", userID, "platformID", platformID)
}
mm[k] = state
}
return mm, nil
}
func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err()
if err != nil {
return errs.Wrap(err)
}
return nil
}
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
var (
res = make(map[int]map[string]int)
resLock = sync.Mutex{}
)
keys := cachekey.GetAllPlatformTokenKey(userID)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
mapRes := make([]*redis.MapStringStringCmd, len(keys))
for i, key := range keys {
mapRes[i] = pipe.HGetAll(ctx, key)
}
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
for i, m := range mapRes {
mm := make(map[string]int)
for k, v := range m.Val() {
state, err := strconv.Atoi(v)
if err != nil {
return errs.WrapMsg(err, "redis token value is not int", "value", v, "userID", userID)
}
mm[k] = state
}
resLock.Lock()
res[cachekey.GetPlatformIDByTokenKey(keys[i])] = mm
resLock.Unlock()
}
return nil
}); err != nil {
return nil, err
}
return res, nil
}
func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
mm := make(map[string]any)
for k, v := range m {
mm[k] = v
}
err := c.rdb.HSet(ctx, cachekey.GetTokenKey(userID, platformID), mm).Err()
if err != nil {
return errs.Wrap(err)
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, cachekey.GetTokenKey(userID, platformID))
}
return nil
}
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
keys := datautil.Keys(tokens)
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HSet(ctx, k, v)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, keys...)
}
return nil
}
func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, key)
}
return nil
}
func (c *tokenCache) getExpireTime(t int64) time.Duration {
return time.Hour * 24 * time.Duration(t)
}
// DeleteTokenByTokenMap tokens key is platformID, value is token slice
func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
var (
keys = make([]string, 0, len(tokens))
keyMap = make(map[string][]string)
)
for k, v := range tokens {
k1 := cachekey.GetTokenKey(userID, k)
keys = append(keys, k1)
keyMap[k1] = v
}
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
pipe := c.rdb.Pipeline()
for k, v := range tokens {
pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...)
}
_, err := pipe.Exec(ctx)
if err != nil {
return errs.Wrap(err)
}
return nil
}); err != nil {
return err
}
// Remove local cache for the token
if c.localCache != nil {
c.removeLocalTokenCache(ctx, keys...)
}
return nil
}
func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
for _, f := range fields {
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
if err := c.rdb.Set(ctx, k, "", c.accessExpire).Err(); err != nil {
return errs.Wrap(err)
}
}
key := cachekey.GetTokenKey(userID, platformID)
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
if c.localCache != nil {
c.removeLocalTokenCache(ctx, key)
}
return nil
}
func (c *tokenCache) removeLocalTokenCache(ctx context.Context, keys ...string) {
if len(keys) == 0 {
return
}
topic := c.localCache.Auth.Topic
if topic == "" {
return
}
data, err := json.Marshal(keys)
if err != nil {
log.ZWarn(ctx, "keys json marshal failed", err, "topic", topic, "keys", keys)
} else {
if err := c.rdb.Publish(ctx, topic, string(data)).Err(); err != nil {
log.ZWarn(ctx, "redis publish cache delete error", err, "topic", topic, "keys", keys)
}
}
}

107
pkg/common/storage/cache/redis/user.go vendored Normal file
View File

@@ -0,0 +1,107 @@
package redis
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/dtm-labs/rockscache"
"github.com/openimsdk/tools/log"
"github.com/redis/go-redis/v9"
)
const (
userExpireTime = time.Second * 60 * 60 * 12
userOlineStatusExpireTime = time.Second * 60 * 60 * 24
statusMod = 501
)
type UserCacheRedis struct {
cache.BatchDeleter
rdb redis.UniversalClient
userDB database.User
expireTime time.Duration
rcClient *rocksCacheClient
}
func NewUserCacheRedis(rdb redis.UniversalClient, localCache *config.LocalCache, userDB database.User, options *rockscache.Options) cache.UserCache {
rc := newRocksCacheClient(rdb)
return &UserCacheRedis{
BatchDeleter: rc.GetBatchDeleter(localCache.User.Topic),
rdb: rdb,
userDB: userDB,
expireTime: userExpireTime,
rcClient: rc,
}
}
func (u *UserCacheRedis) getUserID(user *model.User) string {
return user.UserID
}
func (u *UserCacheRedis) CloneUserCache() cache.UserCache {
return &UserCacheRedis{
BatchDeleter: u.BatchDeleter.Clone(),
rdb: u.rdb,
userDB: u.userDB,
expireTime: u.expireTime,
rcClient: u.rcClient,
}
}
func (u *UserCacheRedis) getUserInfoKey(userID string) string {
return cachekey.GetUserInfoKey(userID)
}
func (u *UserCacheRedis) getUserGlobalRecvMsgOptKey(userID string) string {
return cachekey.GetUserGlobalRecvMsgOptKey(userID)
}
func (u *UserCacheRedis) GetUserInfo(ctx context.Context, userID string) (userInfo *model.User, err error) {
return getCache(ctx, u.rcClient, u.getUserInfoKey(userID), u.expireTime, func(ctx context.Context) (*model.User, error) {
return u.userDB.Take(ctx, userID)
})
}
func (u *UserCacheRedis) GetUsersInfo(ctx context.Context, userIDs []string) ([]*model.User, error) {
log.ZInfo(ctx, "GetUsersInfo start", "userIDs", userIDs)
return batchGetCache2(ctx, u.rcClient, u.expireTime, userIDs, u.getUserInfoKey, u.getUserID, u.userDB.Find)
}
func (u *UserCacheRedis) DelUsersInfo(userIDs ...string) cache.UserCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, u.getUserInfoKey(userID))
}
cache := u.CloneUserCache()
cache.AddKeys(keys...)
return cache
}
func (u *UserCacheRedis) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) {
return getCache(
ctx,
u.rcClient,
u.getUserGlobalRecvMsgOptKey(userID),
u.expireTime,
func(ctx context.Context) (int, error) {
return u.userDB.GetUserGlobalRecvMsgOpt(ctx, userID)
},
)
}
func (u *UserCacheRedis) DelUsersGlobalRecvMsgOpt(userIDs ...string) cache.UserCache {
keys := make([]string, 0, len(userIDs))
for _, userID := range userIDs {
keys = append(keys, u.getUserGlobalRecvMsgOptKey(userID))
}
cache := u.CloneUserCache()
cache.AddKeys(keys...)
return cache
}

52
pkg/common/storage/cache/s3.go vendored Normal file
View File

@@ -0,0 +1,52 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/s3"
)
type ObjectCache interface {
BatchDeleter
CloneObjectCache() ObjectCache
GetName(ctx context.Context, engine string, name string) (*relationtb.Object, error)
DelObjectName(engine string, names ...string) ObjectCache
}
type S3Cache interface {
BatchDeleter
GetKey(ctx context.Context, engine string, key string) (*s3.ObjectInfo, error)
DelS3Key(engine string, keys ...string) S3Cache
}
// TODO integrating minio.Cache and MinioCache interfaces.
type MinioCache interface {
BatchDeleter
GetImageObjectKeyInfo(ctx context.Context, key string, fn func(ctx context.Context) (*MinioImageInfo, error)) (*MinioImageInfo, error)
GetThumbnailKey(ctx context.Context, key string, format string, width int, height int, minioCache func(ctx context.Context) (string, error)) (string, error)
DelObjectImageInfoKey(keys ...string) MinioCache
DelImageThumbnailKey(key string, format string, width int, height int) MinioCache
}
type MinioImageInfo struct {
IsImg bool `json:"isImg"`
Width int `json:"width"`
Height int `json:"height"`
Format string `json:"format"`
Etag string `json:"etag"`
}

View File

@@ -0,0 +1,19 @@
package cache
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
)
type SeqConversationCache interface {
Malloc(ctx context.Context, conversationID string, size int64) (int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error)
SetMinSeq(ctx context.Context, conversationID string, seq int64) error
GetMinSeq(ctx context.Context, conversationID string) (int64, error)
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
SetMinSeqs(ctx context.Context, seqs map[string]int64) error
GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error)
}

16
pkg/common/storage/cache/seq_user.go vendored Normal file
View File

@@ -0,0 +1,16 @@
package cache
import "context"
type SeqUser interface {
GetUserMaxSeq(ctx context.Context, conversationID string, userID string) (int64, error)
SetUserMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error
GetUserMinSeq(ctx context.Context, conversationID string, userID string) (int64, error)
SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error
GetUserReadSeq(ctx context.Context, conversationID string, userID string) (int64, error)
SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error
SetUserReadSeqToDB(ctx context.Context, conversationID string, userID string, seq int64) error
SetUserMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error
SetUserReadSeqs(ctx context.Context, userID string, seqs map[string]int64) error
GetUserReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error)
}

18
pkg/common/storage/cache/third.go vendored Normal file
View File

@@ -0,0 +1,18 @@
package cache
import (
"context"
)
type ThirdCache interface {
SetFcmToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) (err error)
GetFcmToken(ctx context.Context, account string, platformID int) (string, error)
DelFcmToken(ctx context.Context, account string, platformID int) error
IncrUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error)
SetUserBadgeUnreadCountSum(ctx context.Context, userID string, value int) error
GetUserBadgeUnreadCountSum(ctx context.Context, userID string) (int, error)
SetGetuiToken(ctx context.Context, token string, expireTime int64) error
GetGetuiToken(ctx context.Context) (string, error)
SetGetuiTaskID(ctx context.Context, taskID string, expireTime int64) error
GetGetuiTaskID(ctx context.Context) (string, error)
}

19
pkg/common/storage/cache/token.go vendored Normal file
View File

@@ -0,0 +1,19 @@
package cache
import (
"context"
)
type TokenModel interface {
SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error
// SetTokenFlagEx set token and flag with expire time
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error
DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error
}

33
pkg/common/storage/cache/user.go vendored Normal file
View File

@@ -0,0 +1,33 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cache
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
type UserCache interface {
BatchDeleter
CloneUserCache() UserCache
GetUserInfo(ctx context.Context, userID string) (userInfo *model.User, err error)
GetUsersInfo(ctx context.Context, userIDs []string) ([]*model.User, error)
DelUsersInfo(userIDs ...string) UserCache
GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error)
DelUsersGlobalRecvMsgOpt(userIDs ...string) UserCache
//GetUserStatus(ctx context.Context, userIDs []string) ([]*user.OnlineStatus, error)
//SetUserStatus(ctx context.Context, userID string, status, platformID int32) error
}

View File

@@ -0,0 +1,26 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
type BatchUpdateGroupMember struct {
GroupID string
UserID string
Map map[string]any
}
type GroupSimpleUserID struct {
Hash uint64
MemberNum uint32
}

View File

@@ -0,0 +1,253 @@
package controller
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
"git.imall.cloud/openim/protocol/constant"
"github.com/golang-jwt/jwt/v4"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/tokenverify"
)
type AuthDatabase interface {
// If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error
// Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error)
BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
}
type multiLoginConfig struct {
Policy int
MaxNumOneEnd int
}
type authDatabase struct {
cache cache.TokenModel
accessSecret string
accessExpire int64
multiLogin multiLoginConfig
adminUserIDs []string
}
func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin, adminUserIDs []string) AuthDatabase {
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{
Policy: multiLogin.Policy,
MaxNumOneEnd: multiLogin.MaxNumOneEnd,
},
adminUserIDs: adminUserIDs,
}
}
// If the result is empty.
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platformID)
}
func (a *authDatabase) GetTemporaryTokensWithoutError(ctx context.Context, userID string, platformID int, token string) error {
return a.cache.HasTemporaryToken(ctx, userID, platformID, token)
}
func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
}
func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error {
setMap := make(map[string]map[string]any)
for _, token := range tokens {
claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(a.accessSecret))
if err != nil {
continue
}
key := cachekey.GetTokenKey(claims.UserID, claims.PlatformID)
if v, ok := setMap[key]; ok {
v[token] = constant.KickedToken
} else {
setMap[key] = map[string]any{
token: constant.KickedToken,
}
}
}
if err := a.cache.BatchSetTokenMapByUidPid(ctx, setMap); err != nil {
return err
}
return nil
}
// Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID)
if err != nil {
return "", err
}
deleteTokenKey, kickedTokenKey, adminTokens, err := a.checkToken(ctx, tokens, platformID)
if err != nil {
return "", err
}
if len(deleteTokenKey) != 0 {
err = a.cache.DeleteTokenByTokenMap(ctx, userID, deleteTokenKey)
if err != nil {
return "", err
}
}
if len(kickedTokenKey) != 0 {
for plt, ks := range kickedTokenKey {
for _, k := range ks {
err := a.cache.SetTokenFlagEx(ctx, userID, plt, k, constant.KickedToken)
if err != nil {
return "", err
}
log.ZDebug(ctx, "kicked token in create token", "token", k)
}
}
}
if len(adminTokens) != 0 {
if err = a.cache.DeleteAndSetTemporary(ctx, userID, constant.AdminPlatformID, adminTokens); err != nil {
return "", err
}
}
claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil {
return "", errs.WrapMsg(err, "token.SignedString")
}
if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil {
return "", err
}
return tokenString, nil
}
// checkToken will check token by tokenPolicy and return deleteToken,kickToken,deleteAdminToken
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) (map[int][]string, map[int][]string, []string, error) {
// todo: Asynchronous deletion of old data.
var (
loginTokenMap = make(map[int][]string) // The length of the value of the map must be greater than 0
deleteToken = make(map[int][]string)
kickToken = make(map[int][]string)
adminToken = make([]string, 0)
unkickTerminal = ""
)
for plfID, tks := range tokens {
for k, v := range tks {
_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
if err != nil || v != constant.NormalToken {
deleteToken[plfID] = append(deleteToken[plfID], k)
} else {
if plfID != constant.AdminPlatformID {
loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
} else {
adminToken = append(adminToken, k)
}
}
}
}
switch a.multiLogin.Policy {
case constant.DefalutNotKick:
for plt, ts := range loginTokenMap {
l := len(ts)
if platformID == plt {
l++
}
limit := a.multiLogin.MaxNumOneEnd
if l > limit {
kickToken[plt] = ts[:l-limit]
}
}
case constant.AllLoginButSameTermKick:
for plt, ts := range loginTokenMap {
kickToken[plt] = ts[:len(ts)-1]
if plt == platformID {
kickToken[plt] = append(kickToken[plt], ts[len(ts)-1])
}
}
case constant.PCAndOther:
unkickTerminal = constant.TerminalPC
if constant.PlatformIDToClass(platformID) != unkickTerminal {
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
kickToken[plt] = ts
}
}
} else {
var (
preKickToken string
preKickPlt int
reserveToken = false
)
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) != unkickTerminal {
// Keep a token from another end
if !reserveToken {
reserveToken = true
kickToken[plt] = ts[:len(ts)-1]
preKickToken = ts[len(ts)-1]
preKickPlt = plt
continue
} else {
// Prioritize keeping Android
if plt == constant.AndroidPlatformID {
if preKickToken != "" {
kickToken[preKickPlt] = append(kickToken[preKickPlt], preKickToken)
}
kickToken[plt] = ts[:len(ts)-1]
} else {
kickToken[plt] = ts
}
}
}
}
}
case constant.AllLoginButSameClassKick:
var (
reserved = make(map[string]struct{})
)
for plt, ts := range loginTokenMap {
if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
kickToken[plt] = ts
} else {
if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok {
reserved[constant.PlatformIDToClass(plt)] = struct{}{}
kickToken[plt] = ts[:len(ts)-1]
continue
} else {
kickToken[plt] = ts
}
}
}
default:
return nil, nil, nil, errs.New("unknown multiLogin policy").Wrap()
}
//var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
//l := len(adminToken)
//if platformID == constant.AdminPlatformID {
// l++
//}
//if l > adminTokenMaxNum {
// kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
//}
var deleteAdminToken []string
if platformID == constant.AdminPlatformID {
deleteAdminToken = adminToken
}
return deleteToken, kickToken, deleteAdminToken, nil
}

View File

@@ -0,0 +1,101 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
)
type BlackDatabase interface {
// Create add BlackList
Create(ctx context.Context, blacks []*model.Black) (err error)
// Delete delete BlackList
Delete(ctx context.Context, blacks []*model.Black) (err error)
// FindOwnerBlacks get BlackList list
FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error)
FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error)
// CheckIn Check whether user2 is in the black list of user1 (inUser1Blacks==true) Check whether user1 is in the black list of user2 (inUser2Blacks==true)
CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error)
}
type blackDatabase struct {
black database.Black
cache cache.BlackCache
}
func NewBlackDatabase(black database.Black, cache cache.BlackCache) BlackDatabase {
return &blackDatabase{black, cache}
}
// Create Add Blacklist.
func (b *blackDatabase) Create(ctx context.Context, blacks []*model.Black) (err error) {
if err := b.black.Create(ctx, blacks); err != nil {
return err
}
return b.deleteBlackIDsCache(ctx, blacks)
}
// Delete Delete Blacklist.
func (b *blackDatabase) Delete(ctx context.Context, blacks []*model.Black) (err error) {
if err := b.black.Delete(ctx, blacks); err != nil {
return err
}
return b.deleteBlackIDsCache(ctx, blacks)
}
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*model.Black) (err error) {
cache := b.cache.CloneBlackCache()
for _, black := range blacks {
cache = cache.DelBlackIDs(ctx, black.OwnerUserID)
}
return cache.ChainExecDel(ctx)
}
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error) {
return b.black.FindOwnerBlacks(ctx, ownerUserID, pagination)
}
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) {
userID1BlackIDs, err := b.cache.GetBlackIDs(ctx, userID1)
if err != nil {
return
}
userID2BlackIDs, err := b.cache.GetBlackIDs(ctx, userID2)
if err != nil {
return
}
log.ZDebug(ctx, "blackIDs", "user1BlackIDs", userID1BlackIDs, "user2BlackIDs", userID2BlackIDs)
return datautil.Contain(userID2, userID1BlackIDs...), datautil.Contain(userID1, userID2BlackIDs...), nil
}
// FindBlackIDs Get Blacklist List.
func (b *blackDatabase) FindBlackIDs(ctx context.Context, ownerUserID string) (blackIDs []string, err error) {
return b.cache.GetBlackIDs(ctx, ownerUserID)
}
// FindBlackInfos Get Blacklist List.
func (b *blackDatabase) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error) {
return b.black.FindOwnerBlackInfos(ctx, ownerUserID, userIDs)
}

View File

@@ -0,0 +1,58 @@
package controller
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
)
type ClientConfigDatabase interface {
SetUserConfig(ctx context.Context, userID string, config map[string]string) error
GetUserConfig(ctx context.Context, userID string) (map[string]string, error)
DelUserConfig(ctx context.Context, userID string, keys []string) error
GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error)
}
func NewClientConfigDatabase(db database.ClientConfig, cache cache.ClientConfigCache, tx tx.Tx) ClientConfigDatabase {
return &clientConfigDatabase{
tx: tx,
db: db,
cache: cache,
}
}
type clientConfigDatabase struct {
tx tx.Tx
db database.ClientConfig
cache cache.ClientConfigCache
}
func (x *clientConfigDatabase) SetUserConfig(ctx context.Context, userID string, config map[string]string) error {
return x.tx.Transaction(ctx, func(ctx context.Context) error {
if err := x.db.Set(ctx, userID, config); err != nil {
return err
}
return x.cache.DeleteUserCache(ctx, []string{userID})
})
}
func (x *clientConfigDatabase) GetUserConfig(ctx context.Context, userID string) (map[string]string, error) {
return x.cache.GetUserConfig(ctx, userID)
}
func (x *clientConfigDatabase) DelUserConfig(ctx context.Context, userID string, keys []string) error {
return x.tx.Transaction(ctx, func(ctx context.Context) error {
if err := x.db.Del(ctx, userID, keys); err != nil {
return err
}
return x.cache.DeleteUserCache(ctx, []string{userID})
})
}
func (x *clientConfigDatabase) GetUserConfigPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) {
return x.db.GetPage(ctx, userID, key, pagination)
}

View File

@@ -0,0 +1,451 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/stringutil"
)
type ConversationDatabase interface {
// UpdateUsersConversationField updates the properties of a conversation for specified users.
UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation creates a batch of new conversations.
CreateConversation(ctx context.Context, conversations []*relationtb.Conversation) error
// SyncPeerUserPrivateConversationTx ensures transactional operation while syncing private conversations between peers.
SyncPeerUserPrivateConversationTx(ctx context.Context, conversation []*relationtb.Conversation) error
// FindConversations retrieves multiple conversations of a user by conversation IDs.
FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.Conversation, error)
// GetUserAllConversation fetches all conversations of a user on the server.
GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.Conversation, error)
// SetUserConversations sets multiple conversation properties for a user, creates new conversations if they do not exist, or updates them otherwise. This operation is atomic.
SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.Conversation) error
// SetUsersConversationFieldTx updates a specific field for multiple users' conversations, creating new conversations if they do not exist, or updates them otherwise. This operation is
// transactional.
SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.Conversation, fieldMap map[string]any) error
// UpdateUserConversations updates all conversations related to a specified user.
// This function does NOT update the user's own conversations but rather the conversations where this user is involved (e.g., other users' conversations referencing this user).
UpdateUserConversations(ctx context.Context, userID string, args map[string]any) error
// CreateGroupChatConversation creates a group chat conversation for the specified group ID and user IDs.
CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string, conversations *relationtb.Conversation) error
// GetConversationIDs retrieves conversation IDs for a given user.
GetConversationIDs(ctx context.Context, userID string) ([]string, error)
// GetUserConversationIDsHash gets the hash of conversation IDs for a given user.
GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error)
// GetAllConversationIDs fetches all conversation IDs.
GetAllConversationIDs(ctx context.Context) ([]string, error)
// GetAllConversationIDsNumber returns the number of all conversation IDs.
GetAllConversationIDsNumber(ctx context.Context) (int64, error)
// PageConversationIDs paginates through conversation IDs based on the specified pagination settings.
PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error)
// GetConversationsByConversationID retrieves conversations by their IDs.
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.Conversation, error)
// GetConversationIDsNeedDestruct fetches conversations that need to be destructed based on specific criteria.
GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.Conversation, error)
// GetConversationNotReceiveMessageUserIDs gets user IDs for users in a conversation who have not received messages.
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
// GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
// FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
FindConversationUserVersion(ctx context.Context, userID string, version uint, limit int) (*relationtb.VersionLog, error)
FindMaxConversationUserVersionCache(ctx context.Context, userID string) (*relationtb.VersionLog, error)
GetOwnerConversation(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (int64, []*relationtb.Conversation, error)
// GetNotNotifyConversationIDs gets not notify conversationIDs by userID
GetNotNotifyConversationIDs(ctx context.Context, userID string) ([]string, error)
// GetPinnedConversationIDs gets pinned conversationIDs by userID
GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error)
// FindRandConversation finds random conversations based on the specified timestamp and limit.
FindRandConversation(ctx context.Context, ts int64, limit int) ([]*relationtb.Conversation, error)
// DeleteUsersConversations deletes conversations for a user.
DeleteUsersConversations(ctx context.Context, userID string, conversationIDs []string) error
}
func NewConversationDatabase(conversation database.Conversation, cache cache.ConversationCache, tx tx.Tx) ConversationDatabase {
return &conversationDatabase{
conversationDB: conversation,
cache: cache,
tx: tx,
}
}
type conversationDatabase struct {
conversationDB database.Conversation
cache cache.ConversationCache
tx tx.Tx
}
func (c *conversationDatabase) SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.Conversation, fieldMap map[string]any) (err error) {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.CloneConversationCache()
if conversation.GroupID != "" {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(conversation.GroupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(conversation.GroupID)
}
haveUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversation.ConversationID})
if err != nil {
return err
}
if len(haveUserIDs) > 0 {
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, fieldMap)
if err != nil {
return err
}
cache = cache.DelUsersConversation(conversation.ConversationID, haveUserIDs...)
if _, ok := fieldMap["has_read_seq"]; ok {
for _, userID := range haveUserIDs {
cache = cache.DelUserAllHasReadSeqs(userID, conversation.ConversationID)
}
}
if _, ok := fieldMap["recv_msg_opt"]; ok {
cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID)
cache = cache.DelConversationNotNotifyMessageUserIDs(userIDs...)
}
if _, ok := fieldMap["is_pinned"]; ok {
cache = cache.DelUserPinnedConversations(userIDs...)
}
cache = cache.DelConversationVersionUserIDs(haveUserIDs...)
}
NotUserIDs := stringutil.DifferenceString(haveUserIDs, userIDs)
log.ZDebug(ctx, "SetUsersConversationFieldTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
var conversations []*relationtb.Conversation
now := time.Now()
for _, v := range NotUserIDs {
temp := new(relationtb.Conversation)
if err = datautil.CopyStructFields(temp, conversation); err != nil {
return err
}
temp.OwnerUserID = v
temp.CreateTime = now
conversations = append(conversations, temp)
}
if len(conversations) > 0 {
err = c.conversationDB.Create(ctx, conversations)
if err != nil {
return err
}
cache = cache.DelConversationIDs(NotUserIDs...).DelUserConversationIDsHash(NotUserIDs...).DelConversations(conversation.ConversationID, NotUserIDs...)
}
return cache.ChainExecDel(ctx)
})
}
func (c *conversationDatabase) UpdateUserConversations(ctx context.Context, userID string, args map[string]any) error {
conversations, err := c.conversationDB.UpdateUserConversations(ctx, userID, args)
if err != nil {
return err
}
cache := c.cache.CloneConversationCache()
for _, conversation := range conversations {
cache = cache.DelUsersConversation(conversation.ConversationID, conversation.OwnerUserID).DelConversationVersionUserIDs(conversation.OwnerUserID)
}
return cache.ChainExecDel(ctx)
}
func (c *conversationDatabase) UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
_, err := c.conversationDB.UpdateByMap(ctx, userIDs, conversationID, args)
if err != nil {
return err
}
cache := c.cache.CloneConversationCache()
cache = cache.DelUsersConversation(conversationID, userIDs...).DelConversationVersionUserIDs(userIDs...)
if _, ok := args["recv_msg_opt"]; ok {
cache = cache.DelConversationNotReceiveMessageUserIDs(conversationID)
cache = cache.DelConversationNotNotifyMessageUserIDs(userIDs...)
}
if _, ok := args["is_pinned"]; ok {
cache = cache.DelUserPinnedConversations(userIDs...)
}
return cache.ChainExecDel(ctx)
}
func (c *conversationDatabase) CreateConversation(ctx context.Context, conversations []*relationtb.Conversation) error {
if err := c.conversationDB.Create(ctx, conversations); err != nil {
return err
}
var (
userIDs []string
notNotifyUserIDs []string
pinnedUserIDs []string
)
cache := c.cache.CloneConversationCache()
for _, conversation := range conversations {
cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID)
cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID)
userIDs = append(userIDs, conversation.OwnerUserID)
if conversation.RecvMsgOpt == constant.ReceiveNotNotifyMessage {
notNotifyUserIDs = append(notNotifyUserIDs, conversation.OwnerUserID)
}
if conversation.IsPinned {
pinnedUserIDs = append(pinnedUserIDs, conversation.OwnerUserID)
}
}
return cache.DelConversationIDs(userIDs...).
DelUserConversationIDsHash(userIDs...).
DelConversationVersionUserIDs(userIDs...).
DelConversationNotNotifyMessageUserIDs(notNotifyUserIDs...).
DelUserPinnedConversations(pinnedUserIDs...).
ChainExecDel(ctx)
}
func (c *conversationDatabase) SyncPeerUserPrivateConversationTx(ctx context.Context, conversations []*relationtb.Conversation) error {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.CloneConversationCache()
for _, conversation := range conversations {
cache = cache.DelConversationVersionUserIDs(conversation.OwnerUserID, conversation.UserID)
for _, v := range [][2]string{{conversation.OwnerUserID, conversation.UserID}, {conversation.UserID, conversation.OwnerUserID}} {
ownerUserID := v[0]
userID := v[1]
haveUserIDs, err := c.conversationDB.FindUserID(ctx, []string{ownerUserID}, []string{conversation.ConversationID})
if err != nil {
return err
}
if len(haveUserIDs) > 0 {
_, err := c.conversationDB.UpdateByMap(ctx, []string{ownerUserID}, conversation.ConversationID, map[string]any{"is_private_chat": conversation.IsPrivateChat})
if err != nil {
return err
}
cache = cache.DelUsersConversation(conversation.ConversationID, ownerUserID)
} else {
newConversation := *conversation
newConversation.OwnerUserID = ownerUserID
newConversation.UserID = userID
newConversation.ConversationID = conversation.ConversationID
newConversation.IsPrivateChat = conversation.IsPrivateChat
if err := c.conversationDB.Create(ctx, []*relationtb.Conversation{&newConversation}); err != nil {
return err
}
cache = cache.DelConversationIDs(ownerUserID).DelUserConversationIDsHash(ownerUserID)
}
}
}
return cache.ChainExecDel(ctx)
})
}
func (c *conversationDatabase) FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.Conversation, error) {
return c.cache.GetConversations(ctx, ownerUserID, conversationIDs)
}
func (c *conversationDatabase) GetConversation(ctx context.Context, ownerUserID string, conversationID string) (*relationtb.Conversation, error) {
return c.cache.GetConversation(ctx, ownerUserID, conversationID)
}
func (c *conversationDatabase) GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.Conversation, error) {
return c.cache.GetUserAllConversations(ctx, ownerUserID)
}
func (c *conversationDatabase) SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.Conversation) error {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.CloneConversationCache()
cache = cache.DelConversationVersionUserIDs(ownerUserID).
DelConversationNotNotifyMessageUserIDs(ownerUserID).
DelUserPinnedConversations(ownerUserID)
groupIDs := datautil.Distinct(datautil.Filter(conversations, func(e *relationtb.Conversation) (string, bool) {
return e.GroupID, e.GroupID != ""
}))
for _, groupID := range groupIDs {
cache = cache.DelSuperGroupRecvMsgNotNotifyUserIDs(groupID).DelSuperGroupRecvMsgNotNotifyUserIDsHash(groupID)
}
var conversationIDs []string
for _, conversation := range conversations {
conversationIDs = append(conversationIDs, conversation.ConversationID)
cache = cache.DelConversations(conversation.OwnerUserID, conversation.ConversationID)
}
existConversations, err := c.conversationDB.Find(ctx, ownerUserID, conversationIDs)
if err != nil {
return err
}
if len(existConversations) > 0 {
for _, conversation := range conversations {
err = c.conversationDB.Update(ctx, conversation)
if err != nil {
return err
}
}
}
var existConversationIDs []string
for _, conversation := range existConversations {
existConversationIDs = append(existConversationIDs, conversation.ConversationID)
}
var notExistConversations []*relationtb.Conversation
for _, conversation := range conversations {
if !datautil.Contain(conversation.ConversationID, existConversationIDs...) {
notExistConversations = append(notExistConversations, conversation)
}
}
if len(notExistConversations) > 0 {
err = c.conversationDB.Create(ctx, notExistConversations)
if err != nil {
return err
}
cache = cache.DelConversationIDs(ownerUserID).
DelUserConversationIDsHash(ownerUserID).
DelConversationNotReceiveMessageUserIDs(datautil.Slice(notExistConversations, func(e *relationtb.Conversation) string { return e.ConversationID })...)
}
return cache.ChainExecDel(ctx)
})
}
// func (c *conversationDatabase) FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error) {
// return c.cache.GetSuperGroupRecvMsgNotNotifyUserIDs(ctx, groupID)
//}
func (c *conversationDatabase) CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string, conversation *relationtb.Conversation) error {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.CloneConversationCache()
conversationID := conversation.ConversationID
existConversationUserIDs, err := c.conversationDB.FindUserID(ctx, userIDs, []string{conversationID})
if err != nil {
return err
}
notExistUserIDs := stringutil.DifferenceString(userIDs, existConversationUserIDs)
var conversations []*relationtb.Conversation
for _, v := range notExistUserIDs {
conversation := relationtb.Conversation{
ConversationType: conversation.ConversationType, GroupID: groupID, OwnerUserID: v, ConversationID: conversationID,
// the parameters have default value
RecvMsgOpt: conversation.RecvMsgOpt, IsPinned: conversation.IsPinned, IsPrivateChat: conversation.IsPrivateChat,
BurnDuration: conversation.BurnDuration, GroupAtType: conversation.GroupAtType, AttachedInfo: conversation.AttachedInfo,
Ex: conversation.Ex, MaxSeq: conversation.MaxSeq, MinSeq: conversation.MinSeq, CreateTime: conversation.CreateTime,
MsgDestructTime: conversation.MsgDestructTime, IsMsgDestruct: conversation.IsMsgDestruct, LatestMsgDestructTime: conversation.LatestMsgDestructTime,
}
conversations = append(conversations, &conversation)
cache = cache.DelConversations(v, conversationID).DelConversationNotReceiveMessageUserIDs(conversationID)
}
cache = cache.DelConversationIDs(notExistUserIDs...).DelUserConversationIDsHash(notExistUserIDs...)
if len(conversations) > 0 {
err = c.conversationDB.Create(ctx, conversations)
if err != nil {
return err
}
}
_, err = c.conversationDB.UpdateByMap(ctx, existConversationUserIDs, conversationID, map[string]any{"max_seq": conversation.MaxSeq})
if err != nil {
return err
}
for _, v := range existConversationUserIDs {
cache = cache.DelConversations(v, conversationID)
}
return cache.ChainExecDel(ctx)
})
}
func (c *conversationDatabase) GetConversationIDs(ctx context.Context, userID string) ([]string, error) {
return c.cache.GetUserConversationIDs(ctx, userID)
}
func (c *conversationDatabase) GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error) {
return c.cache.GetUserConversationIDsHash(ctx, ownerUserID)
}
func (c *conversationDatabase) GetAllConversationIDs(ctx context.Context) ([]string, error) {
return c.conversationDB.GetAllConversationIDs(ctx)
}
func (c *conversationDatabase) GetAllConversationIDsNumber(ctx context.Context) (int64, error) {
return c.conversationDB.GetAllConversationIDsNumber(ctx)
}
func (c *conversationDatabase) PageConversationIDs(ctx context.Context, pagination pagination.Pagination) ([]string, error) {
return c.conversationDB.PageConversationIDs(ctx, pagination)
}
func (c *conversationDatabase) GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.Conversation, error) {
return c.conversationDB.GetConversationsByConversationID(ctx, conversationIDs)
}
func (c *conversationDatabase) GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.Conversation, error) {
return c.conversationDB.GetConversationIDsNeedDestruct(ctx)
}
func (c *conversationDatabase) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) {
return c.cache.GetConversationNotReceiveMessageUserIDs(ctx, conversationID)
}
func (c *conversationDatabase) FindConversationUserVersion(ctx context.Context, userID string, version uint, limit int) (*relationtb.VersionLog, error) {
return c.conversationDB.FindConversationUserVersion(ctx, userID, version, limit)
}
func (c *conversationDatabase) FindMaxConversationUserVersionCache(ctx context.Context, userID string) (*relationtb.VersionLog, error) {
return c.cache.FindMaxConversationUserVersion(ctx, userID)
}
func (c *conversationDatabase) GetOwnerConversation(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (int64, []*relationtb.Conversation, error) {
conversationIDs, err := c.cache.GetUserConversationIDs(ctx, ownerUserID)
if err != nil {
return 0, nil, err
}
findConversationIDs := datautil.Paginate(conversationIDs, int(pagination.GetPageNumber()), int(pagination.GetShowNumber()))
conversations := make([]*relationtb.Conversation, 0, len(findConversationIDs))
for _, conversationID := range findConversationIDs {
conversation, err := c.cache.GetConversation(ctx, ownerUserID, conversationID)
if err != nil {
return 0, nil, err
}
conversations = append(conversations, conversation)
}
return int64(len(conversationIDs)), conversations, nil
}
func (c *conversationDatabase) GetNotNotifyConversationIDs(ctx context.Context, userID string) ([]string, error) {
conversationIDs, err := c.cache.GetUserNotNotifyConversationIDs(ctx, userID)
if err != nil {
return nil, err
}
return conversationIDs, nil
}
func (c *conversationDatabase) GetPinnedConversationIDs(ctx context.Context, userID string) ([]string, error) {
conversationIDs, err := c.cache.GetPinnedConversationIDs(ctx, userID)
if err != nil {
return nil, err
}
return conversationIDs, nil
}
func (c *conversationDatabase) FindRandConversation(ctx context.Context, ts int64, limit int) ([]*relationtb.Conversation, error) {
return c.conversationDB.FindRandConversation(ctx, ts, limit)
}
func (c *conversationDatabase) DeleteUsersConversations(ctx context.Context, userID string, conversationIDs []string) (err error) {
return c.tx.Transaction(ctx, func(ctx context.Context) error {
err = c.conversationDB.DeleteUsersConversations(ctx, userID, conversationIDs)
if err != nil {
return err
}
cache := c.cache.CloneConversationCache()
cache = cache.DelConversations(userID, conversationIDs...).
DelConversationVersionUserIDs(userID).
DelConversationIDs(userID).
DelUserConversationIDsHash(userID).
DelConversationNotNotifyMessageUserIDs(userID).
DelUserPinnedConversations(userID)
return cache.ChainExecDel(ctx)
})
}

View File

@@ -0,0 +1,15 @@
// Copyright © 2024 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller // import "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"

View File

@@ -0,0 +1,406 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"fmt"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
type FriendDatabase interface {
// CheckIn checks if user2 is in user1's friend list (inUser1Friends==true) and if user1 is in user2's friend list (inUser2Friends==true)
CheckIn(ctx context.Context, user1, user2 string) (inUser1Friends bool, inUser2Friends bool, err error)
// AddFriendRequest adds or updates a friend request
AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error)
// BecomeFriends first checks if the users are already in the friends model; if not, it inserts them as friends
BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error)
// RefuseFriendRequest refuses a friend request
RefuseFriendRequest(ctx context.Context, friendRequest *model.FriendRequest) (err error)
// AgreeFriendRequest accepts a friend request
AgreeFriendRequest(ctx context.Context, friendRequest *model.FriendRequest) (err error)
// Delete removes a friend or friends from the owner's friend list
Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error)
// UpdateRemark updates the remark for a friend
UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error)
// PageOwnerFriends retrieves the friend list of ownerUserID with pagination
PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error)
// PageInWhoseFriends finds the users who have friendUserID in their friend list with pagination
PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error)
// PageFriendRequestFromMe retrieves the friend requests sent by the user with pagination
PageFriendRequestFromMe(ctx context.Context, userID string, handleResults []int, pagination pagination.Pagination) (total int64, friends []*model.FriendRequest, err error)
// PageFriendRequestToMe retrieves the friend requests received by the user with pagination
PageFriendRequestToMe(ctx context.Context, userID string, handleResults []int, pagination pagination.Pagination) (total int64, friends []*model.FriendRequest, err error)
// FindFriendsWithError fetches specified friends of a user and returns an error if any do not exist
FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*model.Friend, err error)
// FindFriendUserIDs retrieves the friend IDs of a user
FindFriendUserIDs(ctx context.Context, ownerUserID string) (friendUserIDs []string, err error)
// FindBothFriendRequests finds friend requests sent and received
FindBothFriendRequests(ctx context.Context, fromUserID, toUserID string) (friends []*model.FriendRequest, err error)
// UpdateFriends updates fields for friends
UpdateFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, val map[string]any) (err error)
//FindSortFriendUserIDs(ctx context.Context, ownerUserID string) ([]string, error)
FindFriendIncrVersion(ctx context.Context, ownerUserID string, version uint, limit int) (*model.VersionLog, error)
FindMaxFriendVersionCache(ctx context.Context, ownerUserID string) (*model.VersionLog, error)
FindFriendUserID(ctx context.Context, friendUserID string) ([]string, error)
OwnerIncrVersion(ctx context.Context, ownerUserID string, friendUserIDs []string, state int32) error
GetUnhandledCount(ctx context.Context, userID string, ts int64) (int64, error)
}
type friendDatabase struct {
friend database.Friend
friendRequest database.FriendRequest
tx tx.Tx
cache cache.FriendCache
}
func NewFriendDatabase(friend database.Friend, friendRequest database.FriendRequest, cache cache.FriendCache, tx tx.Tx) FriendDatabase {
return &friendDatabase{friend: friend, friendRequest: friendRequest, cache: cache, tx: tx}
}
// CheckIn verifies if user2 is in user1's friend list (inUser1Friends returns true) and
// if user1 is in user2's friend list (inUser2Friends returns true).
func (f *friendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) {
// Retrieve friend IDs of userID1 from the cache
userID1FriendIDs, err := f.cache.GetFriendIDs(ctx, userID1)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID1, err)
return
}
// Retrieve friend IDs of userID2 from the cache
userID2FriendIDs, err := f.cache.GetFriendIDs(ctx, userID2)
if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID2, err)
return
}
// Check if userID2 is in userID1's friend list and vice versa
inUser1Friends = datautil.Contain(userID2, userID1FriendIDs...)
inUser2Friends = datautil.Contain(userID1, userID2FriendIDs...)
return inUser1Friends, inUser2Friends, nil
}
// AddFriendRequest adds or updates a friend request.
func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
_, err := f.friendRequest.Take(ctx, fromUserID, toUserID)
switch {
case err == nil:
m := make(map[string]any, 1)
m["handle_result"] = 0
m["handle_msg"] = ""
m["req_msg"] = reqMsg
m["ex"] = ex
m["create_time"] = time.Now()
return f.friendRequest.UpdateByMap(ctx, fromUserID, toUserID, m)
case mgo.IsNotFound(err):
return f.friendRequest.Create(
ctx,
[]*model.FriendRequest{{FromUserID: fromUserID, ToUserID: toUserID, ReqMsg: reqMsg, Ex: ex, CreateTime: time.Now(), HandleTime: time.Unix(0, 0)}},
)
default:
return err
}
})
}
// (1) First determine whether it is in the friends list (in or out does not return an error) (2) for not in the friends list can be inserted.
func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
cache := f.cache.CloneFriendCache()
// user find friends
myFriends, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
return err
}
addOwners, err := f.friend.FindReversalFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
return err
}
opUserID := mcontext.GetOpUserID(ctx)
friends := make([]*model.Friend, 0, len(friendUserIDs)*2)
myFriendsSet := datautil.SliceSetAny(myFriends, func(friend *model.Friend) string {
return friend.FriendUserID
})
addOwnersSet := datautil.SliceSetAny(addOwners, func(friend *model.Friend) string {
return friend.OwnerUserID
})
newMyFriendIDs := make([]string, 0, len(friendUserIDs))
newMyOwnerIDs := make([]string, 0, len(friendUserIDs))
for _, userID := range friendUserIDs {
if ownerUserID == userID {
continue
}
if _, ok := myFriendsSet[userID]; !ok {
myFriendsSet[userID] = struct{}{}
newMyFriendIDs = append(newMyFriendIDs, userID)
friends = append(friends, &model.Friend{OwnerUserID: ownerUserID, FriendUserID: userID, AddSource: addSource, OperatorUserID: opUserID})
}
if _, ok := addOwnersSet[userID]; !ok {
addOwnersSet[userID] = struct{}{}
newMyOwnerIDs = append(newMyOwnerIDs, userID)
friends = append(friends, &model.Friend{OwnerUserID: userID, FriendUserID: ownerUserID, AddSource: addSource, OperatorUserID: opUserID})
}
}
if len(friends) == 0 {
return nil
}
err = f.friend.Create(ctx, friends)
if err != nil {
return err
}
cache = cache.DelFriendIDs(ownerUserID).DelMaxFriendVersion(ownerUserID)
if len(newMyFriendIDs) > 0 {
cache = cache.DelFriendIDs(newMyFriendIDs...)
cache = cache.DelFriends(ownerUserID, newMyFriendIDs).DelMaxFriendVersion(newMyFriendIDs...)
}
if len(newMyOwnerIDs) > 0 {
cache = cache.DelFriendIDs(newMyOwnerIDs...)
cache = cache.DelOwner(ownerUserID, newMyOwnerIDs).DelMaxFriendVersion(newMyOwnerIDs...)
}
return cache.ChainExecDel(ctx)
})
}
// RefuseFriendRequest rejects a friend request. It first checks for an existing, unprocessed request.
// If no such request exists, it returns an error. Otherwise, it marks the request as refused.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *model.FriendRequest) error {
// Attempt to retrieve the friend request from the database.
fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil {
return fmt.Errorf("failed to retrieve friend request from %s to %s: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
// Check if the friend request has already been handled.
if fr.HandleResult != 0 {
return fmt.Errorf("friend request from %s to %s has already been processed", friendRequest.FromUserID, friendRequest.ToUserID)
}
// Log the action of refusing the friend request for debugging and auditing purposes.
log.ZDebug(ctx, "Refusing friend request", map[string]interface{}{
"DB_FriendRequest": fr,
"Arg_FriendRequest": friendRequest,
})
// Mark the friend request as refused and update the handle time.
friendRequest.HandleResult = constant.FriendResponseRefuse
friendRequest.HandleTime = time.Now()
if err := f.friendRequest.Update(ctx, friendRequest); err != nil {
return fmt.Errorf("failed to update friend request from %s to %s as refused: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
}
return nil
}
// AgreeFriendRequest accepts a friend request. It first checks for an existing, unprocessed request.
func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *model.FriendRequest) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error {
now := time.Now()
fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil {
return err
}
if fr.HandleResult != 0 {
return errs.ErrArgs.WrapMsg("the friend request has been processed")
}
friendRequest.HandlerUserID = mcontext.GetOpUserID(ctx)
friendRequest.HandleResult = constant.FriendResponseAgree
friendRequest.HandleTime = now
err = f.friendRequest.Update(ctx, friendRequest)
if err != nil {
return err
}
fr2, err := f.friendRequest.Take(ctx, friendRequest.ToUserID, friendRequest.FromUserID)
if err == nil && fr2.HandleResult == constant.FriendResponseNotHandle {
fr2.HandlerUserID = mcontext.GetOpUserID(ctx)
fr2.HandleResult = constant.FriendResponseAgree
fr2.HandleTime = now
err = f.friendRequest.Update(ctx, fr2)
if err != nil {
return err
}
} else if err != nil && (!mgo.IsNotFound(err)) {
return err
}
exists, err := f.friend.FindUserState(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil {
return err
}
existsMap := datautil.SliceSet(datautil.Slice(exists, func(friend *model.Friend) [2]string {
return [...]string{friend.OwnerUserID, friend.FriendUserID} // My - Friend
}))
var adds []*model.Friend
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // My - Friend
adds = append(
adds,
&model.Friend{
OwnerUserID: friendRequest.ToUserID,
FriendUserID: friendRequest.FromUserID,
AddSource: int32(constant.BecomeFriendByApply),
OperatorUserID: friendRequest.FromUserID,
},
)
}
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // My - Friend
adds = append(
adds,
&model.Friend{
OwnerUserID: friendRequest.FromUserID,
FriendUserID: friendRequest.ToUserID,
AddSource: int32(constant.BecomeFriendByApply),
OperatorUserID: friendRequest.FromUserID,
},
)
}
if len(adds) > 0 {
if err := f.friend.Create(ctx, adds); err != nil {
return err
}
}
return f.cache.DelFriendIDs(friendRequest.ToUserID, friendRequest.FromUserID).DelMaxFriendVersion(friendRequest.ToUserID, friendRequest.FromUserID).ChainExecDel(ctx)
})
}
// Delete removes a friend relationship. It is assumed that the external caller has verified the friendship status.
func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) {
if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil {
return err
}
userIds := append(friendUserIDs, ownerUserID)
return f.cache.DelFriendIDs(userIds...).DelMaxFriendVersion(userIds...).ChainExecDel(ctx)
}
// UpdateRemark updates the remark for a friend. Zero value for remark is also supported.
func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) {
if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil {
return err
}
return f.cache.DelFriend(ownerUserID, friendUserID).DelMaxFriendVersion(ownerUserID).ChainExecDel(ctx)
}
// PageOwnerFriends retrieves the list of friends for the ownerUserID. It does not return an error if the result is empty.
func (f *friendDatabase) PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error) {
return f.friend.FindOwnerFriends(ctx, ownerUserID, pagination)
}
// PageInWhoseFriends identifies in whose friend lists the friendUserID appears.
func (f *friendDatabase) PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error) {
return f.friend.FindInWhoseFriends(ctx, friendUserID, pagination)
}
// PageFriendRequestFromMe retrieves friend requests sent by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestFromMe(ctx context.Context, userID string, handleResults []int, pagination pagination.Pagination) (total int64, friends []*model.FriendRequest, err error) {
return f.friendRequest.FindFromUserID(ctx, userID, handleResults, pagination)
}
// PageFriendRequestToMe retrieves friend requests received by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestToMe(ctx context.Context, userID string, handleResults []int, pagination pagination.Pagination) (total int64, friends []*model.FriendRequest, err error) {
return f.friendRequest.FindToUserID(ctx, userID, handleResults, pagination)
}
// FindFriendsWithError retrieves specified friends' information for ownerUserID. Returns an error if any friend does not exist.
func (f *friendDatabase) FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*model.Friend, err error) {
friends, err = f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil {
return
}
return
}
func (f *friendDatabase) FindFriendUserIDs(ctx context.Context, ownerUserID string) (friendUserIDs []string, err error) {
return f.cache.GetFriendIDs(ctx, ownerUserID)
}
func (f *friendDatabase) FindBothFriendRequests(ctx context.Context, fromUserID, toUserID string) (friends []*model.FriendRequest, err error) {
return f.friendRequest.FindBothFriendRequests(ctx, fromUserID, toUserID)
}
func (f *friendDatabase) UpdateFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, val map[string]any) (err error) {
if len(val) == 0 {
return nil
}
return f.tx.Transaction(ctx, func(ctx context.Context) error {
if err := f.friend.UpdateFriends(ctx, ownerUserID, friendUserIDs, val); err != nil {
return err
}
return f.cache.DelFriends(ownerUserID, friendUserIDs).DelMaxFriendVersion(ownerUserID).ChainExecDel(ctx)
})
}
//func (f *friendDatabase) FindSortFriendUserIDs(ctx context.Context, ownerUserID string) ([]string, error) {
// return f.cache.FindSortFriendUserIDs(ctx, ownerUserID)
//}
func (f *friendDatabase) FindFriendIncrVersion(ctx context.Context, ownerUserID string, version uint, limit int) (*model.VersionLog, error) {
return f.friend.FindIncrVersion(ctx, ownerUserID, version, limit)
}
func (f *friendDatabase) FindMaxFriendVersionCache(ctx context.Context, ownerUserID string) (*model.VersionLog, error) {
return f.cache.FindMaxFriendVersion(ctx, ownerUserID)
}
func (f *friendDatabase) FindFriendUserID(ctx context.Context, friendUserID string) ([]string, error) {
return f.friend.FindFriendUserID(ctx, friendUserID)
}
//func (f *friendDatabase) SearchFriend(ctx context.Context, ownerUserID, keyword string, pagination pagination.Pagination) (int64, []*model.Friend, error) {
// return f.friend.SearchFriend(ctx, ownerUserID, keyword, pagination)
//}
func (f *friendDatabase) OwnerIncrVersion(ctx context.Context, ownerUserID string, friendUserIDs []string, state int32) error {
if err := f.friend.IncrVersion(ctx, ownerUserID, friendUserIDs, state); err != nil {
return err
}
return f.cache.DelMaxFriendVersion(ownerUserID).ChainExecDel(ctx)
}
func (f *friendDatabase) GetUnhandledCount(ctx context.Context, userID string, ts int64) (int64, error) {
return f.friendRequest.GetUnhandledCount(ctx, userID, ts)
}

View File

@@ -0,0 +1,574 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
redis2 "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/common"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
type GroupDatabase interface {
// CreateGroup creates new groups along with their members.
CreateGroup(ctx context.Context, groups []*model.Group, groupMembers []*model.GroupMember) error
// TakeGroup retrieves a single group by its ID.
TakeGroup(ctx context.Context, groupID string) (group *model.Group, err error)
// FindGroup retrieves multiple groups by their IDs.
FindGroup(ctx context.Context, groupIDs []string) (groups []*model.Group, err error)
// SearchGroup searches for groups based on a keyword and pagination settings, returns total count and groups.
SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error)
// UpdateGroup updates the properties of a group identified by its ID.
UpdateGroup(ctx context.Context, groupID string, data map[string]any) error
// DismissGroup disbands a group and optionally removes its members based on the deleteMember flag.
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error
// TakeGroupMember retrieves a specific group member by group ID and user ID.
TakeGroupMember(ctx context.Context, groupID string, userID string) (groupMember *model.GroupMember, err error)
// TakeGroupOwner retrieves the owner of a group by group ID.
TakeGroupOwner(ctx context.Context, groupID string) (*model.GroupMember, error)
// FindGroupMembers retrieves members of a group filtered by user IDs.
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*model.GroupMember, err error)
// FindGroupMemberUser retrieves groups that a user is a member of, filtered by group IDs.
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*model.GroupMember, err error)
// FindGroupMemberRoleLevels retrieves group members filtered by their role levels within a group.
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*model.GroupMember, err error)
// FindGroupMemberAll retrieves all members of a group.
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*model.GroupMember, err error)
// FindGroupsOwner retrieves the owners for multiple groups.
FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error)
// FindGroupMemberUserID retrieves the user IDs of all members in a group.
FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error)
// FindGroupMemberNum retrieves the number of members in a group.
FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error)
// FindUserManagedGroupID retrieves group IDs managed by a user.
FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
// PageGroupRequest paginates through group requests for specified groups.
PageGroupRequest(ctx context.Context, groupIDs []string, handleResults []int, pagination pagination.Pagination) (int64, []*model.GroupRequest, error)
// GetGroupRoleLevelMemberIDs retrieves user IDs of group members with a specific role level.
GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
// PageGetJoinGroup paginates through groups that a user has joined.
PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*model.GroupMember, err error)
// PageGetGroupMember paginates through members of a group.
PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*model.GroupMember, err error)
// SearchGroupMember searches for group members based on a keyword, group ID, and pagination settings.
SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*model.GroupMember, error)
// SearchGroupMemberByFields searches for group members by multiple independent fields: nickname, userID (account), and phone
SearchGroupMemberByFields(ctx context.Context, groupID string, nickname, userID, phone string, pagination pagination.Pagination) (int64, []*model.GroupMember, error)
// HandlerGroupRequest processes a group join request with a specified result.
HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *model.GroupMember) error
// DeleteGroupMember removes specified users from a group.
DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error
// MapGroupMemberUserID maps group IDs to their members' simplified user IDs.
MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*common.GroupSimpleUserID, error)
// MapGroupMemberNum maps group IDs to their member count.
MapGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error)
// TransferGroupOwner transfers the ownership of a group to another user.
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error
// UpdateGroupMember updates properties of a group member.
UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error
// UpdateGroupMembers batch updates properties of group members.
UpdateGroupMembers(ctx context.Context, data []*common.BatchUpdateGroupMember) error
// CreateGroupRequest creates new group join requests.
CreateGroupRequest(ctx context.Context, requests []*model.GroupRequest) error
// TakeGroupRequest retrieves a specific group join request.
TakeGroupRequest(ctx context.Context, groupID string, userID string) (*model.GroupRequest, error)
// FindGroupRequests retrieves multiple group join requests.
FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupRequest, error)
// PageGroupRequestUser paginates through group join requests made by a user.
PageGroupRequestUser(ctx context.Context, userID string, groupIDs []string, handleResults []int, pagination pagination.Pagination) (int64, []*model.GroupRequest, error)
// CountTotal counts the total number of groups as of a certain date.
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// CountRangeEverydayTotal counts the daily group creation total within a specified date range.
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
// DeleteGroupMemberHash deletes the hash entries for group members in specified groups.
DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error
FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error)
BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error)
FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error)
MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error
//FindSortGroupMemberUserIDs(ctx context.Context, groupID string) ([]string, error)
//FindSortJoinGroupIDs(ctx context.Context, userID string) ([]string, error)
FindMaxGroupMemberVersionCache(ctx context.Context, groupID string) (*model.VersionLog, error)
BatchFindMaxGroupMemberVersionCache(ctx context.Context, groupIDs []string) (map[string]*model.VersionLog, error)
FindMaxJoinGroupVersionCache(ctx context.Context, userID string) (*model.VersionLog, error)
SearchJoinGroup(ctx context.Context, userID string, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error)
FindJoinGroupID(ctx context.Context, userID string) ([]string, error)
GetGroupApplicationUnhandledCount(ctx context.Context, groupIDs []string, ts int64) (int64, error)
}
func NewGroupDatabase(
rdb redis.UniversalClient,
localCache *config.LocalCache,
groupDB database.Group,
groupMemberDB database.GroupMember,
groupRequestDB database.GroupRequest,
ctxTx tx.Tx,
groupHash cache.GroupHash,
) GroupDatabase {
return &groupDatabase{
groupDB: groupDB,
groupMemberDB: groupMemberDB,
groupRequestDB: groupRequestDB,
ctxTx: ctxTx,
cache: redis2.NewGroupCacheRedis(rdb, localCache, groupDB, groupMemberDB, groupRequestDB, groupHash),
}
}
type groupDatabase struct {
groupDB database.Group
groupMemberDB database.GroupMember
groupRequestDB database.GroupRequest
ctxTx tx.Tx
cache cache.GroupCache
}
func (g *groupDatabase) FindJoinGroupID(ctx context.Context, userID string) ([]string, error) {
return g.cache.GetJoinedGroupIDs(ctx, userID)
}
func (g *groupDatabase) FindGroupMembers(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) {
return g.cache.GetGroupMembersInfo(ctx, groupID, userIDs)
}
func (g *groupDatabase) FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) ([]*model.GroupMember, error) {
return g.cache.FindGroupMemberUser(ctx, groupIDs, userID)
}
func (g *groupDatabase) FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) ([]*model.GroupMember, error) {
return g.cache.GetGroupRolesLevelMemberInfo(ctx, groupID, roleLevels)
}
func (g *groupDatabase) FindGroupMemberAll(ctx context.Context, groupID string) ([]*model.GroupMember, error) {
return g.cache.GetAllGroupMembersInfo(ctx, groupID)
}
func (g *groupDatabase) FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*model.GroupMember, error) {
return g.cache.GetGroupsOwner(ctx, groupIDs)
}
func (g *groupDatabase) GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error) {
return g.cache.GetGroupRoleLevelMemberIDs(ctx, groupID, roleLevel)
}
func (g *groupDatabase) CreateGroup(ctx context.Context, groups []*model.Group, groupMembers []*model.GroupMember) error {
if len(groups)+len(groupMembers) == 0 {
return nil
}
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
c := g.cache.CloneGroupCache()
if len(groups) > 0 {
if err := g.groupDB.Create(ctx, groups); err != nil {
return err
}
for _, group := range groups {
c = c.DelGroupsInfo(group.GroupID).
DelGroupMembersHash(group.GroupID).
DelGroupsMemberNum(group.GroupID).
DelGroupMemberIDs(group.GroupID).
DelGroupAllRoleLevel(group.GroupID).
DelMaxGroupMemberVersion(group.GroupID)
}
}
if len(groupMembers) > 0 {
if err := g.groupMemberDB.Create(ctx, groupMembers); err != nil {
return err
}
for _, groupMember := range groupMembers {
c = c.DelGroupMembersHash(groupMember.GroupID).
DelGroupsMemberNum(groupMember.GroupID).
DelGroupMemberIDs(groupMember.GroupID).
DelJoinedGroupID(groupMember.UserID).
DelGroupMembersInfo(groupMember.GroupID, groupMember.UserID).
DelGroupAllRoleLevel(groupMember.GroupID).
DelMaxJoinGroupVersion(groupMember.UserID).
DelMaxGroupMemberVersion(groupMember.GroupID)
}
}
return c.ChainExecDel(ctx)
})
}
func (g *groupDatabase) FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error) {
return g.cache.GetGroupMemberIDs(ctx, groupID)
}
func (g *groupDatabase) FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error) {
num, err := g.cache.GetGroupMemberNum(ctx, groupID)
if err != nil {
return 0, err
}
return uint32(num), nil
}
func (g *groupDatabase) TakeGroup(ctx context.Context, groupID string) (*model.Group, error) {
return g.cache.GetGroupInfo(ctx, groupID)
}
func (g *groupDatabase) FindGroup(ctx context.Context, groupIDs []string) ([]*model.Group, error) {
return g.cache.GetGroupsInfo(ctx, groupIDs)
}
func (g *groupDatabase) SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error) {
return g.groupDB.Search(ctx, keyword, pagination)
}
func (g *groupDatabase) UpdateGroup(ctx context.Context, groupID string, data map[string]any) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
if err := g.groupDB.UpdateMap(ctx, groupID, data); err != nil {
return err
}
if err := g.groupMemberDB.MemberGroupIncrVersion(ctx, groupID, []string{""}, model.VersionStateUpdate); err != nil {
return err
}
return g.cache.CloneGroupCache().DelGroupsInfo(groupID).DelMaxGroupMemberVersion(groupID).ChainExecDel(ctx)
})
}
func (g *groupDatabase) DismissGroup(ctx context.Context, groupID string, deleteMember bool) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
c := g.cache.CloneGroupCache()
if err := g.groupDB.UpdateStatus(ctx, groupID, constant.GroupStatusDismissed); err != nil {
return err
}
if deleteMember {
userIDs, err := g.cache.GetGroupMemberIDs(ctx, groupID)
if err != nil {
return err
}
if err := g.groupMemberDB.Delete(ctx, groupID, nil); err != nil {
return err
}
c = c.DelJoinedGroupID(userIDs...).
DelGroupMemberIDs(groupID).
DelGroupsMemberNum(groupID).
DelGroupMembersHash(groupID).
DelGroupAllRoleLevel(groupID).
DelGroupMembersInfo(groupID, userIDs...).
DelMaxGroupMemberVersion(groupID).
DelMaxJoinGroupVersion(userIDs...)
for _, userID := range userIDs {
if err := g.groupMemberDB.JoinGroupIncrVersion(ctx, userID, []string{groupID}, model.VersionStateDelete); err != nil {
return err
}
}
} else {
if err := g.groupMemberDB.MemberGroupIncrVersion(ctx, groupID, []string{""}, model.VersionStateUpdate); err != nil {
return err
}
c = c.DelMaxGroupMemberVersion(groupID)
}
return c.DelGroupsInfo(groupID).ChainExecDel(ctx)
})
}
func (g *groupDatabase) TakeGroupMember(ctx context.Context, groupID string, userID string) (*model.GroupMember, error) {
return g.cache.GetGroupMemberInfo(ctx, groupID, userID)
}
func (g *groupDatabase) TakeGroupOwner(ctx context.Context, groupID string) (*model.GroupMember, error) {
return g.cache.GetGroupOwner(ctx, groupID)
}
func (g *groupDatabase) FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) {
return g.groupMemberDB.FindUserManagedGroupID(ctx, userID)
}
func (g *groupDatabase) PageGroupRequest(ctx context.Context, groupIDs []string, handleResults []int, pagination pagination.Pagination) (int64, []*model.GroupRequest, error) {
return g.groupRequestDB.PageGroup(ctx, groupIDs, handleResults, pagination)
}
func (g *groupDatabase) PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*model.GroupMember, err error) {
groupIDs, err := g.cache.GetJoinedGroupIDs(ctx, userID)
if err != nil {
return 0, nil, err
}
for _, groupID := range datautil.Paginate(groupIDs, int(pagination.GetPageNumber()), int(pagination.GetShowNumber())) {
groupMembers, err := g.cache.GetGroupMembersInfo(ctx, groupID, []string{userID})
if err != nil {
return 0, nil, err
}
totalGroupMembers = append(totalGroupMembers, groupMembers...)
}
return int64(len(groupIDs)), totalGroupMembers, nil
}
func (g *groupDatabase) PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*model.GroupMember, err error) {
groupMemberIDs, err := g.cache.GetGroupMemberIDs(ctx, groupID)
if err != nil {
return 0, nil, err
}
pageIDs := datautil.Paginate(groupMemberIDs, int(pagination.GetPageNumber()), int(pagination.GetShowNumber()))
if len(pageIDs) == 0 {
return int64(len(groupMemberIDs)), nil, nil
}
members, err := g.cache.GetGroupMembersInfo(ctx, groupID, pageIDs)
if err != nil {
return 0, nil, err
}
return int64(len(groupMemberIDs)), members, nil
}
func (g *groupDatabase) SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*model.GroupMember, error) {
return g.groupMemberDB.SearchMember(ctx, keyword, groupID, pagination)
}
func (g *groupDatabase) SearchGroupMemberByFields(ctx context.Context, groupID string, nickname, userID, phone string, pagination pagination.Pagination) (int64, []*model.GroupMember, error) {
return g.groupMemberDB.SearchMemberByFields(ctx, groupID, nickname, userID, phone, pagination)
}
func (g *groupDatabase) HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *model.GroupMember) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
if err := g.groupRequestDB.UpdateHandler(ctx, groupID, userID, handledMsg, handleResult); err != nil {
return err
}
if member != nil {
c := g.cache.CloneGroupCache()
if err := g.groupMemberDB.Create(ctx, []*model.GroupMember{member}); err != nil {
return err
}
c = c.DelGroupMembersHash(groupID).
DelGroupMembersInfo(groupID, member.UserID).
DelGroupMemberIDs(groupID).
DelGroupsMemberNum(groupID).
DelJoinedGroupID(member.UserID).
DelGroupRoleLevel(groupID, []int32{member.RoleLevel}).
DelMaxJoinGroupVersion(userID).
DelMaxGroupMemberVersion(groupID)
if err := c.ChainExecDel(ctx); err != nil {
return err
}
}
return nil
})
}
func (g *groupDatabase) DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
if err := g.groupMemberDB.Delete(ctx, groupID, userIDs); err != nil {
return err
}
c := g.cache.CloneGroupCache()
return c.DelGroupMembersHash(groupID).
DelGroupMemberIDs(groupID).
DelGroupsMemberNum(groupID).
DelJoinedGroupID(userIDs...).
DelGroupMembersInfo(groupID, userIDs...).
DelGroupAllRoleLevel(groupID).
DelMaxGroupMemberVersion(groupID).
DelMaxJoinGroupVersion(userIDs...).
ChainExecDel(ctx)
})
}
func (g *groupDatabase) MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*common.GroupSimpleUserID, error) {
return g.cache.GetGroupMemberHashMap(ctx, groupIDs)
}
func (g *groupDatabase) MapGroupMemberNum(ctx context.Context, groupIDs []string) (m map[string]uint32, err error) {
m = make(map[string]uint32)
for _, groupID := range groupIDs {
num, err := g.cache.GetGroupMemberNum(ctx, groupID)
if err != nil {
return nil, err
}
m[groupID] = uint32(num)
}
return m, nil
}
func (g *groupDatabase) TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
if err := g.groupMemberDB.UpdateUserRoleLevels(ctx, groupID, oldOwnerUserID, roleLevel, newOwnerUserID, constant.GroupOwner); err != nil {
return err
}
c := g.cache.CloneGroupCache()
return c.DelGroupMembersInfo(groupID, oldOwnerUserID, newOwnerUserID).
DelGroupAllRoleLevel(groupID).
DelGroupMembersHash(groupID).
DelMaxGroupMemberVersion(groupID).
DelGroupMemberIDs(groupID).
ChainExecDel(ctx)
})
}
func (g *groupDatabase) UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error {
if len(data) == 0 {
return nil
}
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
if err := g.groupMemberDB.Update(ctx, groupID, userID, data); err != nil {
return err
}
c := g.cache.CloneGroupCache()
c = c.DelGroupMembersInfo(groupID, userID)
if g.groupMemberDB.IsUpdateRoleLevel(data) {
c = c.DelGroupAllRoleLevel(groupID).DelGroupMemberIDs(groupID)
}
c = c.DelMaxGroupMemberVersion(groupID)
return c.ChainExecDel(ctx)
})
}
func (g *groupDatabase) UpdateGroupMembers(ctx context.Context, data []*common.BatchUpdateGroupMember) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
c := g.cache.CloneGroupCache()
for _, item := range data {
if err := g.groupMemberDB.Update(ctx, item.GroupID, item.UserID, item.Map); err != nil {
return err
}
if g.groupMemberDB.IsUpdateRoleLevel(item.Map) {
c = c.DelGroupAllRoleLevel(item.GroupID).DelGroupMemberIDs(item.GroupID)
}
c = c.DelGroupMembersInfo(item.GroupID, item.UserID).DelMaxGroupMemberVersion(item.GroupID).DelGroupMembersHash(item.GroupID)
}
return c.ChainExecDel(ctx)
})
}
func (g *groupDatabase) CreateGroupRequest(ctx context.Context, requests []*model.GroupRequest) error {
return g.ctxTx.Transaction(ctx, func(ctx context.Context) error {
for _, request := range requests {
if err := g.groupRequestDB.Delete(ctx, request.GroupID, request.UserID); err != nil {
return err
}
}
return g.groupRequestDB.Create(ctx, requests)
})
}
func (g *groupDatabase) TakeGroupRequest(ctx context.Context, groupID string, userID string) (*model.GroupRequest, error) {
return g.groupRequestDB.Take(ctx, groupID, userID)
}
func (g *groupDatabase) PageGroupRequestUser(ctx context.Context, userID string, groupIDs []string, handleResults []int, pagination pagination.Pagination) (int64, []*model.GroupRequest, error) {
return g.groupRequestDB.Page(ctx, userID, groupIDs, handleResults, pagination)
}
func (g *groupDatabase) CountTotal(ctx context.Context, before *time.Time) (count int64, err error) {
return g.groupDB.CountTotal(ctx, before)
}
func (g *groupDatabase) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) {
return g.groupDB.CountRangeEverydayTotal(ctx, start, end)
}
func (g *groupDatabase) FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupRequest, error) {
return g.groupRequestDB.FindGroupRequests(ctx, groupID, userIDs)
}
func (g *groupDatabase) DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
c := g.cache.CloneGroupCache()
for _, groupID := range groupIDs {
c = c.DelGroupMembersHash(groupID)
}
return c.ChainExecDel(ctx)
}
func (g *groupDatabase) FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error) {
return g.groupMemberDB.FindMemberIncrVersion(ctx, groupID, version, limit)
}
func (g *groupDatabase) BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error) {
if len(groupIDs) == 0 {
return nil, errs.Wrap(errs.New("groupIDs is nil."))
}
// convert []uint64 to []uint
var uintVersions []uint
for _, version := range versions {
uintVersions = append(uintVersions, uint(version))
}
versionLogs, err := g.groupMemberDB.BatchFindMemberIncrVersion(ctx, groupIDs, uintVersions, limits)
if err != nil {
return nil, errs.Wrap(err)
}
groupMemberIncrVersionsMap := datautil.SliceToMap(versionLogs, func(e *model.VersionLog) string {
return e.DID
})
return groupMemberIncrVersionsMap, nil
}
func (g *groupDatabase) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) {
return g.groupMemberDB.FindJoinIncrVersion(ctx, userID, version, limit)
}
func (g *groupDatabase) FindMaxGroupMemberVersionCache(ctx context.Context, groupID string) (*model.VersionLog, error) {
return g.cache.FindMaxGroupMemberVersion(ctx, groupID)
}
func (g *groupDatabase) BatchFindMaxGroupMemberVersionCache(ctx context.Context, groupIDs []string) (map[string]*model.VersionLog, error) {
if len(groupIDs) == 0 {
return nil, errs.Wrap(errs.New("groupIDs is nil in Cache."))
}
versionLogs, err := g.cache.BatchFindMaxGroupMemberVersion(ctx, groupIDs)
if err != nil {
return nil, errs.Wrap(err)
}
maxGroupMemberVersionsMap := datautil.SliceToMap(versionLogs, func(e *model.VersionLog) string {
return e.DID
})
return maxGroupMemberVersionsMap, nil
}
func (g *groupDatabase) FindMaxJoinGroupVersionCache(ctx context.Context, userID string) (*model.VersionLog, error) {
return g.cache.FindMaxJoinGroupVersion(ctx, userID)
}
func (g *groupDatabase) SearchJoinGroup(ctx context.Context, userID string, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error) {
groupIDs, err := g.cache.GetJoinedGroupIDs(ctx, userID)
if err != nil {
return 0, nil, err
}
return g.groupDB.SearchJoin(ctx, groupIDs, keyword, pagination)
}
func (g *groupDatabase) MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error {
if err := g.groupMemberDB.MemberGroupIncrVersion(ctx, groupID, userIDs, state); err != nil {
return err
}
return g.cache.DelMaxGroupMemberVersion(groupID).ChainExecDel(ctx)
}
func (g *groupDatabase) GetGroupApplicationUnhandledCount(ctx context.Context, groupIDs []string, ts int64) (int64, error) {
return g.groupRequestDB.GetUnhandledCount(ctx, groupIDs, ts)
}

View File

@@ -0,0 +1,865 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"encoding/json"
"errors"
"github.com/openimsdk/tools/mq"
"github.com/openimsdk/tools/utils/jsonutil"
"google.golang.org/protobuf/proto"
"strconv"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/constant"
pbmsg "git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
)
const (
updateKeyMsg = iota
updateKeyRevoke
)
// CommonMsgDatabase defines the interface for message database operations.
type CommonMsgDatabase interface {
// RevokeMsg revokes a message in a conversation.
RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *model.RevokeModel) error
// MarkSingleChatMsgsAsRead marks messages as read for a single chat by sequence numbers.
MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error
// GetMsgBySeqsRange retrieves messages from MongoDB by a range of sequence numbers.
GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// GetMsgBySeqs retrieves messages for large groups from MongoDB by sequence numbers.
GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
GetMessagesBySeqWithBounds(ctx context.Context, userID string, conversationID string, seqs []int64, pullOrder sdkws.PullOrder) (bool, int64, []*sdkws.MsgData, error)
// DeleteUserMsgsBySeqs allows a user to delete messages based on sequence numbers.
DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error
// DeleteMsgsPhysicalBySeqs physically deletes messages by emptying them based on sequence numbers.
DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, seqs []int64) error
//SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error)
SetMinSeqs(ctx context.Context, seqs map[string]int64) error
SetMinSeq(ctx context.Context, conversationID string, seq int64) error
SetUserConversationsMinSeqs(ctx context.Context, userID string, seqs map[string]int64) (err error)
SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error
GetHasReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error)
GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error)
UserSetHasReadSeqs(ctx context.Context, userID string, hasReadSeqs map[string]int64) error
GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error)
GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error)
SetSendMsgStatus(ctx context.Context, id string, status int32) error
GetSendMsgStatus(ctx context.Context, id string) (int32, error)
SearchMessage(ctx context.Context, req *pbmsg.SearchMessageReq) (total int64, msgData []*pbmsg.SearchedMsgData, err error)
FindOneByDocIDs(ctx context.Context, docIDs []string, seqs map[string]int64) (map[string]*sdkws.MsgData, error)
// to mq
MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error
RangeUserSendCount(ctx context.Context, start time.Time, end time.Time, group bool, ase bool, pageNumber int32, showNumber int32) (msgCount int64, userCount int64, users []*model.UserCount, dateCount map[string]int64, err error)
RangeGroupSendCount(ctx context.Context, start time.Time, end time.Time, ase bool, pageNumber int32, showNumber int32) (msgCount int64, userCount int64, groups []*model.GroupCount, dateCount map[string]int64, err error)
GetRandBeforeMsg(ctx context.Context, ts int64, limit int) ([]*model.MsgDocModel, error)
SetUserConversationsMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error
SetUserConversationsMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error
DeleteDoc(ctx context.Context, docID string) error
GetLastMessageSeqByTime(ctx context.Context, conversationID string, time int64) (int64, error)
GetLastMessage(ctx context.Context, conversationIDS []string, userID string) (map[string]*sdkws.MsgData, error)
}
func NewCommonMsgDatabase(msgDocModel database.Msg, msg cache.MsgCache, seqUser cache.SeqUser, seqConversation cache.SeqConversationCache, producer mq.Producer) CommonMsgDatabase {
return &commonMsgDatabase{
msgDocDatabase: msgDocModel,
msgCache: msg,
seqUser: seqUser,
seqConversation: seqConversation,
producer: producer,
}
}
type commonMsgDatabase struct {
msgDocDatabase database.Msg
msgTable model.MsgDocModel
msgCache cache.MsgCache
seqConversation cache.SeqConversationCache
seqUser cache.SeqUser
producer mq.Producer
}
func (db *commonMsgDatabase) MsgToMQ(ctx context.Context, key string, msg2mq *sdkws.MsgData) error {
data, err := proto.Marshal(msg2mq)
if err != nil {
return err
}
return db.producer.SendMessage(ctx, key, data)
}
func (db *commonMsgDatabase) batchInsertBlock(ctx context.Context, conversationID string, fields []any, key int8, firstSeq int64) error {
if len(fields) == 0 {
return nil
}
num := db.msgTable.GetSingleGocMsgNum()
// num = 100
for i, field := range fields { // Check the type of the field
var ok bool
switch key {
case updateKeyMsg:
var msg *model.MsgDataModel
msg, ok = field.(*model.MsgDataModel)
if msg != nil && msg.Seq != firstSeq+int64(i) {
return errs.ErrInternalServer.WrapMsg("seq is invalid")
}
case updateKeyRevoke:
_, ok = field.(*model.RevokeModel)
default:
return errs.ErrInternalServer.WrapMsg("key is invalid")
}
if !ok {
return errs.ErrInternalServer.WrapMsg("field type is invalid")
}
}
// Returns true if the document exists in the database, false if the document does not exist in the database
updateMsgModel := func(seq int64, i int) (bool, error) {
var (
res *mongo.UpdateResult
err error
)
docID := db.msgTable.GetDocID(conversationID, seq)
index := db.msgTable.GetMsgIndex(seq)
field := fields[i]
switch key {
case updateKeyMsg:
res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "msg", field)
case updateKeyRevoke:
res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "revoke", field)
}
if err != nil {
return false, err
}
return res.MatchedCount > 0, nil
}
tryUpdate := true
for i := 0; i < len(fields); i++ {
seq := firstSeq + int64(i) // Current sequence number
if tryUpdate {
matched, err := updateMsgModel(seq, i)
if err != nil {
return err
}
if matched {
continue // The current data has been updated, skip the current data
}
}
doc := model.MsgDocModel{
DocID: db.msgTable.GetDocID(conversationID, seq),
Msg: make([]*model.MsgInfoModel, num),
}
var insert int // Inserted data number
for j := i; j < len(fields); j++ {
seq = firstSeq + int64(j)
if db.msgTable.GetDocID(conversationID, seq) != doc.DocID {
break
}
insert++
switch key {
case updateKeyMsg:
doc.Msg[db.msgTable.GetMsgIndex(seq)] = &model.MsgInfoModel{
Msg: fields[j].(*model.MsgDataModel),
}
case updateKeyRevoke:
doc.Msg[db.msgTable.GetMsgIndex(seq)] = &model.MsgInfoModel{
Revoke: fields[j].(*model.RevokeModel),
}
}
}
for i, msgInfo := range doc.Msg {
if msgInfo == nil {
msgInfo = &model.MsgInfoModel{}
doc.Msg[i] = msgInfo
}
if msgInfo.DelList == nil {
doc.Msg[i].DelList = []string{}
}
}
if err := db.msgDocDatabase.Create(ctx, &doc); err != nil {
if mongo.IsDuplicateKeyError(err) {
i-- // already inserted
tryUpdate = true // next block use update mode
continue
}
return err
}
tryUpdate = false // The current block is inserted successfully, and the next block is inserted preferentially
i += insert - 1 // Skip the inserted data
}
return nil
}
func (db *commonMsgDatabase) RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *model.RevokeModel) error {
if err := db.batchInsertBlock(ctx, conversationID, []any{revoke}, updateKeyRevoke, seq); err != nil {
return err
}
return db.msgCache.DelMessageBySeqs(ctx, conversationID, []int64{seq})
}
func (db *commonMsgDatabase) MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, totalSeqs []int64) error {
for docID, seqs := range db.msgTable.GetDocIDSeqsMap(conversationID, totalSeqs) {
var indexes []int64
for _, seq := range seqs {
indexes = append(indexes, db.msgTable.GetMsgIndex(seq))
}
log.ZDebug(ctx, "MarkSingleChatMsgsAsRead", "userID", userID, "docID", docID, "indexes", indexes)
if err := db.msgDocDatabase.MarkSingleChatMsgsAsRead(ctx, userID, docID, indexes); err != nil {
log.ZError(ctx, "MarkSingleChatMsgsAsRead", err, "userID", userID, "docID", docID, "indexes", indexes)
return err
}
}
return db.msgCache.DelMessageBySeqs(ctx, conversationID, totalSeqs)
}
func (db *commonMsgDatabase) getMsgBySeqs(ctx context.Context, userID, conversationID string, seqs []int64) (totalMsgs []*sdkws.MsgData, err error) {
return db.GetMessageBySeqs(ctx, conversationID, userID, seqs)
}
func (db *commonMsgDatabase) handlerDBMsg(ctx context.Context, cache map[int64][]*model.MsgInfoModel, userID, conversationID string, msg *model.MsgInfoModel) {
if msg == nil || msg.Msg == nil {
return
}
if msg.IsRead {
msg.Msg.IsRead = true
}
if msg.Msg.ContentType != constant.Quote {
return
}
if msg.Msg.Content == "" {
return
}
type MsgData struct {
SendID string `json:"sendID"`
RecvID string `json:"recvID"`
GroupID string `json:"groupID"`
ClientMsgID string `json:"clientMsgID"`
ServerMsgID string `json:"serverMsgID"`
SenderPlatformID int32 `json:"senderPlatformID"`
SenderNickname string `json:"senderNickname"`
SenderFaceURL string `json:"senderFaceURL"`
SessionType int32 `json:"sessionType"`
MsgFrom int32 `json:"msgFrom"`
ContentType int32 `json:"contentType"`
Content string `json:"content"`
Seq int64 `json:"seq"`
SendTime int64 `json:"sendTime"`
CreateTime int64 `json:"createTime"`
Status int32 `json:"status"`
IsRead bool `json:"isRead"`
Options map[string]bool `json:"options,omitempty"`
OfflinePushInfo *sdkws.OfflinePushInfo `json:"offlinePushInfo"`
AtUserIDList []string `json:"atUserIDList"`
AttachedInfo string `json:"attachedInfo"`
Ex string `json:"ex"`
KeyVersion int32 `json:"keyVersion"`
DstUserIDs []string `json:"dstUserIDs"`
}
var quoteMsg struct {
Text string `json:"text,omitempty"`
QuoteMessage *MsgData `json:"quoteMessage,omitempty"`
MessageEntityList json.RawMessage `json:"messageEntityList,omitempty"`
}
if err := json.Unmarshal([]byte(msg.Msg.Content), &quoteMsg); err != nil {
log.ZError(ctx, "json.Unmarshal", err)
return
}
if quoteMsg.QuoteMessage == nil {
return
}
if quoteMsg.QuoteMessage.Content == "e30=" {
quoteMsg.QuoteMessage.Content = "{}"
data, err := json.Marshal(&quoteMsg)
if err != nil {
return
}
msg.Msg.Content = string(data)
}
if quoteMsg.QuoteMessage.Seq <= 0 && quoteMsg.QuoteMessage.ContentType == constant.MsgRevokeNotification {
return
}
var msgs []*model.MsgInfoModel
if v, ok := cache[quoteMsg.QuoteMessage.Seq]; ok {
msgs = v
} else {
if quoteMsg.QuoteMessage.Seq > 0 {
ms, err := db.msgDocDatabase.GetMsgBySeqIndexIn1Doc(ctx, userID, db.msgTable.GetDocID(conversationID, quoteMsg.QuoteMessage.Seq), []int64{quoteMsg.QuoteMessage.Seq})
if err != nil {
log.ZError(ctx, "GetMsgBySeqIndexIn1Doc", err, "conversationID", conversationID, "seq", quoteMsg.QuoteMessage.Seq)
return
}
msgs = ms
cache[quoteMsg.QuoteMessage.Seq] = ms
}
}
if len(msgs) != 0 && msgs[0].Msg.ContentType != constant.MsgRevokeNotification {
return
}
quoteMsg.QuoteMessage.ContentType = constant.MsgRevokeNotification
if len(msgs) > 0 {
quoteMsg.QuoteMessage.Content = msgs[0].Msg.Content
} else {
quoteMsg.QuoteMessage.Content = "{}"
}
data, err := json.Marshal(&quoteMsg)
if err != nil {
log.ZError(ctx, "json.Marshal", err)
return
}
msg.Msg.Content = string(data)
}
func (db *commonMsgDatabase) findMsgInfoBySeq(ctx context.Context, userID, docID string, conversationID string, seqs []int64) (totalMsgs []*model.MsgInfoModel, err error) {
msgs, err := db.msgDocDatabase.GetMsgBySeqIndexIn1Doc(ctx, userID, docID, seqs)
if err != nil {
return nil, err
}
tempCache := make(map[int64][]*model.MsgInfoModel)
for _, msg := range msgs {
db.handlerDBMsg(ctx, tempCache, userID, conversationID, msg)
}
return msgs, err
}
// GetMsgBySeqsRange In the context of group chat, we have the following parameters:
//
// "maxSeq" of a conversation: It represents the maximum value of messages in the group conversation.
// "minSeq" of a conversation (default: 1): It represents the minimum value of messages in the group conversation.
//
// For a user's perspective regarding the group conversation, we have the following parameters:
//
// "userMaxSeq": It represents the user's upper limit for message retrieval in the group. If not set (default: 0),
// it means the upper limit is the same as the conversation's "maxSeq".
// "userMinSeq": It represents the user's starting point for message retrieval in the group. If not set (default: 0),
// it means the starting point is the same as the conversation's "minSeq".
//
// The scenarios for these parameters are as follows:
//
// For users who have been kicked out of the group, "userMaxSeq" can be set as the maximum value they had before
// being kicked out. This limits their ability to retrieve messages up to a certain point.
// For new users joining the group, if they don't need to receive old messages,
// "userMinSeq" can be set as the same value as the conversation's "maxSeq" at the moment they join the group.
// This ensures that their message retrieval starts from the point they joined.
func (db *commonMsgDatabase) GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (int64, int64, []*sdkws.MsgData, error) {
userMinSeq, err := db.seqUser.GetUserMinSeq(ctx, conversationID, userID)
if err != nil && !errors.Is(err, redis.Nil) {
return 0, 0, nil, err
}
minSeq, err := db.seqConversation.GetMinSeq(ctx, conversationID)
if err != nil {
return 0, 0, nil, err
}
if userMinSeq > minSeq {
minSeq = userMinSeq
}
// "minSeq" represents the startSeq value that the user can retrieve.
if minSeq > end {
log.ZWarn(ctx, "minSeq > end", errs.New("minSeq>end"), "minSeq", minSeq, "end", end)
return 0, 0, nil, nil
}
maxSeq, err := db.seqConversation.GetMaxSeq(ctx, conversationID)
if err != nil {
return 0, 0, nil, err
}
log.ZDebug(ctx, "GetMsgBySeqsRange", "userMinSeq", userMinSeq, "conMinSeq", minSeq, "conMaxSeq", maxSeq, "userMaxSeq", userMaxSeq)
if userMaxSeq != 0 {
if userMaxSeq < maxSeq {
maxSeq = userMaxSeq
}
}
// "maxSeq" represents the endSeq value that the user can retrieve.
if begin < minSeq {
begin = minSeq
}
if end > maxSeq {
end = maxSeq
}
// "begin" and "end" represent the actual startSeq and endSeq values that the user can retrieve.
if end < begin {
log.ZWarn(ctx, "seq end < begin after adjustment", errs.New("seq end < begin"), "begin", begin, "end", end, "minSeq", minSeq, "maxSeq", maxSeq)
return 0, 0, nil, nil
}
// 限制最大查询数量,防止内存溢出(默认最大 5000 条)
// 如果单次查询范围太大,进一步限制以避免内存问题
const maxFetchLimit = 5000
const maxRangeSize = 10000 // 最大范围限制
if num <= 0 {
num = 100 // 默认值
}
if num > maxFetchLimit {
num = maxFetchLimit
}
var seqs []int64
rangeSize := end - begin + 1
// 如果范围太大,限制范围大小
if rangeSize > maxRangeSize {
log.ZWarn(ctx, "seq range too large, limiting", nil, "conversationID", conversationID, "begin", begin, "end", end, "rangeSize", rangeSize, "maxRangeSize", maxRangeSize)
// 只取最后 maxRangeSize 条
begin = end - maxRangeSize + 1
rangeSize = maxRangeSize
}
if rangeSize <= num {
// 如果范围小于等于 num直接生成所有 seqs
for i := begin; i <= end; i++ {
seqs = append(seqs, i)
}
} else {
// 如果范围大于 num只取最后 num 条
for i := end - num + 1; i <= end; i++ {
seqs = append(seqs, i)
}
}
successMsgs, err := db.GetMessageBySeqs(ctx, conversationID, userID, seqs)
if err != nil {
return 0, 0, nil, err
}
return minSeq, maxSeq, successMsgs, nil
}
func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (int64, int64, []*sdkws.MsgData, error) {
userMinSeq, err := db.seqUser.GetUserMinSeq(ctx, conversationID, userID)
if err != nil {
return 0, 0, nil, err
}
minSeq, err := db.seqConversation.GetMinSeq(ctx, conversationID)
if err != nil {
return 0, 0, nil, err
}
maxSeq, err := db.seqConversation.GetMaxSeq(ctx, conversationID)
if err != nil {
return 0, 0, nil, err
}
userMaxSeq, err := db.seqUser.GetUserMaxSeq(ctx, conversationID, userID)
if err != nil {
return 0, 0, nil, err
}
if userMinSeq > minSeq {
minSeq = userMinSeq
}
if userMaxSeq > 0 && userMaxSeq < maxSeq {
maxSeq = userMaxSeq
}
newSeqs := make([]int64, 0, len(seqs))
for _, seq := range seqs {
if seq <= 0 {
continue
}
if seq >= minSeq && seq <= maxSeq {
newSeqs = append(newSeqs, seq)
}
}
successMsgs, err := db.GetMessageBySeqs(ctx, conversationID, userID, newSeqs)
if err != nil {
return 0, 0, nil, err
}
return minSeq, maxSeq, successMsgs, nil
}
func (db *commonMsgDatabase) GetMessagesBySeqWithBounds(ctx context.Context, userID string, conversationID string, seqs []int64, pullOrder sdkws.PullOrder) (bool, int64, []*sdkws.MsgData, error) {
var endSeq int64
var isEnd bool
userMinSeq, err := db.seqUser.GetUserMinSeq(ctx, conversationID, userID)
if err != nil {
return false, 0, nil, err
}
minSeq, err := db.seqConversation.GetMinSeq(ctx, conversationID)
if err != nil {
return false, 0, nil, err
}
maxSeq, err := db.seqConversation.GetMaxSeq(ctx, conversationID)
if err != nil {
return false, 0, nil, err
}
userMaxSeq, err := db.seqUser.GetUserMaxSeq(ctx, conversationID, userID)
if err != nil {
return false, 0, nil, err
}
if userMinSeq > minSeq {
minSeq = userMinSeq
}
if userMaxSeq > 0 && userMaxSeq < maxSeq {
maxSeq = userMaxSeq
}
newSeqs := make([]int64, 0, len(seqs))
for _, seq := range seqs {
if seq <= 0 {
continue
}
// The normal range and can fetch messages
if seq >= minSeq && seq <= maxSeq {
newSeqs = append(newSeqs, seq)
continue
}
// If the requested seq is smaller than the minimum seq and the pull order is descending (pulling older messages)
if seq < minSeq && pullOrder == sdkws.PullOrder_PullOrderDesc {
isEnd = true
endSeq = minSeq
}
// If the requested seq is larger than the maximum seq and the pull order is ascending (pulling newer messages)
if seq > maxSeq && pullOrder == sdkws.PullOrder_PullOrderAsc {
isEnd = true
endSeq = maxSeq
}
}
if len(newSeqs) == 0 {
return isEnd, endSeq, nil, nil
}
successMsgs, err := db.GetMessageBySeqs(ctx, conversationID, userID, newSeqs)
if err != nil {
return false, 0, nil, err
}
return isEnd, endSeq, successMsgs, nil
}
func (db *commonMsgDatabase) DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, allSeqs []int64) error {
for docID, seqs := range db.msgTable.GetDocIDSeqsMap(conversationID, allSeqs) {
var indexes []int
for _, seq := range seqs {
indexes = append(indexes, int(db.msgTable.GetMsgIndex(seq)))
}
if err := db.msgDocDatabase.DeleteMsgsInOneDocByIndex(ctx, docID, indexes); err != nil {
return err
}
}
return db.msgCache.DelMessageBySeqs(ctx, conversationID, allSeqs)
}
func (db *commonMsgDatabase) DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error {
for docID, seqs := range db.msgTable.GetDocIDSeqsMap(conversationID, seqs) {
for _, seq := range seqs {
if _, err := db.msgDocDatabase.PushUnique(ctx, docID, db.msgTable.GetMsgIndex(seq), "del_list", []string{userID}); err != nil {
return err
}
}
}
return db.msgCache.DelMessageBySeqs(ctx, conversationID, seqs)
}
func (db *commonMsgDatabase) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
return db.seqConversation.GetMaxSeqs(ctx, conversationIDs)
}
func (db *commonMsgDatabase) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) {
return db.seqConversation.GetMaxSeq(ctx, conversationID)
}
func (db *commonMsgDatabase) SetMinSeqs(ctx context.Context, seqs map[string]int64) error {
return db.seqConversation.SetMinSeqs(ctx, seqs)
}
func (db *commonMsgDatabase) SetUserConversationsMinSeqs(ctx context.Context, userID string, seqs map[string]int64) error {
return db.seqUser.SetUserMinSeqs(ctx, userID, seqs)
}
func (db *commonMsgDatabase) SetUserConversationsMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
return db.seqUser.SetUserMaxSeq(ctx, conversationID, userID, seq)
}
func (db *commonMsgDatabase) SetUserConversationsMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
return db.seqUser.SetUserMinSeq(ctx, conversationID, userID, seq)
}
func (db *commonMsgDatabase) UserSetHasReadSeqs(ctx context.Context, userID string, hasReadSeqs map[string]int64) error {
return db.seqUser.SetUserReadSeqs(ctx, userID, hasReadSeqs)
}
func (db *commonMsgDatabase) SetHasReadSeq(ctx context.Context, userID string, conversationID string, hasReadSeq int64) error {
return db.seqUser.SetUserReadSeq(ctx, conversationID, userID, hasReadSeq)
}
func (db *commonMsgDatabase) GetHasReadSeqs(ctx context.Context, userID string, conversationIDs []string) (map[string]int64, error) {
return db.seqUser.GetUserReadSeqs(ctx, userID, conversationIDs)
}
func (db *commonMsgDatabase) GetHasReadSeq(ctx context.Context, userID string, conversationID string) (int64, error) {
return db.seqUser.GetUserReadSeq(ctx, conversationID, userID)
}
func (db *commonMsgDatabase) SetSendMsgStatus(ctx context.Context, id string, status int32) error {
return db.msgCache.SetSendMsgStatus(ctx, id, status)
}
func (db *commonMsgDatabase) GetSendMsgStatus(ctx context.Context, id string) (int32, error) {
return db.msgCache.GetSendMsgStatus(ctx, id)
}
func (db *commonMsgDatabase) GetConversationMinMaxSeqInMongoAndCache(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo, minSeqCache, maxSeqCache int64, err error) {
minSeqMongo, maxSeqMongo, err = db.GetMinMaxSeqMongo(ctx, conversationID)
if err != nil {
return
}
minSeqCache, err = db.seqConversation.GetMinSeq(ctx, conversationID)
if err != nil {
return
}
maxSeqCache, err = db.seqConversation.GetMaxSeq(ctx, conversationID)
if err != nil {
return
}
return
}
func (db *commonMsgDatabase) GetMongoMaxAndMinSeq(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo int64, err error) {
return db.GetMinMaxSeqMongo(ctx, conversationID)
}
func (db *commonMsgDatabase) GetMinMaxSeqMongo(ctx context.Context, conversationID string) (minSeqMongo, maxSeqMongo int64, err error) {
oldestMsgMongo, err := db.msgDocDatabase.GetOldestMsg(ctx, conversationID)
if err != nil {
return
}
minSeqMongo = oldestMsgMongo.Msg.Seq
newestMsgMongo, err := db.msgDocDatabase.GetNewestMsg(ctx, conversationID)
if err != nil {
return
}
maxSeqMongo = newestMsgMongo.Msg.Seq
return
}
func (db *commonMsgDatabase) RangeUserSendCount(ctx context.Context, start time.Time, end time.Time, group bool, ase bool, pageNumber int32, showNumber int32) (msgCount int64, userCount int64, users []*model.UserCount, dateCount map[string]int64, err error) {
return db.msgDocDatabase.RangeUserSendCount(ctx, start, end, group, ase, pageNumber, showNumber)
}
func (db *commonMsgDatabase) RangeGroupSendCount(ctx context.Context, start time.Time, end time.Time, ase bool, pageNumber int32, showNumber int32) (msgCount int64, userCount int64, groups []*model.GroupCount, dateCount map[string]int64, err error) {
return db.msgDocDatabase.RangeGroupSendCount(ctx, start, end, ase, pageNumber, showNumber)
}
func (db *commonMsgDatabase) SearchMessage(ctx context.Context, req *pbmsg.SearchMessageReq) (total int64, msgData []*pbmsg.SearchedMsgData, err error) {
var totalMsgs []*pbmsg.SearchedMsgData
total, msgs, err := db.msgDocDatabase.SearchMessage(ctx, req)
if err != nil {
return 0, nil, err
}
for _, msg := range msgs {
if msg.IsRead {
msg.Msg.IsRead = true
}
searchedMsgData := &pbmsg.SearchedMsgData{MsgData: convert.MsgDB2Pb(msg.Msg)}
if msg.Revoke != nil {
searchedMsgData.IsRevoked = true
}
totalMsgs = append(totalMsgs, searchedMsgData)
}
return total, totalMsgs, nil
}
func (db *commonMsgDatabase) FindOneByDocIDs(ctx context.Context, conversationIDs []string, seqs map[string]int64) (map[string]*sdkws.MsgData, error) {
totalMsgs := make(map[string]*sdkws.MsgData)
for _, conversationID := range conversationIDs {
seq, ok := seqs[conversationID]
if !ok {
log.ZWarn(ctx, "seq not found for conversationID", errs.New("seq not found for conversation"), "conversationID", conversationID)
continue
}
docID := db.msgTable.GetDocID(conversationID, seq)
msgs, err := db.msgDocDatabase.FindOneByDocID(ctx, docID)
if err != nil {
log.ZWarn(ctx, "FindOneByDocID failed", err, "conversationID", conversationID, "docID", docID, "seq", seq)
continue
}
index := db.msgTable.GetMsgIndex(seq)
totalMsgs[conversationID] = convert.MsgDB2Pb(msgs.Msg[index].Msg)
}
return totalMsgs, nil
}
func (db *commonMsgDatabase) GetRandBeforeMsg(ctx context.Context, ts int64, limit int) ([]*model.MsgDocModel, error) {
return db.msgDocDatabase.GetRandBeforeMsg(ctx, ts, limit)
}
func (db *commonMsgDatabase) SetMinSeq(ctx context.Context, conversationID string, seq int64) error {
dbSeq, err := db.seqConversation.GetMinSeq(ctx, conversationID)
if err != nil {
if errors.Is(errs.Unwrap(err), redis.Nil) {
return nil
}
return err
}
if dbSeq >= seq {
return nil
}
return db.seqConversation.SetMinSeq(ctx, conversationID, seq)
}
func (db *commonMsgDatabase) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
return db.seqConversation.GetCacheMaxSeqWithTime(ctx, conversationIDs)
}
func (db *commonMsgDatabase) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) {
return db.seqConversation.GetMaxSeqWithTime(ctx, conversationID)
}
func (db *commonMsgDatabase) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) {
// todo: only the time in the redis cache will be taken, not the message time
return db.seqConversation.GetMaxSeqsWithTime(ctx, conversationIDs)
}
func (db *commonMsgDatabase) DeleteDoc(ctx context.Context, docID string) error {
index := strings.LastIndex(docID, ":")
if index <= 0 {
return errs.ErrInternalServer.WrapMsg("docID is invalid", "docID", docID)
}
docIndex, err := strconv.Atoi(docID[index+1:])
if err != nil {
return errs.WrapMsg(err, "strconv.Atoi", "docID", docID)
}
conversationID := docID[:index]
seqs := make([]int64, db.msgTable.GetSingleGocMsgNum())
minSeq := db.msgTable.GetMinSeq(docIndex)
for i := range seqs {
seqs[i] = minSeq + int64(i)
}
if err := db.msgDocDatabase.DeleteDoc(ctx, docID); err != nil {
return err
}
return db.msgCache.DelMessageBySeqs(ctx, conversationID, seqs)
}
func (db *commonMsgDatabase) GetLastMessageSeqByTime(ctx context.Context, conversationID string, time int64) (int64, error) {
return db.msgDocDatabase.GetLastMessageSeqByTime(ctx, conversationID, time)
}
func (db *commonMsgDatabase) handlerDeleteAndRevoked(ctx context.Context, userID string, msgs []*model.MsgInfoModel) {
for i := range msgs {
msg := msgs[i]
if msg == nil || msg.Msg == nil {
continue
}
msg.Msg.IsRead = msg.IsRead
if datautil.Contain(userID, msg.DelList...) {
msg.Msg.Content = ""
msg.Msg.Status = constant.MsgDeleted
}
if msg.Revoke == nil {
continue
}
msg.Msg.ContentType = constant.MsgRevokeNotification
revokeContent := sdkws.MessageRevokedContent{
RevokerID: msg.Revoke.UserID,
RevokerRole: msg.Revoke.Role,
ClientMsgID: msg.Msg.ClientMsgID,
RevokerNickname: msg.Revoke.Nickname,
RevokeTime: msg.Revoke.Time,
SourceMessageSendTime: msg.Msg.SendTime,
SourceMessageSendID: msg.Msg.SendID,
SourceMessageSenderNickname: msg.Msg.SenderNickname,
SessionType: msg.Msg.SessionType,
Seq: msg.Msg.Seq,
Ex: msg.Msg.Ex,
}
data, err := jsonutil.JsonMarshal(&revokeContent)
if err != nil {
log.ZWarn(ctx, "handlerDeleteAndRevoked JsonMarshal MessageRevokedContent", err, "msg", msg)
continue
}
elem := sdkws.NotificationElem{
Detail: string(data),
}
content, err := jsonutil.JsonMarshal(&elem)
if err != nil {
log.ZWarn(ctx, "handlerDeleteAndRevoked JsonMarshal NotificationElem", err, "msg", msg)
continue
}
msg.Msg.Content = string(content)
}
}
func (db *commonMsgDatabase) handlerQuote(ctx context.Context, userID, conversationID string, msgs []*model.MsgInfoModel) {
temp := make(map[int64][]*model.MsgInfoModel)
for i := range msgs {
db.handlerDBMsg(ctx, temp, userID, conversationID, msgs[i])
}
}
func (db *commonMsgDatabase) GetMessageBySeqs(ctx context.Context, conversationID string, userID string, seqs []int64) ([]*sdkws.MsgData, error) {
msgs, err := db.msgCache.GetMessageBySeqs(ctx, conversationID, seqs)
if err != nil {
return nil, err
}
db.handlerDeleteAndRevoked(ctx, userID, msgs)
db.handlerQuote(ctx, userID, conversationID, msgs)
seqMsgs := make(map[int64]*model.MsgInfoModel)
for i, msg := range msgs {
if msg.Msg == nil {
continue
}
seqMsgs[msg.Msg.Seq] = msgs[i]
}
res := make([]*sdkws.MsgData, 0, len(seqs))
for _, seq := range seqs {
if v, ok := seqMsgs[seq]; ok {
res = append(res, convert.MsgDB2Pb(v.Msg))
} else {
res = append(res, &sdkws.MsgData{Seq: seq, Status: constant.MsgStatusHasDeleted})
}
}
return res, nil
}
func (db *commonMsgDatabase) GetLastMessage(ctx context.Context, conversationIDs []string, userID string) (map[string]*sdkws.MsgData, error) {
res := make(map[string]*sdkws.MsgData)
for _, conversationID := range conversationIDs {
if _, ok := res[conversationID]; ok {
continue
}
msg, err := db.msgDocDatabase.GetLastMessage(ctx, conversationID)
if err != nil {
if errs.Unwrap(err) == mongo.ErrNoDocuments {
continue
}
return nil, err
}
tmp := []*model.MsgInfoModel{msg}
db.handlerDeleteAndRevoked(ctx, userID, tmp)
db.handlerQuote(ctx, userID, conversationID, tmp)
res[conversationID] = convert.MsgDB2Pb(msg.Msg)
}
return res, nil
}

View File

@@ -0,0 +1,277 @@
package controller
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/mq"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
pbmsg "git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"go.mongodb.org/mongo-driver/mongo"
)
type MsgTransferDatabase interface {
// BatchInsertChat2DB inserts a batch of messages into the database for a specific conversation.
BatchInsertChat2DB(ctx context.Context, conversationID string, msgs []*sdkws.MsgData, currentMaxSeq int64) error
// DeleteMessagesFromCache deletes message caches from Redis by sequence numbers.
DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error
// BatchInsertChat2Cache increments the sequence number and then batch inserts messages into the cache.
BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, userHasReadMap map[string]int64, err error)
SetHasReadSeqs(ctx context.Context, conversationID string, userSeqMap map[string]int64) error
SetHasReadSeqToDB(ctx context.Context, conversationID string, userSeqMap map[string]int64) error
// to mq
MsgToPushMQ(ctx context.Context, key, conversationID string, msg2mq *sdkws.MsgData) error
MsgToMongoMQ(ctx context.Context, key, conversationID string, msgs []*sdkws.MsgData, lastSeq int64) error
}
func NewMsgTransferDatabase(msgDocModel database.Msg, msg cache.MsgCache, seqUser cache.SeqUser, seqConversation cache.SeqConversationCache, mongoProducer, pushProducer mq.Producer) (MsgTransferDatabase, error) {
//conf, err := kafka.BuildProducerConfig(*kafkaConf.Build())
//if err != nil {
// return nil, err
//}
//producerToMongo, err := kafka.NewKafkaProducerV2(conf, kafkaConf.Address, kafkaConf.ToMongoTopic)
//if err != nil {
// return nil, err
//}
//producerToPush, err := kafka.NewKafkaProducerV2(conf, kafkaConf.Address, kafkaConf.ToPushTopic)
//if err != nil {
// return nil, err
//}
return &msgTransferDatabase{
msgDocDatabase: msgDocModel,
msgCache: msg,
seqUser: seqUser,
seqConversation: seqConversation,
producerToMongo: mongoProducer,
producerToPush: pushProducer,
}, nil
}
type msgTransferDatabase struct {
msgDocDatabase database.Msg
msgTable model.MsgDocModel
msgCache cache.MsgCache
seqConversation cache.SeqConversationCache
seqUser cache.SeqUser
producerToMongo mq.Producer
producerToPush mq.Producer
}
func (db *msgTransferDatabase) BatchInsertChat2DB(ctx context.Context, conversationID string, msgList []*sdkws.MsgData, currentMaxSeq int64) error {
if len(msgList) == 0 {
return errs.ErrArgs.WrapMsg("msgList is empty")
}
msgs := make([]any, len(msgList))
seqs := make([]int64, len(msgList))
for i, msg := range msgList {
if msg == nil {
continue
}
seqs[i] = msg.Seq
if msg.Status == constant.MsgStatusSending {
msg.Status = constant.MsgStatusSendSuccess
}
msgs[i] = convert.MsgPb2DB(msg)
}
if err := db.BatchInsertBlock(ctx, conversationID, msgs, updateKeyMsg, msgList[0].Seq); err != nil {
return err
}
//return db.msgCache.DelMessageBySeqs(ctx, conversationID, seqs)
return nil
}
func (db *msgTransferDatabase) BatchInsertBlock(ctx context.Context, conversationID string, fields []any, key int8, firstSeq int64) error {
if len(fields) == 0 {
return nil
}
num := db.msgTable.GetSingleGocMsgNum()
// num = 100
for i, field := range fields { // Check the type of the field
var ok bool
switch key {
case updateKeyMsg:
var msg *model.MsgDataModel
msg, ok = field.(*model.MsgDataModel)
if msg != nil && msg.Seq != firstSeq+int64(i) {
return errs.ErrInternalServer.WrapMsg("seq is invalid")
}
case updateKeyRevoke:
_, ok = field.(*model.RevokeModel)
default:
return errs.ErrInternalServer.WrapMsg("key is invalid")
}
if !ok {
return errs.ErrInternalServer.WrapMsg("field type is invalid")
}
}
// Returns true if the document exists in the database, false if the document does not exist in the database
updateMsgModel := func(seq int64, i int) (bool, error) {
var (
res *mongo.UpdateResult
err error
)
docID := db.msgTable.GetDocID(conversationID, seq)
index := db.msgTable.GetMsgIndex(seq)
field := fields[i]
switch key {
case updateKeyMsg:
res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "msg", field)
case updateKeyRevoke:
res, err = db.msgDocDatabase.UpdateMsg(ctx, docID, index, "revoke", field)
}
if err != nil {
return false, err
}
return res.MatchedCount > 0, nil
}
tryUpdate := true
for i := 0; i < len(fields); i++ {
seq := firstSeq + int64(i) // Current sequence number
if tryUpdate {
matched, err := updateMsgModel(seq, i)
if err != nil {
return err
}
if matched {
continue // The current data has been updated, skip the current data
}
}
doc := model.MsgDocModel{
DocID: db.msgTable.GetDocID(conversationID, seq),
Msg: make([]*model.MsgInfoModel, num),
}
var insert int // Inserted data number
for j := i; j < len(fields); j++ {
seq = firstSeq + int64(j)
if db.msgTable.GetDocID(conversationID, seq) != doc.DocID {
break
}
insert++
switch key {
case updateKeyMsg:
doc.Msg[db.msgTable.GetMsgIndex(seq)] = &model.MsgInfoModel{
Msg: fields[j].(*model.MsgDataModel),
}
case updateKeyRevoke:
doc.Msg[db.msgTable.GetMsgIndex(seq)] = &model.MsgInfoModel{
Revoke: fields[j].(*model.RevokeModel),
}
}
}
for i, msgInfo := range doc.Msg {
if msgInfo == nil {
msgInfo = &model.MsgInfoModel{}
doc.Msg[i] = msgInfo
}
if msgInfo.DelList == nil {
doc.Msg[i].DelList = []string{}
}
}
if err := db.msgDocDatabase.Create(ctx, &doc); err != nil {
if mongo.IsDuplicateKeyError(err) {
i-- // already inserted
tryUpdate = true // next block use update mode
continue
}
return err
}
tryUpdate = false // The current block is inserted successfully, and the next block is inserted preferentially
i += insert - 1 // Skip the inserted data
}
return nil
}
func (db *msgTransferDatabase) DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error {
return db.msgCache.DelMessageBySeqs(ctx, conversationID, seqs)
}
func (db *msgTransferDatabase) BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNew bool, userHasReadMap map[string]int64, err error) {
lenList := len(msgs)
if int64(lenList) > db.msgTable.GetSingleGocMsgNum() {
return 0, false, nil, errs.New("message count exceeds limit", "limit", db.msgTable.GetSingleGocMsgNum()).Wrap()
}
if lenList < 1 {
return 0, false, nil, errs.New("no messages to insert", "minCount", 1).Wrap()
}
currentMaxSeq, err := db.seqConversation.Malloc(ctx, conversationID, int64(len(msgs)))
if err != nil {
log.ZError(ctx, "storage.seq.Malloc", err)
return 0, false, nil, err
}
isNew = currentMaxSeq == 0
lastMaxSeq := currentMaxSeq
userSeqMap := make(map[string]int64)
seqs := make([]int64, 0, lenList)
for _, m := range msgs {
currentMaxSeq++
m.Seq = currentMaxSeq
userSeqMap[m.SendID] = m.Seq
seqs = append(seqs, m.Seq)
}
msgToDB := func(msg *sdkws.MsgData) *model.MsgInfoModel {
return &model.MsgInfoModel{
Msg: convert.MsgPb2DB(msg),
}
}
if err := db.msgCache.SetMessageBySeqs(ctx, conversationID, datautil.Slice(msgs, msgToDB)); err != nil {
return 0, false, nil, err
}
return lastMaxSeq, isNew, userSeqMap, nil
}
func (db *msgTransferDatabase) SetHasReadSeqs(ctx context.Context, conversationID string, userSeqMap map[string]int64) error {
for userID, seq := range userSeqMap {
if err := db.seqUser.SetUserReadSeq(ctx, conversationID, userID, seq); err != nil {
return err
}
}
return nil
}
func (db *msgTransferDatabase) SetHasReadSeqToDB(ctx context.Context, conversationID string, userSeqMap map[string]int64) error {
for userID, seq := range userSeqMap {
if err := db.seqUser.SetUserReadSeqToDB(ctx, conversationID, userID, seq); err != nil {
return err
}
}
return nil
}
func (db *msgTransferDatabase) MsgToPushMQ(ctx context.Context, key, conversationID string, msg2mq *sdkws.MsgData) error {
data, err := proto.Marshal(&pbmsg.PushMsgDataToMQ{MsgData: msg2mq, ConversationID: conversationID})
if err != nil {
return err
}
if err := db.producerToPush.SendMessage(ctx, key, data); err != nil {
log.ZError(ctx, "MsgToPushMQ", err, "key", key, "conversationID", conversationID)
return err
}
return nil
}
func (db *msgTransferDatabase) MsgToMongoMQ(ctx context.Context, key, conversationID string, messages []*sdkws.MsgData, lastSeq int64) error {
if len(messages) > 0 {
data, err := proto.Marshal(&pbmsg.MsgDataToMongoByMQ{LastSeq: lastSeq, ConversationID: conversationID, MsgData: messages})
if err != nil {
return err
}
if err := db.producerToMongo.SendMessage(ctx, key, data); err != nil {
log.ZError(ctx, "MsgToMongoMQ", err, "key", key, "conversationID", conversationID, "lastSeq", lastSeq)
return err
}
}
return nil
}

View File

@@ -0,0 +1,58 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/push"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mq"
"google.golang.org/protobuf/proto"
)
type PushDatabase interface {
DelFcmToken(ctx context.Context, userID string, platformID int) error
MsgToOfflinePushMQ(ctx context.Context, key string, userIDs []string, msg2mq *sdkws.MsgData) error
}
type pushDataBase struct {
cache cache.ThirdCache
producerToOfflinePush mq.Producer
}
func NewPushDatabase(cache cache.ThirdCache, offlinePushProducer mq.Producer) PushDatabase {
return &pushDataBase{
cache: cache,
producerToOfflinePush: offlinePushProducer,
}
}
func (p *pushDataBase) DelFcmToken(ctx context.Context, userID string, platformID int) error {
return p.cache.DelFcmToken(ctx, userID, platformID)
}
func (p *pushDataBase) MsgToOfflinePushMQ(ctx context.Context, key string, userIDs []string, msg2mq *sdkws.MsgData) error {
data, err := proto.Marshal(&push.PushMsgReq{MsgData: msg2mq, UserIDs: userIDs})
if err != nil {
return err
}
if err := p.producerToOfflinePush.SendMessage(ctx, key, data); err != nil {
log.ZError(ctx, "message is push to offlinePush topic", err, "key", key, "userIDs", userIDs, "msg", msg2mq.String())
}
return err
}

View File

@@ -0,0 +1,136 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"path/filepath"
"time"
redisCache "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"github.com/openimsdk/tools/s3"
"github.com/openimsdk/tools/s3/cont"
"github.com/redis/go-redis/v9"
)
type S3Database interface {
PartLimit() (*s3.PartLimit, error)
PartSize(ctx context.Context, size int64) (int64, error)
AuthSign(ctx context.Context, uploadID string, partNumbers []int) (*s3.AuthSignResult, error)
InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error)
CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error)
AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (time.Time, string, error)
SetObject(ctx context.Context, info *model.Object) error
StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error)
FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error)
FindExpirationObject(ctx context.Context, engine string, expiration time.Time, needDelType []string, count int64) ([]*model.Object, error)
DeleteSpecifiedData(ctx context.Context, engine string, name []string) error
DelS3Key(ctx context.Context, engine string, keys ...string) error
GetKeyCount(ctx context.Context, engine string, key string) (int64, error)
}
func NewS3Database(rdb redis.UniversalClient, s3 s3.Interface, obj database.ObjectInfo) S3Database {
return &s3Database{
s3: cont.New(redisCache.NewS3Cache(rdb, s3), s3),
cache: redisCache.NewObjectCacheRedis(rdb, obj),
s3cache: redisCache.NewS3Cache(rdb, s3),
db: obj,
}
}
type s3Database struct {
s3 *cont.Controller
cache cache.ObjectCache
s3cache cont.S3Cache
db database.ObjectInfo
}
func (s *s3Database) PartSize(ctx context.Context, size int64) (int64, error) {
return s.s3.PartSize(ctx, size)
}
func (s *s3Database) PartLimit() (*s3.PartLimit, error) {
return s.s3.PartLimit()
}
func (s *s3Database) AuthSign(ctx context.Context, uploadID string, partNumbers []int) (*s3.AuthSignResult, error) {
return s.s3.AuthSign(ctx, uploadID, partNumbers)
}
func (s *s3Database) InitiateMultipartUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int, contentType string) (*cont.InitiateUploadResult, error) {
return s.s3.InitiateUploadContentType(ctx, hash, size, expire, maxParts, contentType)
}
func (s *s3Database) CompleteMultipartUpload(ctx context.Context, uploadID string, parts []string) (*cont.UploadResult, error) {
return s.s3.CompleteUpload(ctx, uploadID, parts)
}
func (s *s3Database) SetObject(ctx context.Context, info *model.Object) error {
info.Engine = s.s3.Engine()
if err := s.db.SetObject(ctx, info); err != nil {
return err
}
return s.cache.DelObjectName(info.Engine, info.Name).ChainExecDel(ctx)
}
func (s *s3Database) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (time.Time, string, error) {
obj, err := s.cache.GetName(ctx, s.s3.Engine(), name)
if err != nil {
return time.Time{}, "", err
}
if opt == nil {
opt = &s3.AccessURLOption{}
}
if opt.ContentType == "" {
opt.ContentType = obj.ContentType
}
if opt.Filename == "" {
opt.Filename = filepath.Base(obj.Name)
}
expireTime := time.Now().Add(expire)
rawURL, err := s.s3.AccessURL(ctx, obj.Key, expire, opt)
if err != nil {
return time.Time{}, "", err
}
return expireTime, rawURL, nil
}
func (s *s3Database) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) {
return s.s3.StatObject(ctx, name)
}
func (s *s3Database) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) {
return s.s3.FormData(ctx, name, size, contentType, duration)
}
func (s *s3Database) FindExpirationObject(ctx context.Context, engine string, expiration time.Time, needDelType []string, count int64) ([]*model.Object, error) {
return s.db.FindExpirationObject(ctx, engine, expiration, needDelType, count)
}
func (s *s3Database) GetKeyCount(ctx context.Context, engine string, key string) (int64, error) {
return s.db.GetKeyCount(ctx, engine, key)
}
func (s *s3Database) DeleteSpecifiedData(ctx context.Context, engine string, name []string) error {
return s.db.Delete(ctx, engine, name)
}
func (s *s3Database) DelS3Key(ctx context.Context, engine string, keys ...string) error {
return s.s3cache.DelS3Key(ctx, engine, keys...)
}

View File

@@ -0,0 +1,73 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"github.com/openimsdk/tools/db/pagination"
)
type ThirdDatabase interface {
FcmUpdateToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) error
SetAppBadge(ctx context.Context, userID string, value int) error
// about log for debug
UploadLogs(ctx context.Context, logs []*model.Log) error
DeleteLogs(ctx context.Context, logID []string, userID string) error
SearchLogs(ctx context.Context, keyword string, start time.Time, end time.Time, pagination pagination.Pagination) (int64, []*model.Log, error)
GetLogs(ctx context.Context, LogIDs []string, userID string) ([]*model.Log, error)
}
type thirdDatabase struct {
cache cache.ThirdCache
logdb database.Log
}
// DeleteLogs implements ThirdDatabase.
func (t *thirdDatabase) DeleteLogs(ctx context.Context, logID []string, userID string) error {
return t.logdb.Delete(ctx, logID, userID)
}
// GetLogs implements ThirdDatabase.
func (t *thirdDatabase) GetLogs(ctx context.Context, LogIDs []string, userID string) ([]*model.Log, error) {
return t.logdb.Get(ctx, LogIDs, userID)
}
// SearchLogs implements ThirdDatabase.
func (t *thirdDatabase) SearchLogs(ctx context.Context, keyword string, start time.Time, end time.Time, pagination pagination.Pagination) (int64, []*model.Log, error) {
return t.logdb.Search(ctx, keyword, start, end, pagination)
}
// UploadLogs implements ThirdDatabase.
func (t *thirdDatabase) UploadLogs(ctx context.Context, logs []*model.Log) error {
return t.logdb.Create(ctx, logs)
}
func NewThirdDatabase(cache cache.ThirdCache, logdb database.Log) ThirdDatabase {
return &thirdDatabase{cache: cache, logdb: logdb}
}
func (t *thirdDatabase) FcmUpdateToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) error {
return t.cache.SetFcmToken(ctx, account, platformID, fcmToken, expireTime)
}
func (t *thirdDatabase) SetAppBadge(ctx context.Context, userID string, value int) error {
return t.cache.SetUserBadgeUnreadCountSum(ctx, userID, value)
}

View File

@@ -0,0 +1,263 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package controller
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/user"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/datautil"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
)
type UserDatabase interface {
// FindWithError Get the information of the specified user. If the userID is not found, it will also return an error
FindWithError(ctx context.Context, userIDs []string) (users []*model.User, err error)
// Find Get the information of the specified user If the userID is not found, no error will be returned
Find(ctx context.Context, userIDs []string) (users []*model.User, err error)
// Find userInfo By Nickname
FindByNickname(ctx context.Context, nickname string) (users []*model.User, err error)
// FindNotification find system account by level
FindNotification(ctx context.Context, level int64) (users []*model.User, err error)
// FindSystemAccount find all system account
FindSystemAccount(ctx context.Context) (users []*model.User, err error)
// Create Insert multiple external guarantees that the userID is not repeated and does not exist in the storage
Create(ctx context.Context, users []*model.User) (err error)
// UpdateByMap update (zero value) external guarantee userID exists
UpdateByMap(ctx context.Context, userID string, args map[string]any) (err error)
// FindUser
PageFindUser(ctx context.Context, level1 int64, level2 int64, pagination pagination.Pagination) (count int64, users []*model.User, err error)
// FindUser with keyword
PageFindUserWithKeyword(ctx context.Context, level1 int64, level2 int64, userID string, nickName string, pagination pagination.Pagination) (count int64, users []*model.User, err error)
// Page If not found, no error is returned
Page(ctx context.Context, pagination pagination.Pagination) (count int64, users []*model.User, err error)
// IsExist true as long as one exists
IsExist(ctx context.Context, userIDs []string) (exist bool, err error)
// GetAllUserID Get all user IDs
GetAllUserID(ctx context.Context, pagination pagination.Pagination) (int64, []string, error)
// Get user by userID
GetUserByID(ctx context.Context, userID string) (user *model.User, err error)
// SearchUsersByFields searches users by multiple fields: account (userID), phone, nickname
// Returns userIDs that match the search criteria
SearchUsersByFields(ctx context.Context, account, phone, nickname string) (userIDs []string, err error)
// InitOnce Inside the function, first query whether it exists in the storage, if it exists, do nothing; if it does not exist, insert it
InitOnce(ctx context.Context, users []*model.User) (err error)
// CountTotal Get the total number of users
CountTotal(ctx context.Context, before *time.Time) (int64, error)
// CountRangeEverydayTotal Get the user increment in the range
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
SortQuery(ctx context.Context, userIDName map[string]string, asc bool) ([]*model.User, error)
// CRUD user command
AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error
DeleteUserCommand(ctx context.Context, userID string, Type int32, UUID string) error
UpdateUserCommand(ctx context.Context, userID string, Type int32, UUID string, val map[string]any) error
GetUserCommands(ctx context.Context, userID string, Type int32) ([]*user.CommandInfoResp, error)
GetAllUserCommands(ctx context.Context, userID string) ([]*user.AllCommandInfoResp, error)
}
type userDatabase struct {
tx tx.Tx
userDB database.User
cache cache.UserCache
}
func NewUserDatabase(userDB database.User, cache cache.UserCache, tx tx.Tx) UserDatabase {
return &userDatabase{userDB: userDB, cache: cache, tx: tx}
}
func (u *userDatabase) InitOnce(ctx context.Context, users []*model.User) error {
// Extract user IDs from the given user models.
userIDs := datautil.Slice(users, func(e *model.User) string {
return e.UserID
})
// Find existing users in the database.
existingUsers, err := u.userDB.Find(ctx, userIDs)
if err != nil {
return err
}
// Determine which users are missing from the database.
var (
missing, update []*model.User
)
existMap := datautil.SliceToMap(existingUsers, func(e *model.User) string {
return e.UserID
})
orgMap := datautil.SliceToMap(users, func(e *model.User) string { return e.UserID })
for k, u1 := range orgMap {
if u2, ok := existMap[k]; !ok {
missing = append(missing, u1)
} else if u1.Nickname != u2.Nickname {
update = append(update, u1)
}
}
// Create records for missing users.
if len(missing) > 0 {
if err := u.userDB.Create(ctx, missing); err != nil {
return err
}
}
if len(update) > 0 {
for i := range update {
if err := u.userDB.UpdateByMap(ctx, update[i].UserID, map[string]any{"nickname": update[i].Nickname}); err != nil {
return err
}
}
}
return nil
}
// FindWithError Get the information of the specified user and return an error if the userID is not found.
func (u *userDatabase) FindWithError(ctx context.Context, userIDs []string) (users []*model.User, err error) {
userIDs = datautil.Distinct(userIDs)
// TODO: Add logic to identify which user IDs are distinct and which user IDs were not found.
users, err = u.cache.GetUsersInfo(ctx, userIDs)
if err != nil {
return
}
if len(users) != len(userIDs) {
err = errs.ErrRecordNotFound.WrapMsg("userID not found")
}
return
}
// Find Get the information of the specified user. If the userID is not found, no error will be returned.
func (u *userDatabase) Find(ctx context.Context, userIDs []string) (users []*model.User, err error) {
return u.cache.GetUsersInfo(ctx, userIDs)
}
func (u *userDatabase) FindByNickname(ctx context.Context, nickname string) (users []*model.User, err error) {
return u.userDB.TakeByNickname(ctx, nickname)
}
func (u *userDatabase) FindNotification(ctx context.Context, level int64) (users []*model.User, err error) {
return u.userDB.TakeNotification(ctx, level)
}
func (u *userDatabase) FindSystemAccount(ctx context.Context) (users []*model.User, err error) {
return u.userDB.TakeGTEAppManagerLevel(ctx, constant.AppNotificationAdmin)
}
// Create Insert multiple external guarantees that the userID is not repeated and does not exist in the storage.
func (u *userDatabase) Create(ctx context.Context, users []*model.User) (err error) {
return u.tx.Transaction(ctx, func(ctx context.Context) error {
if err = u.userDB.Create(ctx, users); err != nil {
return err
}
return u.cache.DelUsersInfo(datautil.Slice(users, func(e *model.User) string {
return e.UserID
})...).ChainExecDel(ctx)
})
}
// UpdateByMap update (zero value) externally guarantees that userID exists.
func (u *userDatabase) UpdateByMap(ctx context.Context, userID string, args map[string]any) (err error) {
return u.tx.Transaction(ctx, func(ctx context.Context) error {
if err := u.userDB.UpdateByMap(ctx, userID, args); err != nil {
return err
}
return u.cache.DelUsersInfo(userID).ChainExecDel(ctx)
})
}
// Page Gets, returns no error if not found.
func (u *userDatabase) Page(ctx context.Context, pagination pagination.Pagination) (count int64, users []*model.User, err error) {
return u.userDB.Page(ctx, pagination)
}
func (u *userDatabase) PageFindUser(ctx context.Context, level1 int64, level2 int64, pagination pagination.Pagination) (count int64, users []*model.User, err error) {
return u.userDB.PageFindUser(ctx, level1, level2, pagination)
}
func (u *userDatabase) PageFindUserWithKeyword(ctx context.Context, level1 int64, level2 int64, userID, nickName string, pagination pagination.Pagination) (count int64, users []*model.User, err error) {
return u.userDB.PageFindUserWithKeyword(ctx, level1, level2, userID, nickName, pagination)
}
// IsExist Does userIDs exist? As long as there is one, it will be true.
func (u *userDatabase) IsExist(ctx context.Context, userIDs []string) (exist bool, err error) {
users, err := u.userDB.Find(ctx, userIDs)
if err != nil {
return false, err
}
if len(users) > 0 {
return true, nil
}
return false, nil
}
// GetAllUserID Get all user IDs.
func (u *userDatabase) GetAllUserID(ctx context.Context, pagination pagination.Pagination) (total int64, userIDs []string, err error) {
return u.userDB.GetAllUserID(ctx, pagination)
}
func (u *userDatabase) GetUserByID(ctx context.Context, userID string) (user *model.User, err error) {
return u.cache.GetUserInfo(ctx, userID)
}
func (u *userDatabase) SearchUsersByFields(ctx context.Context, account, phone, nickname string) (userIDs []string, err error) {
return u.userDB.SearchUsersByFields(ctx, account, phone, nickname)
}
// CountTotal Get the total number of users.
func (u *userDatabase) CountTotal(ctx context.Context, before *time.Time) (count int64, err error) {
return u.userDB.CountTotal(ctx, before)
}
// CountRangeEverydayTotal Get the user increment in the range.
func (u *userDatabase) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) {
return u.userDB.CountRangeEverydayTotal(ctx, start, end)
}
func (u *userDatabase) SortQuery(ctx context.Context, userIDName map[string]string, asc bool) ([]*model.User, error) {
return u.userDB.SortQuery(ctx, userIDName, asc)
}
func (u *userDatabase) AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error {
return u.userDB.AddUserCommand(ctx, userID, Type, UUID, value, ex)
}
func (u *userDatabase) DeleteUserCommand(ctx context.Context, userID string, Type int32, UUID string) error {
return u.userDB.DeleteUserCommand(ctx, userID, Type, UUID)
}
func (u *userDatabase) UpdateUserCommand(ctx context.Context, userID string, Type int32, UUID string, val map[string]any) error {
return u.userDB.UpdateUserCommand(ctx, userID, Type, UUID, val)
}
func (u *userDatabase) GetUserCommands(ctx context.Context, userID string, Type int32) ([]*user.CommandInfoResp, error) {
commands, err := u.userDB.GetUserCommand(ctx, userID, Type)
return commands, err
}
func (u *userDatabase) GetAllUserCommands(ctx context.Context, userID string) ([]*user.AllCommandInfoResp, error) {
commands, err := u.userDB.GetAllUserCommand(ctx, userID)
return commands, err
}

View File

@@ -0,0 +1,114 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type Black interface {
Create(ctx context.Context, blacks []*model.Black) (err error)
Delete(ctx context.Context, blacks []*model.Black) (err error)
Find(ctx context.Context, blacks []*model.Black) (blackList []*model.Black, err error)
Take(ctx context.Context, ownerUserID, blockUserID string) (black *model.Black, err error)
FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error)
FindOwnerBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error)
FindBlackUserIDs(ctx context.Context, ownerUserID string) (blackUserIDs []string, err error)
}
var (
_ Black = (*mgoImpl)(nil)
_ Black = (*redisImpl)(nil)
)
type mgoImpl struct {
}
func (m *mgoImpl) Create(ctx context.Context, blacks []*model.Black) (err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) Delete(ctx context.Context, blacks []*model.Black) (err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) Find(ctx context.Context, blacks []*model.Black) (blackList []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) Take(ctx context.Context, ownerUserID, blockUserID string) (black *model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) FindOwnerBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (m *mgoImpl) FindBlackUserIDs(ctx context.Context, ownerUserID string) (blackUserIDs []string, err error) {
//TODO implement me
panic("implement me")
}
type redisImpl struct {
}
func (r *redisImpl) Create(ctx context.Context, blacks []*model.Black) (err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) Delete(ctx context.Context, blacks []*model.Black) (err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) Find(ctx context.Context, blacks []*model.Black) (blackList []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) Take(ctx context.Context, ownerUserID, blockUserID string) (black *model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) FindOwnerBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error) {
//TODO implement me
panic("implement me")
}
func (r *redisImpl) FindBlackUserIDs(ctx context.Context, ownerUserID string) (blackUserIDs []string, err error) {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,16 @@
package database
import (
"context"
"time"
)
type Cache interface {
Get(ctx context.Context, key []string) (map[string]string, error)
Prefix(ctx context.Context, prefix string) (map[string]string, error)
Set(ctx context.Context, key string, value string, expireAt time.Duration) error
Incr(ctx context.Context, key string, value int) (int, error)
Del(ctx context.Context, key []string) error
Lock(ctx context.Context, key string, duration time.Duration) (string, error)
Unlock(ctx context.Context, key string, value string) error
}

View File

@@ -0,0 +1,15 @@
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type ClientConfig interface {
Set(ctx context.Context, userID string, config map[string]string) error
Get(ctx context.Context, userID string) (map[string]string, error)
Del(ctx context.Context, userID string, keys []string) error
GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error)
}

View File

@@ -0,0 +1,48 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type Conversation interface {
Create(ctx context.Context, conversations []*model.Conversation) (err error)
UpdateByMap(ctx context.Context, userIDs []string, conversationID string, args map[string]any) (rows int64, err error)
UpdateUserConversations(ctx context.Context, userID string, args map[string]any) ([]*model.Conversation, error)
Update(ctx context.Context, conversation *model.Conversation) (err error)
Find(ctx context.Context, ownerUserID string, conversationIDs []string) (conversations []*model.Conversation, err error)
FindUserID(ctx context.Context, userIDs []string, conversationIDs []string) ([]string, error)
FindUserIDAllConversationID(ctx context.Context, userID string) ([]string, error)
FindUserIDAllNotNotifyConversationID(ctx context.Context, userID string) ([]string, error)
FindUserIDAllPinnedConversationID(ctx context.Context, userID string) ([]string, error)
Take(ctx context.Context, userID, conversationID string) (conversation *model.Conversation, err error)
FindConversationID(ctx context.Context, userID string, conversationIDs []string) (existConversationID []string, err error)
FindUserIDAllConversations(ctx context.Context, userID string) (conversations []*model.Conversation, err error)
FindRecvMsgUserIDs(ctx context.Context, conversationID string, recvOpts []int) ([]string, error)
GetUserRecvMsgOpt(ctx context.Context, ownerUserID, conversationID string) (opt int, err error)
GetAllConversationIDs(ctx context.Context) ([]string, error)
GetAllConversationIDsNumber(ctx context.Context) (int64, error)
PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error)
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*model.Conversation, error)
GetConversationIDsNeedDestruct(ctx context.Context) ([]*model.Conversation, error)
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
FindConversationUserVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error)
FindRandConversation(ctx context.Context, ts int64, limit int) ([]*model.Conversation, error)
DeleteUsersConversations(ctx context.Context, userID string, conversationIDs []string) error
}

View File

@@ -0,0 +1,15 @@
// Copyright © 2024 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database // import "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model/relation"

View File

@@ -0,0 +1,60 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
// Friend defines the operations for managing friends in MongoDB.
type Friend interface {
// Create inserts multiple friend records.
Create(ctx context.Context, friends []*model.Friend) (err error)
// Delete removes specified friends of the owner user.
Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error)
// UpdateByMap updates specific fields of a friend document using a map.
UpdateByMap(ctx context.Context, ownerUserID string, friendUserID string, args map[string]any) (err error)
// UpdateRemark modify remarks.
UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error)
// Take retrieves a single friend document. Returns an error if not found.
Take(ctx context.Context, ownerUserID, friendUserID string) (friend *model.Friend, err error)
// FindUserState finds the friendship status between two users.
FindUserState(ctx context.Context, userID1, userID2 string) (friends []*model.Friend, err error)
// FindFriends retrieves a list of friends for a given owner. Missing friends do not cause an error.
FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*model.Friend, err error)
// FindReversalFriends finds users who have added the specified user as a friend.
FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string) (friends []*model.Friend, err error)
// FindOwnerFriends retrieves a paginated list of friends for a given owner.
FindOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error)
// FindInWhoseFriends finds users who have added the specified user as a friend, with pagination.
FindInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*model.Friend, err error)
// FindFriendUserIDs retrieves a list of friend user IDs for a given owner.
FindFriendUserIDs(ctx context.Context, ownerUserID string) (friendUserIDs []string, err error)
// UpdateFriends update friends' fields
UpdateFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, val map[string]any) (err error)
FindIncrVersion(ctx context.Context, ownerUserID string, version uint, limit int) (*model.VersionLog, error)
FindFriendUserID(ctx context.Context, friendUserID string) ([]string, error)
//SearchFriend(ctx context.Context, ownerUserID, keyword string, pagination pagination.Pagination) (int64, []*model.Friend, error)
FindOwnerFriendUserIds(ctx context.Context, ownerUserID string, limit int) ([]string, error)
IncrVersion(ctx context.Context, ownerUserID string, friendUserIDs []string, state int32) error
}

View File

@@ -0,0 +1,42 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type FriendRequest interface {
// Insert multiple records
Create(ctx context.Context, friendRequests []*model.FriendRequest) (err error)
// Delete record
Delete(ctx context.Context, fromUserID, toUserID string) (err error)
// Update with zero values
UpdateByMap(ctx context.Context, formUserID string, toUserID string, args map[string]any) (err error)
// Update multiple records (non-zero values)
Update(ctx context.Context, friendRequest *model.FriendRequest) (err error)
// Get friend requests sent to a specific user, no error returned if not found
Find(ctx context.Context, fromUserID, toUserID string) (friendRequest *model.FriendRequest, err error)
Take(ctx context.Context, fromUserID, toUserID string) (friendRequest *model.FriendRequest, err error)
// Get list of friend requests received by toUserID
FindToUserID(ctx context.Context, toUserID string, handleResults []int, pagination pagination.Pagination) (total int64, friendRequests []*model.FriendRequest, err error)
// Get list of friend requests sent by fromUserID
FindFromUserID(ctx context.Context, fromUserID string, handleResults []int, pagination pagination.Pagination) (total int64, friendRequests []*model.FriendRequest, err error)
FindBothFriendRequests(ctx context.Context, fromUserID, toUserID string) (friends []*model.FriendRequest, err error)
GetUnhandledCount(ctx context.Context, userID string, ts int64) (int64, error)
}

View File

@@ -0,0 +1,40 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type Group interface {
Create(ctx context.Context, groups []*model.Group) (err error)
UpdateMap(ctx context.Context, groupID string, args map[string]any) (err error)
UpdateStatus(ctx context.Context, groupID string, status int32) (err error)
Find(ctx context.Context, groupIDs []string) (groups []*model.Group, err error)
Take(ctx context.Context, groupID string) (group *model.Group, err error)
Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*model.Group, err error)
// Get Group total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// Get Group total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
FindJoinSortGroupID(ctx context.Context, groupIDs []string) ([]string, error)
SearchJoin(ctx context.Context, groupIDs []string, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error)
}

View File

@@ -0,0 +1,48 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type GroupMember interface {
Create(ctx context.Context, groupMembers []*model.GroupMember) (err error)
Delete(ctx context.Context, groupID string, userIDs []string) (err error)
Update(ctx context.Context, groupID string, userID string, data map[string]any) (err error)
UpdateRoleLevel(ctx context.Context, groupID string, userID string, roleLevel int32) error
UpdateUserRoleLevels(ctx context.Context, groupID string, firstUserID string, firstUserRoleLevel int32, secondUserID string, secondUserRoleLevel int32) error
FindMemberUserID(ctx context.Context, groupID string) (userIDs []string, err error)
Take(ctx context.Context, groupID string, userID string) (groupMember *model.GroupMember, err error)
Find(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error)
FindInGroup(ctx context.Context, userID string, groupIDs []string) ([]*model.GroupMember, error)
TakeOwner(ctx context.Context, groupID string) (groupMember *model.GroupMember, err error)
SearchMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (total int64, groupList []*model.GroupMember, err error)
// SearchMemberByFields searches for group members by multiple fields: nickname, userID (account), and optionally phone
SearchMemberByFields(ctx context.Context, groupID string, nickname, userID, phone string, pagination pagination.Pagination) (total int64, groupList []*model.GroupMember, err error)
FindRoleLevelUserIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
FindUserJoinedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
TakeGroupMemberNum(ctx context.Context, groupID string) (count int64, err error)
FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
IsUpdateRoleLevel(data map[string]any) bool
JoinGroupIncrVersion(ctx context.Context, userID string, groupIDs []string, state int32) error
MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error
FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error)
BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint, limits []int) ([]*model.VersionLog, error)
FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error)
}

View File

@@ -0,0 +1,33 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type GroupRequest interface {
Create(ctx context.Context, groupRequests []*model.GroupRequest) (err error)
Delete(ctx context.Context, groupID string, userID string) (err error)
UpdateHandler(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32) (err error)
Take(ctx context.Context, groupID string, userID string) (groupRequest *model.GroupRequest, err error)
FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupRequest, error)
Page(ctx context.Context, userID string, groupIDs []string, handleResults []int, pagination pagination.Pagination) (total int64, groups []*model.GroupRequest, err error)
PageGroup(ctx context.Context, groupIDs []string, handleResults []int, pagination pagination.Pagination) (total int64, groups []*model.GroupRequest, err error)
GetUnhandledCount(ctx context.Context, groupIDs []string, ts int64) (int64, error)
}

View File

@@ -0,0 +1,30 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
type Log interface {
Create(ctx context.Context, log []*model.Log) error
Search(ctx context.Context, keyword string, start time.Time, end time.Time, pagination pagination.Pagination) (int64, []*model.Log, error)
Delete(ctx context.Context, logID []string, userID string) error
Get(ctx context.Context, logIDs []string, userID string) ([]*model.Log, error)
}

View File

@@ -0,0 +1,52 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
// Meeting defines the operations for managing meetings in MongoDB.
type Meeting interface {
// Create creates a new meeting record.
Create(ctx context.Context, meeting *model.Meeting) error
// Take retrieves a meeting by meeting ID. Returns an error if not found.
Take(ctx context.Context, meetingID string) (*model.Meeting, error)
// Update updates meeting information.
Update(ctx context.Context, meetingID string, data map[string]any) error
// UpdateStatus updates the status of a meeting.
UpdateStatus(ctx context.Context, meetingID string, status int32) error
// Find finds meetings by meeting IDs.
Find(ctx context.Context, meetingIDs []string) ([]*model.Meeting, error)
// FindByCreator finds meetings created by a specific user.
FindByCreator(ctx context.Context, creatorUserID string, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error)
// FindAll finds all meetings with pagination.
FindAll(ctx context.Context, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error)
// Search searches meetings by keyword (subject, description).
Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error)
// FindByStatus finds meetings by status.
FindByStatus(ctx context.Context, status int32, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error)
// FindByScheduledTimeRange finds meetings within a scheduled time range.
FindByScheduledTimeRange(ctx context.Context, startTime, endTime int64, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error)
// FindFinishedMeetingsBefore finds finished meetings that ended before the specified time.
// This is used for cleanup tasks like dismissing groups after meetings end.
FindFinishedMeetingsBefore(ctx context.Context, beforeTime time.Time) ([]*model.Meeting, error)
// Delete deletes a meeting by meeting ID.
Delete(ctx context.Context, meetingID string) error
}

View File

@@ -0,0 +1,41 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/pagination"
)
// MeetingCheckIn defines the operations for managing meeting check-ins in MongoDB.
type MeetingCheckIn interface {
// Create creates a new meeting check-in record.
Create(ctx context.Context, checkIn *model.MeetingCheckIn) error
// Take retrieves a check-in by check-in ID. Returns an error if not found.
Take(ctx context.Context, checkInID string) (*model.MeetingCheckIn, error)
// FindByMeetingID finds all check-ins for a meeting with pagination.
FindByMeetingID(ctx context.Context, meetingID string, pagination pagination.Pagination) (total int64, checkIns []*model.MeetingCheckIn, err error)
// FindByUserAndMeetingID finds if a user has checked in for a specific meeting.
FindByUserAndMeetingID(ctx context.Context, userID, meetingID string) (*model.MeetingCheckIn, error)
// CountByMeetingID counts the number of check-ins for a meeting.
CountByMeetingID(ctx context.Context, meetingID string) (int64, error)
// FindByUser finds all check-ins by a user with pagination.
FindByUser(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, checkIns []*model.MeetingCheckIn, err error)
// Delete deletes a check-in by check-in ID.
Delete(ctx context.Context, checkInID string) error
}

View File

@@ -0,0 +1,106 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewBlackMongo(db *mongo.Database) (database.Black, error) {
coll := db.Collection(database.BlackName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "owner_user_id", Value: 1},
{Key: "block_user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
}
return &BlackMgo{coll: coll}, nil
}
type BlackMgo struct {
coll *mongo.Collection
}
func (b *BlackMgo) blackFilter(ownerUserID, blockUserID string) bson.M {
return bson.M{
"owner_user_id": ownerUserID,
"block_user_id": blockUserID,
}
}
func (b *BlackMgo) blacksFilter(blacks []*model.Black) bson.M {
if len(blacks) == 0 {
return nil
}
or := make(bson.A, 0, len(blacks))
for _, black := range blacks {
or = append(or, b.blackFilter(black.OwnerUserID, black.BlockUserID))
}
return bson.M{"$or": or}
}
func (b *BlackMgo) Create(ctx context.Context, blacks []*model.Black) (err error) {
return mongoutil.InsertMany(ctx, b.coll, blacks)
}
func (b *BlackMgo) Delete(ctx context.Context, blacks []*model.Black) (err error) {
if len(blacks) == 0 {
return nil
}
return mongoutil.DeleteMany(ctx, b.coll, b.blacksFilter(blacks))
}
func (b *BlackMgo) UpdateByMap(ctx context.Context, ownerUserID, blockUserID string, args map[string]any) (err error) {
if len(args) == 0 {
return nil
}
return mongoutil.UpdateOne(ctx, b.coll, b.blackFilter(ownerUserID, blockUserID), bson.M{"$set": args}, false)
}
func (b *BlackMgo) Find(ctx context.Context, blacks []*model.Black) (blackList []*model.Black, err error) {
return mongoutil.Find[*model.Black](ctx, b.coll, b.blacksFilter(blacks))
}
func (b *BlackMgo) Take(ctx context.Context, ownerUserID, blockUserID string) (black *model.Black, err error) {
return mongoutil.FindOne[*model.Black](ctx, b.coll, b.blackFilter(ownerUserID, blockUserID))
}
func (b *BlackMgo) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*model.Black, err error) {
return mongoutil.FindPage[*model.Black](ctx, b.coll, bson.M{"owner_user_id": ownerUserID}, pagination)
}
func (b *BlackMgo) FindOwnerBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*model.Black, err error) {
if len(userIDs) == 0 {
return mongoutil.Find[*model.Black](ctx, b.coll, bson.M{"owner_user_id": ownerUserID})
}
return mongoutil.Find[*model.Black](ctx, b.coll, bson.M{"owner_user_id": ownerUserID, "block_user_id": bson.M{"$in": userIDs}})
}
func (b *BlackMgo) FindBlackUserIDs(ctx context.Context, ownerUserID string) (blackUserIDs []string, err error) {
return mongoutil.Find[string](ctx, b.coll, bson.M{"owner_user_id": ownerUserID}, options.Find().SetProjection(bson.M{"_id": 0, "block_user_id": 1}))
}

View File

@@ -0,0 +1,183 @@
package mgo
import (
"context"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/google/uuid"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewCacheMgo(db *mongo.Database) (*CacheMgo, error) {
coll := db.Collection(database.CacheName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "key", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "expire_at", Value: 1},
},
Options: options.Index().SetExpireAfterSeconds(0),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &CacheMgo{coll: coll}, nil
}
type CacheMgo struct {
coll *mongo.Collection
}
func (x *CacheMgo) findToMap(res []model.Cache, now time.Time) map[string]string {
kv := make(map[string]string)
for _, re := range res {
if re.ExpireAt != nil && re.ExpireAt.Before(now) {
continue
}
kv[re.Key] = re.Value
}
return kv
}
func (x *CacheMgo) Get(ctx context.Context, key []string) (map[string]string, error) {
if len(key) == 0 {
return nil, nil
}
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$in": key},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Prefix(ctx context.Context, prefix string) (map[string]string, error) {
now := time.Now()
res, err := mongoutil.Find[model.Cache](ctx, x.coll, bson.M{
"key": bson.M{"$regex": "^" + prefix},
"$or": []bson.M{
{"expire_at": bson.M{"$gt": now}},
{"expire_at": nil},
},
})
if err != nil {
return nil, err
}
return x.findToMap(res, now), nil
}
func (x *CacheMgo) Set(ctx context.Context, key string, value string, expireAt time.Duration) error {
cv := &model.Cache{
Key: key,
Value: value,
}
if expireAt > 0 {
now := time.Now().Add(expireAt)
cv.ExpireAt = &now
}
opt := options.Update().SetUpsert(true)
return mongoutil.UpdateOne(ctx, x.coll, bson.M{"key": key}, bson.M{"$set": cv}, false, opt)
}
func (x *CacheMgo) Incr(ctx context.Context, key string, value int) (int, error) {
pipeline := mongo.Pipeline{
{
{"$set", bson.M{
"value": bson.M{
"$toString": bson.M{
"$add": bson.A{
bson.M{"$toInt": "$value"},
value,
},
},
},
}},
},
}
opt := options.FindOneAndUpdate().SetReturnDocument(options.After)
res, err := mongoutil.FindOneAndUpdate[model.Cache](ctx, x.coll, bson.M{"key": key}, pipeline, opt)
if err != nil {
return 0, err
}
return strconv.Atoi(res.Value)
}
func (x *CacheMgo) Del(ctx context.Context, key []string) error {
if len(key) == 0 {
return nil
}
_, err := x.coll.DeleteMany(ctx, bson.M{"key": bson.M{"$in": key}})
return errs.Wrap(err)
}
func (x *CacheMgo) lockKey(key string) string {
return "LOCK_" + key
}
func (x *CacheMgo) Lock(ctx context.Context, key string, duration time.Duration) (string, error) {
tmp, err := uuid.NewUUID()
if err != nil {
return "", err
}
if duration <= 0 || duration > time.Minute*10 {
duration = time.Minute * 10
}
cv := &model.Cache{
Key: x.lockKey(key),
Value: tmp.String(),
ExpireAt: nil,
}
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
wait := func() error {
timeout := time.NewTimer(time.Millisecond * 100)
defer timeout.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.C:
return nil
}
}
for {
if err := mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": key, "expire_at": bson.M{"$lt": time.Now()}}); err != nil {
return "", err
}
expireAt := time.Now().Add(duration)
cv.ExpireAt = &expireAt
if err := mongoutil.InsertMany[*model.Cache](ctx, x.coll, []*model.Cache{cv}); err != nil {
if mongo.IsDuplicateKeyError(err) {
if err := wait(); err != nil {
return "", err
}
continue
}
return "", err
}
return cv.Value, nil
}
}
func (x *CacheMgo) Unlock(ctx context.Context, key string, value string) error {
return mongoutil.DeleteOne(ctx, x.coll, bson.M{"key": x.lockKey(key), "value": value})
}

View File

@@ -0,0 +1,133 @@
package mgo
import (
"context"
"strings"
"sync"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestName1111(t *testing.T) {
coll := Mongodb().Collection("temp")
//updatePipeline := mongo.Pipeline{
// {
// {"$set", bson.M{
// "age": bson.M{
// "$toString": bson.M{
// "$add": bson.A{
// bson.M{"$toInt": "$age"},
// 1,
// },
// },
// },
// }},
// },
//}
pipeline := mongo.Pipeline{
{
{"$set", bson.M{
"value": bson.M{
"$toString": bson.M{
"$add": bson.A{
bson.M{"$toInt": "$value"},
1,
},
},
},
}},
},
}
opt := options.FindOneAndUpdate().SetUpsert(false).SetReturnDocument(options.After)
res, err := mongoutil.FindOneAndUpdate[model.Cache](context.Background(), coll, bson.M{"key": "123456"}, pipeline, opt)
if err != nil {
panic(err)
}
t.Log(res)
}
func TestName33333(t *testing.T) {
c, err := NewCacheMgo(Mongodb())
if err != nil {
panic(err)
}
if err := c.Set(context.Background(), "123456", "123456", time.Hour); err != nil {
panic(err)
}
if err := c.Set(context.Background(), "123666", "123666", time.Hour); err != nil {
panic(err)
}
res1, err := c.Get(context.Background(), []string{"123456"})
if err != nil {
panic(err)
}
t.Log(res1)
res2, err := c.Prefix(context.Background(), "123")
if err != nil {
panic(err)
}
t.Log(res2)
}
func TestName1111aa(t *testing.T) {
c, err := NewCacheMgo(Mongodb())
if err != nil {
panic(err)
}
var count int
key := "123456"
doFunc := func() {
value, err := c.Lock(context.Background(), key, time.Second*30)
if err != nil {
t.Log("Lock error", err)
return
}
tmp := count
tmp++
count = tmp
t.Log("count", tmp)
if err := c.Unlock(context.Background(), key, value); err != nil {
t.Log("Unlock error", err)
return
}
}
if _, err := c.Lock(context.Background(), key, time.Second*10); err != nil {
t.Log(err)
return
}
var wg sync.WaitGroup
for i := 0; i < 32; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
doFunc()
}
}()
}
wg.Wait()
}
func TestName111111a(t *testing.T) {
arr := strings.SplitN("1:testkakskdask:1111", ":", 2)
t.Log(arr)
}

View File

@@ -0,0 +1,99 @@
// Copyright © 2023 OpenIM open source community. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/tools/errs"
)
func NewClientConfig(db *mongo.Database) (database.ClientConfig, error) {
coll := db.Collection("config")
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "key", Value: 1},
{Key: "user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &ClientConfig{
coll: coll,
}, nil
}
type ClientConfig struct {
coll *mongo.Collection
}
func (x *ClientConfig) Set(ctx context.Context, userID string, config map[string]string) error {
if len(config) == 0 {
return nil
}
for key, value := range config {
filter := bson.M{"key": key, "user_id": userID}
update := bson.M{
"value": value,
}
err := mongoutil.UpdateOne(ctx, x.coll, filter, bson.M{"$set": update}, false, options.Update().SetUpsert(true))
if err != nil {
return err
}
}
return nil
}
func (x *ClientConfig) Get(ctx context.Context, userID string) (map[string]string, error) {
cs, err := mongoutil.Find[*model.ClientConfig](ctx, x.coll, bson.M{"user_id": userID})
if err != nil {
return nil, err
}
cm := make(map[string]string)
for _, config := range cs {
cm[config.Key] = config.Value
}
return cm, nil
}
func (x *ClientConfig) Del(ctx context.Context, userID string, keys []string) error {
if len(keys) == 0 {
return nil
}
return mongoutil.DeleteMany(ctx, x.coll, bson.M{"key": bson.M{"$in": keys}, "user_id": userID})
}
func (x *ClientConfig) GetPage(ctx context.Context, userID string, key string, pagination pagination.Pagination) (int64, []*model.ClientConfig, error) {
filter := bson.M{}
if userID != "" {
filter["user_id"] = userID
}
if key != "" {
filter["key"] = key
}
return mongoutil.FindPage[*model.ClientConfig](ctx, x.coll, filter, pagination)
}

View File

@@ -0,0 +1,325 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
)
func NewConversationMongo(db *mongo.Database) (*ConversationMgo, error) {
coll := db.Collection(database.ConversationName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "owner_user_id", Value: 1},
{Key: "conversation_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "user_id", Value: 1},
},
Options: options.Index(),
},
})
if err != nil {
return nil, errs.Wrap(err)
}
version, err := NewVersionLog(db.Collection(database.ConversationVersionName))
if err != nil {
return nil, err
}
return &ConversationMgo{version: version, coll: coll}, nil
}
type ConversationMgo struct {
version database.VersionLog
coll *mongo.Collection
}
func (c *ConversationMgo) Create(ctx context.Context, conversations []*model.Conversation) (err error) {
return mongoutil.IncrVersion(func() error {
return mongoutil.InsertMany(ctx, c.coll, conversations)
}, func() error {
userConversation := make(map[string][]string)
for _, conversation := range conversations {
userConversation[conversation.OwnerUserID] = append(userConversation[conversation.OwnerUserID], conversation.ConversationID)
}
for userID, conversationIDs := range userConversation {
if err := c.version.IncrVersion(ctx, userID, conversationIDs, model.VersionStateInsert); err != nil {
return err
}
}
return nil
})
}
func (c *ConversationMgo) UpdateByMap(ctx context.Context, userIDs []string, conversationID string, args map[string]any) (int64, error) {
if len(args) == 0 || len(userIDs) == 0 {
return 0, nil
}
filter := bson.M{
"conversation_id": conversationID,
"owner_user_id": bson.M{"$in": userIDs},
}
var rows int64
err := mongoutil.IncrVersion(func() error {
res, err := mongoutil.UpdateMany(ctx, c.coll, filter, bson.M{"$set": args})
if err != nil {
return err
}
rows = res.ModifiedCount
return nil
}, func() error {
for _, userID := range userIDs {
if err := c.version.IncrVersion(ctx, userID, []string{conversationID}, model.VersionStateUpdate); err != nil {
return err
}
}
return nil
})
if err != nil {
return 0, err
}
return rows, nil
}
func (c *ConversationMgo) UpdateUserConversations(ctx context.Context, userID string, args map[string]any) ([]*model.Conversation, error) {
if len(args) == 0 {
return nil, nil
}
filter := bson.M{
"user_id": userID,
}
conversations, err := mongoutil.Find[*model.Conversation](ctx, c.coll, filter, options.Find().SetProjection(bson.M{"_id": 0, "owner_user_id": 1, "conversation_id": 1}))
if err != nil {
return nil, err
}
err = mongoutil.IncrVersion(func() error {
_, err := mongoutil.UpdateMany(ctx, c.coll, filter, bson.M{"$set": args})
if err != nil {
return err
}
return nil
}, func() error {
for _, conversation := range conversations {
if err := c.version.IncrVersion(ctx, conversation.OwnerUserID, []string{conversation.ConversationID}, model.VersionStateUpdate); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return conversations, nil
}
func (c *ConversationMgo) Update(ctx context.Context, conversation *model.Conversation) (err error) {
return mongoutil.IncrVersion(func() error {
return mongoutil.UpdateOne(ctx, c.coll, bson.M{"owner_user_id": conversation.OwnerUserID, "conversation_id": conversation.ConversationID}, bson.M{"$set": conversation}, true)
}, func() error {
return c.version.IncrVersion(ctx, conversation.OwnerUserID, []string{conversation.ConversationID}, model.VersionStateUpdate)
})
}
func (c *ConversationMgo) Find(ctx context.Context, ownerUserID string, conversationIDs []string) (conversations []*model.Conversation, err error) {
return mongoutil.Find[*model.Conversation](ctx, c.coll, bson.M{"owner_user_id": ownerUserID, "conversation_id": bson.M{"$in": conversationIDs}})
}
func (c *ConversationMgo) FindUserID(ctx context.Context, userIDs []string, conversationIDs []string) ([]string, error) {
return mongoutil.Find[string](
ctx,
c.coll,
bson.M{"owner_user_id": bson.M{"$in": userIDs}, "conversation_id": bson.M{"$in": conversationIDs}},
options.Find().SetProjection(bson.M{"_id": 0, "owner_user_id": 1}),
)
}
func (c *ConversationMgo) FindUserIDAllConversationID(ctx context.Context, userID string) ([]string, error) {
return mongoutil.Find[string](ctx, c.coll, bson.M{"owner_user_id": userID}, options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1}))
}
func (c *ConversationMgo) FindUserIDAllNotNotifyConversationID(ctx context.Context, userID string) ([]string, error) {
return mongoutil.Find[string](ctx, c.coll, bson.M{
"owner_user_id": userID,
"recv_msg_opt": constant.ReceiveNotNotifyMessage,
}, options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1}))
}
func (c *ConversationMgo) FindUserIDAllPinnedConversationID(ctx context.Context, userID string) ([]string, error) {
return mongoutil.Find[string](ctx, c.coll, bson.M{
"owner_user_id": userID,
"is_pinned": true,
}, options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1}))
}
func (c *ConversationMgo) Take(ctx context.Context, userID, conversationID string) (conversation *model.Conversation, err error) {
return mongoutil.FindOne[*model.Conversation](ctx, c.coll, bson.M{"owner_user_id": userID, "conversation_id": conversationID})
}
func (c *ConversationMgo) FindConversationID(ctx context.Context, userID string, conversationIDs []string) (existConversationID []string, err error) {
return mongoutil.Find[string](ctx, c.coll, bson.M{"owner_user_id": userID, "conversation_id": bson.M{"$in": conversationIDs}}, options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1}))
}
func (c *ConversationMgo) FindUserIDAllConversations(ctx context.Context, userID string) (conversations []*model.Conversation, err error) {
return mongoutil.Find[*model.Conversation](ctx, c.coll, bson.M{"owner_user_id": userID})
}
func (c *ConversationMgo) FindRecvMsgUserIDs(ctx context.Context, conversationID string, recvOpts []int) ([]string, error) {
var filter any
if len(recvOpts) == 0 {
filter = bson.M{"conversation_id": conversationID}
} else {
filter = bson.M{"conversation_id": conversationID, "recv_msg_opt": bson.M{"$in": recvOpts}}
}
return mongoutil.Find[string](ctx, c.coll, filter, options.Find().SetProjection(bson.M{"_id": 0, "owner_user_id": 1}))
}
func (c *ConversationMgo) GetUserRecvMsgOpt(ctx context.Context, ownerUserID, conversationID string) (opt int, err error) {
return mongoutil.FindOne[int](ctx, c.coll, bson.M{"owner_user_id": ownerUserID, "conversation_id": conversationID}, options.FindOne().SetProjection(bson.M{"recv_msg_opt": 1}))
}
func (c *ConversationMgo) GetAllConversationIDs(ctx context.Context) ([]string, error) {
return mongoutil.Aggregate[string](ctx, c.coll, []bson.M{
{"$group": bson.M{"_id": "$conversation_id"}},
{"$project": bson.M{"_id": 0, "conversation_id": "$_id"}},
})
}
func (c *ConversationMgo) GetAllConversationIDsNumber(ctx context.Context) (int64, error) {
counts, err := mongoutil.Aggregate[int64](ctx, c.coll, []bson.M{
{"$group": bson.M{"_id": "$conversation_id"}},
{"$group": bson.M{"_id": nil, "count": bson.M{"$sum": 1}}},
{"$project": bson.M{"_id": 0}},
})
if err != nil {
return 0, err
}
if len(counts) == 0 {
return 0, nil
}
return counts[0], nil
}
func (c *ConversationMgo) PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error) {
return mongoutil.FindPageOnly[string](ctx, c.coll, bson.M{}, pagination, options.Find().SetProjection(bson.M{"conversation_id": 1}))
}
func (c *ConversationMgo) GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*model.Conversation, error) {
return mongoutil.Find[*model.Conversation](ctx, c.coll, bson.M{"conversation_id": bson.M{"$in": conversationIDs}})
}
func (c *ConversationMgo) GetConversationIDsNeedDestruct(ctx context.Context) ([]*model.Conversation, error) {
// "is_msg_destruct = 1 && msg_destruct_time != 0 && (UNIX_TIMESTAMP(NOW()) > (msg_destruct_time + UNIX_TIMESTAMP(latest_msg_destruct_time)) || latest_msg_destruct_time is NULL)"
return mongoutil.Find[*model.Conversation](ctx, c.coll, bson.M{
"is_msg_destruct": 1,
"msg_destruct_time": bson.M{"$ne": 0},
"$or": []bson.M{
{
"$expr": bson.M{
"$gt": []any{
time.Now(),
bson.M{"$add": []any{"$msg_destruct_time", "$latest_msg_destruct_time"}},
},
},
},
{
"latest_msg_destruct_time": nil,
},
},
})
}
func (c *ConversationMgo) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) {
return mongoutil.Find[string](
ctx,
c.coll,
bson.M{"conversation_id": conversationID, "recv_msg_opt": bson.M{"$ne": constant.ReceiveMessage}},
options.Find().SetProjection(bson.M{"_id": 0, "owner_user_id": 1}),
)
}
func (c *ConversationMgo) FindConversationUserVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) {
return c.version.FindChangeLog(ctx, userID, version, limit)
}
func (c *ConversationMgo) FindRandConversation(ctx context.Context, ts int64, limit int) ([]*model.Conversation, error) {
pipeline := []bson.M{
{
"$match": bson.M{
"is_msg_destruct": true,
"msg_destruct_time": bson.M{"$ne": 0},
},
},
{
"$addFields": bson.M{
"next_msg_destruct_timestamp": bson.M{
"$add": []any{
bson.M{
"$toLong": "$latest_msg_destruct_time",
},
bson.M{
"$multiply": []any{
"$msg_destruct_time",
1000, // convert to milliseconds
},
},
},
},
},
},
{
"$match": bson.M{
"next_msg_destruct_timestamp": bson.M{"$lt": ts},
},
},
{
"$sample": bson.M{
"size": limit,
},
},
}
return mongoutil.Aggregate[*model.Conversation](ctx, c.coll, pipeline)
}
func (c *ConversationMgo) DeleteUsersConversations(ctx context.Context, userID string, conversationIDs []string) error {
if len(conversationIDs) == 0 {
return nil
}
filter := bson.M{
"owner_user_id": userID,
"conversation_id": bson.M{"$in": conversationIDs},
}
return mongoutil.IncrVersion(func() error {
return mongoutil.DeleteMany(ctx, c.coll, filter)
}, func() error {
return c.version.IncrVersion(ctx, userID, conversationIDs, model.VersionStateDelete)
})
}

View File

@@ -0,0 +1,15 @@
// Copyright © 2024 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo // import "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"

View File

@@ -0,0 +1,271 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"go.mongodb.org/mongo-driver/bson/primitive"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// FriendMgo implements Friend using MongoDB as the storage backend.
type FriendMgo struct {
coll *mongo.Collection
owner database.VersionLog
}
// NewFriendMongo creates a new instance of FriendMgo with the provided MongoDB database.
func NewFriendMongo(db *mongo.Database) (database.Friend, error) {
coll := db.Collection(database.FriendName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "owner_user_id", Value: 1},
{Key: "friend_user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, err
}
owner, err := NewVersionLog(db.Collection(database.FriendVersionName))
if err != nil {
return nil, err
}
return &FriendMgo{coll: coll, owner: owner}, nil
}
func (f *FriendMgo) friendSort() any {
return bson.D{{"is_pinned", -1}, {"_id", 1}}
}
// Create inserts multiple friend records.
func (f *FriendMgo) Create(ctx context.Context, friends []*model.Friend) error {
for i, friend := range friends {
if friend.ID.IsZero() {
friends[i].ID = primitive.NewObjectID()
}
if friend.CreateTime.IsZero() {
friends[i].CreateTime = time.Now()
}
}
return mongoutil.IncrVersion(func() error {
return mongoutil.InsertMany(ctx, f.coll, friends)
}, func() error {
mp := make(map[string][]string)
for _, friend := range friends {
mp[friend.OwnerUserID] = append(mp[friend.OwnerUserID], friend.FriendUserID)
}
for ownerUserID, friendUserIDs := range mp {
if err := f.owner.IncrVersion(ctx, ownerUserID, friendUserIDs, model.VersionStateInsert); err != nil {
return err
}
}
return nil
})
}
// Delete removes specified friends of the owner user.
func (f *FriendMgo) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) error {
filter := bson.M{
"owner_user_id": ownerUserID,
"friend_user_id": bson.M{"$in": friendUserIDs},
}
return mongoutil.IncrVersion(func() error {
return mongoutil.DeleteOne(ctx, f.coll, filter)
}, func() error {
return f.owner.IncrVersion(ctx, ownerUserID, friendUserIDs, model.VersionStateDelete)
})
}
// UpdateByMap updates specific fields of a friend document using a map.
func (f *FriendMgo) UpdateByMap(ctx context.Context, ownerUserID string, friendUserID string, args map[string]any) error {
if len(args) == 0 {
return nil
}
filter := bson.M{
"owner_user_id": ownerUserID,
"friend_user_id": friendUserID,
}
return mongoutil.IncrVersion(func() error {
return mongoutil.UpdateOne(ctx, f.coll, filter, bson.M{"$set": args}, true)
}, func() error {
var friendUserIDs []string
if f.IsUpdateIsPinned(args) {
friendUserIDs = []string{model.VersionSortChangeID, friendUserID}
} else {
friendUserIDs = []string{friendUserID}
}
return f.owner.IncrVersion(ctx, ownerUserID, friendUserIDs, model.VersionStateUpdate)
})
}
// UpdateRemark updates the remark for a specific friend.
func (f *FriendMgo) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) error {
return f.UpdateByMap(ctx, ownerUserID, friendUserID, map[string]any{"remark": remark})
}
func (f *FriendMgo) fillTime(friends ...*model.Friend) {
for i, friend := range friends {
if friend.CreateTime.IsZero() {
friends[i].CreateTime = friend.ID.Timestamp()
}
}
}
func (f *FriendMgo) findOne(ctx context.Context, filter any) (*model.Friend, error) {
friend, err := mongoutil.FindOne[*model.Friend](ctx, f.coll, filter)
if err != nil {
return nil, err
}
f.fillTime(friend)
return friend, nil
}
func (f *FriendMgo) find(ctx context.Context, filter any) ([]*model.Friend, error) {
friends, err := mongoutil.Find[*model.Friend](ctx, f.coll, filter)
if err != nil {
return nil, err
}
f.fillTime(friends...)
return friends, nil
}
func (f *FriendMgo) findPage(ctx context.Context, filter any, pagination pagination.Pagination, opts ...*options.FindOptions) (int64, []*model.Friend, error) {
return mongoutil.FindPage[*model.Friend](ctx, f.coll, filter, pagination, opts...)
}
// Take retrieves a single friend document. Returns an error if not found.
func (f *FriendMgo) Take(ctx context.Context, ownerUserID, friendUserID string) (*model.Friend, error) {
filter := bson.M{
"owner_user_id": ownerUserID,
"friend_user_id": friendUserID,
}
return f.findOne(ctx, filter)
}
// FindUserState finds the friendship status between two users.
func (f *FriendMgo) FindUserState(ctx context.Context, userID1, userID2 string) ([]*model.Friend, error) {
filter := bson.M{
"$or": []bson.M{
{"owner_user_id": userID1, "friend_user_id": userID2},
{"owner_user_id": userID2, "friend_user_id": userID1},
},
}
return f.find(ctx, filter)
}
// FindFriends retrieves a list of friends for a given owner. Missing friends do not cause an error.
func (f *FriendMgo) FindFriends(ctx context.Context, ownerUserID string, friendUserIDs []string) ([]*model.Friend, error) {
filter := bson.M{
"owner_user_id": ownerUserID,
"friend_user_id": bson.M{"$in": friendUserIDs},
}
return f.find(ctx, filter)
}
// FindReversalFriends finds users who have added the specified user as a friend.
func (f *FriendMgo) FindReversalFriends(ctx context.Context, friendUserID string, ownerUserIDs []string) ([]*model.Friend, error) {
filter := bson.M{
"owner_user_id": bson.M{"$in": ownerUserIDs},
"friend_user_id": friendUserID,
}
return f.find(ctx, filter)
}
// FindOwnerFriends retrieves a paginated list of friends for a given owner.
func (f *FriendMgo) FindOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (int64, []*model.Friend, error) {
filter := bson.M{"owner_user_id": ownerUserID}
opt := options.Find().SetSort(f.friendSort())
return f.findPage(ctx, filter, pagination, opt)
}
func (f *FriendMgo) FindOwnerFriendUserIds(ctx context.Context, ownerUserID string, limit int) ([]string, error) {
filter := bson.M{"owner_user_id": ownerUserID}
opt := options.Find().SetProjection(bson.M{"_id": 0, "friend_user_id": 1}).SetSort(f.friendSort()).SetLimit(int64(limit))
return mongoutil.Find[string](ctx, f.coll, filter, opt)
}
// FindInWhoseFriends finds users who have added the specified user as a friend, with pagination.
func (f *FriendMgo) FindInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (int64, []*model.Friend, error) {
filter := bson.M{"friend_user_id": friendUserID}
opt := options.Find().SetSort(f.friendSort())
return f.findPage(ctx, filter, pagination, opt)
}
// FindFriendUserIDs retrieves a list of friend user IDs for a given owner.
func (f *FriendMgo) FindFriendUserIDs(ctx context.Context, ownerUserID string) ([]string, error) {
filter := bson.M{"owner_user_id": ownerUserID}
return mongoutil.Find[string](ctx, f.coll, filter, options.Find().SetProjection(bson.M{"_id": 0, "friend_user_id": 1}).SetSort(f.friendSort()))
}
func (f *FriendMgo) UpdateFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, val map[string]any) error {
// Ensure there are IDs to update
if len(friendUserIDs) == 0 || len(val) == 0 {
return nil // Or return an error if you expect there to always be IDs
}
// Create a filter to match documents with the specified ownerUserID and any of the friendUserIDs
filter := bson.M{
"owner_user_id": ownerUserID,
"friend_user_id": bson.M{"$in": friendUserIDs},
}
// Create an update document
update := bson.M{"$set": val}
return mongoutil.IncrVersion(func() error {
return mongoutil.Ignore(mongoutil.UpdateMany(ctx, f.coll, filter, update))
}, func() error {
var userIDs []string
if f.IsUpdateIsPinned(val) {
userIDs = append([]string{model.VersionSortChangeID}, friendUserIDs...)
} else {
userIDs = friendUserIDs
}
return f.owner.IncrVersion(ctx, ownerUserID, userIDs, model.VersionStateUpdate)
})
}
func (f *FriendMgo) FindIncrVersion(ctx context.Context, ownerUserID string, version uint, limit int) (*model.VersionLog, error) {
return f.owner.FindChangeLog(ctx, ownerUserID, version, limit)
}
func (f *FriendMgo) FindFriendUserID(ctx context.Context, friendUserID string) ([]string, error) {
filter := bson.M{
"friend_user_id": friendUserID,
}
return mongoutil.Find[string](ctx, f.coll, filter, options.Find().SetProjection(bson.M{"_id": 0, "owner_user_id": 1}).SetSort(f.friendSort()))
}
func (f *FriendMgo) IncrVersion(ctx context.Context, ownerUserID string, friendUserIDs []string, state int32) error {
return f.owner.IncrVersion(ctx, ownerUserID, friendUserIDs, state)
}
func (f *FriendMgo) IsUpdateIsPinned(data map[string]any) bool {
if data == nil {
return false
}
_, ok := data["is_pinned"]
return ok
}

View File

@@ -0,0 +1,143 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"go.mongodb.org/mongo-driver/mongo/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
)
func NewFriendRequestMongo(db *mongo.Database) (database.FriendRequest, error) {
coll := db.Collection(database.FriendRequestName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "from_user_id", Value: 1},
{Key: "to_user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "create_time", Value: -1},
},
},
})
if err != nil {
return nil, err
}
return &FriendRequestMgo{coll: coll}, nil
}
type FriendRequestMgo struct {
coll *mongo.Collection
}
func (f *FriendRequestMgo) sort() any {
return bson.D{{Key: "create_time", Value: -1}}
}
func (f *FriendRequestMgo) FindToUserID(ctx context.Context, toUserID string, handleResults []int, pagination pagination.Pagination) (total int64, friendRequests []*model.FriendRequest, err error) {
filter := bson.M{"to_user_id": toUserID}
if len(handleResults) > 0 {
filter["handle_result"] = bson.M{"$in": handleResults}
}
return mongoutil.FindPage[*model.FriendRequest](ctx, f.coll, filter, pagination, options.Find().SetSort(f.sort()))
}
func (f *FriendRequestMgo) FindFromUserID(ctx context.Context, fromUserID string, handleResults []int, pagination pagination.Pagination) (total int64, friendRequests []*model.FriendRequest, err error) {
filter := bson.M{"from_user_id": fromUserID}
if len(handleResults) > 0 {
filter["handle_result"] = bson.M{"$in": handleResults}
}
return mongoutil.FindPage[*model.FriendRequest](ctx, f.coll, filter, pagination, options.Find().SetSort(f.sort()))
}
func (f *FriendRequestMgo) FindBothFriendRequests(ctx context.Context, fromUserID, toUserID string) (friends []*model.FriendRequest, err error) {
filter := bson.M{"$or": []bson.M{
{"from_user_id": fromUserID, "to_user_id": toUserID},
{"from_user_id": toUserID, "to_user_id": fromUserID},
}}
return mongoutil.Find[*model.FriendRequest](ctx, f.coll, filter)
}
func (f *FriendRequestMgo) Create(ctx context.Context, friendRequests []*model.FriendRequest) error {
return mongoutil.InsertMany(ctx, f.coll, friendRequests)
}
func (f *FriendRequestMgo) Delete(ctx context.Context, fromUserID, toUserID string) (err error) {
return mongoutil.DeleteOne(ctx, f.coll, bson.M{"from_user_id": fromUserID, "to_user_id": toUserID})
}
func (f *FriendRequestMgo) UpdateByMap(ctx context.Context, formUserID, toUserID string, args map[string]any) (err error) {
if len(args) == 0 {
return nil
}
return mongoutil.UpdateOne(ctx, f.coll, bson.M{"from_user_id": formUserID, "to_user_id": toUserID}, bson.M{"$set": args}, true)
}
func (f *FriendRequestMgo) Update(ctx context.Context, friendRequest *model.FriendRequest) (err error) {
updater := bson.M{}
if friendRequest.HandleResult != 0 {
updater["handle_result"] = friendRequest.HandleResult
}
if friendRequest.ReqMsg != "" {
updater["req_msg"] = friendRequest.ReqMsg
}
if friendRequest.HandlerUserID != "" {
updater["handler_user_id"] = friendRequest.HandlerUserID
}
if friendRequest.HandleMsg != "" {
updater["handle_msg"] = friendRequest.HandleMsg
}
if !friendRequest.HandleTime.IsZero() {
updater["handle_time"] = friendRequest.HandleTime
}
if friendRequest.Ex != "" {
updater["ex"] = friendRequest.Ex
}
if len(updater) == 0 {
return nil
}
filter := bson.M{"from_user_id": friendRequest.FromUserID, "to_user_id": friendRequest.ToUserID}
return mongoutil.UpdateOne(ctx, f.coll, filter, bson.M{"$set": updater}, true)
}
func (f *FriendRequestMgo) Find(ctx context.Context, fromUserID, toUserID string) (friendRequest *model.FriendRequest, err error) {
return mongoutil.FindOne[*model.FriendRequest](ctx, f.coll, bson.M{"from_user_id": fromUserID, "to_user_id": toUserID})
}
func (f *FriendRequestMgo) Take(ctx context.Context, fromUserID, toUserID string) (friendRequest *model.FriendRequest, err error) {
return f.Find(ctx, fromUserID, toUserID)
}
func (f *FriendRequestMgo) GetUnhandledCount(ctx context.Context, userID string, ts int64) (int64, error) {
filter := bson.M{"to_user_id": userID, "handle_result": 0}
if ts != 0 {
filter["create_time"] = bson.M{"$gt": time.UnixMilli(ts)}
}
return mongoutil.Count(ctx, f.coll, filter)
}

View File

@@ -0,0 +1,162 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewGroupMongo(db *mongo.Database) (database.Group, error) {
coll := db.Collection(database.GroupName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "group_id", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, errs.Wrap(err)
}
return &GroupMgo{coll: coll}, nil
}
type GroupMgo struct {
coll *mongo.Collection
}
func (g *GroupMgo) sortGroup() any {
return bson.D{{"group_name", 1}, {"create_time", 1}}
}
func (g *GroupMgo) Create(ctx context.Context, groups []*model.Group) (err error) {
return mongoutil.InsertMany(ctx, g.coll, groups)
}
func (g *GroupMgo) UpdateStatus(ctx context.Context, groupID string, status int32) (err error) {
return g.UpdateMap(ctx, groupID, map[string]any{"status": status})
}
func (g *GroupMgo) UpdateMap(ctx context.Context, groupID string, args map[string]any) (err error) {
if len(args) == 0 {
return nil
}
return mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID}, bson.M{"$set": args}, true)
}
func (g *GroupMgo) Find(ctx context.Context, groupIDs []string) (groups []*model.Group, err error) {
return mongoutil.Find[*model.Group](ctx, g.coll, bson.M{"group_id": bson.M{"$in": groupIDs}})
}
func (g *GroupMgo) Take(ctx context.Context, groupID string) (group *model.Group, err error) {
return mongoutil.FindOne[*model.Group](ctx, g.coll, bson.M{"group_id": groupID})
}
func (g *GroupMgo) Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*model.Group, err error) {
// Define the sorting options
opts := options.Find().SetSort(bson.D{{Key: "create_time", Value: -1}})
// Perform the search with pagination and sorting
return mongoutil.FindPage[*model.Group](ctx, g.coll, bson.M{
"group_name": bson.M{"$regex": keyword},
"status": bson.M{"$ne": constant.GroupStatusDismissed},
}, pagination, opts)
}
func (g *GroupMgo) CountTotal(ctx context.Context, before *time.Time) (count int64, err error) {
if before == nil {
return mongoutil.Count(ctx, g.coll, bson.M{})
}
return mongoutil.Count(ctx, g.coll, bson.M{"create_time": bson.M{"$lt": before}})
}
func (g *GroupMgo) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) {
pipeline := bson.A{
bson.M{
"$match": bson.M{
"create_time": bson.M{
"$gte": start,
"$lt": end,
},
},
},
bson.M{
"$group": bson.M{
"_id": bson.M{
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": "$create_time",
},
},
"count": bson.M{
"$sum": 1,
},
},
},
}
type Item struct {
Date string `bson:"_id"`
Count int64 `bson:"count"`
}
items, err := mongoutil.Aggregate[Item](ctx, g.coll, pipeline)
if err != nil {
return nil, err
}
res := make(map[string]int64, len(items))
for _, item := range items {
res[item.Date] = item.Count
}
return res, nil
}
func (g *GroupMgo) FindJoinSortGroupID(ctx context.Context, groupIDs []string) ([]string, error) {
if len(groupIDs) < 2 {
return groupIDs, nil
}
filter := bson.M{
"group_id": bson.M{"$in": groupIDs},
"status": bson.M{"$ne": constant.GroupStatusDismissed},
}
opt := options.Find().SetSort(g.sortGroup()).SetProjection(bson.M{"_id": 0, "group_id": 1})
return mongoutil.Find[string](ctx, g.coll, filter, opt)
}
func (g *GroupMgo) SearchJoin(ctx context.Context, groupIDs []string, keyword string, pagination pagination.Pagination) (int64, []*model.Group, error) {
if len(groupIDs) == 0 {
return 0, nil, nil
}
filter := bson.M{
"group_id": bson.M{"$in": groupIDs},
"status": bson.M{"$ne": constant.GroupStatusDismissed},
}
if keyword != "" {
filter["group_name"] = bson.M{"$regex": keyword}
}
// Define the sorting options
opts := options.Find().SetSort(g.sortGroup())
// Perform the search with pagination and sorting
return mongoutil.FindPage[*model.Group](ctx, g.coll, filter, pagination, opts)
}

View File

@@ -0,0 +1,282 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/log"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewGroupMember(db *mongo.Database) (database.GroupMember, error) {
coll := db.Collection(database.GroupMemberName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "group_id", Value: 1},
{Key: "user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, errs.Wrap(err)
}
member, err := NewVersionLog(db.Collection(database.GroupMemberVersionName))
if err != nil {
return nil, err
}
join, err := NewVersionLog(db.Collection(database.GroupJoinVersionName))
if err != nil {
return nil, err
}
return &GroupMemberMgo{coll: coll, member: member, join: join}, nil
}
type GroupMemberMgo struct {
coll *mongo.Collection
member database.VersionLog
join database.VersionLog
}
func (g *GroupMemberMgo) memberSort() any {
return bson.D{{Key: "role_level", Value: -1}, {Key: "create_time", Value: 1}}
}
func (g *GroupMemberMgo) Create(ctx context.Context, groupMembers []*model.GroupMember) (err error) {
return mongoutil.IncrVersion(func() error {
return mongoutil.InsertMany(ctx, g.coll, groupMembers)
}, func() error {
gms := make(map[string][]string)
for _, member := range groupMembers {
gms[member.GroupID] = append(gms[member.GroupID], member.UserID)
}
for groupID, userIDs := range gms {
if err := g.member.IncrVersion(ctx, groupID, userIDs, model.VersionStateInsert); err != nil {
return err
}
}
return nil
}, func() error {
gms := make(map[string][]string)
for _, member := range groupMembers {
gms[member.UserID] = append(gms[member.UserID], member.GroupID)
}
for userID, groupIDs := range gms {
if err := g.join.IncrVersion(ctx, userID, groupIDs, model.VersionStateInsert); err != nil {
return err
}
}
return nil
})
}
func (g *GroupMemberMgo) Delete(ctx context.Context, groupID string, userIDs []string) (err error) {
filter := bson.M{"group_id": groupID}
if len(userIDs) > 0 {
filter["user_id"] = bson.M{"$in": userIDs}
}
return mongoutil.IncrVersion(func() error {
return mongoutil.DeleteMany(ctx, g.coll, filter)
}, func() error {
if len(userIDs) == 0 {
return g.member.Delete(ctx, groupID)
} else {
return g.member.IncrVersion(ctx, groupID, userIDs, model.VersionStateDelete)
}
}, func() error {
for _, userID := range userIDs {
if err := g.join.IncrVersion(ctx, userID, []string{groupID}, model.VersionStateDelete); err != nil {
return err
}
}
return nil
})
}
func (g *GroupMemberMgo) UpdateRoleLevel(ctx context.Context, groupID string, userID string, roleLevel int32) error {
return mongoutil.IncrVersion(func() error {
return mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID},
bson.M{"$set": bson.M{"role_level": roleLevel}}, true)
}, func() error {
return g.member.IncrVersion(ctx, groupID, []string{model.VersionSortChangeID, userID}, model.VersionStateUpdate)
})
}
func (g *GroupMemberMgo) UpdateUserRoleLevels(ctx context.Context, groupID string, firstUserID string, firstUserRoleLevel int32, secondUserID string, secondUserRoleLevel int32) error {
return mongoutil.IncrVersion(func() error {
if err := mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": firstUserID},
bson.M{"$set": bson.M{"role_level": firstUserRoleLevel}}, true); err != nil {
return err
}
if err := mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": secondUserID},
bson.M{"$set": bson.M{"role_level": secondUserRoleLevel}}, true); err != nil {
return err
}
return nil
}, func() error {
return g.member.IncrVersion(ctx, groupID, []string{model.VersionSortChangeID, firstUserID, secondUserID}, model.VersionStateUpdate)
})
}
func (g *GroupMemberMgo) Update(ctx context.Context, groupID string, userID string, data map[string]any) (err error) {
if len(data) == 0 {
return nil
}
return mongoutil.IncrVersion(func() error {
return mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID}, bson.M{"$set": data}, true)
}, func() error {
var userIDs []string
if g.IsUpdateRoleLevel(data) {
userIDs = []string{model.VersionSortChangeID, userID}
} else {
userIDs = []string{userID}
}
return g.member.IncrVersion(ctx, groupID, userIDs, model.VersionStateUpdate)
})
}
func (g *GroupMemberMgo) FindMemberUserID(ctx context.Context, groupID string) (userIDs []string, err error) {
return mongoutil.Find[string](ctx, g.coll, bson.M{"group_id": groupID}, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}).SetSort(g.memberSort()))
}
func (g *GroupMemberMgo) Find(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupMember, error) {
filter := bson.M{"group_id": groupID}
if len(userIDs) > 0 {
filter["user_id"] = bson.M{"$in": userIDs}
}
return mongoutil.Find[*model.GroupMember](ctx, g.coll, filter)
}
func (g *GroupMemberMgo) FindInGroup(ctx context.Context, userID string, groupIDs []string) ([]*model.GroupMember, error) {
filter := bson.M{"user_id": userID}
if len(groupIDs) > 0 {
filter["group_id"] = bson.M{"$in": groupIDs}
}
return mongoutil.Find[*model.GroupMember](ctx, g.coll, filter)
}
func (g *GroupMemberMgo) Take(ctx context.Context, groupID string, userID string) (groupMember *model.GroupMember, err error) {
return mongoutil.FindOne[*model.GroupMember](ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID})
}
func (g *GroupMemberMgo) TakeOwner(ctx context.Context, groupID string) (groupMember *model.GroupMember, err error) {
return mongoutil.FindOne[*model.GroupMember](ctx, g.coll, bson.M{"group_id": groupID, "role_level": constant.GroupOwner})
}
func (g *GroupMemberMgo) FindRoleLevelUserIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error) {
return mongoutil.Find[string](ctx, g.coll, bson.M{"group_id": groupID, "role_level": roleLevel}, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}))
}
func (g *GroupMemberMgo) SearchMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*model.GroupMember, error) {
// 支持通过昵称、user_id账号搜索
// 使用 $or 条件,匹配昵称或 user_id
filter := bson.M{
"group_id": groupID,
"$or": []bson.M{
{"nickname": bson.M{"$regex": keyword, "$options": "i"}}, // 昵称模糊匹配,不区分大小写
{"user_id": bson.M{"$regex": keyword, "$options": "i"}}, // user_id账号模糊匹配不区分大小写
},
}
return mongoutil.FindPage[*model.GroupMember](ctx, g.coll, filter, pagination, options.Find().SetSort(g.memberSort()))
}
// SearchMemberByFields 支持通过多个独立字段搜索群成员昵称、账号userID、手机号
// nickname: 用户昵称(群内昵称)
// userID: 用户账号user_id
// phone: 手机号(如果群成员表中有相关字段,或通过 Ex 字段存储)
func (g *GroupMemberMgo) SearchMemberByFields(ctx context.Context, groupID string, nickname, userID, phone string, pagination pagination.Pagination) (int64, []*model.GroupMember, error) {
filter := bson.M{"group_id": groupID}
// 构建多个搜索条件,使用 $and 确保所有提供的条件都满足
conditions := []bson.M{}
if nickname != "" {
conditions = append(conditions, bson.M{"nickname": bson.M{"$regex": nickname, "$options": "i"}})
}
if userID != "" {
conditions = append(conditions, bson.M{"user_id": bson.M{"$regex": userID, "$options": "i"}})
}
if phone != "" {
// 手机号可能存储在 Ex 字段中,使用正则表达式匹配
// 如果 Ex 字段是 JSON 格式,可能需要更复杂的查询
conditions = append(conditions, bson.M{"ex": bson.M{"$regex": phone, "$options": "i"}})
}
// 如果有搜索条件,添加到 filter 中
if len(conditions) > 0 {
filter["$and"] = conditions
}
return mongoutil.FindPage[*model.GroupMember](ctx, g.coll, filter, pagination, options.Find().SetSort(g.memberSort()))
}
func (g *GroupMemberMgo) FindUserJoinedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) {
return mongoutil.Find[string](ctx, g.coll, bson.M{"user_id": userID}, options.Find().SetProjection(bson.M{"_id": 0, "group_id": 1}).SetSort(g.memberSort()))
}
func (g *GroupMemberMgo) TakeGroupMemberNum(ctx context.Context, groupID string) (count int64, err error) {
return mongoutil.Count(ctx, g.coll, bson.M{"group_id": groupID})
}
func (g *GroupMemberMgo) FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) {
filter := bson.M{
"user_id": userID,
"role_level": bson.M{
"$in": []int{constant.GroupOwner, constant.GroupAdmin},
},
}
return mongoutil.Find[string](ctx, g.coll, filter, options.Find().SetProjection(bson.M{"_id": 0, "group_id": 1}))
}
func (g *GroupMemberMgo) IsUpdateRoleLevel(data map[string]any) bool {
if len(data) == 0 {
return false
}
_, ok := data["role_level"]
return ok
}
func (g *GroupMemberMgo) JoinGroupIncrVersion(ctx context.Context, userID string, groupIDs []string, state int32) error {
return g.join.IncrVersion(ctx, userID, groupIDs, state)
}
func (g *GroupMemberMgo) MemberGroupIncrVersion(ctx context.Context, groupID string, userIDs []string, state int32) error {
return g.member.IncrVersion(ctx, groupID, userIDs, state)
}
func (g *GroupMemberMgo) FindMemberIncrVersion(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error) {
log.ZDebug(ctx, "find member incr version", "groupID", groupID, "version", version)
return g.member.FindChangeLog(ctx, groupID, version, limit)
}
func (g *GroupMemberMgo) BatchFindMemberIncrVersion(ctx context.Context, groupIDs []string, versions []uint, limits []int) ([]*model.VersionLog, error) {
log.ZDebug(ctx, "Batch find member incr version", "groupIDs", groupIDs, "versions", versions)
return g.member.BatchFindChangeLog(ctx, groupIDs, versions, limits)
}
func (g *GroupMemberMgo) FindJoinIncrVersion(ctx context.Context, userID string, version uint, limit int) (*model.VersionLog, error) {
log.ZDebug(ctx, "find join incr version", "userID", userID, "version", version)
return g.join.FindChangeLog(ctx, userID, version, limit)
}

View File

@@ -0,0 +1,115 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/utils/datautil"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
)
func NewGroupRequestMgo(db *mongo.Database) (database.GroupRequest, error) {
coll := db.Collection(database.GroupRequestName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "group_id", Value: 1},
{Key: "user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "req_time", Value: -1},
},
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &GroupRequestMgo{coll: coll}, nil
}
type GroupRequestMgo struct {
coll *mongo.Collection
}
func (g *GroupRequestMgo) Create(ctx context.Context, groupRequests []*model.GroupRequest) (err error) {
return mongoutil.InsertMany(ctx, g.coll, groupRequests)
}
func (g *GroupRequestMgo) Delete(ctx context.Context, groupID string, userID string) (err error) {
return mongoutil.DeleteOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID})
}
func (g *GroupRequestMgo) UpdateHandler(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32) (err error) {
return mongoutil.UpdateOne(ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID}, bson.M{"$set": bson.M{"handle_msg": handledMsg, "handle_result": handleResult}}, true)
}
func (g *GroupRequestMgo) Take(ctx context.Context, groupID string, userID string) (groupRequest *model.GroupRequest, err error) {
return mongoutil.FindOne[*model.GroupRequest](ctx, g.coll, bson.M{"group_id": groupID, "user_id": userID})
}
func (g *GroupRequestMgo) FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*model.GroupRequest, error) {
return mongoutil.Find[*model.GroupRequest](ctx, g.coll, bson.M{"group_id": groupID, "user_id": bson.M{"$in": userIDs}})
}
func (g *GroupRequestMgo) sort() any {
return bson.D{{Key: "req_time", Value: -1}}
}
func (g *GroupRequestMgo) Page(ctx context.Context, userID string, groupIDs []string, handleResults []int, pagination pagination.Pagination) (total int64, groups []*model.GroupRequest, err error) {
filter := bson.M{"user_id": userID}
if len(groupIDs) > 0 {
filter["group_id"] = bson.M{"$in": datautil.Distinct(groupIDs)}
}
if len(handleResults) > 0 {
filter["handle_result"] = bson.M{"$in": handleResults}
}
return mongoutil.FindPage[*model.GroupRequest](ctx, g.coll, filter, pagination, options.Find().SetSort(g.sort()))
}
func (g *GroupRequestMgo) PageGroup(ctx context.Context, groupIDs []string, handleResults []int, pagination pagination.Pagination) (total int64, groups []*model.GroupRequest, err error) {
if len(groupIDs) == 0 {
return 0, nil, nil
}
filter := bson.M{"group_id": bson.M{"$in": groupIDs}}
if len(handleResults) > 0 {
filter["handle_result"] = bson.M{"$in": handleResults}
}
return mongoutil.FindPage[*model.GroupRequest](ctx, g.coll, filter, pagination, options.Find().SetSort(g.sort()))
}
func (g *GroupRequestMgo) GetUnhandledCount(ctx context.Context, groupIDs []string, ts int64) (int64, error) {
if len(groupIDs) == 0 {
return 0, nil
}
filter := bson.M{"group_id": bson.M{"$in": groupIDs}, "handle_result": 0}
if ts != 0 {
filter["req_time"] = bson.M{"$gt": time.UnixMilli(ts)}
}
return mongoutil.Count(ctx, g.coll, filter)
}

View File

@@ -0,0 +1,24 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/mongo"
)
func IsNotFound(err error) bool {
return errs.Unwrap(err) == mongo.ErrNoDocuments
}

View File

@@ -0,0 +1,85 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewLogMongo(db *mongo.Database) (database.Log, error) {
coll := db.Collection(database.LogName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{
{Key: "log_id", Value: 1},
},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{
{Key: "user_id", Value: 1},
},
},
{
Keys: bson.D{
{Key: "create_time", Value: -1},
},
},
})
if err != nil {
return nil, err
}
return &LogMgo{coll: coll}, nil
}
type LogMgo struct {
coll *mongo.Collection
}
func (l *LogMgo) Create(ctx context.Context, log []*model.Log) error {
return mongoutil.InsertMany(ctx, l.coll, log)
}
func (l *LogMgo) Search(ctx context.Context, keyword string, start time.Time, end time.Time, pagination pagination.Pagination) (int64, []*model.Log, error) {
filter := bson.M{"create_time": bson.M{"$gte": start, "$lte": end}}
if keyword != "" {
filter["user_id"] = bson.M{"$regex": keyword}
}
return mongoutil.FindPage[*model.Log](ctx, l.coll, filter, pagination, options.Find().SetSort(bson.M{"create_time": -1}))
}
func (l *LogMgo) Delete(ctx context.Context, logID []string, userID string) error {
if userID == "" {
return mongoutil.DeleteMany(ctx, l.coll, bson.M{"log_id": bson.M{"$in": logID}})
}
return mongoutil.DeleteMany(ctx, l.coll, bson.M{"log_id": bson.M{"$in": logID}, "user_id": userID})
}
func (l *LogMgo) Get(ctx context.Context, logIDs []string, userID string) ([]*model.Log, error) {
if userID == "" {
return mongoutil.Find[*model.Log](ctx, l.coll, bson.M{"log_id": bson.M{"$in": logIDs}})
}
return mongoutil.Find[*model.Log](ctx, l.coll, bson.M{"log_id": bson.M{"$in": logIDs}, "user_id": userID})
}

View File

@@ -0,0 +1,183 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// MeetingMgo implements Meeting using MongoDB as the storage backend.
type MeetingMgo struct {
coll *mongo.Collection
}
// NewMeetingMongo creates a new instance of MeetingMgo with the provided MongoDB database.
func NewMeetingMongo(db *mongo.Database) (database.Meeting, error) {
coll := db.Collection(database.MeetingName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "meeting_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "creator_user_id", Value: 1}},
},
{
Keys: bson.D{{Key: "status", Value: 1}},
},
{
Keys: bson.D{{Key: "scheduled_time", Value: 1}},
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "update_time", Value: -1}},
},
{
Keys: bson.D{{Key: "subject", Value: "text"}, {Key: "description", Value: "text"}},
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &MeetingMgo{coll: coll}, nil
}
// Create creates a new meeting record.
func (m *MeetingMgo) Create(ctx context.Context, meeting *model.Meeting) error {
if meeting.CreateTime.IsZero() {
meeting.CreateTime = time.Now()
}
if meeting.UpdateTime.IsZero() {
meeting.UpdateTime = time.Now()
}
return mongoutil.InsertOne(ctx, m.coll, meeting)
}
// Take retrieves a meeting by meeting ID. Returns an error if not found.
func (m *MeetingMgo) Take(ctx context.Context, meetingID string) (*model.Meeting, error) {
return mongoutil.FindOne[*model.Meeting](ctx, m.coll, bson.M{"meeting_id": meetingID})
}
// Update updates meeting information.
func (m *MeetingMgo) Update(ctx context.Context, meetingID string, data map[string]any) error {
data["update_time"] = time.Now()
update := bson.M{"$set": data}
return mongoutil.UpdateOne(ctx, m.coll, bson.M{"meeting_id": meetingID}, update, false)
}
// UpdateStatus updates the status of a meeting.
func (m *MeetingMgo) UpdateStatus(ctx context.Context, meetingID string, status int32) error {
return m.Update(ctx, meetingID, map[string]any{"status": status})
}
// Find finds meetings by meeting IDs.
func (m *MeetingMgo) Find(ctx context.Context, meetingIDs []string) ([]*model.Meeting, error) {
if len(meetingIDs) == 0 {
return []*model.Meeting{}, nil
}
filter := bson.M{"meeting_id": bson.M{"$in": meetingIDs}}
return mongoutil.Find[*model.Meeting](ctx, m.coll, filter)
}
// FindByCreator finds meetings created by a specific user.
func (m *MeetingMgo) FindByCreator(ctx context.Context, creatorUserID string, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error) {
filter := bson.M{"creator_user_id": creatorUserID}
return mongoutil.FindPage[*model.Meeting](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "scheduled_time", Value: -1}},
})
}
// FindAll finds all meetings with pagination.
func (m *MeetingMgo) FindAll(ctx context.Context, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error) {
return mongoutil.FindPage[*model.Meeting](ctx, m.coll, bson.M{}, pagination, &options.FindOptions{
Sort: bson.D{{Key: "scheduled_time", Value: -1}},
})
}
// Search searches meetings by keyword (subject, description).
func (m *MeetingMgo) Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error) {
filter := bson.M{
"$or": []bson.M{
{"subject": bson.M{"$regex": keyword, "$options": "i"}},
{"description": bson.M{"$regex": keyword, "$options": "i"}},
},
}
return mongoutil.FindPage[*model.Meeting](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "scheduled_time", Value: -1}},
})
}
// FindByStatus finds meetings by status.
func (m *MeetingMgo) FindByStatus(ctx context.Context, status int32, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error) {
filter := bson.M{"status": status}
return mongoutil.FindPage[*model.Meeting](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "scheduled_time", Value: -1}},
})
}
// FindByScheduledTimeRange finds meetings within a scheduled time range.
func (m *MeetingMgo) FindByScheduledTimeRange(ctx context.Context, startTime, endTime int64, pagination pagination.Pagination) (total int64, meetings []*model.Meeting, err error) {
filter := bson.M{
"scheduled_time": bson.M{
"$gte": time.UnixMilli(startTime),
"$lte": time.UnixMilli(endTime),
},
}
return mongoutil.FindPage[*model.Meeting](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "scheduled_time", Value: 1}},
})
}
// FindFinishedMeetingsBefore finds finished meetings that ended before the specified time.
// A meeting is considered finished if its status is 3 (Finished) and its end time (scheduledTime + duration) is before beforeTime.
// Only returns meetings with a non-empty group_id to avoid processing meetings that have already been handled.
func (m *MeetingMgo) FindFinishedMeetingsBefore(ctx context.Context, beforeTime time.Time) ([]*model.Meeting, error) {
// 查询状态为3已结束且group_id不为空的会议
// 结束时间 = scheduledTime + duration分钟
// 需要计算scheduledTime + duration * 60秒 <= beforeTime
filter := bson.M{
"status": 3, // 已结束
"group_id": bson.M{"$ne": ""}, // 只查询group_id不为空的会议避免重复处理已清空groupID的会议
"$expr": bson.M{
"$lte": []interface{}{
bson.M{
"$add": []interface{}{
"$scheduled_time",
bson.M{"$multiply": []interface{}{"$duration", int64(60)}}, // duration是分钟转换为秒
},
},
beforeTime,
},
},
}
return mongoutil.Find[*model.Meeting](ctx, m.coll, filter)
}
// Delete deletes a meeting by meeting ID.
func (m *MeetingMgo) Delete(ctx context.Context, meetingID string) error {
return mongoutil.DeleteOne(ctx, m.coll, bson.M{"meeting_id": meetingID})
}

View File

@@ -0,0 +1,110 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// MeetingCheckInMgo implements MeetingCheckIn using MongoDB as the storage backend.
type MeetingCheckInMgo struct {
coll *mongo.Collection
}
// NewMeetingCheckInMongo creates a new instance of MeetingCheckInMgo with the provided MongoDB database.
func NewMeetingCheckInMongo(db *mongo.Database) (database.MeetingCheckIn, error) {
coll := db.Collection(database.MeetingCheckInName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "check_in_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "meeting_id", Value: 1}, {Key: "user_id", Value: 1}},
Options: options.Index().SetUnique(true), // 一个用户在一个会议中只能签到一次
},
{
Keys: bson.D{{Key: "meeting_id", Value: 1}, {Key: "check_in_time", Value: -1}},
},
{
Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "check_in_time", Value: -1}},
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &MeetingCheckInMgo{coll: coll}, nil
}
// Create creates a new meeting check-in record.
func (m *MeetingCheckInMgo) Create(ctx context.Context, checkIn *model.MeetingCheckIn) error {
if checkIn.CreateTime.IsZero() {
checkIn.CreateTime = time.Now()
}
if checkIn.CheckInTime.IsZero() {
checkIn.CheckInTime = time.Now()
}
return mongoutil.InsertOne(ctx, m.coll, checkIn)
}
// Take retrieves a check-in by check-in ID. Returns an error if not found.
func (m *MeetingCheckInMgo) Take(ctx context.Context, checkInID string) (*model.MeetingCheckIn, error) {
return mongoutil.FindOne[*model.MeetingCheckIn](ctx, m.coll, bson.M{"check_in_id": checkInID})
}
// FindByMeetingID finds all check-ins for a meeting with pagination.
func (m *MeetingCheckInMgo) FindByMeetingID(ctx context.Context, meetingID string, pagination pagination.Pagination) (total int64, checkIns []*model.MeetingCheckIn, err error) {
filter := bson.M{"meeting_id": meetingID}
return mongoutil.FindPage[*model.MeetingCheckIn](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "check_in_time", Value: -1}},
})
}
// FindByUserAndMeetingID finds if a user has checked in for a specific meeting.
func (m *MeetingCheckInMgo) FindByUserAndMeetingID(ctx context.Context, userID, meetingID string) (*model.MeetingCheckIn, error) {
return mongoutil.FindOne[*model.MeetingCheckIn](ctx, m.coll, bson.M{
"user_id": userID,
"meeting_id": meetingID,
})
}
// CountByMeetingID counts the number of check-ins for a meeting.
func (m *MeetingCheckInMgo) CountByMeetingID(ctx context.Context, meetingID string) (int64, error) {
return mongoutil.Count(ctx, m.coll, bson.M{"meeting_id": meetingID})
}
// FindByUser finds all check-ins by a user with pagination.
func (m *MeetingCheckInMgo) FindByUser(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, checkIns []*model.MeetingCheckIn, err error) {
filter := bson.M{"user_id": userID}
return mongoutil.FindPage[*model.MeetingCheckIn](ctx, m.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "check_in_time", Value: -1}},
})
}
// Delete deletes a check-in by check-in ID.
func (m *MeetingCheckInMgo) Delete(ctx context.Context, checkInID string) error {
return mongoutil.DeleteOne(ctx, m.coll, bson.M{"check_in_id": checkInID})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,178 @@
package mgo
import (
"context"
"math"
"math/rand"
"strconv"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/db/mongoutil"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestName1(t *testing.T) {
//ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
//defer cancel()
//cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.66:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
//
//v := &MsgMgo{
// coll: cli.Database("openim_v3").Collection("msg3"),
//}
//
//req := &msg.SearchMessageReq{
// //RecvID: "3187706596",
// //SendID: "7009965934",
// ContentType: 101,
// //SendTime: "2024-05-06",
// //SessionType: 3,
// Pagination: &sdkws.RequestPagination{
// PageNumber: 1,
// ShowNumber: 10,
// },
//}
//total, res, err := v.SearchMessage(ctx, req)
//if err != nil {
// panic(err)
//}
//
//for i, re := range res {
// t.Logf("%d => %d | %+v", i+1, re.Msg.Seq, re.Msg.Content)
//}
//
//t.Log(total)
//
//msg, err := NewMsgMongo(cli.Database("openim_v3"))
//if err != nil {
// panic(err)
//}
//res, err := msg.GetBeforeMsg(ctx, time.Now().UnixMilli(), []string{"1:0"}, 1000)
//if err != nil {
// panic(err)
//}
//t.Log(len(res))
}
func TestName10(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.48:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
v := &MsgMgo{
coll: cli.Database("openim_v3").Collection("msg3"),
}
opt := options.Find().SetLimit(1000)
res, err := mongoutil.Find[model.MsgDocModel](ctx, v.coll, bson.M{}, opt)
if err != nil {
panic(err)
}
ctx = context.Background()
for i := 0; i < 100000; i++ {
for j := range res {
res[j].DocID = strconv.FormatUint(rand.Uint64(), 10) + ":0"
}
if err := mongoutil.InsertMany(ctx, v.coll, res); err != nil {
panic(err)
}
t.Log("====>", time.Now(), i)
}
}
func TestName3(t *testing.T) {
t.Log(uint64(math.MaxUint64))
t.Log(int64(math.MaxInt64))
t.Log(int64(math.MinInt64))
}
func TestName4(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
defer cancel()
cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.135:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
msg, err := NewMsgMongo(cli.Database("openim_v3"))
if err != nil {
panic(err)
}
ts := time.Now().Add(-time.Hour * 24 * 5).UnixMilli()
t.Log(ts)
res, err := msg.GetLastMessageSeqByTime(ctx, "sg_1523453548", ts)
if err != nil {
panic(err)
}
t.Log(res)
}
func TestName5(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
defer cancel()
cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.135:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
tmp, err := NewMsgMongo(cli.Database("openim_v3"))
if err != nil {
panic(err)
}
msg := tmp.(*MsgMgo)
ts := time.Now().Add(-time.Hour * 24 * 5).UnixMilli()
t.Log(ts)
var seqs []int64
for i := 1; i < 256; i++ {
seqs = append(seqs, int64(i))
}
res, err := msg.FindSeqs(ctx, "si_4924054191_9511766539", seqs)
if err != nil {
panic(err)
}
t.Log(res)
}
//func TestName6(t *testing.T) {
// ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
// defer cancel()
// cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.135:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
//
// tmp, err := NewMsgMongo(cli.Database("openim_v3"))
// if err != nil {
// panic(err)
// }
// msg := tmp.(*MsgMgo)
// seq, sendTime, err := msg.findBeforeSendTime(ctx, "si_4924054191_9511766539", 1144)
// if err != nil {
// panic(err)
// }
// t.Log(seq, sendTime)
//}
func TestSearchMessage(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
defer cancel()
cli := Result(mongo.Connect(ctx, options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.135:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
msgMongo, err := NewMsgMongo(cli.Database("openim_v3"))
if err != nil {
panic(err)
}
ts := time.Now().Add(-time.Hour * 24 * 5).UnixMilli()
t.Log(ts)
req := &msg.SearchMessageReq{
//SendID: "yjz",
//RecvID: "aibot",
Pagination: &sdkws.RequestPagination{
PageNumber: 1,
ShowNumber: 20,
},
}
count, resp, err := msgMongo.SearchMessage(ctx, req)
if err != nil {
panic(err)
}
t.Log(resp, count)
}

View File

@@ -0,0 +1,126 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewS3Mongo(db *mongo.Database) (database.ObjectInfo, error) {
coll := db.Collection(database.ObjectName)
// Create index for name
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "name", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, errs.Wrap(err)
}
// Create index for create_time
_, err = coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "create_time", Value: 1},
},
})
if err != nil {
return nil, errs.Wrap(err)
}
// Create index for key
_, err = coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "key", Value: 1},
},
})
if err != nil {
return nil, errs.Wrap(err)
}
return &S3Mongo{coll: coll}, nil
}
type S3Mongo struct {
coll *mongo.Collection
}
func (o *S3Mongo) SetObject(ctx context.Context, obj *model.Object) error {
filter := bson.M{"name": obj.Name, "engine": obj.Engine}
update := bson.M{
"name": obj.Name,
"engine": obj.Engine,
"key": obj.Key,
"size": obj.Size,
"content_type": obj.ContentType,
"group": obj.Group,
"create_time": obj.CreateTime,
}
return mongoutil.UpdateOne(ctx, o.coll, filter, bson.M{"$set": update}, false, options.Update().SetUpsert(true))
}
func (o *S3Mongo) Take(ctx context.Context, engine string, name string) (*model.Object, error) {
if engine == "" {
return mongoutil.FindOne[*model.Object](ctx, o.coll, bson.M{"name": name})
}
return mongoutil.FindOne[*model.Object](ctx, o.coll, bson.M{"name": name, "engine": engine})
}
func (o *S3Mongo) Delete(ctx context.Context, engine string, name []string) error {
if len(name) == 0 {
return nil
}
return mongoutil.DeleteOne(ctx, o.coll, bson.M{"engine": engine, "name": bson.M{"$in": name}})
}
func (o *S3Mongo) FindExpirationObject(ctx context.Context, engine string, expiration time.Time, needDelType []string, count int64) ([]*model.Object, error) {
opt := options.Find()
if count > 0 {
opt.SetLimit(count)
}
return mongoutil.Find[*model.Object](ctx, o.coll, bson.M{
"engine": engine,
"create_time": bson.M{"$lt": expiration},
"group": bson.M{"$in": needDelType},
}, opt)
}
func (o *S3Mongo) GetKeyCount(ctx context.Context, engine string, key string) (int64, error) {
return mongoutil.Count(ctx, o.coll, bson.M{"engine": engine, "key": key})
}
func (o *S3Mongo) GetEngineCount(ctx context.Context, engine string) (int64, error) {
return mongoutil.Count(ctx, o.coll, bson.M{"engine": engine})
}
func (o *S3Mongo) GetEngineInfo(ctx context.Context, engine string, limit int, skip int) ([]*model.Object, error) {
return mongoutil.Find[*model.Object](ctx, o.coll, bson.M{"engine": engine}, options.Find().SetLimit(int64(limit)).SetSkip(int64(skip)))
}
func (o *S3Mongo) UpdateEngine(ctx context.Context, oldEngine, oldName string, newEngine string) error {
return mongoutil.UpdateOne(ctx, o.coll, bson.M{"engine": oldEngine, "name": oldName}, bson.M{"$set": bson.M{"engine": newEngine}}, false)
}

View File

@@ -0,0 +1,234 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// RedPacketMgo implements RedPacket using MongoDB as the storage backend.
type RedPacketMgo struct {
coll *mongo.Collection
}
// NewRedPacketMongo creates a new instance of RedPacketMgo with the provided MongoDB database.
func NewRedPacketMongo(db *mongo.Database) (database.RedPacket, error) {
coll := db.Collection(database.RedPacketName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "red_packet_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "send_user_id", Value: 1}, {Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "group_id", Value: 1}, {Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "expire_time", Value: 1}},
},
})
if err != nil {
return nil, err
}
return &RedPacketMgo{coll: coll}, nil
}
// Create creates a new red packet record.
func (r *RedPacketMgo) Create(ctx context.Context, redPacket *model.RedPacket) error {
if redPacket.CreateTime.IsZero() {
redPacket.CreateTime = time.Now()
}
return mongoutil.InsertOne(ctx, r.coll, redPacket)
}
// Take retrieves a red packet by ID. Returns an error if not found.
func (r *RedPacketMgo) Take(ctx context.Context, redPacketID string) (*model.RedPacket, error) {
return mongoutil.FindOne[*model.RedPacket](ctx, r.coll, bson.M{"red_packet_id": redPacketID})
}
// UpdateStatus updates the status of a red packet.
func (r *RedPacketMgo) UpdateStatus(ctx context.Context, redPacketID string, status int32) error {
return mongoutil.UpdateOne(ctx, r.coll, bson.M{"red_packet_id": redPacketID}, bson.M{"$set": bson.M{"status": status}}, false)
}
// UpdateRemain updates the remain amount and count of a red packet.
func (r *RedPacketMgo) UpdateRemain(ctx context.Context, redPacketID string, remainAmount int64, remainCount int32) error {
update := bson.M{
"$set": bson.M{
"remain_amount": remainAmount,
"remain_count": remainCount,
},
}
// If remain count is 0, update status to finished
if remainCount == 0 {
update["$set"].(bson.M)["status"] = model.RedPacketStatusFinished
}
return mongoutil.UpdateOne(ctx, r.coll, bson.M{"red_packet_id": redPacketID}, update, false)
}
// DecreaseRemainAtomic 原子性地减少红包剩余数量和金额(防止并发问题)
// 只有在 remain_count > 0 且状态为 Active 时才会更新
func (r *RedPacketMgo) DecreaseRemainAtomic(ctx context.Context, redPacketID string, amount int64) (*model.RedPacket, error) {
// 过滤条件红包ID匹配、剩余数量>0、状态为Active
filter := bson.M{
"red_packet_id": redPacketID,
"remain_count": bson.M{"$gt": 0},
"status": model.RedPacketStatusActive,
}
// 使用 $inc 原子性地减少剩余数量和金额
update := bson.M{
"$inc": bson.M{
"remain_amount": -amount,
"remain_count": -1,
},
}
// 使用 findOneAndUpdate 返回更新后的文档
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updatedRedPacket model.RedPacket
err := r.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedRedPacket)
if err != nil {
if err == mongo.ErrNoDocuments {
// 红包不存在、已领完或状态不正确
return nil, errs.ErrArgs.WrapMsg("red packet not available (already finished or expired)")
}
return nil, err
}
// 如果剩余数量为0更新状态为已完成
if updatedRedPacket.RemainCount == 0 {
statusUpdate := bson.M{"$set": bson.M{"status": model.RedPacketStatusFinished}}
_ = mongoutil.UpdateOne(ctx, r.coll, bson.M{"red_packet_id": redPacketID}, statusUpdate, false)
updatedRedPacket.Status = model.RedPacketStatusFinished
}
return &updatedRedPacket, nil
}
// FindExpiredRedPackets finds red packets that have expired.
func (r *RedPacketMgo) FindExpiredRedPackets(ctx context.Context, beforeTime time.Time) ([]*model.RedPacket, error) {
filter := bson.M{
"expire_time": bson.M{"$lt": beforeTime},
"status": model.RedPacketStatusActive,
}
return mongoutil.Find[*model.RedPacket](ctx, r.coll, filter)
}
// FindRedPacketsByUser finds red packets sent by a user with pagination.
func (r *RedPacketMgo) FindRedPacketsByUser(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, redPackets []*model.RedPacket, err error) {
filter := bson.M{"send_user_id": userID}
return mongoutil.FindPage[*model.RedPacket](ctx, r.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}
// FindRedPacketsByGroup finds red packets in a group with pagination.
func (r *RedPacketMgo) FindRedPacketsByGroup(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, redPackets []*model.RedPacket, err error) {
filter := bson.M{"group_id": groupID}
return mongoutil.FindPage[*model.RedPacket](ctx, r.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}
// FindAllRedPackets finds all red packets with pagination.
func (r *RedPacketMgo) FindAllRedPackets(ctx context.Context, pagination pagination.Pagination) (total int64, redPackets []*model.RedPacket, err error) {
return mongoutil.FindPage[*model.RedPacket](ctx, r.coll, bson.M{}, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}
// RedPacketReceiveMgo implements RedPacketReceive using MongoDB as the storage backend.
type RedPacketReceiveMgo struct {
coll *mongo.Collection
}
// NewRedPacketReceiveMongo creates a new instance of RedPacketReceiveMgo with the provided MongoDB database.
func NewRedPacketReceiveMongo(db *mongo.Database) (database.RedPacketReceive, error) {
coll := db.Collection(database.RedPacketReceiveName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "receive_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "red_packet_id", Value: 1}, {Key: "receive_time", Value: -1}},
},
{
Keys: bson.D{{Key: "receive_user_id", Value: 1}, {Key: "red_packet_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "receive_user_id", Value: 1}, {Key: "receive_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &RedPacketReceiveMgo{coll: coll}, nil
}
// Create creates a new red packet receive record.
func (r *RedPacketReceiveMgo) Create(ctx context.Context, receive *model.RedPacketReceive) error {
if receive.ReceiveTime.IsZero() {
receive.ReceiveTime = time.Now()
}
return mongoutil.InsertOne(ctx, r.coll, receive)
}
// Take retrieves a receive record by ID. Returns an error if not found.
func (r *RedPacketReceiveMgo) Take(ctx context.Context, receiveID string) (*model.RedPacketReceive, error) {
return mongoutil.FindOne[*model.RedPacketReceive](ctx, r.coll, bson.M{"receive_id": receiveID})
}
// FindByRedPacketID finds all receive records for a red packet.
func (r *RedPacketReceiveMgo) FindByRedPacketID(ctx context.Context, redPacketID string) ([]*model.RedPacketReceive, error) {
return mongoutil.Find[*model.RedPacketReceive](ctx, r.coll, bson.M{"red_packet_id": redPacketID}, &options.FindOptions{
Sort: bson.D{{Key: "receive_time", Value: 1}},
})
}
// FindByUserAndRedPacketID finds if a user has received a specific red packet.
func (r *RedPacketReceiveMgo) FindByUserAndRedPacketID(ctx context.Context, userID, redPacketID string) (*model.RedPacketReceive, error) {
return mongoutil.FindOne[*model.RedPacketReceive](ctx, r.coll, bson.M{
"receive_user_id": userID,
"red_packet_id": redPacketID,
})
}
// FindByUser finds all red packets received by a user with pagination.
func (r *RedPacketReceiveMgo) FindByUser(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, receives []*model.RedPacketReceive, err error) {
filter := bson.M{"receive_user_id": userID}
return mongoutil.FindPage[*model.RedPacketReceive](ctx, r.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "receive_time", Value: -1}},
})
}
// DeleteByReceiveID deletes a receive record by receive ID (for cleanup on failure).
func (r *RedPacketReceiveMgo) DeleteByReceiveID(ctx context.Context, receiveID string) error {
return mongoutil.DeleteOne(ctx, r.coll, bson.M{"receive_id": receiveID})
}

View File

@@ -0,0 +1,104 @@
package mgo
import (
"context"
"errors"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewSeqConversationMongo(db *mongo.Database) (database.SeqConversation, error) {
coll := db.Collection(database.SeqConversationName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "conversation_id", Value: 1},
},
})
if err != nil {
return nil, err
}
return &seqConversationMongo{coll: coll}, nil
}
type seqConversationMongo struct {
coll *mongo.Collection
}
func (s *seqConversationMongo) setSeq(ctx context.Context, conversationID string, seq int64, field string) error {
filter := map[string]any{
"conversation_id": conversationID,
}
insert := bson.M{
"conversation_id": conversationID,
"min_seq": 0,
"max_seq": 0,
}
delete(insert, field)
update := map[string]any{
"$set": bson.M{
field: seq,
},
"$setOnInsert": insert,
}
opt := options.Update().SetUpsert(true)
return mongoutil.UpdateOne(ctx, s.coll, filter, update, false, opt)
}
func (s *seqConversationMongo) Malloc(ctx context.Context, conversationID string, size int64) (int64, error) {
if size < 0 {
return 0, errors.New("size must be greater than 0")
}
if size == 0 {
return s.GetMaxSeq(ctx, conversationID)
}
filter := map[string]any{"conversation_id": conversationID}
update := map[string]any{
"$inc": map[string]any{"max_seq": size},
"$set": map[string]any{"min_seq": int64(0)},
}
opt := options.FindOneAndUpdate().SetUpsert(true).SetReturnDocument(options.After).SetProjection(map[string]any{"_id": 0, "max_seq": 1})
lastSeq, err := mongoutil.FindOneAndUpdate[int64](ctx, s.coll, filter, update, opt)
if err != nil {
return 0, err
}
return lastSeq - size, nil
}
func (s *seqConversationMongo) SetMaxSeq(ctx context.Context, conversationID string, seq int64) error {
return s.setSeq(ctx, conversationID, seq, "max_seq")
}
func (s *seqConversationMongo) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) {
seq, err := mongoutil.FindOne[int64](ctx, s.coll, bson.M{"conversation_id": conversationID}, options.FindOne().SetProjection(map[string]any{"_id": 0, "max_seq": 1}))
if err == nil {
return seq, nil
} else if IsNotFound(err) {
return 0, nil
} else {
return 0, err
}
}
func (s *seqConversationMongo) GetMinSeq(ctx context.Context, conversationID string) (int64, error) {
seq, err := mongoutil.FindOne[int64](ctx, s.coll, bson.M{"conversation_id": conversationID}, options.FindOne().SetProjection(map[string]any{"_id": 0, "min_seq": 1}))
if err == nil {
return seq, nil
} else if IsNotFound(err) {
return 0, nil
} else {
return 0, err
}
}
func (s *seqConversationMongo) SetMinSeq(ctx context.Context, conversationID string, seq int64) error {
return s.setSeq(ctx, conversationID, seq, "min_seq")
}
func (s *seqConversationMongo) GetConversation(ctx context.Context, conversationID string) (*model.SeqConversation, error) {
return mongoutil.FindOne[*model.SeqConversation](ctx, s.coll, bson.M{"conversation_id": conversationID})
}

View File

@@ -0,0 +1,43 @@
package mgo
import (
"context"
"testing"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func Result[V any](val V, err error) V {
if err != nil {
panic(err)
}
return val
}
func Mongodb() *mongo.Database {
return Result(
mongo.Connect(context.Background(),
options.Client().
ApplyURI("mongodb://openIM:openIM123@172.16.8.135:37017/openim_v3?maxPoolSize=100").
SetConnectTimeout(5*time.Second)),
).Database("openim_v3")
}
func TestUserSeq(t *testing.T) {
uSeq := Result(NewSeqUserMongo(Mongodb())).(*seqUserMongo)
t.Log(uSeq.SetUserMinSeq(context.Background(), "1000", "2000", 4))
}
func TestConversationSeq(t *testing.T) {
cSeq := Result(NewSeqConversationMongo(Mongodb())).(*seqConversationMongo)
t.Log(cSeq.SetMaxSeq(context.Background(), "2000", 10))
t.Log(cSeq.Malloc(context.Background(), "2000", 10))
t.Log(cSeq.GetMaxSeq(context.Background(), "2000"))
}
func TestUserGetUserReadSeqs(t *testing.T) {
uSeq := Result(NewSeqUserMongo(Mongodb())).(*seqUserMongo)
t.Log(uSeq.GetUserReadSeqs(context.Background(), "2110910952", []string{"sg_345762580", "2000", "3000"}))
}

View File

@@ -0,0 +1,127 @@
package mgo
import (
"context"
"errors"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewSeqUserMongo(db *mongo.Database) (database.SeqUser, error) {
coll := db.Collection(database.SeqUserName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "user_id", Value: 1},
{Key: "conversation_id", Value: 1},
},
})
if err != nil {
return nil, err
}
return &seqUserMongo{coll: coll}, nil
}
type seqUserMongo struct {
coll *mongo.Collection
}
func (s *seqUserMongo) setSeq(ctx context.Context, conversationID string, userID string, seq int64, field string) error {
filter := map[string]any{
"user_id": userID,
"conversation_id": conversationID,
}
insert := bson.M{
"user_id": userID,
"conversation_id": conversationID,
"min_seq": 0,
"max_seq": 0,
"read_seq": 0,
}
delete(insert, field)
update := map[string]any{
"$set": bson.M{
field: seq,
},
"$setOnInsert": insert,
}
opt := options.Update().SetUpsert(true)
return mongoutil.UpdateOne(ctx, s.coll, filter, update, false, opt)
}
func (s *seqUserMongo) getSeq(ctx context.Context, conversationID string, userID string, failed string) (int64, error) {
filter := map[string]any{
"user_id": userID,
"conversation_id": conversationID,
}
opt := options.FindOne().SetProjection(bson.M{"_id": 0, failed: 1})
seq, err := mongoutil.FindOne[int64](ctx, s.coll, filter, opt)
if err == nil {
return seq, nil
} else if errors.Is(err, mongo.ErrNoDocuments) {
return 0, nil
} else {
return 0, err
}
}
func (s *seqUserMongo) GetUserMaxSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return s.getSeq(ctx, conversationID, userID, "max_seq")
}
func (s *seqUserMongo) SetUserMaxSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
return s.setSeq(ctx, conversationID, userID, seq, "max_seq")
}
func (s *seqUserMongo) GetUserMinSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return s.getSeq(ctx, conversationID, userID, "min_seq")
}
func (s *seqUserMongo) SetUserMinSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
return s.setSeq(ctx, conversationID, userID, seq, "min_seq")
}
func (s *seqUserMongo) GetUserReadSeq(ctx context.Context, conversationID string, userID string) (int64, error) {
return s.getSeq(ctx, conversationID, userID, "read_seq")
}
func (s *seqUserMongo) notFoundSet0(seq map[string]int64, conversationIDs []string) {
for _, conversationID := range conversationIDs {
if _, ok := seq[conversationID]; !ok {
seq[conversationID] = 0
}
}
}
func (s *seqUserMongo) GetUserReadSeqs(ctx context.Context, userID string, conversationID []string) (map[string]int64, error) {
if len(conversationID) == 0 {
return map[string]int64{}, nil
}
filter := bson.M{"user_id": userID, "conversation_id": bson.M{"$in": conversationID}}
opt := options.Find().SetProjection(bson.M{"_id": 0, "conversation_id": 1, "read_seq": 1})
seqs, err := mongoutil.Find[*model.SeqUser](ctx, s.coll, filter, opt)
if err != nil {
return nil, err
}
res := make(map[string]int64)
for _, seq := range seqs {
res[seq.ConversationID] = seq.ReadSeq
}
s.notFoundSet0(res, conversationID)
return res, nil
}
func (s *seqUserMongo) SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error {
dbSeq, err := s.GetUserReadSeq(ctx, conversationID, userID)
if err != nil {
return err
}
if dbSeq > seq {
return nil
}
return s.setSeq(ctx, conversationID, userID, seq, "read_seq")
}

View File

@@ -0,0 +1,99 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// SystemConfigMgo implements SystemConfig using MongoDB as the storage backend.
type SystemConfigMgo struct {
coll *mongo.Collection
}
// NewSystemConfigMongo creates a new instance of SystemConfigMgo with the provided MongoDB database.
func NewSystemConfigMongo(db *mongo.Database) (database.SystemConfig, error) {
coll := db.Collection(database.SystemConfigName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "key", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "enabled", Value: 1}},
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &SystemConfigMgo{coll: coll}, nil
}
// Create creates a new system config record.
func (s *SystemConfigMgo) Create(ctx context.Context, config *model.SystemConfig) error {
config.CreateTime = time.Now()
config.UpdateTime = time.Now()
return mongoutil.InsertOne(ctx, s.coll, config)
}
// Take retrieves a system config by key. Returns an error if not found.
func (s *SystemConfigMgo) Take(ctx context.Context, key string) (*model.SystemConfig, error) {
return mongoutil.FindOne[*model.SystemConfig](ctx, s.coll, bson.M{"key": key})
}
// Update updates system config information.
func (s *SystemConfigMgo) Update(ctx context.Context, key string, data map[string]any) error {
data["update_time"] = time.Now()
return mongoutil.UpdateOne(ctx, s.coll, bson.M{"key": key}, bson.M{"$set": data}, true)
}
// Find finds system configs by keys.
func (s *SystemConfigMgo) Find(ctx context.Context, keys []string) ([]*model.SystemConfig, error) {
return mongoutil.Find[*model.SystemConfig](ctx, s.coll, bson.M{"key": bson.M{"$in": keys}})
}
// FindEnabled finds all enabled system configs.
func (s *SystemConfigMgo) FindEnabled(ctx context.Context) ([]*model.SystemConfig, error) {
return mongoutil.Find[*model.SystemConfig](ctx, s.coll, bson.M{"enabled": true})
}
// FindByKey finds a system config by key (returns nil if not found, no error).
func (s *SystemConfigMgo) FindByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
config, err := mongoutil.FindOne[*model.SystemConfig](ctx, s.coll, bson.M{"key": key})
if err != nil {
if errs.ErrRecordNotFound.Is(err) || err == mongo.ErrNoDocuments {
return nil, nil
}
return nil, err
}
return config, nil
}
// Delete deletes a system config by key.
func (s *SystemConfigMgo) Delete(ctx context.Context, key string) error {
return mongoutil.DeleteOne(ctx, s.coll, bson.M{"key": key})
}

View File

@@ -0,0 +1,699 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/user"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewUserMongo(db *mongo.Database) (database.User, error) {
coll := db.Collection(database.UserName)
_, err := coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{
Keys: bson.D{
{Key: "user_id", Value: 1},
},
Options: options.Index().SetUnique(true),
})
if err != nil {
return nil, errs.Wrap(err)
}
return &UserMgo{coll: coll}, nil
}
type UserMgo struct {
coll *mongo.Collection
}
func (u *UserMgo) Create(ctx context.Context, users []*model.User) error {
return mongoutil.InsertMany(ctx, u.coll, users)
}
func (u *UserMgo) UpdateByMap(ctx context.Context, userID string, args map[string]any) (err error) {
if len(args) == 0 {
return nil
}
return mongoutil.UpdateOne(ctx, u.coll, bson.M{"user_id": userID}, bson.M{"$set": args}, true)
}
func (u *UserMgo) Find(ctx context.Context, userIDs []string) (users []*model.User, err error) {
query := bson.M{"user_id": bson.M{"$in": userIDs}}
log.ZInfo(ctx, "UserMongo Find query", "collection", u.coll.Name(), "query", query)
users, err = mongoutil.Find[*model.User](ctx, u.coll, query)
log.ZInfo(ctx, "UserMongo Find result", "userCount", len(users), "err", err)
return users, err
}
func (u *UserMgo) Take(ctx context.Context, userID string) (user *model.User, err error) {
return mongoutil.FindOne[*model.User](ctx, u.coll, bson.M{"user_id": userID})
}
func (u *UserMgo) TakeNotification(ctx context.Context, level int64) (user []*model.User, err error) {
return mongoutil.Find[*model.User](ctx, u.coll, bson.M{"app_manger_level": level})
}
func (u *UserMgo) TakeGTEAppManagerLevel(ctx context.Context, level int64) (user []*model.User, err error) {
return mongoutil.Find[*model.User](ctx, u.coll, bson.M{"app_manger_level": bson.M{"$gte": level}})
}
func (u *UserMgo) TakeByNickname(ctx context.Context, nickname string) (user []*model.User, err error) {
return mongoutil.Find[*model.User](ctx, u.coll, bson.M{"nickname": nickname})
}
func (u *UserMgo) Page(ctx context.Context, pagination pagination.Pagination) (count int64, users []*model.User, err error) {
return mongoutil.FindPage[*model.User](ctx, u.coll, bson.M{}, pagination)
}
func (u *UserMgo) PageFindUser(ctx context.Context, level1 int64, level2 int64, pagination pagination.Pagination) (count int64, users []*model.User, err error) {
query := bson.M{
"$or": []bson.M{
{"app_manger_level": level1},
{"app_manger_level": level2},
},
}
return mongoutil.FindPage[*model.User](ctx, u.coll, query, pagination)
}
func (u *UserMgo) PageFindUserWithKeyword(
ctx context.Context,
level1 int64,
level2 int64,
userID string,
nickName string,
pagination pagination.Pagination,
) (count int64, users []*model.User, err error) {
// Initialize the base query with level conditions
query := bson.M{
"$and": []bson.M{
{"app_manger_level": bson.M{"$in": []int64{level1, level2}}},
},
}
// Add userID and userName conditions to the query if they are provided
if userID != "" || nickName != "" {
userConditions := []bson.M{}
if userID != "" {
// Use regex for userID
regexPattern := primitive.Regex{Pattern: userID, Options: "i"} // 'i' for case-insensitive matching
userConditions = append(userConditions, bson.M{"user_id": regexPattern})
}
if nickName != "" {
// Use regex for userName
regexPattern := primitive.Regex{Pattern: nickName, Options: "i"} // 'i' for case-insensitive matching
userConditions = append(userConditions, bson.M{"nickname": regexPattern})
}
query["$and"] = append(query["$and"].([]bson.M), bson.M{"$or": userConditions})
}
// Perform the paginated search
return mongoutil.FindPage[*model.User](ctx, u.coll, query, pagination)
}
func (u *UserMgo) GetAllUserID(ctx context.Context, pagination pagination.Pagination) (int64, []string, error) {
return mongoutil.FindPage[string](ctx, u.coll, bson.M{}, pagination, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}))
}
func (u *UserMgo) Exist(ctx context.Context, userID string) (exist bool, err error) {
return mongoutil.Exist(ctx, u.coll, bson.M{"user_id": userID})
}
func (u *UserMgo) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) {
return mongoutil.FindOne[int](ctx, u.coll, bson.M{"user_id": userID}, options.FindOne().SetProjection(bson.M{"_id": 0, "global_recv_msg_opt": 1}))
}
// SearchUsersByFields 根据多个字段搜索用户accountuserID、phone、nickname
// 返回匹配的用户ID列表
// 使用 MongoDB $lookup 实现连表查询attribute 集合和 user 集合
func (u *UserMgo) SearchUsersByFields(ctx context.Context, account, phone, nickname string) (userIDs []string, err error) {
log.ZInfo(ctx, "SearchUsersByFields START", "account", account, "phone", phone, "nickname", nickname)
// 获取 attribute 集合
attributeColl := u.coll.Database().Collection("attribute")
log.ZInfo(ctx, "SearchUsersByFields collections", "attributeCollection", "attribute", "userCollection", u.coll.Name(), "database", u.coll.Database().Name())
// 构建聚合管道,使用 $lookup 实现连表查询
pipeline := bson.A{}
// 第一步:从 attribute 集合开始,构建匹配条件
attributeMatch := bson.M{}
attributeOrConditions := []bson.M{}
if account != "" {
attributeOrConditions = append(attributeOrConditions, bson.M{"account": bson.M{"$regex": account, "$options": "i"}})
log.ZInfo(ctx, "SearchUsersByFields add account condition", "account", account)
}
if phone != "" {
// phone_number 在 attribute 集合中搜索(注意:字段名是 phone_number不是 phone
attributeOrConditions = append(attributeOrConditions, bson.M{"phone_number": bson.M{"$regex": phone, "$options": "i"}})
log.ZInfo(ctx, "SearchUsersByFields add phone condition", "phone", phone, "field", "phone_number")
}
// 判断查询策略:如果有 account 或 phone从 attribute 开始;如果只有 nickname从 user 开始
hasAttributeSearch := len(attributeOrConditions) > 0
hasUserSearch := nickname != ""
var startCollection *mongo.Collection // 记录起始集合
if hasAttributeSearch {
// 从 attribute 集合开始查询
startCollection = attributeColl
// 先直接查询 attribute 集合,看看实际的数据结构
type FullAttributeDoc map[string]interface{}
allDocs, err := mongoutil.Find[*FullAttributeDoc](ctx, attributeColl, bson.M{}, options.Find().SetLimit(5).SetProjection(bson.M{"_id": 0}))
if err == nil && len(allDocs) > 0 {
log.ZInfo(ctx, "SearchUsersByFields attribute sample documents (all fields)", "sampleCount", len(allDocs), "samples", allDocs)
}
// 尝试精确匹配手机号(使用正确的字段名 phone_number
exactPhoneMatch := bson.M{"phone_number": phone}
exactDocs, err := mongoutil.Find[*FullAttributeDoc](ctx, attributeColl, exactPhoneMatch, options.Find().SetLimit(5).SetProjection(bson.M{"_id": 0}))
if err == nil {
log.ZInfo(ctx, "SearchUsersByFields exact phone_number match", "phone", phone, "matchCount", len(exactDocs), "docs", exactDocs)
}
// 尝试正则匹配(不区分大小写,使用正确的字段名 phone_number
regexPhoneMatch := bson.M{"phone_number": bson.M{"$regex": phone, "$options": "i"}}
regexDocs, err := mongoutil.Find[*FullAttributeDoc](ctx, attributeColl, regexPhoneMatch, options.Find().SetLimit(5).SetProjection(bson.M{"_id": 0}))
if err == nil {
log.ZInfo(ctx, "SearchUsersByFields regex phone_number match", "phone", phone, "matchCount", len(regexDocs), "docs", regexDocs)
}
// 尝试查询包含 phone_number 字段的所有记录
hasPhoneFieldMatch := bson.M{"phone_number": bson.M{"$exists": true, "$ne": ""}}
hasPhoneDocs, err := mongoutil.Find[*FullAttributeDoc](ctx, attributeColl, hasPhoneFieldMatch, options.Find().SetLimit(5).SetProjection(bson.M{"_id": 0}))
if err == nil {
log.ZInfo(ctx, "SearchUsersByFields documents with phone_number field", "matchCount", len(hasPhoneDocs), "docs", hasPhoneDocs)
}
attributeMatch["$or"] = attributeOrConditions
pipeline = append(pipeline, bson.M{"$match": attributeMatch})
log.ZInfo(ctx, "SearchUsersByFields attribute match stage", "match", attributeMatch)
// 使用 $lookup 关联 user 集合
lookupStage := bson.M{
"$lookup": bson.M{
"from": u.coll.Name(), // user 集合名称
"localField": "user_id", // attribute 集合的字段
"foreignField": "user_id", // user 集合的字段
"as": "userInfo", // 关联后的字段名
},
}
pipeline = append(pipeline, lookupStage)
log.ZInfo(ctx, "SearchUsersByFields add lookup stage", "from", u.coll.Name(), "localField", "user_id", "foreignField", "user_id")
// 展开 userInfo 数组
pipeline = append(pipeline, bson.M{"$unwind": bson.M{
"path": "$userInfo",
"preserveNullAndEmptyArrays": true,
}})
// 如果有 nickname 条件,需要匹配 user 集合的 nickname
if hasUserSearch {
userMatch := bson.M{
"$or": []bson.M{
{"userInfo.nickname": bson.M{"$regex": nickname, "$options": "i"}},
{"userInfo": bson.M{"$exists": false}}, // 如果没有关联到 user也保留
},
}
pipeline = append(pipeline, bson.M{"$match": userMatch})
log.ZInfo(ctx, "SearchUsersByFields add nickname match", "nickname", nickname)
}
// 从 attribute 集合开始user_id 字段在根级别
pipeline = append(pipeline, bson.M{
"$project": bson.M{
"_id": 0,
"user_id": 1,
},
})
} else if hasUserSearch {
// 只有 nickname 条件,从 user 集合开始查询
startCollection = u.coll
userMatch := bson.M{"nickname": bson.M{"$regex": nickname, "$options": "i"}}
pipeline = append(pipeline, bson.M{"$match": userMatch})
log.ZInfo(ctx, "SearchUsersByFields user match stage", "match", userMatch)
// 从 user 集合开始user_id 字段在根级别
pipeline = append(pipeline, bson.M{
"$project": bson.M{
"_id": 0,
"user_id": 1,
},
})
} else {
// 没有任何搜索条件,返回空
log.ZInfo(ctx, "SearchUsersByFields no search conditions", "returning empty")
return []string{}, nil
}
// 第五步:去重
pipeline = append(pipeline, bson.M{
"$group": bson.M{
"_id": "$user_id",
"user_id": bson.M{"$first": "$user_id"},
},
})
// 第六步:只返回 user_id
pipeline = append(pipeline, bson.M{
"$project": bson.M{
"_id": 0,
"user_id": 1,
},
})
log.ZInfo(ctx, "SearchUsersByFields pipeline", "pipeline", pipeline, "startCollection", startCollection.Name())
// 执行聚合查询
type ResultDoc struct {
UserID string `bson:"user_id"`
}
results, err := mongoutil.Aggregate[*ResultDoc](ctx, startCollection, pipeline)
if err != nil {
log.ZError(ctx, "SearchUsersByFields Aggregate failed", err, "pipeline", pipeline)
return nil, err
}
log.ZInfo(ctx, "SearchUsersByFields Aggregate result", "resultCount", len(results), "results", results)
// 提取 user_id 列表
userIDs = make([]string, 0, len(results))
for _, result := range results {
if result.UserID != "" {
userIDs = append(userIDs, result.UserID)
}
}
log.ZInfo(ctx, "SearchUsersByFields FINAL result", "totalUserIDs", len(userIDs), "userIDs", userIDs)
return userIDs, nil
}
// 旧版本的实现(保留作为备份)
func (u *UserMgo) SearchUsersByFields_old(ctx context.Context, account, phone, nickname string) (userIDs []string, err error) {
log.ZInfo(ctx, "SearchUsersByFields START", "account", account, "phone", phone, "nickname", nickname)
userIDMap := make(map[string]bool) // 用于去重
// 获取 attribute 集合
attributeColl := u.coll.Database().Collection("attribute")
log.ZInfo(ctx, "SearchUsersByFields attribute collection", "collectionName", "attribute", "database", u.coll.Database().Name())
// 从 attribute 集合查询 account 和 phone
if account != "" || phone != "" {
attributeFilter := bson.M{}
attributeConditions := []bson.M{}
if account != "" {
// account 在 attribute 集合中搜索
attributeConditions = append(attributeConditions, bson.M{"account": bson.M{"$regex": account, "$options": "i"}})
log.ZInfo(ctx, "SearchUsersByFields add account condition", "account", account)
}
if phone != "" {
// phone 在 attribute 集合中搜索
attributeConditions = append(attributeConditions, bson.M{"phone": bson.M{"$regex": phone, "$options": "i"}})
log.ZInfo(ctx, "SearchUsersByFields add phone condition", "phone", phone)
}
if len(attributeConditions) > 0 {
attributeFilter["$or"] = attributeConditions
log.ZInfo(ctx, "SearchUsersByFields query attribute", "filter", attributeFilter, "account", account, "phone", phone, "conditionsCount", len(attributeConditions))
// attribute 集合的结构包含user_id, account, phone 等字段
type AttributeDoc struct {
UserID string `bson:"user_id"`
}
// 先尝试查询,看看集合是否存在数据
count, err := mongoutil.Count(ctx, attributeColl, bson.M{})
log.ZInfo(ctx, "SearchUsersByFields attribute collection total count", "count", count, "err", err)
// 尝试查询多条记录看看结构,特别是包含 phone 数据的记录
type SampleDoc struct {
UserID string `bson:"user_id"`
Account string `bson:"account"`
Phone string `bson:"phone"`
}
// 查询所有记录,看看实际的数据结构
samples, err := mongoutil.Find[*SampleDoc](ctx, attributeColl, bson.M{}, options.Find().SetLimit(10).SetProjection(bson.M{"_id": 0, "user_id": 1, "account": 1, "phone": 1}))
if err == nil && len(samples) > 0 {
log.ZInfo(ctx, "SearchUsersByFields attribute sample documents", "sampleCount", len(samples), "samples", samples)
// 尝试查询包含 phone 字段不为空的记录
phoneFilter := bson.M{"phone": bson.M{"$exists": true, "$ne": ""}}
phoneSamples, err := mongoutil.Find[*SampleDoc](ctx, attributeColl, phoneFilter, options.Find().SetLimit(5).SetProjection(bson.M{"_id": 0, "user_id": 1, "account": 1, "phone": 1}))
if err == nil {
log.ZInfo(ctx, "SearchUsersByFields attribute documents with phone", "phoneSampleCount", len(phoneSamples), "phoneSamples", phoneSamples)
} else {
log.ZWarn(ctx, "SearchUsersByFields cannot find documents with phone", err)
}
// 尝试查询所有字段,看看实际的数据结构(只查一条)
type FullDoc map[string]interface{}
fullSample, err := mongoutil.FindOne[*FullDoc](ctx, attributeColl, bson.M{}, options.FindOne().SetProjection(bson.M{"_id": 0}))
if err == nil && fullSample != nil {
log.ZInfo(ctx, "SearchUsersByFields attribute full document structure", "fullSample", fullSample)
}
} else {
log.ZWarn(ctx, "SearchUsersByFields cannot get samples from attribute", err, "sampleCount", len(samples))
}
// 尝试精确匹配手机号,看看是否有数据
exactPhoneFilter := bson.M{"phone": phone}
exactCount, err := mongoutil.Count(ctx, attributeColl, exactPhoneFilter)
log.ZInfo(ctx, "SearchUsersByFields exact phone match count", "phone", phone, "count", exactCount, "err", err)
attributeDocs, err := mongoutil.Find[*AttributeDoc](ctx, attributeColl, attributeFilter, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}))
if err != nil {
log.ZError(ctx, "SearchUsersByFields Find failed in attribute collection", err, "filter", attributeFilter)
return nil, err
}
log.ZInfo(ctx, "SearchUsersByFields Find result from attribute", "userCount", len(attributeDocs), "userIDs", attributeDocs)
for i, doc := range attributeDocs {
log.ZDebug(ctx, "SearchUsersByFields processing attribute doc", "index", i, "userID", doc.UserID)
if doc.UserID != "" && !userIDMap[doc.UserID] {
userIDMap[doc.UserID] = true
log.ZDebug(ctx, "SearchUsersByFields added userID from attribute", "userID", doc.UserID)
}
}
}
}
// 从 user 集合查询 nickname
if nickname != "" {
userFilter := bson.M{"nickname": bson.M{"$regex": nickname, "$options": "i"}}
log.ZInfo(ctx, "SearchUsersByFields query user", "filter", userFilter, "nickname", nickname)
users, err := mongoutil.Find[*model.User](ctx, u.coll, userFilter, options.Find().SetProjection(bson.M{"_id": 0, "user_id": 1}))
if err != nil {
log.ZError(ctx, "SearchUsersByFields Find failed in user collection", err, "filter", userFilter)
return nil, err
}
log.ZInfo(ctx, "SearchUsersByFields Find result from user", "userCount", len(users))
for i, user := range users {
log.ZDebug(ctx, "SearchUsersByFields processing user doc", "index", i, "userID", user.UserID)
if user.UserID != "" && !userIDMap[user.UserID] {
userIDMap[user.UserID] = true
log.ZDebug(ctx, "SearchUsersByFields added userID from user", "userID", user.UserID)
}
}
}
// 将 map 转换为 slice
userIDs = make([]string, 0, len(userIDMap))
for userID := range userIDMap {
userIDs = append(userIDs, userID)
}
log.ZInfo(ctx, "SearchUsersByFields FINAL result", "totalUserIDs", len(userIDs), "userIDs", userIDs)
return userIDs, nil
}
func (u *UserMgo) CountTotal(ctx context.Context, before *time.Time) (count int64, err error) {
if before == nil {
return mongoutil.Count(ctx, u.coll, bson.M{})
}
return mongoutil.Count(ctx, u.coll, bson.M{"create_time": bson.M{"$lt": before}})
}
func (u *UserMgo) AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error {
collection := u.coll.Database().Collection("userCommands")
// Create a new document instead of updating an existing one
doc := bson.M{
"userID": userID,
"type": Type,
"uuid": UUID,
"createTime": time.Now().Unix(), // assuming you want the creation time in Unix timestamp
"value": value,
"ex": ex,
}
_, err := collection.InsertOne(ctx, doc)
return errs.Wrap(err)
}
func (u *UserMgo) DeleteUserCommand(ctx context.Context, userID string, Type int32, UUID string) error {
collection := u.coll.Database().Collection("userCommands")
filter := bson.M{"userID": userID, "type": Type, "uuid": UUID}
result, err := collection.DeleteOne(ctx, filter)
// when err is not nil, result might be nil
if err != nil {
return errs.Wrap(err)
}
if result.DeletedCount == 0 {
// No records found to update
return errs.Wrap(errs.ErrRecordNotFound)
}
return errs.Wrap(err)
}
func (u *UserMgo) UpdateUserCommand(ctx context.Context, userID string, Type int32, UUID string, val map[string]any) error {
if len(val) == 0 {
return nil
}
collection := u.coll.Database().Collection("userCommands")
filter := bson.M{"userID": userID, "type": Type, "uuid": UUID}
update := bson.M{"$set": val}
result, err := collection.UpdateOne(ctx, filter, update)
if err != nil {
return errs.Wrap(err)
}
if result.MatchedCount == 0 {
// No records found to update
return errs.Wrap(errs.ErrRecordNotFound)
}
return nil
}
func (u *UserMgo) GetUserCommand(ctx context.Context, userID string, Type int32) ([]*user.CommandInfoResp, error) {
collection := u.coll.Database().Collection("userCommands")
filter := bson.M{"userID": userID, "type": Type}
cursor, err := collection.Find(ctx, filter)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
// Initialize commands as a slice of pointers
commands := []*user.CommandInfoResp{}
for cursor.Next(ctx) {
var document struct {
Type int32 `bson:"type"`
UUID string `bson:"uuid"`
Value string `bson:"value"`
CreateTime int64 `bson:"createTime"`
Ex string `bson:"ex"`
}
if err := cursor.Decode(&document); err != nil {
return nil, err
}
commandInfo := &user.CommandInfoResp{
Type: document.Type,
Uuid: document.UUID,
Value: document.Value,
CreateTime: document.CreateTime,
Ex: document.Ex,
}
commands = append(commands, commandInfo)
}
if err := cursor.Err(); err != nil {
return nil, errs.Wrap(err)
}
return commands, nil
}
func (u *UserMgo) GetAllUserCommand(ctx context.Context, userID string) ([]*user.AllCommandInfoResp, error) {
collection := u.coll.Database().Collection("userCommands")
filter := bson.M{"userID": userID}
cursor, err := collection.Find(ctx, filter)
if err != nil {
return nil, errs.Wrap(err)
}
defer cursor.Close(ctx)
// Initialize commands as a slice of pointers
commands := []*user.AllCommandInfoResp{}
for cursor.Next(ctx) {
var document struct {
Type int32 `bson:"type"`
UUID string `bson:"uuid"`
Value string `bson:"value"`
CreateTime int64 `bson:"createTime"`
Ex string `bson:"ex"`
}
if err := cursor.Decode(&document); err != nil {
return nil, errs.Wrap(err)
}
commandInfo := &user.AllCommandInfoResp{
Type: document.Type,
Uuid: document.UUID,
Value: document.Value,
CreateTime: document.CreateTime,
Ex: document.Ex,
}
commands = append(commands, commandInfo)
}
if err := cursor.Err(); err != nil {
return nil, errs.Wrap(err)
}
return commands, nil
}
func (u *UserMgo) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) {
pipeline := bson.A{
bson.M{
"$match": bson.M{
"create_time": bson.M{
"$gte": start,
"$lt": end,
},
},
},
bson.M{
"$group": bson.M{
"_id": bson.M{
"$dateToString": bson.M{
"format": "%Y-%m-%d",
"date": "$create_time",
},
},
"count": bson.M{
"$sum": 1,
},
},
},
}
type Item struct {
Date string `bson:"_id"`
Count int64 `bson:"count"`
}
items, err := mongoutil.Aggregate[Item](ctx, u.coll, pipeline)
if err != nil {
return nil, err
}
res := make(map[string]int64, len(items))
for _, item := range items {
res[item.Date] = item.Count
}
return res, nil
}
func (u *UserMgo) SortQuery(ctx context.Context, userIDName map[string]string, asc bool) ([]*model.User, error) {
if len(userIDName) == 0 {
return nil, nil
}
userIDs := make([]string, 0, len(userIDName))
attached := make(map[string]string)
for userID, name := range userIDName {
userIDs = append(userIDs, userID)
if name == "" {
continue
}
attached[userID] = name
}
var sortValue int
if asc {
sortValue = 1
} else {
sortValue = -1
}
if len(attached) == 0 {
filter := bson.M{"user_id": bson.M{"$in": userIDs}}
opt := options.Find().SetSort(bson.M{"nickname": sortValue})
return mongoutil.Find[*model.User](ctx, u.coll, filter, opt)
}
pipeline := []bson.M{
{
"$match": bson.M{
"user_id": bson.M{"$in": userIDs},
},
},
{
"$addFields": bson.M{
"_query_sort_name": bson.M{
"$arrayElemAt": []any{
bson.M{
"$filter": bson.M{
"input": bson.M{
"$objectToArray": attached,
},
"as": "item",
"cond": bson.M{
"$eq": []any{"$$item.k", "$user_id"},
},
},
},
0,
},
},
},
},
{
"$addFields": bson.M{
"_query_sort_name": bson.M{
"$ifNull": []any{"$_query_sort_name.v", "$nickname"},
},
},
},
{
"$sort": bson.M{
"_query_sort_name": sortValue,
},
},
}
return mongoutil.Aggregate[*model.User](ctx, u.coll, pipeline)
}

View File

@@ -0,0 +1,304 @@
package mgo
import (
"context"
"errors"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/versionctx"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func NewVersionLog(coll *mongo.Collection) (database.VersionLog, error) {
lm := &VersionLogMgo{coll: coll}
if err := lm.initIndex(context.Background()); err != nil {
return nil, errs.WrapMsg(err, "init version log index failed", "coll", coll.Name())
}
return lm, nil
}
type VersionLogMgo struct {
coll *mongo.Collection
}
func (l *VersionLogMgo) initIndex(ctx context.Context) error {
_, err := l.coll.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.M{
"d_id": 1,
},
Options: options.Index().SetUnique(true),
})
return err
}
func (l *VersionLogMgo) IncrVersion(ctx context.Context, dId string, eIds []string, state int32) error {
_, err := l.IncrVersionResult(ctx, dId, eIds, state)
return err
}
func (l *VersionLogMgo) IncrVersionResult(ctx context.Context, dId string, eIds []string, state int32) (*model.VersionLog, error) {
vl, err := l.incrVersionResult(ctx, dId, eIds, state)
if err != nil {
return nil, err
}
versionctx.GetVersionLog(ctx).Append(versionctx.Collection{
Name: l.coll.Name(),
Doc: vl,
})
return vl, nil
}
func (l *VersionLogMgo) incrVersionResult(ctx context.Context, dId string, eIds []string, state int32) (*model.VersionLog, error) {
if len(eIds) == 0 {
return nil, errs.ErrArgs.WrapMsg("elem id is empty", "dId", dId)
}
now := time.Now()
if res, err := l.writeLogBatch2(ctx, dId, eIds, state, now); err == nil {
return res, nil
} else if !errors.Is(err, mongo.ErrNoDocuments) {
return nil, err
}
if res, err := l.initDoc(ctx, dId, eIds, state, now); err == nil {
return res, nil
} else if !mongo.IsDuplicateKeyError(err) {
return nil, err
}
return l.writeLogBatch2(ctx, dId, eIds, state, now)
}
func (l *VersionLogMgo) initDoc(ctx context.Context, dId string, eIds []string, state int32, now time.Time) (*model.VersionLog, error) {
wl := model.VersionLogTable{
ID: primitive.NewObjectID(),
DID: dId,
Logs: make([]model.VersionLogElem, 0, len(eIds)),
Version: database.FirstVersion,
Deleted: database.DefaultDeleteVersion,
LastUpdate: now,
}
for _, eId := range eIds {
wl.Logs = append(wl.Logs, model.VersionLogElem{
EID: eId,
State: state,
Version: database.FirstVersion,
LastUpdate: now,
})
}
if _, err := l.coll.InsertOne(ctx, &wl); err != nil {
return nil, err
}
return wl.VersionLog(), nil
}
func (l *VersionLogMgo) writeLogBatch2(ctx context.Context, dId string, eIds []string, state int32, now time.Time) (*model.VersionLog, error) {
if eIds == nil {
eIds = []string{}
}
filter := bson.M{
"d_id": dId,
}
elems := make([]bson.M, 0, len(eIds))
for _, eId := range eIds {
elems = append(elems, bson.M{
"e_id": eId,
"version": "$version",
"state": state,
"last_update": now,
})
}
pipeline := []bson.M{
{
"$addFields": bson.M{
"delete_e_ids": eIds,
},
},
{
"$set": bson.M{
"version": bson.M{"$add": []any{"$version", 1}},
"last_update": now,
},
},
{
"$set": bson.M{
"logs": bson.M{
"$filter": bson.M{
"input": "$logs",
"as": "log",
"cond": bson.M{
"$not": bson.M{
"$in": []any{"$$log.e_id", "$delete_e_ids"},
},
},
},
},
},
},
{
"$set": bson.M{
"logs": bson.M{
"$concatArrays": []any{
"$logs",
elems,
},
},
},
},
{
"$unset": "delete_e_ids",
},
}
projection := bson.M{
"logs": 0,
}
opt := options.FindOneAndUpdate().SetUpsert(false).SetReturnDocument(options.After).SetProjection(projection)
res, err := mongoutil.FindOneAndUpdate[*model.VersionLog](ctx, l.coll, filter, pipeline, opt)
if err != nil {
return nil, err
}
res.Logs = make([]model.VersionLogElem, 0, len(eIds))
for _, id := range eIds {
res.Logs = append(res.Logs, model.VersionLogElem{
EID: id,
State: state,
Version: res.Version,
LastUpdate: res.LastUpdate,
})
}
return res, nil
}
func (l *VersionLogMgo) findDoc(ctx context.Context, dId string) (*model.VersionLog, error) {
vl, err := mongoutil.FindOne[*model.VersionLogTable](ctx, l.coll, bson.M{"d_id": dId}, options.FindOne().SetProjection(bson.M{"logs": 0}))
if err != nil {
return nil, err
}
return vl.VersionLog(), nil
}
func (l *VersionLogMgo) FindChangeLog(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error) {
if wl, err := l.findChangeLog(ctx, dId, version, limit); err == nil {
return wl, nil
} else if !errors.Is(err, mongo.ErrNoDocuments) {
return nil, err
}
log.ZDebug(ctx, "init doc", "dId", dId)
if res, err := l.initDoc(ctx, dId, nil, 0, time.Now()); err == nil {
log.ZDebug(ctx, "init doc success", "dId", dId)
return res, nil
} else if mongo.IsDuplicateKeyError(err) {
return l.findChangeLog(ctx, dId, version, limit)
} else {
return nil, err
}
}
func (l *VersionLogMgo) BatchFindChangeLog(ctx context.Context, dIds []string, versions []uint, limits []int) (vLogs []*model.VersionLog, err error) {
for i := 0; i < len(dIds); i++ {
if vLog, err := l.findChangeLog(ctx, dIds[i], versions[i], limits[i]); err == nil {
vLogs = append(vLogs, vLog)
} else if !errors.Is(err, mongo.ErrNoDocuments) {
log.ZError(ctx, "findChangeLog error:", errs.Wrap(err))
}
log.ZDebug(ctx, "init doc", "dId", dIds[i])
if res, err := l.initDoc(ctx, dIds[i], nil, 0, time.Now()); err == nil {
log.ZDebug(ctx, "init doc success", "dId", dIds[i])
vLogs = append(vLogs, res)
} else if mongo.IsDuplicateKeyError(err) {
l.findChangeLog(ctx, dIds[i], versions[i], limits[i])
} else {
log.ZError(ctx, "init doc error:", errs.Wrap(err))
}
}
return vLogs, errs.Wrap(err)
}
func (l *VersionLogMgo) findChangeLog(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error) {
if version == 0 && limit == 0 {
return l.findDoc(ctx, dId)
}
pipeline := []bson.M{
{
"$match": bson.M{
"d_id": dId,
},
},
{
"$addFields": bson.M{
"logs": bson.M{
"$cond": bson.M{
"if": bson.M{
"$or": []bson.M{
{"$lt": []any{"$version", version}},
{"$gte": []any{"$deleted", version}},
},
},
"then": []any{},
"else": "$logs",
},
},
},
},
{
"$addFields": bson.M{
"logs": bson.M{
"$filter": bson.M{
"input": "$logs",
"as": "l",
"cond": bson.M{
"$gt": []any{"$$l.version", version},
},
},
},
},
},
{
"$addFields": bson.M{
"log_len": bson.M{"$size": "$logs"},
},
},
{
"$addFields": bson.M{
"logs": bson.M{
"$cond": bson.M{
"if": bson.M{
"$gt": []any{"$log_len", limit},
},
"then": []any{},
"else": "$logs",
},
},
},
},
}
if limit <= 0 {
pipeline = pipeline[:len(pipeline)-1]
}
vl, err := mongoutil.Aggregate[*model.VersionLog](ctx, l.coll, pipeline)
if err != nil {
return nil, err
}
if len(vl) == 0 {
return nil, mongo.ErrNoDocuments
}
return vl[0], nil
}
func (l *VersionLogMgo) DeleteAfterUnchangedLog(ctx context.Context, deadline time.Time) error {
return mongoutil.DeleteMany(ctx, l.coll, bson.M{
"last_update": bson.M{
"$lt": deadline,
},
})
}
func (l *VersionLogMgo) Delete(ctx context.Context, dId string) error {
return mongoutil.DeleteOne(ctx, l.coll, bson.M{"d_id": dId})
}

View File

@@ -0,0 +1,40 @@
package mgo
import (
"context"
"testing"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
//func Result[V any](val V, err error) V {
// if err != nil {
// panic(err)
// }
// return val
//}
func Check(err error) {
if err != nil {
panic(err)
}
}
func TestName(t *testing.T) {
cli := Result(mongo.Connect(context.Background(), options.Client().ApplyURI("mongodb://openIM:openIM123@172.16.8.48:37017/openim_v3?maxPoolSize=100").SetConnectTimeout(5*time.Second)))
coll := cli.Database("openim_v3").Collection("version_test")
tmp, err := NewVersionLog(coll)
if err != nil {
panic(err)
}
vl := tmp.(*VersionLogMgo)
res, err := vl.incrVersionResult(context.Background(), "100", []string{"1000", "1001", "1003"}, model.VersionStateInsert)
if err != nil {
t.Log(err)
return
}
t.Logf("%+v", res)
}

View File

@@ -0,0 +1,231 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// WalletMgo implements Wallet using MongoDB as the storage backend.
type WalletMgo struct {
coll *mongo.Collection
}
// NewWalletMongo creates a new instance of WalletMgo with the provided MongoDB database.
func NewWalletMongo(db *mongo.Database) (database.Wallet, error) {
coll := db.Collection(database.WalletName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "user_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "update_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &WalletMgo{coll: coll}, nil
}
// Create creates a new wallet record.
func (w *WalletMgo) Create(ctx context.Context, wallet *model.Wallet) error {
if wallet.CreateTime.IsZero() {
wallet.CreateTime = time.Now()
}
if wallet.UpdateTime.IsZero() {
wallet.UpdateTime = time.Now()
}
if wallet.Version == 0 {
wallet.Version = 1
}
return mongoutil.InsertOne(ctx, w.coll, wallet)
}
// Take retrieves a wallet by user ID. Returns an error if not found.
func (w *WalletMgo) Take(ctx context.Context, userID string) (*model.Wallet, error) {
return mongoutil.FindOne[*model.Wallet](ctx, w.coll, bson.M{"user_id": userID})
}
// UpdateBalance updates the balance of a wallet.
func (w *WalletMgo) UpdateBalance(ctx context.Context, userID string, balance int64) error {
update := bson.M{
"$set": bson.M{
"balance": balance,
"update_time": time.Now(),
},
}
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
}
// UpdateBalanceByAmount updates the balance by adding/subtracting an amount.
func (w *WalletMgo) UpdateBalanceByAmount(ctx context.Context, userID string, amount int64) error {
update := bson.M{
"$inc": bson.M{
"balance": amount,
},
"$set": bson.M{
"update_time": time.Now(),
},
}
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
}
// UpdateBalanceWithVersion 使用版本号更新余额(防止并发覆盖)
func (w *WalletMgo) UpdateBalanceWithVersion(ctx context.Context, params *database.WalletUpdateParams) (*database.WalletUpdateResult, error) {
// 基于单文档原子操作,避免并发覆盖
if params.Amount < 0 {
return nil, errs.ErrArgs.WrapMsg("amount cannot be negative")
}
// 兼容旧数据(无 version 字段或 version=0
filter := bson.M{
"user_id": params.UserID,
"$or": []bson.M{
{"version": params.OldVersion},
{"version": bson.M{"$exists": false}},
{"version": 0},
},
}
// 默认更新:更新时间、版本号自增
update := bson.M{
"$set": bson.M{
"update_time": time.Now(),
},
"$inc": bson.M{
"version": 1,
},
}
switch params.Operation {
case "add":
update["$inc"].(bson.M)["balance"] = params.Amount
case "subtract":
update["$inc"].(bson.M)["balance"] = -params.Amount
// 防止出现负数,过滤条件要求当前余额 >= amount
filter["balance"] = bson.M{"$gte": params.Amount}
case "set":
// 直接设置余额,同时递增版本
delete(update, "$inc")
update["$set"].(bson.M)["balance"] = params.Amount
update["$set"].(bson.M)["version"] = params.OldVersion + 1
default:
return nil, errs.ErrArgs.WrapMsg("invalid operation: " + params.Operation)
}
// 如果 set 分支删除了 $inc需要补上版本递增
if _, ok := update["$inc"]; !ok {
update["$inc"] = bson.M{"version": 1}
} else if _, ok := update["$inc"].(bson.M)["version"]; !ok {
update["$inc"].(bson.M)["version"] = 1
}
// 使用 findOneAndUpdate 返回更新后的文档
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updatedWallet model.Wallet
err := w.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedWallet)
if err != nil {
if err == mongo.ErrNoDocuments {
// 版本号或余额不匹配,说明有并发修改
return nil, errs.ErrInternalServer.WrapMsg("concurrent modification detected: version or balance mismatch")
}
return nil, err
}
return &database.WalletUpdateResult{
NewBalance: updatedWallet.Balance,
NewVersion: updatedWallet.Version,
Success: true,
}, nil
}
// FindAllWallets finds all wallets with pagination.
func (w *WalletMgo) FindAllWallets(ctx context.Context, pagination pagination.Pagination) (total int64, wallets []*model.Wallet, err error) {
return mongoutil.FindPage[*model.Wallet](ctx, w.coll, bson.M{}, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}
// FindWalletsByUserIDs finds wallets by user IDs.
func (w *WalletMgo) FindWalletsByUserIDs(ctx context.Context, userIDs []string) ([]*model.Wallet, error) {
if len(userIDs) == 0 {
return []*model.Wallet{}, nil
}
filter := bson.M{"user_id": bson.M{"$in": userIDs}}
return mongoutil.Find[*model.Wallet](ctx, w.coll, filter)
}
// WalletBalanceRecordMgo implements WalletBalanceRecord using MongoDB as the storage backend.
type WalletBalanceRecordMgo struct {
coll *mongo.Collection
}
// NewWalletBalanceRecordMongo creates a new instance of WalletBalanceRecordMgo with the provided MongoDB database.
func NewWalletBalanceRecordMongo(db *mongo.Database) (database.WalletBalanceRecord, error) {
coll := db.Collection(database.WalletBalanceRecordName)
// 先尝试删除可能存在的旧 record_id 索引(如果不是稀疏索引)
// 忽略删除失败的错误(索引可能不存在)
_, _ = coll.Indexes().DropOne(context.Background(), "record_id_1")
// 创建索引,使用稀疏索引允许 record_id 为 null
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "record_id", Value: 1}},
Options: options.Index().SetUnique(true).SetSparse(true),
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &WalletBalanceRecordMgo{coll: coll}, nil
}
// Create creates a new wallet balance record.
func (w *WalletBalanceRecordMgo) Create(ctx context.Context, record *model.WalletBalanceRecord) error {
if record.CreateTime.IsZero() {
record.CreateTime = time.Now()
}
return mongoutil.InsertOne(ctx, w.coll, record)
}
// FindByUserID finds all balance records for a user with pagination.
func (w *WalletBalanceRecordMgo) FindByUserID(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, records []*model.WalletBalanceRecord, err error) {
filter := bson.M{"user_id": userID}
return mongoutil.FindPage[*model.WalletBalanceRecord](ctx, w.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}

Some files were not shown because too many files have changed in this diff Show More