复制项目

This commit is contained in:
kim.dev.6789
2026-01-14 22:16:44 +08:00
parent e2577b8cee
commit e50142a3b9
691 changed files with 97009 additions and 1 deletions

45
internal/api/auth.go Normal file
View File

@@ -0,0 +1,45 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/auth"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
)
type AuthApi struct {
Client auth.AuthClient
}
func NewAuthApi(client auth.AuthClient) AuthApi {
return AuthApi{client}
}
func (o *AuthApi) GetAdminToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.GetAdminToken, o.Client)
}
func (o *AuthApi) GetUserToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.GetUserToken, o.Client)
}
func (o *AuthApi) ParseToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.ParseToken, o.Client)
}
func (o *AuthApi) ForceLogout(c *gin.Context) {
a2r.Call(c, auth.AuthClient.ForceLogout, o.Client)
}

View File

@@ -0,0 +1,413 @@
package api
import (
"encoding/json"
"reflect"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/discovery/etcd"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/version"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/runtimeenv"
clientv3 "go.etcd.io/etcd/client/v3"
)
const (
// wait for Restart http call return
waitHttp = time.Millisecond * 200
)
type ConfigManager struct {
imAdminUserID []string
config *config.AllConfig
client *clientv3.Client
configPath string
systemConfigDB database.SystemConfig
}
func NewConfigManager(IMAdminUserID []string, cfg *config.AllConfig, client *clientv3.Client, configPath string, systemConfigDB database.SystemConfig) *ConfigManager {
cm := &ConfigManager{
imAdminUserID: IMAdminUserID,
config: cfg,
client: client,
configPath: configPath,
systemConfigDB: systemConfigDB,
}
return cm
}
func (cm *ConfigManager) CheckAdmin(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
c.Abort()
}
}
func (cm *ConfigManager) GetConfig(c *gin.Context) {
var req apistruct.GetConfigReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
conf := cm.config.Name2Config(req.ConfigName)
if conf == nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail("config name not found").Wrap())
return
}
b, err := json.Marshal(conf)
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, string(b))
}
func (cm *ConfigManager) GetConfigList(c *gin.Context) {
var resp apistruct.GetConfigListResp
resp.ConfigNames = cm.config.GetConfigNames()
resp.Environment = runtimeenv.RuntimeEnvironment()
resp.Version = version.Version
apiresp.GinSuccess(c, resp)
}
func (cm *ConfigManager) SetConfig(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support set config").Wrap())
return
}
var req apistruct.SetConfigReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var err error
switch req.ConfigName {
case cm.config.Discovery.GetConfigFileName():
err = compareAndSave[config.Discovery](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Kafka.GetConfigFileName():
err = compareAndSave[config.Kafka](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.LocalCache.GetConfigFileName():
err = compareAndSave[config.LocalCache](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Log.GetConfigFileName():
err = compareAndSave[config.Log](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Minio.GetConfigFileName():
err = compareAndSave[config.Minio](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Mongo.GetConfigFileName():
err = compareAndSave[config.Mongo](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Notification.GetConfigFileName():
err = compareAndSave[config.Notification](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.API.GetConfigFileName():
err = compareAndSave[config.API](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.CronTask.GetConfigFileName():
err = compareAndSave[config.CronTask](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.MsgGateway.GetConfigFileName():
err = compareAndSave[config.MsgGateway](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.MsgTransfer.GetConfigFileName():
err = compareAndSave[config.MsgTransfer](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Push.GetConfigFileName():
err = compareAndSave[config.Push](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Auth.GetConfigFileName():
err = compareAndSave[config.Auth](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Conversation.GetConfigFileName():
err = compareAndSave[config.Conversation](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Friend.GetConfigFileName():
err = compareAndSave[config.Friend](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Group.GetConfigFileName():
err = compareAndSave[config.Group](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Msg.GetConfigFileName():
err = compareAndSave[config.Msg](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Third.GetConfigFileName():
err = compareAndSave[config.Third](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.User.GetConfigFileName():
err = compareAndSave[config.User](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Redis.GetConfigFileName():
err = compareAndSave[config.Redis](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Share.GetConfigFileName():
err = compareAndSave[config.Share](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Webhooks.GetConfigFileName():
err = compareAndSave[config.Webhooks](c, cm.config.Name2Config(req.ConfigName), &req, cm)
default:
apiresp.GinError(c, errs.ErrArgs.Wrap())
return
}
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) SetConfigs(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support set config").Wrap())
return
}
var req apistruct.SetConfigsReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var (
err error
ops []*clientv3.Op
)
for _, cf := range req.Configs {
var op *clientv3.Op
switch cf.ConfigName {
case cm.config.Discovery.GetConfigFileName():
op, err = compareAndOp[config.Discovery](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Kafka.GetConfigFileName():
op, err = compareAndOp[config.Kafka](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.LocalCache.GetConfigFileName():
op, err = compareAndOp[config.LocalCache](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Log.GetConfigFileName():
op, err = compareAndOp[config.Log](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Minio.GetConfigFileName():
op, err = compareAndOp[config.Minio](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Mongo.GetConfigFileName():
op, err = compareAndOp[config.Mongo](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Notification.GetConfigFileName():
op, err = compareAndOp[config.Notification](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.API.GetConfigFileName():
op, err = compareAndOp[config.API](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.CronTask.GetConfigFileName():
op, err = compareAndOp[config.CronTask](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.MsgGateway.GetConfigFileName():
op, err = compareAndOp[config.MsgGateway](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.MsgTransfer.GetConfigFileName():
op, err = compareAndOp[config.MsgTransfer](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Push.GetConfigFileName():
op, err = compareAndOp[config.Push](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Auth.GetConfigFileName():
op, err = compareAndOp[config.Auth](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Conversation.GetConfigFileName():
op, err = compareAndOp[config.Conversation](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Friend.GetConfigFileName():
op, err = compareAndOp[config.Friend](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Group.GetConfigFileName():
op, err = compareAndOp[config.Group](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Msg.GetConfigFileName():
op, err = compareAndOp[config.Msg](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Third.GetConfigFileName():
op, err = compareAndOp[config.Third](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.User.GetConfigFileName():
op, err = compareAndOp[config.User](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Redis.GetConfigFileName():
op, err = compareAndOp[config.Redis](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Share.GetConfigFileName():
op, err = compareAndOp[config.Share](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Webhooks.GetConfigFileName():
op, err = compareAndOp[config.Webhooks](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
default:
apiresp.GinError(c, errs.ErrArgs.Wrap())
return
}
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if op != nil {
ops = append(ops, op)
}
}
if len(ops) > 0 {
tx := cm.client.Txn(c)
if _, err = tx.Then(datautil.Batch(func(op *clientv3.Op) clientv3.Op { return *op }, ops)...).Commit(); err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "save to etcd failed"))
return
}
}
apiresp.GinSuccess(c, nil)
}
func compareAndOp[T any](c *gin.Context, old any, req *apistruct.SetConfigReq, cm *ConfigManager) (*clientv3.Op, error) {
conf := new(T)
err := json.Unmarshal([]byte(req.Data), &conf)
if err != nil {
return nil, errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
eq := reflect.DeepEqual(old, conf)
if eq {
return nil, nil
}
data, err := json.Marshal(conf)
if err != nil {
return nil, errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
op := clientv3.OpPut(etcd.BuildKey(req.ConfigName), string(data))
return &op, nil
}
func compareAndSave[T any](c *gin.Context, old any, req *apistruct.SetConfigReq, cm *ConfigManager) error {
conf := new(T)
err := json.Unmarshal([]byte(req.Data), &conf)
if err != nil {
return errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
eq := reflect.DeepEqual(old, conf)
if eq {
return nil
}
data, err := json.Marshal(conf)
if err != nil {
return errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
_, err = cm.client.Put(c, etcd.BuildKey(req.ConfigName), string(data))
if err != nil {
return errs.WrapMsg(err, "save to etcd failed")
}
return nil
}
func (cm *ConfigManager) ResetConfig(c *gin.Context) {
go func() {
if err := cm.resetConfig(c, true); err != nil {
log.ZError(c, "reset config err", err)
}
}()
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) resetConfig(c *gin.Context, checkChange bool, ops ...clientv3.Op) error {
txn := cm.client.Txn(c)
type initConf struct {
old any
new any
}
configMap := map[string]*initConf{
cm.config.Discovery.GetConfigFileName(): {old: &cm.config.Discovery, new: new(config.Discovery)},
cm.config.Kafka.GetConfigFileName(): {old: &cm.config.Kafka, new: new(config.Kafka)},
cm.config.LocalCache.GetConfigFileName(): {old: &cm.config.LocalCache, new: new(config.LocalCache)},
cm.config.Log.GetConfigFileName(): {old: &cm.config.Log, new: new(config.Log)},
cm.config.Minio.GetConfigFileName(): {old: &cm.config.Minio, new: new(config.Minio)},
cm.config.Mongo.GetConfigFileName(): {old: &cm.config.Mongo, new: new(config.Mongo)},
cm.config.Notification.GetConfigFileName(): {old: &cm.config.Notification, new: new(config.Notification)},
cm.config.API.GetConfigFileName(): {old: &cm.config.API, new: new(config.API)},
cm.config.CronTask.GetConfigFileName(): {old: &cm.config.CronTask, new: new(config.CronTask)},
cm.config.MsgGateway.GetConfigFileName(): {old: &cm.config.MsgGateway, new: new(config.MsgGateway)},
cm.config.MsgTransfer.GetConfigFileName(): {old: &cm.config.MsgTransfer, new: new(config.MsgTransfer)},
cm.config.Push.GetConfigFileName(): {old: &cm.config.Push, new: new(config.Push)},
cm.config.Auth.GetConfigFileName(): {old: &cm.config.Auth, new: new(config.Auth)},
cm.config.Conversation.GetConfigFileName(): {old: &cm.config.Conversation, new: new(config.Conversation)},
cm.config.Friend.GetConfigFileName(): {old: &cm.config.Friend, new: new(config.Friend)},
cm.config.Group.GetConfigFileName(): {old: &cm.config.Group, new: new(config.Group)},
cm.config.Msg.GetConfigFileName(): {old: &cm.config.Msg, new: new(config.Msg)},
cm.config.Third.GetConfigFileName(): {old: &cm.config.Third, new: new(config.Third)},
cm.config.User.GetConfigFileName(): {old: &cm.config.User, new: new(config.User)},
cm.config.Redis.GetConfigFileName(): {old: &cm.config.Redis, new: new(config.Redis)},
cm.config.Share.GetConfigFileName(): {old: &cm.config.Share, new: new(config.Share)},
cm.config.Webhooks.GetConfigFileName(): {old: &cm.config.Webhooks, new: new(config.Webhooks)},
}
changedKeys := make([]string, 0, len(configMap))
for k, v := range configMap {
err := config.Load(cm.configPath, k, config.EnvPrefixMap[k], v.new)
if err != nil {
log.ZError(c, "load config failed", err)
continue
}
equal := reflect.DeepEqual(v.old, v.new)
if !checkChange || !equal {
changedKeys = append(changedKeys, k)
}
}
for _, k := range changedKeys {
data, err := json.Marshal(configMap[k].new)
if err != nil {
log.ZError(c, "marshal config failed", err)
continue
}
ops = append(ops, clientv3.OpPut(etcd.BuildKey(k), string(data)))
}
if len(ops) > 0 {
txn.Then(ops...)
_, err := txn.Commit()
if err != nil {
return errs.WrapMsg(err, "commit etcd txn failed")
}
}
return nil
}
func (cm *ConfigManager) Restart(c *gin.Context) {
go cm.restart(c)
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) restart(c *gin.Context) {
time.Sleep(waitHttp) // wait for Restart http call return
t := time.Now().Unix()
_, err := cm.client.Put(c, etcd.BuildKey(etcd.RestartKey), strconv.Itoa(int(t)))
if err != nil {
log.ZError(c, "restart etcd put key failed", err)
}
}
func (cm *ConfigManager) SetEnableConfigManager(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support config manager").Wrap())
return
}
var req apistruct.SetEnableConfigManagerReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var enableStr string
if req.Enable {
enableStr = etcd.Enable
} else {
enableStr = etcd.Disable
}
resp, err := cm.client.Get(c, etcd.BuildKey(etcd.EnableConfigCenterKey))
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "getEnableConfigManager failed"))
return
}
if !(resp.Count > 0 && string(resp.Kvs[0].Value) == etcd.Enable) && req.Enable {
go func() {
time.Sleep(waitHttp) // wait for Restart http call return
err := cm.resetConfig(c, false, clientv3.OpPut(etcd.BuildKey(etcd.EnableConfigCenterKey), enableStr))
if err != nil {
log.ZError(c, "resetConfig failed", err)
}
}()
} else {
_, err = cm.client.Put(c, etcd.BuildKey(etcd.EnableConfigCenterKey), enableStr)
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "setEnableConfigManager failed"))
return
}
}
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) GetEnableConfigManager(c *gin.Context) {
resp, err := cm.client.Get(c, etcd.BuildKey(etcd.EnableConfigCenterKey))
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "getEnableConfigManager failed"))
return
}
var enable bool
if resp.Count > 0 && string(resp.Kvs[0].Value) == etcd.Enable {
enable = true
}
apiresp.GinSuccess(c, &apistruct.GetEnableConfigManagerResp{Enable: enable})
}

View File

@@ -0,0 +1,82 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/conversation"
"github.com/openimsdk/tools/a2r"
)
type ConversationApi struct {
Client conversation.ConversationClient
}
func NewConversationApi(client conversation.ConversationClient) ConversationApi {
return ConversationApi{client}
}
func (o *ConversationApi) GetAllConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetAllConversations, o.Client)
}
func (o *ConversationApi) GetSortedConversationList(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetSortedConversationList, o.Client)
}
func (o *ConversationApi) GetConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetConversation, o.Client)
}
func (o *ConversationApi) GetConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetConversations, o.Client)
}
func (o *ConversationApi) SetConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.SetConversations, o.Client)
}
//func (o *ConversationApi) GetConversationOfflinePushUserIDs(c *gin.Context) {
// a2r.Call(c, conversation.ConversationClient.GetConversationOfflinePushUserIDs, o.Client)
//}
func (o *ConversationApi) GetFullOwnerConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetFullOwnerConversationIDs, o.Client)
}
func (o *ConversationApi) GetIncrementalConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetIncrementalConversation, o.Client)
}
func (o *ConversationApi) GetOwnerConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetOwnerConversation, o.Client)
}
func (o *ConversationApi) GetNotNotifyConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetNotNotifyConversationIDs, o.Client)
}
func (o *ConversationApi) GetPinnedConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetPinnedConversationIDs, o.Client)
}
func (o *ConversationApi) UpdateConversationsByUser(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.UpdateConversationsByUser, o.Client)
}
func (o *ConversationApi) DeleteConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.DeleteConversations, o.Client)
}

View File

@@ -0,0 +1,34 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/constant"
"github.com/go-playground/validator/v10"
)
// RequiredIf validates if the specified field is required based on the session type.
func RequiredIf(fl validator.FieldLevel) bool {
sessionType := fl.Parent().FieldByName("SessionType").Int()
switch sessionType {
case constant.SingleChatType, constant.NotificationChatType:
return fl.FieldName() != "RecvID" || fl.Field().String() != ""
case constant.WriteGroupChatType, constant.ReadGroupChatType:
return fl.FieldName() != "GroupID" || fl.Field().String() != ""
default:
return true
}
}

120
internal/api/friend.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/relation"
"github.com/openimsdk/tools/a2r"
)
type FriendApi struct {
Client relation.FriendClient
}
func NewFriendApi(client relation.FriendClient) FriendApi {
return FriendApi{client}
}
func (o *FriendApi) ApplyToAddFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.ApplyToAddFriend, o.Client)
}
func (o *FriendApi) RespondFriendApply(c *gin.Context) {
a2r.Call(c, relation.FriendClient.RespondFriendApply, o.Client)
}
func (o *FriendApi) DeleteFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.DeleteFriend, o.Client)
}
func (o *FriendApi) GetFriendApplyList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriendsApplyTo, o.Client)
}
func (o *FriendApi) GetDesignatedFriendsApply(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetDesignatedFriendsApply, o.Client)
}
func (o *FriendApi) GetSelfApplyList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriendsApplyFrom, o.Client)
}
func (o *FriendApi) GetFriendList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriends, o.Client)
}
func (o *FriendApi) GetDesignatedFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetDesignatedFriends, o.Client)
}
func (o *FriendApi) SetFriendRemark(c *gin.Context) {
a2r.Call(c, relation.FriendClient.SetFriendRemark, o.Client)
}
func (o *FriendApi) AddBlack(c *gin.Context) {
a2r.Call(c, relation.FriendClient.AddBlack, o.Client)
}
func (o *FriendApi) GetPaginationBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationBlacks, o.Client)
}
func (o *FriendApi) GetSpecifiedBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSpecifiedBlacks, o.Client)
}
func (o *FriendApi) RemoveBlack(c *gin.Context) {
a2r.Call(c, relation.FriendClient.RemoveBlack, o.Client)
}
func (o *FriendApi) ImportFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.ImportFriends, o.Client)
}
func (o *FriendApi) IsFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.IsFriend, o.Client)
}
func (o *FriendApi) GetFriendIDs(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetFriendIDs, o.Client)
}
func (o *FriendApi) GetSpecifiedFriendsInfo(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSpecifiedFriendsInfo, o.Client)
}
func (o *FriendApi) UpdateFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.UpdateFriends, o.Client)
}
func (o *FriendApi) GetIncrementalFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetIncrementalFriends, o.Client)
}
// GetIncrementalBlacks is temporarily unused.
// Deprecated: This function is currently unused and may be removed in future versions.
func (o *FriendApi) GetIncrementalBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetIncrementalBlacks, o.Client)
}
func (o *FriendApi) GetFullFriendUserIDs(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetFullFriendUserIDs, o.Client)
}
func (o *FriendApi) GetSelfUnhandledApplyCount(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSelfUnhandledApplyCount, o.Client)
}

171
internal/api/group.go Normal file
View File

@@ -0,0 +1,171 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/group"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
)
type GroupApi struct {
Client group.GroupClient
}
func NewGroupApi(client group.GroupClient) GroupApi {
return GroupApi{client}
}
func (o *GroupApi) CreateGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.CreateGroup, o.Client)
}
func (o *GroupApi) SetGroupInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupInfo, o.Client)
}
func (o *GroupApi) SetGroupInfoEx(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupInfoEx, o.Client)
}
func (o *GroupApi) JoinGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.JoinGroup, o.Client)
}
func (o *GroupApi) QuitGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.QuitGroup, o.Client)
}
func (o *GroupApi) ApplicationGroupResponse(c *gin.Context) {
a2r.Call(c, group.GroupClient.GroupApplicationResponse, o.Client)
}
func (o *GroupApi) TransferGroupOwner(c *gin.Context) {
a2r.Call(c, group.GroupClient.TransferGroupOwner, o.Client)
}
func (o *GroupApi) GetRecvGroupApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupApplicationList, o.Client)
}
func (o *GroupApi) GetUserReqGroupApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetUserReqApplicationList, o.Client)
}
func (o *GroupApi) GetGroupUsersReqApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupUsersReqApplicationList, o.Client)
}
func (o *GroupApi) GetSpecifiedUserGroupRequestInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetSpecifiedUserGroupRequestInfo, o.Client)
}
func (o *GroupApi) GetGroupsInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupsInfo, o.Client)
//a2r.Call(c, group.GroupClient.GetGroupsInfo, o.Client, c, a2r.NewNilReplaceOption(group.GroupClient.GetGroupsInfo))
}
func (o *GroupApi) KickGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.KickGroupMember, o.Client)
}
func (o *GroupApi) GetGroupMembersInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMembersInfo, o.Client)
//a2r.Call(c, group.GroupClient.GetGroupMembersInfo, o.Client, c, a2r.NewNilReplaceOption(group.GroupClient.GetGroupMembersInfo))
}
func (o *GroupApi) GetGroupMemberList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMemberList, o.Client)
}
func (o *GroupApi) InviteUserToGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.InviteUserToGroup, o.Client)
}
func (o *GroupApi) GetJoinedGroupList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetJoinedGroupList, o.Client)
}
func (o *GroupApi) DismissGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.DismissGroup, o.Client)
}
func (o *GroupApi) MuteGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.MuteGroupMember, o.Client)
}
func (o *GroupApi) CancelMuteGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.CancelMuteGroupMember, o.Client)
}
func (o *GroupApi) MuteGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.MuteGroup, o.Client)
}
func (o *GroupApi) CancelMuteGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.CancelMuteGroup, o.Client)
}
func (o *GroupApi) SetGroupMemberInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupMemberInfo, o.Client)
}
func (o *GroupApi) GetGroupAbstractInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupAbstractInfo, o.Client)
}
// func (g *Group) SetGroupMemberNickname(c *gin.Context) {
// a2r.Call(c, group.GroupClient.SetGroupMemberNickname, g.userClient)
//}
//
// func (g *Group) GetGroupAllMemberList(c *gin.Context) {
// a2r.Call(c, group.GroupClient.GetGroupAllMember, g.userClient)
//}
func (o *GroupApi) GroupCreateCount(c *gin.Context) {
a2r.Call(c, group.GroupClient.GroupCreateCount, o.Client)
}
func (o *GroupApi) GetGroups(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroups, o.Client)
}
func (o *GroupApi) GetGroupMemberUserIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMemberUserIDs, o.Client)
}
func (o *GroupApi) GetIncrementalJoinGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetIncrementalJoinGroup, o.Client)
}
func (o *GroupApi) GetIncrementalGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetIncrementalGroupMember, o.Client)
}
func (o *GroupApi) GetIncrementalGroupMemberBatch(c *gin.Context) {
a2r.Call(c, group.GroupClient.BatchGetIncrementalGroupMember, o.Client)
}
func (o *GroupApi) GetFullGroupMemberUserIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetFullGroupMemberUserIDs, o.Client)
}
func (o *GroupApi) GetFullJoinGroupIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetFullJoinGroupIDs, o.Client)
}
func (o *GroupApi) GetGroupApplicationUnhandledCount(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupApplicationUnhandledCount, o.Client)
}

104
internal/api/init.go Normal file
View File

@@ -0,0 +1,104 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"time"
conf "git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/network"
"github.com/openimsdk/tools/utils/runtimeenv"
"google.golang.org/grpc"
)
type Config struct {
conf.AllConfig
ConfigPath conf.Path
Index conf.Index
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, service grpc.ServiceRegistrar) error {
apiPort, err := datautil.GetElemByIndex(config.API.Api.Ports, int(config.Index))
if err != nil {
return err
}
router, err := newGinRouter(ctx, client, config)
if err != nil {
return err
}
apiCtx, apiCancel := context.WithCancelCause(context.Background())
done := make(chan struct{})
go func() {
httpServer := &http.Server{
Handler: router,
Addr: net.JoinHostPort(network.GetListenIP(config.API.Api.ListenIP), strconv.Itoa(apiPort)),
}
go func() {
defer close(done)
select {
case <-ctx.Done():
apiCancel(fmt.Errorf("recv ctx %w", context.Cause(ctx)))
case <-apiCtx.Done():
}
log.ZDebug(ctx, "api server is shutting down")
if err := httpServer.Shutdown(context.Background()); err != nil {
log.ZWarn(ctx, "api server shutdown err", err)
}
}()
log.CInfo(ctx, "api server is init", "runtimeEnv", runtimeenv.RuntimeEnvironment(), "address", httpServer.Addr, "apiPort", apiPort)
err := httpServer.ListenAndServe()
if err == nil {
err = errors.New("api done")
}
apiCancel(err)
}()
//if config.Discovery.Enable == conf.ETCD {
// cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames())
// cm.Watch(ctx)
//}
//sigs := make(chan os.Signal, 1)
//signal.Notify(sigs, syscall.SIGTERM)
//select {
//case val := <-sigs:
// log.ZDebug(ctx, "recv exit", "signal", val.String())
// cancel(fmt.Errorf("signal %s", val.String()))
//case <-ctx.Done():
//}
<-apiCtx.Done()
exitCause := context.Cause(apiCtx)
log.ZWarn(ctx, "api server exit", exitCause)
timer := time.NewTimer(time.Second * 15)
defer timer.Stop()
select {
case <-timer.C:
log.ZWarn(ctx, "api server graceful stop timeout", nil)
case <-done:
log.ZDebug(ctx, "api server graceful stop done")
}
return exitCause
}

287
internal/api/jssdk/jssdk.go Normal file
View File

@@ -0,0 +1,287 @@
package jssdk
import (
"context"
"sort"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/log"
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/jssdk"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
const (
maxGetActiveConversation = 500
defaultGetActiveConversation = 100
)
func NewJSSdkApi(userClient *rpcli.UserClient, relationClient *rpcli.RelationClient, groupClient *rpcli.GroupClient,
conversationClient *rpcli.ConversationClient, msgClient *rpcli.MsgClient) *JSSdk {
return &JSSdk{
userClient: userClient,
relationClient: relationClient,
groupClient: groupClient,
conversationClient: conversationClient,
msgClient: msgClient,
}
}
type JSSdk struct {
userClient *rpcli.UserClient
relationClient *rpcli.RelationClient
groupClient *rpcli.GroupClient
conversationClient *rpcli.ConversationClient
msgClient *rpcli.MsgClient
}
func (x *JSSdk) GetActiveConversations(c *gin.Context) {
call(c, x.getActiveConversations)
}
func (x *JSSdk) GetConversations(c *gin.Context) {
call(c, x.getConversations)
}
func (x *JSSdk) fillConversations(ctx context.Context, conversations []*jssdk.ConversationMsg) error {
if len(conversations) == 0 {
return nil
}
var (
userIDs []string
groupIDs []string
)
for _, c := range conversations {
if c.Conversation.GroupID == "" {
userIDs = append(userIDs, c.Conversation.UserID)
} else {
groupIDs = append(groupIDs, c.Conversation.GroupID)
}
}
var (
userMap map[string]*sdkws.UserInfo
friendMap map[string]*relation.FriendInfoOnly
groupMap map[string]*sdkws.GroupInfo
)
if len(userIDs) > 0 {
users, err := x.userClient.GetUsersInfo(ctx, userIDs)
if err != nil {
return err
}
friends, err := x.relationClient.GetFriendsInfo(ctx, conversations[0].Conversation.OwnerUserID, userIDs)
if err != nil {
return err
}
userMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
friendMap = datautil.SliceToMap(friends, (*relation.FriendInfoOnly).GetFriendUserID)
}
if len(groupIDs) > 0 {
groups, err := x.groupClient.GetGroupsInfo(ctx, groupIDs)
if err != nil {
return err
}
groupMap = datautil.SliceToMap(groups, (*sdkws.GroupInfo).GetGroupID)
}
for _, c := range conversations {
if c.Conversation.GroupID == "" {
c.User = userMap[c.Conversation.UserID]
c.Friend = friendMap[c.Conversation.UserID]
} else {
c.Group = groupMap[c.Conversation.GroupID]
}
}
return nil
}
func (x *JSSdk) getActiveConversations(ctx context.Context, req *jssdk.GetActiveConversationsReq) (*jssdk.GetActiveConversationsResp, error) {
if req.Count <= 0 || req.Count > maxGetActiveConversation {
req.Count = defaultGetActiveConversation
}
req.OwnerUserID = mcontext.GetOpUserID(ctx)
conversationIDs, err := x.conversationClient.GetConversationIDs(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
if len(conversationIDs) == 0 {
return &jssdk.GetActiveConversationsResp{}, nil
}
activeConversation, err := x.msgClient.GetActiveConversation(ctx, conversationIDs)
if err != nil {
return nil, err
}
if len(activeConversation) == 0 {
return &jssdk.GetActiveConversationsResp{}, nil
}
readSeq, err := x.msgClient.GetHasReadSeqs(ctx, conversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
sortConversations := sortActiveConversations{
Conversation: activeConversation,
}
if len(activeConversation) > 1 {
pinnedConversationIDs, err := x.conversationClient.GetPinnedConversationIDs(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
sortConversations.PinnedConversationIDs = datautil.SliceSet(pinnedConversationIDs)
}
sort.Sort(&sortConversations)
sortList := sortConversations.Top(int(req.Count))
conversations, err := x.conversationClient.GetConversations(ctx, datautil.Slice(sortList, func(c *msg.ActiveConversation) string {
return c.ConversationID
}), req.OwnerUserID)
if err != nil {
return nil, err
}
msgs, err := x.msgClient.GetSeqMessage(ctx, req.OwnerUserID, datautil.Slice(sortList, func(c *msg.ActiveConversation) *msg.ConversationSeqs {
return &msg.ConversationSeqs{
ConversationID: c.ConversationID,
Seqs: []int64{c.MaxSeq},
}
}))
if err != nil {
return nil, err
}
x.checkMessagesAndGetLastMessage(ctx, req.OwnerUserID, msgs)
conversationMap := datautil.SliceToMap(conversations, func(c *conversation.Conversation) string {
return c.ConversationID
})
resp := make([]*jssdk.ConversationMsg, 0, len(sortList))
for _, c := range sortList {
conv, ok := conversationMap[c.ConversationID]
if !ok {
continue
}
if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 {
resp = append(resp, &jssdk.ConversationMsg{
Conversation: conv,
LastMsg: msgList.Msgs[0],
MaxSeq: c.MaxSeq,
ReadSeq: readSeq[c.ConversationID],
})
}
}
if err := x.fillConversations(ctx, resp); err != nil {
return nil, err
}
var unreadCount int64
for _, c := range activeConversation {
count := c.MaxSeq - readSeq[c.ConversationID]
if count > 0 {
unreadCount += count
}
}
return &jssdk.GetActiveConversationsResp{
Conversations: resp,
UnreadCount: unreadCount,
}, nil
}
func (x *JSSdk) getConversations(ctx context.Context, req *jssdk.GetConversationsReq) (*jssdk.GetConversationsResp, error) {
req.OwnerUserID = mcontext.GetOpUserID(ctx)
conversations, err := x.conversationClient.GetConversations(ctx, req.ConversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
if len(conversations) == 0 {
return &jssdk.GetConversationsResp{}, nil
}
req.ConversationIDs = datautil.Slice(conversations, func(c *conversation.Conversation) string {
return c.ConversationID
})
maxSeqs, err := x.msgClient.GetMaxSeqs(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
readSeqs, err := x.msgClient.GetHasReadSeqs(ctx, req.ConversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
conversationSeqs := make([]*msg.ConversationSeqs, 0, len(conversations))
for _, c := range conversations {
if seq := maxSeqs[c.ConversationID]; seq > 0 {
conversationSeqs = append(conversationSeqs, &msg.ConversationSeqs{
ConversationID: c.ConversationID,
Seqs: []int64{seq},
})
}
}
var msgs map[string]*sdkws.PullMsgs
if len(conversationSeqs) > 0 {
msgs, err = x.msgClient.GetSeqMessage(ctx, req.OwnerUserID, conversationSeqs)
if err != nil {
return nil, err
}
}
x.checkMessagesAndGetLastMessage(ctx, req.OwnerUserID, msgs)
resp := make([]*jssdk.ConversationMsg, 0, len(conversations))
for _, c := range conversations {
if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 {
resp = append(resp, &jssdk.ConversationMsg{
Conversation: c,
LastMsg: msgList.Msgs[0],
MaxSeq: maxSeqs[c.ConversationID],
ReadSeq: readSeqs[c.ConversationID],
})
}
}
if err := x.fillConversations(ctx, resp); err != nil {
return nil, err
}
var unreadCount int64
for conversationID, maxSeq := range maxSeqs {
count := maxSeq - readSeqs[conversationID]
if count > 0 {
unreadCount += count
}
}
return &jssdk.GetConversationsResp{
Conversations: resp,
UnreadCount: unreadCount,
}, nil
}
// This function checks whether the latest MaxSeq message is valid.
// If not, it needs to fetch a valid message again.
func (x *JSSdk) checkMessagesAndGetLastMessage(ctx context.Context, userID string, messages map[string]*sdkws.PullMsgs) {
var conversationIDs []string
for conversationID, message := range messages {
allInValid := true
for _, data := range message.Msgs {
if data.Status < constant.MsgStatusHasDeleted {
allInValid = false
break
}
}
if allInValid {
conversationIDs = append(conversationIDs, conversationID)
}
}
if len(conversationIDs) > 0 {
resp, err := x.msgClient.GetLastMessage(ctx, &msg.GetLastMessageReq{
UserID: userID,
ConversationIDs: conversationIDs,
})
if err != nil {
log.ZError(ctx, "fetchLatestValidMessages", err, "conversationIDs", conversationIDs)
return
}
for conversationID, message := range resp.Msgs {
messages[conversationID] = &sdkws.PullMsgs{Msgs: []*sdkws.MsgData{message}}
}
}
}

View File

@@ -0,0 +1,33 @@
package jssdk
import "git.imall.cloud/openim/protocol/msg"
type sortActiveConversations struct {
Conversation []*msg.ActiveConversation
PinnedConversationIDs map[string]struct{}
}
func (s sortActiveConversations) Top(limit int) []*msg.ActiveConversation {
if limit > 0 && len(s.Conversation) > limit {
return s.Conversation[:limit]
}
return s.Conversation
}
func (s sortActiveConversations) Len() int {
return len(s.Conversation)
}
func (s sortActiveConversations) Less(i, j int) bool {
iv, jv := s.Conversation[i], s.Conversation[j]
_, ip := s.PinnedConversationIDs[iv.ConversationID]
_, jp := s.PinnedConversationIDs[jv.ConversationID]
if ip != jp {
return ip
}
return iv.LastTime > jv.LastTime
}
func (s sortActiveConversations) Swap(i, j int) {
s.Conversation[i], s.Conversation[j] = s.Conversation[j], s.Conversation[i]
}

View File

@@ -0,0 +1,77 @@
package jssdk
import (
"context"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/checker"
"github.com/openimsdk/tools/errs"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"io"
"strings"
)
func field[A, B, C any](ctx context.Context, fn func(ctx context.Context, req *A, opts ...grpc.CallOption) (*B, error), req *A, get func(*B) C) (C, error) {
resp, err := fn(ctx, req)
if err != nil {
var c C
return c, err
}
return get(resp), nil
}
func call[A, B any](c *gin.Context, fn func(ctx context.Context, req *A) (*B, error)) {
var isJSON bool
switch contentType := c.GetHeader("Content-Type"); {
case contentType == "":
isJSON = true
case strings.Contains(contentType, "application/json"):
isJSON = true
case strings.Contains(contentType, "application/protobuf"):
case strings.Contains(contentType, "application/x-protobuf"):
default:
apiresp.GinError(c, errs.ErrArgs.WrapMsg("unsupported content type"))
return
}
var req *A
if isJSON {
var err error
req, err = a2r.ParseRequest[A](c)
if err != nil {
apiresp.GinError(c, err)
return
}
} else {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
apiresp.GinError(c, err)
return
}
req = new(A)
if err := proto.Unmarshal(body, any(req).(proto.Message)); err != nil {
apiresp.GinError(c, err)
return
}
if err := checker.Validate(&req); err != nil {
apiresp.GinError(c, err)
return
}
}
resp, err := fn(c, req)
if err != nil {
apiresp.GinError(c, err)
return
}
if isJSON {
apiresp.GinSuccess(c, resp)
return
}
body, err := proto.Marshal(any(resp).(proto.Message))
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, body)
}

1371
internal/api/meeting.go Normal file

File diff suppressed because it is too large Load Diff

575
internal/api/msg.go Normal file
View File

@@ -0,0 +1,575 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"encoding/base64"
"encoding/json"
"sync"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/reflect/protoreflect"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/idutil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/timeutil"
)
var (
msgDataDescriptor []protoreflect.FieldDescriptor
msgDataDescriptorOnce sync.Once
)
func getMsgDataDescriptor() []protoreflect.FieldDescriptor {
msgDataDescriptorOnce.Do(func() {
skip := make(map[string]struct{})
respFields := new(msg.SendMsgResp).ProtoReflect().Descriptor().Fields()
for i := 0; i < respFields.Len(); i++ {
field := respFields.Get(i)
if !field.HasJSONName() {
continue
}
skip[field.JSONName()] = struct{}{}
}
fields := new(sdkws.MsgData).ProtoReflect().Descriptor().Fields()
num := fields.Len()
msgDataDescriptor = make([]protoreflect.FieldDescriptor, 0, num)
for i := 0; i < num; i++ {
field := fields.Get(i)
if !field.HasJSONName() {
continue
}
if _, ok := skip[field.JSONName()]; ok {
continue
}
msgDataDescriptor = append(msgDataDescriptor, fields.Get(i))
}
})
return msgDataDescriptor
}
type MessageApi struct {
Client msg.MsgClient
userClient *rpcli.UserClient
imAdminUserID []string
validate *validator.Validate
}
func NewMessageApi(client msg.MsgClient, userClient *rpcli.UserClient, imAdminUserID []string) MessageApi {
return MessageApi{Client: client, userClient: userClient, imAdminUserID: imAdminUserID, validate: validator.New()}
}
func (*MessageApi) SetOptions(options map[string]bool, value bool) {
datautil.SetSwitchFromOptions(options, constant.IsHistory, value)
datautil.SetSwitchFromOptions(options, constant.IsPersistent, value)
datautil.SetSwitchFromOptions(options, constant.IsSenderSync, value)
datautil.SetSwitchFromOptions(options, constant.IsConversationUpdate, value)
}
func (m *MessageApi) newUserSendMsgReq(_ *gin.Context, params *apistruct.SendMsg, data any) *msg.SendMsgReq {
msgData := &sdkws.MsgData{
SendID: params.SendID,
GroupID: params.GroupID,
ClientMsgID: idutil.GetMsgIDByMD5(params.SendID),
SenderPlatformID: params.SenderPlatformID,
SenderNickname: params.SenderNickname,
SenderFaceURL: params.SenderFaceURL,
SessionType: params.SessionType,
MsgFrom: constant.SysMsgType,
ContentType: params.ContentType,
CreateTime: timeutil.GetCurrentTimestampByMill(),
SendTime: params.SendTime,
OfflinePushInfo: params.OfflinePushInfo,
Ex: params.Ex,
}
var newContent string
options := make(map[string]bool, 5)
switch params.ContentType {
case constant.OANotification:
notification := sdkws.NotificationElem{}
notification.Detail = jsonutil.StructToJsonString(params.Content)
newContent = jsonutil.StructToJsonString(&notification)
case constant.Text:
fallthrough
case constant.AtText:
if atElem, ok := data.(*apistruct.AtElem); ok {
msgData.AtUserIDList = atElem.AtUserList
}
fallthrough
case constant.Picture:
fallthrough
case constant.Custom:
fallthrough
case constant.Voice:
fallthrough
case constant.Video:
fallthrough
case constant.File:
fallthrough
default:
newContent = jsonutil.StructToJsonString(params.Content)
}
if params.IsOnlineOnly {
m.SetOptions(options, false)
}
if params.NotOfflinePush {
datautil.SetSwitchFromOptions(options, constant.IsOfflinePush, false)
}
msgData.Content = []byte(newContent)
msgData.Options = options
pbData := msg.SendMsgReq{
MsgData: msgData,
}
return &pbData
}
func (m *MessageApi) GetSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetMaxSeq, m.Client)
}
func (m *MessageApi) PullMsgBySeqs(c *gin.Context) {
a2r.Call(c, msg.MsgClient.PullMessageBySeqs, m.Client)
}
func (m *MessageApi) RevokeMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.RevokeMsg, m.Client)
}
func (m *MessageApi) MarkMsgsAsRead(c *gin.Context) {
a2r.Call(c, msg.MsgClient.MarkMsgsAsRead, m.Client)
}
func (m *MessageApi) MarkConversationAsRead(c *gin.Context) {
a2r.Call(c, msg.MsgClient.MarkConversationAsRead, m.Client)
}
func (m *MessageApi) GetConversationsHasReadAndMaxSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetConversationsHasReadAndMaxSeq, m.Client)
}
func (m *MessageApi) SetConversationHasReadSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.SetConversationHasReadSeq, m.Client)
}
func (m *MessageApi) ClearConversationsMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.ClearConversationsMsg, m.Client)
}
func (m *MessageApi) UserClearAllMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.UserClearAllMsg, m.Client)
}
func (m *MessageApi) DeleteMsgs(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgs, m.Client)
}
func (m *MessageApi) DeleteMsgPhysicalBySeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgPhysicalBySeq, m.Client)
}
func (m *MessageApi) DeleteMsgPhysical(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgPhysical, m.Client)
}
func (m *MessageApi) getSendMsgReq(c *gin.Context, req apistruct.SendMsg) (sendMsgReq *msg.SendMsgReq, err error) {
var data any
log.ZDebug(c, "getSendMsgReq", "req", req.Content)
switch req.ContentType {
case constant.Text:
data = &apistruct.TextElem{}
case constant.Picture:
data = &apistruct.PictureElem{}
case constant.Voice:
data = &apistruct.SoundElem{}
case constant.Video:
data = &apistruct.VideoElem{}
case constant.File:
data = &apistruct.FileElem{}
case constant.AtText:
data = &apistruct.AtElem{}
case constant.Custom:
data = &apistruct.CustomElem{}
case constant.MarkdownText:
data = &apistruct.MarkdownTextElem{}
case constant.Quote:
data = &apistruct.QuoteElem{}
case constant.OANotification:
data = &apistruct.OANotificationElem{}
req.SessionType = constant.NotificationChatType
if err = m.userClient.GetNotificationByID(c, req.SendID); err != nil {
return nil, err
}
default:
return nil, errs.WrapMsg(errs.ErrArgs, "unsupported content type", "contentType", req.ContentType)
}
if err := mapstructure.WeakDecode(req.Content, data); err != nil {
return nil, errs.WrapMsg(err, "failed to decode message content")
}
log.ZDebug(c, "getSendMsgReq", "decodedContent", data)
if err := m.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation error")
}
return m.newUserSendMsgReq(c, &req, data), nil
}
func (m *MessageApi) getModifyFields(req, respModify *sdkws.MsgData) map[string]any {
if req == nil || respModify == nil {
return nil
}
fields := make(map[string]any)
reqProtoReflect := req.ProtoReflect()
respProtoReflect := respModify.ProtoReflect()
for _, descriptor := range getMsgDataDescriptor() {
reqValue := reqProtoReflect.Get(descriptor)
respValue := respProtoReflect.Get(descriptor)
if !reqValue.Equal(respValue) {
val := respValue.Interface()
name := descriptor.JSONName()
if name == "content" {
if bs, ok := val.([]byte); ok {
val = string(bs)
}
}
fields[name] = val
}
}
if len(fields) == 0 {
fields = nil
}
return fields
}
func (m *MessageApi) ginRespSendMsg(c *gin.Context, req *msg.SendMsgReq, resp *msg.SendMsgResp) {
res := m.getModifyFields(req.MsgData, resp.Modify)
resp.Modify = nil
apiresp.GinSuccess(c, &apistruct.SendMsgResp{
SendMsgResp: resp,
Modify: res,
})
}
// SendMessage handles the sending of a message. It's an HTTP handler function to be used with Gin framework.
func (m *MessageApi) SendMessage(c *gin.Context) {
// Initialize a request struct for sending a message.
req := apistruct.SendMsgReq{}
// Bind the JSON request body to the request struct.
if err := c.BindJSON(&req); err != nil {
// Respond with an error if request body binding fails.
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// Check if the user has the app manager role.
if !authverify.IsAdmin(c) {
// Respond with a permission error if the user is not an app manager.
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
// Prepare the message request with additional required data.
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil {
// Log and respond with an error if preparation fails.
apiresp.GinError(c, err)
return
}
// Set the receiver ID in the message data.
sendMsgReq.MsgData.RecvID = req.RecvID
// Attempt to send the message using the client.
respPb, err := m.Client.SendMsg(c, sendMsgReq)
if err != nil {
// Set the status to failed and respond with an error if sending fails.
apiresp.GinError(c, err)
return
}
// Set the status to successful if the message is sent.
var status = constant.MsgSendSuccessed
// Attempt to update the message sending status in the system.
_, err = m.Client.SetSendMsgStatus(c, &msg.SetSendMsgStatusReq{
Status: int32(status),
})
if err != nil {
// Log the error if updating the status fails.
apiresp.GinError(c, err)
return
}
// Respond with a success message and the response payload.
m.ginRespSendMsg(c, sendMsgReq, respPb)
}
func (m *MessageApi) SendBusinessNotification(c *gin.Context) {
req := struct {
Key string `json:"key"`
Data string `json:"data"`
SendUserID string `json:"sendUserID" binding:"required"`
RecvUserID string `json:"recvUserID"`
RecvGroupID string `json:"recvGroupID"`
SendMsg bool `json:"sendMsg"`
ReliabilityLevel *int `json:"reliabilityLevel"`
}{}
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if req.RecvUserID == "" && req.RecvGroupID == "" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("recvUserID and recvGroupID cannot be empty at the same time"))
return
}
if req.RecvUserID != "" && req.RecvGroupID != "" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("recvUserID and recvGroupID cannot be set at the same time"))
return
}
var sessionType int32
if req.RecvUserID != "" {
sessionType = constant.SingleChatType
} else {
sessionType = constant.ReadGroupChatType
}
if req.ReliabilityLevel == nil {
req.ReliabilityLevel = datautil.ToPtr(1)
}
if !authverify.IsAdmin(c) {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
sendMsgReq := msg.SendMsgReq{
MsgData: &sdkws.MsgData{
SendID: req.SendUserID,
RecvID: req.RecvUserID,
GroupID: req.RecvGroupID,
Content: []byte(jsonutil.StructToJsonString(&sdkws.NotificationElem{
Detail: jsonutil.StructToJsonString(&struct {
Key string `json:"key"`
Data string `json:"data"`
}{Key: req.Key, Data: req.Data}),
})),
MsgFrom: constant.SysMsgType,
ContentType: constant.BusinessNotification,
SessionType: sessionType,
CreateTime: timeutil.GetCurrentTimestampByMill(),
ClientMsgID: idutil.GetMsgIDByMD5(mcontext.GetOpUserID(c)),
Options: config.GetOptionsByNotification(config.NotificationConfig{
IsSendMsg: req.SendMsg,
ReliabilityLevel: *req.ReliabilityLevel,
UnreadCount: false,
}, nil),
},
}
respPb, err := m.Client.SendMsg(c, &sendMsgReq)
if err != nil {
apiresp.GinError(c, err)
return
}
m.ginRespSendMsg(c, &sendMsgReq, respPb)
}
func (m *MessageApi) BatchSendMsg(c *gin.Context) {
var (
req apistruct.BatchSendMsgReq
resp apistruct.BatchSendMsgResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
var recvIDs []string
if req.IsSendAll {
var pageNumber int32 = 1
const showNumber = 500
for {
recvIDsPart, err := m.userClient.GetAllUserIDs(c, pageNumber, showNumber)
if err != nil {
apiresp.GinError(c, err)
return
}
recvIDs = append(recvIDs, recvIDsPart...)
if len(recvIDsPart) < showNumber {
break
}
pageNumber++
}
} else {
recvIDs = req.RecvIDs
}
log.ZDebug(c, "BatchSendMsg nums", "nums ", len(recvIDs))
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil {
apiresp.GinError(c, err)
return
}
for _, recvID := range recvIDs {
sendMsgReq.MsgData.RecvID = recvID
rpcResp, err := m.Client.SendMsg(c, sendMsgReq)
if err != nil {
resp.FailedIDs = append(resp.FailedIDs, recvID)
continue
}
resp.Results = append(resp.Results, &apistruct.SingleReturnResult{
ServerMsgID: rpcResp.ServerMsgID,
ClientMsgID: rpcResp.ClientMsgID,
SendTime: rpcResp.SendTime,
RecvID: recvID,
Modify: m.getModifyFields(sendMsgReq.MsgData, rpcResp.Modify),
})
}
apiresp.GinSuccess(c, resp)
}
func (m *MessageApi) SendSimpleMessage(c *gin.Context) {
encodedKey, ok := c.GetQuery(webhook.Key)
if !ok {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing key in query").Wrap())
return
}
decodedData, err := base64.StdEncoding.DecodeString(encodedKey)
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var (
req apistruct.SendSingleMsgReq
keyMsgData apistruct.KeyMsgData
sendID string
sessionType int32
recvID string
)
if err = c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
err = json.Unmarshal(decodedData, &keyMsgData)
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if keyMsgData.GroupID != "" {
sessionType = constant.ReadGroupChatType
sendID = req.SendID
} else {
sessionType = constant.SingleChatType
sendID = keyMsgData.RecvID
recvID = keyMsgData.SendID
}
// check param
if keyMsgData.SendID == "" {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing recvID or GroupID").Wrap())
return
}
if sendID == "" {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing sendID").Wrap())
return
}
content, err := jsonutil.JsonMarshal(apistruct.MarkdownTextElem{Content: req.Content})
if err != nil {
apiresp.GinError(c, errs.Wrap(err))
return
}
msgData := &sdkws.MsgData{
SendID: sendID,
RecvID: recvID,
GroupID: keyMsgData.GroupID,
ClientMsgID: idutil.GetMsgIDByMD5(sendID),
SenderPlatformID: constant.AdminPlatformID,
SessionType: sessionType,
MsgFrom: constant.UserMsgType,
ContentType: constant.MarkdownText,
Content: content,
OfflinePushInfo: req.OfflinePushInfo,
Ex: req.Ex,
}
sendReq := &msg.SendSimpleMsgReq{
MsgData: msgData,
}
respPb, err := m.Client.SendSimpleMsg(c, sendReq)
if err != nil {
apiresp.GinError(c, err)
return
}
var status = constant.MsgSendSuccessed
_, err = m.Client.SetSendMsgStatus(c, &msg.SetSendMsgStatusReq{
Status: int32(status),
})
if err != nil {
apiresp.GinError(c, err)
return
}
m.ginRespSendMsg(c, &msg.SendMsgReq{MsgData: sendReq.MsgData}, &msg.SendMsgResp{
ServerMsgID: respPb.ServerMsgID,
ClientMsgID: respPb.ClientMsgID,
SendTime: respPb.SendTime,
Modify: respPb.Modify,
})
}
func (m *MessageApi) CheckMsgIsSendSuccess(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetSendMsgStatus, m.Client)
}
func (m *MessageApi) GetUsersOnlineStatus(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetSendMsgStatus, m.Client)
}
func (m *MessageApi) GetActiveUser(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetActiveUser, m.Client)
}
func (m *MessageApi) GetActiveGroup(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetActiveGroup, m.Client)
}
func (m *MessageApi) SearchMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.SearchMessage, m.Client)
}
func (m *MessageApi) GetServerTime(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetServerTime, m.Client)
}

View File

@@ -0,0 +1,99 @@
package api
import (
"encoding/json"
"errors"
"net/http"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
)
type PrometheusDiscoveryApi struct {
config *Config
kv discovery.KeyValue
}
func NewPrometheusDiscoveryApi(config *Config, client discovery.SvcDiscoveryRegistry) *PrometheusDiscoveryApi {
api := &PrometheusDiscoveryApi{
config: config,
kv: client,
}
return api
}
func (p *PrometheusDiscoveryApi) discovery(c *gin.Context, key string) {
value, err := p.kv.GetKeyWithPrefix(c, prommetrics.BuildDiscoveryKeyPrefix(key))
if err != nil {
if errors.Is(err, discovery.ErrNotSupported) {
c.JSON(http.StatusOK, []struct{}{})
return
}
apiresp.GinError(c, errs.WrapMsg(err, "get key value"))
return
}
if len(value) == 0 {
c.JSON(http.StatusOK, []*prommetrics.RespTarget{})
return
}
var resp prommetrics.RespTarget
for i := range value {
var tmp prommetrics.Target
if err = json.Unmarshal(value[i], &tmp); err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "json unmarshal err"))
return
}
resp.Targets = append(resp.Targets, tmp.Target)
resp.Labels = tmp.Labels // default label is fixed. See prommetrics.BuildDefaultTarget
}
c.JSON(http.StatusOK, []*prommetrics.RespTarget{&resp})
}
func (p *PrometheusDiscoveryApi) Api(c *gin.Context) {
p.discovery(c, prommetrics.APIKeyName)
}
func (p *PrometheusDiscoveryApi) User(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.User)
}
func (p *PrometheusDiscoveryApi) Group(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Group)
}
func (p *PrometheusDiscoveryApi) Msg(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Msg)
}
func (p *PrometheusDiscoveryApi) Friend(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Friend)
}
func (p *PrometheusDiscoveryApi) Conversation(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Conversation)
}
func (p *PrometheusDiscoveryApi) Third(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Third)
}
func (p *PrometheusDiscoveryApi) Auth(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Auth)
}
func (p *PrometheusDiscoveryApi) Push(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Push)
}
func (p *PrometheusDiscoveryApi) MessageGateway(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.MessageGateway)
}
func (p *PrometheusDiscoveryApi) MessageTransfer(c *gin.Context) {
p.discovery(c, prommetrics.MessageTransferKeyName)
}

83
internal/api/ratelimit.go Normal file
View File

@@ -0,0 +1,83 @@
package api
import (
"fmt"
"math"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/stability/ratelimit"
"github.com/openimsdk/tools/stability/ratelimit/bbr"
)
type RateLimiter struct {
Enable bool `yaml:"enable"`
Window time.Duration `yaml:"window"` // time duration per window
Bucket int `yaml:"bucket"` // bucket number for each window
CPUThreshold int64 `yaml:"cpuThreshold"` // CPU threshold; valid range 01000 (1000 = 100%)
}
func RateLimitMiddleware(config *RateLimiter) gin.HandlerFunc {
if !config.Enable {
return func(c *gin.Context) {
c.Next()
}
}
limiter := bbr.NewBBRLimiter(
bbr.WithWindow(config.Window),
bbr.WithBucket(config.Bucket),
bbr.WithCPUThreshold(config.CPUThreshold),
)
return func(c *gin.Context) {
status := limiter.Stat()
c.Header("X-BBR-CPU", strconv.FormatInt(status.CPU, 10))
c.Header("X-BBR-MinRT", strconv.FormatInt(status.MinRt, 10))
c.Header("X-BBR-MaxPass", strconv.FormatInt(status.MaxPass, 10))
c.Header("X-BBR-MaxInFlight", strconv.FormatInt(status.MaxInFlight, 10))
c.Header("X-BBR-InFlight", strconv.FormatInt(status.InFlight, 10))
done, err := limiter.Allow()
if err != nil {
c.Header("X-RateLimit-Policy", "BBR")
c.Header("Retry-After", calculateBBRRetryAfter(status))
c.Header("X-RateLimit-Limit", strconv.FormatInt(status.MaxInFlight, 10))
c.Header("X-RateLimit-Remaining", "0") // There is no concept of remaining quota in BBR.
fmt.Println("rate limited:", err, "path:", c.Request.URL.Path)
log.ZWarn(c, "rate limited", err, "path", c.Request.URL.Path)
c.AbortWithStatus(http.StatusTooManyRequests)
apiresp.GinError(c, errs.NewCodeError(http.StatusTooManyRequests, "too many requests, please try again later"))
return
}
c.Next()
done(ratelimit.DoneInfo{})
}
}
func calculateBBRRetryAfter(status bbr.Stat) string {
loadRatio := float64(status.CPU) / float64(status.CPU)
if loadRatio < 0.8 {
return "1"
}
if loadRatio < 0.95 {
return "2"
}
backoff := 1 + int64(math.Pow(loadRatio-0.95, 2)*50)
if backoff > 5 {
backoff = 5
}
return strconv.FormatInt(backoff, 10)
}

1022
internal/api/redpacket.go Normal file

File diff suppressed because it is too large Load Diff

493
internal/api/router.go Normal file
View File

@@ -0,0 +1,493 @@
package api
import (
"context"
"net/http"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/internal/api/jssdk"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
pbAuth "git.imall.cloud/openim/protocol/auth"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/third"
"git.imall.cloud/openim/protocol/user"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/discovery/etcd"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mw"
"github.com/openimsdk/tools/mw/api"
clientv3 "go.etcd.io/etcd/client/v3"
)
const (
NoCompression = -1
DefaultCompression = 0
BestCompression = 1
BestSpeed = 2
)
func prommetricsGin() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
path := c.FullPath()
if c.Writer.Status() == http.StatusNotFound {
prommetrics.HttpCall("<404>", c.Request.Method, c.Writer.Status())
} else {
prommetrics.HttpCall(path, c.Request.Method, c.Writer.Status())
}
if resp := apiresp.GetGinApiResponse(c); resp != nil {
prommetrics.APICall(path, c.Request.Method, resp.ErrCode)
}
}
}
func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cfg *Config) (*gin.Engine, error) {
authConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Auth)
if err != nil {
return nil, err
}
userConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.User)
if err != nil {
return nil, err
}
groupConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Group)
if err != nil {
return nil, err
}
friendConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Friend)
if err != nil {
return nil, err
}
conversationConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Conversation)
if err != nil {
return nil, err
}
thirdConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Third)
if err != nil {
return nil, err
}
msgConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Msg)
if err != nil {
return nil, err
}
// 初始化数据库连接(用于红包功能)
dbb := dbbuild.NewBuilder(&cfg.AllConfig.Mongo, &cfg.AllConfig.Redis)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return nil, err
}
redisClient, err := dbb.Redis(ctx)
if err != nil {
return nil, err
}
redPacketDB, err := mgo.NewRedPacketMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
redPacketReceiveDB, err := mgo.NewRedPacketReceiveMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
walletDB, err := mgo.NewWalletMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
walletBalanceRecordDB, err := mgo.NewWalletBalanceRecordMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
userDB, err := mgo.NewUserMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
meetingDB, err := mgo.NewMeetingMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
meetingCheckInDB, err := mgo.NewMeetingCheckInMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
msgDocDatabase, err := mgo.NewMsgMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
startOnlineCountRefresher(ctx, cfg, redisClient)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
_ = v.RegisterValidation("required_if", RequiredIf)
}
switch cfg.API.Api.CompressionLevel {
case NoCompression:
case DefaultCompression:
r.Use(gzip.Gzip(gzip.DefaultCompression))
case BestCompression:
r.Use(gzip.Gzip(gzip.BestCompression))
case BestSpeed:
r.Use(gzip.Gzip(gzip.BestSpeed))
}
// Use rate limiter middleware
if cfg.API.RateLimiter.Enable {
rl := &RateLimiter{
Enable: cfg.API.RateLimiter.Enable,
Window: cfg.API.RateLimiter.Window,
Bucket: cfg.API.RateLimiter.Bucket,
CPUThreshold: cfg.API.RateLimiter.CPUThreshold,
}
r.Use(RateLimitMiddleware(rl))
}
if config.Standalone() {
r.Use(func(c *gin.Context) {
c.Set(authverify.CtxAdminUserIDsKey, cfg.Share.IMAdminUser.UserIDs)
})
}
r.Use(api.GinLogger(), prommetricsGin(), gin.RecoveryWithWriter(gin.DefaultErrorWriter, mw.GinPanicErr), mw.CorsHandler(),
mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)), setGinIsAdmin(cfg.Share.IMAdminUser.UserIDs))
u := NewUserApi(user.NewUserClient(userConn), client, cfg.Discovery.RpcService)
{
userRouterGroup := r.Group("/user")
userRouterGroup.POST("/user_register", u.UserRegister)
userRouterGroup.POST("/update_user_info", u.UpdateUserInfo)
userRouterGroup.POST("/update_user_info_ex", u.UpdateUserInfoEx)
userRouterGroup.POST("/set_global_msg_recv_opt", u.SetGlobalRecvMessageOpt)
userRouterGroup.POST("/get_users_info", u.GetUsersPublicInfo)
userRouterGroup.POST("/get_all_users_uid", u.GetAllUsersID)
userRouterGroup.POST("/account_check", u.AccountCheck)
userRouterGroup.POST("/get_users", u.GetUsers)
userRouterGroup.POST("/get_users_online_status", u.GetUsersOnlineStatus)
userRouterGroup.POST("/get_users_online_token_detail", u.GetUsersOnlineTokenDetail)
userRouterGroup.POST("/subscribe_users_status", u.SubscriberStatus)
userRouterGroup.POST("/get_users_status", u.GetUserStatus)
userRouterGroup.POST("/get_subscribe_users_status", u.GetSubscribeUsersStatus)
userRouterGroup.POST("/process_user_command_add", u.ProcessUserCommandAdd)
userRouterGroup.POST("/process_user_command_delete", u.ProcessUserCommandDelete)
userRouterGroup.POST("/process_user_command_update", u.ProcessUserCommandUpdate)
userRouterGroup.POST("/process_user_command_get", u.ProcessUserCommandGet)
userRouterGroup.POST("/process_user_command_get_all", u.ProcessUserCommandGetAll)
userRouterGroup.POST("/add_notification_account", u.AddNotificationAccount)
userRouterGroup.POST("/update_notification_account", u.UpdateNotificationAccountInfo)
userRouterGroup.POST("/search_notification_account", u.SearchNotificationAccount)
userRouterGroup.POST("/get_user_client_config", u.GetUserClientConfig)
userRouterGroup.POST("/set_user_client_config", u.SetUserClientConfig)
userRouterGroup.POST("/del_user_client_config", u.DelUserClientConfig)
userRouterGroup.POST("/page_user_client_config", u.PageUserClientConfig)
}
// friend routing group
{
f := NewFriendApi(relation.NewFriendClient(friendConn))
friendRouterGroup := r.Group("/friend")
friendRouterGroup.POST("/delete_friend", f.DeleteFriend)
friendRouterGroup.POST("/get_friend_apply_list", f.GetFriendApplyList)
friendRouterGroup.POST("/get_designated_friend_apply", f.GetDesignatedFriendsApply)
friendRouterGroup.POST("/get_self_friend_apply_list", f.GetSelfApplyList)
friendRouterGroup.POST("/get_friend_list", f.GetFriendList)
friendRouterGroup.POST("/get_designated_friends", f.GetDesignatedFriends)
friendRouterGroup.POST("/add_friend", f.ApplyToAddFriend)
friendRouterGroup.POST("/add_friend_response", f.RespondFriendApply)
friendRouterGroup.POST("/set_friend_remark", f.SetFriendRemark)
friendRouterGroup.POST("/add_black", f.AddBlack)
friendRouterGroup.POST("/get_black_list", f.GetPaginationBlacks)
friendRouterGroup.POST("/get_specified_blacks", f.GetSpecifiedBlacks)
friendRouterGroup.POST("/remove_black", f.RemoveBlack)
friendRouterGroup.POST("/get_incremental_blacks", f.GetIncrementalBlacks)
friendRouterGroup.POST("/import_friend", f.ImportFriends)
friendRouterGroup.POST("/is_friend", f.IsFriend)
friendRouterGroup.POST("/get_friend_id", f.GetFriendIDs)
friendRouterGroup.POST("/get_specified_friends_info", f.GetSpecifiedFriendsInfo)
friendRouterGroup.POST("/update_friends", f.UpdateFriends)
friendRouterGroup.POST("/get_incremental_friends", f.GetIncrementalFriends)
friendRouterGroup.POST("/get_full_friend_user_ids", f.GetFullFriendUserIDs)
friendRouterGroup.POST("/get_self_unhandled_apply_count", f.GetSelfUnhandledApplyCount)
}
g := NewGroupApi(group.NewGroupClient(groupConn))
{
groupRouterGroup := r.Group("/group")
groupRouterGroup.POST("/create_group", g.CreateGroup)
groupRouterGroup.POST("/set_group_info", g.SetGroupInfo)
groupRouterGroup.POST("/set_group_info_ex", g.SetGroupInfoEx)
groupRouterGroup.POST("/join_group", g.JoinGroup)
groupRouterGroup.POST("/quit_group", g.QuitGroup)
groupRouterGroup.POST("/group_application_response", g.ApplicationGroupResponse)
groupRouterGroup.POST("/transfer_group", g.TransferGroupOwner)
groupRouterGroup.POST("/get_recv_group_applicationList", g.GetRecvGroupApplicationList)
groupRouterGroup.POST("/get_user_req_group_applicationList", g.GetUserReqGroupApplicationList)
groupRouterGroup.POST("/get_group_users_req_application_list", g.GetGroupUsersReqApplicationList)
groupRouterGroup.POST("/get_specified_user_group_request_info", g.GetSpecifiedUserGroupRequestInfo)
groupRouterGroup.POST("/get_groups_info", g.GetGroupsInfo)
groupRouterGroup.POST("/kick_group", g.KickGroupMember)
groupRouterGroup.POST("/get_group_members_info", g.GetGroupMembersInfo)
groupRouterGroup.POST("/get_group_member_list", g.GetGroupMemberList)
groupRouterGroup.POST("/invite_user_to_group", g.InviteUserToGroup)
groupRouterGroup.POST("/get_joined_group_list", g.GetJoinedGroupList)
groupRouterGroup.POST("/dismiss_group", g.DismissGroup) //
groupRouterGroup.POST("/mute_group_member", g.MuteGroupMember)
groupRouterGroup.POST("/cancel_mute_group_member", g.CancelMuteGroupMember)
groupRouterGroup.POST("/mute_group", g.MuteGroup)
groupRouterGroup.POST("/cancel_mute_group", g.CancelMuteGroup)
groupRouterGroup.POST("/set_group_member_info", g.SetGroupMemberInfo)
groupRouterGroup.POST("/get_group_abstract_info", g.GetGroupAbstractInfo)
groupRouterGroup.POST("/get_groups", g.GetGroups)
groupRouterGroup.POST("/get_group_member_user_id", g.GetGroupMemberUserIDs)
groupRouterGroup.POST("/get_incremental_join_groups", g.GetIncrementalJoinGroup)
groupRouterGroup.POST("/get_incremental_group_members", g.GetIncrementalGroupMember)
groupRouterGroup.POST("/get_incremental_group_members_batch", g.GetIncrementalGroupMemberBatch)
groupRouterGroup.POST("/get_full_group_member_user_ids", g.GetFullGroupMemberUserIDs)
groupRouterGroup.POST("/get_full_join_group_ids", g.GetFullJoinGroupIDs)
groupRouterGroup.POST("/get_group_application_unhandled_count", g.GetGroupApplicationUnhandledCount)
}
// certificate
{
a := NewAuthApi(pbAuth.NewAuthClient(authConn))
authRouterGroup := r.Group("/auth")
authRouterGroup.POST("/get_admin_token", a.GetAdminToken)
authRouterGroup.POST("/get_user_token", a.GetUserToken)
authRouterGroup.POST("/parse_token", a.ParseToken)
authRouterGroup.POST("/force_logout", a.ForceLogout)
}
// Third service
{
t := NewThirdApi(third.NewThirdClient(thirdConn), cfg.API.Prometheus.GrafanaURL)
thirdGroup := r.Group("/third")
thirdGroup.GET("/prometheus", t.GetPrometheus)
thirdGroup.POST("/fcm_update_token", t.FcmUpdateToken)
thirdGroup.POST("/set_app_badge", t.SetAppBadge)
logs := thirdGroup.Group("/logs")
logs.POST("/upload", t.UploadLogs)
logs.POST("/delete", t.DeleteLogs)
logs.POST("/search", t.SearchLogs)
objectGroup := r.Group("/object")
objectGroup.POST("/part_limit", t.PartLimit)
objectGroup.POST("/part_size", t.PartSize)
objectGroup.POST("/initiate_multipart_upload", t.InitiateMultipartUpload)
objectGroup.POST("/auth_sign", t.AuthSign)
objectGroup.POST("/complete_multipart_upload", t.CompleteMultipartUpload)
objectGroup.POST("/access_url", t.AccessURL)
objectGroup.POST("/initiate_form_data", t.InitiateFormData)
objectGroup.POST("/complete_form_data", t.CompleteFormData)
objectGroup.GET("/*name", t.ObjectRedirect)
}
// Message
m := NewMessageApi(msg.NewMsgClient(msgConn), rpcli.NewUserClient(userConn), cfg.Share.IMAdminUser.UserIDs)
{
msgGroup := r.Group("/msg")
msgGroup.POST("/newest_seq", m.GetSeq)
msgGroup.POST("/search_msg", m.SearchMsg)
msgGroup.POST("/send_msg", m.SendMessage)
msgGroup.POST("/send_business_notification", m.SendBusinessNotification)
msgGroup.POST("/pull_msg_by_seq", m.PullMsgBySeqs)
msgGroup.POST("/revoke_msg", m.RevokeMsg)
msgGroup.POST("/mark_msgs_as_read", m.MarkMsgsAsRead)
msgGroup.POST("/mark_conversation_as_read", m.MarkConversationAsRead)
msgGroup.POST("/get_conversations_has_read_and_max_seq", m.GetConversationsHasReadAndMaxSeq)
msgGroup.POST("/set_conversation_has_read_seq", m.SetConversationHasReadSeq)
msgGroup.POST("/clear_conversation_msg", m.ClearConversationsMsg)
msgGroup.POST("/user_clear_all_msg", m.UserClearAllMsg)
msgGroup.POST("/delete_msgs", m.DeleteMsgs)
msgGroup.POST("/delete_msg_phsical_by_seq", m.DeleteMsgPhysicalBySeq)
msgGroup.POST("/delete_msg_physical", m.DeleteMsgPhysical)
msgGroup.POST("/batch_send_msg", m.BatchSendMsg)
msgGroup.POST("/send_simple_msg", m.SendSimpleMessage)
msgGroup.POST("/check_msg_is_send_success", m.CheckMsgIsSendSuccess)
msgGroup.POST("/get_server_time", m.GetServerTime)
}
// RedPacket
{
rp := NewRedPacketApi(rpcli.NewGroupClient(groupConn), rpcli.NewUserClient(userConn), msg.NewMsgClient(msgConn), redPacketDB, redPacketReceiveDB, walletDB, walletBalanceRecordDB, redisClient)
redPacketGroup := r.Group("/redpacket")
redPacketGroup.POST("/send_redpacket", rp.SendRedPacket)
redPacketGroup.POST("/receive", rp.ReceiveRedPacket)
redPacketGroup.POST("/get_detail", rp.GetRedPacketDetail) // 用户端查询红包详情
// 后台管理接口
redPacketGroup.POST("/get_redpackets_by_group", rp.GetRedPacketsByGroup)
redPacketGroup.POST("/get_receive_info", rp.GetRedPacketReceiveInfo)
redPacketGroup.POST("/pause", rp.PauseRedPacket)
}
// Wallet
{
wa := NewWalletApi(walletDB, walletBalanceRecordDB, userDB, rpcli.NewUserClient(userConn))
walletGroup := r.Group("/wallet")
// 后台管理接口
walletGroup.POST("/get_wallets", wa.GetWallets)
walletGroup.POST("/batch_update_balance", wa.BatchUpdateWalletBalance)
}
// Meeting
{
// 使用已初始化的systemConfigDB如果失败则使用nil
meetingSystemConfigDB, initErr := mgo.NewSystemConfigMongo(mgocli.GetDB())
if initErr != nil {
log.ZWarn(ctx, "failed to init system config db for meeting api, webhook attentionIds update will be disabled", initErr)
meetingSystemConfigDB = nil
}
ma := NewMeetingApi(meetingDB, meetingCheckInDB, rpcli.NewGroupClient(groupConn), rpcli.NewUserClient(userConn), rpcli.NewConversationClient(conversationConn), meetingSystemConfigDB)
meetingGroup := r.Group("/meeting")
// 管理接口
meetingGroup.POST("/create_meeting", ma.CreateMeeting)
meetingGroup.POST("/update_meeting", ma.UpdateMeeting)
meetingGroup.POST("/get_meetings", ma.GetMeetings)
meetingGroup.POST("/delete_meeting", ma.DeleteMeeting)
// 用户端接口
meetingGroup.POST("/get_meeting", ma.GetMeetingPublic)
meetingGroup.POST("/get_meetings_public", ma.GetMeetingsPublic)
// 签到接口
meetingGroup.POST("/check_in", ma.CheckInMeeting)
meetingGroup.POST("/check_user_check_in", ma.CheckUserCheckIn)
meetingGroup.POST("/get_check_ins", ma.GetMeetingCheckIns)
meetingGroup.POST("/get_check_in_stats", ma.GetMeetingCheckInStats)
}
// Conversation
{
c := NewConversationApi(conversation.NewConversationClient(conversationConn))
conversationGroup := r.Group("/conversation")
conversationGroup.POST("/get_sorted_conversation_list", c.GetSortedConversationList)
conversationGroup.POST("/get_all_conversations", c.GetAllConversations)
conversationGroup.POST("/get_conversation", c.GetConversation)
conversationGroup.POST("/get_conversations", c.GetConversations)
conversationGroup.POST("/set_conversations", c.SetConversations)
//conversationGroup.POST("/get_conversation_offline_push_user_ids", c.GetConversationOfflinePushUserIDs)
conversationGroup.POST("/get_full_conversation_ids", c.GetFullOwnerConversationIDs)
conversationGroup.POST("/get_incremental_conversations", c.GetIncrementalConversation)
conversationGroup.POST("/get_owner_conversation", c.GetOwnerConversation)
conversationGroup.POST("/get_not_notify_conversation_ids", c.GetNotNotifyConversationIDs)
conversationGroup.POST("/get_pinned_conversation_ids", c.GetPinnedConversationIDs)
conversationGroup.POST("/delete_conversations", c.DeleteConversations)
}
stats := NewStatisticsApi(redisClient, msgDocDatabase, rpcli.NewUserClient(userConn), rpcli.NewGroupClient(groupConn))
{
statisticsGroup := r.Group("/statistics")
statisticsGroup.POST("/user/register", u.UserRegisterCount)
statisticsGroup.POST("/user/active", m.GetActiveUser)
statisticsGroup.POST("/group/create", g.GroupCreateCount)
statisticsGroup.POST("/group/active", m.GetActiveGroup)
statisticsGroup.POST("/online_user_count", stats.OnlineUserCount)
statisticsGroup.POST("/online_user_count_trend", stats.OnlineUserCountTrend)
statisticsGroup.POST("/user_send_msg_count", stats.UserSendMsgCount)
statisticsGroup.POST("/user_send_msg_count_trend", stats.UserSendMsgCountTrend)
statisticsGroup.POST("/user_send_msg_query", stats.UserSendMsgQuery)
}
{
j := jssdk.NewJSSdkApi(rpcli.NewUserClient(userConn), rpcli.NewRelationClient(friendConn),
rpcli.NewGroupClient(groupConn), rpcli.NewConversationClient(conversationConn), rpcli.NewMsgClient(msgConn))
jssdk := r.Group("/jssdk")
jssdk.POST("/get_conversations", j.GetConversations)
jssdk.POST("/get_active_conversations", j.GetActiveConversations)
}
{
pd := NewPrometheusDiscoveryApi(cfg, client)
proDiscoveryGroup := r.Group("/prometheus_discovery")
proDiscoveryGroup.GET("/api", pd.Api)
proDiscoveryGroup.GET("/user", pd.User)
proDiscoveryGroup.GET("/group", pd.Group)
proDiscoveryGroup.GET("/msg", pd.Msg)
proDiscoveryGroup.GET("/friend", pd.Friend)
proDiscoveryGroup.GET("/conversation", pd.Conversation)
proDiscoveryGroup.GET("/third", pd.Third)
proDiscoveryGroup.GET("/auth", pd.Auth)
proDiscoveryGroup.GET("/push", pd.Push)
proDiscoveryGroup.GET("/msg_gateway", pd.MessageGateway)
proDiscoveryGroup.GET("/msg_transfer", pd.MessageTransfer)
}
var etcdClient *clientv3.Client
if cfg.Discovery.Enable == config.ETCD {
etcdClient = client.(*etcd.SvcDiscoveryRegistryImpl).GetClient()
}
// 初始化SystemConfig数据库用于webhook配置管理
var systemConfigDB database.SystemConfig
systemConfigDB, err = mgo.NewSystemConfigMongo(mgocli.GetDB())
if err != nil {
log.ZWarn(ctx, "failed to init system config db, webhook config management will be disabled", err)
}
cm := NewConfigManager(cfg.Share.IMAdminUser.UserIDs, &cfg.AllConfig, etcdClient, string(cfg.ConfigPath), systemConfigDB)
{
configGroup := r.Group("/config", cm.CheckAdmin)
configGroup.POST("/get_config_list", cm.GetConfigList)
configGroup.POST("/get_config", cm.GetConfig)
configGroup.POST("/set_config", cm.SetConfig)
configGroup.POST("/reset_config", cm.ResetConfig)
configGroup.POST("/set_enable_config_manager", cm.SetEnableConfigManager)
configGroup.POST("/get_enable_config_manager", cm.GetEnableConfigManager)
}
{
r.POST("/restart", cm.CheckAdmin, cm.Restart)
}
return r, nil
}
func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc {
return func(c *gin.Context) {
switch c.Request.Method {
case http.MethodPost:
for _, wApi := range Whitelist {
if strings.HasPrefix(c.Request.URL.Path, wApi) {
c.Next()
return
}
}
token := c.Request.Header.Get(constant.Token)
if token == "" {
log.ZWarn(c, "header get token error", servererrs.ErrArgs.WrapMsg("header must have token"))
apiresp.GinError(c, servererrs.ErrArgs.WrapMsg("header must have token"))
c.Abort()
return
}
resp, err := authClient.ParseToken(c, token)
if err != nil {
apiresp.GinError(c, err)
c.Abort()
return
}
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(int(resp.PlatformID)))
c.Set(constant.OpUserID, resp.UserID)
c.Next()
}
}
}
func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(authverify.CtxAdminUserIDsKey, imAdminUserID)
}
}
// Whitelist api not parse token
var Whitelist = []string{
"/auth/get_admin_token",
"/auth/parse_token",
}

555
internal/api/statistics.go Normal file
View File

@@ -0,0 +1,555 @@
package api
import (
"context"
"errors"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
rediscache "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
type StatisticsApi struct {
rdb redis.UniversalClient
msgDatabase database.Msg
userClient *rpcli.UserClient
groupClient *rpcli.GroupClient
}
func NewStatisticsApi(rdb redis.UniversalClient, msgDatabase database.Msg, userClient *rpcli.UserClient, groupClient *rpcli.GroupClient) *StatisticsApi {
return &StatisticsApi{
rdb: rdb,
msgDatabase: msgDatabase,
userClient: userClient,
groupClient: groupClient,
}
}
const (
trendIntervalMinutes15 = 15
trendIntervalMinutes30 = 30
trendIntervalMinutes60 = 60
trendChatTypeSingle = 1
trendChatTypeGroup = 2
defaultTrendDuration = 24 * time.Hour
)
// refreshOnlineUserCountAndHistory 刷新在线人数并写入历史采样
func refreshOnlineUserCountAndHistory(ctx context.Context, rdb redis.UniversalClient) {
count, err := rediscache.RefreshOnlineUserCount(ctx, rdb)
if err != nil {
log.ZWarn(ctx, "refresh online user count failed", err)
return
}
if err := rediscache.AppendOnlineUserCountHistory(ctx, rdb, time.Now().UnixMilli(), count); err != nil {
log.ZWarn(ctx, "append online user count history failed", err)
}
}
// startOnlineCountRefresher 定时刷新在线人数缓存
func startOnlineCountRefresher(ctx context.Context, cfg *Config, rdb redis.UniversalClient) {
if cfg == nil || rdb == nil {
return
}
refreshCfg := cfg.API.OnlineCountRefresh
if !refreshCfg.Enable || refreshCfg.Interval <= 0 {
return
}
log.ZInfo(ctx, "online user count refresh enabled", "interval", refreshCfg.Interval)
go func() {
refreshOnlineUserCountAndHistory(ctx, rdb)
ticker := time.NewTicker(refreshCfg.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
refreshOnlineUserCountAndHistory(ctx, rdb)
}
}
}()
}
// OnlineUserCount 在线人数统计接口
func (s *StatisticsApi) OnlineUserCount(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
if s.rdb == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("redis client is nil"))
return
}
count, err := rediscache.GetOnlineUserCount(c, s.rdb)
if err != nil {
if errors.Is(err, redis.Nil) {
count, err = rediscache.RefreshOnlineUserCount(c, s.rdb)
if err == nil {
if appendErr := rediscache.AppendOnlineUserCountHistory(c, s.rdb, time.Now().UnixMilli(), count); appendErr != nil {
log.ZWarn(c, "append online user count history failed", appendErr)
}
}
}
}
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, &apistruct.OnlineUserCountResp{OnlineCount: count})
}
// OnlineUserCountTrend 在线人数走势统计接口
func (s *StatisticsApi) OnlineUserCountTrend(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.OnlineUserCountTrendReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.rdb == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("redis client is nil"))
return
}
intervalMillis, err := parseTrendIntervalMillis(req.IntervalMinutes)
if err != nil {
apiresp.GinError(c, err)
return
}
startTime, endTime, err := normalizeTrendTimeRange(req.StartTime, req.EndTime)
if err != nil {
apiresp.GinError(c, err)
return
}
bucketStart, bucketEnd := alignTrendRange(startTime, endTime, intervalMillis)
// 使用对齐后的时间范围获取历史数据,确保数据范围与构建数据点的范围一致
samples, err := rediscache.GetOnlineUserCountHistory(c, s.rdb, bucketStart, bucketEnd)
if err != nil {
apiresp.GinError(c, err)
return
}
// 将当前在线人数作为最新采样,确保最后一个时间段展示该段内的最大在线人数
now := time.Now().UnixMilli()
currentBucket := now - (now % intervalMillis)
if now < 0 && now%intervalMillis != 0 {
currentBucket = now - ((now % intervalMillis) + intervalMillis)
}
if currentBucket >= bucketStart && currentBucket <= bucketEnd {
if currentCount, err := rediscache.GetOnlineUserCount(c, s.rdb); err == nil {
samples = append(samples, rediscache.OnlineUserCountSample{
Timestamp: now,
Count: currentCount,
})
}
}
points := buildOnlineUserCountTrendPoints(samples, bucketStart, bucketEnd, intervalMillis)
apiresp.GinSuccess(c, &apistruct.OnlineUserCountTrendResp{
IntervalMinutes: req.IntervalMinutes,
Points: points,
})
}
// UserSendMsgCount 用户发送消息总数统计
func (s *StatisticsApi) UserSendMsgCount(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
_, err := a2r.ParseRequest[apistruct.UserSendMsgCountReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
now := time.Now()
endTime := now.UnixMilli()
start24h := now.Add(-24 * time.Hour).UnixMilli()
start7d := now.Add(-7 * 24 * time.Hour).UnixMilli()
start30d := now.Add(-30 * 24 * time.Hour).UnixMilli()
count24h, err := s.msgDatabase.CountUserSendMessages(c, "", start24h, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
count7d, err := s.msgDatabase.CountUserSendMessages(c, "", start7d, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
count30d, err := s.msgDatabase.CountUserSendMessages(c, "", start30d, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, &apistruct.UserSendMsgCountResp{
Count24h: count24h,
Count7d: count7d,
Count30d: count30d,
})
}
// UserSendMsgCountTrend 用户发送消息走势统计
func (s *StatisticsApi) UserSendMsgCountTrend(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.UserSendMsgCountTrendReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
intervalMillis, err := parseTrendIntervalMillis(req.IntervalMinutes)
if err != nil {
apiresp.GinError(c, err)
return
}
startTime, endTime, err := normalizeTrendTimeRange(req.StartTime, req.EndTime)
if err != nil {
apiresp.GinError(c, err)
return
}
sessionTypes, err := mapTrendChatType(req.ChatType)
if err != nil {
apiresp.GinError(c, err)
return
}
bucketStart, bucketEnd := alignTrendRange(startTime, endTime, intervalMillis)
countMap, err := s.msgDatabase.CountUserSendMessagesTrend(c, req.UserID, sessionTypes, startTime, endTime, intervalMillis)
if err != nil {
apiresp.GinError(c, err)
return
}
points := buildUserSendMsgCountTrendPoints(countMap, bucketStart, bucketEnd, intervalMillis)
apiresp.GinSuccess(c, &apistruct.UserSendMsgCountTrendResp{
UserID: req.UserID,
ChatType: req.ChatType,
IntervalMinutes: req.IntervalMinutes,
Points: points,
})
}
// UserSendMsgQuery 用户发送消息查询
func (s *StatisticsApi) UserSendMsgQuery(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.UserSendMsgQueryReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if req.StartTime > 0 && req.EndTime > 0 && req.EndTime <= req.StartTime {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("invalid time range"))
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
pageNumber := req.PageNumber
if pageNumber <= 0 {
pageNumber = 1
}
showNumber := req.ShowNumber
if showNumber <= 0 {
showNumber = 50
}
const maxShowNumber int32 = 200
if showNumber > maxShowNumber {
showNumber = maxShowNumber
}
total, msgs, err := s.msgDatabase.SearchUserMessages(c, req.UserID, req.StartTime, req.EndTime, req.Content, pageNumber, showNumber)
if err != nil {
apiresp.GinError(c, err)
return
}
sendIDs := make([]string, 0, len(msgs))
recvIDs := make([]string, 0, len(msgs))
groupIDs := make([]string, 0, len(msgs))
for _, item := range msgs {
if item == nil || item.Msg == nil {
continue
}
msg := item.Msg
if msg.SendID != "" {
sendIDs = append(sendIDs, msg.SendID)
}
switch msg.SessionType {
case constant.ReadGroupChatType, constant.WriteGroupChatType:
if msg.GroupID != "" {
groupIDs = append(groupIDs, msg.GroupID)
}
default:
if msg.RecvID != "" {
recvIDs = append(recvIDs, msg.RecvID)
}
}
}
sendIDs = datautil.Distinct(sendIDs)
recvIDs = datautil.Distinct(recvIDs)
groupIDs = datautil.Distinct(groupIDs)
sendMap, recvMap, groupMap := map[string]*sdkws.UserInfo{}, map[string]*sdkws.UserInfo{}, map[string]*sdkws.GroupInfo{}
if s.userClient != nil {
if len(sendIDs) > 0 {
if users, err := s.userClient.GetUsersInfo(c, sendIDs); err == nil {
sendMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
}
}
if len(recvIDs) > 0 {
if users, err := s.userClient.GetUsersInfo(c, recvIDs); err == nil {
recvMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
}
}
}
if s.groupClient != nil && len(groupIDs) > 0 {
if groups, err := s.groupClient.GetGroupsInfo(c, groupIDs); err == nil {
groupMap = datautil.SliceToMap(groups, (*sdkws.GroupInfo).GetGroupID)
}
}
records := make([]*apistruct.UserSendMsgQueryRecord, 0, len(msgs))
for _, item := range msgs {
if item == nil || item.Msg == nil {
continue
}
msg := item.Msg
msgID := msg.ServerMsgID
if msgID == "" {
msgID = msg.ClientMsgID
}
senderName := msg.SenderNickname
if senderName == "" {
if u := sendMap[msg.SendID]; u != nil {
senderName = u.Nickname
} else {
senderName = msg.SendID
}
}
recvID := msg.RecvID
recvName := ""
if msg.SessionType == constant.ReadGroupChatType || msg.SessionType == constant.WriteGroupChatType {
if msg.GroupID != "" {
recvID = msg.GroupID
}
if g := groupMap[recvID]; g != nil {
recvName = g.GroupName
} else if recvID != "" {
recvName = recvID
}
} else {
if u := recvMap[msg.RecvID]; u != nil {
recvName = u.Nickname
} else if msg.RecvID != "" {
recvName = msg.RecvID
}
}
records = append(records, &apistruct.UserSendMsgQueryRecord{
MsgID: msgID,
SendID: msg.SendID,
SenderName: senderName,
RecvID: recvID,
RecvName: recvName,
ContentType: msg.ContentType,
ContentTypeName: contentTypeName(msg.ContentType),
SessionType: msg.SessionType,
ChatTypeName: chatTypeName(msg.SessionType),
Content: msg.Content,
SendTime: msg.SendTime,
})
}
apiresp.GinSuccess(c, &apistruct.UserSendMsgQueryResp{
Count: total,
PageNumber: pageNumber,
ShowNumber: showNumber,
Records: records,
})
}
// parseTrendIntervalMillis 解析走势统计间隔并转换为毫秒
func parseTrendIntervalMillis(intervalMinutes int32) (int64, error) {
switch intervalMinutes {
case trendIntervalMinutes15, trendIntervalMinutes30, trendIntervalMinutes60:
return int64(intervalMinutes) * int64(time.Minute/time.Millisecond), nil
default:
return 0, errs.ErrArgs.WrapMsg("invalid intervalMinutes")
}
}
// normalizeTrendTimeRange 标准化走势统计时间区间
func normalizeTrendTimeRange(startTime int64, endTime int64) (int64, int64, error) {
now := time.Now().UnixMilli()
if endTime <= 0 {
endTime = now
}
if startTime <= 0 {
startTime = endTime - int64(defaultTrendDuration/time.Millisecond)
}
if startTime < 0 {
startTime = 0
}
if endTime <= startTime {
return 0, 0, errs.ErrArgs.WrapMsg("invalid time range")
}
return startTime, endTime, nil
}
// alignTrendRange 对齐走势统计区间到间隔边界
func alignTrendRange(startTime int64, endTime int64, intervalMillis int64) (int64, int64) {
if intervalMillis <= 0 {
return startTime, endTime
}
// 开始时间向下对齐到间隔边界
bucketStart := startTime - (startTime % intervalMillis)
if startTime < 0 {
bucketStart = startTime - ((startTime % intervalMillis) + intervalMillis)
}
// 结束时间向下对齐到所在间隔的起始(只包含已发生的间隔)
bucketEnd := endTime - (endTime % intervalMillis)
if endTime < 0 && endTime%intervalMillis != 0 {
bucketEnd = endTime - ((endTime % intervalMillis) + intervalMillis)
}
// 确保至少覆盖一个间隔
if bucketEnd < bucketStart {
bucketEnd = bucketStart
}
return bucketStart, bucketEnd
}
// buildOnlineUserCountTrendPoints 构建在线人数走势数据点
func buildOnlineUserCountTrendPoints(samples []rediscache.OnlineUserCountSample, startTime int64, endTime int64, intervalMillis int64) []*apistruct.OnlineUserCountTrendItem {
points := make([]*apistruct.OnlineUserCountTrendItem, 0)
if intervalMillis <= 0 || endTime <= startTime {
return points
}
maxMap := make(map[int64]int64)
for _, sample := range samples {
// 将采样时间戳对齐到间隔边界
bucket := sample.Timestamp - (sample.Timestamp % intervalMillis)
// 处理负数时间戳的情况(虽然通常不会发生)
if sample.Timestamp < 0 && sample.Timestamp%intervalMillis != 0 {
bucket = sample.Timestamp - ((sample.Timestamp % intervalMillis) + intervalMillis)
}
if sample.Count > maxMap[bucket] {
maxMap[bucket] = sample.Count
}
}
// 计算需要生成的数据点数量
// endTime是对齐后的最后一个bucket的起始时间所以需要包含它
estimated := int((endTime-startTime)/intervalMillis) + 1
if estimated > 0 {
points = make([]*apistruct.OnlineUserCountTrendItem, 0, estimated)
}
// 生成从startTime到endTime包含endTime的所有时间点
// endTime已经是对齐后的最后一个bucket的起始时间
for ts := startTime; ts <= endTime; ts += intervalMillis {
maxVal := maxMap[ts]
points = append(points, &apistruct.OnlineUserCountTrendItem{
Timestamp: ts,
OnlineCount: maxVal,
})
}
return points
}
// buildUserSendMsgCountTrendPoints 构建用户发送消息走势数据点
func buildUserSendMsgCountTrendPoints(countMap map[int64]int64, startTime int64, endTime int64, intervalMillis int64) []*apistruct.UserSendMsgCountTrendItem {
points := make([]*apistruct.UserSendMsgCountTrendItem, 0)
if intervalMillis <= 0 || endTime <= startTime {
return points
}
estimated := int((endTime - startTime) / intervalMillis)
if estimated > 0 {
points = make([]*apistruct.UserSendMsgCountTrendItem, 0, estimated)
}
for ts := startTime; ts < endTime; ts += intervalMillis {
points = append(points, &apistruct.UserSendMsgCountTrendItem{
Timestamp: ts,
Count: countMap[ts],
})
}
return points
}
// mapTrendChatType 走势统计聊天类型转为 sessionType 列表
func mapTrendChatType(chatType int32) ([]int32, error) {
switch chatType {
case trendChatTypeSingle:
return []int32{constant.SingleChatType}, nil
case trendChatTypeGroup:
return []int32{constant.ReadGroupChatType, constant.WriteGroupChatType}, nil
default:
return nil, errs.ErrArgs.WrapMsg("invalid chatType")
}
}
// contentTypeName 消息类型名称转换
func contentTypeName(contentType int32) string {
switch contentType {
case constant.Text:
return "文本消息"
case constant.Picture:
return "图片消息"
case constant.Voice:
return "语音消息"
case constant.Video:
return "视频消息"
case constant.File:
return "文件消息"
case constant.AtText:
return "艾特消息"
case constant.Merger:
return "合并消息"
case constant.Card:
return "名片消息"
case constant.Location:
return "位置消息"
case constant.Custom:
return "自定义消息"
case constant.Revoke:
return "撤回消息"
case constant.MarkdownText:
return "Markdown消息"
default:
return "未知消息"
}
}
// chatTypeName 聊天类型名称转换
func chatTypeName(sessionType int32) string {
switch sessionType {
case constant.SingleChatType:
return "单聊"
case constant.ReadGroupChatType, constant.WriteGroupChatType:
return "群聊"
case constant.NotificationChatType:
return "通知"
default:
return "未知"
}
}

175
internal/api/third.go Normal file
View File

@@ -0,0 +1,175 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"math/rand"
"net/http"
"net/url"
"strconv"
"strings"
"google.golang.org/grpc"
"git.imall.cloud/openim/protocol/third"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
)
type ThirdApi struct {
GrafanaUrl string
Client third.ThirdClient
}
func NewThirdApi(client third.ThirdClient, grafanaUrl string) ThirdApi {
return ThirdApi{Client: client, GrafanaUrl: grafanaUrl}
}
func (o *ThirdApi) FcmUpdateToken(c *gin.Context) {
a2r.Call(c, third.ThirdClient.FcmUpdateToken, o.Client)
}
func (o *ThirdApi) SetAppBadge(c *gin.Context) {
a2r.Call(c, third.ThirdClient.SetAppBadge, o.Client)
}
// #################### s3 ####################
func setURLPrefixOption[A, B, C any](_ func(client C, ctx context.Context, req *A, options ...grpc.CallOption) (*B, error), fn func(*A) error) *a2r.Option[A, B] {
return &a2r.Option[A, B]{
BindAfter: fn,
}
}
func setURLPrefix(c *gin.Context, urlPrefix *string) error {
host := c.GetHeader("X-Request-Api")
if host != "" {
if strings.HasSuffix(host, "/") {
*urlPrefix = host + "object/"
return nil
} else {
*urlPrefix = host + "/object/"
return nil
}
}
u := url.URL{
Scheme: "http",
Host: c.Request.Host,
Path: "/object/",
}
if c.Request.TLS != nil {
u.Scheme = "https"
}
*urlPrefix = u.String()
return nil
}
func (o *ThirdApi) PartLimit(c *gin.Context) {
a2r.Call(c, third.ThirdClient.PartLimit, o.Client)
}
func (o *ThirdApi) PartSize(c *gin.Context) {
a2r.Call(c, third.ThirdClient.PartSize, o.Client)
}
func (o *ThirdApi) InitiateMultipartUpload(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.InitiateMultipartUpload, func(req *third.InitiateMultipartUploadReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.InitiateMultipartUpload, o.Client, opt)
}
func (o *ThirdApi) AuthSign(c *gin.Context) {
a2r.Call(c, third.ThirdClient.AuthSign, o.Client)
}
func (o *ThirdApi) CompleteMultipartUpload(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.CompleteMultipartUpload, func(req *third.CompleteMultipartUploadReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.CompleteMultipartUpload, o.Client, opt)
}
func (o *ThirdApi) AccessURL(c *gin.Context) {
a2r.Call(c, third.ThirdClient.AccessURL, o.Client)
}
func (o *ThirdApi) InitiateFormData(c *gin.Context) {
a2r.Call(c, third.ThirdClient.InitiateFormData, o.Client)
}
func (o *ThirdApi) CompleteFormData(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.CompleteFormData, func(req *third.CompleteFormDataReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.CompleteFormData, o.Client, opt)
}
func (o *ThirdApi) ObjectRedirect(c *gin.Context) {
name := c.Param("name")
if name == "" {
c.String(http.StatusBadRequest, "name is empty")
return
}
if name[0] == '/' {
name = name[1:]
}
operationID := c.Query("operationID")
if operationID == "" {
operationID = strconv.Itoa(rand.Int())
}
ctx := mcontext.SetOperationID(c, operationID)
query := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) == 0 {
continue
}
query[key] = values[0]
}
resp, err := o.Client.AccessURL(ctx, &third.AccessURLReq{Name: name, Query: query})
if err != nil {
if errs.ErrArgs.Is(err) {
c.String(http.StatusBadRequest, err.Error())
return
}
if errs.ErrRecordNotFound.Is(err) {
c.String(http.StatusNotFound, err.Error())
return
}
c.String(http.StatusInternalServerError, err.Error())
return
}
c.Redirect(http.StatusFound, resp.Url)
}
// #################### logs ####################.
func (o *ThirdApi) UploadLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.UploadLogs, o.Client)
}
func (o *ThirdApi) DeleteLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.DeleteLogs, o.Client)
}
func (o *ThirdApi) SearchLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.SearchLogs, o.Client)
}
func (o *ThirdApi) GetPrometheus(c *gin.Context) {
c.Redirect(http.StatusFound, o.GrafanaUrl)
}

260
internal/api/user.go Normal file
View File

@@ -0,0 +1,260 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"git.imall.cloud/openim/protocol/user"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
type UserApi struct {
Client user.UserClient
discov discovery.Conn
config config.RpcService
}
func NewUserApi(client user.UserClient, discov discovery.Conn, config config.RpcService) UserApi {
return UserApi{Client: client, discov: discov, config: config}
}
func (u *UserApi) UserRegister(c *gin.Context) {
a2r.Call(c, user.UserClient.UserRegister, u.Client)
}
// UpdateUserInfo is deprecated. Use UpdateUserInfoEx
func (u *UserApi) UpdateUserInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateUserInfo, u.Client)
}
func (u *UserApi) UpdateUserInfoEx(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateUserInfoEx, u.Client)
}
func (u *UserApi) SetGlobalRecvMessageOpt(c *gin.Context) {
a2r.Call(c, user.UserClient.SetGlobalRecvMessageOpt, u.Client)
}
func (u *UserApi) GetUsersPublicInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.GetDesignateUsers, u.Client)
}
func (u *UserApi) GetAllUsersID(c *gin.Context) {
a2r.Call(c, user.UserClient.GetAllUserID, u.Client)
}
func (u *UserApi) AccountCheck(c *gin.Context) {
a2r.Call(c, user.UserClient.AccountCheck, u.Client)
}
func (u *UserApi) GetUsers(c *gin.Context) {
a2r.Call(c, user.UserClient.GetPaginationUsers, u.Client)
}
// GetUsersOnlineStatus Get user online status.
func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) {
var req msggateway.GetUsersOnlineStatusReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, err)
return
}
conns, err := u.discov.GetConns(c, u.config.MessageGateway)
if err != nil {
apiresp.GinError(c, err)
return
}
var wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
var respResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
flag := false
// Online push message
for _, v := range conns {
msgClient := msggateway.NewMsgGatewayClient(v)
reply, err := msgClient.GetUsersOnlineStatus(c, &req)
if err != nil {
log.ZDebug(c, "GetUsersOnlineStatus rpc error", err)
parseError := apiresp.ParseError(err)
if parseError.ErrCode == errs.NoPermissionError {
apiresp.GinError(c, err)
return
}
} else {
wsResult = append(wsResult, reply.SuccessResult...)
}
}
// Traversing the userIDs in the api request body
for _, v1 := range req.UserIDs {
flag = false
res := new(msggateway.GetUsersOnlineStatusResp_SuccessResult)
// Iterate through the online results fetched from various gateways
for _, v2 := range wsResult {
// If matches the above description on the line, and vice versa
if v2.UserID == v1 {
flag = true
res.UserID = v1
res.Status = constant.Online
res.DetailPlatformStatus = append(res.DetailPlatformStatus, v2.DetailPlatformStatus...)
break
}
}
if !flag {
res.UserID = v1
res.Status = constant.Offline
}
respResult = append(respResult, res)
}
apiresp.GinSuccess(c, respResult)
}
func (u *UserApi) UserRegisterCount(c *gin.Context) {
a2r.Call(c, user.UserClient.UserRegisterCount, u.Client)
}
// GetUsersOnlineTokenDetail Get user online token details.
func (u *UserApi) GetUsersOnlineTokenDetail(c *gin.Context) {
var wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
var respResult []*msggateway.SingleDetail
flag := false
var req msggateway.GetUsersOnlineStatusReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
conns, err := u.discov.GetConns(c, u.config.MessageGateway)
if err != nil {
apiresp.GinError(c, err)
return
}
// Online push message
for _, v := range conns {
msgClient := msggateway.NewMsgGatewayClient(v)
reply, err := msgClient.GetUsersOnlineStatus(c, &req)
if err != nil {
log.ZWarn(c, "GetUsersOnlineStatus rpc err", err)
continue
} else {
wsResult = append(wsResult, reply.SuccessResult...)
}
}
for _, v1 := range req.UserIDs {
m := make(map[int32][]string, 10)
flag = false
temp := new(msggateway.SingleDetail)
for _, v2 := range wsResult {
if v2.UserID == v1 {
flag = true
temp.UserID = v1
temp.Status = constant.Online
for _, status := range v2.DetailPlatformStatus {
if v, ok := m[status.PlatformID]; ok {
m[status.PlatformID] = append(v, status.Token)
} else {
m[status.PlatformID] = []string{status.Token}
}
}
}
}
for p, tokens := range m {
t := new(msggateway.SinglePlatformToken)
t.PlatformID = p
t.Token = tokens
t.Total = int32(len(tokens))
temp.SinglePlatformToken = append(temp.SinglePlatformToken, t)
}
if flag {
respResult = append(respResult, temp)
}
}
apiresp.GinSuccess(c, respResult)
}
// SubscriberStatus Presence status of subscribed users.
func (u *UserApi) SubscriberStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.SubscribeOrCancelUsersStatus, u.Client)
}
// GetUserStatus Get the online status of the user.
func (u *UserApi) GetUserStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.GetUserStatus, u.Client)
}
// GetSubscribeUsersStatus Get the online status of subscribers.
func (u *UserApi) GetSubscribeUsersStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.GetSubscribeUsersStatus, u.Client)
}
// ProcessUserCommandAdd user general function add.
func (u *UserApi) ProcessUserCommandAdd(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandAdd, u.Client)
}
// ProcessUserCommandDelete user general function delete.
func (u *UserApi) ProcessUserCommandDelete(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandDelete, u.Client)
}
// ProcessUserCommandUpdate user general function update.
func (u *UserApi) ProcessUserCommandUpdate(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandUpdate, u.Client)
}
// ProcessUserCommandGet user general function get.
func (u *UserApi) ProcessUserCommandGet(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandGet, u.Client)
}
// ProcessUserCommandGet user general function get all.
func (u *UserApi) ProcessUserCommandGetAll(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandGetAll, u.Client)
}
func (u *UserApi) AddNotificationAccount(c *gin.Context) {
a2r.Call(c, user.UserClient.AddNotificationAccount, u.Client)
}
func (u *UserApi) UpdateNotificationAccountInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateNotificationAccountInfo, u.Client)
}
func (u *UserApi) SearchNotificationAccount(c *gin.Context) {
a2r.Call(c, user.UserClient.SearchNotificationAccount, u.Client)
}
func (u *UserApi) GetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.GetUserClientConfig, u.Client)
}
func (u *UserApi) SetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.SetUserClientConfig, u.Client)
}
func (u *UserApi) DelUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.DelUserClientConfig, u.Client)
}
func (u *UserApi) PageUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.PageUserClientConfig, u.Client)
}

523
internal/api/wallet.go Normal file
View File

@@ -0,0 +1,523 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"strconv"
"time"
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/idutil"
"github.com/openimsdk/tools/utils/timeutil"
)
type WalletApi struct {
walletDB database.Wallet
walletBalanceRecordDB database.WalletBalanceRecord
userDB database.User
userClient *rpcli.UserClient
}
func NewWalletApi(walletDB database.Wallet, walletBalanceRecordDB database.WalletBalanceRecord, userDB database.User, userClient *rpcli.UserClient) *WalletApi {
return &WalletApi{
walletDB: walletDB,
walletBalanceRecordDB: walletBalanceRecordDB,
userDB: userDB,
userClient: userClient,
}
}
// updateBalanceWithRecord 统一的余额更新方法,包含余额记录创建和并发控制
// operation: 操作类型set/add/subtract
// amount: 金额(分)
// oldBalance: 旧余额
// oldVersion: 旧版本号
// remark: 备注信息
// 返回新余额、新版本号和错误
func (w *WalletApi) updateBalanceWithRecord(ctx *gin.Context, userID string, operation string, amount int64, oldBalance int64, oldVersion int64, remark string) (newBalance int64, newVersion int64, err error) {
// 使用版本号更新余额(防止并发覆盖)
params := &database.WalletUpdateParams{
UserID: userID,
Operation: operation,
Amount: amount,
OldBalance: oldBalance,
OldVersion: oldVersion,
}
result, err := w.walletDB.UpdateBalanceWithVersion(ctx, params)
if err != nil {
return 0, 0, err
}
// 计算变动金额:用于余额记录
var recordAmount int64
switch operation {
case "set":
recordAmount = result.NewBalance - oldBalance
case "add":
recordAmount = amount
case "subtract":
recordAmount = -amount
}
// 创建余额记录(新字段)
recordID := idutil.GetMsgIDByMD5(userID + timeutil.GetCurrentTimeFormatted() + operation + strconv.FormatInt(amount, 10))
var recordType int32 = 99 // 默认其他
switch operation {
case "add":
recordType = 99 // 通用“加钱”,具体业务可在上层传入
case "subtract":
recordType = 99 // 通用“减钱”
case "set":
recordType = 99 // 设置,归为其他
}
balanceRecord := &model.WalletBalanceRecord{
ID: recordID,
UserID: userID,
Amount: recordAmount,
Type: recordType,
BeforeBalance: oldBalance,
AfterBalance: result.NewBalance,
OrderID: "",
TransactionID: "",
RedPacketID: "",
Remark: remark,
CreateTime: time.Now(),
}
if err := w.walletBalanceRecordDB.Create(ctx, balanceRecord); err != nil {
// 余额记录创建失败不影响主流程,只记录警告日志
log.ZWarn(ctx, "updateBalanceWithRecord: failed to create balance record", err,
"userID", userID,
"operation", operation,
"amount", amount,
"oldBalance", oldBalance,
"newBalance", result.NewBalance)
}
return result.NewBalance, result.NewVersion, nil
}
// paginationWrapper 实现 pagination.Pagination 接口
type walletPaginationWrapper struct {
pageNumber int32
showNumber int32
}
func (p *walletPaginationWrapper) GetPageNumber() int32 {
if p.pageNumber <= 0 {
return 1
}
return p.pageNumber
}
func (p *walletPaginationWrapper) GetShowNumber() int32 {
if p.showNumber <= 0 {
return 20
}
return p.showNumber
}
// GetWallets 查询用户钱包列表(后台管理接口)
func (w *WalletApi) GetWallets(c *gin.Context) {
var (
req apistruct.GetWalletsReq
resp apistruct.GetWalletsResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 设置默认分页参数
if req.Pagination.PageNumber <= 0 {
req.Pagination.PageNumber = 1
}
if req.Pagination.ShowNumber <= 0 {
req.Pagination.ShowNumber = 20
}
// 创建分页对象
pagination := &walletPaginationWrapper{
pageNumber: req.Pagination.PageNumber,
showNumber: req.Pagination.ShowNumber,
}
// 查询钱包列表
var total int64
var wallets []*apistruct.WalletInfo
var err error
// 判断是否有查询条件用户ID、手机号、账号
hasQueryCondition := req.UserID != "" || req.PhoneNumber != "" || req.Account != ""
if hasQueryCondition {
// 如果有查询条件先通过条件查询用户ID
var userIDs []string
if req.UserID != "" {
// 如果直接提供了用户ID直接使用
userIDs = []string{req.UserID}
} else {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, req.Account, req.PhoneNumber, "")
if err != nil {
log.ZError(c, "GetWallets: failed to search users", err, "account", req.Account, "phoneNumber", req.PhoneNumber)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to search users"))
return
}
if len(searchUserIDs) == 0 {
// 没有找到匹配的用户,返回空列表
resp.Total = 0
resp.Wallets = []*apistruct.WalletInfo{}
apiresp.GinSuccess(c, resp)
return
}
userIDs = searchUserIDs
}
// 根据查询到的用户ID列表查询钱包
walletModels, err := w.walletDB.FindWalletsByUserIDs(c, userIDs)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets by userIDs", err, "userIDs", userIDs)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
total = int64(len(wallets))
} else {
// 如果没有任何查询条件,查询所有钱包(带分页)
var walletModels []*model.Wallet
total, walletModels, err = w.walletDB.FindAllWallets(c, pagination)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets", err)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
}
// 收集所有用户ID
userIDMap := make(map[string]bool)
for _, wallet := range wallets {
if wallet.UserID != "" {
userIDMap[wallet.UserID] = true
}
}
// 批量查询用户信息
userIDList := make([]string, 0, len(userIDMap))
for userID := range userIDMap {
userIDList = append(userIDList, userID)
}
userInfoMap := make(map[string]*apistruct.WalletInfo) // userID -> WalletInfo (with nickname and faceURL)
if len(userIDList) > 0 {
userInfos, err := w.userClient.GetUsersInfo(c, userIDList)
if err == nil && len(userInfos) > 0 {
for _, userInfo := range userInfos {
if userInfo != nil {
userInfoMap[userInfo.UserID] = &apistruct.WalletInfo{
Nickname: userInfo.Nickname,
FaceURL: userInfo.FaceURL,
}
}
}
} else {
log.ZWarn(c, "GetWallets: failed to get users info", err, "userIDs", userIDList)
}
}
// 填充用户信息
for _, wallet := range wallets {
if userInfo, ok := userInfoMap[wallet.UserID]; ok {
wallet.Nickname = userInfo.Nickname
wallet.FaceURL = userInfo.FaceURL
}
}
// 填充响应
resp.Total = total
resp.Wallets = wallets
log.ZInfo(c, "GetWallets: success", "userID", req.UserID, "phoneNumber", req.PhoneNumber, "account", req.Account, "total", total, "count", len(resp.Wallets))
apiresp.GinSuccess(c, resp)
}
// BatchUpdateWalletBalance 批量修改用户余额(后台管理接口)
func (w *WalletApi) BatchUpdateWalletBalance(c *gin.Context) {
var (
req apistruct.BatchUpdateWalletBalanceReq
resp apistruct.BatchUpdateWalletBalanceResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 验证请求参数
if len(req.Users) == 0 {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("users list cannot be empty"))
return
}
// 设置默认操作类型
defaultOperation := req.Operation
if defaultOperation == "" {
defaultOperation = "add" // 默认为增加
}
if defaultOperation != "" && defaultOperation != "set" && defaultOperation != "add" && defaultOperation != "subtract" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("operation must be 'set', 'add', or 'subtract'"))
return
}
// 处理每个用户
resp.Total = int32(len(req.Users))
resp.Results = make([]apistruct.WalletUpdateResult, 0, len(req.Users))
for _, user := range req.Users {
result := apistruct.WalletUpdateResult{
UserID: user.UserID.String(),
PhoneNumber: user.PhoneNumber,
Account: user.Account,
Success: false,
Remark: user.Remark, // 保存备注信息
}
// 确定用户ID
var userID string
if user.UserID != "" {
userID = user.UserID.String()
} else if user.PhoneNumber != "" || user.Account != "" {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, user.Account, user.PhoneNumber, "")
if err != nil {
result.Message = "failed to search user: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) == 0 {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) > 1 {
result.Message = "multiple users found, please use userID"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userID = searchUserIDs[0]
} else {
result.Message = "user identifier is required (userID, phoneNumber, or account)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 先验证用户是否存在
userExists, err := w.userDB.Exist(c, userID)
if err != nil {
result.Message = "failed to check user existence: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if !userExists {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 查询当前钱包
wallet, err := w.walletDB.Take(c, userID)
if err != nil {
// 如果钱包不存在,但用户存在,创建新钱包
if errs.ErrRecordNotFound.Is(err) {
wallet = &model.Wallet{
UserID: userID,
Balance: 0,
CreateTime: time.Now(),
UpdateTime: time.Now(),
}
if err := w.walletDB.Create(c, wallet); err != nil {
result.Message = "failed to create wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
} else {
result.Message = "failed to get wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
}
// 记录旧余额和版本号(旧数据可能没有 version保持 0 以兼容)
result.OldBalance = wallet.Balance
oldVersion := wallet.Version
// 确定该用户使用的金额和操作类型
userAmount := user.Amount
if userAmount == 0 {
// 如果用户没有指定金额,使用默认金额
userAmount = req.Amount
}
if userAmount == 0 {
result.Message = "amount is required (either in user object or in request)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userOperation := user.Operation
if userOperation == "" {
// 如果用户没有指定操作类型,使用默认操作类型
userOperation = defaultOperation
}
if userOperation != "set" && userOperation != "add" && userOperation != "subtract" {
result.Message = "operation must be 'set', 'add', or 'subtract'"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 预检查:计算新余额并检查不能为负数
var expectedNewBalance int64
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 使用统一的余额更新方法(包含并发控制和余额记录创建)
// 如果发生并发冲突,重试一次(重新获取最新余额和版本号)
var newBalance int64
maxRetries := 2
for retry := 0; retry < maxRetries; retry++ {
if retry > 0 {
// 重试时重新获取钱包信息
wallet, err = w.walletDB.Take(c, userID)
if err != nil {
result.Message = "failed to get wallet for retry: " + err.Error()
break
}
result.OldBalance = wallet.Balance
oldVersion = wallet.Version
// 重新计算预期新余额
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative after retry"
break
}
}
newBalance, _, err = w.updateBalanceWithRecord(c, userID, userOperation, userAmount, result.OldBalance, oldVersion, user.Remark)
if err == nil {
// 更新成功
break
}
// 如果是并发冲突且还有重试机会,继续重试
if retry < maxRetries-1 && errs.ErrInternalServer.Is(err) {
log.ZWarn(c, "BatchUpdateWalletBalance: concurrent modification detected, retrying", err,
"userID", userID,
"retry", retry+1,
"maxRetries", maxRetries)
continue
}
// 其他错误或重试次数用完,返回错误
result.Message = "failed to update balance: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if err != nil {
// 更新失败
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 更新成功
result.Success = true
result.NewBalance = newBalance
result.Message = "success"
// 备注信息已在初始化时设置
resp.Results = append(resp.Results, result)
resp.Success++
// 记录日志,包含备注信息
if user.Remark != "" {
log.ZInfo(c, "BatchUpdateWalletBalance: user balance updated",
"userID", userID,
"operation", userOperation,
"amount", userAmount,
"oldBalance", result.OldBalance,
"newBalance", newBalance,
"remark", user.Remark)
}
}
log.ZInfo(c, "BatchUpdateWalletBalance: success", "total", resp.Total, "success", resp.Success, "failed", resp.Failed)
apiresp.GinSuccess(c, resp)
}

View File

@@ -0,0 +1,76 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"time"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/mcontext"
)
func (ws *WsServer) webhookAfterUserOnline(ctx context.Context, after *config.AfterConfig, userID string, platformID int, isAppBackground bool, connID string) {
req := cbapi.CallbackUserOnlineReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserOnlineCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
IsAppBackground: isAppBackground,
ConnID: connID,
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CommonCallbackResp{}, after)
}
func (ws *WsServer) webhookAfterUserOffline(ctx context.Context, after *config.AfterConfig, userID string, platformID int, connID string) {
req := &cbapi.CallbackUserOfflineReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserOfflineCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
ConnID: connID,
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CallbackUserOfflineResp{}, after)
}
func (ws *WsServer) webhookAfterUserKickOff(ctx context.Context, after *config.AfterConfig, userID string, platformID int) {
req := &cbapi.CallbackUserKickOffReq{
UserStatusCallbackReq: cbapi.UserStatusCallbackReq{
UserStatusBaseCallback: cbapi.UserStatusBaseCallback{
CallbackCommand: cbapi.CallbackAfterUserKickOffCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: platformID,
Platform: constant.PlatformIDToName(platformID),
},
UserID: userID,
},
Seq: time.Now().UnixMilli(),
}
ws.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CommonCallbackResp{}, after)
}

View File

@@ -0,0 +1,476 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/stringutil"
)
var (
ErrConnClosed = errs.New("conn has closed")
ErrNotSupportMessageProtocol = errs.New("not support message protocol")
ErrClientClosed = errs.New("client actively close the connection")
ErrPanic = errs.New("panic error")
)
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
// CloseMessage denotes a close control message. The optional message
// payload contains a numeric code and text. Use the FormatCloseMessage
// function to format a close message payload.
CloseMessage = 8
// PingMessage denotes a ping control message. The optional message payload
// is UTF-8 encoded text.
PingMessage = 9
// PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text.
PongMessage = 10
)
type PingPongHandler func(string) error
type Client struct {
w *sync.Mutex
conn LongConn
PlatformID int `json:"platformID"`
IsCompress bool `json:"isCompress"`
UserID string `json:"userID"`
IsBackground bool `json:"isBackground"`
SDKType string `json:"sdkType"`
SDKVersion string `json:"sdkVersion"`
Encoder Encoder
ctx *UserConnContext
longConnServer LongConnServer
closed atomic.Bool
closedErr error
token string
hbCtx context.Context
hbCancel context.CancelFunc
subLock *sync.Mutex
subUserIDs map[string]struct{} // client conn subscription list
}
// ResetClient updates the client's state with new connection and context information.
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
c.w = new(sync.Mutex)
c.conn = conn
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
c.IsCompress = ctx.GetCompression()
c.IsBackground = ctx.GetBackground()
c.UserID = ctx.GetUserID()
c.ctx = ctx
c.longConnServer = longConnServer
c.IsBackground = false
c.closed.Store(false)
c.closedErr = nil
c.token = ctx.GetToken()
c.SDKType = ctx.GetSDKType()
c.SDKVersion = ctx.GetSDKVersion()
c.hbCtx, c.hbCancel = context.WithCancel(c.ctx)
c.subLock = new(sync.Mutex)
if c.subUserIDs != nil {
clear(c.subUserIDs)
}
if c.SDKType == GoSDK {
c.Encoder = NewGobEncoder()
} else {
c.Encoder = NewJsonEncoder()
}
c.subUserIDs = make(map[string]struct{})
}
func (c *Client) pingHandler(appData string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
return c.writePongMsg(appData)
}
func (c *Client) pongHandler(_ string) error {
if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
return nil
}
// readMessage continuously reads messages from the connection.
func (c *Client) readMessage() {
defer func() {
if r := recover(); r != nil {
c.closedErr = ErrPanic
log.ZPanic(c.ctx, "socket have panic err:", errs.ErrPanic(r))
}
c.close()
}()
c.conn.SetReadLimit(maxMessageSize)
_ = c.conn.SetReadDeadline(pongWait)
c.conn.SetPongHandler(c.pongHandler)
c.conn.SetPingHandler(c.pingHandler)
c.activeHeartbeat(c.hbCtx)
for {
log.ZDebug(c.ctx, "readMessage")
messageType, message, returnErr := c.conn.ReadMessage()
if returnErr != nil {
log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
c.closedErr = returnErr
return
}
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
if c.closed.Load() {
// The scenario where the connection has just been closed, but the coroutine has not exited
c.closedErr = ErrConnClosed
return
}
switch messageType {
case MessageBinary:
_ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handleMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case MessageText:
_ = c.conn.SetReadDeadline(pongWait)
parseDataErr := c.handlerTextMessage(message)
if parseDataErr != nil {
c.closedErr = parseDataErr
return
}
case PingMessage:
err := c.writePongMsg("")
log.ZError(c.ctx, "writePongMsg", err)
case CloseMessage:
c.closedErr = ErrClientClosed
return
default:
}
}
}
// handleMessage processes a single message received by the client.
func (c *Client) handleMessage(message []byte) error {
if c.IsCompress {
var err error
message, err = c.longConnServer.DecompressWithPool(message)
if err != nil {
return errs.Wrap(err)
}
}
var binaryReq = getReq()
defer freeReq(binaryReq)
err := c.Encoder.Decode(message, binaryReq)
if err != nil {
return err
}
if err := c.longConnServer.Validate(binaryReq); err != nil {
return err
}
if binaryReq.SendID != c.UserID {
return errs.New("exception conn userID not same to req userID", "binaryReq", binaryReq.String())
}
ctx := mcontext.WithMustInfoCtx(
[]string{binaryReq.OperationID, binaryReq.SendID, constant.PlatformIDToName(c.PlatformID), c.ctx.GetConnID()},
)
log.ZDebug(ctx, "gateway req message", "req", binaryReq.String())
var (
resp []byte
messageErr error
)
switch binaryReq.ReqIdentifier {
case WSGetNewestSeq:
resp, messageErr = c.longConnServer.GetSeq(ctx, binaryReq)
case WSSendMsg:
resp, messageErr = c.longConnServer.SendMessage(ctx, binaryReq)
case WSSendSignalMsg:
resp, messageErr = c.longConnServer.SendSignalMessage(ctx, binaryReq)
case WSPullMsgBySeqList:
resp, messageErr = c.longConnServer.PullMessageBySeqList(ctx, binaryReq)
case WSPullMsg:
resp, messageErr = c.longConnServer.GetSeqMessage(ctx, binaryReq)
case WSGetConvMaxReadSeq:
resp, messageErr = c.longConnServer.GetConversationsHasReadAndMaxSeq(ctx, binaryReq)
case WsPullConvLastMessage:
resp, messageErr = c.longConnServer.GetLastMessage(ctx, binaryReq)
case WsLogoutMsg:
resp, messageErr = c.longConnServer.UserLogout(ctx, binaryReq)
case WsSetBackgroundStatus:
resp, messageErr = c.setAppBackgroundStatus(ctx, binaryReq)
case WsSubUserOnlineStatus:
resp, messageErr = c.longConnServer.SubUserOnlineStatus(ctx, c, binaryReq)
default:
return fmt.Errorf(
"ReqIdentifier failed,sendID:%s,msgIncr:%s,reqIdentifier:%d",
binaryReq.SendID,
binaryReq.MsgIncr,
binaryReq.ReqIdentifier,
)
}
return c.replyMessage(ctx, binaryReq, messageErr, resp)
}
func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte, error) {
resp, isBackground, messageErr := c.longConnServer.SetUserDeviceBackground(ctx, req)
if messageErr != nil {
return nil, messageErr
}
c.IsBackground = isBackground
// TODO: callback
return resp, nil
}
func (c *Client) close() {
c.w.Lock()
defer c.w.Unlock()
if c.closed.Load() {
return
}
c.closed.Store(true)
c.conn.Close()
c.hbCancel() // Close server-initiated heartbeat.
c.longConnServer.UnRegister(c)
}
func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, resp []byte) error {
errResp := apiresp.ParseError(err)
mReply := Resp{
ReqIdentifier: binaryReq.ReqIdentifier,
MsgIncr: binaryReq.MsgIncr,
OperationID: binaryReq.OperationID,
ErrCode: errResp.ErrCode,
ErrMsg: errResp.ErrMsg,
Data: resp,
}
t := time.Now()
log.ZDebug(ctx, "gateway reply message", "resp", mReply.String())
err = c.writeBinaryMsg(mReply)
if err != nil {
log.ZWarn(ctx, "wireBinaryMsg replyMessage", err, "resp", mReply.String())
}
log.ZDebug(ctx, "wireBinaryMsg end", "time cost", time.Since(t))
if binaryReq.ReqIdentifier == WsLogoutMsg {
return errs.New("user logout", "operationID", binaryReq.OperationID).Wrap()
}
return nil
}
func (c *Client) PushMessage(ctx context.Context, msgData *sdkws.MsgData) error {
var msg sdkws.PushMessages
conversationID := msgprocessor.GetConversationIDByMsg(msgData)
m := map[string]*sdkws.PullMsgs{conversationID: {Msgs: []*sdkws.MsgData{msgData}}}
if msgprocessor.IsNotification(conversationID) {
msg.NotificationMsgs = m
} else {
msg.Msgs = m
}
log.ZDebug(ctx, "PushMessage", "msg", &msg)
data, err := proto.Marshal(&msg)
if err != nil {
return err
}
resp := Resp{
ReqIdentifier: WSPushMsg,
OperationID: mcontext.GetOperationID(ctx),
Data: data,
}
return c.writeBinaryMsg(resp)
}
func (c *Client) KickOnlineMessage() error {
resp := Resp{
ReqIdentifier: WSKickOnlineMsg,
}
log.ZDebug(c.ctx, "KickOnlineMessage debug ")
// 先发送踢掉通知消息
err := c.writeBinaryMsg(resp)
if err != nil {
log.ZWarn(c.ctx, "KickOnlineMessage writeBinaryMsg failed", err)
// 即使发送失败,也要关闭连接
c.close()
return err
}
// 给消息一些时间发送,然后关闭连接
// 使用一个小延迟确保消息已经写入缓冲区
time.Sleep(10 * time.Millisecond)
c.close()
return nil
}
func (c *Client) PushUserOnlineStatus(data []byte) error {
resp := Resp{
ReqIdentifier: WsSubUserOnlineStatus,
Data: data,
}
return c.writeBinaryMsg(resp)
}
func (c *Client) writeBinaryMsg(resp Resp) error {
if c.closed.Load() {
return nil
}
encodedBuf, err := c.Encoder.Encode(resp)
if err != nil {
return err
}
c.w.Lock()
defer c.w.Unlock()
err = c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
if c.IsCompress {
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
if compressErr != nil {
return compressErr
}
return c.conn.WriteMessage(MessageBinary, resultBuf)
}
return c.conn.WriteMessage(MessageBinary, encodedBuf)
}
// Actively initiate Heartbeat when platform in Web.
func (c *Client) activeHeartbeat(ctx context.Context) {
if c.PlatformID == constant.WebPlatformID {
go func() {
defer func() {
if r := recover(); r != nil {
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
}
}()
log.ZDebug(ctx, "server initiative send heartbeat start.")
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.writePingMsg(); err != nil {
log.ZWarn(c.ctx, "send Ping Message error.", err)
return
}
case <-c.hbCtx.Done():
return
}
}
}()
}
}
func (c *Client) writePingMsg() error {
if c.closed.Load() {
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
return c.conn.WriteMessage(PingMessage, nil)
}
func (c *Client) writePongMsg(appData string) error {
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
if c.closed.Load() {
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
return nil
}
c.w.Lock()
defer c.w.Unlock()
err := c.conn.SetWriteDeadline(writeWait)
if err != nil {
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
return errs.Wrap(err)
}
err = c.conn.WriteMessage(PongMessage, []byte(appData))
if err != nil {
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
}
return errs.Wrap(err)
}
func (c *Client) handlerTextMessage(b []byte) error {
var msg TextMessage
if err := json.Unmarshal(b, &msg); err != nil {
return err
}
switch msg.Type {
case TextPong:
return nil
case TextPing:
msg.Type = TextPong
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
c.w.Lock()
defer c.w.Unlock()
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
return err
}
return c.conn.WriteMessage(MessageText, msgData)
default:
return fmt.Errorf("not support message type %s", msg.Type)
}
}

View File

@@ -0,0 +1,113 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"bytes"
"compress/gzip"
"io"
"sync"
"github.com/openimsdk/tools/errs"
)
var (
gzipWriterPool = sync.Pool{New: func() any { return gzip.NewWriter(nil) }}
gzipReaderPool = sync.Pool{New: func() any { return new(gzip.Reader) }}
)
type Compressor interface {
Compress(rawData []byte) ([]byte, error)
CompressWithPool(rawData []byte) ([]byte, error)
DeCompress(compressedData []byte) ([]byte, error)
DecompressWithPool(compressedData []byte) ([]byte, error)
}
type GzipCompressor struct {
compressProtocol string
}
func NewGzipCompressor() *GzipCompressor {
return &GzipCompressor{compressProtocol: "gzip"}
}
func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) {
gzipBuffer := bytes.Buffer{}
gz := gzip.NewWriter(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.Compress: writing to gzip writer failed")
}
if err := gz.Close(); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.Compress: closing gzip writer failed")
}
return gzipBuffer.Bytes(), nil
}
func (g *GzipCompressor) CompressWithPool(rawData []byte) ([]byte, error) {
gz := gzipWriterPool.Get().(*gzip.Writer)
defer gzipWriterPool.Put(gz)
gzipBuffer := bytes.Buffer{}
gz.Reset(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.CompressWithPool: error writing data")
}
if err := gz.Close(); err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.CompressWithPool: error closing gzip writer")
}
return gzipBuffer.Bytes(), nil
}
func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) {
buff := bytes.NewBuffer(compressedData)
reader, err := gzip.NewReader(buff)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DeCompress: NewReader creation failed")
}
decompressedData, err := io.ReadAll(reader)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DeCompress: reading from gzip reader failed")
}
if err = reader.Close(); err != nil {
// Even if closing the reader fails, we've successfully read the data,
// so we return the decompressed data and an error indicating the close failure.
return decompressedData, errs.WrapMsg(err, "GzipCompressor.DeCompress: closing gzip reader failed")
}
return decompressedData, nil
}
func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) {
reader := gzipReaderPool.Get().(*gzip.Reader)
defer gzipReaderPool.Put(reader)
err := reader.Reset(bytes.NewReader(compressedData))
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: resetting gzip reader failed")
}
decompressedData, err := io.ReadAll(reader)
if err != nil {
return nil, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: reading from pooled gzip reader failed")
}
if err = reader.Close(); err != nil {
// Similar to DeCompress, return the data and error for close failure.
return decompressedData, errs.WrapMsg(err, "GzipCompressor.DecompressWithPool: closing pooled gzip reader failed")
}
return decompressedData, nil
}

View File

@@ -0,0 +1,139 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"crypto/rand"
"github.com/stretchr/testify/assert"
"sync"
"testing"
"unsafe"
)
func mockRandom() []byte {
bs := make([]byte, 50)
rand.Read(bs)
return bs
}
func TestCompressDecompress(t *testing.T) {
compressor := NewGzipCompressor()
for i := 0; i < 2000; i++ {
src := mockRandom()
// compress
dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// decompress
res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// check
assert.EqualValues(t, src, res)
}
}
func TestCompressDecompressWithConcurrency(t *testing.T) {
wg := sync.WaitGroup{}
compressor := NewGzipCompressor()
for i := 0; i < 200; i++ {
wg.Add(1)
go func() {
defer wg.Done()
src := mockRandom()
// compress
dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// decompress
res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err)
// check
assert.EqualValues(t, src, res)
}()
}
wg.Wait()
}
func BenchmarkCompress(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
for i := 0; i < b.N; i++ {
_, err := compressor.Compress(src)
assert.Equal(b, nil, err)
}
}
func BenchmarkCompressWithSyncPool(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
for i := 0; i < b.N; i++ {
_, err := compressor.CompressWithPool(src)
assert.Equal(b, nil, err)
}
}
func BenchmarkDecompress(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
comdata, err := compressor.Compress(src)
assert.Equal(b, nil, err)
for i := 0; i < b.N; i++ {
_, err := compressor.DeCompress(comdata)
assert.Equal(b, nil, err)
}
}
func BenchmarkDecompressWithSyncPool(b *testing.B) {
src := mockRandom()
compressor := NewGzipCompressor()
comdata, err := compressor.Compress(src)
assert.Equal(b, nil, err)
for i := 0; i < b.N; i++ {
_, err := compressor.DecompressWithPool(comdata)
assert.Equal(b, nil, err)
}
}
func TestName(t *testing.T) {
t.Log(unsafe.Sizeof(Client{}))
}

View File

@@ -0,0 +1,72 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import "time"
const (
WsUserID = "sendID"
CommonUserID = "userID"
PlatformID = "platformID"
ConnID = "connID"
Token = "token"
OperationID = "operationID"
Compression = "compression"
GzipCompressionProtocol = "gzip"
BackgroundStatus = "isBackground"
SendResponse = "isMsgResp"
SDKType = "sdkType"
SDKVersion = "sdkVersion"
)
const (
GoSDK = "go"
JsSDK = "js"
)
const (
WebSocket = iota + 1
)
const (
// Websocket Protocol.
WSGetNewestSeq = 1001
WSPullMsgBySeqList = 1002
WSSendMsg = 1003
WSSendSignalMsg = 1004
WSPullMsg = 1005
WSGetConvMaxReadSeq = 1006
WsPullConvLastMessage = 1007
WSPushMsg = 2001
WSKickOnlineMsg = 2002
WsLogoutMsg = 2003
WsSetBackgroundStatus = 2004
WsSubUserOnlineStatus = 2005
WSDataError = 3001
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 30 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 51200
)

View File

@@ -0,0 +1,216 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"net/http"
"net/url"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/utils/encrypt"
"github.com/openimsdk/tools/utils/stringutil"
"github.com/openimsdk/tools/utils/timeutil"
)
type UserConnContext struct {
RespWriter http.ResponseWriter
Req *http.Request
Path string
Method string
RemoteAddr string
ConnID string
}
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
return
}
func (c *UserConnContext) Done() <-chan struct{} {
return nil
}
func (c *UserConnContext) Err() error {
return nil
}
func (c *UserConnContext) Value(key any) any {
switch key {
case constant.OpUserID:
return c.GetUserID()
case constant.OperationID:
return c.GetOperationID()
case constant.ConnID:
return c.GetConnID()
case constant.OpUserPlatform:
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
case constant.RemoteAddr:
return c.RemoteAddr
default:
return ""
}
}
func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnContext {
remoteAddr := req.RemoteAddr
if forwarded := req.Header.Get("X-Forwarded-For"); forwarded != "" {
remoteAddr += "_" + forwarded
}
return &UserConnContext{
RespWriter: respWriter,
Req: req,
Path: req.URL.Path,
Method: req.Method,
RemoteAddr: remoteAddr,
ConnID: encrypt.Md5(req.RemoteAddr + "_" + strconv.Itoa(int(timeutil.GetCurrentTimestampByMill()))),
}
}
func newTempContext() *UserConnContext {
return &UserConnContext{
Req: &http.Request{URL: &url.URL{}},
}
}
func (c *UserConnContext) GetRemoteAddr() string {
return c.RemoteAddr
}
func (c *UserConnContext) Query(key string) (string, bool) {
var value string
if value = c.Req.URL.Query().Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) GetHeader(key string) (string, bool) {
var value string
if value = c.Req.Header.Get(key); value == "" {
return value, false
}
return value, true
}
func (c *UserConnContext) SetHeader(key, value string) {
c.RespWriter.Header().Set(key, value)
}
func (c *UserConnContext) ErrReturn(error string, code int) {
http.Error(c.RespWriter, error, code)
}
func (c *UserConnContext) GetConnID() string {
return c.ConnID
}
func (c *UserConnContext) GetUserID() string {
return c.Req.URL.Query().Get(WsUserID)
}
func (c *UserConnContext) GetPlatformID() string {
return c.Req.URL.Query().Get(PlatformID)
}
func (c *UserConnContext) GetOperationID() string {
return c.Req.URL.Query().Get(OperationID)
}
func (c *UserConnContext) SetOperationID(operationID string) {
values := c.Req.URL.Query()
values.Set(OperationID, operationID)
c.Req.URL.RawQuery = values.Encode()
}
func (c *UserConnContext) GetToken() string {
return c.Req.URL.Query().Get(Token)
}
func (c *UserConnContext) GetSDKVersion() string {
return c.Req.URL.Query().Get(SDKVersion)
}
func (c *UserConnContext) GetCompression() bool {
compression, exists := c.Query(Compression)
if exists && compression == GzipCompressionProtocol {
return true
} else {
compression, exists := c.GetHeader(Compression)
if exists && compression == GzipCompressionProtocol {
return true
}
}
return false
}
func (c *UserConnContext) GetSDKType() string {
sdkType := c.Req.URL.Query().Get(SDKType)
if sdkType == "" {
sdkType = GoSDK
}
return sdkType
}
func (c *UserConnContext) ShouldSendResp() bool {
errResp, exists := c.Query(SendResponse)
if exists {
b, err := strconv.ParseBool(errResp)
if err != nil {
return false
} else {
return b
}
}
return false
}
func (c *UserConnContext) SetToken(token string) {
c.Req.URL.RawQuery = Token + "=" + token
}
func (c *UserConnContext) GetBackground() bool {
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
if err != nil {
return false
}
return b
}
func (c *UserConnContext) ParseEssentialArgs() error {
_, exists := c.Query(Token)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
}
_, exists = c.Query(WsUserID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
}
platformIDStr, exists := c.Query(PlatformID)
if !exists {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
}
_, err := strconv.Atoi(platformIDStr)
if err != nil {
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
}
switch sdkType, _ := c.Query(SDKType); sdkType {
case "", GoSDK, JsSDK:
default:
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
}
return nil
}

View File

@@ -0,0 +1,74 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"bytes"
"encoding/gob"
"encoding/json"
"github.com/openimsdk/tools/errs"
)
type Encoder interface {
Encode(data any) ([]byte, error)
Decode(encodeData []byte, decodeData any) error
}
type GobEncoder struct{}
func NewGobEncoder() Encoder {
return GobEncoder{}
}
func (g GobEncoder) Encode(data any) ([]byte, error) {
var buff bytes.Buffer
enc := gob.NewEncoder(&buff)
if err := enc.Encode(data); err != nil {
return nil, errs.WrapMsg(err, "GobEncoder.Encode failed", "action", "encode")
}
return buff.Bytes(), nil
}
func (g GobEncoder) Decode(encodeData []byte, decodeData any) error {
buff := bytes.NewBuffer(encodeData)
dec := gob.NewDecoder(buff)
if err := dec.Decode(decodeData); err != nil {
return errs.WrapMsg(err, "GobEncoder.Decode failed", "action", "decode")
}
return nil
}
type JsonEncoder struct{}
func NewJsonEncoder() Encoder {
return JsonEncoder{}
}
func (g JsonEncoder) Encode(data any) ([]byte, error) {
b, err := json.Marshal(data)
if err != nil {
return nil, errs.New("JsonEncoder.Encode failed", "action", "encode")
}
return b, nil
}
func (g JsonEncoder) Decode(encodeData []byte, decodeData any) error {
err := json.Unmarshal(encodeData, decodeData)
if err != nil {
return errs.New("JsonEncoder.Decode failed", "action", "decode")
}
return nil
}

View File

@@ -0,0 +1,25 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/log"
)
func httpError(ctx *UserConnContext, err error) {
log.ZWarn(ctx, "ws connection error", err)
apiresp.HttpError(ctx.RespWriter, err)
}

View File

@@ -0,0 +1,264 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"sync/atomic"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/mq/memamq"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/grpc"
)
func (s *Server) InitServer(ctx context.Context, config *Config, disCov discovery.Conn, server grpc.ServiceRegistrar) error {
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
s.userClient = rpcli.NewUserClient(userConn)
if err := s.LongConnServer.SetDiscoveryRegistry(ctx, disCov, config); err != nil {
return err
}
msggateway.RegisterMsgGatewayServer(server, s)
if s.ready != nil {
return s.ready(s)
}
return nil
}
//func (s *Server) Start(ctx context.Context, index int, conf *Config) error {
// return startrpc.Start(ctx, &conf.Discovery, &conf.MsgGateway.Prometheus, conf.MsgGateway.ListenIP,
// conf.MsgGateway.RPC.RegisterIP,
// conf.MsgGateway.RPC.AutoSetPorts, conf.MsgGateway.RPC.Ports, index,
// conf.Discovery.RpcService.MessageGateway,
// nil,
// conf,
// []string{
// conf.Share.GetConfigFileName(),
// conf.Discovery.GetConfigFileName(),
// conf.MsgGateway.GetConfigFileName(),
// conf.WebhooksConfig.GetConfigFileName(),
// conf.RedisConfig.GetConfigFileName(),
// },
// []string{
// conf.Discovery.RpcService.MessageGateway,
// },
// s.InitServer,
// )
//}
type Server struct {
msggateway.UnimplementedMsgGatewayServer
LongConnServer LongConnServer
config *Config
pushTerminal map[int]struct{}
ready func(srv *Server) error
queue *memamq.MemoryQueue
userClient *rpcli.UserClient
}
func (s *Server) SetLongConnServer(LongConnServer LongConnServer) {
s.LongConnServer = LongConnServer
}
func NewServer(longConnServer LongConnServer, conf *Config, ready func(srv *Server) error) *Server {
s := &Server{
LongConnServer: longConnServer,
pushTerminal: make(map[int]struct{}),
config: conf,
ready: ready,
queue: memamq.NewMemoryQueue(512, 1024*16),
}
s.pushTerminal[constant.IOSPlatformID] = struct{}{}
s.pushTerminal[constant.AndroidPlatformID] = struct{}{}
s.pushTerminal[constant.WebPlatformID] = struct{}{}
return s
}
func (s *Server) GetUsersOnlineStatus(ctx context.Context, req *msggateway.GetUsersOnlineStatusReq) (*msggateway.GetUsersOnlineStatusResp, error) {
if !authverify.IsAdmin(ctx) {
return nil, errs.ErrNoPermission.WrapMsg("only app manager")
}
var resp msggateway.GetUsersOnlineStatusResp
for _, userID := range req.UserIDs {
clients, ok := s.LongConnServer.GetUserAllCons(userID)
if !ok {
continue
}
uresp := new(msggateway.GetUsersOnlineStatusResp_SuccessResult)
uresp.UserID = userID
for _, client := range clients {
if client == nil {
continue
}
ps := new(msggateway.GetUsersOnlineStatusResp_SuccessDetail)
ps.PlatformID = int32(client.PlatformID)
ps.ConnID = client.ctx.GetConnID()
ps.Token = client.token
ps.IsBackground = client.IsBackground
uresp.Status = constant.Online
uresp.DetailPlatformStatus = append(uresp.DetailPlatformStatus, ps)
}
if uresp.Status == constant.Online {
resp.SuccessResult = append(resp.SuccessResult, uresp)
}
}
return &resp, nil
}
func (s *Server) pushToUser(ctx context.Context, userID string, msgData *sdkws.MsgData) *msggateway.SingleMsgToUserResults {
clients, ok := s.LongConnServer.GetUserAllCons(userID)
if !ok {
log.ZDebug(ctx, "push user not online", "userID", userID)
return &msggateway.SingleMsgToUserResults{
UserID: userID,
}
}
log.ZDebug(ctx, "push user online", "clients", clients, "userID", userID)
result := &msggateway.SingleMsgToUserResults{
UserID: userID,
Resp: make([]*msggateway.SingleMsgToUserPlatform, 0, len(clients)),
}
for _, client := range clients {
if client == nil {
continue
}
userPlatform := &msggateway.SingleMsgToUserPlatform{
RecvPlatFormID: int32(client.PlatformID),
}
if !client.IsBackground ||
(client.IsBackground && client.PlatformID != constant.IOSPlatformID) {
err := client.PushMessage(ctx, msgData)
if err != nil {
log.ZWarn(ctx, "online push msg failed", err, "userID", userID, "platformID", client.PlatformID)
userPlatform.ResultCode = int64(servererrs.ErrPushMsgErr.Code())
} else {
if _, ok := s.pushTerminal[client.PlatformID]; ok {
result.OnlinePush = true
}
}
} else {
userPlatform.ResultCode = int64(servererrs.ErrIOSBackgroundPushErr.Code())
}
result.Resp = append(result.Resp, userPlatform)
}
return result
}
func (s *Server) SuperGroupOnlineBatchPushOneMsg(ctx context.Context, req *msggateway.OnlineBatchPushOneMsgReq) (*msggateway.OnlineBatchPushOneMsgResp, error) {
if len(req.PushToUserIDs) == 0 {
return &msggateway.OnlineBatchPushOneMsgResp{}, nil
}
ch := make(chan *msggateway.SingleMsgToUserResults, len(req.PushToUserIDs))
var count atomic.Int64
count.Add(int64(len(req.PushToUserIDs)))
for i := range req.PushToUserIDs {
userID := req.PushToUserIDs[i]
err := s.queue.PushCtx(ctx, func() {
ch <- s.pushToUser(ctx, userID, req.MsgData)
if count.Add(-1) == 0 {
close(ch)
}
})
if err != nil {
if count.Add(-1) == 0 {
close(ch)
}
log.ZError(ctx, "pushToUser MemoryQueue failed", err, "userID", userID)
ch <- &msggateway.SingleMsgToUserResults{
UserID: userID,
}
}
}
resp := &msggateway.OnlineBatchPushOneMsgResp{
SinglePushResult: make([]*msggateway.SingleMsgToUserResults, 0, len(req.PushToUserIDs)),
}
for {
select {
case <-ctx.Done():
log.ZError(ctx, "SuperGroupOnlineBatchPushOneMsg ctx done", context.Cause(ctx))
userIDSet := datautil.SliceSet(req.PushToUserIDs)
for _, results := range resp.SinglePushResult {
delete(userIDSet, results.UserID)
}
for userID := range userIDSet {
resp.SinglePushResult = append(resp.SinglePushResult, &msggateway.SingleMsgToUserResults{
UserID: userID,
})
}
return resp, nil
case res, ok := <-ch:
if !ok {
return resp, nil
}
resp.SinglePushResult = append(resp.SinglePushResult, res)
}
}
}
func (s *Server) KickUserOffline(ctx context.Context, req *msggateway.KickUserOfflineReq) (*msggateway.KickUserOfflineResp, error) {
for _, v := range req.KickUserIDList {
clients, _, ok := s.LongConnServer.GetUserPlatformCons(v, int(req.PlatformID))
if !ok {
log.ZDebug(ctx, "conn not exist", "userID", v, "platformID", req.PlatformID)
continue
}
for _, client := range clients {
log.ZDebug(ctx, "kick user offline", "userID", v, "platformID", req.PlatformID, "client", client)
if err := client.longConnServer.KickUserConn(client); err != nil {
log.ZWarn(ctx, "kick user offline failed", err, "userID", v, "platformID", req.PlatformID)
}
}
continue
}
return &msggateway.KickUserOfflineResp{}, nil
}
func (s *Server) MultiTerminalLoginCheck(ctx context.Context, req *msggateway.MultiTerminalLoginCheckReq) (*msggateway.MultiTerminalLoginCheckResp, error) {
if oldClients, userOK, clientOK := s.LongConnServer.GetUserPlatformCons(req.UserID, int(req.PlatformID)); userOK {
tempUserCtx := newTempContext()
tempUserCtx.SetToken(req.Token)
tempUserCtx.SetOperationID(mcontext.GetOperationID(ctx))
client := &Client{}
client.ctx = tempUserCtx
client.token = req.Token
client.UserID = req.UserID
client.PlatformID = int(req.PlatformID)
i := &kickHandler{
clientOK: clientOK,
oldClients: oldClients,
newClient: client,
}
s.LongConnServer.SetKickHandlerInfo(i)
}
return &msggateway.MultiTerminalLoginCheckResp{}, nil
}

117
internal/msggateway/init.go Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/runtimeenv"
"google.golang.org/grpc"
"github.com/openimsdk/tools/log"
)
type Config struct {
MsgGateway config.MsgGateway
Share config.Share
RedisConfig config.Redis
WebhooksConfig config.Webhooks
Discovery config.Discovery
Index config.Index
}
// Start run ws server.
func Start(ctx context.Context, conf *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
log.CInfo(ctx, "MSG-GATEWAY server is initializing", "runtimeEnv", runtimeenv.RuntimeEnvironment(),
"rpcPorts", conf.MsgGateway.RPC.Ports,
"wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports)
wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, int(conf.Index))
if err != nil {
return err
}
dbb := dbbuild.NewBuilder(nil, &conf.RedisConfig)
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
longServer := NewWsServer(
conf,
WithPort(wsPort),
WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)),
WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second),
WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen),
)
hubServer := NewServer(longServer, conf, func(srv *Server) error {
var err error
longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
return err
})
if err := hubServer.InitServer(ctx, conf, client, server); err != nil {
return err
}
go longServer.ChangeOnlineStatus(4)
return hubServer.LongConnServer.Run(ctx)
}
//
//// Start run ws server.
//func Start(ctx context.Context, index int, conf *Config) error {
// log.CInfo(ctx, "MSG-GATEWAY server is initializing", "runtimeEnv", runtimeenv.RuntimeEnvironment(),
// "rpcPorts", conf.MsgGateway.RPC.Ports,
// "wsPort", conf.MsgGateway.LongConnSvr.Ports, "prometheusPorts", conf.MsgGateway.Prometheus.Ports)
// wsPort, err := datautil.GetElemByIndex(conf.MsgGateway.LongConnSvr.Ports, index)
// if err != nil {
// return err
// }
//
// rdb, err := redisutil.NewRedisClient(ctx, conf.RedisConfig.Build())
// if err != nil {
// return err
// }
// longServer := NewWsServer(
// conf,
// WithPort(wsPort),
// WithMaxConnNum(int64(conf.MsgGateway.LongConnSvr.WebsocketMaxConnNum)),
// WithHandshakeTimeout(time.Duration(conf.MsgGateway.LongConnSvr.WebsocketTimeout)*time.Second),
// WithMessageMaxMsgLength(conf.MsgGateway.LongConnSvr.WebsocketMaxMsgLen),
// )
//
// hubServer := NewServer(longServer, conf, func(srv *Server) error {
// var err error
// longServer.online, err = rpccache.NewOnlineCache(srv.userClient, nil, rdb, false, longServer.subscriberUserOnlineStatusChanges)
// return err
// })
//
// go longServer.ChangeOnlineStatus(4)
//
// netDone := make(chan error)
// go func() {
// err = hubServer.Start(ctx, index, conf)
// netDone <- err
// }()
// return hubServer.LongConnServer.Run(netDone)
//}

View File

@@ -0,0 +1,179 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"encoding/json"
"net/http"
"time"
"github.com/openimsdk/tools/apiresp"
"github.com/gorilla/websocket"
"github.com/openimsdk/tools/errs"
)
type LongConn interface {
// Close this connection
Close() error
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
WriteMessage(messageType int, message []byte) error
// ReadMessage Read message from connection.
ReadMessage() (int, []byte, error)
// SetReadDeadline sets the read deadline on the underlying network connection,
// after a read has timed out, will return an error.
SetReadDeadline(timeout time.Duration) error
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
SetWriteDeadline(timeout time.Duration) error
// Dial Try to dial a connection,url must set auth args,header can control compress data
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
// IsNil Whether the connection of the current long connection is nil
IsNil() bool
// SetConnNil Set the connection of the current long connection to nil
SetConnNil()
// SetReadLimit sets the maximum size for a message read from the peer.bytes
SetReadLimit(limit int64)
SetPongHandler(handler PingPongHandler)
SetPingHandler(handler PingPongHandler)
// GenerateLongConn Check the connection of the current and when it was sent are the same
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
}
type GWebSocket struct {
protocolType int
conn *websocket.Conn
handshakeTimeout time.Duration
writeBufferSize int
}
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
}
func (d *GWebSocket) Close() error {
return d.conn.Close()
}
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
upgrader := &websocket.Upgrader{
HandshakeTimeout: d.handshakeTimeout,
CheckOrigin: func(r *http.Request) bool { return true },
}
if d.writeBufferSize > 0 { // default is 4kb.
upgrader.WriteBufferSize = d.writeBufferSize
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
}
d.conn = conn
return nil
}
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
// d.setSendConn(d.conn)
return d.conn.WriteMessage(messageType, message)
}
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
// d.sendConn = sendConn
//}
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
return d.conn.ReadMessage()
}
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
return d.conn.SetReadDeadline(time.Now().Add(timeout))
}
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
if timeout <= 0 {
return errs.New("timeout must be greater than 0")
}
// TODO SetWriteDeadline Future add error handling
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
}
return nil
}
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
if err != nil {
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
}
d.conn = conn
return httpResp, nil
}
func (d *GWebSocket) IsNil() bool {
return d.conn == nil
//
// if d.conn != nil {
// return false
// }
// return true
}
func (d *GWebSocket) SetConnNil() {
d.conn = nil
}
func (d *GWebSocket) SetReadLimit(limit int64) {
d.conn.SetReadLimit(limit)
}
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
d.conn.SetPongHandler(handler)
}
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
d.conn.SetPingHandler(handler)
}
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
if err := d.GenerateLongConn(w, r); err != nil {
return err
}
data, err := json.Marshal(apiresp.ParseError(err))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
_ = d.Close()
return nil
}
func (d *GWebSocket) RespondWithSuccess() error {
data, err := json.Marshal(apiresp.ParseError(nil))
if err != nil {
_ = d.Close()
return errs.WrapMsg(err, "json marshal failed")
}
if err := d.WriteMessage(MessageText, data); err != nil {
_ = d.Close()
return errs.WrapMsg(err, "WriteMessage failed")
}
return nil
}

View File

@@ -0,0 +1,282 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import (
"context"
"encoding/json"
"sync"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/go-playground/validator/v10"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/push"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/utils/jsonutil"
)
const (
TextPing = "ping"
TextPong = "pong"
)
type TextMessage struct {
Type string `json:"type"`
Body json.RawMessage `json:"body"`
}
type Req struct {
ReqIdentifier int32 `json:"reqIdentifier" validate:"required"`
Token string `json:"token"`
SendID string `json:"sendID" validate:"required"`
OperationID string `json:"operationID" validate:"required"`
MsgIncr string `json:"msgIncr" validate:"required"`
Data []byte `json:"data"`
}
func (r *Req) String() string {
var tReq Req
tReq.ReqIdentifier = r.ReqIdentifier
tReq.Token = r.Token
tReq.SendID = r.SendID
tReq.OperationID = r.OperationID
tReq.MsgIncr = r.MsgIncr
return jsonutil.StructToJsonString(tReq)
}
var reqPool = sync.Pool{
New: func() any {
return new(Req)
},
}
func getReq() *Req {
req := reqPool.Get().(*Req)
req.Data = nil
req.MsgIncr = ""
req.OperationID = ""
req.ReqIdentifier = 0
req.SendID = ""
req.Token = ""
return req
}
func freeReq(req *Req) {
reqPool.Put(req)
}
type Resp struct {
ReqIdentifier int32 `json:"reqIdentifier"`
MsgIncr string `json:"msgIncr"`
OperationID string `json:"operationID"`
ErrCode int `json:"errCode"`
ErrMsg string `json:"errMsg"`
Data []byte `json:"data"`
}
func (r *Resp) String() string {
var tResp Resp
tResp.ReqIdentifier = r.ReqIdentifier
tResp.MsgIncr = r.MsgIncr
tResp.OperationID = r.OperationID
tResp.ErrCode = r.ErrCode
tResp.ErrMsg = r.ErrMsg
return jsonutil.StructToJsonString(tResp)
}
type MessageHandler interface {
GetSeq(ctx context.Context, data *Req) ([]byte, error)
SendMessage(ctx context.Context, data *Req) ([]byte, error)
SendSignalMessage(ctx context.Context, data *Req) ([]byte, error)
PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error)
GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error)
GetSeqMessage(ctx context.Context, data *Req) ([]byte, error)
UserLogout(ctx context.Context, data *Req) ([]byte, error)
SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error)
GetLastMessage(ctx context.Context, data *Req) ([]byte, error)
}
var _ MessageHandler = (*GrpcHandler)(nil)
type GrpcHandler struct {
validate *validator.Validate
msgClient *rpcli.MsgClient
pushClient *rpcli.PushMsgServiceClient
}
func NewGrpcHandler(validate *validator.Validate, msgClient *rpcli.MsgClient, pushClient *rpcli.PushMsgServiceClient) *GrpcHandler {
return &GrpcHandler{
validate: validate,
msgClient: msgClient,
pushClient: pushClient,
}
}
func (g *GrpcHandler) GetSeq(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.GetMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: error unmarshaling request", "action", "unmarshal", "dataType", "GetMaxSeqReq")
}
if err := g.validate.Struct(&req); err != nil {
return nil, errs.WrapMsg(err, "GetSeq: validation failed", "action", "validate", "dataType", "GetMaxSeqReq")
}
resp, err := g.msgClient.MsgClient.GetMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "GetSeq: error marshaling response", "action", "marshal", "dataType", "GetMaxSeqResp")
}
return c, nil
}
// SendMessage handles the sending of messages through gRPC. It unmarshals the request data,
// validates the message, and then sends it using the message RPC client.
func (g *GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) {
var msgData sdkws.MsgData
if err := proto.Unmarshal(data.Data, &msgData); err != nil {
return nil, errs.WrapMsg(err, "SendMessage: error unmarshaling message data", "action", "unmarshal", "dataType", "MsgData")
}
if err := g.validate.Struct(&msgData); err != nil {
return nil, errs.WrapMsg(err, "SendMessage: message data validation failed", "action", "validate", "dataType", "MsgData")
}
req := msg.SendMsgReq{MsgData: &msgData}
resp, err := g.msgClient.MsgClient.SendMsg(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "SendMessage: error marshaling response", "action", "marshal", "dataType", "SendMsgResp")
}
return c, nil
}
func (g *GrpcHandler) SendSignalMessage(ctx context.Context, data *Req) ([]byte, error) {
resp, err := g.msgClient.MsgClient.SendMsg(ctx, nil)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "SendMsgResp")
}
return c, nil
}
func (g *GrpcHandler) PullMessageBySeqList(ctx context.Context, data *Req) ([]byte, error) {
req := sdkws.PullMessageBySeqsReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "PullMessageBySeqsReq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "PullMessageBySeqsReq")
}
resp, err := g.msgClient.MsgClient.PullMessageBySeqs(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "PullMessageBySeqsResp")
}
return c, nil
}
func (g *GrpcHandler) GetConversationsHasReadAndMaxSeq(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetConversationsHasReadAndMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "err proto unmarshal", "action", "unmarshal", "dataType", "GetConversationsHasReadAndMaxSeq")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetConversationsHasReadAndMaxSeq")
}
resp, err := g.msgClient.MsgClient.GetConversationsHasReadAndMaxSeq(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "GetConversationsHasReadAndMaxSeq")
}
return c, nil
}
func (g *GrpcHandler) GetSeqMessage(ctx context.Context, data *Req) ([]byte, error) {
req := msg.GetSeqMessageReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "GetSeqMessage")
}
if err := g.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "GetSeqMessage")
}
resp, err := g.msgClient.MsgClient.GetSeqMessage(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "GetSeqMessage")
}
return c, nil
}
func (g *GrpcHandler) UserLogout(ctx context.Context, data *Req) ([]byte, error) {
req := push.DelUserPushTokenReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "DelUserPushTokenReq")
}
resp, err := g.pushClient.PushMsgServiceClient.DelUserPushToken(ctx, &req)
if err != nil {
return nil, err
}
c, err := proto.Marshal(resp)
if err != nil {
return nil, errs.WrapMsg(err, "error marshaling response", "action", "marshal", "dataType", "DelUserPushTokenResp")
}
return c, nil
}
func (g *GrpcHandler) SetUserDeviceBackground(ctx context.Context, data *Req) ([]byte, bool, error) {
req := sdkws.SetAppBackgroundStatusReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, false, errs.WrapMsg(err, "error unmarshaling request", "action", "unmarshal", "dataType", "SetAppBackgroundStatusReq")
}
if err := g.validate.Struct(data); err != nil {
return nil, false, errs.WrapMsg(err, "validation failed", "action", "validate", "dataType", "SetAppBackgroundStatusReq")
}
return nil, req.IsBackground, nil
}
func (g *GrpcHandler) GetLastMessage(ctx context.Context, data *Req) ([]byte, error) {
var req msg.GetLastMessageReq
if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err
}
resp, err := g.msgClient.GetLastMessage(ctx, &req)
if err != nil {
return nil, err
}
return proto.Marshal(resp)
}

View File

@@ -0,0 +1,131 @@
package msggateway
import (
"context"
"crypto/md5"
"encoding/binary"
"fmt"
"math/rand"
"os"
"strconv"
"sync/atomic"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/cachekey"
pbuser "git.imall.cloud/openim/protocol/user"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (ws *WsServer) ChangeOnlineStatus(concurrent int) {
if concurrent < 1 {
concurrent = 1
}
const renewalTime = cachekey.OnlineExpire / 3
//const renewalTime = time.Second * 10
renewalTicker := time.NewTicker(renewalTime)
requestChs := make([]chan *pbuser.SetUserOnlineStatusReq, concurrent)
changeStatus := make([][]UserState, concurrent)
for i := 0; i < concurrent; i++ {
requestChs[i] = make(chan *pbuser.SetUserOnlineStatusReq, 64)
changeStatus[i] = make([]UserState, 0, 100)
}
mergeTicker := time.NewTicker(time.Second)
local2pb := func(u UserState) *pbuser.UserOnlineStatus {
return &pbuser.UserOnlineStatus{
UserID: u.UserID,
Online: u.Online,
Offline: u.Offline,
}
}
rNum := rand.Uint64()
pushUserState := func(us ...UserState) {
for _, u := range us {
sum := md5.Sum([]byte(u.UserID))
i := (binary.BigEndian.Uint64(sum[:]) + rNum) % uint64(concurrent)
changeStatus[i] = append(changeStatus[i], u)
status := changeStatus[i]
if len(status) == cap(status) {
req := &pbuser.SetUserOnlineStatusReq{
Status: datautil.Slice(status, local2pb),
}
changeStatus[i] = status[:0]
select {
case requestChs[i] <- req:
default:
log.ZError(context.Background(), "user online processing is too slow", nil)
}
}
}
}
pushAllUserState := func() {
for i, status := range changeStatus {
if len(status) == 0 {
continue
}
req := &pbuser.SetUserOnlineStatusReq{
Status: datautil.Slice(status, local2pb),
}
changeStatus[i] = status[:0]
select {
case requestChs[i] <- req:
default:
log.ZError(context.Background(), "user online processing is too slow", nil)
}
}
}
var count atomic.Int64
operationIDPrefix := fmt.Sprintf("p_%d_", os.Getpid())
doRequest := func(req *pbuser.SetUserOnlineStatusReq) {
opIdCtx := mcontext.SetOperationID(context.Background(), operationIDPrefix+strconv.FormatInt(count.Add(1), 10))
ctx, cancel := context.WithTimeout(opIdCtx, time.Second*5)
defer cancel()
if err := ws.userClient.SetUserOnlineStatus(ctx, req); err != nil {
log.ZError(ctx, "update user online status", err)
}
for _, ss := range req.Status {
for _, online := range ss.Online {
client, _, _ := ws.clients.Get(ss.UserID, int(online))
back := false
if len(client) > 0 {
back = client[0].IsBackground
}
ws.webhookAfterUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOnline, ss.UserID, int(online), back, ss.ConnID)
}
for _, offline := range ss.Offline {
ws.webhookAfterUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOffline, ss.UserID, int(offline), ss.ConnID)
}
}
}
for i := 0; i < concurrent; i++ {
go func(ch <-chan *pbuser.SetUserOnlineStatusReq) {
for req := range ch {
doRequest(req)
}
}(requestChs[i])
}
for {
select {
case <-mergeTicker.C:
pushAllUserState()
case now := <-renewalTicker.C:
deadline := now.Add(-cachekey.OnlineExpire / 3)
users := ws.clients.GetAllUserStatus(deadline, now)
log.ZDebug(context.Background(), "renewal ticker", "deadline", deadline, "nowtime", now, "num", len(users), "users", users)
pushUserState(users...)
case state := <-ws.clients.UserState():
log.ZDebug(context.Background(), "OnlineCache user online change", "userID", state.UserID, "online", state.Online, "offline", state.Offline)
pushUserState(state)
}
}
}

View File

@@ -0,0 +1,63 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msggateway
import "time"
type (
Option func(opt *configs)
configs struct {
// Long connection listening port
port int
// Maximum number of connections allowed for long connection
maxConnNum int64
// Connection handshake timeout
handshakeTimeout time.Duration
// Maximum length allowed for messages
messageMaxMsgLength int
// Websocket write buffer, default: 4096, 4kb.
writeBufferSize int
}
)
func WithPort(port int) Option {
return func(opt *configs) {
opt.port = port
}
}
func WithMaxConnNum(num int64) Option {
return func(opt *configs) {
opt.maxConnNum = num
}
}
func WithHandshakeTimeout(t time.Duration) Option {
return func(opt *configs) {
opt.handshakeTimeout = t
}
}
func WithMessageMaxMsgLength(length int) Option {
return func(opt *configs) {
opt.messageMaxMsgLength = length
}
}
func WithWriteBufferSize(size int) Option {
return func(opt *configs) {
opt.writeBufferSize = size
}
}

View File

@@ -0,0 +1,167 @@
package msggateway
import (
"context"
"sync"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/protobuf/proto"
)
func (ws *WsServer) subscriberUserOnlineStatusChanges(ctx context.Context, userID string, platformIDs []int32) {
if ws.clients.RecvSubChange(userID, platformIDs) {
log.ZDebug(ctx, "gateway receive subscription message and go back online", "userID", userID, "platformIDs", platformIDs)
} else {
log.ZDebug(ctx, "gateway ignore user online status changes", "userID", userID, "platformIDs", platformIDs)
}
ws.pushUserIDOnlineStatus(ctx, userID, platformIDs)
}
func (ws *WsServer) SubUserOnlineStatus(ctx context.Context, client *Client, data *Req) ([]byte, error) {
var sub sdkws.SubUserOnlineStatus
if err := proto.Unmarshal(data.Data, &sub); err != nil {
return nil, err
}
ws.subscription.Sub(client, sub.SubscribeUserID, sub.UnsubscribeUserID)
var resp sdkws.SubUserOnlineStatusTips
if len(sub.SubscribeUserID) > 0 {
resp.Subscribers = make([]*sdkws.SubUserOnlineStatusElem, 0, len(sub.SubscribeUserID))
for _, userID := range sub.SubscribeUserID {
platformIDs, err := ws.online.GetUserOnlinePlatform(ctx, userID)
if err != nil {
return nil, err
}
resp.Subscribers = append(resp.Subscribers, &sdkws.SubUserOnlineStatusElem{
UserID: userID,
OnlinePlatformIDs: platformIDs,
})
}
}
return proto.Marshal(&resp)
}
func newSubscription() *Subscription {
return &Subscription{
userIDs: make(map[string]*subClient),
}
}
type subClient struct {
clients map[string]*Client
}
type Subscription struct {
lock sync.RWMutex
userIDs map[string]*subClient // subscribe to the user's client connection
}
func (s *Subscription) DelClient(client *Client) {
client.subLock.Lock()
userIDs := datautil.Keys(client.subUserIDs)
for _, userID := range userIDs {
delete(client.subUserIDs, userID)
}
client.subLock.Unlock()
if len(userIDs) == 0 {
return
}
addr := client.ctx.GetRemoteAddr()
s.lock.Lock()
defer s.lock.Unlock()
for _, userID := range userIDs {
sub, ok := s.userIDs[userID]
if !ok {
continue
}
delete(sub.clients, addr)
if len(sub.clients) == 0 {
delete(s.userIDs, userID)
}
}
}
func (s *Subscription) GetClient(userID string) []*Client {
s.lock.RLock()
defer s.lock.RUnlock()
cs, ok := s.userIDs[userID]
if !ok {
return nil
}
clients := make([]*Client, 0, len(cs.clients))
for _, client := range cs.clients {
clients = append(clients, client)
}
return clients
}
func (s *Subscription) Sub(client *Client, addUserIDs, delUserIDs []string) {
if len(addUserIDs)+len(delUserIDs) == 0 {
return
}
var (
del = make(map[string]struct{})
add = make(map[string]struct{})
)
client.subLock.Lock()
for _, userID := range delUserIDs {
if _, ok := client.subUserIDs[userID]; !ok {
continue
}
del[userID] = struct{}{}
delete(client.subUserIDs, userID)
}
for _, userID := range addUserIDs {
delete(del, userID)
if _, ok := client.subUserIDs[userID]; ok {
continue
}
client.subUserIDs[userID] = struct{}{}
add[userID] = struct{}{}
}
client.subLock.Unlock()
if len(del)+len(add) == 0 {
return
}
addr := client.ctx.GetRemoteAddr()
s.lock.Lock()
defer s.lock.Unlock()
for userID := range del {
sub, ok := s.userIDs[userID]
if !ok {
continue
}
delete(sub.clients, addr)
if len(sub.clients) == 0 {
delete(s.userIDs, userID)
}
}
for userID := range add {
sub, ok := s.userIDs[userID]
if !ok {
sub = &subClient{clients: make(map[string]*Client)}
s.userIDs[userID] = sub
}
sub.clients[addr] = client
}
}
func (ws *WsServer) pushUserIDOnlineStatus(ctx context.Context, userID string, platformIDs []int32) {
clients := ws.subscription.GetClient(userID)
if len(clients) == 0 {
return
}
onlineStatus, err := proto.Marshal(&sdkws.SubUserOnlineStatusTips{
Subscribers: []*sdkws.SubUserOnlineStatusElem{{UserID: userID, OnlinePlatformIDs: platformIDs}},
})
if err != nil {
log.ZError(ctx, "pushUserIDOnlineStatus json.Marshal", err)
return
}
for _, client := range clients {
if err := client.PushUserOnlineStatus(onlineStatus); err != nil {
log.ZError(ctx, "UserSubscribeOnlineStatusNotification push failed", err, "userID", client.UserID, "platformID", client.PlatformID, "changeUserID", userID, "changePlatformID", platformIDs)
}
}
}

View File

@@ -0,0 +1,185 @@
package msggateway
import (
"github.com/openimsdk/tools/utils/datautil"
"sync"
"time"
)
type UserMap interface {
GetAll(userID string) ([]*Client, bool)
Get(userID string, platformID int) ([]*Client, bool, bool)
Set(userID string, v *Client)
DeleteClients(userID string, clients []*Client) (isDeleteUser bool)
UserState() <-chan UserState
GetAllUserStatus(deadline time.Time, nowtime time.Time) []UserState
RecvSubChange(userID string, platformIDs []int32) bool
}
type UserState struct {
UserID string
Online []int32
Offline []int32
}
type UserPlatform struct {
Time time.Time
Clients []*Client
}
func (u *UserPlatform) PlatformIDs() []int32 {
if len(u.Clients) == 0 {
return nil
}
platformIDs := make([]int32, 0, len(u.Clients))
for _, client := range u.Clients {
platformIDs = append(platformIDs, int32(client.PlatformID))
}
return platformIDs
}
func (u *UserPlatform) PlatformIDSet() map[int32]struct{} {
if len(u.Clients) == 0 {
return nil
}
platformIDs := make(map[int32]struct{})
for _, client := range u.Clients {
platformIDs[int32(client.PlatformID)] = struct{}{}
}
return platformIDs
}
func newUserMap() UserMap {
return &userMap{
data: make(map[string]*UserPlatform),
ch: make(chan UserState, 10000),
}
}
type userMap struct {
lock sync.RWMutex
data map[string]*UserPlatform
ch chan UserState
}
func (u *userMap) RecvSubChange(userID string, platformIDs []int32) bool {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return false
}
localPlatformIDs := result.PlatformIDSet()
for _, platformID := range platformIDs {
delete(localPlatformIDs, platformID)
}
if len(localPlatformIDs) == 0 {
return false
}
u.push(userID, result, nil)
return true
}
func (u *userMap) push(userID string, userPlatform *UserPlatform, offline []int32) bool {
select {
case u.ch <- UserState{UserID: userID, Online: userPlatform.PlatformIDs(), Offline: offline}:
userPlatform.Time = time.Now()
return true
default:
return false
}
}
func (u *userMap) GetAll(userID string) ([]*Client, bool) {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return nil, false
}
return result.Clients, true
}
func (u *userMap) Get(userID string, platformID int) ([]*Client, bool, bool) {
u.lock.RLock()
defer u.lock.RUnlock()
result, ok := u.data[userID]
if !ok {
return nil, false, false
}
var clients []*Client
for _, client := range result.Clients {
if client.PlatformID == platformID {
clients = append(clients, client)
}
}
return clients, true, len(clients) > 0
}
func (u *userMap) Set(userID string, client *Client) {
u.lock.Lock()
defer u.lock.Unlock()
result, ok := u.data[userID]
if ok {
result.Clients = append(result.Clients, client)
} else {
result = &UserPlatform{
Clients: []*Client{client},
}
u.data[userID] = result
}
u.push(client.UserID, result, nil)
}
func (u *userMap) DeleteClients(userID string, clients []*Client) (isDeleteUser bool) {
if len(clients) == 0 {
return false
}
u.lock.Lock()
defer u.lock.Unlock()
result, ok := u.data[userID]
if !ok {
return false
}
offline := make([]int32, 0, len(clients))
deleteAddr := datautil.SliceSetAny(clients, func(client *Client) string {
return client.ctx.GetRemoteAddr()
})
tmp := result.Clients
result.Clients = result.Clients[:0]
for _, client := range tmp {
if _, delCli := deleteAddr[client.ctx.GetRemoteAddr()]; delCli {
offline = append(offline, int32(client.PlatformID))
} else {
result.Clients = append(result.Clients, client)
}
}
defer u.push(userID, result, offline)
if len(result.Clients) > 0 {
return false
}
delete(u.data, userID)
return true
}
func (u *userMap) GetAllUserStatus(deadline time.Time, nowtime time.Time) (result []UserState) {
u.lock.RLock()
defer u.lock.RUnlock()
result = make([]UserState, 0, len(u.data))
for userID, userPlatform := range u.data {
if deadline.Before(userPlatform.Time) {
continue
}
userPlatform.Time = nowtime
online := make([]int32, 0, len(userPlatform.Clients))
for _, client := range userPlatform.Clients {
online = append(online, int32(client.PlatformID))
}
result = append(result, UserState{UserID: userID, Online: online})
}
return result
}
func (u *userMap) UserState() <-chan UserState {
return u.ch
}

View File

@@ -0,0 +1,568 @@
package msggateway
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
pbAuth "git.imall.cloud/openim/protocol/auth"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"github.com/go-playground/validator/v10"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/stringutil"
"golang.org/x/sync/errgroup"
)
type LongConnServer interface {
Run(ctx context.Context) error
wsHandler(w http.ResponseWriter, r *http.Request)
GetUserAllCons(userID string) ([]*Client, bool)
GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool)
Validate(s any) error
SetDiscoveryRegistry(ctx context.Context, client discovery.Conn, config *Config) error
KickUserConn(client *Client) error
UnRegister(c *Client)
SetKickHandlerInfo(i *kickHandler)
SubUserOnlineStatus(ctx context.Context, client *Client, data *Req) ([]byte, error)
Compressor
MessageHandler
}
type WsServer struct {
msgGatewayConfig *Config
port int
wsMaxConnNum int64
registerChan chan *Client
unregisterChan chan *Client
kickHandlerChan chan *kickHandler
clients UserMap
online rpccache.OnlineCache
subscription *Subscription
clientPool sync.Pool
onlineUserNum atomic.Int64
onlineUserConnNum atomic.Int64
handshakeTimeout time.Duration
writeBufferSize int
validate *validator.Validate
disCov discovery.Conn
Compressor
//Encoder
MessageHandler
webhookClient *webhook.Client
userClient *rpcli.UserClient
authClient *rpcli.AuthClient
ready atomic.Bool
}
type kickHandler struct {
clientOK bool
oldClients []*Client
newClient *Client
}
func (ws *WsServer) SetDiscoveryRegistry(ctx context.Context, disCov discovery.Conn, config *Config) error {
userConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
pushConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Push)
if err != nil {
return err
}
authConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Auth)
if err != nil {
return err
}
msgConn, err := disCov.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
}
ws.userClient = rpcli.NewUserClient(userConn)
ws.authClient = rpcli.NewAuthClient(authConn)
ws.MessageHandler = NewGrpcHandler(ws.validate, rpcli.NewMsgClient(msgConn), rpcli.NewPushMsgServiceClient(pushConn))
ws.disCov = disCov
ws.ready.Store(true)
return nil
}
//func (ws *WsServer) SetUserOnlineStatus(ctx context.Context, client *Client, status int32) {
// err := ws.userClient.SetUserStatus(ctx, client.UserID, status, client.PlatformID)
// if err != nil {
// log.ZWarn(ctx, "SetUserStatus err", err)
// }
// switch status {
// case constant.Online:
// ws.webhookAfterUserOnline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOnline, client.UserID, client.PlatformID, client.IsBackground, client.ctx.GetConnID())
// case constant.Offline:
// ws.webhookAfterUserOffline(ctx, &ws.msgGatewayConfig.WebhooksConfig.AfterUserOffline, client.UserID, client.PlatformID, client.ctx.GetConnID())
// }
//}
func (ws *WsServer) UnRegister(c *Client) {
ws.unregisterChan <- c
}
func (ws *WsServer) Validate(_ any) error {
return nil
}
func (ws *WsServer) GetUserAllCons(userID string) ([]*Client, bool) {
return ws.clients.GetAll(userID)
}
func (ws *WsServer) GetUserPlatformCons(userID string, platform int) ([]*Client, bool, bool) {
return ws.clients.Get(userID, platform)
}
func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
var config configs
for _, o := range opts {
o(&config)
}
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser)
v := validator.New()
return &WsServer{
msgGatewayConfig: msgGatewayConfig,
port: config.port,
wsMaxConnNum: config.maxConnNum,
writeBufferSize: config.writeBufferSize,
handshakeTimeout: config.handshakeTimeout,
clientPool: sync.Pool{
New: func() any {
return new(Client)
},
},
registerChan: make(chan *Client, 1000),
unregisterChan: make(chan *Client, 1000),
kickHandlerChan: make(chan *kickHandler, 1000),
validate: v,
clients: newUserMap(),
subscription: newSubscription(),
Compressor: NewGzipCompressor(),
webhookClient: webhook.NewWebhookClient(msgGatewayConfig.WebhooksConfig.URL),
}
}
func (ws *WsServer) Run(ctx context.Context) error {
var client *Client
ctx, cancel := context.WithCancelCause(ctx)
go func() {
for {
select {
case <-ctx.Done():
return
case client = <-ws.registerChan:
ws.registerClient(client)
case client = <-ws.unregisterChan:
ws.unregisterClient(client)
case onlineInfo := <-ws.kickHandlerChan:
ws.multiTerminalLoginChecker(onlineInfo.clientOK, onlineInfo.oldClients, onlineInfo.newClient)
}
}
}()
done := make(chan struct{})
go func() {
wsServer := http.Server{Addr: fmt.Sprintf(":%d", ws.port), Handler: nil}
http.HandleFunc("/", ws.wsHandler)
go func() {
defer close(done)
<-ctx.Done()
_ = wsServer.Shutdown(context.Background())
}()
err := wsServer.ListenAndServe()
if err == nil {
err = fmt.Errorf("http server closed")
}
cancel(fmt.Errorf("msg gateway %w", err))
}()
<-ctx.Done()
timeout := time.NewTimer(time.Second * 15)
defer timeout.Stop()
select {
case <-timeout.C:
log.ZWarn(ctx, "msg gateway graceful stop timeout", nil)
case <-done:
log.ZDebug(ctx, "msg gateway graceful stop done")
}
return context.Cause(ctx)
}
const concurrentRequest = 3
func (ws *WsServer) sendUserOnlineInfoToOtherNode(ctx context.Context, client *Client) error {
conns, err := ws.disCov.GetConns(ctx, ws.msgGatewayConfig.Discovery.RpcService.MessageGateway)
if err != nil {
return err
}
if len(conns) == 0 || (len(conns) == 1 && ws.disCov.IsSelfNode(conns[0])) {
return nil
}
wg := errgroup.Group{}
wg.SetLimit(concurrentRequest)
// Online push user online message to other node
for _, v := range conns {
v := v
log.ZDebug(ctx, "sendUserOnlineInfoToOtherNode conn")
if ws.disCov.IsSelfNode(v) {
log.ZDebug(ctx, "Filter out this node")
continue
}
wg.Go(func() error {
msgClient := msggateway.NewMsgGatewayClient(v)
_, err := msgClient.MultiTerminalLoginCheck(ctx, &msggateway.MultiTerminalLoginCheckReq{
UserID: client.UserID,
PlatformID: int32(client.PlatformID), Token: client.token,
})
if err != nil {
log.ZWarn(ctx, "MultiTerminalLoginCheck err", err)
}
return nil
})
}
_ = wg.Wait()
return nil
}
func (ws *WsServer) SetKickHandlerInfo(i *kickHandler) {
ws.kickHandlerChan <- i
}
func (ws *WsServer) registerClient(client *Client) {
var (
userOK bool
clientOK bool
oldClients []*Client
)
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID,
"sdkVersion", client.SDKVersion)
if !userOK {
ws.clients.Set(client.UserID, client)
log.ZDebug(client.ctx, "user not exist", "userID", client.UserID, "platformID", client.PlatformID)
prommetrics.OnlineUserGauge.Add(1)
ws.onlineUserNum.Add(1)
ws.onlineUserConnNum.Add(1)
} else {
ws.multiTerminalLoginChecker(clientOK, oldClients, client)
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK {
ws.clients.Set(client.UserID, client)
// There is already a connection to the platform
log.ZDebug(client.ctx, "repeat login", "userID", client.UserID, "platformID",
client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.onlineUserConnNum.Add(1)
} else {
ws.clients.Set(client.UserID, client)
ws.onlineUserConnNum.Add(1)
}
}
wg := sync.WaitGroup{}
log.ZDebug(client.ctx, "ws.msgGatewayConfig.Discovery.Enable", "discoveryEnable", ws.msgGatewayConfig.Discovery.Enable)
if ws.msgGatewayConfig.Discovery.Enable != "k8s" {
wg.Add(1)
go func() {
defer wg.Done()
_ = ws.sendUserOnlineInfoToOtherNode(client.ctx, client)
}()
}
//wg.Add(1)
//go func() {
// defer wg.Done()
// ws.SetUserOnlineStatus(client.ctx, client, constant.Online)
//}()
wg.Wait()
log.ZDebug(client.ctx, "user online", "online user Num", ws.onlineUserNum.Load(), "online user conn Num", ws.onlineUserConnNum.Load())
}
func getRemoteAdders(client []*Client) string {
var ret string
for i, c := range client {
if i == 0 {
ret = c.ctx.GetRemoteAddr()
} else {
ret += "@" + c.ctx.GetRemoteAddr()
}
}
return ret
}
func (ws *WsServer) KickUserConn(client *Client) error {
// 先发送踢掉通知,再删除客户端,确保通知能正确发送
err := client.KickOnlineMessage()
// KickOnlineMessage 内部会调用 close()close() 会调用 UnRegisterUnRegister 会删除客户端
// 所以这里不需要再调用 DeleteClients
return err
}
func (ws *WsServer) multiTerminalLoginChecker(clientOK bool, oldClients []*Client, newClient *Client) {
kickTokenFunc := func(kickClients []*Client) {
if len(kickClients) == 0 {
return
}
var kickTokens []string
// 先发送踢掉通知KickOnlineMessage 内部会关闭连接并删除客户端
// 使用 goroutine 并发发送通知,提高效率
var wg sync.WaitGroup
for _, c := range kickClients {
kickTokens = append(kickTokens, c.token)
wg.Add(1)
go func(client *Client) {
defer wg.Done()
err := client.KickOnlineMessage()
if err != nil {
log.ZWarn(client.ctx, "KickOnlineMessage", err)
}
}(c)
}
// 等待所有通知发送完成
wg.Wait()
ctx := mcontext.WithMustInfoCtx(
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)
if err := ws.authClient.KickTokens(ctx, kickTokens); err != nil {
log.ZWarn(newClient.ctx, "kickTokens err", err)
}
}
// If reconnect: When multiple msgGateway instances are deployed, a client may disconnect from instance A and reconnect to instance B.
// During this process, instance A might still be executing, resulting in two clients with the same token existing simultaneously.
// This situation needs to be filtered to prevent duplicate clients.
checkSameTokenFunc := func(oldClients []*Client) []*Client {
var clientsNeedToKick []*Client
for _, c := range oldClients {
if c.token == newClient.token {
log.ZDebug(newClient.ctx, "token is same, not kick",
"userID", newClient.UserID,
"platformID", newClient.PlatformID,
"token", newClient.token)
continue
}
clientsNeedToKick = append(clientsNeedToKick, c)
}
return clientsNeedToKick
}
switch ws.msgGatewayConfig.Share.MultiLogin.Policy {
case constant.DefalutNotKick:
case constant.PCAndOther:
if constant.PlatformIDToClass(newClient.PlatformID) == constant.TerminalPC {
return
}
clients, ok := ws.clients.GetAll(newClient.UserID)
clientOK = ok
oldClients = make([]*Client, 0, len(clients))
for _, c := range clients {
if constant.PlatformIDToClass(c.PlatformID) == constant.TerminalPC {
continue
}
oldClients = append(oldClients, c)
}
fallthrough
case constant.AllLoginButSameTermKick:
if !clientOK {
return
}
oldClients = checkSameTokenFunc(oldClients)
if len(oldClients) == 0 {
return
}
// 先发送踢掉通知KickOnlineMessage 内部会关闭连接并删除客户端
// 使用 goroutine 并发发送通知,提高效率
var wg sync.WaitGroup
for _, c := range oldClients {
wg.Add(1)
go func(client *Client) {
defer wg.Done()
err := client.KickOnlineMessage()
if err != nil {
log.ZWarn(client.ctx, "KickOnlineMessage", err)
}
}(c)
}
// 等待所有通知发送完成
wg.Wait()
ctx := mcontext.WithMustInfoCtx(
[]string{newClient.ctx.GetOperationID(), newClient.ctx.GetUserID(),
constant.PlatformIDToName(newClient.PlatformID), newClient.ctx.GetConnID()},
)
req := &pbAuth.InvalidateTokenReq{
PreservedToken: newClient.token,
UserID: newClient.UserID,
PlatformID: int32(newClient.PlatformID),
}
if err := ws.authClient.InvalidateToken(ctx, req); err != nil {
log.ZWarn(newClient.ctx, "InvalidateToken err", err, "userID", newClient.UserID,
"platformID", newClient.PlatformID)
}
case constant.AllLoginButSameClassKick:
clients, ok := ws.clients.GetAll(newClient.UserID)
if !ok {
return
}
var kickClients []*Client
for _, client := range clients {
if constant.PlatformIDToClass(client.PlatformID) == constant.PlatformIDToClass(newClient.PlatformID) {
{
kickClients = append(kickClients, client)
}
}
}
kickClients = checkSameTokenFunc(kickClients)
kickTokenFunc(kickClients)
}
}
func (ws *WsServer) unregisterClient(client *Client) {
defer ws.clientPool.Put(client)
isDeleteUser := ws.clients.DeleteClients(client.UserID, []*Client{client})
if isDeleteUser {
ws.onlineUserNum.Add(-1)
prommetrics.OnlineUserGauge.Dec()
}
ws.onlineUserConnNum.Add(-1)
ws.subscription.DelClient(client)
//ws.SetUserOnlineStatus(client.ctx, client, constant.Offline)
log.ZDebug(client.ctx, "user offline", "close reason", client.closedErr, "online user Num",
ws.onlineUserNum.Load(), "online user conn Num",
ws.onlineUserConnNum.Load(),
)
}
// validateRespWithRequest checks if the response matches the expected userID and platformID.
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
userID := ctx.GetUserID()
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
if resp.UserID != userID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
}
if resp.PlatformID != platformID {
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token platform %d != platformID %d", resp.PlatformID, platformID))
}
return nil
}
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
// Create a new connection context
connContext := newContext(w, r)
if !ws.ready.Load() {
httpError(connContext, errs.New("ws server not ready"))
return
}
// Check if the current number of online user connections exceeds the maximum limit
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
return
}
// Parse essential arguments (e.g., user ID, Token)
err := connContext.ParseEssentialArgs()
if err != nil {
// If there's an error during parsing, return an error via HTTP and stop processing
httpError(connContext, err)
return
}
if ws.authClient == nil {
httpError(connContext, errs.New("auth client is not initialized"))
return
}
// Call the authentication client to parse the Token obtained from the context
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
if err != nil {
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
shouldSendError := connContext.ShouldSendResp()
if shouldSendError {
// Create a WebSocket connection object and attempt to send the error message via WebSocket
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
// If the error message is successfully sent via WebSocket, stop processing
return
}
}
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
httpError(connContext, err)
return
}
// Validate the authentication response matches the request (e.g., user ID and platform ID)
err = ws.validateRespWithRequest(connContext, resp)
if err != nil {
// If validation fails, return an error via HTTP and stop processing
httpError(connContext, err)
return
}
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
// Create a WebSocket long connection object
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
//If the creation of the long connection fails, the error is handled internally during the handshake process.
log.ZWarn(connContext, "long connection fails", err)
return
} else {
// Check if a normal response should be sent via WebSocket
shouldSendSuccessResp := connContext.ShouldSendResp()
if shouldSendSuccessResp {
// Attempt to send a success message through WebSocket
if err := wsLongConn.RespondWithSuccess(); err != nil {
// If the success message is successfully sent, end further processing
return
}
}
}
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
client := ws.clientPool.Get().(*Client)
client.ResetClient(connContext, wsLongConn, ws)
// Register the client with the server and start message processing
ws.registerChan <- client
go client.readMessage()
}

View File

@@ -0,0 +1,112 @@
package msgtransfer
import (
"context"
"encoding/base64"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/stringutil"
"google.golang.org/protobuf/proto"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
)
func toCommonCallback(ctx context.Context, msg *sdkws.MsgData, command string) cbapi.CommonCallbackReq {
return cbapi.CommonCallbackReq{
SendID: msg.SendID,
ServerMsgID: msg.ServerMsgID,
CallbackCommand: command,
ClientMsgID: msg.ClientMsgID,
OperationID: mcontext.GetOperationID(ctx),
SenderPlatformID: msg.SenderPlatformID,
SenderNickname: msg.SenderNickname,
SessionType: msg.SessionType,
MsgFrom: msg.MsgFrom,
ContentType: msg.ContentType,
Status: msg.Status,
SendTime: msg.SendTime,
CreateTime: msg.CreateTime,
AtUserIDList: msg.AtUserIDList,
SenderFaceURL: msg.SenderFaceURL,
Content: GetContent(msg),
Seq: uint32(msg.Seq),
Ex: msg.Ex,
}
}
func GetContent(msg *sdkws.MsgData) string {
if msg.ContentType >= constant.NotificationBegin && msg.ContentType <= constant.NotificationEnd {
var tips sdkws.TipsComm
_ = proto.Unmarshal(msg.Content, &tips)
content := tips.JsonDetail
return content
} else {
return string(msg.Content)
}
}
func (mc *OnlineHistoryMongoConsumerHandler) webhookAfterMsgSaveDB(ctx context.Context, after *config.AfterConfig, msg *sdkws.MsgData) {
if !filterAfterMsg(msg, after) {
return
}
cbReq := &cbapi.CallbackAfterMsgSaveDBReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackAfterMsgSaveDBCommand),
RecvID: msg.RecvID,
GroupID: msg.GroupID,
}
mc.webhookClient.AsyncPostWithQuery(ctx, cbReq.GetCallbackCommand(), cbReq, &cbapi.CallbackAfterMsgSaveDBResp{}, after, buildKeyMsgDataQuery(msg))
}
func buildKeyMsgDataQuery(msg *sdkws.MsgData) map[string]string {
keyMsgData := apistruct.KeyMsgData{
SendID: msg.SendID,
RecvID: msg.RecvID,
GroupID: msg.GroupID,
}
return map[string]string{
webhook.Key: base64.StdEncoding.EncodeToString(stringutil.StructToJsonBytes(keyMsgData)),
}
}
func filterAfterMsg(msg *sdkws.MsgData, after *config.AfterConfig) bool {
return filterMsg(msg, after.AttentionIds, after.DeniedTypes)
}
func filterMsg(msg *sdkws.MsgData, attentionIds []string, deniedTypes []int32) bool {
// According to the attentionIds configuration, only some users are sent
if len(attentionIds) != 0 && msg.ContentType == constant.SingleChatType && !datautil.Contain(msg.RecvID, attentionIds...) {
return false
}
if len(attentionIds) != 0 && msg.ContentType == constant.ReadGroupChatType && !datautil.Contain(msg.GroupID, attentionIds...) {
return false
}
if defaultDeniedTypes(msg.ContentType) {
return false
}
if len(deniedTypes) != 0 && datautil.Contain(msg.ContentType, deniedTypes...) {
return false
}
return true
}
func defaultDeniedTypes(contentType int32) bool {
if contentType >= constant.NotificationBegin && contentType <= constant.NotificationEnd {
return true
}
if contentType == constant.Typing {
return true
}
return false
}

View File

@@ -0,0 +1,188 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgtransfer
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/mqbuild"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/mq"
"github.com/openimsdk/tools/utils/runtimeenv"
conf "git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"github.com/openimsdk/tools/log"
"google.golang.org/grpc"
)
type MsgTransfer struct {
historyConsumer mq.Consumer
historyMongoConsumer mq.Consumer
// This consumer aggregated messages, subscribed to the topic:toRedis,
// the message is stored in redis, Incr Redis, and then the message is sent to toPush topic for push,
// and the message is sent to toMongo topic for persistence
historyHandler *OnlineHistoryRedisConsumerHandler
//This consumer handle message to mongo
historyMongoHandler *OnlineHistoryMongoConsumerHandler
ctx context.Context
//cancel context.CancelFunc
}
type Config struct {
MsgTransfer conf.MsgTransfer
RedisConfig conf.Redis
MongodbConfig conf.Mongo
KafkaConfig conf.Kafka
Share conf.Share
WebhooksConfig conf.Webhooks
Discovery conf.Discovery
Index conf.Index
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
builder := mqbuild.NewBuilder(&config.KafkaConfig)
log.CInfo(ctx, "MSG-TRANSFER server is initializing", "runTimeEnv", runtimeenv.RuntimeEnvironment(), "prometheusPorts",
config.MsgTransfer.Prometheus.Ports, "index", config.Index)
dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return err
}
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
//if config.Discovery.Enable == conf.ETCD {
// cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), []string{
// config.MsgTransfer.GetConfigFileName(),
// config.RedisConfig.GetConfigFileName(),
// config.MongodbConfig.GetConfigFileName(),
// config.KafkaConfig.GetConfigFileName(),
// config.Share.GetConfigFileName(),
// config.WebhooksConfig.GetConfigFileName(),
// config.Discovery.GetConfigFileName(),
// conf.LogConfigFileName,
// })
// cm.Watch(ctx)
//}
mongoProducer, err := builder.GetTopicProducer(ctx, config.KafkaConfig.ToMongoTopic)
if err != nil {
return err
}
pushProducer, err := builder.GetTopicProducer(ctx, config.KafkaConfig.ToPushTopic)
if err != nil {
return err
}
msgDocModel, err := mgo.NewMsgMongo(mgocli.GetDB())
if err != nil {
return err
}
var msgModel cache.MsgCache
if rdb == nil {
cm, err := mgo.NewCacheMgo(mgocli.GetDB())
if err != nil {
return err
}
msgModel = mcache.NewMsgCache(cm, msgDocModel)
} else {
msgModel = redis.NewMsgCache(rdb, msgDocModel)
}
seqConversation, err := mgo.NewSeqConversationMongo(mgocli.GetDB())
if err != nil {
return err
}
seqConversationCache := redis.NewSeqConversationCacheRedis(rdb, seqConversation)
seqUser, err := mgo.NewSeqUserMongo(mgocli.GetDB())
if err != nil {
return err
}
seqUserCache := redis.NewSeqUserCacheRedis(rdb, seqUser)
msgTransferDatabase, err := controller.NewMsgTransferDatabase(msgDocModel, msgModel, seqUserCache, seqConversationCache, mongoProducer, pushProducer)
if err != nil {
return err
}
historyConsumer, err := builder.GetTopicConsumer(ctx, config.KafkaConfig.ToRedisTopic)
if err != nil {
return err
}
historyMongoConsumer, err := builder.GetTopicConsumer(ctx, config.KafkaConfig.ToMongoTopic)
if err != nil {
return err
}
historyHandler, err := NewOnlineHistoryRedisConsumerHandler(ctx, client, config, msgTransferDatabase)
if err != nil {
return err
}
historyMongoHandler := NewOnlineHistoryMongoConsumerHandler(msgTransferDatabase, config)
msgTransfer := &MsgTransfer{
historyConsumer: historyConsumer,
historyMongoConsumer: historyMongoConsumer,
historyHandler: historyHandler,
historyMongoHandler: historyMongoHandler,
}
return msgTransfer.Start(ctx)
}
func (m *MsgTransfer) Start(ctx context.Context) error {
m.ctx = ctx
go func() {
for {
if err := m.historyConsumer.Subscribe(m.ctx, m.historyHandler.HandlerRedisMessage); err != nil {
log.ZError(m.ctx, "historyConsumer err, will retry in 5 seconds", err)
time.Sleep(5 * time.Second)
continue
}
// Subscribe returned normally (possibly due to context cancellation), retry immediately
log.ZWarn(m.ctx, "historyConsumer Subscribe returned normally, will retry immediately", nil)
}
}()
go func() {
fn := func(msg mq.Message) error {
m.historyMongoHandler.HandleChatWs2Mongo(msg)
return nil
}
for {
if err := m.historyMongoConsumer.Subscribe(m.ctx, fn); err != nil {
log.ZError(m.ctx, "historyMongoConsumer err, will retry in 5 seconds", err)
time.Sleep(5 * time.Second)
continue
}
// Subscribe returned normally (possibly due to context cancellation), retry immediately
log.ZWarn(m.ctx, "historyMongoConsumer Subscribe returned normally, will retry immediately", nil)
}
}()
go m.historyHandler.HandleUserHasReadSeqMessages(m.ctx)
err := m.historyHandler.redisMessageBatches.Start()
if err != nil {
return err
}
<-m.ctx.Done()
return m.ctx.Err()
}

View File

@@ -0,0 +1,410 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgtransfer
import (
"context"
"encoding/json"
"errors"
"github.com/openimsdk/tools/mq"
"sync"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/openimsdk/tools/discovery"
"github.com/go-redis/redis"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/open-im-server-deploy/pkg/tools/batcher"
"git.imall.cloud/openim/protocol/constant"
pbconv "git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/stringutil"
)
const (
size = 1000
mainDataBuffer = 2000
subChanBuffer = 200
worker = 100
interval = 100 * time.Millisecond
hasReadChanBuffer = 2000
)
type ContextMsg struct {
message *sdkws.MsgData
ctx context.Context
}
// This structure is used for asynchronously writing the senders read sequence (seq) regarding a message into MongoDB.
// For example, if the sender sends a message with a seq of 10, then their own read seq for this conversation should be set to 10.
type userHasReadSeq struct {
conversationID string
userHasReadMap map[string]int64
}
type OnlineHistoryRedisConsumerHandler struct {
redisMessageBatches *batcher.Batcher[ConsumerMessage]
msgTransferDatabase controller.MsgTransferDatabase
conversationUserHasReadChan chan *userHasReadSeq
wg sync.WaitGroup
groupClient *rpcli.GroupClient
conversationClient *rpcli.ConversationClient
}
type ConsumerMessage struct {
Ctx context.Context
Key string
Value []byte
Raw mq.Message
}
func NewOnlineHistoryRedisConsumerHandler(ctx context.Context, client discovery.Conn, config *Config, database controller.MsgTransferDatabase) (*OnlineHistoryRedisConsumerHandler, error) {
groupConn, err := client.GetConn(ctx, config.Discovery.RpcService.Group)
if err != nil {
return nil, err
}
conversationConn, err := client.GetConn(ctx, config.Discovery.RpcService.Conversation)
if err != nil {
return nil, err
}
var och OnlineHistoryRedisConsumerHandler
och.msgTransferDatabase = database
och.conversationUserHasReadChan = make(chan *userHasReadSeq, hasReadChanBuffer)
och.groupClient = rpcli.NewGroupClient(groupConn)
och.conversationClient = rpcli.NewConversationClient(conversationConn)
och.wg.Add(1)
b := batcher.New[ConsumerMessage](
batcher.WithSize(size),
batcher.WithWorker(worker),
batcher.WithInterval(interval),
batcher.WithDataBuffer(mainDataBuffer),
batcher.WithSyncWait(true),
batcher.WithBuffer(subChanBuffer),
)
b.Sharding = func(key string) int {
hashCode := stringutil.GetHashCode(key)
return int(hashCode) % och.redisMessageBatches.Worker()
}
b.Key = func(consumerMessage *ConsumerMessage) string {
return consumerMessage.Key
}
b.Do = och.do
och.redisMessageBatches = b
och.redisMessageBatches.OnComplete = func(lastMessage *ConsumerMessage, totalCount int) {
lastMessage.Raw.Mark()
lastMessage.Raw.Commit()
}
return &och, nil
}
func (och *OnlineHistoryRedisConsumerHandler) do(ctx context.Context, channelID int, val *batcher.Msg[ConsumerMessage]) {
ctx = mcontext.WithTriggerIDContext(ctx, val.TriggerID())
ctxMessages := och.parseConsumerMessages(ctx, val.Val())
ctx = withAggregationCtx(ctx, ctxMessages)
log.ZInfo(ctx, "msg arrived channel", "channel id", channelID, "msgList length", len(ctxMessages), "key", val.Key())
och.doSetReadSeq(ctx, ctxMessages)
storageMsgList, notStorageMsgList, storageNotificationList, notStorageNotificationList :=
och.categorizeMessageLists(ctxMessages)
log.ZDebug(ctx, "number of categorized messages", "storageMsgList", len(storageMsgList), "notStorageMsgList",
len(notStorageMsgList), "storageNotificationList", len(storageNotificationList), "notStorageNotificationList", len(notStorageNotificationList))
conversationIDMsg := msgprocessor.GetChatConversationIDByMsg(ctxMessages[0].message)
conversationIDNotification := msgprocessor.GetNotificationConversationIDByMsg(ctxMessages[0].message)
och.handleMsg(ctx, val.Key(), conversationIDMsg, storageMsgList, notStorageMsgList)
och.handleNotification(ctx, val.Key(), conversationIDNotification, storageNotificationList, notStorageNotificationList)
}
func (och *OnlineHistoryRedisConsumerHandler) doSetReadSeq(ctx context.Context, msgs []*ContextMsg) {
// Outer map: conversationID -> (userID -> maxHasReadSeq)
conversationUserSeq := make(map[string]map[string]int64)
for _, msg := range msgs {
if msg.message.ContentType != constant.HasReadReceipt {
continue
}
var elem sdkws.NotificationElem
if err := json.Unmarshal(msg.message.Content, &elem); err != nil {
log.ZWarn(ctx, "Unmarshal NotificationElem error", err, "msg", msg)
continue
}
var tips sdkws.MarkAsReadTips
if err := json.Unmarshal([]byte(elem.Detail), &tips); err != nil {
log.ZWarn(ctx, "Unmarshal MarkAsReadTips error", err, "msg", msg)
continue
}
if len(tips.ConversationID) == 0 || tips.HasReadSeq < 0 {
continue
}
// Calculate the max seq from tips.Seqs
for _, seq := range tips.Seqs {
if tips.HasReadSeq < seq {
tips.HasReadSeq = seq
}
}
if _, ok := conversationUserSeq[tips.ConversationID]; !ok {
conversationUserSeq[tips.ConversationID] = make(map[string]int64)
}
if conversationUserSeq[tips.ConversationID][tips.MarkAsReadUserID] < tips.HasReadSeq {
conversationUserSeq[tips.ConversationID][tips.MarkAsReadUserID] = tips.HasReadSeq
}
}
log.ZInfo(ctx, "doSetReadSeq", "conversationUserSeq", conversationUserSeq)
// persist to db
for convID, userSeqMap := range conversationUserSeq {
if err := och.msgTransferDatabase.SetHasReadSeqToDB(ctx, convID, userSeqMap); err != nil {
log.ZWarn(ctx, "SetHasReadSeqToDB error", err, "conversationID", convID, "userSeqMap", userSeqMap)
}
}
}
func (och *OnlineHistoryRedisConsumerHandler) parseConsumerMessages(ctx context.Context, consumerMessages []*ConsumerMessage) []*ContextMsg {
var ctxMessages []*ContextMsg
for i := 0; i < len(consumerMessages); i++ {
ctxMsg := &ContextMsg{}
msgFromMQ := &sdkws.MsgData{}
err := proto.Unmarshal(consumerMessages[i].Value, msgFromMQ)
if err != nil {
log.ZWarn(ctx, "msg_transfer Unmarshal msg err", err, string(consumerMessages[i].Value))
continue
}
ctxMsg.ctx = consumerMessages[i].Ctx
ctxMsg.message = msgFromMQ
log.ZDebug(ctx, "message parse finish", "message", msgFromMQ, "key", consumerMessages[i].Key)
ctxMessages = append(ctxMessages, ctxMsg)
}
return ctxMessages
}
// Get messages/notifications stored message list, not stored and pushed message list.
func (och *OnlineHistoryRedisConsumerHandler) categorizeMessageLists(totalMsgs []*ContextMsg) (storageMsgList,
notStorageMsgList, storageNotificationList, notStorageNotificationList []*ContextMsg) {
for _, v := range totalMsgs {
options := msgprocessor.Options(v.message.Options)
if !options.IsNotNotification() {
// clone msg from notificationMsg
if options.IsSendMsg() {
msg := proto.Clone(v.message).(*sdkws.MsgData)
// message
if v.message.Options != nil {
msg.Options = msgprocessor.NewMsgOptions()
}
msg.Options = msgprocessor.WithOptions(msg.Options,
msgprocessor.WithOfflinePush(options.IsOfflinePush()),
msgprocessor.WithUnreadCount(options.IsUnreadCount()),
)
v.message.Options = msgprocessor.WithOptions(
v.message.Options,
msgprocessor.WithOfflinePush(false),
msgprocessor.WithUnreadCount(false),
)
ctxMsg := &ContextMsg{
message: msg,
ctx: v.ctx,
}
storageMsgList = append(storageMsgList, ctxMsg)
}
if options.IsHistory() {
storageNotificationList = append(storageNotificationList, v)
} else {
notStorageNotificationList = append(notStorageNotificationList, v)
}
} else {
if options.IsHistory() {
storageMsgList = append(storageMsgList, v)
} else {
notStorageMsgList = append(notStorageMsgList, v)
}
}
}
return
}
func (och *OnlineHistoryRedisConsumerHandler) handleMsg(ctx context.Context, key, conversationID string, storageList, notStorageList []*ContextMsg) {
log.ZInfo(ctx, "handle storage msg")
for _, storageMsg := range storageList {
log.ZDebug(ctx, "handle storage msg", "msg", storageMsg.message.String())
}
och.toPushTopic(ctx, key, conversationID, notStorageList)
var storageMessageList []*sdkws.MsgData
for _, msg := range storageList {
storageMessageList = append(storageMessageList, msg.message)
}
if len(storageMessageList) > 0 {
msg := storageMessageList[0]
lastSeq, isNewConversation, userSeqMap, err := och.msgTransferDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList)
if err != nil && !errors.Is(errs.Unwrap(err), redis.Nil) {
log.ZWarn(ctx, "batch data insert to redis err", err, "storageMsgList", storageMessageList)
return
}
log.ZInfo(ctx, "BatchInsertChat2Cache end")
err = och.msgTransferDatabase.SetHasReadSeqs(ctx, conversationID, userSeqMap)
if err != nil {
log.ZWarn(ctx, "SetHasReadSeqs error", err, "userSeqMap", userSeqMap, "conversationID", conversationID)
prommetrics.SeqSetFailedCounter.Inc()
}
select {
case och.conversationUserHasReadChan <- &userHasReadSeq{
conversationID: conversationID,
userHasReadMap: userSeqMap,
}:
default:
log.ZWarn(ctx, "conversationUserHasReadChan full, drop userHasReadSeq update", nil,
"conversationID", conversationID, "userSeqMapSize", len(userSeqMap))
}
if isNewConversation {
switch msg.SessionType {
case constant.ReadGroupChatType:
log.ZDebug(ctx, "group chat first create conversation", "conversationID",
conversationID)
userIDs, err := och.groupClient.GetGroupMemberUserIDs(ctx, msg.GroupID)
if err != nil {
log.ZWarn(ctx, "get group member ids error", err, "conversationID",
conversationID)
} else {
log.ZInfo(ctx, "GetGroupMemberIDs end")
if err := och.conversationClient.CreateGroupChatConversations(ctx, msg.GroupID, userIDs); err != nil {
log.ZWarn(ctx, "single chat first create conversation error", err,
"conversationID", conversationID)
}
}
case constant.SingleChatType, constant.NotificationChatType:
req := &pbconv.CreateSingleChatConversationsReq{
RecvID: msg.RecvID,
SendID: msg.SendID,
ConversationID: conversationID,
ConversationType: msg.SessionType,
}
if err := och.conversationClient.CreateSingleChatConversations(ctx, req); err != nil {
log.ZWarn(ctx, "single chat or notification first create conversation error", err,
"conversationID", conversationID, "sessionType", msg.SessionType)
}
default:
log.ZWarn(ctx, "unknown session type", nil, "sessionType",
msg.SessionType)
}
}
log.ZInfo(ctx, "success incr to next topic")
err = och.msgTransferDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq)
if err != nil {
log.ZError(ctx, "Msg To MongoDB MQ error", err, "conversationID",
conversationID, "storageList", storageMessageList, "lastSeq", lastSeq)
}
log.ZInfo(ctx, "MsgToMongoMQ end")
och.toPushTopic(ctx, key, conversationID, storageList)
log.ZInfo(ctx, "toPushTopic end")
}
}
func (och *OnlineHistoryRedisConsumerHandler) handleNotification(ctx context.Context, key, conversationID string,
storageList, notStorageList []*ContextMsg) {
och.toPushTopic(ctx, key, conversationID, notStorageList)
var storageMessageList []*sdkws.MsgData
for _, msg := range storageList {
storageMessageList = append(storageMessageList, msg.message)
}
if len(storageMessageList) > 0 {
lastSeq, _, _, err := och.msgTransferDatabase.BatchInsertChat2Cache(ctx, conversationID, storageMessageList)
if err != nil {
log.ZError(ctx, "notification batch insert to redis error", err, "conversationID", conversationID,
"storageList", storageMessageList)
return
}
log.ZDebug(ctx, "success to next topic", "conversationID", conversationID)
err = och.msgTransferDatabase.MsgToMongoMQ(ctx, key, conversationID, storageMessageList, lastSeq)
if err != nil {
log.ZError(ctx, "Msg To MongoDB MQ error", err, "conversationID",
conversationID, "storageList", storageMessageList, "lastSeq", lastSeq)
}
och.toPushTopic(ctx, key, conversationID, storageList)
}
}
func (och *OnlineHistoryRedisConsumerHandler) HandleUserHasReadSeqMessages(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
log.ZPanic(ctx, "HandleUserHasReadSeqMessages Panic", errs.ErrPanic(r))
}
}()
defer och.wg.Done()
for msg := range och.conversationUserHasReadChan {
if err := och.msgTransferDatabase.SetHasReadSeqToDB(ctx, msg.conversationID, msg.userHasReadMap); err != nil {
log.ZWarn(ctx, "set read seq to db error", err, "conversationID", msg.conversationID, "userSeqMap", msg.userHasReadMap)
}
}
log.ZInfo(ctx, "Channel closed, exiting handleUserHasReadSeqMessages")
}
func (och *OnlineHistoryRedisConsumerHandler) Close() {
close(och.conversationUserHasReadChan)
och.wg.Wait()
}
func (och *OnlineHistoryRedisConsumerHandler) toPushTopic(ctx context.Context, key, conversationID string, msgs []*ContextMsg) {
for _, v := range msgs {
log.ZDebug(ctx, "push msg to topic", "msg", v.message.String())
if err := och.msgTransferDatabase.MsgToPushMQ(v.ctx, key, conversationID, v.message); err != nil {
log.ZError(ctx, "msg to push topic error", err, "msg", v.message.String())
}
}
}
func withAggregationCtx(ctx context.Context, values []*ContextMsg) context.Context {
var allMessageOperationID string
for i, v := range values {
if opid := mcontext.GetOperationID(v.ctx); opid != "" {
if i == 0 {
allMessageOperationID += opid
} else {
allMessageOperationID += "$" + opid
}
}
}
return mcontext.SetOperationID(ctx, allMessageOperationID)
}
func (och *OnlineHistoryRedisConsumerHandler) HandlerRedisMessage(msg mq.Message) error { // a instance in the consumer group
err := och.redisMessageBatches.Put(msg.Context(), &ConsumerMessage{Ctx: msg.Context(), Key: msg.Key(), Value: msg.Value(), Raw: msg})
if err != nil {
log.ZWarn(msg.Context(), "put msg to error", err, "key", msg.Key(), "value", msg.Value())
}
return nil
}

View File

@@ -0,0 +1,78 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msgtransfer
import (
"github.com/openimsdk/tools/mq"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
pbmsg "git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/log"
"google.golang.org/protobuf/proto"
)
type OnlineHistoryMongoConsumerHandler struct {
msgTransferDatabase controller.MsgTransferDatabase
config *Config
webhookClient *webhook.Client
}
func NewOnlineHistoryMongoConsumerHandler(database controller.MsgTransferDatabase, config *Config) *OnlineHistoryMongoConsumerHandler {
return &OnlineHistoryMongoConsumerHandler{
msgTransferDatabase: database,
config: config,
webhookClient: webhook.NewWebhookClient(config.WebhooksConfig.URL),
}
}
func (mc *OnlineHistoryMongoConsumerHandler) HandleChatWs2Mongo(val mq.Message) {
ctx := val.Context()
key := val.Key()
msg := val.Value()
msgFromMQ := pbmsg.MsgDataToMongoByMQ{}
err := proto.Unmarshal(msg, &msgFromMQ)
if err != nil {
log.ZError(ctx, "unmarshall failed", err, "key", key, "len", len(msg))
return
}
if len(msgFromMQ.MsgData) == 0 {
log.ZError(ctx, "msgFromMQ.MsgData is empty", nil, "key", key, "msg", msg)
return
}
log.ZDebug(ctx, "mongo consumer recv msg", "msgs", msgFromMQ.String())
err = mc.msgTransferDatabase.BatchInsertChat2DB(ctx, msgFromMQ.ConversationID, msgFromMQ.MsgData, msgFromMQ.LastSeq)
if err != nil {
log.ZError(ctx, "single data insert to mongo err", err, "msg", msgFromMQ.MsgData, "conversationID", msgFromMQ.ConversationID)
prommetrics.MsgInsertMongoFailedCounter.Inc()
} else {
prommetrics.MsgInsertMongoSuccessCounter.Inc()
val.Mark()
}
for _, msgData := range msgFromMQ.MsgData {
mc.webhookAfterMsgSaveDB(ctx, &mc.config.WebhooksConfig.AfterMsgSaveDB, msgData)
}
//var seqs []int64
//for _, msg := range msgFromMQ.MsgData {
// seqs = append(seqs, msg.Seq)
//}
//if err := mc.msgTransferDatabase.DeleteMessagesFromCache(ctx, msgFromMQ.ConversationID, seqs); err != nil {
// log.ZError(ctx, "remove cache msg from redis err", err, "msg",
// msgFromMQ.MsgData, "conversationID", msgFromMQ.ConversationID)
//}
}

150
internal/push/callback.go Normal file
View File

@@ -0,0 +1,150 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package push
import (
"context"
"encoding/json"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
)
func (c *ConsumerHandler) webhookBeforeOfflinePush(ctx context.Context, before *config.BeforeConfig, userIDs []string, msg *sdkws.MsgData, offlinePushUserIDs *[]string) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
if msg.ContentType == constant.Typing {
return nil
}
req := &callbackstruct.CallbackBeforePushReq{
UserStatusBatchCallbackReq: callbackstruct.UserStatusBatchCallbackReq{
UserStatusBaseCallback: callbackstruct.UserStatusBaseCallback{
CallbackCommand: callbackstruct.CallbackBeforeOfflinePushCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: int(msg.SenderPlatformID),
Platform: constant.PlatformIDToName(int(msg.SenderPlatformID)),
},
UserIDList: userIDs,
},
OfflinePushInfo: msg.OfflinePushInfo,
ClientMsgID: msg.ClientMsgID,
SendID: msg.SendID,
GroupID: msg.GroupID,
ContentType: msg.ContentType,
SessionType: msg.SessionType,
AtUserIDs: msg.AtUserIDList,
Content: GetContent(msg),
}
resp := &callbackstruct.CallbackBeforePushResp{}
if err := c.webhookClient.SyncPost(ctx, req.GetCallbackCommand(), req, resp, before); err != nil {
return err
}
if len(resp.UserIDs) != 0 {
*offlinePushUserIDs = resp.UserIDs
}
if resp.OfflinePushInfo != nil {
msg.OfflinePushInfo = resp.OfflinePushInfo
}
return nil
})
}
func (c *ConsumerHandler) webhookBeforeOnlinePush(ctx context.Context, before *config.BeforeConfig, userIDs []string, msg *sdkws.MsgData) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
if msg.ContentType == constant.Typing {
return nil
}
req := callbackstruct.CallbackBeforePushReq{
UserStatusBatchCallbackReq: callbackstruct.UserStatusBatchCallbackReq{
UserStatusBaseCallback: callbackstruct.UserStatusBaseCallback{
CallbackCommand: callbackstruct.CallbackBeforeOnlinePushCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: int(msg.SenderPlatformID),
Platform: constant.PlatformIDToName(int(msg.SenderPlatformID)),
},
UserIDList: userIDs,
},
ClientMsgID: msg.ClientMsgID,
SendID: msg.SendID,
GroupID: msg.GroupID,
ContentType: msg.ContentType,
SessionType: msg.SessionType,
AtUserIDs: msg.AtUserIDList,
Content: GetContent(msg),
}
resp := &callbackstruct.CallbackBeforePushResp{}
if err := c.webhookClient.SyncPost(ctx, req.GetCallbackCommand(), req, resp, before); err != nil {
return err
}
return nil
})
}
func (c *ConsumerHandler) webhookBeforeGroupOnlinePush(
ctx context.Context,
before *config.BeforeConfig,
groupID string,
msg *sdkws.MsgData,
pushToUserIDs *[]string,
) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
if msg.ContentType == constant.Typing {
return nil
}
req := callbackstruct.CallbackBeforeSuperGroupOnlinePushReq{
UserStatusBaseCallback: callbackstruct.UserStatusBaseCallback{
CallbackCommand: callbackstruct.CallbackBeforeGroupOnlinePushCommand,
OperationID: mcontext.GetOperationID(ctx),
PlatformID: int(msg.SenderPlatformID),
Platform: constant.PlatformIDToName(int(msg.SenderPlatformID)),
},
ClientMsgID: msg.ClientMsgID,
SendID: msg.SendID,
GroupID: groupID,
ContentType: msg.ContentType,
SessionType: msg.SessionType,
AtUserIDs: msg.AtUserIDList,
Content: GetContent(msg),
Seq: msg.Seq,
}
resp := &callbackstruct.CallbackBeforeSuperGroupOnlinePushResp{}
if err := c.webhookClient.SyncPost(ctx, req.GetCallbackCommand(), req, resp, before); err != nil {
return err
}
if len(resp.UserIDs) != 0 {
*pushToUserIDs = resp.UserIDs
}
return nil
})
}
func GetContent(msg *sdkws.MsgData) string {
if msg.ContentType >= constant.NotificationBegin && msg.ContentType <= constant.NotificationEnd {
var notification sdkws.NotificationElem
if err := json.Unmarshal(msg.Content, &notification); err != nil {
return ""
}
return notification.Detail
} else {
return string(msg.Content)
}
}

View File

@@ -0,0 +1,38 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dummy
import (
"context"
"sync/atomic"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"github.com/openimsdk/tools/log"
)
func NewClient() *Dummy {
return &Dummy{}
}
type Dummy struct {
v atomic.Bool
}
func (d *Dummy) Push(ctx context.Context, userIDs []string, title, content string, opts *options.Opts) error {
if d.v.CompareAndSwap(false, true) {
log.ZWarn(ctx, "dummy push", nil, "ps", "the offline push is not configured. to configure it, please go to config/openim-push.yml")
}
return nil
}

View File

@@ -0,0 +1,172 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fcm
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"github.com/openimsdk/tools/utils/httputil"
firebase "firebase.google.com/go/v4"
"firebase.google.com/go/v4/messaging"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
"google.golang.org/api/option"
)
const SinglePushCountLimit = 400
var Terminal = []int{constant.IOSPlatformID, constant.AndroidPlatformID, constant.WebPlatformID}
type Fcm struct {
fcmMsgCli *messaging.Client
cache cache.ThirdCache
}
// NewClient initializes a new FCM client using the Firebase Admin SDK.
// It requires the FCM service account credentials file located within the project's configuration directory.
func NewClient(pushConf *config.Push, cache cache.ThirdCache, fcmConfigPath string) (*Fcm, error) {
var opt option.ClientOption
switch {
case len(pushConf.FCM.FilePath) != 0:
// with file path
credentialsFilePath := filepath.Join(fcmConfigPath, pushConf.FCM.FilePath)
opt = option.WithCredentialsFile(credentialsFilePath)
case len(pushConf.FCM.AuthURL) != 0:
// with authentication URL
client := httputil.NewHTTPClient(httputil.NewClientConfig())
resp, err := client.Get(pushConf.FCM.AuthURL)
if err != nil {
return nil, err
}
opt = option.WithCredentialsJSON(resp)
default:
return nil, errs.New("no FCM config").Wrap()
}
fcmApp, err := firebase.NewApp(context.Background(), nil, opt)
if err != nil {
return nil, errs.Wrap(err)
}
ctx := context.Background()
fcmMsgClient, err := fcmApp.Messaging(ctx)
if err != nil {
return nil, errs.Wrap(err)
}
return &Fcm{fcmMsgCli: fcmMsgClient, cache: cache}, nil
}
func (f *Fcm) Push(ctx context.Context, userIDs []string, title, content string, opts *options.Opts) error {
// accounts->registrationToken
allTokens := make(map[string][]string, 0)
for _, account := range userIDs {
var personTokens []string
for _, v := range Terminal {
Token, err := f.cache.GetFcmToken(ctx, account, v)
if err == nil {
personTokens = append(personTokens, Token)
}
}
allTokens[account] = personTokens
}
Success := 0
Fail := 0
notification := &messaging.Notification{}
notification.Body = content
notification.Title = title
var messages []*messaging.Message
var sendErrBuilder strings.Builder
var msgErrBuilder strings.Builder
for userID, personTokens := range allTokens {
apns := &messaging.APNSConfig{Payload: &messaging.APNSPayload{Aps: &messaging.Aps{Sound: opts.IOSPushSound}}}
messageCount := len(messages)
if messageCount >= SinglePushCountLimit {
response, err := f.fcmMsgCli.SendEach(ctx, messages)
if err != nil {
Fail = Fail + messageCount
// Record push error
sendErrBuilder.WriteString(err.Error())
sendErrBuilder.WriteByte('.')
} else {
Success = Success + response.SuccessCount
Fail = Fail + response.FailureCount
if response.FailureCount != 0 {
// Record message error
for i := range response.Responses {
if !response.Responses[i].Success {
msgErrBuilder.WriteString(response.Responses[i].Error.Error())
msgErrBuilder.WriteByte('.')
}
}
}
}
messages = messages[0:0]
}
if opts.IOSBadgeCount {
unreadCountSum, err := f.cache.IncrUserBadgeUnreadCountSum(ctx, userID)
if err == nil {
apns.Payload.Aps.Badge = &unreadCountSum
} else {
// log.Error(operationID, "IncrUserBadgeUnreadCountSum redis err", err.Error(), uid)
Fail++
continue
}
} else {
unreadCountSum, err := f.cache.GetUserBadgeUnreadCountSum(ctx, userID)
if err == nil && unreadCountSum != 0 {
apns.Payload.Aps.Badge = &unreadCountSum
} else if errors.Is(err, redis.Nil) || unreadCountSum == 0 {
zero := 1
apns.Payload.Aps.Badge = &zero
} else {
// log.Error(operationID, "GetUserBadgeUnreadCountSum redis err", err.Error(), uid)
Fail++
continue
}
}
for _, token := range personTokens {
temp := &messaging.Message{
Data: map[string]string{"ex": opts.Ex},
Token: token,
Notification: notification,
APNS: apns,
}
messages = append(messages, temp)
}
}
messageCount := len(messages)
if messageCount > 0 {
response, err := f.fcmMsgCli.SendEach(ctx, messages)
if err != nil {
Fail = Fail + messageCount
} else {
Success = Success + response.SuccessCount
Fail = Fail + response.FailureCount
}
}
if Fail != 0 {
return errs.New(fmt.Sprintf("%d message send failed;send err:%s;message err:%s",
Fail, sendErrBuilder.String(), msgErrBuilder.String())).Wrap()
}
return nil
}

View File

@@ -0,0 +1,219 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package getui
import (
"fmt"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"github.com/openimsdk/tools/utils/datautil"
)
var (
incOne = datautil.ToPtr("+1")
addNum = "1"
defaultStrategy = strategy{
Default: 1,
}
msgCategory = "CATEGORY_MESSAGE"
)
type Resp struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data any `json:"data"`
}
func (r *Resp) parseError() (err error) {
switch r.Code {
case tokenExpireCode:
err = ErrTokenExpire
case 0:
err = nil
default:
err = fmt.Errorf("code %d, msg %s", r.Code, r.Msg)
}
return err
}
type RespI interface {
parseError() error
}
type AuthReq struct {
Sign string `json:"sign"`
Timestamp string `json:"timestamp"`
AppKey string `json:"appkey"`
}
type AuthResp struct {
ExpireTime string `json:"expire_time"`
Token string `json:"token"`
}
type TaskResp struct {
TaskID string `json:"taskID"`
}
type Settings struct {
TTL *int64 `json:"ttl"`
Strategy strategy `json:"strategy"`
}
type strategy struct {
Default int64 `json:"default"`
//IOS int64 `json:"ios"`
//St int64 `json:"st"`
//Hw int64 `json:"hw"`
//Ho int64 `json:"ho"`
//XM int64 `json:"xm"`
//XMG int64 `json:"xmg"`
//VV int64 `json:"vv"`
//Op int64 `json:"op"`
//OpG int64 `json:"opg"`
//MZ int64 `json:"mz"`
//HosHw int64 `json:"hoshw"`
//WX int64 `json:"wx"`
}
type Audience struct {
Alias []string `json:"alias"`
}
type PushMessage struct {
Notification *Notification `json:"notification,omitempty"`
Transmission *string `json:"transmission,omitempty"`
}
type PushChannel struct {
Ios *Ios `json:"ios"`
Android *Android `json:"android"`
}
type PushReq struct {
RequestID *string `json:"request_id"`
Settings *Settings `json:"settings"`
Audience *Audience `json:"audience"`
PushMessage *PushMessage `json:"push_message"`
PushChannel *PushChannel `json:"push_channel"`
IsAsync *bool `json:"is_async"`
TaskID *string `json:"taskid"`
}
type Ios struct {
NotificationType *string `json:"type"`
AutoBadge *string `json:"auto_badge"`
Aps struct {
Sound string `json:"sound"`
Alert Alert `json:"alert"`
} `json:"aps"`
}
type Alert struct {
Title string `json:"title"`
Body string `json:"body"`
}
type Android struct {
Ups struct {
Notification Notification `json:"notification"`
Options Options `json:"options"`
} `json:"ups"`
}
type Notification struct {
Title string `json:"title"`
Body string `json:"body"`
ChannelID string `json:"channelID"`
ChannelName string `json:"ChannelName"`
ClickType string `json:"click_type"`
BadgeAddNum string `json:"badge_add_num"`
Category string `json:"category"`
}
type Options struct {
HW struct {
DefaultSound bool `json:"/message/android/notification/default_sound"`
ChannelID string `json:"/message/android/notification/channel_id"`
Sound string `json:"/message/android/notification/sound"`
Importance string `json:"/message/android/notification/importance"`
Category string `json:"/message/android/category"`
} `json:"HW"`
XM struct {
ChannelID string `json:"/extra.channel_id"`
} `json:"XM"`
VV struct {
Classification int `json:"/classification"`
} `json:"VV"`
}
type Payload struct {
IsSignal bool `json:"isSignal"`
}
func newPushReq(pushConf *config.Push, title, content string) PushReq {
pushReq := PushReq{PushMessage: &PushMessage{Notification: &Notification{
Title: title,
Body: content,
ClickType: "startapp",
ChannelID: pushConf.GeTui.ChannelID,
ChannelName: pushConf.GeTui.ChannelName,
BadgeAddNum: addNum,
Category: msgCategory,
}}}
return pushReq
}
func newBatchPushReq(userIDs []string, taskID string) PushReq {
IsAsync := true
return PushReq{Audience: &Audience{Alias: userIDs}, IsAsync: &IsAsync, TaskID: &taskID}
}
func (pushReq *PushReq) setPushChannel(title string, body string) {
pushReq.PushChannel = &PushChannel{}
// autoBadge := "+1"
pushReq.PushChannel.Ios = &Ios{}
notify := "notify"
pushReq.PushChannel.Ios.NotificationType = &notify
pushReq.PushChannel.Ios.Aps.Sound = "default"
pushReq.PushChannel.Ios.AutoBadge = incOne
pushReq.PushChannel.Ios.Aps.Alert = Alert{
Title: title,
Body: body,
}
pushReq.PushChannel.Android = &Android{}
pushReq.PushChannel.Android.Ups.Notification = Notification{
Title: title,
Body: body,
ClickType: "startapp",
}
pushReq.PushChannel.Android.Ups.Options = Options{
HW: struct {
DefaultSound bool `json:"/message/android/notification/default_sound"`
ChannelID string `json:"/message/android/notification/channel_id"`
Sound string `json:"/message/android/notification/sound"`
Importance string `json:"/message/android/notification/importance"`
Category string `json:"/message/android/category"`
}{ChannelID: "RingRing4", Sound: "/raw/ring001", Importance: "NORMAL", Category: "IM"},
XM: struct {
ChannelID string `json:"/extra.channel_id"`
}{ChannelID: "high_system"},
VV: struct {
Classification int "json:\"/classification\""
}{
Classification: 1,
},
}
}

View File

@@ -0,0 +1,219 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package getui
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"strconv"
"sync"
"time"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/httputil"
"github.com/openimsdk/tools/utils/splitter"
"github.com/redis/go-redis/v9"
)
var (
ErrTokenExpire = errs.New("token expire")
ErrUserIDEmpty = errs.New("userIDs is empty")
)
const (
pushURL = "/push/single/alias"
authURL = "/auth"
taskURL = "/push/list/message"
batchPushURL = "/push/list/alias"
// Codes.
tokenExpireCode = 10001
tokenExpireTime = 60 * 60 * 23
taskIDTTL = 1000 * 60 * 60 * 24
)
type Client struct {
cache cache.ThirdCache
tokenExpireTime int64
taskIDTTL int64
pushConf *config.Push
httpClient *httputil.HTTPClient
}
func NewClient(pushConf *config.Push, cache cache.ThirdCache) *Client {
return &Client{cache: cache,
tokenExpireTime: tokenExpireTime,
taskIDTTL: taskIDTTL,
pushConf: pushConf,
httpClient: httputil.NewHTTPClient(httputil.NewClientConfig()),
}
}
func (g *Client) Push(ctx context.Context, userIDs []string, title, content string, opts *options.Opts) error {
token, err := g.cache.GetGetuiToken(ctx)
if err != nil {
if errors.Is(err, redis.Nil) {
log.ZDebug(ctx, "getui token not exist in redis")
token, err = g.getTokenAndSave2Redis(ctx)
if err != nil {
return err
}
} else {
return err
}
}
pushReq := newPushReq(g.pushConf, title, content)
pushReq.setPushChannel(title, content)
if len(userIDs) > 1 {
maxNum := 999
if len(userIDs) > maxNum {
s := splitter.NewSplitter(maxNum, userIDs)
wg := sync.WaitGroup{}
wg.Add(len(s.GetSplitResult()))
for i, v := range s.GetSplitResult() {
go func(index int, userIDs []string) {
defer wg.Done()
for i := 0; i < len(userIDs); i += maxNum {
end := i + maxNum
if end > len(userIDs) {
end = len(userIDs)
}
if err = g.batchPush(ctx, token, userIDs[i:end], pushReq); err != nil {
log.ZError(ctx, "batchPush failed", err, "index", index, "token", token, "req", pushReq)
}
}
if err = g.batchPush(ctx, token, userIDs, pushReq); err != nil {
log.ZError(ctx, "batchPush failed", err, "index", index, "token", token, "req", pushReq)
}
}(i, v.Item)
}
wg.Wait()
} else {
err = g.batchPush(ctx, token, userIDs, pushReq)
}
} else if len(userIDs) == 1 {
err = g.singlePush(ctx, token, userIDs[0], pushReq)
} else {
return ErrUserIDEmpty
}
switch err {
case ErrTokenExpire:
token, err = g.getTokenAndSave2Redis(ctx)
}
return err
}
func (g *Client) Auth(ctx context.Context, timeStamp int64) (token string, expireTime int64, err error) {
h := sha256.New()
h.Write(
[]byte(g.pushConf.GeTui.AppKey + strconv.Itoa(int(timeStamp)) + g.pushConf.GeTui.MasterSecret),
)
sign := hex.EncodeToString(h.Sum(nil))
reqAuth := AuthReq{
Sign: sign,
Timestamp: strconv.Itoa(int(timeStamp)),
AppKey: g.pushConf.GeTui.AppKey,
}
respAuth := AuthResp{}
err = g.request(ctx, authURL, reqAuth, "", &respAuth)
if err != nil {
return "", 0, err
}
expire, err := strconv.Atoi(respAuth.ExpireTime)
return respAuth.Token, int64(expire), err
}
func (g *Client) GetTaskID(ctx context.Context, token string, pushReq PushReq) (string, error) {
respTask := TaskResp{}
ttl := int64(1000 * 60 * 5)
pushReq.Settings = &Settings{TTL: &ttl, Strategy: defaultStrategy}
err := g.request(ctx, taskURL, pushReq, token, &respTask)
if err != nil {
return "", errs.Wrap(err)
}
return respTask.TaskID, nil
}
// max num is 999.
func (g *Client) batchPush(ctx context.Context, token string, userIDs []string, pushReq PushReq) error {
taskID, err := g.GetTaskID(ctx, token, pushReq)
if err != nil {
return err
}
pushReq = newBatchPushReq(userIDs, taskID)
return g.request(ctx, batchPushURL, pushReq, token, nil)
}
func (g *Client) singlePush(ctx context.Context, token, userID string, pushReq PushReq) error {
operationID := mcontext.GetOperationID(ctx)
pushReq.RequestID = &operationID
pushReq.Audience = &Audience{Alias: []string{userID}}
return g.request(ctx, pushURL, pushReq, token, nil)
}
func (g *Client) request(ctx context.Context, url string, input any, token string, output any) error {
header := map[string]string{"token": token}
resp := &Resp{}
resp.Data = output
return g.postReturn(ctx, g.pushConf.GeTui.PushUrl+url, header, input, resp, 3)
}
func (g *Client) postReturn(
ctx context.Context,
url string,
header map[string]string,
input any,
output RespI,
timeout int,
) error {
err := g.httpClient.PostReturn(ctx, url, header, input, output, timeout)
if err != nil {
return err
}
log.ZDebug(ctx, "postReturn", "url", url, "header", header, "input", input, "timeout", timeout, "output", output)
return output.parseError()
}
func (g *Client) getTokenAndSave2Redis(ctx context.Context) (token string, err error) {
token, _, err = g.Auth(ctx, time.Now().UnixNano()/1e6)
if err != nil {
return
}
err = g.cache.SetGetuiToken(ctx, token, 60*60*23)
if err != nil {
return
}
return token, nil
}
func (g *Client) GetTaskIDAndSave2Redis(ctx context.Context, token string, pushReq PushReq) (taskID string, err error) {
pushReq.Settings = &Settings{TTL: &g.taskIDTTL, Strategy: defaultStrategy}
taskID, err = g.GetTaskID(ctx, token, pushReq)
if err != nil {
return
}
err = g.cache.SetGetuiTaskID(ctx, taskID, g.tokenExpireTime)
if err != nil {
return
}
return token, nil
}

View File

@@ -0,0 +1,64 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
const (
TAG = "tag"
TAGAND = "tag_and"
TAGNOT = "tag_not"
ALIAS = "alias"
REGISTRATIONID = "registration_id"
)
type Audience struct {
Object any
audience map[string][]string
}
func (a *Audience) set(key string, v []string) {
if a.audience == nil {
a.audience = make(map[string][]string)
a.Object = a.audience
}
// v, ok = this.audience[key]
// if ok {
// return
//}
a.audience[key] = v
}
func (a *Audience) SetTag(tags []string) {
a.set(TAG, tags)
}
func (a *Audience) SetTagAnd(tags []string) {
a.set(TAGAND, tags)
}
func (a *Audience) SetTagNot(tags []string) {
a.set(TAGNOT, tags)
}
func (a *Audience) SetAlias(alias []string) {
a.set(ALIAS, alias)
}
func (a *Audience) SetRegistrationId(ids []string) {
a.set(REGISTRATIONID, ids)
}
func (a *Audience) SetAll() {
a.Object = "all"
}

View File

@@ -0,0 +1,41 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
type Message struct {
MsgContent string `json:"msg_content"`
Title string `json:"title,omitempty"`
ContentType string `json:"content_type,omitempty"`
Extras map[string]any `json:"extras,omitempty"`
}
func (m *Message) SetMsgContent(c string) {
m.MsgContent = c
}
func (m *Message) SetTitle(t string) {
m.Title = t
}
func (m *Message) SetContentType(c string) {
m.ContentType = c
}
func (m *Message) SetExtras(key string, value any) {
if m.Extras == nil {
m.Extras = make(map[string]any)
}
m.Extras[key] = value
}

View File

@@ -0,0 +1,72 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
import (
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
)
type Notification struct {
Alert string `json:"alert,omitempty"`
Android Android `json:"android,omitempty"`
IOS Ios `json:"ios,omitempty"`
}
type Android struct {
Alert string `json:"alert,omitempty"`
Title string `json:"title,omitempty"`
Intent struct {
URL string `json:"url,omitempty"`
} `json:"intent,omitempty"`
Extras map[string]string `json:"extras,omitempty"`
}
type Ios struct {
Alert IosAlert `json:"alert,omitempty"`
Sound string `json:"sound,omitempty"`
Badge string `json:"badge,omitempty"`
Extras map[string]string `json:"extras,omitempty"`
MutableContent bool `json:"mutable-content"`
}
type IosAlert struct {
Title string `json:"title,omitempty"`
Body string `json:"body,omitempty"`
}
func (n *Notification) SetAlert(alert string, title string, opts *options.Opts) {
n.Alert = alert
n.Android.Alert = alert
n.Android.Title = title
n.IOS.Alert.Body = alert
n.IOS.Alert.Title = title
n.IOS.Sound = opts.IOSPushSound
if opts.IOSBadgeCount {
n.IOS.Badge = "+1"
}
}
func (n *Notification) SetExtras(extras map[string]string) {
n.IOS.Extras = extras
n.Android.Extras = extras
}
func (n *Notification) SetAndroidIntent(pushConf *config.Push) {
n.Android.Intent.URL = pushConf.JPush.PushIntent
}
func (n *Notification) IOSEnableMutableContent() {
n.IOS.MutableContent = true
}

View File

@@ -0,0 +1,23 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
type Options struct {
ApnsProduction bool `json:"apns_production"`
}
func (o *Options) SetApnsProduction(c bool) {
o.ApnsProduction = c
}

View File

@@ -0,0 +1,99 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
import (
"github.com/openimsdk/tools/errs"
"git.imall.cloud/openim/protocol/constant"
)
const (
ANDROID = "android"
IOS = "ios"
QUICKAPP = "quickapp"
WINDOWSPHONE = "winphone"
ALL = "all"
)
type Platform struct {
Os any
osArry []string
}
func (p *Platform) Set(os string) error {
if p.Os == nil {
p.osArry = make([]string, 0, 4)
} else {
switch p.Os.(type) {
case string:
return errs.New("platform is all")
default:
}
}
for _, value := range p.osArry {
if os == value {
return nil
}
}
switch os {
case IOS:
fallthrough
case ANDROID:
fallthrough
case QUICKAPP:
fallthrough
case WINDOWSPHONE:
p.osArry = append(p.osArry, os)
p.Os = p.osArry
default:
return errs.New("unknow platform")
}
return nil
}
func (p *Platform) SetPlatform(platform string) error {
switch platform {
case constant.AndroidPlatformStr:
return p.SetAndroid()
case constant.IOSPlatformStr:
return p.SetIOS()
default:
return errs.New("platform err")
}
}
func (p *Platform) SetIOS() error {
return p.Set(IOS)
}
func (p *Platform) SetAndroid() error {
return p.Set(ANDROID)
}
func (p *Platform) SetQuickApp() error {
return p.Set(QUICKAPP)
}
func (p *Platform) SetWindowsPhone() error {
return p.Set(WINDOWSPHONE)
}
func (p *Platform) SetAll() {
p.Os = ALL
}

View File

@@ -0,0 +1,43 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package body
type PushObj struct {
Platform any `json:"platform"`
Audience any `json:"audience"`
Notification any `json:"notification,omitempty"`
Message any `json:"message,omitempty"`
Options any `json:"options,omitempty"`
}
func (p *PushObj) SetPlatform(pf *Platform) {
p.Platform = pf.Os
}
func (p *PushObj) SetAudience(ad *Audience) {
p.Audience = ad.Object
}
func (p *PushObj) SetNotification(no *Notification) {
p.Notification = no
}
func (p *PushObj) SetMessage(m *Message) {
p.Message = m
}
func (p *PushObj) SetOptions(o *Options) {
p.Options = o
}

View File

@@ -0,0 +1,107 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jpush
import (
"context"
"encoding/base64"
"fmt"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/jpush/body"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"github.com/openimsdk/tools/utils/httputil"
)
type JPush struct {
pushConf *config.Push
httpClient *httputil.HTTPClient
}
func NewClient(pushConf *config.Push) *JPush {
return &JPush{pushConf: pushConf,
httpClient: httputil.NewHTTPClient(httputil.NewClientConfig()),
}
}
func (j *JPush) Auth(apiKey, secretKey string, timeStamp int64) (token string, err error) {
return token, nil
}
func (j *JPush) SetAlias(cid, alias string) (resp string, err error) {
return resp, nil
}
func (j *JPush) getAuthorization(appKey string, masterSecret string) string {
str := fmt.Sprintf("%s:%s", appKey, masterSecret)
buf := []byte(str)
Authorization := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(buf))
return Authorization
}
func (j *JPush) Push(ctx context.Context, userIDs []string, title, content string, opts *options.Opts) error {
var pf body.Platform
pf.SetAll()
var au body.Audience
au.SetAlias(userIDs)
var no body.Notification
extras := make(map[string]string)
extras["ex"] = opts.Ex
if opts.Signal.ClientMsgID != "" {
extras["ClientMsgID"] = opts.Signal.ClientMsgID
}
no.IOSEnableMutableContent()
no.SetExtras(extras)
no.SetAlert(content, title, opts)
no.SetAndroidIntent(j.pushConf)
var msg body.Message
msg.SetMsgContent(content)
msg.SetTitle(title)
if opts.Signal.ClientMsgID != "" {
msg.SetExtras("ClientMsgID", opts.Signal.ClientMsgID)
}
msg.SetExtras("ex", opts.Ex)
var opt body.Options
opt.SetApnsProduction(j.pushConf.IOSPush.Production)
var pushObj body.PushObj
pushObj.SetPlatform(&pf)
pushObj.SetAudience(&au)
pushObj.SetNotification(&no)
pushObj.SetMessage(&msg)
pushObj.SetOptions(&opt)
var resp map[string]any
return j.request(ctx, pushObj, &resp, 5)
}
func (j *JPush) request(ctx context.Context, po body.PushObj, resp *map[string]any, timeout int) error {
err := j.httpClient.PostReturn(
ctx,
j.pushConf.JPush.PushURL,
map[string]string{
"Authorization": j.getAuthorization(j.pushConf.JPush.AppKey, j.pushConf.JPush.MasterSecret),
},
po,
resp,
timeout,
)
if err != nil {
return err
}
if (*resp)["sendno"] != "0" {
return fmt.Errorf("jpush push failed %v", resp)
}
return nil
}

View File

@@ -0,0 +1,55 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package offlinepush
import (
"context"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/dummy"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/fcm"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/getui"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/jpush"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
)
const (
geTUI = "getui"
firebase = "fcm"
jPush = "jpush"
)
// OfflinePusher Offline Pusher.
type OfflinePusher interface {
Push(ctx context.Context, userIDs []string, title, content string, opts *options.Opts) error
}
func NewOfflinePusher(pushConf *config.Push, cache cache.ThirdCache, fcmConfigPath string) (OfflinePusher, error) {
var offlinePusher OfflinePusher
pushConf.Enable = strings.ToLower(pushConf.Enable)
switch pushConf.Enable {
case geTUI:
offlinePusher = getui.NewClient(pushConf, cache)
case firebase:
return fcm.NewClient(pushConf, cache, fcmConfigPath)
case jPush:
offlinePusher = jpush.NewClient(pushConf)
default:
offlinePusher = dummy.NewClient()
}
return offlinePusher, nil
}

View File

@@ -0,0 +1,14 @@
package options
// Opts opts.
type Opts struct {
Signal *Signal
IOSPushSound string
IOSBadgeCount bool
Ex string
}
// Signal message id.
type Signal struct {
ClientMsgID string
}

View File

@@ -0,0 +1,105 @@
package push
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/protocol/constant"
pbpush "git.imall.cloud/openim/protocol/push"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/jsonutil"
"google.golang.org/protobuf/proto"
)
type OfflinePushConsumerHandler struct {
offlinePusher offlinepush.OfflinePusher
}
func NewOfflinePushConsumerHandler(offlinePusher offlinepush.OfflinePusher) *OfflinePushConsumerHandler {
return &OfflinePushConsumerHandler{
offlinePusher: offlinePusher,
}
}
func (o *OfflinePushConsumerHandler) HandleMsg2OfflinePush(ctx context.Context, msg []byte) {
offlinePushMsg := pbpush.PushMsgReq{}
if err := proto.Unmarshal(msg, &offlinePushMsg); err != nil {
log.ZError(ctx, "offline push Unmarshal msg err", err, "msg", string(msg))
return
}
if offlinePushMsg.MsgData == nil || offlinePushMsg.UserIDs == nil {
log.ZError(ctx, "offline push msg is empty", errs.New("offlinePushMsg is empty"), "userIDs", offlinePushMsg.UserIDs, "msg", offlinePushMsg.MsgData)
return
}
if offlinePushMsg.MsgData.Status == constant.MsgStatusSending {
offlinePushMsg.MsgData.Status = constant.MsgStatusSendSuccess
}
log.ZInfo(ctx, "receive to OfflinePush MQ", "userIDs", offlinePushMsg.UserIDs, "msg", offlinePushMsg.MsgData)
err := o.offlinePushMsg(ctx, offlinePushMsg.MsgData, offlinePushMsg.UserIDs)
if err != nil {
log.ZWarn(ctx, "offline push failed", err, "msg", offlinePushMsg.String())
}
}
func (o *OfflinePushConsumerHandler) getOfflinePushInfos(msg *sdkws.MsgData) (title, content string, opts *options.Opts, err error) {
type AtTextElem struct {
Text string `json:"text,omitempty"`
AtUserList []string `json:"atUserList,omitempty"`
IsAtSelf bool `json:"isAtSelf"`
}
opts = &options.Opts{Signal: &options.Signal{ClientMsgID: msg.ClientMsgID}}
if msg.OfflinePushInfo != nil {
opts.IOSBadgeCount = msg.OfflinePushInfo.IOSBadgeCount
opts.IOSPushSound = msg.OfflinePushInfo.IOSPushSound
opts.Ex = msg.OfflinePushInfo.Ex
}
if msg.OfflinePushInfo != nil {
title = msg.OfflinePushInfo.Title
content = msg.OfflinePushInfo.Desc
}
if title == "" {
switch msg.ContentType {
case constant.Text:
fallthrough
case constant.Picture:
fallthrough
case constant.Voice:
fallthrough
case constant.Video:
fallthrough
case constant.File:
title = constant.ContentType2PushContent[int64(msg.ContentType)]
case constant.AtText:
ac := AtTextElem{}
_ = jsonutil.JsonStringToStruct(string(msg.Content), &ac)
case constant.SignalingNotification:
title = constant.ContentType2PushContent[constant.SignalMsg]
default:
title = constant.ContentType2PushContent[constant.Common]
}
}
if content == "" {
content = title
}
return
}
func (o *OfflinePushConsumerHandler) offlinePushMsg(ctx context.Context, msg *sdkws.MsgData, offlinePushUserIDs []string) error {
title, content, opts, err := o.getOfflinePushInfos(msg)
if err != nil {
return err
}
err = o.offlinePusher.Push(ctx, offlinePushUserIDs, title, content, opts)
if err != nil {
prommetrics.MsgOfflinePushFailedCounter.Inc()
return err
}
return nil
}

View File

@@ -0,0 +1,214 @@
package push
import (
"context"
"fmt"
"sync"
"git.imall.cloud/openim/protocol/msggateway"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/runtimeenv"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
conf "git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
)
type OnlinePusher interface {
GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData,
pushToUserIDs []string) (wsResults []*msggateway.SingleMsgToUserResults, err error)
GetOnlinePushFailedUserIDs(ctx context.Context, msg *sdkws.MsgData, wsResults []*msggateway.SingleMsgToUserResults,
pushToUserIDs *[]string) []string
}
type emptyOnlinePusher struct{}
func newEmptyOnlinePusher() *emptyOnlinePusher {
return &emptyOnlinePusher{}
}
func (emptyOnlinePusher) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData, pushToUserIDs []string) (wsResults []*msggateway.SingleMsgToUserResults, err error) {
log.ZInfo(ctx, "emptyOnlinePusher GetConnsAndOnlinePush", nil)
return nil, nil
}
func (u emptyOnlinePusher) GetOnlinePushFailedUserIDs(ctx context.Context, msg *sdkws.MsgData, wsResults []*msggateway.SingleMsgToUserResults, pushToUserIDs *[]string) []string {
log.ZInfo(ctx, "emptyOnlinePusher GetOnlinePushFailedUserIDs", nil)
return nil
}
func NewOnlinePusher(disCov discovery.Conn, config *Config) (OnlinePusher, error) {
if conf.Standalone() {
return NewDefaultAllNode(disCov, config), nil
}
if runtimeenv.RuntimeEnvironment() == conf.KUBERNETES {
return NewDefaultAllNode(disCov, config), nil
}
switch config.Discovery.Enable {
case conf.ETCD:
return NewDefaultAllNode(disCov, config), nil
default:
return nil, errs.New(fmt.Sprintf("unsupported discovery type %s", config.Discovery.Enable))
}
}
type DefaultAllNode struct {
disCov discovery.Conn
config *Config
}
func NewDefaultAllNode(disCov discovery.Conn, config *Config) *DefaultAllNode {
return &DefaultAllNode{disCov: disCov, config: config}
}
func (d *DefaultAllNode) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData,
pushToUserIDs []string) (wsResults []*msggateway.SingleMsgToUserResults, err error) {
conns, err := d.disCov.GetConns(ctx, d.config.Discovery.RpcService.MessageGateway)
if len(conns) == 0 {
log.ZWarn(ctx, "get gateway conn 0 ", nil)
} else {
log.ZDebug(ctx, "get gateway conn", "conn length", len(conns))
}
if err != nil {
return nil, err
}
var (
mu sync.Mutex
wg = errgroup.Group{}
input = &msggateway.OnlineBatchPushOneMsgReq{MsgData: msg, PushToUserIDs: pushToUserIDs}
maxWorkers = d.config.RpcConfig.MaxConcurrentWorkers
)
if maxWorkers < 3 {
maxWorkers = 3
}
wg.SetLimit(maxWorkers)
// Online push message
for _, conn := range conns {
conn := conn // loop var safe
ctx := ctx
wg.Go(func() error {
msgClient := msggateway.NewMsgGatewayClient(conn)
reply, err := msgClient.SuperGroupOnlineBatchPushOneMsg(ctx, input)
if err != nil {
log.ZError(ctx, "SuperGroupOnlineBatchPushOneMsg ", err, "req:", input.String())
return nil
}
log.ZDebug(ctx, "push result", "reply", reply)
if reply != nil && reply.SinglePushResult != nil {
mu.Lock()
wsResults = append(wsResults, reply.SinglePushResult...)
mu.Unlock()
}
return nil
})
}
_ = wg.Wait()
// always return nil
return wsResults, nil
}
func (d *DefaultAllNode) GetOnlinePushFailedUserIDs(_ context.Context, msg *sdkws.MsgData,
wsResults []*msggateway.SingleMsgToUserResults, pushToUserIDs *[]string) []string {
onlineSuccessUserIDs := []string{msg.SendID}
for _, v := range wsResults {
//message sender do not need offline push
if msg.SendID == v.UserID {
continue
}
// mobile online push success
if v.OnlinePush {
onlineSuccessUserIDs = append(onlineSuccessUserIDs, v.UserID)
}
}
return datautil.SliceSub(*pushToUserIDs, onlineSuccessUserIDs)
}
type K8sStaticConsistentHash struct {
disCov discovery.SvcDiscoveryRegistry
config *Config
}
func NewK8sStaticConsistentHash(disCov discovery.SvcDiscoveryRegistry, config *Config) *K8sStaticConsistentHash {
return &K8sStaticConsistentHash{disCov: disCov, config: config}
}
func (k *K8sStaticConsistentHash) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData,
pushToUserIDs []string) (wsResults []*msggateway.SingleMsgToUserResults, err error) {
var usersHost = make(map[string][]string)
for _, v := range pushToUserIDs {
tHost, err := k.disCov.GetUserIdHashGatewayHost(ctx, v)
if err != nil {
log.ZError(ctx, "get msg gateway hash error", err)
return nil, err
}
tUsers, tbl := usersHost[tHost]
if tbl {
tUsers = append(tUsers, v)
usersHost[tHost] = tUsers
} else {
usersHost[tHost] = []string{v}
}
}
log.ZDebug(ctx, "genUsers send hosts struct:", "usersHost", usersHost)
var usersConns = make(map[grpc.ClientConnInterface][]string)
for host, userIds := range usersHost {
tconn, _ := k.disCov.GetConn(ctx, host)
usersConns[tconn] = userIds
}
var (
mu sync.Mutex
wg = errgroup.Group{}
maxWorkers = k.config.RpcConfig.MaxConcurrentWorkers
)
if maxWorkers < 3 {
maxWorkers = 3
}
wg.SetLimit(maxWorkers)
for conn, userIds := range usersConns {
tcon := conn
tuserIds := userIds
wg.Go(func() error {
input := &msggateway.OnlineBatchPushOneMsgReq{MsgData: msg, PushToUserIDs: tuserIds}
msgClient := msggateway.NewMsgGatewayClient(tcon)
reply, err := msgClient.SuperGroupOnlineBatchPushOneMsg(ctx, input)
if err != nil {
return nil
}
log.ZDebug(ctx, "push result", "reply", reply)
if reply != nil && reply.SinglePushResult != nil {
mu.Lock()
wsResults = append(wsResults, reply.SinglePushResult...)
mu.Unlock()
}
return nil
})
}
_ = wg.Wait()
return wsResults, nil
}
func (k *K8sStaticConsistentHash) GetOnlinePushFailedUserIDs(_ context.Context, _ *sdkws.MsgData,
wsResults []*msggateway.SingleMsgToUserResults, _ *[]string) []string {
var needOfflinePushUserIDs []string
for _, v := range wsResults {
if !v.OnlinePush {
needOfflinePushUserIDs = append(needOfflinePushUserIDs, v.UserID)
}
}
return needOfflinePushUserIDs
}

163
internal/push/push.go Normal file
View File

@@ -0,0 +1,163 @@
package push
import (
"context"
"math/rand"
"strconv"
"time"
"github.com/openimsdk/tools/mq"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/mqbuild"
pbpush "git.imall.cloud/openim/protocol/push"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"google.golang.org/grpc"
)
type pushServer struct {
pbpush.UnimplementedPushMsgServiceServer
database controller.PushDatabase
disCov discovery.Conn
offlinePusher offlinepush.OfflinePusher
}
type Config struct {
RpcConfig config.Push
RedisConfig config.Redis
MongoConfig config.Mongo
KafkaConfig config.Kafka
NotificationConfig config.Notification
Share config.Share
WebhooksConfig config.Webhooks
LocalCacheConfig config.LocalCache
Discovery config.Discovery
FcmConfigPath config.Path
}
func (p pushServer) DelUserPushToken(ctx context.Context,
req *pbpush.DelUserPushTokenReq) (resp *pbpush.DelUserPushTokenResp, err error) {
if err = p.database.DelFcmToken(ctx, req.UserID, int(req.PlatformID)); err != nil {
return nil, err
}
return &pbpush.DelUserPushTokenResp{}, nil
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongoConfig, &config.RedisConfig)
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
var cacheModel cache.ThirdCache
if rdb == nil {
mdb, err := dbb.Mongo(ctx)
if err != nil {
return err
}
mc, err := mgo.NewCacheMgo(mdb.GetDB())
if err != nil {
return err
}
cacheModel = mcache.NewThirdCache(mc)
} else {
cacheModel = redis.NewThirdCache(rdb)
}
offlinePusher, err := offlinepush.NewOfflinePusher(&config.RpcConfig, cacheModel, string(config.FcmConfigPath))
if err != nil {
return err
}
builder := mqbuild.NewBuilder(&config.KafkaConfig)
offlinePushProducer, err := builder.GetTopicProducer(ctx, config.KafkaConfig.ToOfflinePushTopic)
if err != nil {
return err
}
database := controller.NewPushDatabase(cacheModel, offlinePushProducer)
pushConsumer, err := builder.GetTopicConsumer(ctx, config.KafkaConfig.ToPushTopic)
if err != nil {
return err
}
offlinePushConsumer, err := builder.GetTopicConsumer(ctx, config.KafkaConfig.ToOfflinePushTopic)
if err != nil {
return err
}
pushHandler, err := NewConsumerHandler(ctx, config, database, offlinePusher, rdb, client)
if err != nil {
return err
}
offlineHandler := NewOfflinePushConsumerHandler(offlinePusher)
pbpush.RegisterPushMsgServiceServer(server, &pushServer{
database: database,
disCov: client,
offlinePusher: offlinePusher,
})
go func() {
consumerCtx := mcontext.SetOperationID(context.Background(), "push_"+strconv.Itoa(int(rand.Uint32())))
waitDone := make(chan struct{})
go func() {
defer func() {
if r := recover(); r != nil {
log.ZError(consumerCtx, "WaitCache panic", nil, "panic", r)
}
close(waitDone)
}()
pushHandler.WaitCache()
}()
select {
case <-waitDone:
log.ZInfo(consumerCtx, "WaitCache completed successfully")
case <-time.After(30 * time.Second):
log.ZWarn(consumerCtx, "WaitCache timeout after 30s, will start subscribe anyway", nil)
}
fn := func(msg mq.Message) error {
pushHandler.HandleMs2PsChat(authverify.WithTempAdmin(msg.Context()), msg.Value())
return nil
}
log.ZInfo(consumerCtx, "begin consume messages")
for {
if err := pushConsumer.Subscribe(consumerCtx, fn); err != nil {
log.ZError(consumerCtx, "subscribe err, will retry in 5 seconds", err)
time.Sleep(5 * time.Second)
continue
}
// Subscribe returned normally (possibly due to context cancellation), retry immediately
log.ZWarn(consumerCtx, "Subscribe returned normally, will retry immediately", nil)
}
}()
go func() {
fn := func(msg mq.Message) error {
offlineHandler.HandleMsg2OfflinePush(msg.Context(), msg.Value())
return nil
}
consumerCtx := mcontext.SetOperationID(context.Background(), "push_"+strconv.Itoa(int(rand.Uint32())))
log.ZInfo(consumerCtx, "begin consume messages")
for {
if err := offlinePushConsumer.Subscribe(consumerCtx, fn); err != nil {
log.ZError(consumerCtx, "subscribe err, will retry in 5 seconds", err)
time.Sleep(5 * time.Second)
continue
}
// Subscribe returned normally (possibly due to context cancellation), retry immediately
log.ZWarn(consumerCtx, "Subscribe returned normally, will retry immediately", nil)
}
}()
return nil
}

View File

@@ -0,0 +1,615 @@
package push
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush"
"git.imall.cloud/openim/open-im-server-deploy/internal/push/offlinepush/options"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/conversationutil"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
pbpush "git.imall.cloud/openim/protocol/push"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/timeutil"
"github.com/redis/go-redis/v9"
"google.golang.org/protobuf/proto"
)
type ConsumerHandler struct {
//pushConsumerGroup mq.Consumer
offlinePusher offlinepush.OfflinePusher
onlinePusher OnlinePusher
pushDatabase controller.PushDatabase
onlineCache rpccache.OnlineCache
groupLocalCache *rpccache.GroupLocalCache
conversationLocalCache *rpccache.ConversationLocalCache
webhookClient *webhook.Client
config *Config
userClient *rpcli.UserClient
groupClient *rpcli.GroupClient
msgClient *rpcli.MsgClient
conversationClient *rpcli.ConversationClient
}
func NewConsumerHandler(ctx context.Context, config *Config, database controller.PushDatabase, offlinePusher offlinepush.OfflinePusher, rdb redis.UniversalClient, client discovery.Conn) (*ConsumerHandler, error) {
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return nil, err
}
groupConn, err := client.GetConn(ctx, config.Discovery.RpcService.Group)
if err != nil {
return nil, err
}
msgConn, err := client.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return nil, err
}
conversationConn, err := client.GetConn(ctx, config.Discovery.RpcService.Conversation)
if err != nil {
return nil, err
}
onlinePusher, err := NewOnlinePusher(client, config)
if err != nil {
return nil, err
}
var consumerHandler ConsumerHandler
consumerHandler.userClient = rpcli.NewUserClient(userConn)
consumerHandler.groupClient = rpcli.NewGroupClient(groupConn)
consumerHandler.msgClient = rpcli.NewMsgClient(msgConn)
consumerHandler.conversationClient = rpcli.NewConversationClient(conversationConn)
consumerHandler.offlinePusher = offlinePusher
consumerHandler.onlinePusher = onlinePusher
consumerHandler.groupLocalCache = rpccache.NewGroupLocalCache(consumerHandler.groupClient, &config.LocalCacheConfig, rdb)
consumerHandler.conversationLocalCache = rpccache.NewConversationLocalCache(consumerHandler.conversationClient, &config.LocalCacheConfig, rdb)
consumerHandler.webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
consumerHandler.config = config
consumerHandler.pushDatabase = database
consumerHandler.onlineCache, err = rpccache.NewOnlineCache(consumerHandler.userClient, consumerHandler.groupLocalCache, rdb, config.RpcConfig.FullUserCache, nil)
if err != nil {
return nil, err
}
return &consumerHandler, nil
}
func (c *ConsumerHandler) HandleMs2PsChat(ctx context.Context, msg []byte) {
msgFromMQ := pbpush.PushMsgReq{}
if err := proto.Unmarshal(msg, &msgFromMQ); err != nil {
log.ZError(ctx, "push Unmarshal msg err", err, "msg", string(msg))
return
}
sec := msgFromMQ.MsgData.SendTime / 1000
nowSec := timeutil.GetCurrentTimestampBySecond()
if nowSec-sec > 10 {
prommetrics.MsgLoneTimePushCounter.Inc()
log.ZWarn(ctx, "its been a while since the message was sent", nil, "msg", msgFromMQ.String(), "sec", sec, "nowSec", nowSec, "nowSec-sec", nowSec-sec)
}
var err error
switch msgFromMQ.MsgData.SessionType {
case constant.ReadGroupChatType:
err = c.Push2Group(ctx, msgFromMQ.MsgData.GroupID, msgFromMQ.MsgData)
default:
var pushUserIDList []string
isSenderSync := datautil.GetSwitchFromOptions(msgFromMQ.MsgData.Options, constant.IsSenderSync)
if !isSenderSync || msgFromMQ.MsgData.SendID == msgFromMQ.MsgData.RecvID {
pushUserIDList = append(pushUserIDList, msgFromMQ.MsgData.RecvID)
} else {
pushUserIDList = append(pushUserIDList, msgFromMQ.MsgData.RecvID, msgFromMQ.MsgData.SendID)
}
err = c.Push2User(ctx, pushUserIDList, msgFromMQ.MsgData)
}
if err != nil {
log.ZWarn(ctx, "push failed", err, "msg", msgFromMQ.String())
}
}
func (c *ConsumerHandler) WaitCache() {
c.onlineCache.WaitCache()
}
// Push2User Suitable for two types of conversations, one is SingleChatType and the other is NotificationChatType.
func (c *ConsumerHandler) Push2User(ctx context.Context, userIDs []string, msg *sdkws.MsgData) (err error) {
log.ZInfo(ctx, "Get msg from msg_transfer And push msg", "userIDs", userIDs, "msg", msg.String())
defer func(duration time.Time) {
t := time.Since(duration)
log.ZInfo(ctx, "Get msg from msg_transfer And push msg end", "msg", msg.String(), "time cost", t)
}(time.Now())
if err := c.webhookBeforeOnlinePush(ctx, &c.config.WebhooksConfig.BeforeOnlinePush, userIDs, msg); err != nil {
return err
}
wsResults, err := c.GetConnsAndOnlinePush(ctx, msg, userIDs)
if err != nil {
return err
}
log.ZDebug(ctx, "single and notification push result", "result", wsResults, "msg", msg, "push_to_userID", userIDs)
log.ZInfo(ctx, "single and notification push end")
if !c.shouldPushOffline(ctx, msg) {
return nil
}
log.ZInfo(ctx, "pushOffline start")
for _, v := range wsResults {
//message sender do not need offline push
if msg.SendID == v.UserID {
continue
}
//receiver online push success
if v.OnlinePush {
return nil
}
}
needOfflinePushUserID := []string{msg.RecvID}
var offlinePushUserID []string
//receiver offline push
if err = c.webhookBeforeOfflinePush(ctx, &c.config.WebhooksConfig.BeforeOfflinePush, needOfflinePushUserID, msg, &offlinePushUserID); err != nil {
return err
}
if len(offlinePushUserID) > 0 {
needOfflinePushUserID = offlinePushUserID
}
err = c.offlinePushMsg(ctx, msg, needOfflinePushUserID)
if err != nil {
log.ZDebug(ctx, "offlinePushMsg failed", err, "needOfflinePushUserID", needOfflinePushUserID, "msg", msg)
log.ZWarn(ctx, "offlinePushMsg failed", err, "needOfflinePushUserID length", len(needOfflinePushUserID), "msg", msg)
return nil
}
return nil
}
func (c *ConsumerHandler) shouldPushOffline(_ context.Context, msg *sdkws.MsgData) bool {
isOfflinePush := datautil.GetSwitchFromOptions(msg.Options, constant.IsOfflinePush)
if !isOfflinePush {
return false
}
switch msg.ContentType {
case constant.RoomParticipantsConnectedNotification:
return false
case constant.RoomParticipantsDisconnectedNotification:
return false
}
return true
}
func (c *ConsumerHandler) GetConnsAndOnlinePush(ctx context.Context, msg *sdkws.MsgData, pushToUserIDs []string) ([]*msggateway.SingleMsgToUserResults, error) {
if msg != nil && msg.Status == constant.MsgStatusSending {
msg.Status = constant.MsgStatusSendSuccess
}
onlineUserIDs, offlineUserIDs, err := c.onlineCache.GetUsersOnline(ctx, pushToUserIDs)
if err != nil {
return nil, err
}
log.ZDebug(ctx, "GetConnsAndOnlinePush online cache", "sendID", msg.SendID, "recvID", msg.RecvID, "groupID", msg.GroupID, "sessionType", msg.SessionType, "clientMsgID", msg.ClientMsgID, "serverMsgID", msg.ServerMsgID, "offlineUserIDs", offlineUserIDs, "onlineUserIDs", onlineUserIDs)
var result []*msggateway.SingleMsgToUserResults
if len(onlineUserIDs) > 0 {
var err error
result, err = c.onlinePusher.GetConnsAndOnlinePush(ctx, msg, onlineUserIDs)
if err != nil {
return nil, err
}
}
for _, userID := range offlineUserIDs {
result = append(result, &msggateway.SingleMsgToUserResults{
UserID: userID,
})
}
return result, nil
}
func (c *ConsumerHandler) Push2Group(ctx context.Context, groupID string, msg *sdkws.MsgData) (err error) {
log.ZInfo(ctx, "Get group msg from msg_transfer and push msg", "msg", msg.String(), "groupID", groupID, "contentType", msg.ContentType)
defer func(duration time.Time) {
t := time.Since(duration)
log.ZInfo(ctx, "Get group msg from msg_transfer and push msg end", "msg", msg.String(), "groupID", groupID, "time cost", t)
}(time.Now())
var pushToUserIDs []string
if err = c.webhookBeforeGroupOnlinePush(ctx, &c.config.WebhooksConfig.BeforeGroupOnlinePush, groupID, msg,
&pushToUserIDs); err != nil {
log.ZWarn(ctx, "Push2Group webhookBeforeGroupOnlinePush failed", err, "groupID", groupID)
return err
}
log.ZDebug(ctx, "Push2Group after webhook", "groupID", groupID, "pushToUserIDsFromWebhook", pushToUserIDs, "count", len(pushToUserIDs))
err = c.groupMessagesHandler(ctx, groupID, &pushToUserIDs, msg)
if err != nil {
log.ZWarn(ctx, "Push2Group groupMessagesHandler failed", err, "groupID", groupID)
return err
}
log.ZDebug(ctx, "Push2Group after groupMessagesHandler", "groupID", groupID, "pushToUserIDs", pushToUserIDs, "count", len(pushToUserIDs))
wsResults, err := c.GetConnsAndOnlinePush(ctx, msg, pushToUserIDs)
if err != nil {
log.ZWarn(ctx, "Push2Group GetConnsAndOnlinePush failed", err, "groupID", groupID)
return err
}
log.ZDebug(ctx, "Push2Group online push completed", "groupID", groupID, "wsResultsCount", len(wsResults))
log.ZDebug(ctx, "group push result", "result", wsResults, "msg", msg)
log.ZInfo(ctx, "online group push end")
if !c.shouldPushOffline(ctx, msg) {
return nil
}
needOfflinePushUserIDs := c.onlinePusher.GetOnlinePushFailedUserIDs(ctx, msg, wsResults, &pushToUserIDs)
//filter some user, like don not disturb or don't need offline push etc.
needOfflinePushUserIDs, err = c.filterGroupMessageOfflinePush(ctx, groupID, msg, needOfflinePushUserIDs)
if err != nil {
return err
}
log.ZInfo(ctx, "filterGroupMessageOfflinePush end")
// Use offline push messaging
if len(needOfflinePushUserIDs) > 0 {
c.asyncOfflinePush(ctx, needOfflinePushUserIDs, msg)
}
return nil
}
func (c *ConsumerHandler) asyncOfflinePush(ctx context.Context, needOfflinePushUserIDs []string, msg *sdkws.MsgData) {
var offlinePushUserIDs []string
err := c.webhookBeforeOfflinePush(ctx, &c.config.WebhooksConfig.BeforeOfflinePush, needOfflinePushUserIDs, msg, &offlinePushUserIDs)
if err != nil {
log.ZWarn(ctx, "webhookBeforeOfflinePush failed", err, "msg", msg)
return
}
if len(offlinePushUserIDs) > 0 {
needOfflinePushUserIDs = offlinePushUserIDs
}
if err := c.pushDatabase.MsgToOfflinePushMQ(ctx, conversationutil.GenConversationUniqueKeyForSingle(msg.SendID, msg.RecvID), needOfflinePushUserIDs, msg); err != nil {
log.ZDebug(ctx, "Msg To OfflinePush MQ error", err, "needOfflinePushUserIDs",
needOfflinePushUserIDs, "msg", msg)
log.ZWarn(ctx, "Msg To OfflinePush MQ error", err, "needOfflinePushUserIDs length",
len(needOfflinePushUserIDs), "msg", msg)
prommetrics.GroupChatMsgProcessFailedCounter.Inc()
return
}
}
func (c *ConsumerHandler) groupMessagesHandler(ctx context.Context, groupID string, pushToUserIDs *[]string, msg *sdkws.MsgData) (err error) {
if len(*pushToUserIDs) == 0 {
*pushToUserIDs, err = c.groupLocalCache.GetGroupMemberIDs(ctx, groupID)
if err != nil {
return err
}
switch msg.ContentType {
case constant.MemberQuitNotification:
var tips sdkws.MemberQuitTips
if unmarshalNotificationElem(msg.Content, &tips) != nil {
return err
}
if err = c.DeleteMemberAndSetConversationSeq(ctx, groupID, []string{tips.QuitUser.UserID}); err != nil {
log.ZError(ctx, "MemberQuitNotification DeleteMemberAndSetConversationSeq", err, "groupID", groupID, "userID", tips.QuitUser.UserID)
}
// 退出群聊通知只通知群主和管理员,不通知退出者本人
case constant.MemberKickedNotification:
var tips sdkws.MemberKickedTips
if unmarshalNotificationElem(msg.Content, &tips) != nil {
return err
}
kickedUsers := datautil.Slice(tips.KickedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
if err = c.DeleteMemberAndSetConversationSeq(ctx, groupID, kickedUsers); err != nil {
log.ZError(ctx, "MemberKickedNotification DeleteMemberAndSetConversationSeq", err, "groupID", groupID, "userIDs", kickedUsers)
}
// 被踢出群聊通知只通知群主和管理员,不通知被踢出的用户本人
case constant.GroupDismissedNotification:
if msgprocessor.IsNotification(msgprocessor.GetConversationIDByMsg(msg)) {
var tips sdkws.GroupDismissedTips
if unmarshalNotificationElem(msg.Content, &tips) != nil {
return err
}
log.ZDebug(ctx, "GroupDismissedNotificationInfo****", "groupID", groupID, "num", len(*pushToUserIDs), "list", pushToUserIDs)
if len(c.config.Share.IMAdminUser.UserIDs) > 0 {
ctx = mcontext.WithOpUserIDContext(ctx, c.config.Share.IMAdminUser.UserIDs[0])
}
defer func(groupID string) {
if err := c.groupClient.DismissGroup(ctx, groupID, true); err != nil {
log.ZError(ctx, "DismissGroup Notification clear members", err, "groupID", groupID)
}
}(groupID)
}
}
}
// 过滤通知消息,只通知群主、管理员和相关成员本人
switch msg.ContentType {
case constant.GroupMemberMutedNotification, constant.GroupMemberCancelMutedNotification:
// 禁言和取消禁言:通知群主、管理员和本人
if err := c.filterNotificationWithUser(ctx, groupID, pushToUserIDs, msg, true); err != nil {
return err
}
case constant.MemberQuitNotification, constant.MemberEnterNotification, constant.GroupMemberInfoSetNotification:
// 退出、进入、成员信息设置:只通知群主和管理员
if err := c.filterNotificationWithUser(ctx, groupID, pushToUserIDs, msg, false); err != nil {
return err
}
case constant.MemberKickedNotification:
// 被踢出:通知群主、管理员和本人
if err := c.filterNotificationWithUser(ctx, groupID, pushToUserIDs, msg, true); err != nil {
return err
}
case constant.MemberInvitedNotification:
// 被邀请:通知群主、管理员和被邀请的本人
if err := c.filterNotificationWithUser(ctx, groupID, pushToUserIDs, msg, true); err != nil {
return err
}
case constant.GroupMemberSetToAdminNotification, constant.GroupMemberSetToOrdinaryUserNotification:
// 设置为管理员/普通用户:通知群主、管理员和本人
if err := c.filterNotificationWithUser(ctx, groupID, pushToUserIDs, msg, true); err != nil {
return err
}
}
return err
}
// filterNotificationWithUser 过滤通知消息,只通知群主、管理员,可选是否包含相关用户本人
// includeUser: true 表示包含相关用户本人false 表示排除相关用户
func (c *ConsumerHandler) filterNotificationWithUser(ctx context.Context, groupID string, pushToUserIDs *[]string, msg *sdkws.MsgData, includeUser bool) error {
notificationName := getNotificationName(msg.ContentType)
log.ZInfo(ctx, notificationName+" push filter start", "groupID", groupID, "originalPushToUserIDs", *pushToUserIDs, "originalCount", len(*pushToUserIDs), "includeUser", includeUser)
// 解析通知内容提取相关用户ID
var relatedUserIDs []string
var excludeUserIDs []string
switch msg.ContentType {
case constant.GroupMemberMutedNotification:
var tips sdkws.GroupMemberMutedTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
if includeUser {
relatedUserIDs = append(relatedUserIDs, tips.MutedUser.UserID)
}
log.ZDebug(ctx, notificationName+" parsed tips", "mutedUserID", tips.MutedUser.UserID, "opUserID", tips.OpUser.UserID)
case constant.GroupMemberCancelMutedNotification:
var tips sdkws.GroupMemberCancelMutedTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
if includeUser {
relatedUserIDs = append(relatedUserIDs, tips.MutedUser.UserID)
}
log.ZDebug(ctx, notificationName+" parsed tips", "cancelMutedUserID", tips.MutedUser.UserID, "opUserID", tips.OpUser.UserID)
case constant.MemberQuitNotification:
var tips sdkws.MemberQuitTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
excludeUserIDs = append(excludeUserIDs, tips.QuitUser.UserID)
log.ZDebug(ctx, notificationName+" parsed tips", "quitUserID", tips.QuitUser.UserID)
case constant.MemberInvitedNotification:
var tips sdkws.MemberInvitedTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
invitedUserIDs := datautil.Slice(tips.InvitedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
if includeUser {
relatedUserIDs = append(relatedUserIDs, invitedUserIDs...)
} else {
excludeUserIDs = append(excludeUserIDs, invitedUserIDs...)
}
log.ZDebug(ctx, notificationName+" parsed tips", "invitedUserIDs", invitedUserIDs, "opUserID", tips.OpUser.UserID)
case constant.MemberEnterNotification:
var tips sdkws.MemberEnterTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
excludeUserIDs = append(excludeUserIDs, tips.EntrantUser.UserID)
log.ZDebug(ctx, notificationName+" parsed tips", "entrantUserID", tips.EntrantUser.UserID)
case constant.MemberKickedNotification:
var tips sdkws.MemberKickedTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
kickedUserIDs := datautil.Slice(tips.KickedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
if includeUser {
relatedUserIDs = append(relatedUserIDs, kickedUserIDs...)
} else {
excludeUserIDs = append(excludeUserIDs, kickedUserIDs...)
}
log.ZDebug(ctx, notificationName+" parsed tips", "kickedUserIDs", kickedUserIDs, "opUserID", tips.OpUser.UserID)
case constant.GroupMemberInfoSetNotification, constant.GroupMemberSetToAdminNotification, constant.GroupMemberSetToOrdinaryUserNotification:
var tips sdkws.GroupMemberInfoSetTips
if err := unmarshalNotificationElem(msg.Content, &tips); err != nil {
log.ZWarn(ctx, notificationName+" unmarshalNotificationElem failed", err, "groupID", groupID)
return err
}
if includeUser {
relatedUserIDs = append(relatedUserIDs, tips.ChangedUser.UserID)
} else {
excludeUserIDs = append(excludeUserIDs, tips.ChangedUser.UserID)
}
log.ZDebug(ctx, notificationName+" parsed tips", "changedUserID", tips.ChangedUser.UserID, "opUserID", tips.OpUser.UserID)
default:
log.ZWarn(ctx, notificationName+" unsupported notification type", nil, "contentType", msg.ContentType)
return nil
}
// 获取所有群成员信息(如果还没有获取)
allMemberIDs := *pushToUserIDs
if len(allMemberIDs) == 0 {
var err error
allMemberIDs, err = c.groupLocalCache.GetGroupMemberIDs(ctx, groupID)
if err != nil {
log.ZWarn(ctx, notificationName+" GetGroupMemberIDs failed", err, "groupID", groupID)
return err
}
log.ZDebug(ctx, notificationName+" fetched all member IDs", "groupID", groupID, "memberCount", len(allMemberIDs))
}
members, err := c.groupLocalCache.GetGroupMembers(ctx, groupID, allMemberIDs)
if err != nil {
log.ZWarn(ctx, notificationName+" GetGroupMembers failed", err, "groupID", groupID)
return err
}
log.ZDebug(ctx, notificationName+" got members", "groupID", groupID, "memberCount", len(members))
// 筛选群主和管理员
var targetUserIDs []string
var ownerUserIDs []string
var adminUserIDs []string
for _, member := range members {
// 排除相关用户
if len(excludeUserIDs) > 0 && datautil.Contain(member.UserID, excludeUserIDs...) {
continue
}
if member.RoleLevel == constant.GroupOwner {
targetUserIDs = append(targetUserIDs, member.UserID)
ownerUserIDs = append(ownerUserIDs, member.UserID)
} else if member.RoleLevel == constant.GroupAdmin {
targetUserIDs = append(targetUserIDs, member.UserID)
adminUserIDs = append(adminUserIDs, member.UserID)
}
}
log.ZDebug(ctx, notificationName+" filtered owners and admins", "ownerCount", len(ownerUserIDs), "ownerIDs", ownerUserIDs, "adminCount", len(adminUserIDs), "adminIDs", adminUserIDs)
// 添加相关用户本人(如果需要)
if includeUser && len(relatedUserIDs) > 0 {
targetUserIDs = append(targetUserIDs, relatedUserIDs...)
log.ZDebug(ctx, notificationName+" added related users", "relatedUserIDs", relatedUserIDs, "targetCountBeforeDistinct", len(targetUserIDs))
}
// 去重并更新推送列表
*pushToUserIDs = datautil.Distinct(targetUserIDs)
log.ZInfo(ctx, notificationName+" push filter completed", "groupID", groupID, "finalPushToUserIDs", *pushToUserIDs, "finalCount", len(*pushToUserIDs), "filteredFrom", len(allMemberIDs))
return nil
}
// getNotificationName 根据通知类型获取通知名称
func getNotificationName(contentType int32) string {
switch contentType {
case constant.GroupMemberMutedNotification:
return "GroupMemberMutedNotification"
case constant.GroupMemberCancelMutedNotification:
return "GroupMemberCancelMutedNotification"
case constant.MemberQuitNotification:
return "MemberQuitNotification"
case constant.MemberInvitedNotification:
return "MemberInvitedNotification"
case constant.MemberEnterNotification:
return "MemberEnterNotification"
case constant.MemberKickedNotification:
return "MemberKickedNotification"
case constant.GroupMemberInfoSetNotification:
return "GroupMemberInfoSetNotification"
case constant.GroupMemberSetToAdminNotification:
return "GroupMemberSetToAdminNotification"
case constant.GroupMemberSetToOrdinaryUserNotification:
return "GroupMemberSetToOrdinaryUserNotification"
default:
return "UnknownNotification"
}
}
func (c *ConsumerHandler) offlinePushMsg(ctx context.Context, msg *sdkws.MsgData, offlinePushUserIDs []string) error {
title, content, opts, err := c.getOfflinePushInfos(msg)
if err != nil {
log.ZError(ctx, "getOfflinePushInfos failed", err, "msg", msg)
return err
}
err = c.offlinePusher.Push(ctx, offlinePushUserIDs, title, content, opts)
if err != nil {
prommetrics.MsgOfflinePushFailedCounter.Inc()
return err
}
return nil
}
func (c *ConsumerHandler) filterGroupMessageOfflinePush(ctx context.Context, groupID string, msg *sdkws.MsgData,
offlinePushUserIDs []string) (userIDs []string, err error) {
needOfflinePushUserIDs, err := c.conversationClient.GetConversationOfflinePushUserIDs(ctx, conversationutil.GenGroupConversationID(groupID), offlinePushUserIDs)
if err != nil {
return nil, err
}
return needOfflinePushUserIDs, nil
}
func (c *ConsumerHandler) getOfflinePushInfos(msg *sdkws.MsgData) (title, content string, opts *options.Opts, err error) {
type AtTextElem struct {
Text string `json:"text,omitempty"`
AtUserList []string `json:"atUserList,omitempty"`
IsAtSelf bool `json:"isAtSelf"`
}
opts = &options.Opts{Signal: &options.Signal{ClientMsgID: msg.ClientMsgID}}
if msg.OfflinePushInfo != nil {
opts.IOSBadgeCount = msg.OfflinePushInfo.IOSBadgeCount
opts.IOSPushSound = msg.OfflinePushInfo.IOSPushSound
opts.Ex = msg.OfflinePushInfo.Ex
}
if msg.OfflinePushInfo != nil {
title = msg.OfflinePushInfo.Title
content = msg.OfflinePushInfo.Desc
}
if title == "" {
switch msg.ContentType {
case constant.Text:
fallthrough
case constant.Picture:
fallthrough
case constant.Voice:
fallthrough
case constant.Video:
fallthrough
case constant.File:
title = constant.ContentType2PushContent[int64(msg.ContentType)]
case constant.AtText:
ac := AtTextElem{}
_ = jsonutil.JsonStringToStruct(string(msg.Content), &ac)
case constant.SignalingNotification:
title = constant.ContentType2PushContent[constant.SignalMsg]
default:
title = constant.ContentType2PushContent[constant.Common]
}
}
if content == "" {
content = title
}
return
}
func (c *ConsumerHandler) DeleteMemberAndSetConversationSeq(ctx context.Context, groupID string, userIDs []string) error {
conversationID := msgprocessor.GetConversationIDBySessionType(constant.ReadGroupChatType, groupID)
maxSeq, err := c.msgClient.GetConversationMaxSeq(ctx, conversationID)
if err != nil {
return err
}
return c.conversationClient.SetConversationMaxSeq(ctx, conversationID, userIDs, maxSeq)
}
func unmarshalNotificationElem(bytes []byte, t any) error {
var notification sdkws.NotificationElem
if err := json.Unmarshal(bytes, &notification); err != nil {
return err
}
return json.Unmarshal([]byte(notification.Detail), t)
}

311
internal/rpc/auth/auth.go Normal file
View File

@@ -0,0 +1,311 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package auth
import (
"context"
"errors"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
redis2 "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
pbauth "git.imall.cloud/openim/protocol/auth"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/tokenverify"
"google.golang.org/grpc"
)
type authServer struct {
pbauth.UnimplementedAuthServer
authDatabase controller.AuthDatabase
AuthLocalCache *rpccache.AuthLocalCache
RegisterCenter discovery.Conn
config *Config
userClient *rpcli.UserClient
adminUserIDs []string
}
type Config struct {
RpcConfig config.Auth
RedisConfig config.Redis
MongoConfig config.Mongo
Share config.Share
LocalCacheConfig config.LocalCache
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongoConfig, &config.RedisConfig)
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
var token cache.TokenModel
if rdb == nil {
mdb, err := dbb.Mongo(ctx)
if err != nil {
return err
}
mc, err := mgo.NewCacheMgo(mdb.GetDB())
if err != nil {
return err
}
token = mcache.NewTokenCacheModel(mc, config.RpcConfig.TokenPolicy.Expire)
} else {
token = redis2.NewTokenCacheModel(rdb, &config.LocalCacheConfig, config.RpcConfig.TokenPolicy.Expire)
}
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
authConn, err := client.GetConn(ctx, config.Discovery.RpcService.Auth)
if err != nil {
return err
}
localcache.InitLocalCache(&config.LocalCacheConfig)
pbauth.RegisterAuthServer(server, &authServer{
RegisterCenter: client,
authDatabase: controller.NewAuthDatabase(
token,
config.Share.Secret,
config.RpcConfig.TokenPolicy.Expire,
config.Share.MultiLogin,
config.Share.IMAdminUser.UserIDs,
),
AuthLocalCache: rpccache.NewAuthLocalCache(rpcli.NewAuthClient(authConn), &config.LocalCacheConfig, rdb),
config: config,
userClient: rpcli.NewUserClient(userConn),
adminUserIDs: config.Share.IMAdminUser.UserIDs,
})
return nil
}
func (s *authServer) GetAdminToken(ctx context.Context, req *pbauth.GetAdminTokenReq) (*pbauth.GetAdminTokenResp, error) {
resp := pbauth.GetAdminTokenResp{}
if req.Secret != s.config.Share.Secret {
return nil, errs.ErrNoPermission.WrapMsg("secret invalid")
}
if !datautil.Contain(req.UserID, s.adminUserIDs...) {
return nil, errs.ErrArgs.WrapMsg("userID is error.", "userID", req.UserID, "adminUserID", s.adminUserIDs)
}
if err := s.userClient.CheckUser(ctx, []string{req.UserID}); err != nil {
return nil, err
}
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(constant.AdminPlatformID))
if err != nil {
return nil, err
}
prommetrics.UserLoginCounter.Inc()
resp.Token = token
resp.ExpireTimeSeconds = s.config.RpcConfig.TokenPolicy.Expire * 24 * 60 * 60
return &resp, nil
}
func (s *authServer) GetUserToken(ctx context.Context, req *pbauth.GetUserTokenReq) (*pbauth.GetUserTokenResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if req.PlatformID == constant.AdminPlatformID {
return nil, errs.ErrNoPermission.WrapMsg("platformID invalid. platformID must not be adminPlatformID")
}
resp := pbauth.GetUserTokenResp{}
if authverify.CheckUserIsAdmin(ctx, req.UserID) {
return nil, errs.ErrNoPermission.WrapMsg("don't get Admin token")
}
user, err := s.userClient.GetUserInfo(ctx, req.UserID)
if err != nil {
return nil, err
}
if user.AppMangerLevel >= constant.AppNotificationAdmin {
return nil, errs.ErrArgs.WrapMsg("app account can`t get token")
}
token, err := s.authDatabase.CreateToken(ctx, req.UserID, int(req.PlatformID))
if err != nil {
return nil, err
}
prommetrics.UserLoginCounter.Inc()
resp.Token = token
resp.ExpireTimeSeconds = s.config.RpcConfig.TokenPolicy.Expire * 24 * 60 * 60
return &resp, nil
}
func (s *authServer) GetExistingToken(ctx context.Context, req *pbauth.GetExistingTokenReq) (*pbauth.GetExistingTokenResp, error) {
m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID))
if err != nil {
return nil, err
}
return &pbauth.GetExistingTokenResp{
TokenStates: convert.TokenMapDB2Pb(m),
}, nil
}
func (s *authServer) parseToken(ctx context.Context, tokensString string) (claims *tokenverify.Claims, err error) {
claims, err = tokenverify.GetClaimFromToken(tokensString, authverify.Secret(s.config.Share.Secret))
if err != nil {
return nil, err
}
m, err := s.AuthLocalCache.GetExistingToken(ctx, claims.UserID, claims.PlatformID)
if err != nil {
return nil, err
}
if len(m) == 0 {
isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
return nil, servererrs.ErrTokenNotExist.Wrap()
}
if v, ok := m[tokensString]; ok {
switch v {
case constant.NormalToken:
return claims, nil
case constant.KickedToken:
return nil, servererrs.ErrTokenKicked.Wrap()
default:
return nil, errs.Wrap(errs.ErrTokenUnknown)
}
} else {
isAdmin := authverify.CheckUserIsAdmin(ctx, claims.UserID)
if isAdmin {
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
return claims, nil
}
}
}
return nil, servererrs.ErrTokenNotExist.Wrap()
}
func (s *authServer) ParseToken(ctx context.Context, req *pbauth.ParseTokenReq) (resp *pbauth.ParseTokenResp, err error) {
resp = &pbauth.ParseTokenResp{}
claims, err := s.parseToken(ctx, req.Token)
if err != nil {
return nil, err
}
resp.UserID = claims.UserID
resp.PlatformID = int32(claims.PlatformID)
resp.ExpireTimeSeconds = claims.ExpiresAt.Unix()
return resp, nil
}
func (s *authServer) ForceLogout(ctx context.Context, req *pbauth.ForceLogoutReq) (*pbauth.ForceLogoutResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if err := s.forceKickOff(ctx, req.UserID, req.PlatformID); err != nil {
return nil, err
}
return &pbauth.ForceLogoutResp{}, nil
}
func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID int32) error {
conns, err := s.RegisterCenter.GetConns(ctx, s.config.Discovery.RpcService.MessageGateway)
if err != nil {
return err
}
for _, v := range conns {
log.ZDebug(ctx, "forceKickOff", "userID", userID, "platformID", platformID)
client := msggateway.NewMsgGatewayClient(v)
kickReq := &msggateway.KickUserOfflineReq{KickUserIDList: []string{userID}, PlatformID: platformID}
_, err := client.KickUserOffline(ctx, kickReq)
if err != nil {
log.ZError(ctx, "forceKickOff", err, "kickReq", kickReq)
}
}
m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID))
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
for k := range m {
m[k] = constant.KickedToken
log.ZDebug(ctx, "set token map is ", "token map", m, "userID",
userID, "token", k)
err = s.authDatabase.SetTokenMapByUidPid(ctx, userID, int(platformID), m)
if err != nil {
return err
}
}
return nil
}
func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) {
m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID))
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
}
if m == nil {
return nil, errs.New("token map is empty").Wrap()
}
log.ZDebug(ctx, "get token from redis", "userID", req.UserID, "platformID",
req.PlatformID, "tokenMap", m)
for k := range m {
if k != req.GetPreservedToken() {
m[k] = constant.KickedToken
}
}
log.ZDebug(ctx, "set token map is ", "token map", m, "userID",
req.UserID, "token", req.GetPreservedToken())
err = s.authDatabase.SetTokenMapByUidPid(ctx, req.UserID, int(req.PlatformID), m)
if err != nil {
return nil, err
}
return &pbauth.InvalidateTokenResp{}, nil
}
func (s *authServer) KickTokens(ctx context.Context, req *pbauth.KickTokensReq) (*pbauth.KickTokensResp, error) {
if err := s.authDatabase.BatchSetTokenMapByUidPid(ctx, req.Tokens); err != nil {
return nil, err
}
return &pbauth.KickTokensResp{}, nil
}

View File

@@ -0,0 +1,117 @@
package conversation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
dbModel "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"github.com/openimsdk/tools/utils/datautil"
)
func (c *conversationServer) webhookBeforeCreateSingleChatConversations(ctx context.Context, before *config.BeforeConfig, req *dbModel.Conversation) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeCreateSingleChatConversationsReq{
CallbackCommand: callbackstruct.CallbackBeforeCreateSingleChatConversationsCommand,
OwnerUserID: req.OwnerUserID,
ConversationID: req.ConversationID,
ConversationType: req.ConversationType,
UserID: req.UserID,
RecvMsgOpt: req.RecvMsgOpt,
IsPinned: req.IsPinned,
IsPrivateChat: req.IsPrivateChat,
BurnDuration: req.BurnDuration,
GroupAtType: req.GroupAtType,
AttachedInfo: req.AttachedInfo,
Ex: req.Ex,
}
resp := &callbackstruct.CallbackBeforeCreateSingleChatConversationsResp{}
if err := c.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
datautil.NotNilReplace(&req.RecvMsgOpt, resp.RecvMsgOpt)
datautil.NotNilReplace(&req.IsPinned, resp.IsPinned)
datautil.NotNilReplace(&req.IsPrivateChat, resp.IsPrivateChat)
datautil.NotNilReplace(&req.BurnDuration, resp.BurnDuration)
datautil.NotNilReplace(&req.GroupAtType, resp.GroupAtType)
datautil.NotNilReplace(&req.AttachedInfo, resp.AttachedInfo)
datautil.NotNilReplace(&req.Ex, resp.Ex)
return nil
})
}
func (c *conversationServer) webhookAfterCreateSingleChatConversations(ctx context.Context, after *config.AfterConfig, req *dbModel.Conversation) error {
cbReq := &callbackstruct.CallbackAfterCreateSingleChatConversationsReq{
CallbackCommand: callbackstruct.CallbackAfterCreateSingleChatConversationsCommand,
OwnerUserID: req.OwnerUserID,
ConversationID: req.ConversationID,
ConversationType: req.ConversationType,
UserID: req.UserID,
RecvMsgOpt: req.RecvMsgOpt,
IsPinned: req.IsPinned,
IsPrivateChat: req.IsPrivateChat,
BurnDuration: req.BurnDuration,
GroupAtType: req.GroupAtType,
AttachedInfo: req.AttachedInfo,
Ex: req.Ex,
}
c.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterCreateSingleChatConversationsResp{}, after)
return nil
}
func (c *conversationServer) webhookBeforeCreateGroupChatConversations(ctx context.Context, before *config.BeforeConfig, req *dbModel.Conversation) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeCreateGroupChatConversationsReq{
CallbackCommand: callbackstruct.CallbackBeforeCreateGroupChatConversationsCommand,
ConversationID: req.ConversationID,
ConversationType: req.ConversationType,
GroupID: req.GroupID,
RecvMsgOpt: req.RecvMsgOpt,
IsPinned: req.IsPinned,
IsPrivateChat: req.IsPrivateChat,
BurnDuration: req.BurnDuration,
GroupAtType: req.GroupAtType,
AttachedInfo: req.AttachedInfo,
Ex: req.Ex,
}
resp := &callbackstruct.CallbackBeforeCreateGroupChatConversationsResp{}
if err := c.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
datautil.NotNilReplace(&req.RecvMsgOpt, resp.RecvMsgOpt)
datautil.NotNilReplace(&req.IsPinned, resp.IsPinned)
datautil.NotNilReplace(&req.IsPrivateChat, resp.IsPrivateChat)
datautil.NotNilReplace(&req.BurnDuration, resp.BurnDuration)
datautil.NotNilReplace(&req.GroupAtType, resp.GroupAtType)
datautil.NotNilReplace(&req.AttachedInfo, resp.AttachedInfo)
datautil.NotNilReplace(&req.Ex, resp.Ex)
return nil
})
}
func (c *conversationServer) webhookAfterCreateGroupChatConversations(ctx context.Context, after *config.AfterConfig, req *dbModel.Conversation) error {
cbReq := &callbackstruct.CallbackAfterCreateGroupChatConversationsReq{
CallbackCommand: callbackstruct.CallbackAfterCreateGroupChatConversationsCommand,
ConversationID: req.ConversationID,
ConversationType: req.ConversationType,
GroupID: req.GroupID,
RecvMsgOpt: req.RecvMsgOpt,
IsPinned: req.IsPinned,
IsPrivateChat: req.IsPrivateChat,
BurnDuration: req.BurnDuration,
GroupAtType: req.GroupAtType,
AttachedInfo: req.AttachedInfo,
Ex: req.Ex,
}
c.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterCreateGroupChatConversationsResp{}, after)
return nil
}

View File

@@ -0,0 +1,865 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conversation
import (
"context"
"sort"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"google.golang.org/grpc"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
dbModel "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/protocol/constant"
pbconversation "git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
)
type conversationServer struct {
pbconversation.UnimplementedConversationServer
conversationDatabase controller.ConversationDatabase
conversationNotificationSender *ConversationNotificationSender
config *Config
webhookClient *webhook.Client
userClient *rpcli.UserClient
msgClient *rpcli.MsgClient
groupClient *rpcli.GroupClient
}
type Config struct {
RpcConfig config.Conversation
RedisConfig config.Redis
MongodbConfig config.Mongo
NotificationConfig config.Notification
Share config.Share
WebhooksConfig config.Webhooks
LocalCacheConfig config.LocalCache
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return err
}
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
conversationDB, err := mgo.NewConversationMongo(mgocli.GetDB())
if err != nil {
return err
}
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
groupConn, err := client.GetConn(ctx, config.Discovery.RpcService.Group)
if err != nil {
return err
}
msgConn, err := client.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
}
msgClient := rpcli.NewMsgClient(msgConn)
// 初始化webhook配置管理器支持从数据库读取配置
var webhookClient *webhook.Client
systemConfigDB, err := mgo.NewSystemConfigMongo(mgocli.GetDB())
if err == nil {
// 如果SystemConfig数据库初始化成功使用配置管理器
webhookConfigManager := webhook.NewConfigManager(systemConfigDB, &config.WebhooksConfig)
if err := webhookConfigManager.Start(ctx); err != nil {
log.ZWarn(ctx, "failed to start webhook config manager, using default config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
} else {
webhookClient = webhook.NewWebhookClientWithManager(webhookConfigManager)
}
} else {
// 如果SystemConfig数据库初始化失败使用默认配置
log.ZWarn(ctx, "failed to init system config db, using default webhook config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
}
cs := conversationServer{
config: config,
webhookClient: webhookClient,
userClient: rpcli.NewUserClient(userConn),
groupClient: rpcli.NewGroupClient(groupConn),
msgClient: msgClient,
}
cs.conversationNotificationSender = NewConversationNotificationSender(&config.NotificationConfig, msgClient)
cs.conversationDatabase = controller.NewConversationDatabase(
conversationDB,
redis.NewConversationRedis(rdb, &config.LocalCacheConfig, conversationDB),
mgocli.GetTx())
localcache.InitLocalCache(&config.LocalCacheConfig)
pbconversation.RegisterConversationServer(server, &cs)
return nil
}
func (c *conversationServer) GetConversation(ctx context.Context, req *pbconversation.GetConversationReq) (*pbconversation.GetConversationResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
conversations, err := c.conversationDatabase.FindConversations(ctx, req.OwnerUserID, []string{req.ConversationID})
if err != nil {
return nil, err
}
if len(conversations) < 1 {
return nil, errs.ErrRecordNotFound.WrapMsg("conversation not found")
}
resp := &pbconversation.GetConversationResp{Conversation: &pbconversation.Conversation{}}
resp.Conversation = convert.ConversationDB2Pb(conversations[0])
return resp, nil
}
// Deprecated: Use `GetConversations` instead.
func (c *conversationServer) GetSortedConversationList(ctx context.Context, req *pbconversation.GetSortedConversationListReq) (resp *pbconversation.GetSortedConversationListResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
var conversationIDs []string
if len(req.ConversationIDs) == 0 {
conversationIDs, err = c.conversationDatabase.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
} else {
conversationIDs = req.ConversationIDs
}
conversations, err := c.conversationDatabase.FindConversations(ctx, req.UserID, conversationIDs)
if err != nil {
return nil, err
}
if len(conversations) == 0 {
return nil, errs.ErrRecordNotFound.Wrap()
}
maxSeqs, err := c.msgClient.GetMaxSeqs(ctx, conversationIDs)
if err != nil {
return nil, err
}
chatLogs, err := c.msgClient.GetMsgByConversationIDs(ctx, conversationIDs, maxSeqs)
if err != nil {
return nil, err
}
conversationMsg, err := c.getConversationInfo(ctx, chatLogs, req.UserID)
if err != nil {
return nil, err
}
hasReadSeqs, err := c.msgClient.GetHasReadSeqs(ctx, conversationIDs, req.UserID)
if err != nil {
return nil, err
}
var unreadTotal int64
conversation_unreadCount := make(map[string]int64)
for conversationID, maxSeq := range maxSeqs {
unreadCount := maxSeq - hasReadSeqs[conversationID]
conversation_unreadCount[conversationID] = unreadCount
unreadTotal += unreadCount
}
conversation_isPinTime := make(map[int64]string)
conversation_notPinTime := make(map[int64]string)
for _, v := range conversations {
conversationID := v.ConversationID
var time int64
if _, ok := conversationMsg[conversationID]; ok {
time = conversationMsg[conversationID].MsgInfo.LatestMsgRecvTime
} else {
conversationMsg[conversationID] = &pbconversation.ConversationElem{
ConversationID: conversationID,
IsPinned: v.IsPinned,
MsgInfo: nil,
}
time = v.CreateTime.UnixMilli()
}
conversationMsg[conversationID].RecvMsgOpt = v.RecvMsgOpt
if v.IsPinned {
conversationMsg[conversationID].IsPinned = v.IsPinned
conversation_isPinTime[time] = conversationID
continue
}
conversation_notPinTime[time] = conversationID
}
resp = &pbconversation.GetSortedConversationListResp{
ConversationTotal: int64(len(chatLogs)),
ConversationElems: []*pbconversation.ConversationElem{},
UnreadTotal: unreadTotal,
}
c.conversationSort(conversation_isPinTime, resp, conversation_unreadCount, conversationMsg)
c.conversationSort(conversation_notPinTime, resp, conversation_unreadCount, conversationMsg)
resp.ConversationElems = datautil.Paginate(resp.ConversationElems, int(req.Pagination.GetPageNumber()), int(req.Pagination.GetShowNumber()))
return resp, nil
}
func (c *conversationServer) GetAllConversations(ctx context.Context, req *pbconversation.GetAllConversationsReq) (*pbconversation.GetAllConversationsResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
conversations, err := c.conversationDatabase.GetUserAllConversation(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
resp := &pbconversation.GetAllConversationsResp{Conversations: []*pbconversation.Conversation{}}
resp.Conversations = convert.ConversationsDB2Pb(conversations)
return resp, nil
}
func (c *conversationServer) GetConversations(ctx context.Context, req *pbconversation.GetConversationsReq) (*pbconversation.GetConversationsResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
conversations, err := c.getConversations(ctx, req.OwnerUserID, req.ConversationIDs)
if err != nil {
return nil, err
}
return &pbconversation.GetConversationsResp{
Conversations: conversations,
}, nil
}
func (c *conversationServer) getConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*pbconversation.Conversation, error) {
conversations, err := c.conversationDatabase.FindConversations(ctx, ownerUserID, conversationIDs)
if err != nil {
return nil, err
}
resp := &pbconversation.GetConversationsResp{Conversations: []*pbconversation.Conversation{}}
resp.Conversations = convert.ConversationsDB2Pb(conversations)
return convert.ConversationsDB2Pb(conversations), nil
}
// Deprecated
func (c *conversationServer) SetConversation(ctx context.Context, req *pbconversation.SetConversationReq) (*pbconversation.SetConversationResp, error) {
if err := authverify.CheckAccess(ctx, req.GetConversation().GetUserID()); err != nil {
return nil, err
}
var conversation dbModel.Conversation
conversation.CreateTime = time.Now()
if err := datautil.CopyStructFields(&conversation, req.Conversation); err != nil {
return nil, err
}
err := c.conversationDatabase.SetUserConversations(ctx, req.Conversation.OwnerUserID, []*dbModel.Conversation{&conversation})
if err != nil {
return nil, err
}
c.conversationNotificationSender.ConversationChangeNotification(ctx, req.Conversation.OwnerUserID, []string{req.Conversation.ConversationID})
resp := &pbconversation.SetConversationResp{}
return resp, nil
}
func (c *conversationServer) SetConversations(ctx context.Context, req *pbconversation.SetConversationsReq) (*pbconversation.SetConversationsResp, error) {
for _, userID := range req.UserIDs {
if err := authverify.CheckAccess(ctx, userID); err != nil {
return nil, err
}
}
if req.Conversation.ConversationType == constant.WriteGroupChatType {
groupInfo, err := c.groupClient.GetGroupInfo(ctx, req.Conversation.GroupID)
if err != nil {
return nil, err
}
if groupInfo == nil {
return nil, servererrs.ErrGroupIDNotFound.WrapMsg(req.Conversation.GroupID)
}
if groupInfo.Status == constant.GroupStatusDismissed {
return nil, servererrs.ErrDismissedAlready.WrapMsg("group dismissed")
}
}
conversationMap := make(map[string]*dbModel.Conversation)
var needUpdateUsersList []string
for _, userID := range req.UserIDs {
conversationList, err := c.conversationDatabase.FindConversations(ctx, userID, []string{req.Conversation.ConversationID})
if err != nil {
return nil, err
}
if len(conversationList) != 0 {
conversationMap[userID] = conversationList[0]
} else {
needUpdateUsersList = append(needUpdateUsersList, userID)
}
}
var conversation dbModel.Conversation
conversation.ConversationID = req.Conversation.ConversationID
conversation.ConversationType = req.Conversation.ConversationType
conversation.UserID = req.Conversation.UserID
conversation.GroupID = req.Conversation.GroupID
conversation.CreateTime = time.Now()
m, conversation, err := UpdateConversationsMap(ctx, req)
if err != nil {
return nil, err
}
for userID := range conversationMap {
unequal := UserUpdateCheckMap(ctx, userID, req.Conversation, conversationMap[userID])
if unequal {
needUpdateUsersList = append(needUpdateUsersList, userID)
}
}
if len(m) != 0 && len(needUpdateUsersList) != 0 {
if err := c.conversationDatabase.SetUsersConversationFieldTx(ctx, needUpdateUsersList, &conversation, m); err != nil {
return nil, err
}
for _, userID := range needUpdateUsersList {
c.conversationNotificationSender.ConversationChangeNotification(ctx, userID, []string{req.Conversation.ConversationID})
}
}
if req.Conversation.IsPrivateChat != nil && req.Conversation.ConversationType != constant.ReadGroupChatType {
var conversations []*dbModel.Conversation
for _, ownerUserID := range req.UserIDs {
transConversation := conversation
transConversation.OwnerUserID = ownerUserID
transConversation.IsPrivateChat = req.Conversation.IsPrivateChat.Value
conversations = append(conversations, &transConversation)
}
if err := c.conversationDatabase.SyncPeerUserPrivateConversationTx(ctx, conversations); err != nil {
return nil, err
}
for _, userID := range req.UserIDs {
c.conversationNotificationSender.ConversationSetPrivateNotification(ctx, userID, req.Conversation.UserID,
req.Conversation.IsPrivateChat.Value, req.Conversation.ConversationID)
}
}
return &pbconversation.SetConversationsResp{}, nil
}
func (c *conversationServer) UpdateConversationsByUser(ctx context.Context, req *pbconversation.UpdateConversationsByUserReq) (*pbconversation.UpdateConversationsByUserResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
m := make(map[string]any)
if req.Ex != nil {
m["ex"] = req.Ex.Value
}
if len(m) > 0 {
if err := c.conversationDatabase.UpdateUserConversations(ctx, req.UserID, m); err != nil {
return nil, err
}
}
return &pbconversation.UpdateConversationsByUserResp{}, nil
}
// create conversation without notification for msg redis transfer.
func (c *conversationServer) CreateSingleChatConversations(ctx context.Context, req *pbconversation.CreateSingleChatConversationsReq) (*pbconversation.CreateSingleChatConversationsResp, error) {
var conversation dbModel.Conversation
conversation.CreateTime = time.Now()
switch req.ConversationType {
case constant.SingleChatType:
// sendUser create
conversation.ConversationID = req.ConversationID
conversation.ConversationType = req.ConversationType
conversation.OwnerUserID = req.SendID
conversation.UserID = req.RecvID
if err := c.webhookBeforeCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.BeforeCreateSingleChatConversations, &conversation); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
err := c.conversationDatabase.CreateConversation(ctx, []*dbModel.Conversation{&conversation})
if err != nil {
log.ZWarn(ctx, "create conversation failed", err, "conversation", conversation)
}
c.webhookAfterCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.AfterCreateSingleChatConversations, &conversation)
// recvUser create
conversation2 := conversation
conversation2.OwnerUserID = req.RecvID
conversation2.UserID = req.SendID
if err := c.webhookBeforeCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.BeforeCreateSingleChatConversations, &conversation); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
err = c.conversationDatabase.CreateConversation(ctx, []*dbModel.Conversation{&conversation2})
if err != nil {
log.ZWarn(ctx, "create conversation failed", err, "conversation2", conversation)
}
c.webhookAfterCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.AfterCreateSingleChatConversations, &conversation2)
case constant.NotificationChatType:
conversation.ConversationID = req.ConversationID
conversation.ConversationType = req.ConversationType
conversation.OwnerUserID = req.RecvID
conversation.UserID = req.SendID
if err := c.webhookBeforeCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.BeforeCreateSingleChatConversations, &conversation); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
err := c.conversationDatabase.CreateConversation(ctx, []*dbModel.Conversation{&conversation})
if err != nil {
log.ZWarn(ctx, "create conversation failed", err, "conversation2", conversation)
}
c.webhookAfterCreateSingleChatConversations(ctx, &c.config.WebhooksConfig.AfterCreateSingleChatConversations, &conversation)
}
return &pbconversation.CreateSingleChatConversationsResp{}, nil
}
func (c *conversationServer) CreateGroupChatConversations(ctx context.Context, req *pbconversation.CreateGroupChatConversationsReq) (*pbconversation.CreateGroupChatConversationsResp, error) {
var conversation dbModel.Conversation
conversation.ConversationID = msgprocessor.GetConversationIDBySessionType(constant.ReadGroupChatType, req.GroupID)
conversation.GroupID = req.GroupID
conversation.ConversationType = constant.ReadGroupChatType
conversation.CreateTime = time.Now()
if err := c.webhookBeforeCreateGroupChatConversations(ctx, &c.config.WebhooksConfig.BeforeCreateGroupChatConversations, &conversation); err != nil {
return nil, err
}
err := c.conversationDatabase.CreateGroupChatConversation(ctx, req.GroupID, req.UserIDs, &conversation)
if err != nil {
return nil, err
}
if err := c.msgClient.SetUserConversationMaxSeq(ctx, conversation.ConversationID, req.UserIDs, 0); err != nil {
return nil, err
}
c.webhookAfterCreateGroupChatConversations(ctx, &c.config.WebhooksConfig.AfterCreateGroupChatConversations, &conversation)
return &pbconversation.CreateGroupChatConversationsResp{}, nil
}
func (c *conversationServer) SetConversationMaxSeq(ctx context.Context, req *pbconversation.SetConversationMaxSeqReq) (*pbconversation.SetConversationMaxSeqResp, error) {
if err := c.msgClient.SetUserConversationMaxSeq(ctx, req.ConversationID, req.OwnerUserID, req.MaxSeq); err != nil {
return nil, err
}
if err := c.conversationDatabase.UpdateUsersConversationField(ctx, req.OwnerUserID, req.ConversationID,
map[string]any{"max_seq": req.MaxSeq}); err != nil {
return nil, err
}
for _, userID := range req.OwnerUserID {
c.conversationNotificationSender.ConversationChangeNotification(ctx, userID, []string{req.ConversationID})
}
return &pbconversation.SetConversationMaxSeqResp{}, nil
}
func (c *conversationServer) SetConversationMinSeq(ctx context.Context, req *pbconversation.SetConversationMinSeqReq) (*pbconversation.SetConversationMinSeqResp, error) {
if err := c.msgClient.SetUserConversationMin(ctx, req.ConversationID, req.OwnerUserID, req.MinSeq); err != nil {
return nil, err
}
if err := c.conversationDatabase.UpdateUsersConversationField(ctx, req.OwnerUserID, req.ConversationID,
map[string]any{"min_seq": req.MinSeq}); err != nil {
return nil, err
}
for _, userID := range req.OwnerUserID {
c.conversationNotificationSender.ConversationChangeNotification(ctx, userID, []string{req.ConversationID})
}
return &pbconversation.SetConversationMinSeqResp{}, nil
}
func (c *conversationServer) GetConversationIDs(ctx context.Context, req *pbconversation.GetConversationIDsReq) (*pbconversation.GetConversationIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := c.conversationDatabase.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
return &pbconversation.GetConversationIDsResp{ConversationIDs: conversationIDs}, nil
}
func (c *conversationServer) GetUserConversationIDsHash(ctx context.Context, req *pbconversation.GetUserConversationIDsHashReq) (*pbconversation.GetUserConversationIDsHashResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
hash, err := c.conversationDatabase.GetUserConversationIDsHash(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
return &pbconversation.GetUserConversationIDsHashResp{Hash: hash}, nil
}
func (c *conversationServer) GetConversationsByConversationID(ctx context.Context, req *pbconversation.GetConversationsByConversationIDReq) (*pbconversation.GetConversationsByConversationIDResp, error) {
conversations, err := c.conversationDatabase.GetConversationsByConversationID(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
return &pbconversation.GetConversationsByConversationIDResp{Conversations: convert.ConversationsDB2Pb(conversations)}, nil
}
func (c *conversationServer) GetConversationOfflinePushUserIDs(ctx context.Context, req *pbconversation.GetConversationOfflinePushUserIDsReq) (*pbconversation.GetConversationOfflinePushUserIDsResp, error) {
if req.ConversationID == "" {
return nil, errs.ErrArgs.WrapMsg("conversationID is empty")
}
if len(req.UserIDs) == 0 {
return &pbconversation.GetConversationOfflinePushUserIDsResp{}, nil
}
userIDs, err := c.conversationDatabase.GetConversationNotReceiveMessageUserIDs(ctx, req.ConversationID)
if err != nil {
return nil, err
}
if len(userIDs) == 0 {
return &pbconversation.GetConversationOfflinePushUserIDsResp{UserIDs: req.UserIDs}, nil
}
userIDSet := make(map[string]struct{})
for _, userID := range req.UserIDs {
userIDSet[userID] = struct{}{}
}
for _, userID := range userIDs {
delete(userIDSet, userID)
}
return &pbconversation.GetConversationOfflinePushUserIDsResp{UserIDs: datautil.Keys(userIDSet)}, nil
}
func (c *conversationServer) conversationSort(conversations map[int64]string, resp *pbconversation.GetSortedConversationListResp, conversation_unreadCount map[string]int64, conversationMsg map[string]*pbconversation.ConversationElem) {
keys := []int64{}
for key := range conversations {
keys = append(keys, key)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i] > keys[j]
})
index := 0
cons := make([]*pbconversation.ConversationElem, len(conversations))
for _, v := range keys {
conversationID := conversations[v]
conversationElem := conversationMsg[conversationID]
conversationElem.UnreadCount = conversation_unreadCount[conversationID]
cons[index] = conversationElem
index++
}
resp.ConversationElems = append(resp.ConversationElems, cons...)
}
func (c *conversationServer) getConversationInfo(ctx context.Context, chatLogs map[string]*sdkws.MsgData, userID string) (map[string]*pbconversation.ConversationElem, error) {
var (
sendIDs []string
groupIDs []string
sendMap = make(map[string]*sdkws.UserInfo)
groupMap = make(map[string]*sdkws.GroupInfo)
conversationMsg = make(map[string]*pbconversation.ConversationElem)
)
for _, chatLog := range chatLogs {
switch chatLog.SessionType {
case constant.SingleChatType:
if chatLog.SendID == userID {
sendIDs = append(sendIDs, chatLog.RecvID)
}
sendIDs = append(sendIDs, chatLog.SendID)
case constant.WriteGroupChatType, constant.ReadGroupChatType:
groupIDs = append(groupIDs, chatLog.GroupID)
sendIDs = append(sendIDs, chatLog.SendID)
}
}
if len(sendIDs) != 0 {
sendInfos, err := c.userClient.GetUsersInfo(ctx, sendIDs)
if err != nil {
return nil, err
}
for _, sendInfo := range sendInfos {
sendMap[sendInfo.UserID] = sendInfo
}
}
if len(groupIDs) != 0 {
groupInfos, err := c.groupClient.GetGroupsInfo(ctx, groupIDs)
if err != nil {
return nil, err
}
for _, groupInfo := range groupInfos {
groupMap[groupInfo.GroupID] = groupInfo
}
}
for conversationID, chatLog := range chatLogs {
pbchatLog := &pbconversation.ConversationElem{}
msgInfo := &pbconversation.MsgInfo{}
if err := datautil.CopyStructFields(msgInfo, chatLog); err != nil {
return nil, err
}
switch chatLog.SessionType {
case constant.SingleChatType:
if chatLog.SendID == userID {
if recv, ok := sendMap[chatLog.RecvID]; ok {
msgInfo.FaceURL = recv.FaceURL
msgInfo.SenderName = recv.Nickname
}
break
}
if send, ok := sendMap[chatLog.SendID]; ok {
msgInfo.FaceURL = send.FaceURL
msgInfo.SenderName = send.Nickname
}
case constant.WriteGroupChatType, constant.ReadGroupChatType:
msgInfo.GroupID = chatLog.GroupID
if group, ok := groupMap[chatLog.GroupID]; ok {
msgInfo.GroupName = group.GroupName
msgInfo.GroupFaceURL = group.FaceURL
msgInfo.GroupMemberCount = group.MemberCount
msgInfo.GroupType = group.GroupType
}
if send, ok := sendMap[chatLog.SendID]; ok {
msgInfo.SenderName = send.Nickname
}
}
pbchatLog.ConversationID = conversationID
msgInfo.LatestMsgRecvTime = chatLog.SendTime
pbchatLog.MsgInfo = msgInfo
conversationMsg[conversationID] = pbchatLog
}
return conversationMsg, nil
}
func (c *conversationServer) GetConversationNotReceiveMessageUserIDs(ctx context.Context, req *pbconversation.GetConversationNotReceiveMessageUserIDsReq) (*pbconversation.GetConversationNotReceiveMessageUserIDsResp, error) {
userIDs, err := c.conversationDatabase.GetConversationNotReceiveMessageUserIDs(ctx, req.ConversationID)
if err != nil {
return nil, err
}
return &pbconversation.GetConversationNotReceiveMessageUserIDsResp{UserIDs: userIDs}, nil
}
func (c *conversationServer) UpdateConversation(ctx context.Context, req *pbconversation.UpdateConversationReq) (*pbconversation.UpdateConversationResp, error) {
for _, userID := range req.UserIDs {
if err := authverify.CheckAccess(ctx, userID); err != nil {
return nil, err
}
}
m := make(map[string]any)
if req.RecvMsgOpt != nil {
m["recv_msg_opt"] = req.RecvMsgOpt.Value
}
if req.AttachedInfo != nil {
m["attached_info"] = req.AttachedInfo.Value
}
if req.Ex != nil {
m["ex"] = req.Ex.Value
}
if req.IsPinned != nil {
m["is_pinned"] = req.IsPinned.Value
}
if req.GroupAtType != nil {
m["group_at_type"] = req.GroupAtType.Value
}
if req.MsgDestructTime != nil {
m["msg_destruct_time"] = req.MsgDestructTime.Value
}
if req.IsMsgDestruct != nil {
m["is_msg_destruct"] = req.IsMsgDestruct.Value
}
if req.BurnDuration != nil {
m["burn_duration"] = req.BurnDuration.Value
}
if req.IsPrivateChat != nil {
m["is_private_chat"] = req.IsPrivateChat.Value
}
if req.MinSeq != nil {
m["min_seq"] = req.MinSeq.Value
}
if req.MaxSeq != nil {
m["max_seq"] = req.MaxSeq.Value
}
if req.LatestMsgDestructTime != nil {
m["latest_msg_destruct_time"] = time.UnixMilli(req.LatestMsgDestructTime.Value)
}
if len(m) > 0 {
if err := c.conversationDatabase.UpdateUsersConversationField(ctx, req.UserIDs, req.ConversationID, m); err != nil {
return nil, err
}
}
return &pbconversation.UpdateConversationResp{}, nil
}
func (c *conversationServer) GetOwnerConversation(ctx context.Context, req *pbconversation.GetOwnerConversationReq) (*pbconversation.GetOwnerConversationResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
total, conversations, err := c.conversationDatabase.GetOwnerConversation(ctx, req.UserID, req.Pagination)
if err != nil {
return nil, err
}
return &pbconversation.GetOwnerConversationResp{
Total: total,
Conversations: convert.ConversationsDB2Pb(conversations),
}, nil
}
func (c *conversationServer) GetConversationsNeedClearMsg(ctx context.Context, _ *pbconversation.GetConversationsNeedClearMsgReq) (*pbconversation.GetConversationsNeedClearMsgResp, error) {
num, err := c.conversationDatabase.GetAllConversationIDsNumber(ctx)
if err != nil {
log.ZError(ctx, "GetAllConversationIDsNumber failed", err)
return nil, err
}
const batchNum = 100
if num == 0 {
return nil, errs.New("Need Destruct Msg is nil").Wrap()
}
maxPage := (num + batchNum - 1) / batchNum
temp := make([]*dbModel.Conversation, 0, maxPage*batchNum)
for pageNumber := 0; pageNumber < int(maxPage); pageNumber++ {
pagination := &sdkws.RequestPagination{
PageNumber: int32(pageNumber),
ShowNumber: batchNum,
}
conversationIDs, err := c.conversationDatabase.PageConversationIDs(ctx, pagination)
if err != nil {
log.ZError(ctx, "PageConversationIDs failed", err, "pageNumber", pageNumber)
continue
}
// log.ZDebug(ctx, "PageConversationIDs success", "pageNumber", pageNumber, "conversationIDsNum", len(conversationIDs), "conversationIDs", conversationIDs)
if len(conversationIDs) == 0 {
continue
}
conversations, err := c.conversationDatabase.GetConversationsByConversationID(ctx, conversationIDs)
if err != nil {
log.ZError(ctx, "GetConversationsByConversationID failed", err, "conversationIDs", conversationIDs)
continue
}
for _, conversation := range conversations {
if conversation.IsMsgDestruct && conversation.MsgDestructTime != 0 && ((time.Now().UnixMilli() > (conversation.MsgDestructTime + conversation.LatestMsgDestructTime.UnixMilli() + 8*60*60)) || // 8*60*60 is UTC+8
conversation.LatestMsgDestructTime.IsZero()) {
temp = append(temp, conversation)
}
}
}
return &pbconversation.GetConversationsNeedClearMsgResp{Conversations: convert.ConversationsDB2Pb(temp)}, nil
}
func (c *conversationServer) GetNotNotifyConversationIDs(ctx context.Context, req *pbconversation.GetNotNotifyConversationIDsReq) (*pbconversation.GetNotNotifyConversationIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := c.conversationDatabase.GetNotNotifyConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
return &pbconversation.GetNotNotifyConversationIDsResp{ConversationIDs: conversationIDs}, nil
}
func (c *conversationServer) GetPinnedConversationIDs(ctx context.Context, req *pbconversation.GetPinnedConversationIDsReq) (*pbconversation.GetPinnedConversationIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := c.conversationDatabase.GetPinnedConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
return &pbconversation.GetPinnedConversationIDsResp{ConversationIDs: conversationIDs}, nil
}
func (c *conversationServer) DeleteConversations(ctx context.Context, req *pbconversation.DeleteConversationsReq) (*pbconversation.DeleteConversationsResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
if err := c.conversationDatabase.DeleteUsersConversations(ctx, req.OwnerUserID, req.ConversationIDs); err != nil {
return nil, err
}
return &pbconversation.DeleteConversationsResp{}, nil
}
func (c *conversationServer) ClearUserConversationMsg(ctx context.Context, req *pbconversation.ClearUserConversationMsgReq) (*pbconversation.ClearUserConversationMsgResp, error) {
conversations, err := c.conversationDatabase.FindRandConversation(ctx, req.Timestamp, int(req.Limit))
if err != nil {
return nil, err
}
latestMsgDestructTime := time.UnixMilli(req.Timestamp)
for i, conversation := range conversations {
if conversation.IsMsgDestruct == false || conversation.MsgDestructTime == 0 {
continue
}
seq, err := c.msgClient.GetLastMessageSeqByTime(ctx, conversation.ConversationID, req.Timestamp-(conversation.MsgDestructTime*1000))
if err != nil {
return nil, err
}
if seq <= 0 {
log.ZDebug(ctx, "ClearUserConversationMsg GetLastMessageSeqByTime seq <= 0", "index", i, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID, "msgDestructTime", conversation.MsgDestructTime, "seq", seq)
if err := c.setConversationMinSeqAndLatestMsgDestructTime(ctx, conversation.ConversationID, conversation.OwnerUserID, -1, latestMsgDestructTime); err != nil {
return nil, err
}
continue
}
seq++
if err := c.setConversationMinSeqAndLatestMsgDestructTime(ctx, conversation.ConversationID, conversation.OwnerUserID, seq, latestMsgDestructTime); err != nil {
return nil, err
}
log.ZDebug(ctx, "ClearUserConversationMsg set min seq", "index", i, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID, "seq", seq, "msgDestructTime", conversation.MsgDestructTime)
}
return &pbconversation.ClearUserConversationMsgResp{Count: int32(len(conversations))}, nil
}
func (c *conversationServer) setConversationMinSeqAndLatestMsgDestructTime(ctx context.Context, conversationID string, ownerUserID string, minSeq int64, latestMsgDestructTime time.Time) error {
update := map[string]any{
"latest_msg_destruct_time": latestMsgDestructTime,
}
if minSeq >= 0 {
if err := c.msgClient.SetUserConversationMin(ctx, conversationID, []string{ownerUserID}, minSeq); err != nil {
return err
}
update["min_seq"] = minSeq
}
if err := c.conversationDatabase.UpdateUsersConversationField(ctx, []string{ownerUserID}, conversationID, update); err != nil {
return err
}
c.conversationNotificationSender.ConversationChangeNotification(ctx, ownerUserID, []string{conversationID})
return nil
}

View File

@@ -0,0 +1,85 @@
package conversation
import (
"context"
dbModel "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/conversation"
)
func UpdateConversationsMap(ctx context.Context, req *conversation.SetConversationsReq) (m map[string]any, conversation dbModel.Conversation, err error) {
m = make(map[string]any)
conversation.ConversationID = req.Conversation.ConversationID
conversation.ConversationType = req.Conversation.ConversationType
conversation.UserID = req.Conversation.UserID
conversation.GroupID = req.Conversation.GroupID
if req.Conversation.RecvMsgOpt != nil {
conversation.RecvMsgOpt = req.Conversation.RecvMsgOpt.Value
m["recv_msg_opt"] = req.Conversation.RecvMsgOpt.Value
}
if req.Conversation.AttachedInfo != nil {
conversation.AttachedInfo = req.Conversation.AttachedInfo.Value
m["attached_info"] = req.Conversation.AttachedInfo.Value
}
if req.Conversation.Ex != nil {
conversation.Ex = req.Conversation.Ex.Value
m["ex"] = req.Conversation.Ex.Value
}
if req.Conversation.IsPinned != nil {
conversation.IsPinned = req.Conversation.IsPinned.Value
m["is_pinned"] = req.Conversation.IsPinned.Value
}
if req.Conversation.GroupAtType != nil {
conversation.GroupAtType = req.Conversation.GroupAtType.Value
m["group_at_type"] = req.Conversation.GroupAtType.Value
}
if req.Conversation.MsgDestructTime != nil {
conversation.MsgDestructTime = req.Conversation.MsgDestructTime.Value
m["msg_destruct_time"] = req.Conversation.MsgDestructTime.Value
}
if req.Conversation.IsMsgDestruct != nil {
conversation.IsMsgDestruct = req.Conversation.IsMsgDestruct.Value
m["is_msg_destruct"] = req.Conversation.IsMsgDestruct.Value
}
if req.Conversation.BurnDuration != nil {
conversation.BurnDuration = req.Conversation.BurnDuration.Value
m["burn_duration"] = req.Conversation.BurnDuration.Value
}
return m, conversation, nil
}
func UserUpdateCheckMap(ctx context.Context, userID string, req *conversation.ConversationReq, conversation *dbModel.Conversation) (unequal bool) {
unequal = false
if req.RecvMsgOpt != nil && conversation.RecvMsgOpt != req.RecvMsgOpt.Value {
unequal = true
}
if req.AttachedInfo != nil && conversation.AttachedInfo != req.AttachedInfo.Value {
unequal = true
}
if req.Ex != nil && conversation.Ex != req.Ex.Value {
unequal = true
}
if req.IsPinned != nil && conversation.IsPinned != req.IsPinned.Value {
unequal = true
}
if req.GroupAtType != nil && conversation.GroupAtType != req.GroupAtType.Value {
unequal = true
}
if req.MsgDestructTime != nil && conversation.MsgDestructTime != req.MsgDestructTime.Value {
unequal = true
}
if req.IsMsgDestruct != nil && conversation.IsMsgDestruct != req.IsMsgDestruct.Value {
unequal = true
}
if req.BurnDuration != nil && conversation.BurnDuration != req.BurnDuration.Value {
unequal = true
}
return unequal
}

View File

@@ -0,0 +1,75 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package conversation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
)
type ConversationNotificationSender struct {
*notification.NotificationSender
}
func NewConversationNotificationSender(conf *config.Notification, msgClient *rpcli.MsgClient) *ConversationNotificationSender {
return &ConversationNotificationSender{notification.NewNotificationSender(conf, notification.WithRpcClient(func(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
return msgClient.SendMsg(ctx, req)
}))}
}
// SetPrivate invote.
func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx context.Context, sendID, recvID string,
isPrivateChat bool, conversationID string,
) {
tips := &sdkws.ConversationSetPrivateTips{
RecvID: recvID,
SendID: sendID,
IsPrivate: isPrivateChat,
ConversationID: conversationID,
}
c.Notification(ctx, sendID, recvID, constant.ConversationPrivateChatNotification, tips)
}
func (c *ConversationNotificationSender) ConversationChangeNotification(ctx context.Context, userID string, conversationIDs []string) {
tips := &sdkws.ConversationUpdateTips{
UserID: userID,
ConversationIDList: conversationIDs,
}
c.Notification(ctx, userID, userID, constant.ConversationChangeNotification, tips)
}
func (c *ConversationNotificationSender) ConversationUnreadChangeNotification(
ctx context.Context,
userID, conversationID string,
unreadCountTime, hasReadSeq int64,
) {
tips := &sdkws.ConversationHasReadTips{
UserID: userID,
ConversationID: conversationID,
HasReadSeq: hasReadSeq,
UnreadCountTime: unreadCountTime,
}
c.Notification(ctx, userID, userID, constant.ConversationUnreadNotification, tips)
}

View File

@@ -0,0 +1,63 @@
package conversation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/internal/rpc/incrversion"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/hashutil"
"git.imall.cloud/openim/protocol/conversation"
)
func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, req *conversation.GetFullOwnerConversationIDsReq) (*conversation.GetFullOwnerConversationIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := c.conversationDatabase.FindMaxConversationUserVersionCache(ctx, req.UserID)
if err != nil {
return nil, err
}
conversationIDs, err := c.conversationDatabase.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
idHash := hashutil.IdHash(conversationIDs)
if req.IdHash == idHash {
conversationIDs = nil
}
return &conversation.GetFullOwnerConversationIDsResp{
Version: idHash,
VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash,
ConversationIDs: conversationIDs,
}, nil
}
func (c *conversationServer) GetIncrementalConversation(ctx context.Context, req *conversation.GetIncrementalConversationReq) (*conversation.GetIncrementalConversationResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
opt := incrversion.Option[*conversation.Conversation, conversation.GetIncrementalConversationResp]{
Ctx: ctx,
VersionKey: req.UserID,
VersionID: req.VersionID,
VersionNumber: req.Version,
Version: c.conversationDatabase.FindConversationUserVersion,
CacheMaxVersion: c.conversationDatabase.FindMaxConversationUserVersionCache,
Find: func(ctx context.Context, conversationIDs []string) ([]*conversation.Conversation, error) {
return c.getConversations(ctx, req.UserID, conversationIDs)
},
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*conversation.Conversation, full bool) *conversation.GetIncrementalConversationResp {
return &conversation.GetIncrementalConversationResp{
VersionID: version.ID.Hex(),
Version: uint64(version.Version),
Full: full,
Delete: delIDs,
Insert: insertList,
Update: updateList,
}
},
}
return opt.Build()
}

View File

@@ -0,0 +1,46 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
pbgroup "git.imall.cloud/openim/protocol/group"
)
// GetGroupInfoCache get group info from cache.
func (g *groupServer) GetGroupInfoCache(ctx context.Context, req *pbgroup.GetGroupInfoCacheReq) (*pbgroup.GetGroupInfoCacheResp, error) {
group, err := g.db.TakeGroup(ctx, req.GroupID)
if err != nil {
return nil, err
}
return &pbgroup.GetGroupInfoCacheResp{
GroupInfo: convert.Db2PbGroupInfo(group, "", 0),
}, nil
}
func (g *groupServer) GetGroupMemberCache(ctx context.Context, req *pbgroup.GetGroupMemberCacheReq) (*pbgroup.GetGroupMemberCacheResp, error) {
if err := g.checkAdminOrInGroup(ctx, req.GroupID); err != nil {
return nil, err
}
members, err := g.db.TakeGroupMember(ctx, req.GroupID, req.GroupMemberID)
if err != nil {
return nil, err
}
return &pbgroup.GetGroupMemberCacheResp{
Member: convert.Db2PbGroupMember(members),
}, nil
}

View File

@@ -0,0 +1,476 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/wrapperspb"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
// CallbackBeforeCreateGroup callback before create group.
func (g *groupServer) webhookBeforeCreateGroup(ctx context.Context, before *config.BeforeConfig, req *group.CreateGroupReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeCreateGroupReq{
CallbackCommand: callbackstruct.CallbackBeforeCreateGroupCommand,
OperationID: mcontext.GetOperationID(ctx),
GroupInfo: req.GroupInfo,
}
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: req.OwnerUserID,
RoleLevel: constant.GroupOwner,
})
for _, userID := range req.AdminUserIDs {
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: userID,
RoleLevel: constant.GroupAdmin,
})
}
for _, userID := range req.MemberUserIDs {
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: userID,
RoleLevel: constant.GroupOrdinaryUsers,
})
}
resp := &callbackstruct.CallbackBeforeCreateGroupResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
datautil.NotNilReplace(&req.GroupInfo.GroupID, resp.GroupID)
datautil.NotNilReplace(&req.GroupInfo.GroupName, resp.GroupName)
datautil.NotNilReplace(&req.GroupInfo.Notification, resp.Notification)
datautil.NotNilReplace(&req.GroupInfo.Introduction, resp.Introduction)
datautil.NotNilReplace(&req.GroupInfo.FaceURL, resp.FaceURL)
datautil.NotNilReplace(&req.GroupInfo.OwnerUserID, resp.OwnerUserID)
datautil.NotNilReplace(&req.GroupInfo.Ex, resp.Ex)
datautil.NotNilReplace(&req.GroupInfo.Status, resp.Status)
datautil.NotNilReplace(&req.GroupInfo.CreatorUserID, resp.CreatorUserID)
datautil.NotNilReplace(&req.GroupInfo.GroupType, resp.GroupType)
datautil.NotNilReplace(&req.GroupInfo.NeedVerification, resp.NeedVerification)
datautil.NotNilReplace(&req.GroupInfo.LookMemberInfo, resp.LookMemberInfo)
return nil
})
}
func (g *groupServer) webhookAfterCreateGroup(ctx context.Context, after *config.AfterConfig, req *group.CreateGroupReq) {
cbReq := &callbackstruct.CallbackAfterCreateGroupReq{
CallbackCommand: callbackstruct.CallbackAfterCreateGroupCommand,
GroupInfo: req.GroupInfo,
}
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: req.OwnerUserID,
RoleLevel: constant.GroupOwner,
})
for _, userID := range req.AdminUserIDs {
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: userID,
RoleLevel: constant.GroupAdmin,
})
}
for _, userID := range req.MemberUserIDs {
cbReq.InitMemberList = append(cbReq.InitMemberList, &apistruct.GroupAddMemberInfo{
UserID: userID,
RoleLevel: constant.GroupOrdinaryUsers,
})
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterCreateGroupResp{}, after)
}
func (g *groupServer) webhookBeforeMembersJoinGroup(ctx context.Context, before *config.BeforeConfig, groupMembers []*model.GroupMember, groupID string, groupEx string) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
groupMembersMap := datautil.SliceToMap(groupMembers, func(e *model.GroupMember) string {
return e.UserID
})
var groupMembersCallback []*callbackstruct.CallbackGroupMember
for _, member := range groupMembers {
groupMembersCallback = append(groupMembersCallback, &callbackstruct.CallbackGroupMember{
UserID: member.UserID,
Ex: member.Ex,
})
}
cbReq := &callbackstruct.CallbackBeforeMembersJoinGroupReq{
CallbackCommand: callbackstruct.CallbackBeforeMembersJoinGroupCommand,
GroupID: groupID,
MembersList: groupMembersCallback,
GroupEx: groupEx,
}
resp := &callbackstruct.CallbackBeforeMembersJoinGroupResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
// 记录webhook响应用于排查自动禁言问题
if len(resp.MemberCallbackList) > 0 {
log.ZInfo(ctx, "webhookBeforeMembersJoinGroup: webhook response received",
"groupID", groupID,
"memberCallbackListCount", len(resp.MemberCallbackList),
"memberCallbackList", resp.MemberCallbackList)
}
for _, memberCallbackResp := range resp.MemberCallbackList {
if _, ok := groupMembersMap[(*memberCallbackResp.UserID)]; ok {
if memberCallbackResp.MuteEndTime != nil {
muteEndTimeTimestamp := *memberCallbackResp.MuteEndTime
now := time.Now()
nowUnixMilli := now.UnixMilli()
// 检查时间戳是否合理(防止时间戳单位错误导致自动禁言)
// 如果时间戳小于10000000000002001-09-09可能是秒级时间戳需要转换为毫秒
// 如果时间戳大于当前时间+10年可能是时间戳单位错误
var muteEndTime time.Time
if muteEndTimeTimestamp < 1000000000000 {
// 可能是秒级时间戳,转换为毫秒
log.ZWarn(ctx, "webhookBeforeMembersJoinGroup: MuteEndTime appears to be in seconds, converting to milliseconds", nil,
"groupID", groupID,
"userID", *memberCallbackResp.UserID,
"originalTimestamp", muteEndTimeTimestamp)
muteEndTime = time.Unix(muteEndTimeTimestamp, 0)
} else if muteEndTimeTimestamp > nowUnixMilli+10*365*24*3600*1000 {
// 时间戳超过当前时间10年可能是单位错误忽略
log.ZWarn(ctx, "webhookBeforeMembersJoinGroup: MuteEndTime is too far in the future, ignoring", nil,
"groupID", groupID,
"userID", *memberCallbackResp.UserID,
"muteEndTimeTimestamp", muteEndTimeTimestamp,
"nowUnixMilli", nowUnixMilli)
continue
} else {
muteEndTime = time.UnixMilli(muteEndTimeTimestamp)
}
// 记录webhook返回的禁言时间用于排查自动禁言问题
log.ZInfo(ctx, "webhookBeforeMembersJoinGroup: webhook returned MuteEndTime",
"groupID", groupID,
"userID", *memberCallbackResp.UserID,
"muteEndTimeTimestamp", muteEndTimeTimestamp,
"muteEndTime", muteEndTime.Format(time.RFC3339),
"now", now.Format(time.RFC3339),
"isMuted", muteEndTime.After(now),
"mutedDurationSeconds", muteEndTime.Sub(now).Seconds())
groupMembersMap[(*memberCallbackResp.UserID)].MuteEndTime = muteEndTime
}
datautil.NotNilReplace(&groupMembersMap[(*memberCallbackResp.UserID)].FaceURL, memberCallbackResp.FaceURL)
datautil.NotNilReplace(&groupMembersMap[(*memberCallbackResp.UserID)].Ex, memberCallbackResp.Ex)
datautil.NotNilReplace(&groupMembersMap[(*memberCallbackResp.UserID)].Nickname, memberCallbackResp.Nickname)
datautil.NotNilReplace(&groupMembersMap[(*memberCallbackResp.UserID)].RoleLevel, memberCallbackResp.RoleLevel)
}
}
return nil
})
}
func (g *groupServer) webhookBeforeSetGroupMemberInfo(ctx context.Context, before *config.BeforeConfig, req *group.SetGroupMemberInfo) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := callbackstruct.CallbackBeforeSetGroupMemberInfoReq{
CallbackCommand: callbackstruct.CallbackBeforeSetGroupMemberInfoCommand,
GroupID: req.GroupID,
UserID: req.UserID,
}
if req.Nickname != nil {
cbReq.Nickname = &req.Nickname.Value
}
if req.FaceURL != nil {
cbReq.FaceURL = &req.FaceURL.Value
}
if req.RoleLevel != nil {
cbReq.RoleLevel = &req.RoleLevel.Value
}
if req.Ex != nil {
cbReq.Ex = &req.Ex.Value
}
resp := &callbackstruct.CallbackBeforeSetGroupMemberInfoResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
if resp.FaceURL != nil {
req.FaceURL = wrapperspb.String(*resp.FaceURL)
}
if resp.Nickname != nil {
req.Nickname = wrapperspb.String(*resp.Nickname)
}
if resp.RoleLevel != nil {
req.RoleLevel = wrapperspb.Int32(*resp.RoleLevel)
}
if resp.Ex != nil {
req.Ex = wrapperspb.String(*resp.Ex)
}
return nil
})
}
func (g *groupServer) webhookAfterSetGroupMemberInfo(ctx context.Context, after *config.AfterConfig, req *group.SetGroupMemberInfo) {
cbReq := callbackstruct.CallbackAfterSetGroupMemberInfoReq{
CallbackCommand: callbackstruct.CallbackAfterSetGroupMemberInfoCommand,
GroupID: req.GroupID,
UserID: req.UserID,
}
if req.Nickname != nil {
cbReq.Nickname = &req.Nickname.Value
}
if req.FaceURL != nil {
cbReq.FaceURL = &req.FaceURL.Value
}
if req.RoleLevel != nil {
cbReq.RoleLevel = &req.RoleLevel.Value
}
if req.Ex != nil {
cbReq.Ex = &req.Ex.Value
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterSetGroupMemberInfoResp{}, after)
}
func (g *groupServer) webhookAfterQuitGroup(ctx context.Context, after *config.AfterConfig, req *group.QuitGroupReq) {
cbReq := &callbackstruct.CallbackQuitGroupReq{
CallbackCommand: callbackstruct.CallbackAfterQuitGroupCommand,
GroupID: req.GroupID,
UserID: req.UserID,
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackQuitGroupResp{}, after)
}
func (g *groupServer) webhookAfterKickGroupMember(ctx context.Context, after *config.AfterConfig, req *group.KickGroupMemberReq) {
cbReq := &callbackstruct.CallbackKillGroupMemberReq{
CallbackCommand: callbackstruct.CallbackAfterKickGroupCommand,
GroupID: req.GroupID,
KickedUserIDs: req.KickedUserIDs,
Reason: req.Reason,
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackKillGroupMemberResp{}, after)
}
func (g *groupServer) webhookAfterDismissGroup(ctx context.Context, after *config.AfterConfig, req *callbackstruct.CallbackDisMissGroupReq) {
req.CallbackCommand = callbackstruct.CallbackAfterDisMissGroupCommand
g.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &callbackstruct.CallbackDisMissGroupResp{}, after)
}
func (g *groupServer) webhookBeforeApplyJoinGroup(ctx context.Context, before *config.BeforeConfig, req *callbackstruct.CallbackJoinGroupReq) (err error) {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
req.CallbackCommand = callbackstruct.CallbackBeforeJoinGroupCommand
resp := &callbackstruct.CallbackJoinGroupResp{}
if err := g.webhookClient.SyncPost(ctx, req.GetCallbackCommand(), req, resp, before); err != nil {
return err
}
return nil
})
}
func (g *groupServer) webhookAfterTransferGroupOwner(ctx context.Context, after *config.AfterConfig, req *group.TransferGroupOwnerReq) {
cbReq := &callbackstruct.CallbackTransferGroupOwnerReq{
CallbackCommand: callbackstruct.CallbackAfterTransferGroupOwnerCommand,
GroupID: req.GroupID,
OldOwnerUserID: req.OldOwnerUserID,
NewOwnerUserID: req.NewOwnerUserID,
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackTransferGroupOwnerResp{}, after)
}
func (g *groupServer) webhookBeforeInviteUserToGroup(ctx context.Context, before *config.BeforeConfig, req *group.InviteUserToGroupReq) (err error) {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeInviteUserToGroupReq{
CallbackCommand: callbackstruct.CallbackBeforeInviteJoinGroupCommand,
OperationID: mcontext.GetOperationID(ctx),
GroupID: req.GroupID,
Reason: req.Reason,
InvitedUserIDs: req.InvitedUserIDs,
}
resp := &callbackstruct.CallbackBeforeInviteUserToGroupResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
// Handle the scenario where certain members are refused
// You might want to update the req.Members list or handle it as per your business logic
// if len(resp.RefusedMembersAccount) > 0 {
// implement members are refused
// }
return nil
})
}
func (g *groupServer) webhookAfterJoinGroup(ctx context.Context, after *config.AfterConfig, req *group.JoinGroupReq) {
cbReq := &callbackstruct.CallbackAfterJoinGroupReq{
CallbackCommand: callbackstruct.CallbackAfterJoinGroupCommand,
OperationID: mcontext.GetOperationID(ctx),
GroupID: req.GroupID,
ReqMessage: req.ReqMessage,
JoinSource: req.JoinSource,
InviterUserID: req.InviterUserID,
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterJoinGroupResp{}, after)
}
func (g *groupServer) webhookBeforeSetGroupInfo(ctx context.Context, before *config.BeforeConfig, req *group.SetGroupInfoReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeSetGroupInfoReq{
CallbackCommand: callbackstruct.CallbackBeforeSetGroupInfoCommand,
GroupID: req.GroupInfoForSet.GroupID,
Notification: req.GroupInfoForSet.Notification,
Introduction: req.GroupInfoForSet.Introduction,
FaceURL: req.GroupInfoForSet.FaceURL,
GroupName: req.GroupInfoForSet.GroupName,
}
if req.GroupInfoForSet.Ex != nil {
cbReq.Ex = req.GroupInfoForSet.Ex.Value
}
log.ZDebug(ctx, "debug CallbackBeforeSetGroupInfo", "ex", cbReq.Ex)
if req.GroupInfoForSet.NeedVerification != nil {
cbReq.NeedVerification = req.GroupInfoForSet.NeedVerification.Value
}
if req.GroupInfoForSet.LookMemberInfo != nil {
cbReq.LookMemberInfo = req.GroupInfoForSet.LookMemberInfo.Value
}
if req.GroupInfoForSet.ApplyMemberFriend != nil {
cbReq.ApplyMemberFriend = req.GroupInfoForSet.ApplyMemberFriend.Value
}
resp := &callbackstruct.CallbackBeforeSetGroupInfoResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
if resp.Ex != nil {
req.GroupInfoForSet.Ex = wrapperspb.String(*resp.Ex)
}
if resp.NeedVerification != nil {
req.GroupInfoForSet.NeedVerification = wrapperspb.Int32(*resp.NeedVerification)
}
if resp.LookMemberInfo != nil {
req.GroupInfoForSet.LookMemberInfo = wrapperspb.Int32(*resp.LookMemberInfo)
}
if resp.ApplyMemberFriend != nil {
req.GroupInfoForSet.ApplyMemberFriend = wrapperspb.Int32(*resp.ApplyMemberFriend)
}
datautil.NotNilReplace(&req.GroupInfoForSet.GroupID, &resp.GroupID)
datautil.NotNilReplace(&req.GroupInfoForSet.GroupName, &resp.GroupName)
datautil.NotNilReplace(&req.GroupInfoForSet.FaceURL, &resp.FaceURL)
datautil.NotNilReplace(&req.GroupInfoForSet.Introduction, &resp.Introduction)
return nil
})
}
func (g *groupServer) webhookAfterSetGroupInfo(ctx context.Context, after *config.AfterConfig, req *group.SetGroupInfoReq) {
cbReq := &callbackstruct.CallbackAfterSetGroupInfoReq{
CallbackCommand: callbackstruct.CallbackAfterSetGroupInfoCommand,
GroupID: req.GroupInfoForSet.GroupID,
Notification: req.GroupInfoForSet.Notification,
Introduction: req.GroupInfoForSet.Introduction,
FaceURL: req.GroupInfoForSet.FaceURL,
GroupName: req.GroupInfoForSet.GroupName,
}
if req.GroupInfoForSet.Ex != nil {
cbReq.Ex = &req.GroupInfoForSet.Ex.Value
}
if req.GroupInfoForSet.NeedVerification != nil {
cbReq.NeedVerification = &req.GroupInfoForSet.NeedVerification.Value
}
if req.GroupInfoForSet.LookMemberInfo != nil {
cbReq.LookMemberInfo = &req.GroupInfoForSet.LookMemberInfo.Value
}
if req.GroupInfoForSet.ApplyMemberFriend != nil {
cbReq.ApplyMemberFriend = &req.GroupInfoForSet.ApplyMemberFriend.Value
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterSetGroupInfoResp{}, after)
}
func (g *groupServer) webhookBeforeSetGroupInfoEx(ctx context.Context, before *config.BeforeConfig, req *group.SetGroupInfoExReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &callbackstruct.CallbackBeforeSetGroupInfoExReq{
CallbackCommand: callbackstruct.CallbackBeforeSetGroupInfoExCommand,
GroupID: req.GroupID,
GroupName: req.GroupName,
Notification: req.Notification,
Introduction: req.Introduction,
FaceURL: req.FaceURL,
}
if req.Ex != nil {
cbReq.Ex = req.Ex
}
log.ZDebug(ctx, "debug CallbackBeforeSetGroupInfoEx", "ex", cbReq.Ex)
if req.NeedVerification != nil {
cbReq.NeedVerification = req.NeedVerification
}
if req.LookMemberInfo != nil {
cbReq.LookMemberInfo = req.LookMemberInfo
}
if req.ApplyMemberFriend != nil {
cbReq.ApplyMemberFriend = req.ApplyMemberFriend
}
resp := &callbackstruct.CallbackBeforeSetGroupInfoExResp{}
if err := g.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
datautil.NotNilReplace(&req.GroupID, &resp.GroupID)
datautil.NotNilReplace(&req.GroupName, &resp.GroupName)
datautil.NotNilReplace(&req.FaceURL, &resp.FaceURL)
datautil.NotNilReplace(&req.Introduction, &resp.Introduction)
datautil.NotNilReplace(&req.Ex, &resp.Ex)
datautil.NotNilReplace(&req.NeedVerification, &resp.NeedVerification)
datautil.NotNilReplace(&req.LookMemberInfo, &resp.LookMemberInfo)
datautil.NotNilReplace(&req.ApplyMemberFriend, &resp.ApplyMemberFriend)
return nil
})
}
func (g *groupServer) webhookAfterSetGroupInfoEx(ctx context.Context, after *config.AfterConfig, req *group.SetGroupInfoExReq) {
cbReq := &callbackstruct.CallbackAfterSetGroupInfoExReq{
CallbackCommand: callbackstruct.CallbackAfterSetGroupInfoExCommand,
GroupID: req.GroupID,
GroupName: req.GroupName,
Notification: req.Notification,
Introduction: req.Introduction,
FaceURL: req.FaceURL,
}
if req.Ex != nil {
cbReq.Ex = req.Ex
}
if req.NeedVerification != nil {
cbReq.NeedVerification = req.NeedVerification
}
if req.LookMemberInfo != nil {
cbReq.LookMemberInfo = req.LookMemberInfo
}
if req.ApplyMemberFriend != nil {
cbReq.ApplyMemberFriend = req.ApplyMemberFriend
}
g.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &callbackstruct.CallbackAfterSetGroupInfoExResp{}, after)
}

View File

@@ -0,0 +1,63 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/sdkws"
)
func (g *groupServer) groupDB2PB(group *model.Group, ownerUserID string, memberCount uint32) *sdkws.GroupInfo {
return &sdkws.GroupInfo{
GroupID: group.GroupID,
GroupName: group.GroupName,
Notification: group.Notification,
Introduction: group.Introduction,
FaceURL: group.FaceURL,
OwnerUserID: ownerUserID,
CreateTime: group.CreateTime.UnixMilli(),
MemberCount: memberCount,
Ex: group.Ex,
Status: group.Status,
CreatorUserID: group.CreatorUserID,
GroupType: group.GroupType,
NeedVerification: group.NeedVerification,
LookMemberInfo: group.LookMemberInfo,
ApplyMemberFriend: group.ApplyMemberFriend,
NotificationUpdateTime: group.NotificationUpdateTime.UnixMilli(),
NotificationUserID: group.NotificationUserID,
}
}
func (g *groupServer) groupMemberDB2PB(member *model.GroupMember, appMangerLevel int32) *sdkws.GroupMemberFullInfo {
return &sdkws.GroupMemberFullInfo{
GroupID: member.GroupID,
UserID: member.UserID,
RoleLevel: member.RoleLevel,
JoinTime: member.JoinTime.UnixMilli(),
Nickname: member.Nickname,
FaceURL: member.FaceURL,
AppMangerLevel: appMangerLevel,
JoinSource: member.JoinSource,
OperatorUserID: member.OperatorUserID,
Ex: member.Ex,
MuteEndTime: member.MuteEndTime.UnixMilli(),
InviterUserID: member.InviterUserID,
}
}
func (g *groupServer) groupMemberDB2PB2(member *model.GroupMember) *sdkws.GroupMemberFullInfo {
return g.groupMemberDB2PB(member, 0)
}

View File

@@ -0,0 +1,134 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
"strings"
"time"
pbgroup "git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
)
func UpdateGroupInfoMap(ctx context.Context, group *sdkws.GroupInfoForSet) map[string]any {
m := make(map[string]any)
if group.GroupName != "" {
m["group_name"] = group.GroupName
}
if group.Notification != "" {
m["notification"] = group.Notification
m["notification_update_time"] = time.Now()
m["notification_user_id"] = mcontext.GetOpUserID(ctx)
}
if group.Introduction != "" {
m["introduction"] = group.Introduction
}
if group.FaceURL != "" {
m["face_url"] = group.FaceURL
}
if group.NeedVerification != nil {
m["need_verification"] = group.NeedVerification.Value
}
if group.LookMemberInfo != nil {
m["look_member_info"] = group.LookMemberInfo.Value
}
if group.ApplyMemberFriend != nil {
m["apply_member_friend"] = group.ApplyMemberFriend.Value
}
if group.Ex != nil {
m["ex"] = group.Ex.Value
}
return m
}
func UpdateGroupInfoExMap(ctx context.Context, group *pbgroup.SetGroupInfoExReq) (m map[string]any, normalFlag, groupNameFlag, notificationFlag bool, err error) {
m = make(map[string]any)
if group.GroupName != nil {
if strings.TrimSpace(group.GroupName.Value) != "" {
m["group_name"] = group.GroupName.Value
groupNameFlag = true
} else {
return nil, normalFlag, notificationFlag, groupNameFlag, errs.ErrArgs.WrapMsg("group name is empty")
}
}
if group.Notification != nil {
notificationFlag = true
group.Notification.Value = strings.TrimSpace(group.Notification.Value) // if Notification only contains spaces, set it to empty string
m["notification"] = group.Notification.Value
m["notification_user_id"] = mcontext.GetOpUserID(ctx)
m["notification_update_time"] = time.Now()
}
if group.Introduction != nil {
m["introduction"] = group.Introduction.Value
normalFlag = true
}
if group.FaceURL != nil {
m["face_url"] = group.FaceURL.Value
normalFlag = true
}
if group.NeedVerification != nil {
m["need_verification"] = group.NeedVerification.Value
normalFlag = true
}
if group.LookMemberInfo != nil {
m["look_member_info"] = group.LookMemberInfo.Value
normalFlag = true
}
if group.ApplyMemberFriend != nil {
m["apply_member_friend"] = group.ApplyMemberFriend.Value
normalFlag = true
}
if group.Ex != nil {
m["ex"] = group.Ex.Value
normalFlag = true
}
return m, normalFlag, groupNameFlag, notificationFlag, nil
}
func UpdateGroupStatusMap(status int) map[string]any {
return map[string]any{
"status": status,
}
}
func UpdateGroupMemberMutedTimeMap(t time.Time) map[string]any {
return map[string]any{
"mute_end_time": t,
}
}
func UpdateGroupMemberMap(req *pbgroup.SetGroupMemberInfo) map[string]any {
m := make(map[string]any)
if req.Nickname != nil {
m["nickname"] = req.Nickname.Value
}
if req.FaceURL != nil {
m["face_url"] = req.FaceURL.Value
}
if req.RoleLevel != nil {
m["role_level"] = req.RoleLevel.Value
}
if req.Ex != nil {
m["ex"] = req.Ex.Value
}
return m
}

View File

@@ -0,0 +1,25 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
)
func (g *groupServer) PopulateGroupMember(ctx context.Context, members ...*relationtb.GroupMember) error {
return g.notification.PopulateGroupMember(ctx, members...)
}

2096
internal/rpc/group/group.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,930 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
"errors"
"fmt"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/google/uuid"
"go.mongodb.org/mongo-driver/mongo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/versionctx"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification/common_user"
"git.imall.cloud/openim/protocol/constant"
pbgroup "git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/stringutil"
)
// GroupApplicationReceiver
const (
applicantReceiver = iota
adminReceiver
)
func NewNotificationSender(db controller.GroupDatabase, config *Config, userClient *rpcli.UserClient, msgClient *rpcli.MsgClient, conversationClient *rpcli.ConversationClient) *NotificationSender {
return &NotificationSender{
NotificationSender: notification.NewNotificationSender(&config.NotificationConfig,
notification.WithRpcClient(func(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
return msgClient.SendMsg(ctx, req)
}),
notification.WithUserRpcClient(userClient.GetUserInfo),
),
getUsersInfo: func(ctx context.Context, userIDs []string) ([]common_user.CommonUser, error) {
users, err := userClient.GetUsersInfo(ctx, userIDs)
if err != nil {
return nil, err
}
return datautil.Slice(users, func(e *sdkws.UserInfo) common_user.CommonUser { return e }), nil
},
db: db,
config: config,
msgClient: msgClient,
conversationClient: conversationClient,
}
}
type NotificationSender struct {
*notification.NotificationSender
getUsersInfo func(ctx context.Context, userIDs []string) ([]common_user.CommonUser, error)
db controller.GroupDatabase
config *Config
msgClient *rpcli.MsgClient
conversationClient *rpcli.ConversationClient
}
func (g *NotificationSender) PopulateGroupMember(ctx context.Context, members ...*model.GroupMember) error {
if len(members) == 0 {
return nil
}
// 收集所有需要填充用户信息的UserID
userIDsMap := make(map[string]struct{})
for _, member := range members {
userIDsMap[member.UserID] = struct{}{}
}
// 获取所有用户信息
users, err := g.getUsersInfo(ctx, datautil.Keys(userIDsMap))
if err != nil {
return err
}
// 构建用户信息map
userMap := make(map[string]common_user.CommonUser)
for i, user := range users {
userMap[user.GetUserID()] = users[i]
}
// 填充群成员信息
for i, member := range members {
user, ok := userMap[member.UserID]
if !ok {
continue
}
// 填充昵称和头像
if member.Nickname == "" {
members[i].Nickname = user.GetNickname()
}
if member.FaceURL == "" {
members[i].FaceURL = user.GetFaceURL()
}
// 填充UserType和UserFlag到Ex字段
// 先解析现有的Ex字段如果有的话
var exData map[string]interface{}
if members[i].Ex != "" {
_ = jsonutil.JsonUnmarshal([]byte(members[i].Ex), &exData)
}
if exData == nil {
exData = make(map[string]interface{})
}
// 添加userType和userFlag
exData["userType"] = user.GetUserType()
exData["userFlag"] = user.GetUserFlag()
exJSON, _ := jsonutil.JsonMarshal(exData)
members[i].Ex = string(exJSON)
}
return nil
}
func (g *NotificationSender) getUser(ctx context.Context, userID string) (*sdkws.PublicUserInfo, error) {
users, err := g.getUsersInfo(ctx, []string{userID})
if err != nil {
return nil, err
}
if len(users) == 0 {
return nil, servererrs.ErrUserIDNotFound.WrapMsg(fmt.Sprintf("user %s not found", userID))
}
return &sdkws.PublicUserInfo{
UserID: users[0].GetUserID(),
Nickname: users[0].GetNickname(),
FaceURL: users[0].GetFaceURL(),
Ex: users[0].GetEx(),
UserType: users[0].GetUserType(),
}, nil
}
func (g *NotificationSender) getGroupInfo(ctx context.Context, groupID string) (*sdkws.GroupInfo, error) {
gm, err := g.db.TakeGroup(ctx, groupID)
if err != nil {
return nil, err
}
num, err := g.db.FindGroupMemberNum(ctx, groupID)
if err != nil {
return nil, err
}
ownerUserIDs, err := g.db.GetGroupRoleLevelMemberIDs(ctx, groupID, constant.GroupOwner)
if err != nil {
return nil, err
}
var ownerUserID string
if len(ownerUserIDs) > 0 {
ownerUserID = ownerUserIDs[0]
}
return convert.Db2PbGroupInfo(gm, ownerUserID, num), nil
}
func (g *NotificationSender) getGroupMembers(ctx context.Context, groupID string, userIDs []string) ([]*sdkws.GroupMemberFullInfo, error) {
members, err := g.db.FindGroupMembers(ctx, groupID, userIDs)
if err != nil {
return nil, err
}
if err := g.PopulateGroupMember(ctx, members...); err != nil {
return nil, err
}
log.ZDebug(ctx, "getGroupMembers", "members", members)
res := make([]*sdkws.GroupMemberFullInfo, 0, len(members))
for _, member := range members {
res = append(res, g.groupMemberDB2PB(member, 0))
}
return res, nil
}
func (g *NotificationSender) getGroupMemberMap(ctx context.Context, groupID string, userIDs []string) (map[string]*sdkws.GroupMemberFullInfo, error) {
members, err := g.getGroupMembers(ctx, groupID, userIDs)
if err != nil {
return nil, err
}
m := make(map[string]*sdkws.GroupMemberFullInfo)
for i, member := range members {
m[member.UserID] = members[i]
}
return m, nil
}
func (g *NotificationSender) getGroupMember(ctx context.Context, groupID string, userID string) (*sdkws.GroupMemberFullInfo, error) {
members, err := g.getGroupMembers(ctx, groupID, []string{userID})
if err != nil {
return nil, err
}
if len(members) == 0 {
return nil, errs.ErrInternalServer.WrapMsg(fmt.Sprintf("group %s member %s not found", groupID, userID))
}
return members[0], nil
}
func (g *NotificationSender) getGroupOwnerAndAdminUserID(ctx context.Context, groupID string) ([]string, error) {
members, err := g.db.FindGroupMemberRoleLevels(ctx, groupID, []int32{constant.GroupOwner, constant.GroupAdmin})
if err != nil {
return nil, err
}
if err := g.PopulateGroupMember(ctx, members...); err != nil {
return nil, err
}
fn := func(e *model.GroupMember) string { return e.UserID }
return datautil.Slice(members, fn), nil
}
func (g *NotificationSender) groupMemberDB2PB(member *model.GroupMember, appMangerLevel int32) *sdkws.GroupMemberFullInfo {
return &sdkws.GroupMemberFullInfo{
GroupID: member.GroupID,
UserID: member.UserID,
RoleLevel: member.RoleLevel,
JoinTime: member.JoinTime.UnixMilli(),
Nickname: member.Nickname,
FaceURL: member.FaceURL,
AppMangerLevel: appMangerLevel,
JoinSource: member.JoinSource,
OperatorUserID: member.OperatorUserID,
Ex: member.Ex,
MuteEndTime: member.MuteEndTime.UnixMilli(),
InviterUserID: member.InviterUserID,
}
}
/* func (g *NotificationSender) getUsersInfoMap(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error) {
users, err := g.getUsersInfo(ctx, userIDs)
if err != nil {
return nil, err
}
result := make(map[string]*sdkws.UserInfo)
for _, user := range users {
result[user.GetUserID()] = user.(*sdkws.UserInfo)
}
return result, nil
} */
func (g *NotificationSender) fillOpUser(ctx context.Context, targetUser **sdkws.GroupMemberFullInfo, groupID string) (err error) {
return g.fillUserByUserID(ctx, mcontext.GetOpUserID(ctx), targetUser, groupID)
}
func (g *NotificationSender) fillUserByUserID(ctx context.Context, userID string, targetUser **sdkws.GroupMemberFullInfo, groupID string) error {
if targetUser == nil {
return errs.ErrInternalServer.WrapMsg("**sdkws.GroupMemberFullInfo is nil")
}
if groupID != "" {
if authverify.CheckUserIsAdmin(ctx, userID) {
*targetUser = &sdkws.GroupMemberFullInfo{
GroupID: groupID,
UserID: userID,
RoleLevel: constant.GroupAdmin,
AppMangerLevel: constant.AppAdmin,
}
} else {
member, err := g.db.TakeGroupMember(ctx, groupID, userID)
if err == nil {
*targetUser = g.groupMemberDB2PB(member, 0)
} else if !(errors.Is(err, mongo.ErrNoDocuments) || errs.ErrRecordNotFound.Is(err)) {
return err
}
}
}
user, err := g.getUser(ctx, userID)
if err != nil {
return err
}
if *targetUser == nil {
*targetUser = &sdkws.GroupMemberFullInfo{
GroupID: groupID,
UserID: userID,
Nickname: user.Nickname,
FaceURL: user.FaceURL,
OperatorUserID: userID,
}
} else {
if (*targetUser).Nickname == "" {
(*targetUser).Nickname = user.Nickname
}
if (*targetUser).FaceURL == "" {
(*targetUser).FaceURL = user.FaceURL
}
}
return nil
}
func (g *NotificationSender) setVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string) {
versions := versionctx.GetVersionLog(ctx).Get()
for i := len(versions) - 1; i >= 0; i-- {
coll := versions[i]
if coll.Name == collName && coll.Doc.DID == id {
*version = uint64(coll.Doc.Version)
*versionID = coll.Doc.ID.Hex()
return
}
}
}
func (g *NotificationSender) setSortVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string, sortVersion *uint64) {
versions := versionctx.GetVersionLog(ctx).Get()
for _, coll := range versions {
if coll.Name == collName && coll.Doc.DID == id {
*version = uint64(coll.Doc.Version)
*versionID = coll.Doc.ID.Hex()
for _, elem := range coll.Doc.Logs {
if elem.EID == model.VersionSortChangeID {
*sortVersion = uint64(elem.Version)
}
}
}
}
}
func (g *NotificationSender) GroupCreatedNotification(ctx context.Context, tips *sdkws.GroupCreatedTips, SendMessage *bool) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupCreatedNotification, tips, notification.WithSendMessage(SendMessage))
}
func (g *NotificationSender) GroupInfoSetNotification(ctx context.Context, tips *sdkws.GroupInfoSetTips) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetNotification, tips, notification.WithRpcGetUserName())
}
func (g *NotificationSender) GroupInfoSetNameNotification(ctx context.Context, tips *sdkws.GroupInfoSetNameTips) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetNameNotification, tips)
}
func (g *NotificationSender) GroupInfoSetAnnouncementNotification(ctx context.Context, tips *sdkws.GroupInfoSetAnnouncementTips, sendMessage *bool) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupInfoSetAnnouncementNotification, tips, notification.WithRpcGetUserName(), notification.WithSendMessage(sendMessage))
}
func (g *NotificationSender) uuid() string {
return uuid.New().String()
}
func (g *NotificationSender) getGroupRequest(ctx context.Context, groupID string, userID string) (*sdkws.GroupRequest, error) {
request, err := g.db.TakeGroupRequest(ctx, groupID, userID)
if err != nil {
return nil, err
}
users, err := g.getUsersInfo(ctx, []string{userID})
if err != nil {
return nil, err
}
if len(users) == 0 {
return nil, servererrs.ErrUserIDNotFound.WrapMsg(fmt.Sprintf("user %s not found", userID))
}
info, ok := users[0].(*sdkws.UserInfo)
if !ok {
info = &sdkws.UserInfo{
UserID: users[0].GetUserID(),
Nickname: users[0].GetNickname(),
FaceURL: users[0].GetFaceURL(),
Ex: users[0].GetEx(),
}
}
return convert.Db2PbGroupRequest(request, info, nil), nil
}
func (g *NotificationSender) JoinGroupApplicationNotification(ctx context.Context, req *pbgroup.JoinGroupReq, dbReq *model.GroupRequest) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
request, err := g.getGroupRequest(ctx, dbReq.GroupID, dbReq.UserID)
if err != nil {
log.ZError(ctx, "JoinGroupApplicationNotification getGroupRequest", err, "dbReq", dbReq)
return
}
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, req.GroupID)
if err != nil {
return
}
var user *sdkws.PublicUserInfo
user, err = g.getUser(ctx, req.InviterUserID)
if err != nil {
return
}
userIDs, err := g.getGroupOwnerAndAdminUserID(ctx, req.GroupID)
if err != nil {
return
}
userIDs = append(userIDs, req.InviterUserID, mcontext.GetOpUserID(ctx))
tips := &sdkws.JoinGroupApplicationTips{
Group: group,
Applicant: user,
ReqMsg: req.ReqMessage,
Uuid: g.uuid(),
Request: request,
}
for _, userID := range datautil.Distinct(userIDs) {
g.Notification(ctx, mcontext.GetOpUserID(ctx), userID, constant.JoinGroupApplicationNotification, tips)
}
}
func (g *NotificationSender) MemberQuitNotification(ctx context.Context, member *sdkws.GroupMemberFullInfo) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, member.GroupID)
if err != nil {
return
}
tips := &sdkws.MemberQuitTips{Group: group, QuitUser: member}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, member.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), member.GroupID, constant.MemberQuitNotification, tips)
}
func (g *NotificationSender) GroupApplicationAcceptedNotification(ctx context.Context, req *pbgroup.GroupApplicationResponseReq) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
request, err := g.getGroupRequest(ctx, req.GroupID, req.FromUserID)
if err != nil {
log.ZError(ctx, "GroupApplicationAcceptedNotification getGroupRequest", err, "req", req)
return
}
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, req.GroupID)
if err != nil {
return
}
var userIDs []string
userIDs, err = g.getGroupOwnerAndAdminUserID(ctx, req.GroupID)
if err != nil {
return
}
var opUser *sdkws.GroupMemberFullInfo
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
return
}
tips := &sdkws.GroupApplicationAcceptedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: g.uuid(),
Request: request,
}
for _, userID := range append(userIDs, req.FromUserID) {
if userID == req.FromUserID {
tips.ReceiverAs = applicantReceiver
} else {
tips.ReceiverAs = adminReceiver
}
g.Notification(ctx, mcontext.GetOpUserID(ctx), userID, constant.GroupApplicationAcceptedNotification, tips)
}
}
func (g *NotificationSender) GroupApplicationRejectedNotification(ctx context.Context, req *pbgroup.GroupApplicationResponseReq) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
request, err := g.getGroupRequest(ctx, req.GroupID, req.FromUserID)
if err != nil {
log.ZError(ctx, "GroupApplicationAcceptedNotification getGroupRequest", err, "req", req)
return
}
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, req.GroupID)
if err != nil {
return
}
var userIDs []string
userIDs, err = g.getGroupOwnerAndAdminUserID(ctx, req.GroupID)
if err != nil {
return
}
var opUser *sdkws.GroupMemberFullInfo
if err = g.fillOpUser(ctx, &opUser, group.GroupID); err != nil {
return
}
tips := &sdkws.GroupApplicationRejectedTips{
Group: group,
OpUser: opUser,
HandleMsg: req.HandledMsg,
Uuid: g.uuid(),
Request: request,
}
for _, userID := range append(userIDs, req.FromUserID) {
if userID == req.FromUserID {
tips.ReceiverAs = applicantReceiver
} else {
tips.ReceiverAs = adminReceiver
}
g.Notification(ctx, mcontext.GetOpUserID(ctx), userID, constant.GroupApplicationRejectedNotification, tips)
}
}
func (g *NotificationSender) GroupOwnerTransferredNotification(ctx context.Context, req *pbgroup.TransferGroupOwnerReq) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, req.GroupID)
if err != nil {
return
}
opUserID := mcontext.GetOpUserID(ctx)
var member map[string]*sdkws.GroupMemberFullInfo
member, err = g.getGroupMemberMap(ctx, req.GroupID, []string{opUserID, req.NewOwnerUserID, req.OldOwnerUserID})
if err != nil {
return
}
tips := &sdkws.GroupOwnerTransferredTips{
Group: group,
OpUser: member[opUserID],
NewGroupOwner: member[req.NewOwnerUserID],
OldGroupOwnerInfo: member[req.OldOwnerUserID],
}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, req.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupOwnerTransferredNotification, tips)
}
func (g *NotificationSender) MemberKickedNotification(ctx context.Context, tips *sdkws.MemberKickedTips, SendMessage *bool) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.MemberKickedNotification, tips, notification.WithSendMessage(SendMessage))
}
func (g *NotificationSender) GroupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error {
return g.groupApplicationAgreeMemberEnterNotification(ctx, groupID, SendMessage, invitedOpUserID, entrantUserID...)
}
func (g *NotificationSender) groupApplicationAgreeMemberEnterNotification(ctx context.Context, groupID string, SendMessage *bool, invitedOpUserID string, entrantUserID ...string) error {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if !g.config.RpcConfig.EnableHistoryForNewMembers {
conversationID := msgprocessor.GetConversationIDBySessionType(constant.ReadGroupChatType, groupID)
maxSeq, err := g.msgClient.GetConversationMaxSeq(ctx, conversationID)
if err != nil {
return err
}
if err := g.msgClient.SetUserConversationsMinSeq(ctx, conversationID, entrantUserID, maxSeq+1); err != nil {
return err
}
}
if err := g.conversationClient.CreateGroupChatConversations(ctx, groupID, entrantUserID); err != nil {
return err
}
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return err
}
users, err := g.getGroupMembers(ctx, groupID, entrantUserID)
if err != nil {
return err
}
tips := &sdkws.MemberInvitedTips{
Group: group,
InvitedUserList: users,
}
opUserID := mcontext.GetOpUserID(ctx)
if err = g.fillUserByUserID(ctx, opUserID, &tips.OpUser, tips.Group.GroupID); err != nil {
return nil
}
if invitedOpUserID == opUserID {
tips.InviterUser = tips.OpUser
} else {
if err = g.fillUserByUserID(ctx, invitedOpUserID, &tips.InviterUser, tips.Group.GroupID); err != nil {
return err
}
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.MemberInvitedNotification, tips, notification.WithSendMessage(SendMessage))
return nil
}
func (g *NotificationSender) MemberEnterNotification(ctx context.Context, groupID string, entrantUserID string) error {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if !g.config.RpcConfig.EnableHistoryForNewMembers {
conversationID := msgprocessor.GetConversationIDBySessionType(constant.ReadGroupChatType, groupID)
maxSeq, err := g.msgClient.GetConversationMaxSeq(ctx, conversationID)
if err != nil {
return err
}
if err := g.msgClient.SetUserConversationsMinSeq(ctx, conversationID, []string{entrantUserID}, maxSeq+1); err != nil {
return err
}
}
if err := g.conversationClient.CreateGroupChatConversations(ctx, groupID, []string{entrantUserID}); err != nil {
return err
}
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return err
}
user, err := g.getGroupMember(ctx, groupID, entrantUserID)
if err != nil {
return err
}
tips := &sdkws.MemberEnterTips{
Group: group,
EntrantUser: user,
OperationTime: time.Now().UnixMilli(),
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
// 群内广播:通知所有群成员有人入群
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.MemberEnterNotification, tips)
// 给入群本人发送一条系统通知(单聊),便于客户端展示“你已加入群聊”
g.NotificationWithSessionType(ctx, mcontext.GetOpUserID(ctx), entrantUserID,
constant.MemberEnterNotification, constant.SingleChatType, tips)
return nil
}
func (g *NotificationSender) GroupDismissedNotification(ctx context.Context, tips *sdkws.GroupDismissedTips, SendMessage *bool) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.Notification(ctx, mcontext.GetOpUserID(ctx), tips.Group.GroupID, constant.GroupDismissedNotification, tips, notification.WithSendMessage(SendMessage))
}
func (g *NotificationSender) GroupMemberMutedNotification(ctx context.Context, groupID, groupMemberUserID string, mutedSeconds uint32) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
log.ZDebug(ctx, "GroupMemberMutedNotification start", "groupID", groupID, "groupMemberUserID", groupMemberUserID, "mutedSeconds", mutedSeconds, "opUserID", mcontext.GetOpUserID(ctx))
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
log.ZWarn(ctx, "GroupMemberMutedNotification getGroupInfo failed", err, "groupID", groupID)
return
}
log.ZDebug(ctx, "GroupMemberMutedNotification got group info", "groupID", groupID, "groupName", group.GroupName)
var user map[string]*sdkws.GroupMemberFullInfo
user, err = g.getGroupMemberMap(ctx, groupID, []string{mcontext.GetOpUserID(ctx), groupMemberUserID})
if err != nil {
log.ZWarn(ctx, "GroupMemberMutedNotification getGroupMemberMap failed", err, "groupID", groupID)
return
}
log.ZDebug(ctx, "GroupMemberMutedNotification got user map", "opUser", user[mcontext.GetOpUserID(ctx)], "mutedUser", user[groupMemberUserID])
tips := &sdkws.GroupMemberMutedTips{
Group: group, MutedSeconds: mutedSeconds,
OpUser: user[mcontext.GetOpUserID(ctx)], MutedUser: user[groupMemberUserID],
}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
log.ZWarn(ctx, "GroupMemberMutedNotification fillOpUser failed", err)
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
// 在群聊中通知,但推送服务会过滤只推送给群主、管理员和被禁言成员本人
log.ZInfo(ctx, "GroupMemberMutedNotification sending notification", "groupID", groupID, "recvID", group.GroupID, "contentType", constant.GroupMemberMutedNotification, "mutedUserID", groupMemberUserID, "mutedSeconds", mutedSeconds)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberMutedNotification, tips)
log.ZDebug(ctx, "GroupMemberMutedNotification notification sent", "groupID", groupID)
}
func (g *NotificationSender) GroupMemberCancelMutedNotification(ctx context.Context, groupID, groupMemberUserID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
log.ZDebug(ctx, "GroupMemberCancelMutedNotification start", "groupID", groupID, "groupMemberUserID", groupMemberUserID, "opUserID", mcontext.GetOpUserID(ctx))
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
log.ZWarn(ctx, "GroupMemberCancelMutedNotification getGroupInfo failed", err, "groupID", groupID)
return
}
log.ZDebug(ctx, "GroupMemberCancelMutedNotification got group info", "groupID", groupID, "groupName", group.GroupName)
var user map[string]*sdkws.GroupMemberFullInfo
user, err = g.getGroupMemberMap(ctx, groupID, []string{mcontext.GetOpUserID(ctx), groupMemberUserID})
if err != nil {
log.ZWarn(ctx, "GroupMemberCancelMutedNotification getGroupMemberMap failed", err, "groupID", groupID)
return
}
log.ZDebug(ctx, "GroupMemberCancelMutedNotification got user map", "opUser", user[mcontext.GetOpUserID(ctx)], "mutedUser", user[groupMemberUserID])
tips := &sdkws.GroupMemberCancelMutedTips{Group: group, OpUser: user[mcontext.GetOpUserID(ctx)], MutedUser: user[groupMemberUserID]}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
log.ZWarn(ctx, "GroupMemberCancelMutedNotification fillOpUser failed", err)
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID)
// 在群聊中通知,但推送服务会过滤只推送给群主、管理员和被取消禁言成员本人
log.ZInfo(ctx, "GroupMemberCancelMutedNotification sending notification", "groupID", groupID, "recvID", group.GroupID, "contentType", constant.GroupMemberCancelMutedNotification, "cancelMutedUserID", groupMemberUserID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberCancelMutedNotification, tips)
log.ZDebug(ctx, "GroupMemberCancelMutedNotification notification sent", "groupID", groupID)
}
func (g *NotificationSender) GroupMutedNotification(ctx context.Context, groupID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return
}
var users []*sdkws.GroupMemberFullInfo
users, err = g.getGroupMembers(ctx, groupID, []string{mcontext.GetOpUserID(ctx)})
if err != nil {
return
}
tips := &sdkws.GroupMutedTips{Group: group}
if len(users) > 0 {
tips.OpUser = users[0]
}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, groupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMutedNotification, tips)
}
func (g *NotificationSender) GroupCancelMutedNotification(ctx context.Context, groupID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return
}
var users []*sdkws.GroupMemberFullInfo
users, err = g.getGroupMembers(ctx, groupID, []string{mcontext.GetOpUserID(ctx)})
if err != nil {
return
}
tips := &sdkws.GroupCancelMutedTips{Group: group}
if len(users) > 0 {
tips.OpUser = users[0]
}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, groupID)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupCancelMutedNotification, tips)
}
func (g *NotificationSender) GroupMemberInfoSetNotification(ctx context.Context, groupID, groupMemberUserID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return
}
var user map[string]*sdkws.GroupMemberFullInfo
user, err = g.getGroupMemberMap(ctx, groupID, []string{groupMemberUserID})
if err != nil {
return
}
tips := &sdkws.GroupMemberInfoSetTips{Group: group, OpUser: user[mcontext.GetOpUserID(ctx)], ChangedUser: user[groupMemberUserID]}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setSortVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID, &tips.GroupSortVersion)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberInfoSetNotification, tips)
}
func (g *NotificationSender) GroupMemberSetToAdminNotification(ctx context.Context, groupID, groupMemberUserID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return
}
user, err := g.getGroupMemberMap(ctx, groupID, []string{mcontext.GetOpUserID(ctx), groupMemberUserID})
if err != nil {
return
}
tips := &sdkws.GroupMemberInfoSetTips{Group: group, OpUser: user[mcontext.GetOpUserID(ctx)], ChangedUser: user[groupMemberUserID]}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setSortVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID, &tips.GroupSortVersion)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToAdminNotification, tips)
}
func (g *NotificationSender) GroupMemberSetToOrdinaryUserNotification(ctx context.Context, groupID, groupMemberUserID string) {
var err error
defer func() {
if err != nil {
log.ZError(ctx, stringutil.GetFuncName(1)+" failed", err)
}
}()
var group *sdkws.GroupInfo
group, err = g.getGroupInfo(ctx, groupID)
if err != nil {
return
}
var user map[string]*sdkws.GroupMemberFullInfo
user, err = g.getGroupMemberMap(ctx, groupID, []string{mcontext.GetOpUserID(ctx), groupMemberUserID})
if err != nil {
return
}
tips := &sdkws.GroupMemberInfoSetTips{Group: group, OpUser: user[mcontext.GetOpUserID(ctx)], ChangedUser: user[groupMemberUserID]}
if err = g.fillOpUser(ctx, &tips.OpUser, tips.Group.GroupID); err != nil {
return
}
g.setSortVersion(ctx, &tips.GroupMemberVersion, &tips.GroupMemberVersionID, database.GroupMemberVersionName, tips.Group.GroupID, &tips.GroupSortVersion)
g.Notification(ctx, mcontext.GetOpUserID(ctx), group.GroupID, constant.GroupMemberSetToOrdinaryUserNotification, tips)
}

View File

@@ -0,0 +1,47 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package group
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/protocol/group"
"github.com/openimsdk/tools/errs"
)
func (g *groupServer) GroupCreateCount(ctx context.Context, req *group.GroupCreateCountReq) (*group.GroupCreateCountResp, error) {
if req.Start > req.End {
return nil, errs.ErrArgs.WrapMsg("start > end: %d > %d", req.Start, req.End)
}
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
total, err := g.db.CountTotal(ctx, nil)
if err != nil {
return nil, err
}
start := time.UnixMilli(req.Start)
before, err := g.db.CountTotal(ctx, &start)
if err != nil {
return nil, err
}
count, err := g.db.CountRangeEverydayTotal(ctx, start, time.UnixMilli(req.End))
if err != nil {
return nil, err
}
return &group.GroupCreateCountResp{Total: total, Before: before, Count: count}, nil
}

197
internal/rpc/group/sync.go Normal file
View File

@@ -0,0 +1,197 @@
package group
import (
"context"
"errors"
"git.imall.cloud/openim/open-im-server-deploy/internal/rpc/incrversion"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/hashutil"
"git.imall.cloud/openim/protocol/constant"
pbgroup "git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/log"
)
const versionSyncLimit = 500
func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgroup.GetFullGroupMemberUserIDsReq) (*pbgroup.GetFullGroupMemberUserIDsResp, error) {
userIDs, err := g.db.FindGroupMemberUserID(ctx, req.GroupID)
if err != nil {
return nil, err
}
if err := authverify.CheckAccessIn(ctx, userIDs...); err != nil {
return nil, err
}
vl, err := g.db.FindMaxGroupMemberVersionCache(ctx, req.GroupID)
if err != nil {
return nil, err
}
idHash := hashutil.IdHash(userIDs)
if req.IdHash == idHash {
userIDs = nil
}
return &pbgroup.GetFullGroupMemberUserIDsResp{
Version: idHash,
VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash,
UserIDs: userIDs,
}, nil
}
func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetFullJoinGroupIDsReq) (*pbgroup.GetFullJoinGroupIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := g.db.FindMaxJoinGroupVersionCache(ctx, req.UserID)
if err != nil {
return nil, err
}
groupIDs, err := g.db.FindJoinGroupID(ctx, req.UserID)
if err != nil {
return nil, err
}
idHash := hashutil.IdHash(groupIDs)
if req.IdHash == idHash {
groupIDs = nil
}
return &pbgroup.GetFullJoinGroupIDsResp{
Version: idHash,
VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash,
GroupIDs: groupIDs,
}, nil
}
func (g *groupServer) GetIncrementalGroupMember(ctx context.Context, req *pbgroup.GetIncrementalGroupMemberReq) (*pbgroup.GetIncrementalGroupMemberResp, error) {
if err := g.checkAdminOrInGroup(ctx, req.GroupID); err != nil {
return nil, err
}
group, err := g.db.TakeGroup(ctx, req.GroupID)
if err != nil {
return nil, err
}
if group.Status == constant.GroupStatusDismissed {
return nil, servererrs.ErrDismissedAlready.Wrap()
}
var (
hasGroupUpdate bool
sortVersion uint64
)
opt := incrversion.Option[*sdkws.GroupMemberFullInfo, pbgroup.GetIncrementalGroupMemberResp]{
Ctx: ctx,
VersionKey: req.GroupID,
VersionID: req.VersionID,
VersionNumber: req.Version,
Version: func(ctx context.Context, groupID string, version uint, limit int) (*model.VersionLog, error) {
vl, err := g.db.FindMemberIncrVersion(ctx, groupID, version, limit)
if err != nil {
return nil, err
}
logs := make([]model.VersionLogElem, 0, len(vl.Logs))
for i, log := range vl.Logs {
switch log.EID {
case model.VersionGroupChangeID:
vl.LogLen--
hasGroupUpdate = true
case model.VersionSortChangeID:
vl.LogLen--
sortVersion = uint64(log.Version)
default:
logs = append(logs, vl.Logs[i])
}
}
vl.Logs = logs
if vl.LogLen > 0 {
hasGroupUpdate = true
}
return vl, nil
},
CacheMaxVersion: g.db.FindMaxGroupMemberVersionCache,
Find: func(ctx context.Context, ids []string) ([]*sdkws.GroupMemberFullInfo, error) {
return g.getGroupMembersInfo(ctx, req.GroupID, ids)
},
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupMemberFullInfo, full bool) *pbgroup.GetIncrementalGroupMemberResp {
return &pbgroup.GetIncrementalGroupMemberResp{
VersionID: version.ID.Hex(),
Version: uint64(version.Version),
Full: full,
Delete: delIDs,
Insert: insertList,
Update: updateList,
SortVersion: sortVersion,
}
},
}
resp, err := opt.Build()
if err != nil {
return nil, err
}
if resp.Full || hasGroupUpdate {
count, err := g.db.FindGroupMemberNum(ctx, group.GroupID)
if err != nil {
return nil, err
}
owner, err := g.db.TakeGroupOwner(ctx, group.GroupID)
if err != nil {
return nil, err
}
resp.Group = g.groupDB2PB(group, owner.UserID, count)
}
return resp, nil
}
func (g *groupServer) GetIncrementalJoinGroup(ctx context.Context, req *pbgroup.GetIncrementalJoinGroupReq) (*pbgroup.GetIncrementalJoinGroupResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
opt := incrversion.Option[*sdkws.GroupInfo, pbgroup.GetIncrementalJoinGroupResp]{
Ctx: ctx,
VersionKey: req.UserID,
VersionID: req.VersionID,
VersionNumber: req.Version,
Version: g.db.FindJoinIncrVersion,
CacheMaxVersion: g.db.FindMaxJoinGroupVersionCache,
Find: g.getGroupsInfo,
Resp: func(version *model.VersionLog, delIDs []string, insertList, updateList []*sdkws.GroupInfo, full bool) *pbgroup.GetIncrementalJoinGroupResp {
return &pbgroup.GetIncrementalJoinGroupResp{
VersionID: version.ID.Hex(),
Version: uint64(version.Version),
Full: full,
Delete: delIDs,
Insert: insertList,
Update: updateList,
}
},
}
return opt.Build()
}
func (g *groupServer) BatchGetIncrementalGroupMember(ctx context.Context, req *pbgroup.BatchGetIncrementalGroupMemberReq) (*pbgroup.BatchGetIncrementalGroupMemberResp, error) {
var num int
resp := make(map[string]*pbgroup.GetIncrementalGroupMemberResp)
for _, memberReq := range req.ReqList {
if _, ok := resp[memberReq.GroupID]; ok {
continue
}
memberResp, err := g.GetIncrementalGroupMember(ctx, memberReq)
if err != nil {
if errors.Is(err, servererrs.ErrDismissedAlready) {
log.ZWarn(ctx, "Failed to get incremental group member", err, "groupID", memberReq.GroupID, "request", memberReq)
continue
}
return nil, err
}
resp[memberReq.GroupID] = memberResp
num += len(memberResp.Insert) + len(memberResp.Update) + len(memberResp.Delete)
if num >= versionSyncLimit {
break
}
}
return &pbgroup.BatchGetIncrementalGroupMemberResp{RespList: resp}, nil
}

View File

@@ -0,0 +1,207 @@
package incrversion
import (
"context"
"fmt"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson/primitive"
)
type BatchOption[A, B any] struct {
Ctx context.Context
TargetKeys []string
VersionIDs []string
VersionNumbers []uint64
//SyncLimit int
Versions func(ctx context.Context, dIds []string, versions []uint64, limits []int) (map[string]*model.VersionLog, error)
CacheMaxVersions func(ctx context.Context, dIds []string) (map[string]*model.VersionLog, error)
Find func(ctx context.Context, dId string, ids []string) (A, error)
Resp func(versionsMap map[string]*model.VersionLog, deleteIdsMap map[string][]string, insertListMap, updateListMap map[string]A, fullMap map[string]bool) *B
}
func (o *BatchOption[A, B]) newError(msg string) error {
return errs.ErrInternalServer.WrapMsg(msg)
}
func (o *BatchOption[A, B]) check() error {
if o.Ctx == nil {
return o.newError("opt ctx is nil")
}
if len(o.TargetKeys) == 0 {
return o.newError("targetKeys is empty")
}
if o.Versions == nil {
return o.newError("func versions is nil")
}
if o.Find == nil {
return o.newError("func find is nil")
}
if o.Resp == nil {
return o.newError("func resp is nil")
}
return nil
}
func (o *BatchOption[A, B]) validVersions() []bool {
valids := make([]bool, len(o.VersionIDs))
for i, versionID := range o.VersionIDs {
objID, err := primitive.ObjectIDFromHex(versionID)
valids[i] = (err == nil && (!objID.IsZero()) && o.VersionNumbers[i] > 0)
}
return valids
}
func (o *BatchOption[A, B]) equalIDs(objIDs []primitive.ObjectID) []bool {
equals := make([]bool, len(o.VersionIDs))
for i, versionID := range o.VersionIDs {
equals[i] = versionID == objIDs[i].Hex()
}
return equals
}
func (o *BatchOption[A, B]) getVersions(tags *[]int) (versions map[string]*model.VersionLog, err error) {
var dIDs []string
var versionNums []uint64
var limits []int
valids := o.validVersions()
if o.CacheMaxVersions == nil {
for i, valid := range valids {
if valid {
(*tags)[i] = tagQuery
dIDs = append(dIDs, o.TargetKeys[i])
versionNums = append(versionNums, o.VersionNumbers[i])
limits = append(limits, syncLimit)
} else {
(*tags)[i] = tagFull
dIDs = append(dIDs, o.TargetKeys[i])
versionNums = append(versionNums, 0)
limits = append(limits, 0)
}
}
versions, err = o.Versions(o.Ctx, dIDs, versionNums, limits)
if err != nil {
return nil, errs.Wrap(err)
}
return versions, nil
} else {
caches, err := o.CacheMaxVersions(o.Ctx, o.TargetKeys)
if err != nil {
return nil, errs.Wrap(err)
}
objIDs := make([]primitive.ObjectID, len(o.VersionIDs))
for i, versionID := range o.VersionIDs {
objID, _ := primitive.ObjectIDFromHex(versionID)
objIDs[i] = objID
}
equals := o.equalIDs(objIDs)
for i, valid := range valids {
if !valid {
(*tags)[i] = tagFull
} else if !equals[i] {
(*tags)[i] = tagFull
} else if o.VersionNumbers[i] == uint64(caches[o.TargetKeys[i]].Version) {
(*tags)[i] = tagEqual
} else {
(*tags)[i] = tagQuery
dIDs = append(dIDs, o.TargetKeys[i])
versionNums = append(versionNums, o.VersionNumbers[i])
limits = append(limits, syncLimit)
delete(caches, o.TargetKeys[i])
}
}
if dIDs != nil {
versionMap, err := o.Versions(o.Ctx, dIDs, versionNums, limits)
if err != nil {
return nil, errs.Wrap(err)
}
for k, v := range versionMap {
caches[k] = v
}
}
versions = caches
}
return versions, nil
}
func (o *BatchOption[A, B]) Build() (*B, error) {
if err := o.check(); err != nil {
return nil, errs.Wrap(err)
}
tags := make([]int, len(o.TargetKeys))
versions, err := o.getVersions(&tags)
if err != nil {
return nil, errs.Wrap(err)
}
fullMap := make(map[string]bool)
for i, tag := range tags {
switch tag {
case tagQuery:
vLog := versions[o.TargetKeys[i]]
fullMap[o.TargetKeys[i]] = vLog.ID.Hex() != o.VersionIDs[i] || uint64(vLog.Version) < o.VersionNumbers[i] || len(vLog.Logs) != vLog.LogLen
case tagFull:
fullMap[o.TargetKeys[i]] = true
case tagEqual:
fullMap[o.TargetKeys[i]] = false
default:
panic(fmt.Errorf("undefined tag %d", tag))
}
}
var (
insertIdsMap = make(map[string][]string)
deleteIdsMap = make(map[string][]string)
updateIdsMap = make(map[string][]string)
)
for _, targetKey := range o.TargetKeys {
if !fullMap[targetKey] {
version := versions[targetKey]
insertIds, deleteIds, updateIds := version.DeleteAndChangeIDs()
insertIdsMap[targetKey] = insertIds
deleteIdsMap[targetKey] = deleteIds
updateIdsMap[targetKey] = updateIds
}
}
var (
insertListMap = make(map[string]A)
updateListMap = make(map[string]A)
)
for targetKey, insertIds := range insertIdsMap {
if len(insertIds) > 0 {
insertList, err := o.Find(o.Ctx, targetKey, insertIds)
if err != nil {
return nil, errs.Wrap(err)
}
insertListMap[targetKey] = insertList
}
}
for targetKey, updateIds := range updateIdsMap {
if len(updateIds) > 0 {
updateList, err := o.Find(o.Ctx, targetKey, updateIds)
if err != nil {
return nil, errs.Wrap(err)
}
updateListMap[targetKey] = updateList
}
}
return o.Resp(versions, deleteIdsMap, insertListMap, updateListMap, fullMap), nil
}

View File

@@ -0,0 +1,153 @@
package incrversion
import (
"context"
"fmt"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson/primitive"
)
//func Limit(maxSync int, version uint64) int {
// if version == 0 {
// return 0
// }
// return maxSync
//}
const syncLimit = 200
const (
tagQuery = iota + 1
tagFull
tagEqual
)
type Option[A, B any] struct {
Ctx context.Context
VersionKey string
VersionID string
VersionNumber uint64
//SyncLimit int
CacheMaxVersion func(ctx context.Context, dId string) (*model.VersionLog, error)
Version func(ctx context.Context, dId string, version uint, limit int) (*model.VersionLog, error)
//SortID func(ctx context.Context, dId string) ([]string, error)
Find func(ctx context.Context, ids []string) ([]A, error)
Resp func(version *model.VersionLog, deleteIds []string, insertList, updateList []A, full bool) *B
}
func (o *Option[A, B]) newError(msg string) error {
return errs.ErrInternalServer.WrapMsg(msg)
}
func (o *Option[A, B]) check() error {
if o.Ctx == nil {
return o.newError("opt ctx is nil")
}
if o.VersionKey == "" {
return o.newError("versionKey is empty")
}
//if o.SyncLimit <= 0 {
// return o.newError("invalid synchronization quantity")
//}
if o.Version == nil {
return o.newError("func version is nil")
}
//if o.SortID == nil {
// return o.newError("func allID is nil")
//}
if o.Find == nil {
return o.newError("func find is nil")
}
if o.Resp == nil {
return o.newError("func resp is nil")
}
return nil
}
func (o *Option[A, B]) validVersion() bool {
objID, err := primitive.ObjectIDFromHex(o.VersionID)
return err == nil && (!objID.IsZero()) && o.VersionNumber > 0
}
func (o *Option[A, B]) equalID(objID primitive.ObjectID) bool {
return o.VersionID == objID.Hex()
}
func (o *Option[A, B]) getVersion(tag *int) (*model.VersionLog, error) {
if o.CacheMaxVersion == nil {
if o.validVersion() {
*tag = tagQuery
return o.Version(o.Ctx, o.VersionKey, uint(o.VersionNumber), syncLimit)
}
*tag = tagFull
return o.Version(o.Ctx, o.VersionKey, 0, 0)
} else {
cache, err := o.CacheMaxVersion(o.Ctx, o.VersionKey)
if err != nil {
return nil, err
}
if !o.validVersion() {
*tag = tagFull
return cache, nil
}
if !o.equalID(cache.ID) {
*tag = tagFull
return cache, nil
}
if o.VersionNumber == uint64(cache.Version) {
*tag = tagEqual
return cache, nil
}
*tag = tagQuery
return o.Version(o.Ctx, o.VersionKey, uint(o.VersionNumber), syncLimit)
}
}
func (o *Option[A, B]) Build() (*B, error) {
if err := o.check(); err != nil {
return nil, err
}
var tag int
version, err := o.getVersion(&tag)
if err != nil {
return nil, err
}
var full bool
switch tag {
case tagQuery:
full = version.ID.Hex() != o.VersionID || uint64(version.Version) < o.VersionNumber || len(version.Logs) != version.LogLen
case tagFull:
full = true
case tagEqual:
full = false
default:
panic(fmt.Errorf("undefined tag %d", tag))
}
var (
insertIds []string
deleteIds []string
updateIds []string
)
if !full {
insertIds, deleteIds, updateIds = version.DeleteAndChangeIDs()
}
var (
insertList []A
updateList []A
)
if len(insertIds) > 0 {
insertList, err = o.Find(o.Ctx, insertIds)
if err != nil {
return nil, err
}
}
if len(updateIds) > 0 {
updateList, err = o.Find(o.Ctx, updateIds)
if err != nil {
return nil, err
}
}
return o.Resp(version, deleteIds, insertList, updateList, full), nil
}

231
internal/rpc/msg/as_read.go Normal file
View File

@@ -0,0 +1,231 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"errors"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
func (m *msgServer) GetConversationsHasReadAndMaxSeq(ctx context.Context, req *msg.GetConversationsHasReadAndMaxSeqReq) (*msg.GetConversationsHasReadAndMaxSeqResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
var conversationIDs []string
if len(req.ConversationIDs) == 0 {
var err error
conversationIDs, err = m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
} else {
conversationIDs = req.ConversationIDs
}
hasReadSeqs, err := m.MsgDatabase.GetHasReadSeqs(ctx, req.UserID, conversationIDs)
if err != nil {
return nil, err
}
conversations, err := m.ConversationLocalCache.GetConversations(ctx, req.UserID, conversationIDs)
if err != nil {
return nil, err
}
conversationMaxSeqMap := make(map[string]int64)
for _, conversation := range conversations {
if conversation.MaxSeq != 0 {
conversationMaxSeqMap[conversation.ConversationID] = conversation.MaxSeq
}
}
maxSeqs, err := m.MsgDatabase.GetMaxSeqsWithTime(ctx, conversationIDs)
if err != nil {
return nil, err
}
resp := &msg.GetConversationsHasReadAndMaxSeqResp{Seqs: make(map[string]*msg.Seqs)}
if req.ReturnPinned {
pinnedConversationIDs, err := m.ConversationLocalCache.GetPinnedConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
resp.PinnedConversationIDs = pinnedConversationIDs
}
for conversationID, maxSeq := range maxSeqs {
resp.Seqs[conversationID] = &msg.Seqs{
HasReadSeq: hasReadSeqs[conversationID],
MaxSeq: maxSeq.Seq,
MaxSeqTime: maxSeq.Time,
}
if v, ok := conversationMaxSeqMap[conversationID]; ok {
resp.Seqs[conversationID].MaxSeq = v
}
}
return resp, nil
}
func (m *msgServer) SetConversationHasReadSeq(ctx context.Context, req *msg.SetConversationHasReadSeqReq) (*msg.SetConversationHasReadSeqResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
maxSeq, err := m.MsgDatabase.GetMaxSeq(ctx, req.ConversationID)
if err != nil {
return nil, err
}
if req.HasReadSeq > maxSeq {
return nil, errs.ErrArgs.WrapMsg("hasReadSeq must not be bigger than maxSeq")
}
if err := m.MsgDatabase.SetHasReadSeq(ctx, req.UserID, req.ConversationID, req.HasReadSeq); err != nil {
return nil, err
}
m.sendMarkAsReadNotification(ctx, req.ConversationID, constant.SingleChatType, req.UserID, req.UserID, nil, req.HasReadSeq)
return &msg.SetConversationHasReadSeqResp{}, nil
}
func (m *msgServer) MarkMsgsAsRead(ctx context.Context, req *msg.MarkMsgsAsReadReq) (*msg.MarkMsgsAsReadResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
maxSeq, err := m.MsgDatabase.GetMaxSeq(ctx, req.ConversationID)
if err != nil {
return nil, err
}
hasReadSeq := req.Seqs[len(req.Seqs)-1]
if hasReadSeq > maxSeq {
return nil, errs.ErrArgs.WrapMsg("hasReadSeq must not be bigger than maxSeq")
}
conversation, err := m.ConversationLocalCache.GetConversation(ctx, req.UserID, req.ConversationID)
if err != nil {
return nil, err
}
webhookCfg := m.webhookConfig()
if err := m.MsgDatabase.MarkSingleChatMsgsAsRead(ctx, req.UserID, req.ConversationID, req.Seqs); err != nil {
return nil, err
}
currentHasReadSeq, err := m.MsgDatabase.GetHasReadSeq(ctx, req.UserID, req.ConversationID)
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
}
if hasReadSeq > currentHasReadSeq {
err = m.MsgDatabase.SetHasReadSeq(ctx, req.UserID, req.ConversationID, hasReadSeq)
if err != nil {
return nil, err
}
}
reqCallback := &cbapi.CallbackSingleMsgReadReq{
ConversationID: conversation.ConversationID,
UserID: req.UserID,
Seqs: req.Seqs,
ContentType: conversation.ConversationType,
}
m.webhookAfterSingleMsgRead(ctx, &webhookCfg.AfterSingleMsgRead, reqCallback)
m.sendMarkAsReadNotification(ctx, req.ConversationID, conversation.ConversationType, req.UserID,
m.conversationAndGetRecvID(conversation, req.UserID), req.Seqs, hasReadSeq)
return &msg.MarkMsgsAsReadResp{}, nil
}
func (m *msgServer) MarkConversationAsRead(ctx context.Context, req *msg.MarkConversationAsReadReq) (*msg.MarkConversationAsReadResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversation, err := m.ConversationLocalCache.GetConversation(ctx, req.UserID, req.ConversationID)
if err != nil {
return nil, err
}
webhookCfg := m.webhookConfig()
hasReadSeq, err := m.MsgDatabase.GetHasReadSeq(ctx, req.UserID, req.ConversationID)
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
}
var seqs []int64
log.ZDebug(ctx, "MarkConversationAsRead", "hasReadSeq", hasReadSeq, "req.HasReadSeq", req.HasReadSeq)
if conversation.ConversationType == constant.SingleChatType {
for i := hasReadSeq + 1; i <= req.HasReadSeq; i++ {
seqs = append(seqs, i)
}
// avoid client missed call MarkConversationMessageAsRead by order
for _, val := range req.Seqs {
if !datautil.Contain(val, seqs...) {
seqs = append(seqs, val)
}
}
if len(seqs) > 0 {
log.ZDebug(ctx, "MarkConversationAsRead", "seqs", seqs, "conversationID", req.ConversationID)
if err = m.MsgDatabase.MarkSingleChatMsgsAsRead(ctx, req.UserID, req.ConversationID, seqs); err != nil {
return nil, err
}
}
if req.HasReadSeq > hasReadSeq {
err = m.MsgDatabase.SetHasReadSeq(ctx, req.UserID, req.ConversationID, req.HasReadSeq)
if err != nil {
return nil, err
}
hasReadSeq = req.HasReadSeq
}
m.sendMarkAsReadNotification(ctx, req.ConversationID, conversation.ConversationType, req.UserID,
m.conversationAndGetRecvID(conversation, req.UserID), seqs, hasReadSeq)
} else if conversation.ConversationType == constant.ReadGroupChatType ||
conversation.ConversationType == constant.NotificationChatType {
if req.HasReadSeq > hasReadSeq {
err = m.MsgDatabase.SetHasReadSeq(ctx, req.UserID, req.ConversationID, req.HasReadSeq)
if err != nil {
return nil, err
}
hasReadSeq = req.HasReadSeq
}
m.sendMarkAsReadNotification(ctx, req.ConversationID, constant.SingleChatType, req.UserID,
req.UserID, seqs, hasReadSeq)
}
if conversation.ConversationType == constant.SingleChatType {
reqCall := &cbapi.CallbackSingleMsgReadReq{
ConversationID: conversation.ConversationID,
UserID: conversation.OwnerUserID,
Seqs: req.Seqs,
ContentType: conversation.ConversationType,
}
m.webhookAfterSingleMsgRead(ctx, &webhookCfg.AfterSingleMsgRead, reqCall)
} else if conversation.ConversationType == constant.ReadGroupChatType {
reqCall := &cbapi.CallbackGroupMsgReadReq{
SendID: conversation.OwnerUserID,
ReceiveID: req.UserID,
UnreadMsgNum: req.HasReadSeq,
ContentType: int64(conversation.ConversationType),
}
m.webhookAfterGroupMsgRead(ctx, &webhookCfg.AfterGroupMsgRead, reqCall)
}
return &msg.MarkConversationAsReadResp{}, nil
}
func (m *msgServer) sendMarkAsReadNotification(ctx context.Context, conversationID string, sessionType int32, sendID, recvID string, seqs []int64, hasReadSeq int64) {
tips := &sdkws.MarkAsReadTips{
MarkAsReadUserID: sendID,
ConversationID: conversationID,
Seqs: seqs,
HasReadSeq: hasReadSeq,
}
m.notificationSender.NotificationWithSessionType(ctx, sendID, recvID, constant.HasReadReceipt, sessionType, tips)
}

View File

@@ -0,0 +1,236 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"encoding/base64"
"encoding/json"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
pbchat "git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/stringutil"
"google.golang.org/protobuf/proto"
)
func toCommonCallback(ctx context.Context, msg *pbchat.SendMsgReq, command string) cbapi.CommonCallbackReq {
return cbapi.CommonCallbackReq{
SendID: msg.MsgData.SendID,
ServerMsgID: msg.MsgData.ServerMsgID,
CallbackCommand: command,
ClientMsgID: msg.MsgData.ClientMsgID,
OperationID: mcontext.GetOperationID(ctx),
SenderPlatformID: msg.MsgData.SenderPlatformID,
SenderNickname: msg.MsgData.SenderNickname,
SessionType: msg.MsgData.SessionType,
MsgFrom: msg.MsgData.MsgFrom,
ContentType: msg.MsgData.ContentType,
Status: msg.MsgData.Status,
SendTime: msg.MsgData.SendTime,
CreateTime: msg.MsgData.CreateTime,
AtUserIDList: msg.MsgData.AtUserIDList,
SenderFaceURL: msg.MsgData.SenderFaceURL,
Content: GetContent(msg.MsgData),
Seq: uint32(msg.MsgData.Seq),
Ex: msg.MsgData.Ex,
}
}
func GetContent(msg *sdkws.MsgData) string {
if msg.ContentType >= constant.NotificationBegin && msg.ContentType <= constant.NotificationEnd {
var tips sdkws.TipsComm
_ = proto.Unmarshal(msg.Content, &tips)
content := tips.JsonDetail
return content
} else {
return string(msg.Content)
}
}
func (m *msgServer) webhookBeforeSendSingleMsg(ctx context.Context, before *config.BeforeConfig, msg *pbchat.SendMsgReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
if msg.MsgData.ContentType == constant.Typing {
return nil
}
if !filterBeforeMsg(msg, before) {
return nil
}
cbReq := &cbapi.CallbackBeforeSendSingleMsgReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackBeforeSendSingleMsgCommand),
RecvID: msg.MsgData.RecvID,
}
resp := &cbapi.CallbackBeforeSendSingleMsgResp{}
if err := m.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
return nil
})
}
func (m *msgServer) webhookAfterSendSingleMsg(ctx context.Context, after *config.AfterConfig, msg *pbchat.SendMsgReq) {
if msg.MsgData.ContentType == constant.Typing {
return
}
if !filterAfterMsg(msg, after) {
return
}
cbReq := &cbapi.CallbackAfterSendSingleMsgReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackAfterSendSingleMsgCommand),
RecvID: msg.MsgData.RecvID,
}
m.webhookClient.AsyncPostWithQuery(ctx, cbReq.GetCallbackCommand(), cbReq, &cbapi.CallbackAfterSendSingleMsgResp{}, after, buildKeyMsgDataQuery(msg.MsgData))
}
func (m *msgServer) webhookBeforeSendGroupMsg(ctx context.Context, before *config.BeforeConfig, msg *pbchat.SendMsgReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
if !filterBeforeMsg(msg, before) {
return nil
}
if msg.MsgData.ContentType == constant.Typing {
return nil
}
cbReq := &cbapi.CallbackBeforeSendGroupMsgReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackBeforeSendGroupMsgCommand),
GroupID: msg.MsgData.GroupID,
}
resp := &cbapi.CallbackBeforeSendGroupMsgResp{}
if err := m.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
return nil
})
}
func (m *msgServer) webhookAfterSendGroupMsg(ctx context.Context, after *config.AfterConfig, msg *pbchat.SendMsgReq) {
if after == nil {
return
}
if msg.MsgData.ContentType == constant.Typing {
log.ZDebug(ctx, "webhook skipped: typing message", "contentType", msg.MsgData.ContentType)
return
}
log.ZInfo(ctx, "webhook afterSendGroupMsg checking", "enable", after.Enable, "groupID", msg.MsgData.GroupID, "contentType", msg.MsgData.ContentType, "attentionIds", after.AttentionIds, "deniedTypes", after.DeniedTypes)
if !filterAfterMsg(msg, after) {
log.ZDebug(ctx, "webhook filtered out by filterAfterMsg", "groupID", msg.MsgData.GroupID)
return
}
if !after.Enable {
log.ZDebug(ctx, "webhook afterSendGroupMsg disabled, skipping", "enable", after.Enable)
return
}
log.ZInfo(ctx, "webhook afterSendGroupMsg sending", "groupID", msg.MsgData.GroupID, "sendID", msg.MsgData.SendID, "contentType", msg.MsgData.ContentType)
cbReq := &cbapi.CallbackAfterSendGroupMsgReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackAfterSendGroupMsgCommand),
GroupID: msg.MsgData.GroupID,
}
m.webhookClient.AsyncPostWithQuery(ctx, cbReq.GetCallbackCommand(), cbReq, &cbapi.CallbackAfterSendGroupMsgResp{}, after, buildKeyMsgDataQuery(msg.MsgData))
}
func (m *msgServer) webhookBeforeMsgModify(ctx context.Context, before *config.BeforeConfig, msg *pbchat.SendMsgReq, beforeMsgData **sdkws.MsgData) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
//if msg.MsgData.ContentType != constant.Text {
// return nil
//}
if !filterBeforeMsg(msg, before) {
return nil
}
cbReq := &cbapi.CallbackMsgModifyCommandReq{
CommonCallbackReq: toCommonCallback(ctx, msg, cbapi.CallbackBeforeMsgModifyCommand),
}
resp := &cbapi.CallbackMsgModifyCommandResp{}
if err := m.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
if beforeMsgData != nil {
*beforeMsgData = proto.Clone(msg.MsgData).(*sdkws.MsgData)
}
if resp.Content != nil {
msg.MsgData.Content = []byte(*resp.Content)
if err := json.Unmarshal(msg.MsgData.Content, &struct{}{}); err != nil {
return errs.ErrArgs.WrapMsg("webhook msg modify content is not json", "content", string(msg.MsgData.Content))
}
}
datautil.NotNilReplace(msg.MsgData.OfflinePushInfo, resp.OfflinePushInfo)
datautil.NotNilReplace(&msg.MsgData.RecvID, resp.RecvID)
datautil.NotNilReplace(&msg.MsgData.GroupID, resp.GroupID)
datautil.NotNilReplace(&msg.MsgData.ClientMsgID, resp.ClientMsgID)
datautil.NotNilReplace(&msg.MsgData.ServerMsgID, resp.ServerMsgID)
datautil.NotNilReplace(&msg.MsgData.SenderPlatformID, resp.SenderPlatformID)
datautil.NotNilReplace(&msg.MsgData.SenderNickname, resp.SenderNickname)
datautil.NotNilReplace(&msg.MsgData.SenderFaceURL, resp.SenderFaceURL)
datautil.NotNilReplace(&msg.MsgData.SessionType, resp.SessionType)
datautil.NotNilReplace(&msg.MsgData.MsgFrom, resp.MsgFrom)
datautil.NotNilReplace(&msg.MsgData.ContentType, resp.ContentType)
datautil.NotNilReplace(&msg.MsgData.Status, resp.Status)
datautil.NotNilReplace(&msg.MsgData.Options, resp.Options)
datautil.NotNilReplace(&msg.MsgData.AtUserIDList, resp.AtUserIDList)
datautil.NotNilReplace(&msg.MsgData.AttachedInfo, resp.AttachedInfo)
datautil.NotNilReplace(&msg.MsgData.Ex, resp.Ex)
return nil
})
}
func (m *msgServer) webhookAfterGroupMsgRead(ctx context.Context, after *config.AfterConfig, req *cbapi.CallbackGroupMsgReadReq) {
req.CallbackCommand = cbapi.CallbackAfterGroupMsgReadCommand
m.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CallbackGroupMsgReadResp{}, after)
}
func (m *msgServer) webhookAfterSingleMsgRead(ctx context.Context, after *config.AfterConfig, req *cbapi.CallbackSingleMsgReadReq) {
req.CallbackCommand = cbapi.CallbackAfterSingleMsgReadCommand
m.webhookClient.AsyncPost(ctx, req.GetCallbackCommand(), req, &cbapi.CallbackSingleMsgReadResp{}, after)
}
func (m *msgServer) webhookAfterRevokeMsg(ctx context.Context, after *config.AfterConfig, req *pbchat.RevokeMsgReq) {
callbackReq := &cbapi.CallbackAfterRevokeMsgReq{
CallbackCommand: cbapi.CallbackAfterRevokeMsgCommand,
ConversationID: req.ConversationID,
Seq: req.Seq,
UserID: req.UserID,
}
m.webhookClient.AsyncPost(ctx, callbackReq.GetCallbackCommand(), callbackReq, &cbapi.CallbackAfterRevokeMsgResp{}, after)
}
func buildKeyMsgDataQuery(msg *sdkws.MsgData) map[string]string {
keyMsgData := apistruct.KeyMsgData{
SendID: msg.SendID,
RecvID: msg.RecvID,
GroupID: msg.GroupID,
}
return map[string]string{
webhook.Key: base64.StdEncoding.EncodeToString(stringutil.StructToJsonBytes(keyMsgData)),
}
}

61
internal/rpc/msg/clear.go Normal file
View File

@@ -0,0 +1,61 @@
package msg
import (
"context"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/log"
)
// DestructMsgs hard delete in Database.
func (m *msgServer) DestructMsgs(ctx context.Context, req *msg.DestructMsgsReq) (*msg.DestructMsgsResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
docs, err := m.MsgDatabase.GetRandBeforeMsg(ctx, req.Timestamp, int(req.Limit))
if err != nil {
return nil, err
}
for i, doc := range docs {
if err := m.MsgDatabase.DeleteDoc(ctx, doc.DocID); err != nil {
return nil, err
}
log.ZDebug(ctx, "DestructMsgs delete doc", "index", i, "docID", doc.DocID)
index := strings.LastIndex(doc.DocID, ":")
if index < 0 {
continue
}
var minSeq int64
for _, model := range doc.Msg {
if model.Msg == nil {
continue
}
if model.Msg.Seq > minSeq {
minSeq = model.Msg.Seq
}
}
if minSeq <= 0 {
continue
}
conversationID := doc.DocID[:index]
if conversationID == "" {
continue
}
minSeq++
if err := m.MsgDatabase.SetMinSeq(ctx, conversationID, minSeq); err != nil {
return nil, err
}
log.ZDebug(ctx, "DestructMsgs delete doc set min seq", "index", i, "docID", doc.DocID, "conversationID", conversationID, "setMinSeq", minSeq)
}
return &msg.DestructMsgsResp{Count: int32(len(docs))}, nil
}
func (m *msgServer) GetLastMessageSeqByTime(ctx context.Context, req *msg.GetLastMessageSeqByTimeReq) (*msg.GetLastMessageSeqByTimeResp, error) {
seq, err := m.MsgDatabase.GetLastMessageSeqByTime(ctx, req.ConversationID, req.Time)
if err != nil {
return nil, err
}
return &msg.GetLastMessageSeqByTimeResp{Seq: seq}, nil
}

251
internal/rpc/msg/delete.go Normal file
View File

@@ -0,0 +1,251 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/timeutil"
)
func (m *msgServer) getMinSeqs(maxSeqs map[string]int64) map[string]int64 {
minSeqs := make(map[string]int64)
for k, v := range maxSeqs {
minSeqs[k] = v + 1
}
return minSeqs
}
func (m *msgServer) validateDeleteSyncOpt(opt *msg.DeleteSyncOpt) (isSyncSelf, isSyncOther bool) {
if opt == nil {
return
}
return opt.IsSyncSelf, opt.IsSyncOther
}
func (m *msgServer) ClearConversationsMsg(ctx context.Context, req *msg.ClearConversationsMsgReq) (*msg.ClearConversationsMsgResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
if err := m.clearConversation(ctx, req.ConversationIDs, req.UserID, req.DeleteSyncOpt); err != nil {
return nil, err
}
return &msg.ClearConversationsMsgResp{}, nil
}
func (m *msgServer) UserClearAllMsg(ctx context.Context, req *msg.UserClearAllMsgReq) (*msg.UserClearAllMsgResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
if err := m.clearConversation(ctx, conversationIDs, req.UserID, req.DeleteSyncOpt); err != nil {
return nil, err
}
return &msg.UserClearAllMsgResp{}, nil
}
func (m *msgServer) DeleteMsgs(ctx context.Context, req *msg.DeleteMsgsReq) (*msg.DeleteMsgsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
// 获取要删除的消息信息,用于权限检查
_, _, msgs, err := m.MsgDatabase.GetMsgBySeqs(ctx, req.UserID, req.ConversationID, req.Seqs)
if err != nil {
return nil, err
}
if len(msgs) == 0 {
return nil, errs.ErrRecordNotFound.WrapMsg("messages not found")
}
// 权限检查:如果不是管理员,需要检查删除权限
if !authverify.IsAdmin(ctx) {
// 收集所有消息的发送者ID
sendIDs := make([]string, 0, len(msgs))
for _, msg := range msgs {
if msg != nil && msg.SendID != "" {
sendIDs = append(sendIDs, msg.SendID)
}
}
sendIDs = datautil.Distinct(sendIDs)
// 检查第一条消息的会话类型(假设所有消息来自同一会话)
sessionType := msgs[0].SessionType
switch sessionType {
case constant.SingleChatType:
// 单聊:只能删除自己发送的消息
for _, msg := range msgs {
if msg != nil && msg.SendID != req.UserID {
return nil, errs.ErrNoPermission.WrapMsg("can only delete own messages in single chat")
}
}
case constant.ReadGroupChatType:
// 群聊:检查权限
groupID := msgs[0].GroupID
if groupID == "" {
return nil, errs.ErrArgs.WrapMsg("groupID is empty")
}
// 获取操作者和所有消息发送者的群成员信息
allUserIDs := append([]string{req.UserID}, sendIDs...)
members, err := m.GroupLocalCache.GetGroupMemberInfoMap(ctx, groupID, datautil.Distinct(allUserIDs))
if err != nil {
return nil, err
}
// 检查操作者的角色
opMember, ok := members[req.UserID]
if !ok {
return nil, errs.ErrNoPermission.WrapMsg("user not in group")
}
// 检查每条消息的删除权限
for _, msg := range msgs {
if msg == nil || msg.SendID == "" {
continue
}
// 如果是自己发送的消息,可以删除
if msg.SendID == req.UserID {
continue
}
// 如果不是自己发送的消息,需要检查权限
switch opMember.RoleLevel {
case constant.GroupOwner:
// 群主可以删除任何人的消息
case constant.GroupAdmin:
// 管理员只能删除普通成员的消息
sendMember, ok := members[msg.SendID]
if !ok {
return nil, errs.ErrNoPermission.WrapMsg("message sender not in group")
}
if sendMember.RoleLevel != constant.GroupOrdinaryUsers {
return nil, errs.ErrNoPermission.WrapMsg("group admin can only delete messages from ordinary members")
}
default:
// 普通成员只能删除自己的消息
return nil, errs.ErrNoPermission.WrapMsg("can only delete own messages")
}
}
default:
return nil, errs.ErrInternalServer.WrapMsg("sessionType not supported", "sessionType", sessionType)
}
}
isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(req.DeleteSyncOpt)
if isSyncOther {
if err := m.MsgDatabase.DeleteMsgsPhysicalBySeqs(ctx, req.ConversationID, req.Seqs); err != nil {
return nil, err
}
conv, err := m.conversationClient.GetConversationsByConversationID(ctx, req.ConversationID)
if err != nil {
return nil, err
}
tips := &sdkws.DeleteMsgsTips{UserID: req.UserID, ConversationID: req.ConversationID, Seqs: req.Seqs}
m.notificationSender.NotificationWithSessionType(ctx, req.UserID, m.conversationAndGetRecvID(conv, req.UserID),
constant.DeleteMsgsNotification, conv.ConversationType, tips)
} else {
if err := m.MsgDatabase.DeleteUserMsgsBySeqs(ctx, req.UserID, req.ConversationID, req.Seqs); err != nil {
return nil, err
}
if isSyncSelf {
tips := &sdkws.DeleteMsgsTips{UserID: req.UserID, ConversationID: req.ConversationID, Seqs: req.Seqs}
m.notificationSender.NotificationWithSessionType(ctx, req.UserID, req.UserID, constant.DeleteMsgsNotification, constant.SingleChatType, tips)
}
}
return &msg.DeleteMsgsResp{}, nil
}
func (m *msgServer) DeleteMsgPhysicalBySeq(ctx context.Context, req *msg.DeleteMsgPhysicalBySeqReq) (*msg.DeleteMsgPhysicalBySeqResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
err := m.MsgDatabase.DeleteMsgsPhysicalBySeqs(ctx, req.ConversationID, req.Seqs)
if err != nil {
return nil, err
}
return &msg.DeleteMsgPhysicalBySeqResp{}, nil
}
func (m *msgServer) DeleteMsgPhysical(ctx context.Context, req *msg.DeleteMsgPhysicalReq) (*msg.DeleteMsgPhysicalResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
remainTime := timeutil.GetCurrentTimestampBySecond() - req.Timestamp
if _, err := m.DestructMsgs(ctx, &msg.DestructMsgsReq{Timestamp: remainTime, Limit: 9999}); err != nil {
return nil, err
}
return &msg.DeleteMsgPhysicalResp{}, nil
}
func (m *msgServer) clearConversation(ctx context.Context, conversationIDs []string, userID string, deleteSyncOpt *msg.DeleteSyncOpt) error {
conversations, err := m.conversationClient.GetConversationsByConversationIDs(ctx, conversationIDs)
if err != nil {
return err
}
var existConversations []*conversation.Conversation
var existConversationIDs []string
for _, conversation := range conversations {
existConversations = append(existConversations, conversation)
existConversationIDs = append(existConversationIDs, conversation.ConversationID)
}
log.ZDebug(ctx, "ClearConversationsMsg", "existConversationIDs", existConversationIDs)
maxSeqs, err := m.MsgDatabase.GetMaxSeqs(ctx, existConversationIDs)
if err != nil {
return err
}
isSyncSelf, isSyncOther := m.validateDeleteSyncOpt(deleteSyncOpt)
if !isSyncOther {
setSeqs := m.getMinSeqs(maxSeqs)
if err := m.MsgDatabase.SetUserConversationsMinSeqs(ctx, userID, setSeqs); err != nil {
return err
}
ownerUserIDs := []string{userID}
for conversationID, seq := range setSeqs {
if err := m.conversationClient.SetConversationMinSeq(ctx, conversationID, ownerUserIDs, seq); err != nil {
return err
}
}
// notification 2 self
if isSyncSelf {
tips := &sdkws.ClearConversationTips{UserID: userID, ConversationIDs: existConversationIDs}
m.notificationSender.NotificationWithSessionType(ctx, userID, userID, constant.ClearConversationNotification, constant.SingleChatType, tips)
}
} else {
if err := m.MsgDatabase.SetMinSeqs(ctx, m.getMinSeqs(maxSeqs)); err != nil {
return err
}
for _, conversation := range existConversations {
tips := &sdkws.ClearConversationTips{UserID: userID, ConversationIDs: []string{conversation.ConversationID}}
m.notificationSender.NotificationWithSessionType(ctx, userID, m.conversationAndGetRecvID(conversation, userID), constant.ClearConversationNotification, conversation.ConversationType, tips)
}
}
if err := m.MsgDatabase.UserSetHasReadSeqs(ctx, userID, maxSeqs); err != nil {
return err
}
return nil
}

106
internal/rpc/msg/filter.go Normal file
View File

@@ -0,0 +1,106 @@
package msg
import (
"strconv"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
pbchat "git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/utils/datautil"
)
const (
separator = "-"
)
func filterAfterMsg(msg *pbchat.SendMsgReq, after *config.AfterConfig) bool {
result := filterMsg(msg, after.AttentionIds, after.DeniedTypes)
// 添加调试日志
if !result {
// 只在过滤掉时记录,避免日志过多
}
return result
}
func filterBeforeMsg(msg *pbchat.SendMsgReq, before *config.BeforeConfig) bool {
return filterMsg(msg, nil, before.DeniedTypes)
}
func filterMsg(msg *pbchat.SendMsgReq, attentionIds []string, deniedTypes []int32) bool {
// According to the attentionIds configuration, only some users are sent
// 注意对于群消息应该检查GroupID而不是RecvID
if len(attentionIds) != 0 {
// 单聊消息检查RecvID群聊消息检查GroupID
if msg.MsgData.SessionType == constant.SingleChatType {
if !datautil.Contain(msg.MsgData.RecvID, attentionIds...) {
return false
}
} else if msg.MsgData.SessionType == constant.ReadGroupChatType {
if !datautil.Contain(msg.MsgData.GroupID, attentionIds...) {
return false
}
}
}
if defaultDeniedTypes(msg.MsgData.ContentType) {
return false
}
if len(deniedTypes) != 0 && datautil.Contain(msg.MsgData.ContentType, deniedTypes...) {
return false
}
//if len(allowedTypes) != 0 && !isInInterval(msg.MsgData.ContentType, allowedTypes) {
// return false
//}
//if len(deniedTypes) != 0 && isInInterval(msg.MsgData.ContentType, deniedTypes) {
// return false
//}
return true
}
func defaultDeniedTypes(contentType int32) bool {
if contentType >= constant.NotificationBegin && contentType <= constant.NotificationEnd {
return true
}
if contentType == constant.Typing {
return true
}
return false
}
// isInInterval if data is in interval
// Supports two formats: a single type or a range. The range is defined by the lower and upper bounds connected with a hyphen ("-")
// e.g. [1, 100, 200-500, 600-700] means that only data within the range
// {1, 100} [200, 500] [600, 700] will return true.
func isInInterval(data int32, interval []string) bool {
for _, v := range interval {
if strings.Contains(v, separator) {
// is interval
bounds := strings.Split(v, separator)
if len(bounds) != 2 {
continue
}
bottom, err := strconv.Atoi(bounds[0])
if err != nil {
continue
}
top, err := strconv.Atoi(bounds[1])
if err != nil {
continue
}
if datautil.BetweenEq(int(data), bottom, top) {
return true
}
} else {
iv, err := strconv.Atoi(v)
if err != nil {
continue
}
if int(data) == iv {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,44 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"git.imall.cloud/openim/protocol/constant"
pbmsg "git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/mcontext"
)
func (m *msgServer) SetSendMsgStatus(ctx context.Context, req *pbmsg.SetSendMsgStatusReq) (*pbmsg.SetSendMsgStatusResp, error) {
resp := &pbmsg.SetSendMsgStatusResp{}
if err := m.MsgDatabase.SetSendMsgStatus(ctx, mcontext.GetOperationID(ctx), req.Status); err != nil {
return nil, err
}
return resp, nil
}
func (m *msgServer) GetSendMsgStatus(ctx context.Context, req *pbmsg.GetSendMsgStatusReq) (*pbmsg.GetSendMsgStatusResp, error) {
resp := &pbmsg.GetSendMsgStatusResp{}
status, err := m.MsgDatabase.GetSendMsgStatus(ctx, mcontext.GetOperationID(ctx))
if IsNotFound(err) {
resp.Status = constant.MsgStatusNotExist
return resp, nil
} else if err != nil {
return nil, err
}
resp.Status = status
return resp, nil
}

View File

@@ -0,0 +1,50 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
)
type MsgNotificationSender struct {
*notification.NotificationSender
}
func NewMsgNotificationSender(config *Config, opts ...notification.NotificationSenderOptions) *MsgNotificationSender {
return &MsgNotificationSender{notification.NewNotificationSender(&config.NotificationConfig, opts...)}
}
func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, userID, conversationID string, seqs []int64) {
tips := sdkws.DeleteMsgsTips{
UserID: userID,
ConversationID: conversationID,
Seqs: seqs,
}
m.Notification(ctx, userID, userID, constant.DeleteMsgsNotification, &tips)
}
func (m *MsgNotificationSender) MarkAsReadNotification(ctx context.Context, conversationID string, sessionType int32, sendID, recvID string, seqs []int64, hasReadSeq int64) {
tips := &sdkws.MarkAsReadTips{
MarkAsReadUserID: sendID,
ConversationID: conversationID,
Seqs: seqs,
HasReadSeq: hasReadSeq,
}
m.NotificationWithSessionType(ctx, sendID, recvID, constant.HasReadReceipt, sessionType, tips)
}

View File

@@ -0,0 +1,971 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"fmt"
"image"
"image/color"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"sync"
"github.com/makiuchi-d/gozxing"
"github.com/makiuchi-d/gozxing/qrcode"
"github.com/openimsdk/tools/log"
)
// openImageFile 打开图片文件
func openImageFile(imagePath string) (*os.File, error) {
file, err := os.Open(imagePath)
if err != nil {
return nil, fmt.Errorf("无法打开文件: %v", err)
}
return file, nil
}
// QRDecoder 二维码解码器接口
type QRDecoder interface {
Name() string
Decode(ctx context.Context, imagePath string, logPrefix string) (bool, error) // 返回是否检测到二维码
}
// ============================================================================
// QuircDecoder - Quirc解码器包装
// ============================================================================
// QuircDecoder 使用 Quirc 库的解码器
type QuircDecoder struct {
detectFunc func([]uint8, int, int) (bool, error)
}
// NewQuircDecoder 创建Quirc解码器
func NewQuircDecoder(detectFunc func([]uint8, int, int) (bool, error)) *QuircDecoder {
return &QuircDecoder{detectFunc: detectFunc}
}
func (d *QuircDecoder) Name() string {
return "quirc"
}
func (d *QuircDecoder) Decode(ctx context.Context, imagePath string, logPrefix string) (bool, error) {
if d.detectFunc == nil {
return false, fmt.Errorf("quirc 解码器未启用")
}
// 打开并解码图片
file, err := openImageFile(imagePath)
if err != nil {
log.ZError(ctx, "打开图片文件失败", err, "imagePath", imagePath)
return false, err
}
defer file.Close()
img, _, err := image.Decode(file)
if err != nil {
log.ZError(ctx, "解码图片失败", err, "imagePath", imagePath)
return false, fmt.Errorf("无法解码图片: %v", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// 转换为灰度图
grayData := convertToGrayscale(img, width, height)
// 调用Quirc检测
hasQRCode, err := d.detectFunc(grayData, width, height)
if err != nil {
log.ZError(ctx, "Quirc检测失败", err, "width", width, "height", height)
return false, err
}
return hasQRCode, nil
}
// ============================================================================
// CustomQRDecoder - 自定义解码器(兼容圆形角)
// ============================================================================
// CustomQRDecoder 自定义二维码解码器,兼容圆形角等特殊格式
type CustomQRDecoder struct{}
func (d *CustomQRDecoder) Name() string {
return "custom (圆形角兼容)"
}
// Decode 解码二维码,返回是否检测到二维码
func (d *CustomQRDecoder) Decode(ctx context.Context, imagePath string, logPrefix string) (bool, error) {
file, err := openImageFile(imagePath)
if err != nil {
log.ZError(ctx, "打开图片文件失败", err, "imagePath", imagePath)
return false, fmt.Errorf("无法打开文件: %v", err)
}
defer file.Close()
img, _, err := image.Decode(file)
if err != nil {
log.ZError(ctx, "解码图片失败", err, "imagePath", imagePath)
return false, fmt.Errorf("无法解码图片: %v", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
reader := qrcode.NewQRCodeReader()
hints := make(map[gozxing.DecodeHintType]interface{})
hints[gozxing.DecodeHintType_TRY_HARDER] = true
hints[gozxing.DecodeHintType_PURE_BARCODE] = false
hints[gozxing.DecodeHintType_CHARACTER_SET] = "UTF-8"
// 尝试直接解码
bitmap, err := gozxing.NewBinaryBitmapFromImage(img)
if err == nil {
if _, err := reader.Decode(bitmap, hints); err == nil {
return true, nil
}
// 尝试不使用PURE_BARCODE
delete(hints, gozxing.DecodeHintType_PURE_BARCODE)
if _, err := reader.Decode(bitmap, hints); err == nil {
return true, nil
}
hints[gozxing.DecodeHintType_PURE_BARCODE] = false
}
// 尝试多尺度缩放
scales := []float64{1.0, 1.5, 2.0, 0.75, 0.5}
for _, scale := range scales {
scaledImg := scaleImage(img, width, height, scale)
if scaledImg == nil {
continue
}
scaledBitmap, err := gozxing.NewBinaryBitmapFromImage(scaledImg)
if err == nil {
if _, err := reader.Decode(scaledBitmap, hints); err == nil {
return true, nil
}
}
}
// 转换为灰度图进行预处理
grayData := convertToGrayscale(img, width, height)
// 尝试多种预处理方法
preprocessMethods := []struct {
name string
fn func([]byte, int, int) []byte
}{
{"Otsu二值化", enhanceImageOtsu},
{"标准增强", enhanceImage},
{"强对比度", enhanceImageStrong},
{"圆形角处理", enhanceImageForRoundedCorners},
{"去噪+锐化", enhanceImageDenoiseSharpen},
{"高斯模糊+锐化", enhanceImageGaussianSharpen},
}
scalesForPreprocessed := []float64{1.0, 2.0, 1.5, 1.2, 0.8}
for _, method := range preprocessMethods {
processed := method.fn(grayData, width, height)
// 快速检测定位图案
corners := detectCornersFast(processed, width, height)
if len(corners) < 2 {
// 如果没有检测到足够的定位图案,仍然尝试解码
}
processedImg := createImageFromGrayscale(processed, width, height)
bitmap2, err := gozxing.NewBinaryBitmapFromImage(processedImg)
if err == nil {
if _, err := reader.Decode(bitmap2, hints); err == nil {
return true, nil
}
delete(hints, gozxing.DecodeHintType_PURE_BARCODE)
if _, err := reader.Decode(bitmap2, hints); err == nil {
return true, nil
}
hints[gozxing.DecodeHintType_PURE_BARCODE] = false
}
// 对预处理后的图像进行多尺度缩放
for _, scale := range scalesForPreprocessed {
scaledProcessed := scaleGrayscaleImage(processed, width, height, scale)
if scaledProcessed == nil {
continue
}
scaledImg := createImageFromGrayscale(scaledProcessed.data, scaledProcessed.width, scaledProcessed.height)
scaledBitmap, err := gozxing.NewBinaryBitmapFromImage(scaledImg)
if err == nil {
if _, err := reader.Decode(scaledBitmap, hints); err == nil {
return true, nil
}
delete(hints, gozxing.DecodeHintType_PURE_BARCODE)
if _, err := reader.Decode(scaledBitmap, hints); err == nil {
return true, nil
}
hints[gozxing.DecodeHintType_PURE_BARCODE] = false
}
}
}
return false, nil
}
// ============================================================================
// ParallelQRDecoder - 并行解码器
// ============================================================================
// ParallelQRDecoder 并行解码器,同时运行 quirc 和 custom 解码器
type ParallelQRDecoder struct {
quircDecoder QRDecoder
customDecoder QRDecoder
}
// NewParallelQRDecoder 创建并行解码器
func NewParallelQRDecoder(quircDecoder, customDecoder QRDecoder) *ParallelQRDecoder {
return &ParallelQRDecoder{
quircDecoder: quircDecoder,
customDecoder: customDecoder,
}
}
func (d *ParallelQRDecoder) Name() string {
return "parallel (quirc + custom)"
}
// Decode 并行解码:同时运行 quirc 和 custom任一成功立即返回
func (d *ParallelQRDecoder) Decode(ctx context.Context, imagePath string, logPrefix string) (bool, error) {
type decodeResult struct {
hasQRCode bool
err error
name string
}
resultChan := make(chan decodeResult, 2)
var wg sync.WaitGroup
var mu sync.Mutex
var quircErr error
var customErr error
// 启动Quirc解码
if d.quircDecoder != nil {
wg.Add(1)
go func() {
defer wg.Done()
hasQRCode, err := d.quircDecoder.Decode(ctx, imagePath, logPrefix)
mu.Lock()
if err != nil {
quircErr = err
}
mu.Unlock()
resultChan <- decodeResult{
hasQRCode: hasQRCode,
err: err,
name: d.quircDecoder.Name(),
}
}()
}
// 启动Custom解码
if d.customDecoder != nil {
wg.Add(1)
go func() {
defer wg.Done()
hasQRCode, err := d.customDecoder.Decode(ctx, imagePath, logPrefix)
mu.Lock()
if err != nil {
customErr = err
}
mu.Unlock()
resultChan <- decodeResult{
hasQRCode: hasQRCode,
err: err,
name: d.customDecoder.Name(),
}
}()
}
// 等待第一个结果
var firstResult decodeResult
var secondResult decodeResult
firstResult = <-resultChan
if firstResult.hasQRCode {
// 如果检测到二维码,立即返回
go func() {
<-resultChan
wg.Wait()
}()
return true, nil
}
// 等待第二个结果
if d.quircDecoder != nil && d.customDecoder != nil {
secondResult = <-resultChan
if secondResult.hasQRCode {
wg.Wait()
return true, nil
}
}
wg.Wait()
// 如果都失败,返回错误
if firstResult.err != nil && secondResult.err != nil {
log.ZError(ctx, "并行解码失败,两个解码器都失败", fmt.Errorf("quirc错误=%v, custom错误=%v", quircErr, customErr),
"quircError", quircErr,
"customError", customErr)
return false, fmt.Errorf("quirc 和 custom 都解码失败: quirc错误=%v, custom错误=%v", quircErr, customErr)
}
return false, nil
}
// ============================================================================
// 辅助函数
// ============================================================================
// convertToGrayscale 转换为灰度图
func convertToGrayscale(img image.Image, width, height int) []byte {
grayData := make([]byte, width*height)
bounds := img.Bounds()
if ycbcr, ok := img.(*image.YCbCr); ok {
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
yi := ycbcr.YOffset(x+bounds.Min.X, y+bounds.Min.Y)
grayData[y*width+x] = ycbcr.Y[yi]
}
}
return grayData
}
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
r8 := uint8(r >> 8)
g8 := uint8(g >> 8)
b8 := uint8(b >> 8)
gray := byte((uint16(r8)*299 + uint16(g8)*587 + uint16(b8)*114) / 1000)
grayData[y*width+x] = gray
}
}
return grayData
}
// enhanceImage 图像增强(标准方法)
func enhanceImage(data []byte, width, height int) []byte {
enhanced := make([]byte, len(data))
copy(enhanced, data)
minVal := uint8(255)
maxVal := uint8(0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
if maxVal-minVal < 50 {
rangeVal := maxVal - minVal
if rangeVal == 0 {
rangeVal = 1
}
for i, v := range data {
stretched := uint8((uint16(v-minVal) * 255) / uint16(rangeVal))
enhanced[i] = stretched
}
}
return enhanced
}
// enhanceImageStrong 强对比度增强
func enhanceImageStrong(data []byte, width, height int) []byte {
enhanced := make([]byte, len(data))
histogram := make([]int, 256)
for _, v := range data {
histogram[v]++
}
cdf := make([]int, 256)
cdf[0] = histogram[0]
for i := 1; i < 256; i++ {
cdf[i] = cdf[i-1] + histogram[i]
}
total := len(data)
for i, v := range data {
if total > 0 {
enhanced[i] = uint8((cdf[v] * 255) / total)
}
}
return enhanced
}
// enhanceImageForRoundedCorners 针对圆形角的特殊处理
func enhanceImageForRoundedCorners(data []byte, width, height int) []byte {
enhanced := make([]byte, len(data))
copy(enhanced, data)
minVal := uint8(255)
maxVal := uint8(0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
if maxVal-minVal < 100 {
rangeVal := maxVal - minVal
if rangeVal == 0 {
rangeVal = 1
}
for i, v := range data {
stretched := uint8((uint16(v-minVal) * 255) / uint16(rangeVal))
enhanced[i] = stretched
}
}
// 形态学操作:先腐蚀后膨胀
dilated := make([]byte, len(enhanced))
kernelSize := 3
halfKernel := kernelSize / 2
// 腐蚀(最小值滤波)
for y := halfKernel; y < height-halfKernel; y++ {
for x := halfKernel; x < width-halfKernel; x++ {
minVal := uint8(255)
for ky := -halfKernel; ky <= halfKernel; ky++ {
for kx := -halfKernel; kx <= halfKernel; kx++ {
idx := (y+ky)*width + (x + kx)
if enhanced[idx] < minVal {
minVal = enhanced[idx]
}
}
}
dilated[y*width+x] = minVal
}
}
// 膨胀(最大值滤波)
for y := halfKernel; y < height-halfKernel; y++ {
for x := halfKernel; x < width-halfKernel; x++ {
maxVal := uint8(0)
for ky := -halfKernel; ky <= halfKernel; ky++ {
for kx := -halfKernel; kx <= halfKernel; kx++ {
idx := (y+ky)*width + (x + kx)
if dilated[idx] > maxVal {
maxVal = dilated[idx]
}
}
}
enhanced[y*width+x] = maxVal
}
}
// 边界保持原值
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if y < halfKernel || y >= height-halfKernel || x < halfKernel || x >= width-halfKernel {
enhanced[y*width+x] = data[y*width+x]
}
}
}
return enhanced
}
// enhanceImageDenoiseSharpen 去噪+锐化处理
func enhanceImageDenoiseSharpen(data []byte, width, height int) []byte {
denoised := medianFilter(data, width, height, 3)
sharpened := sharpenImage(denoised, width, height)
return sharpened
}
// medianFilter 中值滤波去噪
func medianFilter(data []byte, width, height, kernelSize int) []byte {
filtered := make([]byte, len(data))
halfKernel := kernelSize / 2
kernelArea := kernelSize * kernelSize
values := make([]byte, kernelArea)
for y := halfKernel; y < height-halfKernel; y++ {
for x := halfKernel; x < width-halfKernel; x++ {
idx := 0
for ky := -halfKernel; ky <= halfKernel; ky++ {
for kx := -halfKernel; kx <= halfKernel; kx++ {
values[idx] = data[(y+ky)*width+(x+kx)]
idx++
}
}
filtered[y*width+x] = quickSelectMedian(values)
}
}
// 边界保持原值
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if y < halfKernel || y >= height-halfKernel || x < halfKernel || x >= width-halfKernel {
filtered[y*width+x] = data[y*width+x]
}
}
}
return filtered
}
// quickSelectMedian 快速选择中值
func quickSelectMedian(arr []byte) byte {
n := len(arr)
if n <= 7 {
// 小数组使用插入排序
for i := 1; i < n; i++ {
key := arr[i]
j := i - 1
for j >= 0 && arr[j] > key {
arr[j+1] = arr[j]
j--
}
arr[j+1] = key
}
return arr[n/2]
}
return quickSelect(arr, 0, n-1, n/2)
}
// quickSelect 快速选择第k小的元素
func quickSelect(arr []byte, left, right, k int) byte {
if left == right {
return arr[left]
}
pivotIndex := partition(arr, left, right)
if k == pivotIndex {
return arr[k]
} else if k < pivotIndex {
return quickSelect(arr, left, pivotIndex-1, k)
}
return quickSelect(arr, pivotIndex+1, right, k)
}
func partition(arr []byte, left, right int) int {
pivot := arr[right]
i := left
for j := left; j < right; j++ {
if arr[j] <= pivot {
arr[i], arr[j] = arr[j], arr[i]
i++
}
}
arr[i], arr[right] = arr[right], arr[i]
return i
}
// sharpenImage 锐化处理
func sharpenImage(data []byte, width, height int) []byte {
sharpened := make([]byte, len(data))
kernel := []int{0, -1, 0, -1, 5, -1, 0, -1, 0}
for y := 1; y < height-1; y++ {
for x := 1; x < width-1; x++ {
sum := 0
idx := 0
for ky := -1; ky <= 1; ky++ {
for kx := -1; kx <= 1; kx++ {
pixelIdx := (y+ky)*width + (x + kx)
sum += int(data[pixelIdx]) * kernel[idx]
idx++
}
}
if sum < 0 {
sum = 0
}
if sum > 255 {
sum = 255
}
sharpened[y*width+x] = uint8(sum)
}
}
// 边界保持原值
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if y == 0 || y == height-1 || x == 0 || x == width-1 {
sharpened[y*width+x] = data[y*width+x]
}
}
}
return sharpened
}
// enhanceImageOtsu Otsu自适应阈值二值化
func enhanceImageOtsu(data []byte, width, height int) []byte {
threshold := calculateOtsuThreshold(data)
binary := make([]byte, len(data))
for i := range data {
if data[i] > threshold {
binary[i] = 255
}
}
return binary
}
// calculateOtsuThreshold 计算Otsu自适应阈值
func calculateOtsuThreshold(data []byte) uint8 {
histogram := make([]int, 256)
for _, v := range data {
histogram[v]++
}
total := len(data)
if total == 0 {
return 128
}
var threshold uint8
var maxVar float64
var sum int
for i := 0; i < 256; i++ {
sum += i * histogram[i]
}
var sum1 int
var wB int
for i := 0; i < 256; i++ {
wB += histogram[i]
if wB == 0 {
continue
}
wF := total - wB
if wF == 0 {
break
}
sum1 += i * histogram[i]
mB := float64(sum1) / float64(wB)
mF := float64(sum-sum1) / float64(wF)
varBetween := float64(wB) * float64(wF) * (mB - mF) * (mB - mF)
if varBetween > maxVar {
maxVar = varBetween
threshold = uint8(i)
}
}
return threshold
}
// enhanceImageGaussianSharpen 高斯模糊+锐化
func enhanceImageGaussianSharpen(data []byte, width, height int) []byte {
blurred := gaussianBlur(data, width, height, 1.0)
sharpened := sharpenImage(blurred, width, height)
enhanced := enhanceImage(sharpened, width, height)
return enhanced
}
// gaussianBlur 高斯模糊
func gaussianBlur(data []byte, width, height int, sigma float64) []byte {
blurred := make([]byte, len(data))
kernelSize := 5
halfKernel := kernelSize / 2
kernel := make([]float64, kernelSize*kernelSize)
sum := 0.0
for y := -halfKernel; y <= halfKernel; y++ {
for x := -halfKernel; x <= halfKernel; x++ {
idx := (y+halfKernel)*kernelSize + (x + halfKernel)
val := math.Exp(-(float64(x*x+y*y) / (2 * sigma * sigma)))
kernel[idx] = val
sum += val
}
}
for i := range kernel {
kernel[i] /= sum
}
for y := halfKernel; y < height-halfKernel; y++ {
for x := halfKernel; x < width-halfKernel; x++ {
var val float64
idx := 0
for ky := -halfKernel; ky <= halfKernel; ky++ {
for kx := -halfKernel; kx <= halfKernel; kx++ {
pixelIdx := (y+ky)*width + (x + kx)
val += float64(data[pixelIdx]) * kernel[idx]
idx++
}
}
blurred[y*width+x] = uint8(val)
}
}
// 边界保持原值
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if y < halfKernel || y >= height-halfKernel || x < halfKernel || x >= width-halfKernel {
blurred[y*width+x] = data[y*width+x]
}
}
}
return blurred
}
// createImageFromGrayscale 从灰度数据创建图像
func createImageFromGrayscale(data []byte, width, height int) image.Image {
img := image.NewGray(image.Rect(0, 0, width, height))
for y := 0; y < height; y++ {
rowStart := y * width
rowEnd := rowStart + width
if rowEnd > len(data) {
rowEnd = len(data)
}
copy(img.Pix[y*img.Stride:], data[rowStart:rowEnd])
}
return img
}
// scaledImage 缩放后的图像数据
type scaledImage struct {
data []byte
width int
height int
}
// scaleImage 缩放图像
func scaleImage(img image.Image, origWidth, origHeight int, scale float64) image.Image {
if scale == 1.0 {
return img
}
newWidth := int(float64(origWidth) * scale)
newHeight := int(float64(origHeight) * scale)
if newWidth < 50 || newHeight < 50 {
return nil
}
if newWidth > 1500 || newHeight > 1500 {
return nil
}
scaled := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight))
bounds := img.Bounds()
for y := 0; y < newHeight; y++ {
for x := 0; x < newWidth; x++ {
srcX := float64(x) / scale
srcY := float64(y) / scale
x1 := int(srcX)
y1 := int(srcY)
x2 := x1 + 1
y2 := y1 + 1
if x2 >= bounds.Dx() {
x2 = bounds.Dx() - 1
}
if y2 >= bounds.Dy() {
y2 = bounds.Dy() - 1
}
fx := srcX - float64(x1)
fy := srcY - float64(y1)
c11 := getPixelColor(img, bounds.Min.X+x1, bounds.Min.Y+y1)
c12 := getPixelColor(img, bounds.Min.X+x1, bounds.Min.Y+y2)
c21 := getPixelColor(img, bounds.Min.X+x2, bounds.Min.Y+y1)
c22 := getPixelColor(img, bounds.Min.X+x2, bounds.Min.Y+y2)
r := uint8(float64(c11.R)*(1-fx)*(1-fy) + float64(c21.R)*fx*(1-fy) + float64(c12.R)*(1-fx)*fy + float64(c22.R)*fx*fy)
g := uint8(float64(c11.G)*(1-fx)*(1-fy) + float64(c21.G)*fx*(1-fy) + float64(c12.G)*(1-fx)*fy + float64(c22.G)*fx*fy)
b := uint8(float64(c11.B)*(1-fx)*(1-fy) + float64(c21.B)*fx*(1-fy) + float64(c12.B)*(1-fx)*fy + float64(c22.B)*fx*fy)
scaled.Set(x, y, color.RGBA{R: r, G: g, B: b, A: 255})
}
}
return scaled
}
// getPixelColor 获取像素颜色
func getPixelColor(img image.Image, x, y int) color.RGBA {
r, g, b, _ := img.At(x, y).RGBA()
return color.RGBA{
R: uint8(r >> 8),
G: uint8(g >> 8),
B: uint8(b >> 8),
A: 255,
}
}
// scaleGrayscaleImage 缩放灰度图像
func scaleGrayscaleImage(data []byte, origWidth, origHeight int, scale float64) *scaledImage {
if scale == 1.0 {
return &scaledImage{data: data, width: origWidth, height: origHeight}
}
newWidth := int(float64(origWidth) * scale)
newHeight := int(float64(origHeight) * scale)
if newWidth < 21 || newHeight < 21 || newWidth > 2000 || newHeight > 2000 {
return nil
}
scaled := make([]byte, newWidth*newHeight)
for y := 0; y < newHeight; y++ {
for x := 0; x < newWidth; x++ {
srcX := float64(x) / scale
srcY := float64(y) / scale
x1 := int(srcX)
y1 := int(srcY)
x2 := x1 + 1
y2 := y1 + 1
if x2 >= origWidth {
x2 = origWidth - 1
}
if y2 >= origHeight {
y2 = origHeight - 1
}
fx := srcX - float64(x1)
fy := srcY - float64(y1)
v11 := float64(data[y1*origWidth+x1])
v12 := float64(data[y2*origWidth+x1])
v21 := float64(data[y1*origWidth+x2])
v22 := float64(data[y2*origWidth+x2])
val := v11*(1-fx)*(1-fy) + v21*fx*(1-fy) + v12*(1-fx)*fy + v22*fx*fy
scaled[y*newWidth+x] = uint8(val)
}
}
return &scaledImage{data: scaled, width: newWidth, height: newHeight}
}
// Point 表示一个点
type Point struct {
X, Y int
}
// Corner 表示检测到的定位图案角点
type Corner struct {
Center Point
Size int
Type int
}
// detectCornersFast 快速检测定位图案
func detectCornersFast(data []byte, width, height int) []Corner {
var corners []Corner
scanStep := max(2, min(width, height)/80)
if scanStep < 1 {
scanStep = 1
}
for y := scanStep * 3; y < height-scanStep*3; y += scanStep {
for x := scanStep * 3; x < width-scanStep*3; x += scanStep {
if isFinderPatternFast(data, width, height, x, y) {
corners = append(corners, Corner{
Center: Point{X: x, Y: y},
Size: 20,
})
if len(corners) >= 3 {
return corners
}
}
}
}
return corners
}
// isFinderPatternFast 快速检测定位图案
func isFinderPatternFast(data []byte, width, height, x, y int) bool {
centerIdx := y*width + x
if centerIdx < 0 || centerIdx >= len(data) {
return false
}
if data[centerIdx] > 180 {
return false
}
radius := min(width, height) / 15
if radius < 3 {
radius = 3
}
if radius > 30 {
radius = 30
}
directions := []struct{ dx, dy int }{{radius, 0}, {-radius, 0}, {0, radius}, {0, -radius}}
blackCount := 0
whiteCount := 0
for _, dir := range directions {
nx := x + dir.dx
ny := y + dir.dy
if nx >= 0 && nx < width && ny >= 0 && ny < height {
idx := ny*width + nx
if idx >= 0 && idx < len(data) {
if data[idx] < 128 {
blackCount++
} else {
whiteCount++
}
}
}
}
return blackCount >= 2 && whiteCount >= 2
}
// 辅助函数
func max(a, b int) int {
if a > b {
return a
}
return b
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,191 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"encoding/json"
"fmt"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"os"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
// PictureElem 用于解析图片消息内容
type PictureElem struct {
SourcePicture struct {
URL string `json:"url"`
} `json:"sourcePicture"`
BigPicture struct {
URL string `json:"url"`
} `json:"bigPicture"`
SnapshotPicture struct {
URL string `json:"url"`
} `json:"snapshotPicture"`
}
// checkImageContainsQRCode 检测图片中是否包含二维码
// userType: 0-普通用户不能发送包含二维码的图片1-特殊用户(可以发送)
func (m *msgServer) checkImageContainsQRCode(ctx context.Context, msgData *sdkws.MsgData, userType int32) error {
// userType=1 的用户可以发送包含二维码的图片,不进行检测
if userType == 1 {
return nil
}
// 只检测图片类型的消息
if msgData.ContentType != constant.Picture {
return nil
}
// 解析图片消息内容
var pictureElem PictureElem
if err := json.Unmarshal(msgData.Content, &pictureElem); err != nil {
// 如果解析失败,记录警告但不拦截
log.ZWarn(ctx, "failed to parse picture message", err, "content", string(msgData.Content))
return nil
}
// 获取图片URL优先使用原图如果没有则使用大图
imageURL := pictureElem.SourcePicture.URL
if imageURL == "" {
imageURL = pictureElem.BigPicture.URL
}
if imageURL == "" {
imageURL = pictureElem.SnapshotPicture.URL
}
if imageURL == "" {
// 没有有效的图片URL无法检测
log.ZWarn(ctx, "no valid image URL found in picture message", nil, "pictureElem", pictureElem)
return nil
}
// 下载图片并检测二维码
hasQRCode, err := m.detectQRCodeInImage(ctx, imageURL, "")
if err != nil {
// 检测失败时,记录错误但不拦截(避免误拦截)
log.ZWarn(ctx, "QR code detection failed", err, "imageURL", imageURL)
return nil
}
if hasQRCode {
log.ZWarn(ctx, "检测到二维码,拒绝发送", nil, "imageURL", imageURL, "userType", userType)
return servererrs.ErrImageContainsQRCode.WrapMsg("userType=0的用户不能发送包含二维码的图片")
}
return nil
}
// detectQRCodeInImage 下载图片并检测是否包含二维码
func (m *msgServer) detectQRCodeInImage(ctx context.Context, imageURL string, logPrefix string) (bool, error) {
// 创建带超时的HTTP客户端
client := &http.Client{
Timeout: 5 * time.Second,
}
// 下载图片
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
if err != nil {
log.ZError(ctx, "创建HTTP请求失败", err, "imageURL", imageURL)
return false, errs.WrapMsg(err, "failed to create request")
}
resp, err := client.Do(req)
if err != nil {
log.ZError(ctx, "下载图片失败", err, "imageURL", imageURL)
return false, errs.WrapMsg(err, "failed to download image")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.ZError(ctx, "下载图片状态码异常", nil, "statusCode", resp.StatusCode, "imageURL", imageURL)
return false, errs.WrapMsg(fmt.Errorf("unexpected status code: %d", resp.StatusCode), "failed to download image")
}
// 限制图片大小最大10MB
const maxImageSize = 10 * 1024 * 1024
limitedReader := io.LimitReader(resp.Body, maxImageSize+1)
// 创建临时文件
tmpFile, err := os.CreateTemp("", "qrcode_detect_*.tmp")
if err != nil {
log.ZError(ctx, "创建临时文件失败", err)
return false, errs.WrapMsg(err, "failed to create temp file")
}
tmpFilePath := tmpFile.Name()
// 确保检测完成后删除临时文件(无论成功还是失败)
defer func() {
// 确保文件已关闭后再删除
if tmpFile != nil {
_ = tmpFile.Close()
}
// 删除临时文件,忽略文件不存在的错误
if err := os.Remove(tmpFilePath); err != nil && !os.IsNotExist(err) {
log.ZWarn(ctx, "删除临时文件失败", err, "tmpFile", tmpFilePath)
}
}()
// 保存图片到临时文件
written, err := io.Copy(tmpFile, limitedReader)
if err != nil {
log.ZError(ctx, "保存图片到临时文件失败", err, "tmpFile", tmpFilePath)
return false, errs.WrapMsg(err, "failed to save image")
}
// 关闭文件以便后续读取
if err := tmpFile.Close(); err != nil {
log.ZError(ctx, "关闭临时文件失败", err, "tmpFile", tmpFilePath)
return false, errs.WrapMsg(err, "failed to close temp file")
}
// 检查文件大小
if written > maxImageSize {
log.ZWarn(ctx, "图片过大", nil, "size", written, "maxSize", maxImageSize)
return false, errs.WrapMsg(fmt.Errorf("image too large: %d bytes", written), "image size exceeds limit")
}
// 使用优化的并行解码器检测二维码
hasQRCode, err := m.detectQRCodeWithDecoder(ctx, tmpFilePath, "")
if err != nil {
log.ZError(ctx, "二维码检测失败", err, "tmpFile", tmpFilePath)
return false, err
}
return hasQRCode, nil
}
// detectQRCodeWithDecoder 使用优化的解码器检测二维码
func (m *msgServer) detectQRCodeWithDecoder(ctx context.Context, imagePath string, logPrefix string) (bool, error) {
// 使用Custom解码器已移除Quirc解码器依赖
customDecoder := &CustomQRDecoder{}
// 执行解码
hasQRCode, err := customDecoder.Decode(ctx, imagePath, logPrefix)
if err != nil {
log.ZError(ctx, "解码器检测失败", err, "decoder", customDecoder.Name())
return false, err
}
return hasQRCode, nil
}

134
internal/rpc/msg/revoke.go Normal file
View File

@@ -0,0 +1,134 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (m *msgServer) RevokeMsg(ctx context.Context, req *msg.RevokeMsgReq) (*msg.RevokeMsgResp, error) {
if req.UserID == "" {
return nil, errs.ErrArgs.WrapMsg("user_id is empty")
}
if req.ConversationID == "" {
return nil, errs.ErrArgs.WrapMsg("conversation_id is empty")
}
if req.Seq < 0 {
return nil, errs.ErrArgs.WrapMsg("seq is invalid")
}
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
user, err := m.UserLocalCache.GetUserInfo(ctx, req.UserID)
if err != nil {
return nil, err
}
_, _, msgs, err := m.MsgDatabase.GetMsgBySeqs(ctx, req.UserID, req.ConversationID, []int64{req.Seq})
if err != nil {
return nil, err
}
if len(msgs) == 0 || msgs[0] == nil {
return nil, errs.ErrRecordNotFound.WrapMsg("msg not found")
}
if msgs[0].ContentType == constant.MsgRevokeNotification {
return nil, servererrs.ErrMsgAlreadyRevoke.WrapMsg("msg already revoke")
}
data, _ := json.Marshal(msgs[0])
log.ZDebug(ctx, "GetMsgBySeqs", "conversationID", req.ConversationID, "seq", req.Seq, "msg", string(data))
var role int32
if !authverify.IsAdmin(ctx) {
sessionType := msgs[0].SessionType
switch sessionType {
case constant.SingleChatType:
if err := authverify.CheckAccess(ctx, msgs[0].SendID); err != nil {
return nil, err
}
role = user.AppMangerLevel
case constant.ReadGroupChatType:
members, err := m.GroupLocalCache.GetGroupMemberInfoMap(ctx, msgs[0].GroupID, datautil.Distinct([]string{req.UserID, msgs[0].SendID}))
if err != nil {
return nil, err
}
if req.UserID != msgs[0].SendID {
switch members[req.UserID].RoleLevel {
case constant.GroupOwner:
case constant.GroupAdmin:
if sendMember, ok := members[msgs[0].SendID]; ok {
if sendMember.RoleLevel != constant.GroupOrdinaryUsers {
return nil, errs.ErrNoPermission.WrapMsg("no permission")
}
}
default:
return nil, errs.ErrNoPermission.WrapMsg("no permission")
}
}
if member := members[req.UserID]; member != nil {
role = member.RoleLevel
}
default:
return nil, errs.ErrInternalServer.WrapMsg("msg sessionType not supported", "sessionType", sessionType)
}
}
now := time.Now().UnixMilli()
err = m.MsgDatabase.RevokeMsg(ctx, req.ConversationID, req.Seq, &model.RevokeModel{
Role: role,
UserID: req.UserID,
Nickname: user.Nickname,
Time: now,
})
if err != nil {
return nil, err
}
revokerUserID := mcontext.GetOpUserID(ctx)
var flag bool
if len(m.config.Share.IMAdminUser.UserIDs) > 0 {
flag = datautil.Contain(revokerUserID, m.adminUserIDs...)
}
tips := sdkws.RevokeMsgTips{
RevokerUserID: revokerUserID,
ClientMsgID: msgs[0].ClientMsgID,
RevokeTime: now,
Seq: req.Seq,
SesstionType: msgs[0].SessionType,
ConversationID: req.ConversationID,
IsAdminRevoke: flag,
}
var recvID string
if msgs[0].SessionType == constant.ReadGroupChatType {
recvID = msgs[0].GroupID
} else {
recvID = msgs[0].RecvID
}
m.notificationSender.NotificationWithSessionType(ctx, req.UserID, recvID, constant.MsgRevokeNotification, msgs[0].SessionType, &tips)
webhookCfg := m.webhookConfig()
m.webhookAfterRevokeMsg(ctx, &webhookCfg.AfterRevokeMsg, req)
return &msg.RevokeMsgResp{}, nil
}

231
internal/rpc/msg/send.go Normal file
View File

@@ -0,0 +1,231 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"google.golang.org/protobuf/proto"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/conversationutil"
"git.imall.cloud/openim/protocol/constant"
pbconv "git.imall.cloud/openim/protocol/conversation"
pbmsg "git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"git.imall.cloud/openim/protocol/wrapperspb"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (m *msgServer) SendMsg(ctx context.Context, req *pbmsg.SendMsgReq) (*pbmsg.SendMsgResp, error) {
if req.MsgData == nil {
return nil, errs.ErrArgs.WrapMsg("msgData is nil")
}
if err := authverify.CheckAccess(ctx, req.MsgData.SendID); err != nil {
return nil, err
}
before := new(*sdkws.MsgData)
resp, err := m.sendMsg(ctx, req, before)
if err != nil {
return nil, err
}
if *before != nil && proto.Equal(*before, req.MsgData) == false {
resp.Modify = req.MsgData
}
return resp, nil
}
func (m *msgServer) sendMsg(ctx context.Context, req *pbmsg.SendMsgReq, before **sdkws.MsgData) (*pbmsg.SendMsgResp, error) {
m.encapsulateMsgData(req.MsgData)
switch req.MsgData.SessionType {
case constant.SingleChatType:
return m.sendMsgSingleChat(ctx, req, before)
case constant.NotificationChatType:
return m.sendMsgNotification(ctx, req, before)
case constant.ReadGroupChatType:
return m.sendMsgGroupChat(ctx, req, before)
default:
return nil, errs.ErrArgs.WrapMsg("unknown sessionType")
}
}
func (m *msgServer) sendMsgGroupChat(ctx context.Context, req *pbmsg.SendMsgReq, before **sdkws.MsgData) (resp *pbmsg.SendMsgResp, err error) {
if err = m.messageVerification(ctx, req); err != nil {
prommetrics.GroupChatMsgProcessFailedCounter.Inc()
return nil, err
}
webhookCfg := m.webhookConfig()
if err = m.webhookBeforeSendGroupMsg(ctx, &webhookCfg.BeforeSendGroupMsg, req); err != nil {
return nil, err
}
if err := m.webhookBeforeMsgModify(ctx, &webhookCfg.BeforeMsgModify, req, before); err != nil {
return nil, err
}
err = m.MsgDatabase.MsgToMQ(ctx, conversationutil.GenConversationUniqueKeyForGroup(req.MsgData.GroupID), req.MsgData)
if err != nil {
return nil, err
}
if req.MsgData.ContentType == constant.AtText {
go m.setConversationAtInfo(ctx, req.MsgData)
}
// 获取webhook配置优先从配置管理器获取
m.webhookAfterSendGroupMsg(ctx, &webhookCfg.AfterSendGroupMsg, req)
prommetrics.GroupChatMsgProcessSuccessCounter.Inc()
resp = &pbmsg.SendMsgResp{}
resp.SendTime = req.MsgData.SendTime
resp.ServerMsgID = req.MsgData.ServerMsgID
resp.ClientMsgID = req.MsgData.ClientMsgID
return resp, nil
}
func (m *msgServer) setConversationAtInfo(nctx context.Context, msg *sdkws.MsgData) {
log.ZDebug(nctx, "setConversationAtInfo", "msg", msg)
defer func() {
if r := recover(); r != nil {
log.ZPanic(nctx, "setConversationAtInfo Panic", errs.ErrPanic(r))
}
}()
ctx := mcontext.NewCtx("@@@" + mcontext.GetOperationID(nctx))
var atUserID []string
conversation := &pbconv.ConversationReq{
ConversationID: msgprocessor.GetConversationIDByMsg(msg),
ConversationType: msg.SessionType,
GroupID: msg.GroupID,
}
memberUserIDList, err := m.GroupLocalCache.GetGroupMemberIDs(ctx, msg.GroupID)
if err != nil {
log.ZWarn(ctx, "GetGroupMemberIDs", err)
return
}
tagAll := datautil.Contain(constant.AtAllString, msg.AtUserIDList...)
if tagAll {
memberUserIDList = datautil.DeleteElems(memberUserIDList, msg.SendID)
atUserID = datautil.Single([]string{constant.AtAllString}, msg.AtUserIDList)
if len(atUserID) == 0 { // just @everyone
conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtAll}
} else { // @Everyone and @other people
conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtAllAtMe}
atUserID = datautil.SliceIntersectFuncs(atUserID, memberUserIDList, func(a string) string { return a }, func(b string) string {
return b
})
if err := m.conversationClient.SetConversations(ctx, atUserID, conversation); err != nil {
log.ZWarn(ctx, "SetConversations", err, "userID", atUserID, "conversation", conversation)
}
memberUserIDList = datautil.Single(atUserID, memberUserIDList)
}
conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtAll}
if err := m.conversationClient.SetConversations(ctx, memberUserIDList, conversation); err != nil {
log.ZWarn(ctx, "SetConversations", err, "userID", memberUserIDList, "conversation", conversation)
}
return
}
atUserID = datautil.SliceIntersectFuncs(msg.AtUserIDList, memberUserIDList, func(a string) string { return a }, func(b string) string {
return b
})
conversation.GroupAtType = &wrapperspb.Int32Value{Value: constant.AtMe}
if err := m.conversationClient.SetConversations(ctx, atUserID, conversation); err != nil {
log.ZWarn(ctx, "SetConversations", err, atUserID, conversation)
}
}
func (m *msgServer) sendMsgNotification(ctx context.Context, req *pbmsg.SendMsgReq, before **sdkws.MsgData) (resp *pbmsg.SendMsgResp, err error) {
if err := m.MsgDatabase.MsgToMQ(ctx, conversationutil.GenConversationUniqueKeyForSingle(req.MsgData.SendID, req.MsgData.RecvID), req.MsgData); err != nil {
return nil, err
}
resp = &pbmsg.SendMsgResp{
ServerMsgID: req.MsgData.ServerMsgID,
ClientMsgID: req.MsgData.ClientMsgID,
SendTime: req.MsgData.SendTime,
}
return resp, nil
}
func (m *msgServer) sendMsgSingleChat(ctx context.Context, req *pbmsg.SendMsgReq, before **sdkws.MsgData) (resp *pbmsg.SendMsgResp, err error) {
if err := m.messageVerification(ctx, req); err != nil {
return nil, err
}
webhookCfg := m.webhookConfig()
isSend := true
isNotification := msgprocessor.IsNotificationByMsg(req.MsgData)
if !isNotification {
isSend, err = m.modifyMessageByUserMessageReceiveOpt(authverify.WithTempAdmin(ctx), req.MsgData.RecvID, conversationutil.GenConversationIDForSingle(req.MsgData.SendID, req.MsgData.RecvID), constant.SingleChatType, req)
if err != nil {
return nil, err
}
}
if !isSend {
prommetrics.SingleChatMsgProcessFailedCounter.Inc()
return nil, errs.ErrArgs.WrapMsg("message is not sent")
} else {
if err := m.webhookBeforeMsgModify(ctx, &webhookCfg.BeforeMsgModify, req, before); err != nil {
return nil, err
}
if err := m.MsgDatabase.MsgToMQ(ctx, conversationutil.GenConversationUniqueKeyForSingle(req.MsgData.SendID, req.MsgData.RecvID), req.MsgData); err != nil {
prommetrics.SingleChatMsgProcessFailedCounter.Inc()
return nil, err
}
m.webhookAfterSendSingleMsg(ctx, &webhookCfg.AfterSendSingleMsg, req)
prommetrics.SingleChatMsgProcessSuccessCounter.Inc()
return &pbmsg.SendMsgResp{
ServerMsgID: req.MsgData.ServerMsgID,
ClientMsgID: req.MsgData.ClientMsgID,
SendTime: req.MsgData.SendTime,
}, nil
}
}
func (m *msgServer) SendSimpleMsg(ctx context.Context, req *pbmsg.SendSimpleMsgReq) (*pbmsg.SendSimpleMsgResp, error) {
if req.MsgData == nil {
return nil, errs.ErrArgs.WrapMsg("msg data is nil")
}
sender, err := m.UserLocalCache.GetUserInfo(ctx, req.MsgData.SendID)
if err != nil {
return nil, err
}
req.MsgData.SenderFaceURL = sender.FaceURL
req.MsgData.SenderNickname = sender.Nickname
resp, err := m.SendMsg(ctx, &pbmsg.SendMsgReq{MsgData: req.MsgData})
if err != nil {
return nil, err
}
return &pbmsg.SendSimpleMsgResp{
ServerMsgID: resp.ServerMsgID,
ClientMsgID: resp.ClientMsgID,
SendTime: resp.SendTime,
Modify: resp.Modify,
}, nil
}

105
internal/rpc/msg/seq.go Normal file
View File

@@ -0,0 +1,105 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"errors"
"sort"
pbmsg "git.imall.cloud/openim/protocol/msg"
"github.com/redis/go-redis/v9"
)
func (m *msgServer) GetConversationMaxSeq(ctx context.Context, req *pbmsg.GetConversationMaxSeqReq) (*pbmsg.GetConversationMaxSeqResp, error) {
maxSeq, err := m.MsgDatabase.GetMaxSeq(ctx, req.ConversationID)
if err != nil && !errors.Is(err, redis.Nil) {
return nil, err
}
return &pbmsg.GetConversationMaxSeqResp{MaxSeq: maxSeq}, nil
}
func (m *msgServer) GetMaxSeqs(ctx context.Context, req *pbmsg.GetMaxSeqsReq) (*pbmsg.SeqsInfoResp, error) {
maxSeqs, err := m.MsgDatabase.GetMaxSeqs(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
return &pbmsg.SeqsInfoResp{MaxSeqs: maxSeqs}, nil
}
func (m *msgServer) GetHasReadSeqs(ctx context.Context, req *pbmsg.GetHasReadSeqsReq) (*pbmsg.SeqsInfoResp, error) {
hasReadSeqs, err := m.MsgDatabase.GetHasReadSeqs(ctx, req.UserID, req.ConversationIDs)
if err != nil {
return nil, err
}
return &pbmsg.SeqsInfoResp{MaxSeqs: hasReadSeqs}, nil
}
func (m *msgServer) GetMsgByConversationIDs(ctx context.Context, req *pbmsg.GetMsgByConversationIDsReq) (*pbmsg.GetMsgByConversationIDsResp, error) {
Msgs, err := m.MsgDatabase.FindOneByDocIDs(ctx, req.ConversationIDs, req.MaxSeqs)
if err != nil {
return nil, err
}
return &pbmsg.GetMsgByConversationIDsResp{MsgDatas: Msgs}, nil
}
func (m *msgServer) SetUserConversationsMinSeq(ctx context.Context, req *pbmsg.SetUserConversationsMinSeqReq) (*pbmsg.SetUserConversationsMinSeqResp, error) {
for _, userID := range req.UserIDs {
if err := m.MsgDatabase.SetUserConversationsMinSeqs(ctx, userID, map[string]int64{req.ConversationID: req.Seq}); err != nil {
return nil, err
}
}
return &pbmsg.SetUserConversationsMinSeqResp{}, nil
}
func (m *msgServer) GetActiveConversation(ctx context.Context, req *pbmsg.GetActiveConversationReq) (*pbmsg.GetActiveConversationResp, error) {
res, err := m.MsgDatabase.GetCacheMaxSeqWithTime(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
conversations := make([]*pbmsg.ActiveConversation, 0, len(res))
for conversationID, val := range res {
conversations = append(conversations, &pbmsg.ActiveConversation{
MaxSeq: val.Seq,
LastTime: val.Time,
ConversationID: conversationID,
})
}
if req.Limit > 0 {
sort.Sort(activeConversations(conversations))
if len(conversations) > int(req.Limit) {
conversations = conversations[:req.Limit]
}
}
return &pbmsg.GetActiveConversationResp{Conversations: conversations}, nil
}
func (m *msgServer) SetUserConversationMaxSeq(ctx context.Context, req *pbmsg.SetUserConversationMaxSeqReq) (*pbmsg.SetUserConversationMaxSeqResp, error) {
for _, userID := range req.OwnerUserID {
if err := m.MsgDatabase.SetUserConversationsMaxSeq(ctx, req.ConversationID, userID, req.MaxSeq); err != nil {
return nil, err
}
}
return &pbmsg.SetUserConversationMaxSeqResp{}, nil
}
func (m *msgServer) SetUserConversationMinSeq(ctx context.Context, req *pbmsg.SetUserConversationMinSeqReq) (*pbmsg.SetUserConversationMinSeqResp, error) {
for _, userID := range req.OwnerUserID {
if err := m.MsgDatabase.SetUserConversationsMinSeq(ctx, req.ConversationID, userID, req.MinSeq); err != nil {
return nil, err
}
}
return &pbmsg.SetUserConversationMinSeqResp{}, nil
}

218
internal/rpc/msg/server.go Normal file
View File

@@ -0,0 +1,218 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/mcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"git.imall.cloud/openim/open-im-server-deploy/pkg/mqbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"google.golang.org/grpc"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpccache"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
)
type MessageInterceptorFunc func(ctx context.Context, globalConfig *Config, req *msg.SendMsgReq) (*sdkws.MsgData, error)
// MessageInterceptorChain defines a chain of message interceptor functions.
type MessageInterceptorChain []MessageInterceptorFunc
type Config struct {
RpcConfig config.Msg
RedisConfig config.Redis
MongodbConfig config.Mongo
KafkaConfig config.Kafka
NotificationConfig config.Notification
Share config.Share
WebhooksConfig config.Webhooks
LocalCacheConfig config.LocalCache
Discovery config.Discovery
}
// MsgServer encapsulates dependencies required for message handling.
type msgServer struct {
msg.UnimplementedMsgServer
RegisterCenter discovery.Conn // Service discovery registry for service registration.
MsgDatabase controller.CommonMsgDatabase // Interface for message database operations.
UserLocalCache *rpccache.UserLocalCache // Local cache for user data.
FriendLocalCache *rpccache.FriendLocalCache // Local cache for friend data.
GroupLocalCache *rpccache.GroupLocalCache // Local cache for group data.
ConversationLocalCache *rpccache.ConversationLocalCache // Local cache for conversation data.
Handlers MessageInterceptorChain // Chain of handlers for processing messages.
notificationSender *notification.NotificationSender // RPC client for sending notifications.
msgNotificationSender *MsgNotificationSender // RPC client for sending msg notifications.
config *Config // Global configuration settings.
webhookClient *webhook.Client
webhookConfigManager *webhook.ConfigManager // Webhook配置管理器支持从数据库读取
conversationClient *rpcli.ConversationClient
redPacketDB database.RedPacket // Database for red packet records.
redPacketReceiveDB database.RedPacketReceive // Database for red packet receive records.
adminUserIDs []string
}
func (m *msgServer) addInterceptorHandler(interceptorFunc ...MessageInterceptorFunc) {
m.Handlers = append(m.Handlers, interceptorFunc...)
}
// webhookConfig returns the latest webhook config from the client/manager with fallback.
func (m *msgServer) webhookConfig() *config.Webhooks {
if m.webhookClient == nil {
return &m.config.WebhooksConfig
}
return m.webhookClient.GetConfig(&m.config.WebhooksConfig)
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
builder := mqbuild.NewBuilder(&config.KafkaConfig)
redisProducer, err := builder.GetTopicProducer(ctx, config.KafkaConfig.ToRedisTopic)
if err != nil {
return err
}
dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return err
}
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
msgDocModel, err := mgo.NewMsgMongo(mgocli.GetDB())
if err != nil {
return err
}
var msgModel cache.MsgCache
if rdb == nil {
cm, err := mgo.NewCacheMgo(mgocli.GetDB())
if err != nil {
return err
}
msgModel = mcache.NewMsgCache(cm, msgDocModel)
} else {
msgModel = redis.NewMsgCache(rdb, msgDocModel)
}
seqConversation, err := mgo.NewSeqConversationMongo(mgocli.GetDB())
if err != nil {
return err
}
seqConversationCache := redis.NewSeqConversationCacheRedis(rdb, seqConversation)
seqUser, err := mgo.NewSeqUserMongo(mgocli.GetDB())
if err != nil {
return err
}
seqUserCache := redis.NewSeqUserCacheRedis(rdb, seqUser)
redPacketDB, err := mgo.NewRedPacketMongo(mgocli.GetDB())
if err != nil {
return err
}
redPacketReceiveDB, err := mgo.NewRedPacketReceiveMongo(mgocli.GetDB())
if err != nil {
return err
}
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
groupConn, err := client.GetConn(ctx, config.Discovery.RpcService.Group)
if err != nil {
return err
}
friendConn, err := client.GetConn(ctx, config.Discovery.RpcService.Friend)
if err != nil {
return err
}
conversationConn, err := client.GetConn(ctx, config.Discovery.RpcService.Conversation)
if err != nil {
return err
}
conversationClient := rpcli.NewConversationClient(conversationConn)
msgDatabase := controller.NewCommonMsgDatabase(msgDocModel, msgModel, seqUserCache, seqConversationCache, redisProducer)
localcache.InitLocalCache(&config.LocalCacheConfig)
// 初始化webhook配置管理器支持从数据库读取配置
var webhookClient *webhook.Client
var webhookConfigManager *webhook.ConfigManager
systemConfigDB, err := mgo.NewSystemConfigMongo(mgocli.GetDB())
if err == nil {
// 如果SystemConfig数据库初始化成功使用配置管理器
webhookConfigManager = webhook.NewConfigManager(systemConfigDB, &config.WebhooksConfig)
if err := webhookConfigManager.Start(ctx); err != nil {
log.ZWarn(ctx, "failed to start webhook config manager, using default config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
} else {
webhookClient = webhook.NewWebhookClientWithManager(webhookConfigManager)
}
} else {
// 如果SystemConfig数据库初始化失败使用默认配置
log.ZWarn(ctx, "failed to init system config db, using default webhook config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
}
s := &msgServer{
MsgDatabase: msgDatabase,
RegisterCenter: client,
UserLocalCache: rpccache.NewUserLocalCache(rpcli.NewUserClient(userConn), &config.LocalCacheConfig, rdb),
GroupLocalCache: rpccache.NewGroupLocalCache(rpcli.NewGroupClient(groupConn), &config.LocalCacheConfig, rdb),
ConversationLocalCache: rpccache.NewConversationLocalCache(conversationClient, &config.LocalCacheConfig, rdb),
FriendLocalCache: rpccache.NewFriendLocalCache(rpcli.NewRelationClient(friendConn), &config.LocalCacheConfig, rdb),
config: config,
webhookClient: webhookClient,
webhookConfigManager: webhookConfigManager,
conversationClient: conversationClient,
redPacketDB: redPacketDB,
redPacketReceiveDB: redPacketReceiveDB,
adminUserIDs: config.Share.IMAdminUser.UserIDs,
}
s.notificationSender = notification.NewNotificationSender(&config.NotificationConfig, notification.WithLocalSendMsg(s.SendMsg))
s.msgNotificationSender = NewMsgNotificationSender(config, notification.WithLocalSendMsg(s.SendMsg))
msg.RegisterMsgServer(server, s)
return nil
}
func (m *msgServer) conversationAndGetRecvID(conversation *conversation.Conversation, userID string) string {
if conversation.ConversationType == constant.SingleChatType ||
conversation.ConversationType == constant.NotificationChatType {
if userID == conversation.OwnerUserID {
return conversation.UserID
} else {
return conversation.OwnerUserID
}
} else if conversation.ConversationType == constant.ReadGroupChatType {
return conversation.GroupID
}
return ""
}

View File

@@ -0,0 +1,107 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/utils/datautil"
)
func (m *msgServer) GetActiveUser(ctx context.Context, req *msg.GetActiveUserReq) (*msg.GetActiveUserResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
msgCount, userCount, users, dateCount, err := m.MsgDatabase.RangeUserSendCount(ctx, time.UnixMilli(req.Start), time.UnixMilli(req.End), req.Group, req.Ase, req.Pagination.PageNumber, req.Pagination.ShowNumber)
if err != nil {
return nil, err
}
var pbUsers []*msg.ActiveUser
if len(users) > 0 {
userIDs := datautil.Slice(users, func(e *model.UserCount) string { return e.UserID })
userMap, err := m.UserLocalCache.GetUsersInfoMap(ctx, userIDs)
if err != nil {
return nil, err
}
pbUsers = make([]*msg.ActiveUser, 0, len(users))
for _, user := range users {
pbUser := userMap[user.UserID]
if pbUser == nil {
pbUser = &sdkws.UserInfo{
UserID: user.UserID,
Nickname: user.UserID,
}
}
pbUsers = append(pbUsers, &msg.ActiveUser{
User: pbUser,
Count: user.Count,
})
}
}
return &msg.GetActiveUserResp{
MsgCount: msgCount,
UserCount: userCount,
DateCount: dateCount,
Users: pbUsers,
}, nil
}
func (m *msgServer) GetActiveGroup(ctx context.Context, req *msg.GetActiveGroupReq) (*msg.GetActiveGroupResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
msgCount, groupCount, groups, dateCount, err := m.MsgDatabase.RangeGroupSendCount(ctx, time.UnixMilli(req.Start), time.UnixMilli(req.End), req.Ase, req.Pagination.PageNumber, req.Pagination.ShowNumber)
if err != nil {
return nil, err
}
var pbgroups []*msg.ActiveGroup
if len(groups) > 0 {
groupIDs := datautil.Slice(groups, func(e *model.GroupCount) string { return e.GroupID })
resp, err := m.GroupLocalCache.GetGroupInfos(ctx, groupIDs)
if err != nil {
return nil, err
}
groupMap := make(map[string]*sdkws.GroupInfo, len(groups))
for i, group := range groups {
groupMap[group.GroupID] = resp[i]
}
pbgroups = make([]*msg.ActiveGroup, 0, len(groups))
for _, group := range groups {
pbgroup := groupMap[group.GroupID]
if pbgroup == nil {
pbgroup = &sdkws.GroupInfo{
GroupID: group.GroupID,
GroupName: group.GroupID,
}
}
pbgroups = append(pbgroups, &msg.ActiveGroup{
Group: pbgroup,
Count: group.Count,
})
}
}
return &msg.GetActiveGroupResp{
MsgCount: msgCount,
GroupCount: groupCount,
DateCount: dateCount,
Groups: pbgroups,
}, nil
}

View File

@@ -0,0 +1,658 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"encoding/json"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/msgprocessor"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/conversationutil"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/timeutil"
)
func (m *msgServer) PullMessageBySeqs(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
// 设置请求超时防止大量数据拉取导致pod异常
// 对于大量大群用户需要更长的超时时间60秒
queryTimeout := 60 * time.Second
queryCtx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
// 参数验证:限制 SeqRanges 数量和每个范围的大小,防止内存溢出
const maxSeqRanges = 100
const maxSeqRangeSize = 10000
// 限制响应总大小防止pod内存溢出估算每条消息平均1KB最多50MB
const maxTotalMessages = 50000
if len(req.SeqRanges) > maxSeqRanges {
log.ZWarn(ctx, "SeqRanges count exceeds limit", nil, "count", len(req.SeqRanges), "limit", maxSeqRanges)
return nil, errs.ErrArgs.WrapMsg("too many seq ranges", "count", len(req.SeqRanges), "limit", maxSeqRanges)
}
for _, seq := range req.SeqRanges {
// 验证每个 seq range 的合理性
if seq.Begin < 0 || seq.End < 0 {
log.ZWarn(ctx, "invalid seq range: negative values", nil, "begin", seq.Begin, "end", seq.End)
continue
}
if seq.End < seq.Begin {
log.ZWarn(ctx, "invalid seq range: end < begin", nil, "begin", seq.Begin, "end", seq.End)
continue
}
seqRangeSize := seq.End - seq.Begin + 1
if seqRangeSize > maxSeqRangeSize {
log.ZWarn(ctx, "seq range size exceeds limit, will be limited", nil, "conversationID", seq.ConversationID, "begin", seq.Begin, "end", seq.End, "size", seqRangeSize, "limit", maxSeqRangeSize)
}
}
resp := &sdkws.PullMessageBySeqsResp{}
resp.Msgs = make(map[string]*sdkws.PullMsgs)
resp.NotificationMsgs = make(map[string]*sdkws.PullMsgs)
var totalMessages int
for _, seq := range req.SeqRanges {
// 检查总消息数,防止内存溢出
if totalMessages >= maxTotalMessages {
log.ZWarn(ctx, "total messages count exceeds limit, stopping", nil, "totalMessages", totalMessages, "limit", maxTotalMessages)
break
}
if !msgprocessor.IsNotification(seq.ConversationID) {
conversation, err := m.ConversationLocalCache.GetConversation(queryCtx, req.UserID, seq.ConversationID)
if err != nil {
log.ZError(ctx, "GetConversation error", err, "conversationID", seq.ConversationID)
continue
}
minSeq, maxSeq, msgs, err := m.MsgDatabase.GetMsgBySeqsRange(queryCtx, req.UserID, seq.ConversationID,
seq.Begin, seq.End, seq.Num, conversation.MaxSeq)
if err != nil {
// 如果是超时错误,记录更详细的日志
if queryCtx.Err() == context.DeadlineExceeded {
log.ZWarn(ctx, "GetMsgBySeqsRange timeout", err, "conversationID", seq.ConversationID, "seq", seq, "timeout", queryTimeout)
return nil, errs.ErrInternalServer.WrapMsg("message pull timeout, data too large or query too slow")
}
log.ZWarn(ctx, "GetMsgBySeqsRange error", err, "conversationID", seq.ConversationID, "seq", seq)
continue
}
totalMessages += len(msgs)
var isEnd bool
switch req.Order {
case sdkws.PullOrder_PullOrderAsc:
isEnd = maxSeq <= seq.End
case sdkws.PullOrder_PullOrderDesc:
isEnd = seq.Begin <= minSeq
}
if len(msgs) == 0 {
log.ZWarn(ctx, "not have msgs", nil, "conversationID", seq.ConversationID, "seq", seq)
continue
}
// 过滤禁言通知消息(只保留群主、管理员和被禁言成员本人可以看到的)
msgs = m.filterMuteNotificationMsgs(ctx, req.UserID, seq.ConversationID, msgs)
// 填充红包消息的领取信息
msgs = m.enrichRedPacketMessages(ctx, req.UserID, msgs)
resp.Msgs[seq.ConversationID] = &sdkws.PullMsgs{Msgs: msgs, IsEnd: isEnd}
} else {
// 限制通知消息的查询范围,防止内存溢出
const maxNotificationSeqRange = 5000
var seqs []int64
seqRange := seq.End - seq.Begin + 1
if seqRange > maxNotificationSeqRange {
log.ZWarn(ctx, "notification seq range too large, limiting", nil, "conversationID", seq.ConversationID, "begin", seq.Begin, "end", seq.End, "range", seqRange, "limit", maxNotificationSeqRange)
// 只取最后 maxNotificationSeqRange 条
for i := seq.End - maxNotificationSeqRange + 1; i <= seq.End; i++ {
seqs = append(seqs, i)
}
} else {
for i := seq.Begin; i <= seq.End; i++ {
seqs = append(seqs, i)
}
}
minSeq, maxSeq, notificationMsgs, err := m.MsgDatabase.GetMsgBySeqs(queryCtx, req.UserID, seq.ConversationID, seqs)
if err != nil {
// 如果是超时错误,记录更详细的日志
if queryCtx.Err() == context.DeadlineExceeded {
log.ZWarn(ctx, "GetMsgBySeqs timeout", err, "conversationID", seq.ConversationID, "seq", seq, "timeout", queryTimeout)
return nil, errs.ErrInternalServer.WrapMsg("notification message pull timeout, data too large or query too slow")
}
log.ZWarn(ctx, "GetMsgBySeqs error", err, "conversationID", seq.ConversationID, "seq", seq)
continue
}
totalMessages += len(notificationMsgs)
var isEnd bool
switch req.Order {
case sdkws.PullOrder_PullOrderAsc:
isEnd = maxSeq <= seq.End
case sdkws.PullOrder_PullOrderDesc:
isEnd = seq.Begin <= minSeq
}
if len(notificationMsgs) == 0 {
log.ZWarn(ctx, "not have notificationMsgs", nil, "conversationID", seq.ConversationID, "seq", seq)
continue
}
// 过滤禁言通知消息(只保留群主、管理员和被禁言成员本人可以看到的)
notificationMsgs = m.filterMuteNotificationMsgs(ctx, req.UserID, seq.ConversationID, notificationMsgs)
// 填充红包消息的领取信息
notificationMsgs = m.enrichRedPacketMessages(ctx, req.UserID, notificationMsgs)
resp.NotificationMsgs[seq.ConversationID] = &sdkws.PullMsgs{Msgs: notificationMsgs, IsEnd: isEnd}
}
}
return resp, nil
}
func (m *msgServer) GetSeqMessage(ctx context.Context, req *msg.GetSeqMessageReq) (*msg.GetSeqMessageResp, error) {
resp := &msg.GetSeqMessageResp{
Msgs: make(map[string]*sdkws.PullMsgs),
NotificationMsgs: make(map[string]*sdkws.PullMsgs),
}
for _, conv := range req.Conversations {
isEnd, endSeq, msgs, err := m.MsgDatabase.GetMessagesBySeqWithBounds(ctx, req.UserID, conv.ConversationID, conv.Seqs, req.GetOrder())
if err != nil {
return nil, err
}
var pullMsgs *sdkws.PullMsgs
if ok := false; conversationutil.IsNotificationConversationID(conv.ConversationID) {
pullMsgs, ok = resp.NotificationMsgs[conv.ConversationID]
if !ok {
pullMsgs = &sdkws.PullMsgs{}
resp.NotificationMsgs[conv.ConversationID] = pullMsgs
}
} else {
pullMsgs, ok = resp.Msgs[conv.ConversationID]
if !ok {
pullMsgs = &sdkws.PullMsgs{}
resp.Msgs[conv.ConversationID] = pullMsgs
}
}
// 过滤禁言通知消息(只保留群主、管理员和被禁言成员本人可以看到的)
filteredMsgs := m.filterMuteNotificationMsgs(ctx, req.UserID, conv.ConversationID, msgs)
// 填充红包消息的领取信息
filteredMsgs = m.enrichRedPacketMessages(ctx, req.UserID, filteredMsgs)
pullMsgs.Msgs = append(pullMsgs.Msgs, filteredMsgs...)
pullMsgs.IsEnd = isEnd
pullMsgs.EndSeq = endSeq
}
return resp, nil
}
// filterMuteNotificationMsgs 过滤禁言、取消禁言、踢出群聊、退出群聊、进入群聊、群成员信息设置和角色变更通知消息,只保留群主、管理员和相关成员本人可以看到的消息
func (m *msgServer) filterMuteNotificationMsgs(ctx context.Context, userID, conversationID string, msgs []*sdkws.MsgData) []*sdkws.MsgData {
// 如果不是群聊会话,直接返回
if !conversationutil.IsGroupConversationID(conversationID) {
return msgs
}
// 提取群ID
groupID := conversationutil.GetGroupIDFromConversationID(conversationID)
if groupID == "" {
log.ZWarn(ctx, "filterMuteNotificationMsgs: invalid group conversationID", nil, "conversationID", conversationID)
return msgs
}
var filteredMsgs []*sdkws.MsgData
var needCheckPermission bool
// 先检查是否有需要过滤的消息
for _, msg := range msgs {
if msg.ContentType == constant.GroupMemberMutedNotification ||
msg.ContentType == constant.GroupMemberCancelMutedNotification ||
msg.ContentType == constant.MemberKickedNotification ||
msg.ContentType == constant.MemberQuitNotification ||
msg.ContentType == constant.MemberInvitedNotification ||
msg.ContentType == constant.MemberEnterNotification ||
msg.ContentType == constant.GroupMemberInfoSetNotification ||
msg.ContentType == constant.GroupMemberSetToAdminNotification ||
msg.ContentType == constant.GroupMemberSetToOrdinaryUserNotification {
needCheckPermission = true
break
}
}
// 如果没有需要过滤的消息,直接返回
if !needCheckPermission {
return msgs
}
// 对于被踢出的用户,可能无法获取成员信息,需要特殊处理
// 先收集所有被踢出的用户ID以便后续判断
var allKickedUserIDs []string
if needCheckPermission {
for _, msg := range msgs {
if msg.ContentType == constant.MemberKickedNotification {
var tips sdkws.MemberKickedTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err == nil {
kickedUserIDs := datautil.Slice(tips.KickedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
allKickedUserIDs = append(allKickedUserIDs, kickedUserIDs...)
}
}
}
allKickedUserIDs = datautil.Distinct(allKickedUserIDs)
}
// 检查用户是否是被踢出的用户(即使已经被踢出,也应该能看到自己被踢出的通知)
isKickedUserInMsgs := datautil.Contain(userID, allKickedUserIDs...)
// 获取当前用户在群中的角色(如果用户已经被踢出,这里会返回错误)
// 添加超时保护防止大群查询阻塞3秒超时
memberCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
member, err := m.GroupLocalCache.GetGroupMember(memberCtx, groupID, userID)
isOwnerOrAdmin := false
if err != nil {
if memberCtx.Err() == context.DeadlineExceeded {
log.ZWarn(ctx, "filterMuteNotificationMsgs: GetGroupMember timeout", err, "groupID", groupID, "userID", userID)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: GetGroupMember failed (user may be kicked)", err, "groupID", groupID, "userID", userID, "isKickedUserInMsgs", isKickedUserInMsgs)
}
// 如果获取失败(可能已经被踢出或超时),仍然需要检查是否是相关成员本人
// 继续处理,但 isOwnerOrAdmin 保持为 false
// 如果是被踢出的用户,仍然可以查看自己被踢出的通知
} else {
isOwnerOrAdmin = member.RoleLevel == constant.GroupOwner || member.RoleLevel == constant.GroupAdmin
}
// 过滤消息
for _, msg := range msgs {
if msg.ContentType == constant.GroupMemberMutedNotification {
var tips sdkws.GroupMemberMutedTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal GroupMemberMutedTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
mutedUserID := tips.MutedUser.UserID
if isOwnerOrAdmin || userID == mutedUserID {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberMutedNotification allowed", "userID", userID, "mutedUserID", mutedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberMutedNotification filtered", "userID", userID, "mutedUserID", mutedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
}
} else if msg.ContentType == constant.GroupMemberCancelMutedNotification {
var tips sdkws.GroupMemberCancelMutedTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal GroupMemberCancelMutedTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
cancelMutedUserID := tips.MutedUser.UserID
if isOwnerOrAdmin || userID == cancelMutedUserID {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberCancelMutedNotification allowed", "userID", userID, "cancelMutedUserID", cancelMutedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberCancelMutedNotification filtered", "userID", userID, "cancelMutedUserID", cancelMutedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
}
} else if msg.ContentType == constant.MemberQuitNotification {
var tips sdkws.MemberQuitTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal MemberQuitTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
quitUserID := tips.QuitUser.UserID
// 退出群聊通知只允许群主和管理员看到,退出者本人不通知
if isOwnerOrAdmin {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberQuitNotification allowed", "userID", userID, "quitUserID", quitUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberQuitNotification filtered", "userID", userID, "quitUserID", quitUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
}
} else if msg.ContentType == constant.MemberInvitedNotification {
var tips sdkws.MemberInvitedTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal MemberInvitedTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
// 获取被邀请的用户ID列表
invitedUserIDs := datautil.Slice(tips.InvitedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
isInvitedUser := datautil.Contain(userID, invitedUserIDs...)
// 邀请入群通知:允许群主、管理员和被邀请的用户本人看到
if isOwnerOrAdmin || isInvitedUser {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberInvitedNotification allowed", "userID", userID, "invitedUserIDs", invitedUserIDs, "isOwnerOrAdmin", isOwnerOrAdmin, "isInvitedUser", isInvitedUser)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberInvitedNotification filtered", "userID", userID, "invitedUserIDs", invitedUserIDs, "isOwnerOrAdmin", isOwnerOrAdmin, "isInvitedUser", isInvitedUser)
}
} else if msg.ContentType == constant.MemberEnterNotification {
var tips sdkws.MemberEnterTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal MemberEnterTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
entrantUserID := tips.EntrantUser.UserID
// 进入群聊通知只允许群主和管理员看到
if isOwnerOrAdmin {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberEnterNotification allowed", "userID", userID, "entrantUserID", entrantUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberEnterNotification filtered", "userID", userID, "entrantUserID", entrantUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
}
} else if msg.ContentType == constant.MemberKickedNotification {
var tips sdkws.MemberKickedTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal MemberKickedTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
// 获取被踢出的用户ID列表
kickedUserIDs := datautil.Slice(tips.KickedUserList, func(e *sdkws.GroupMemberFullInfo) string { return e.UserID })
isKickedUser := datautil.Contain(userID, kickedUserIDs...)
// 被踢出群聊通知:允许群主、管理员和被踢出的用户本人看到
if isOwnerOrAdmin || isKickedUser {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberKickedNotification allowed", "userID", userID, "kickedUserIDs", kickedUserIDs, "isOwnerOrAdmin", isOwnerOrAdmin, "isKickedUser", isKickedUser)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: MemberKickedNotification filtered", "userID", userID, "kickedUserIDs", kickedUserIDs, "isOwnerOrAdmin", isOwnerOrAdmin, "isKickedUser", isKickedUser)
}
} else if msg.ContentType == constant.GroupMemberInfoSetNotification {
var tips sdkws.GroupMemberInfoSetTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal GroupMemberInfoSetTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
changedUserID := tips.ChangedUser.UserID
// 群成员信息设置通知(如背景音)只允许群主和管理员看到
if isOwnerOrAdmin {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberInfoSetNotification allowed", "userID", userID, "changedUserID", changedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberInfoSetNotification filtered", "userID", userID, "changedUserID", changedUserID, "isOwnerOrAdmin", isOwnerOrAdmin)
}
} else if msg.ContentType == constant.GroupMemberSetToAdminNotification || msg.ContentType == constant.GroupMemberSetToOrdinaryUserNotification {
var tips sdkws.GroupMemberInfoSetTips
if err := unmarshalNotificationContent(string(msg.Content), &tips); err != nil {
log.ZWarn(ctx, "filterMuteNotificationMsgs: unmarshal GroupMemberInfoSetTips failed", err)
filteredMsgs = append(filteredMsgs, msg)
continue
}
changedUserID := tips.ChangedUser.UserID
// 设置为管理员/普通用户通知:允许群主、管理员和本人看到
if isOwnerOrAdmin || userID == changedUserID {
filteredMsgs = append(filteredMsgs, msg)
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberSetToAdmin/OrdinaryUserNotification allowed", "userID", userID, "changedUserID", changedUserID, "isOwnerOrAdmin", isOwnerOrAdmin, "contentType", msg.ContentType)
} else {
log.ZDebug(ctx, "filterMuteNotificationMsgs: GroupMemberSetToAdmin/OrdinaryUserNotification filtered", "userID", userID, "changedUserID", changedUserID, "isOwnerOrAdmin", isOwnerOrAdmin, "contentType", msg.ContentType)
}
} else {
// 其他消息直接通过
filteredMsgs = append(filteredMsgs, msg)
}
}
return filteredMsgs
}
// unmarshalNotificationContent 解析通知消息内容
func unmarshalNotificationContent(content string, v interface{}) error {
var notificationElem sdkws.NotificationElem
if err := json.Unmarshal([]byte(content), &notificationElem); err != nil {
return err
}
return jsonutil.JsonUnmarshal([]byte(notificationElem.Detail), v)
}
// enrichRedPacketMessages 填充红包消息的领取信息和状态
func (m *msgServer) enrichRedPacketMessages(ctx context.Context, userID string, msgs []*sdkws.MsgData) []*sdkws.MsgData {
if m.redPacketReceiveDB == nil || m.redPacketDB == nil {
return msgs
}
for _, msg := range msgs {
// 只处理自定义消息类型
if msg.ContentType != constant.Custom {
continue
}
// 解析自定义消息内容
var customElem apistruct.CustomElem
if err := json.Unmarshal(msg.Content, &customElem); err != nil {
continue
}
// 检查是否为红包消息通过description字段判断二级类型
if customElem.Description != "redpacket" {
continue
}
// 解析红包消息内容从data字段中解析
var redPacketElem apistruct.RedPacketElem
if err := json.Unmarshal([]byte(customElem.Data), &redPacketElem); err != nil {
log.ZWarn(ctx, "enrichRedPacketMessages: failed to unmarshal red packet data", err, "redPacketID", redPacketElem.RedPacketID)
continue
}
// 查询红包记录
redPacket, err := m.redPacketDB.Take(ctx, redPacketElem.RedPacketID)
if err != nil {
log.ZWarn(ctx, "enrichRedPacketMessages: failed to get red packet", err, "redPacketID", redPacketElem.RedPacketID)
// 如果查询失败,保持原有状态,不填充信息
continue
}
// 填充红包状态信息
redPacketElem.Status = redPacket.Status
// 判断是否已过期(检查过期时间和状态)
now := time.Now()
isExpired := redPacket.Status == model.RedPacketStatusExpired || (redPacket.ExpireTime.Before(now) && redPacket.Status == model.RedPacketStatusActive)
redPacketElem.IsExpired = isExpired
// 判断是否已领完
isFinished := redPacket.Status == model.RedPacketStatusFinished || redPacket.RemainCount <= 0
redPacketElem.IsFinished = isFinished
// 如果已过期,更新状态
if isExpired && redPacket.Status == model.RedPacketStatusActive {
redPacket.Status = model.RedPacketStatusExpired
redPacketElem.Status = model.RedPacketStatusExpired
}
// 查询用户是否已领取
receive, err := m.redPacketReceiveDB.FindByUserAndRedPacketID(ctx, userID, redPacketElem.RedPacketID)
if err != nil {
// 如果查询失败或未找到记录,说明未领取
redPacketElem.IsReceived = false
redPacketElem.ReceiveInfo = nil
} else {
// 已领取,填充领取信息(包括金额)
redPacketElem.IsReceived = true
redPacketElem.ReceiveInfo = &apistruct.RedPacketReceiveInfo{
Amount: receive.Amount,
ReceiveTime: receive.ReceiveTime.UnixMilli(),
IsLucky: false, // 已去掉手气最佳功能,始终返回 false
}
}
// 更新自定义消息的data字段包含领取信息和状态
redPacketData := jsonutil.StructToJsonString(redPacketElem)
customElem.Data = redPacketData
// 重新序列化并更新消息内容
newContent, err := json.Marshal(customElem)
if err != nil {
log.ZWarn(ctx, "enrichRedPacketMessages: failed to marshal custom elem", err, "redPacketID", redPacketElem.RedPacketID)
continue
}
msg.Content = newContent
}
return msgs
}
func (m *msgServer) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
conversationIDs, err := m.ConversationLocalCache.GetConversationIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
for _, conversationID := range conversationIDs {
conversationIDs = append(conversationIDs, conversationutil.GetNotificationConversationIDByConversationID(conversationID))
}
conversationIDs = append(conversationIDs, conversationutil.GetSelfNotificationConversationID(req.UserID))
log.ZDebug(ctx, "GetMaxSeq", "conversationIDs", conversationIDs)
maxSeqs, err := m.MsgDatabase.GetMaxSeqs(ctx, conversationIDs)
if err != nil {
log.ZWarn(ctx, "GetMaxSeqs error", err, "conversationIDs", conversationIDs, "maxSeqs", maxSeqs)
return nil, err
}
// avoid pulling messages from sessions with a large number of max seq values of 0
for conversationID, seq := range maxSeqs {
if seq == 0 {
delete(maxSeqs, conversationID)
}
}
resp := new(sdkws.GetMaxSeqResp)
resp.MaxSeqs = maxSeqs
return resp, nil
}
func (m *msgServer) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (resp *msg.SearchMessageResp, err error) {
// var chatLogs []*sdkws.MsgData
var chatLogs []*msg.SearchedMsgData
var total int64
resp = &msg.SearchMessageResp{}
if total, chatLogs, err = m.MsgDatabase.SearchMessage(ctx, req); err != nil {
return nil, err
}
var (
sendIDs []string
recvIDs []string
groupIDs []string
sendNameMap = make(map[string]string)
sendFaceMap = make(map[string]string)
recvMap = make(map[string]string)
groupMap = make(map[string]*sdkws.GroupInfo)
seenSendIDs = make(map[string]struct{})
seenRecvIDs = make(map[string]struct{})
seenGroupIDs = make(map[string]struct{})
)
for _, chatLog := range chatLogs {
if chatLog.MsgData.SenderNickname == "" || chatLog.MsgData.SenderFaceURL == "" {
if _, ok := seenSendIDs[chatLog.MsgData.SendID]; !ok {
seenSendIDs[chatLog.MsgData.SendID] = struct{}{}
sendIDs = append(sendIDs, chatLog.MsgData.SendID)
}
}
switch chatLog.MsgData.SessionType {
case constant.SingleChatType, constant.NotificationChatType:
if _, ok := seenRecvIDs[chatLog.MsgData.RecvID]; !ok {
seenRecvIDs[chatLog.MsgData.RecvID] = struct{}{}
recvIDs = append(recvIDs, chatLog.MsgData.RecvID)
}
case constant.WriteGroupChatType, constant.ReadGroupChatType:
if _, ok := seenGroupIDs[chatLog.MsgData.GroupID]; !ok {
seenGroupIDs[chatLog.MsgData.GroupID] = struct{}{}
groupIDs = append(groupIDs, chatLog.MsgData.GroupID)
}
}
}
// Retrieve sender and receiver information
if len(sendIDs) != 0 {
sendInfos, err := m.UserLocalCache.GetUsersInfo(ctx, sendIDs)
if err != nil {
return nil, err
}
for _, sendInfo := range sendInfos {
sendNameMap[sendInfo.UserID] = sendInfo.Nickname
sendFaceMap[sendInfo.UserID] = sendInfo.FaceURL
}
}
if len(recvIDs) != 0 {
recvInfos, err := m.UserLocalCache.GetUsersInfo(ctx, recvIDs)
if err != nil {
return nil, err
}
for _, recvInfo := range recvInfos {
recvMap[recvInfo.UserID] = recvInfo.Nickname
}
}
// Retrieve group information including member counts
if len(groupIDs) != 0 {
groupInfos, err := m.GroupLocalCache.GetGroupInfos(ctx, groupIDs)
if err != nil {
return nil, err
}
for _, groupInfo := range groupInfos {
groupMap[groupInfo.GroupID] = groupInfo
// Get actual member count
memberIDs, err := m.GroupLocalCache.GetGroupMemberIDs(ctx, groupInfo.GroupID)
if err == nil {
groupInfo.MemberCount = uint32(len(memberIDs)) // Update the member count with actual number
}
}
}
// Construct response with updated information
for _, chatLog := range chatLogs {
pbchatLog := &msg.ChatLog{}
datautil.CopyStructFields(pbchatLog, chatLog.MsgData)
pbchatLog.SendTime = chatLog.MsgData.SendTime
pbchatLog.CreateTime = chatLog.MsgData.CreateTime
if chatLog.MsgData.SenderNickname == "" {
pbchatLog.SenderNickname = sendNameMap[chatLog.MsgData.SendID]
}
if chatLog.MsgData.SenderFaceURL == "" {
pbchatLog.SenderFaceURL = sendFaceMap[chatLog.MsgData.SendID]
}
switch chatLog.MsgData.SessionType {
case constant.SingleChatType, constant.NotificationChatType:
pbchatLog.RecvNickname = recvMap[chatLog.MsgData.RecvID]
case constant.ReadGroupChatType:
groupInfo := groupMap[chatLog.MsgData.GroupID]
pbchatLog.GroupMemberCount = groupInfo.MemberCount // Reflects actual member count
pbchatLog.RecvID = groupInfo.GroupID
pbchatLog.GroupName = groupInfo.GroupName
pbchatLog.GroupOwner = groupInfo.OwnerUserID
pbchatLog.GroupType = groupInfo.GroupType
}
searchChatLog := &msg.SearchChatLog{ChatLog: pbchatLog, IsRevoked: chatLog.IsRevoked}
resp.ChatLogs = append(resp.ChatLogs, searchChatLog)
}
resp.ChatLogsNum = int32(total)
return resp, nil
}
func (m *msgServer) GetServerTime(ctx context.Context, _ *msg.GetServerTimeReq) (*msg.GetServerTimeResp, error) {
return &msg.GetServerTimeResp{ServerTime: timeutil.GetCurrentTimestampByMill()}, nil
}
func (m *msgServer) GetLastMessage(ctx context.Context, req *msg.GetLastMessageReq) (*msg.GetLastMessageResp, error) {
msgs, err := m.MsgDatabase.GetLastMessage(ctx, req.ConversationIDs, req.UserID)
if err != nil {
return nil, err
}
return &msg.GetLastMessageResp{Msgs: msgs}, nil
}

91
internal/rpc/msg/utils.go Normal file
View File

@@ -0,0 +1,91 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/mongo"
)
func IsNotFound(err error) bool {
switch errs.Unwrap(err) {
case redis.Nil, mongo.ErrNoDocuments:
return true
default:
return false
}
}
type activeConversations []*msg.ActiveConversation
func (s activeConversations) Len() int {
return len(s)
}
func (s activeConversations) Less(i, j int) bool {
return s[i].LastTime > s[j].LastTime
}
func (s activeConversations) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
//type seqTime struct {
// ConversationID string
// Seq int64
// Time int64
// Unread int64
// Pinned bool
//}
//
//func (s seqTime) String() string {
// return fmt.Sprintf("<Time_%d,Unread_%d,Pinned_%t>", s.Time, s.Unread, s.Pinned)
//}
//
//type seqTimes []seqTime
//
//func (s seqTimes) Len() int {
// return len(s)
//}
//
//// Less sticky priority, unread priority, time descending
//func (s seqTimes) Less(i, j int) bool {
// iv, jv := s[i], s[j]
// if iv.Pinned && (!jv.Pinned) {
// return true
// }
// if jv.Pinned && (!iv.Pinned) {
// return false
// }
// if iv.Unread > 0 && jv.Unread == 0 {
// return true
// }
// if jv.Unread > 0 && iv.Unread == 0 {
// return false
// }
// return iv.Time > jv.Time
//}
//
//func (s seqTimes) Swap(i, j int) {
// s[i], s[j] = s[j], s[i]
//}
//
//type conversationStatus struct {
// ConversationID string
// Pinned bool
// Recv bool
//}

405
internal/rpc/msg/verify.go Normal file
View File

@@ -0,0 +1,405 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package msg
import (
"context"
"encoding/json"
"math/rand"
"regexp"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/encrypt"
"github.com/openimsdk/tools/utils/timeutil"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
var ExcludeContentType = []int{constant.HasReadReceipt}
// URL正则表达式匹配常见的URL格式
// 匹配以下格式的链接:
// 1. http:// 或 https:// 开头的完整URL
// 2. // 开头的协议相对URL如 //s.yam.com/MvQzr
// 3. www. 开头的链接
// 4. 直接包含域名的链接(如 s.yam.com/MvQzr、xxx.cc/csd、t.cn/AX4fYkFZ匹配包含至少一个点的域名格式后面可跟路径
// 域名格式至少包含一个点和一个2位以上的顶级域名如 xxx.com、xxx.cn、s.yam.com、xxx.cc、t.cn 等
// 注意:匹配 // 后面必须跟着域名格式,避免误匹配其他 // 开头的文本
// 修复:支持单字符域名(如 t.cn移除了点之前必须有两个字符的限制
var urlRegex = regexp.MustCompile(`(?i)(https?://[^\s<>"{}|\\^` + "`" + `\[\]]+|//[a-zA-Z0-9][a-zA-Z0-9\-\.]*\.[a-zA-Z]{2,}(/[^\s<>"{}|\\^` + "`" + `\[\]]*)?|www\.[^\s<>"{}|\\^` + "`" + `\[\]]+|[a-zA-Z0-9][a-zA-Z0-9\-\.]*\.[a-zA-Z]{2,}(/[^\s<>"{}|\\^` + "`" + `\[\]]*)?)`)
type Validator interface {
validate(pb *msg.SendMsgReq) (bool, int32, string)
}
// TextElem 用于解析文本消息内容
type TextElem struct {
Content string `json:"content"`
}
// AtElem 用于解析@消息内容
type AtElem struct {
Text string `json:"text"`
}
// checkMessageContainsLink 检测消息内容中是否包含链接
// userType: 0-普通用户不能发送链接1-特殊用户(可以发送链接)
func (m *msgServer) checkMessageContainsLink(msgData *sdkws.MsgData, userType int32) error {
// userType=1 的用户可以发送链接,不进行检测
if userType == 1 {
return nil
}
// 只检测文本类型的消息
if msgData.ContentType != constant.Text && msgData.ContentType != constant.AtText {
return nil
}
var textContent string
var err error
// 解析消息内容
if msgData.ContentType == constant.Text {
var textElem TextElem
if err = json.Unmarshal(msgData.Content, &textElem); err != nil {
// 如果解析失败,尝试直接使用字符串
textContent = string(msgData.Content)
} else {
textContent = textElem.Content
}
} else if msgData.ContentType == constant.AtText {
var atElem AtElem
if err = json.Unmarshal(msgData.Content, &atElem); err != nil {
// 如果解析失败,尝试直接使用字符串
textContent = string(msgData.Content)
} else {
textContent = atElem.Text
}
}
// 检测是否包含链接
if urlRegex.MatchString(textContent) {
return servererrs.ErrMessageContainsLink.WrapMsg("userType=0的用户不能发送包含链接的消息")
}
return nil
}
type MessageRevoked struct {
RevokerID string `json:"revokerID"`
RevokerRole int32 `json:"revokerRole"`
ClientMsgID string `json:"clientMsgID"`
RevokerNickname string `json:"revokerNickname"`
RevokeTime int64 `json:"revokeTime"`
SourceMessageSendTime int64 `json:"sourceMessageSendTime"`
SourceMessageSendID string `json:"sourceMessageSendID"`
SourceMessageSenderNickname string `json:"sourceMessageSenderNickname"`
SessionType int32 `json:"sessionType"`
Seq uint32 `json:"seq"`
}
func (m *msgServer) messageVerification(ctx context.Context, data *msg.SendMsgReq) error {
webhookCfg := m.webhookConfig()
switch data.MsgData.SessionType {
case constant.SingleChatType:
if datautil.Contain(data.MsgData.SendID, m.adminUserIDs...) {
return nil
}
if data.MsgData.ContentType <= constant.NotificationEnd &&
data.MsgData.ContentType >= constant.NotificationBegin {
return nil
}
if err := m.webhookBeforeSendSingleMsg(ctx, &webhookCfg.BeforeSendSingleMsg, data); err != nil {
return err
}
u, err := m.UserLocalCache.GetUserInfo(ctx, data.MsgData.SendID)
if err != nil {
return err
}
if authverify.CheckSystemAccount(ctx, u.AppMangerLevel) {
return nil
}
// userType=1 的用户可以发送链接和二维码,不进行检测
userType := u.GetUserType()
if userType != 1 {
// 检测userType=0的用户是否发送了包含链接的消息
if err := m.checkMessageContainsLink(data.MsgData, userType); err != nil {
return err
}
// 检测userType=0的用户是否发送了包含二维码的图片
if err := m.checkImageContainsQRCode(ctx, data.MsgData, userType); err != nil {
return err
}
}
// 单聊中:不限制文件发送权限,所有用户都可以发送文件
black, err := m.FriendLocalCache.IsBlack(ctx, data.MsgData.SendID, data.MsgData.RecvID)
if err != nil {
return err
}
if black {
return servererrs.ErrBlockedByPeer.Wrap()
}
if m.config.RpcConfig.FriendVerify {
friend, err := m.FriendLocalCache.IsFriend(ctx, data.MsgData.SendID, data.MsgData.RecvID)
if err != nil {
return err
}
if !friend {
return servererrs.ErrNotPeersFriend.Wrap()
}
return nil
}
return nil
case constant.ReadGroupChatType:
groupInfo, err := m.GroupLocalCache.GetGroupInfo(ctx, data.MsgData.GroupID)
if err != nil {
return err
}
if groupInfo.Status == constant.GroupStatusDismissed &&
data.MsgData.ContentType != constant.GroupDismissedNotification {
return servererrs.ErrDismissedAlready.Wrap()
}
// 检查是否为系统管理员,系统管理员跳过部分检查
isSystemAdmin := datautil.Contain(data.MsgData.SendID, m.adminUserIDs...)
// 通知消息类型跳过检查
if data.MsgData.ContentType <= constant.NotificationEnd &&
data.MsgData.ContentType >= constant.NotificationBegin {
return nil
}
// SuperGroup 跳过部分检查,但仍需检查文件权限
if groupInfo.GroupType == constant.SuperGroup {
// SuperGroup 也需要检查文件发送权限
if data.MsgData.ContentType == constant.File {
if isSystemAdmin {
// 系统管理员可以发送文件
return nil
}
memberIDs, err := m.GroupLocalCache.GetGroupMemberIDMap(ctx, data.MsgData.GroupID)
if err != nil {
return err
}
if _, ok := memberIDs[data.MsgData.SendID]; !ok {
return servererrs.ErrNotInGroupYet.Wrap()
}
groupMemberInfo, err := m.GroupLocalCache.GetGroupMember(ctx, data.MsgData.GroupID, data.MsgData.SendID)
if err != nil {
if errs.ErrRecordNotFound.Is(err) {
return servererrs.ErrNotInGroupYet.WrapMsg(err.Error())
}
return err
}
u, err := m.UserLocalCache.GetUserInfo(ctx, data.MsgData.SendID)
if err != nil {
return err
}
isGroupOwner := groupMemberInfo.RoleLevel == constant.GroupOwner
isGroupAdmin := groupMemberInfo.RoleLevel == constant.GroupAdmin
canSendFile := u.GetUserType() == 1 || isGroupOwner || isGroupAdmin
if !canSendFile {
return servererrs.ErrNoPermission.WrapMsg("only group owner, admin, or userType=1 can send files in group chat")
}
}
return nil
}
// 先获取用户信息用于检查userType系统管理员和userType=1的用户可以发送文件
memberIDs, err := m.GroupLocalCache.GetGroupMemberIDMap(ctx, data.MsgData.GroupID)
if err != nil {
return err
}
if _, ok := memberIDs[data.MsgData.SendID]; !ok {
return servererrs.ErrNotInGroupYet.Wrap()
}
groupMemberInfo, err := m.GroupLocalCache.GetGroupMember(ctx, data.MsgData.GroupID, data.MsgData.SendID)
if err != nil {
if errs.ErrRecordNotFound.Is(err) {
return servererrs.ErrNotInGroupYet.WrapMsg(err.Error())
}
return err
}
// 获取用户信息以获取userType
u, err := m.UserLocalCache.GetUserInfo(ctx, data.MsgData.SendID)
if err != nil {
return err
}
// 群聊中:检查文件发送权限(优先检查,确保不会被其他逻辑跳过)
// 只有群主、管理员、userType=1的用户可以发送文件
// 系统管理员也可以发送文件
if data.MsgData.ContentType == constant.File {
isGroupOwner := groupMemberInfo.RoleLevel == constant.GroupOwner
isGroupAdmin := groupMemberInfo.RoleLevel == constant.GroupAdmin
userType := u.GetUserType()
// 如果是文件消息且 userType=0 且不是群主/管理员,可能是缓存问题,尝试清除缓存并重新获取
if userType == 0 && !isGroupOwner && !isGroupAdmin && !isSystemAdmin {
// 清除本地缓存
m.UserLocalCache.DelUserInfo(ctx, data.MsgData.SendID)
// 重新获取用户信息
u, err = m.UserLocalCache.GetUserInfo(ctx, data.MsgData.SendID)
if err != nil {
return err
}
userType = u.GetUserType()
}
canSendFile := isSystemAdmin || userType == 1 || isGroupOwner || isGroupAdmin
if !canSendFile {
return servererrs.ErrNoPermission.WrapMsg("only group owner, admin, or userType=1 can send files in group chat")
}
}
if isSystemAdmin {
// 系统管理员跳过大部分检查
return nil
}
// 群聊中userType=1、群主、群管理员可以发送链接
isGroupOwner := groupMemberInfo.RoleLevel == constant.GroupOwner
isGroupAdmin := groupMemberInfo.RoleLevel == constant.GroupAdmin
canSendLink := u.GetUserType() == 1 || isGroupOwner || isGroupAdmin
// 如果不符合发送链接的条件,进行链接检测
if !canSendLink {
if err := m.checkMessageContainsLink(data.MsgData, u.GetUserType()); err != nil {
return err
}
}
// 群聊中检测userType=0的普通成员是否发送了包含二维码的图片
// userType=1、群主、群管理员可以发送包含二维码的图片不进行检测
if !canSendLink {
if err := m.checkImageContainsQRCode(ctx, data.MsgData, u.GetUserType()); err != nil {
return err
}
}
if isGroupOwner {
return nil
} else {
nowUnixMilli := time.Now().UnixMilli()
// 记录禁言检查信息,用于排查自动禁言问题
if groupMemberInfo.MuteEndTime > 0 {
muteEndTime := time.UnixMilli(groupMemberInfo.MuteEndTime)
isMuted := groupMemberInfo.MuteEndTime >= nowUnixMilli
log.ZInfo(ctx, "messageVerification: checking mute status",
"groupID", data.MsgData.GroupID,
"userID", data.MsgData.SendID,
"muteEndTimeTimestamp", groupMemberInfo.MuteEndTime,
"muteEndTime", muteEndTime.Format(time.RFC3339),
"now", time.UnixMilli(nowUnixMilli).Format(time.RFC3339),
"isMuted", isMuted,
"mutedDurationSeconds", (groupMemberInfo.MuteEndTime-nowUnixMilli)/1000,
"roleLevel", groupMemberInfo.RoleLevel)
}
if groupMemberInfo.MuteEndTime >= nowUnixMilli {
return servererrs.ErrMutedInGroup.Wrap()
}
if groupInfo.Status == constant.GroupStatusMuted && !isGroupAdmin {
return servererrs.ErrMutedGroup.Wrap()
}
}
return nil
default:
return nil
}
}
func (m *msgServer) encapsulateMsgData(msg *sdkws.MsgData) {
msg.ServerMsgID = GetMsgID(msg.SendID)
if msg.SendTime == 0 {
msg.SendTime = timeutil.GetCurrentTimestampByMill()
}
switch msg.ContentType {
case constant.Text, constant.Picture, constant.Voice, constant.Video,
constant.File, constant.AtText, constant.Merger, constant.Card,
constant.Location, constant.Custom, constant.Quote, constant.AdvancedText, constant.MarkdownText:
case constant.Revoke:
datautil.SetSwitchFromOptions(msg.Options, constant.IsUnreadCount, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsOfflinePush, false)
case constant.HasReadReceipt:
datautil.SetSwitchFromOptions(msg.Options, constant.IsConversationUpdate, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsSenderConversationUpdate, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsUnreadCount, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsOfflinePush, false)
case constant.Typing:
datautil.SetSwitchFromOptions(msg.Options, constant.IsHistory, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsPersistent, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsSenderSync, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsConversationUpdate, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsSenderConversationUpdate, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsUnreadCount, false)
datautil.SetSwitchFromOptions(msg.Options, constant.IsOfflinePush, false)
}
}
func GetMsgID(sendID string) string {
t := timeutil.GetCurrentTimeFormatted()
return encrypt.Md5(t + "-" + sendID + "-" + strconv.Itoa(rand.Int()))
}
func (m *msgServer) modifyMessageByUserMessageReceiveOpt(ctx context.Context, userID, conversationID string, sessionType int, pb *msg.SendMsgReq) (bool, error) {
opt, err := m.UserLocalCache.GetUserGlobalMsgRecvOpt(ctx, userID)
if err != nil {
return false, err
}
switch opt {
case constant.ReceiveMessage:
case constant.NotReceiveMessage:
return false, nil
case constant.ReceiveNotNotifyMessage:
if pb.MsgData.Options == nil {
pb.MsgData.Options = make(map[string]bool, 10)
}
datautil.SetSwitchFromOptions(pb.MsgData.Options, constant.IsOfflinePush, false)
return true, nil
}
singleOpt, err := m.ConversationLocalCache.GetSingleConversationRecvMsgOpt(ctx, userID, conversationID)
if errs.ErrRecordNotFound.Is(err) {
return true, nil
} else if err != nil {
return false, err
}
switch singleOpt {
case constant.ReceiveMessage:
return true, nil
case constant.NotReceiveMessage:
if datautil.Contain(int(pb.MsgData.ContentType), ExcludeContentType...) {
return true, nil
}
return false, nil
case constant.ReceiveNotNotifyMessage:
if pb.MsgData.Options == nil {
pb.MsgData.Options = make(map[string]bool, 10)
}
datautil.SetSwitchFromOptions(pb.MsgData.Options, constant.IsOfflinePush, false)
return true, nil
}
return true, nil
}

View File

@@ -0,0 +1,165 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relation
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func (s *friendServer) GetPaginationBlacks(ctx context.Context, req *relation.GetPaginationBlacksReq) (resp *relation.GetPaginationBlacksResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
total, blacks, err := s.blackDatabase.FindOwnerBlacks(ctx, req.UserID, req.Pagination)
if err != nil {
return nil, err
}
resp = &relation.GetPaginationBlacksResp{}
resp.Blacks, err = convert.BlackDB2Pb(ctx, blacks, s.userClient.GetUsersInfoMap)
if err != nil {
return nil, err
}
resp.Total = int32(total)
return resp, nil
}
func (s *friendServer) IsBlack(ctx context.Context, req *relation.IsBlackReq) (*relation.IsBlackResp, error) {
if err := authverify.CheckAccessIn(ctx, req.UserID1, req.UserID2); err != nil {
return nil, err
}
in1, in2, err := s.blackDatabase.CheckIn(ctx, req.UserID1, req.UserID2)
if err != nil {
return nil, err
}
resp := &relation.IsBlackResp{}
resp.InUser1Blacks = in1
resp.InUser2Blacks = in2
return resp, nil
}
func (s *friendServer) RemoveBlack(ctx context.Context, req *relation.RemoveBlackReq) (*relation.RemoveBlackResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
if err := s.blackDatabase.Delete(ctx, []*model.Black{{OwnerUserID: req.OwnerUserID, BlockUserID: req.BlackUserID}}); err != nil {
return nil, err
}
s.notificationSender.BlackDeletedNotification(ctx, req)
s.webhookAfterRemoveBlack(ctx, &s.config.WebhooksConfig.AfterRemoveBlack, req)
return &relation.RemoveBlackResp{}, nil
}
func (s *friendServer) AddBlack(ctx context.Context, req *relation.AddBlackReq) (*relation.AddBlackResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
if err := s.webhookBeforeAddBlack(ctx, &s.config.WebhooksConfig.BeforeAddBlack, req); err != nil {
return nil, err
}
if err := s.userClient.CheckUser(ctx, []string{req.OwnerUserID, req.BlackUserID}); err != nil {
return nil, err
}
black := model.Black{
OwnerUserID: req.OwnerUserID,
BlockUserID: req.BlackUserID,
OperatorUserID: mcontext.GetOpUserID(ctx),
CreateTime: time.Now(),
Ex: req.Ex,
}
if err := s.blackDatabase.Create(ctx, []*model.Black{&black}); err != nil {
return nil, err
}
s.notificationSender.BlackAddedNotification(ctx, req)
return &relation.AddBlackResp{}, nil
}
func (s *friendServer) GetSpecifiedBlacks(ctx context.Context, req *relation.GetSpecifiedBlacksReq) (*relation.GetSpecifiedBlacksResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
if len(req.UserIDList) == 0 {
return nil, errs.ErrArgs.WrapMsg("userIDList is empty")
}
if datautil.Duplicate(req.UserIDList) {
return nil, errs.ErrArgs.WrapMsg("userIDList repeated")
}
userMap, err := s.userClient.GetUsersInfoMap(ctx, req.UserIDList)
if err != nil {
return nil, err
}
blacks, err := s.blackDatabase.FindBlackInfos(ctx, req.OwnerUserID, req.UserIDList)
if err != nil {
return nil, err
}
blackMap := datautil.SliceToMap(blacks, func(e *model.Black) string {
return e.BlockUserID
})
resp := &relation.GetSpecifiedBlacksResp{
Blacks: make([]*sdkws.BlackInfo, 0, len(req.UserIDList)),
}
toPublcUser := func(userID string) *sdkws.PublicUserInfo {
v, ok := userMap[userID]
if !ok {
return nil
}
return &sdkws.PublicUserInfo{
UserID: v.UserID,
Nickname: v.Nickname,
FaceURL: v.FaceURL,
Ex: v.Ex,
UserType: v.UserType,
}
}
for _, userID := range req.UserIDList {
if black := blackMap[userID]; black != nil {
resp.Blacks = append(resp.Blacks,
&sdkws.BlackInfo{
OwnerUserID: black.OwnerUserID,
CreateTime: black.CreateTime.UnixMilli(),
BlackUserInfo: toPublcUser(userID),
AddSource: black.AddSource,
OperatorUserID: black.OperatorUserID,
Ex: black.Ex,
})
}
}
resp.Total = int32(len(resp.Blacks))
return resp, nil
}

View File

@@ -0,0 +1,169 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
cbapi "git.imall.cloud/openim/open-im-server-deploy/pkg/callbackstruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/relation"
)
func (s *friendServer) webhookAfterDeleteFriend(ctx context.Context, after *config.AfterConfig, req *relation.DeleteFriendReq) {
cbReq := &cbapi.CallbackAfterDeleteFriendReq{
CallbackCommand: cbapi.CallbackAfterDeleteFriendCommand,
OwnerUserID: req.OwnerUserID,
FriendUserID: req.FriendUserID,
}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, &cbapi.CallbackAfterDeleteFriendResp{}, after)
}
func (s *friendServer) webhookBeforeAddFriend(ctx context.Context, before *config.BeforeConfig, req *relation.ApplyToAddFriendReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &cbapi.CallbackBeforeAddFriendReq{
CallbackCommand: cbapi.CallbackBeforeAddFriendCommand,
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
ReqMsg: req.ReqMsg,
Ex: req.Ex,
}
resp := &cbapi.CallbackBeforeAddFriendResp{}
if err := s.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
return nil
})
}
func (s *friendServer) webhookAfterAddFriend(ctx context.Context, after *config.AfterConfig, req *relation.ApplyToAddFriendReq) {
cbReq := &cbapi.CallbackAfterAddFriendReq{
CallbackCommand: cbapi.CallbackAfterAddFriendCommand,
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
ReqMsg: req.ReqMsg,
}
resp := &cbapi.CallbackAfterAddFriendResp{}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, after)
}
func (s *friendServer) webhookAfterSetFriendRemark(ctx context.Context, after *config.AfterConfig, req *relation.SetFriendRemarkReq) {
cbReq := &cbapi.CallbackAfterSetFriendRemarkReq{
CallbackCommand: cbapi.CallbackAfterSetFriendRemarkCommand,
OwnerUserID: req.OwnerUserID,
FriendUserID: req.FriendUserID,
Remark: req.Remark,
}
resp := &cbapi.CallbackAfterSetFriendRemarkResp{}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, after)
}
func (s *friendServer) webhookAfterImportFriends(ctx context.Context, after *config.AfterConfig, req *relation.ImportFriendReq) {
cbReq := &cbapi.CallbackAfterImportFriendsReq{
CallbackCommand: cbapi.CallbackAfterImportFriendsCommand,
OwnerUserID: req.OwnerUserID,
FriendUserIDs: req.FriendUserIDs,
}
resp := &cbapi.CallbackAfterImportFriendsResp{}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, after)
}
func (s *friendServer) webhookAfterRemoveBlack(ctx context.Context, after *config.AfterConfig, req *relation.RemoveBlackReq) {
cbReq := &cbapi.CallbackAfterRemoveBlackReq{
CallbackCommand: cbapi.CallbackAfterRemoveBlackCommand,
OwnerUserID: req.OwnerUserID,
BlackUserID: req.BlackUserID,
}
resp := &cbapi.CallbackAfterRemoveBlackResp{}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, after)
}
func (s *friendServer) webhookBeforeSetFriendRemark(ctx context.Context, before *config.BeforeConfig, req *relation.SetFriendRemarkReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &cbapi.CallbackBeforeSetFriendRemarkReq{
CallbackCommand: cbapi.CallbackBeforeSetFriendRemarkCommand,
OwnerUserID: req.OwnerUserID,
FriendUserID: req.FriendUserID,
Remark: req.Remark,
}
resp := &cbapi.CallbackBeforeSetFriendRemarkResp{}
if err := s.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
if resp.Remark != "" {
req.Remark = resp.Remark
}
return nil
})
}
func (s *friendServer) webhookBeforeAddBlack(ctx context.Context, before *config.BeforeConfig, req *relation.AddBlackReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &cbapi.CallbackBeforeAddBlackReq{
CallbackCommand: cbapi.CallbackBeforeAddBlackCommand,
OwnerUserID: req.OwnerUserID,
BlackUserID: req.BlackUserID,
}
resp := &cbapi.CallbackBeforeAddBlackResp{}
return s.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before)
})
}
func (s *friendServer) webhookBeforeAddFriendAgree(ctx context.Context, before *config.BeforeConfig, req *relation.RespondFriendApplyReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &cbapi.CallbackBeforeAddFriendAgreeReq{
CallbackCommand: cbapi.CallbackBeforeAddFriendAgreeCommand,
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
HandleMsg: req.HandleMsg,
HandleResult: req.HandleResult,
}
resp := &cbapi.CallbackBeforeAddFriendAgreeResp{}
return s.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before)
})
}
func (s *friendServer) webhookAfterAddFriendAgree(ctx context.Context, after *config.AfterConfig, req *relation.RespondFriendApplyReq) {
cbReq := &cbapi.CallbackAfterAddFriendAgreeReq{
CallbackCommand: cbapi.CallbackAfterAddFriendAgreeCommand,
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
HandleMsg: req.HandleMsg,
HandleResult: req.HandleResult,
}
resp := &cbapi.CallbackAfterAddFriendAgreeResp{}
s.webhookClient.AsyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, after)
}
func (s *friendServer) webhookBeforeImportFriends(ctx context.Context, before *config.BeforeConfig, req *relation.ImportFriendReq) error {
return webhook.WithCondition(ctx, before, func(ctx context.Context) error {
cbReq := &cbapi.CallbackBeforeImportFriendsReq{
CallbackCommand: cbapi.CallbackBeforeImportFriendsCommand,
OwnerUserID: req.OwnerUserID,
FriendUserIDs: req.FriendUserIDs,
}
resp := &cbapi.CallbackBeforeImportFriendsResp{}
if err := s.webhookClient.SyncPost(ctx, cbReq.GetCallbackCommand(), cbReq, resp, before); err != nil {
return err
}
if len(resp.FriendUserIDs) > 0 {
req.FriendUserIDs = resp.FriendUserIDs
}
return nil
})
}

View File

@@ -0,0 +1,597 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification/common_user"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/openimsdk/tools/mq/memamq"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/localcache"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"google.golang.org/grpc"
)
type friendServer struct {
relation.UnimplementedFriendServer
db controller.FriendDatabase
blackDatabase controller.BlackDatabase
notificationSender *FriendNotificationSender
RegisterCenter discovery.Conn
config *Config
webhookClient *webhook.Client
queue *memamq.MemoryQueue
userClient *rpcli.UserClient
}
type Config struct {
RpcConfig config.Friend
RedisConfig config.Redis
MongodbConfig config.Mongo
// ZookeeperConfig config.ZooKeeper
NotificationConfig config.Notification
Share config.Share
WebhooksConfig config.Webhooks
LocalCacheConfig config.LocalCache
Discovery config.Discovery
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, server grpc.ServiceRegistrar) error {
dbb := dbbuild.NewBuilder(&config.MongodbConfig, &config.RedisConfig)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return err
}
rdb, err := dbb.Redis(ctx)
if err != nil {
return err
}
friendMongoDB, err := mgo.NewFriendMongo(mgocli.GetDB())
if err != nil {
return err
}
friendRequestMongoDB, err := mgo.NewFriendRequestMongo(mgocli.GetDB())
if err != nil {
return err
}
blackMongoDB, err := mgo.NewBlackMongo(mgocli.GetDB())
if err != nil {
return err
}
userConn, err := client.GetConn(ctx, config.Discovery.RpcService.User)
if err != nil {
return err
}
msgConn, err := client.GetConn(ctx, config.Discovery.RpcService.Msg)
if err != nil {
return err
}
userClient := rpcli.NewUserClient(userConn)
database := controller.NewFriendDatabase(
friendMongoDB,
friendRequestMongoDB,
redis.NewFriendCacheRedis(rdb, &config.LocalCacheConfig, friendMongoDB),
mgocli.GetTx(),
)
// Initialize notification sender
notificationSender := NewFriendNotificationSender(
&config.NotificationConfig,
rpcli.NewMsgClient(msgConn),
WithRpcFunc(userClient.GetUsersInfo),
WithFriendDB(database),
)
localcache.InitLocalCache(&config.LocalCacheConfig)
// 初始化webhook配置管理器支持从数据库读取配置
var webhookClient *webhook.Client
systemConfigDB, err := mgo.NewSystemConfigMongo(mgocli.GetDB())
if err == nil {
// 如果SystemConfig数据库初始化成功使用配置管理器
webhookConfigManager := webhook.NewConfigManager(systemConfigDB, &config.WebhooksConfig)
if err := webhookConfigManager.Start(ctx); err != nil {
log.ZWarn(ctx, "failed to start webhook config manager, using default config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
} else {
webhookClient = webhook.NewWebhookClientWithManager(webhookConfigManager)
}
} else {
// 如果SystemConfig数据库初始化失败使用默认配置
log.ZWarn(ctx, "failed to init system config db, using default webhook config", err)
webhookClient = webhook.NewWebhookClient(config.WebhooksConfig.URL)
}
// Register Friend server with refactored MongoDB and Redis integrations
relation.RegisterFriendServer(server, &friendServer{
db: database,
blackDatabase: controller.NewBlackDatabase(
blackMongoDB,
redis.NewBlackCacheRedis(rdb, &config.LocalCacheConfig, blackMongoDB),
),
notificationSender: notificationSender,
RegisterCenter: client,
config: config,
webhookClient: webhookClient,
queue: memamq.NewMemoryQueue(16, 1024*1024),
userClient: userClient,
})
return nil
}
// ok.
func (s *friendServer) ApplyToAddFriend(ctx context.Context, req *relation.ApplyToAddFriendReq) (resp *relation.ApplyToAddFriendResp, err error) {
resp = &relation.ApplyToAddFriendResp{}
if err := authverify.CheckAccess(ctx, req.FromUserID); err != nil {
return nil, err
}
if req.ToUserID == req.FromUserID {
return nil, servererrs.ErrCanNotAddYourself.WrapMsg("req.ToUserID", req.ToUserID)
}
if err = s.webhookBeforeAddFriend(ctx, &s.config.WebhooksConfig.BeforeAddFriend, req); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
if err := s.userClient.CheckUser(ctx, []string{req.ToUserID, req.FromUserID}); err != nil {
return nil, err
}
in1, in2, err := s.db.CheckIn(ctx, req.FromUserID, req.ToUserID)
if err != nil {
return nil, err
}
if in1 && in2 {
return nil, servererrs.ErrRelationshipAlready.WrapMsg("already friends has f")
}
if err = s.db.AddFriendRequest(ctx, req.FromUserID, req.ToUserID, req.ReqMsg, req.Ex); err != nil {
return nil, err
}
s.notificationSender.FriendApplicationAddNotification(ctx, req)
s.webhookAfterAddFriend(ctx, &s.config.WebhooksConfig.AfterAddFriend, req)
return resp, nil
}
// ok.
func (s *friendServer) ImportFriends(ctx context.Context, req *relation.ImportFriendReq) (resp *relation.ImportFriendResp, err error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
if err := s.userClient.CheckUser(ctx, append([]string{req.OwnerUserID}, req.FriendUserIDs...)); err != nil {
return nil, err
}
if datautil.Contain(req.OwnerUserID, req.FriendUserIDs...) {
return nil, servererrs.ErrCanNotAddYourself.WrapMsg("can not add yourself")
}
if datautil.Duplicate(req.FriendUserIDs) {
return nil, errs.ErrArgs.WrapMsg("friend userID repeated")
}
if err := s.webhookBeforeImportFriends(ctx, &s.config.WebhooksConfig.BeforeImportFriends, req); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
if err := s.db.BecomeFriends(ctx, req.OwnerUserID, req.FriendUserIDs, constant.BecomeFriendByImport); err != nil {
return nil, err
}
for _, userID := range req.FriendUserIDs {
s.notificationSender.FriendApplicationAgreedNotification(ctx, &relation.RespondFriendApplyReq{
FromUserID: req.OwnerUserID,
ToUserID: userID,
HandleResult: constant.FriendResponseAgree,
}, false)
}
s.webhookAfterImportFriends(ctx, &s.config.WebhooksConfig.AfterImportFriends, req)
return &relation.ImportFriendResp{}, nil
}
// ok.
func (s *friendServer) RespondFriendApply(ctx context.Context, req *relation.RespondFriendApplyReq) (resp *relation.RespondFriendApplyResp, err error) {
resp = &relation.RespondFriendApplyResp{}
if err := authverify.CheckAccess(ctx, req.ToUserID); err != nil {
return nil, err
}
friendRequest := model.FriendRequest{
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
HandleMsg: req.HandleMsg,
HandleResult: req.HandleResult,
}
if req.HandleResult == constant.FriendResponseAgree {
if err := s.webhookBeforeAddFriendAgree(ctx, &s.config.WebhooksConfig.BeforeAddFriendAgree, req); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
err := s.db.AgreeFriendRequest(ctx, &friendRequest)
if err != nil {
return nil, err
}
s.webhookAfterAddFriendAgree(ctx, &s.config.WebhooksConfig.AfterAddFriendAgree, req)
s.notificationSender.FriendApplicationAgreedNotification(ctx, req, true)
return resp, nil
}
if req.HandleResult == constant.FriendResponseRefuse {
err := s.db.RefuseFriendRequest(ctx, &friendRequest)
if err != nil {
return nil, err
}
s.notificationSender.FriendApplicationRefusedNotification(ctx, req)
return resp, nil
}
return nil, errs.ErrArgs.WrapMsg("req.HandleResult != -1/1")
}
// ok.
func (s *friendServer) DeleteFriend(ctx context.Context, req *relation.DeleteFriendReq) (resp *relation.DeleteFriendResp, err error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
_, err = s.db.FindFriendsWithError(ctx, req.OwnerUserID, []string{req.FriendUserID})
if err != nil {
return nil, err
}
if err := s.db.Delete(ctx, req.OwnerUserID, []string{req.FriendUserID}); err != nil {
return nil, err
}
s.notificationSender.FriendDeletedNotification(ctx, req)
s.webhookAfterDeleteFriend(ctx, &s.config.WebhooksConfig.AfterDeleteFriend, req)
return &relation.DeleteFriendResp{}, nil
}
// ok.
func (s *friendServer) SetFriendRemark(ctx context.Context, req *relation.SetFriendRemarkReq) (resp *relation.SetFriendRemarkResp, err error) {
if err = s.webhookBeforeSetFriendRemark(ctx, &s.config.WebhooksConfig.BeforeSetFriendRemark, req); err != nil && err != servererrs.ErrCallbackContinue {
return nil, err
}
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
_, err = s.db.FindFriendsWithError(ctx, req.OwnerUserID, []string{req.FriendUserID})
if err != nil {
return nil, err
}
if err := s.db.UpdateRemark(ctx, req.OwnerUserID, req.FriendUserID, req.Remark); err != nil {
return nil, err
}
s.webhookAfterSetFriendRemark(ctx, &s.config.WebhooksConfig.AfterSetFriendRemark, req)
s.notificationSender.FriendRemarkSetNotification(ctx, req.OwnerUserID, req.FriendUserID)
return &relation.SetFriendRemarkResp{}, nil
}
func (s *friendServer) GetFriendInfo(ctx context.Context, req *relation.GetFriendInfoReq) (*relation.GetFriendInfoResp, error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
friends, err := s.db.FindFriendsWithError(ctx, req.OwnerUserID, req.FriendUserIDs)
if err != nil {
return nil, err
}
return &relation.GetFriendInfoResp{FriendInfos: convert.FriendOnlyDB2PbOnly(friends)}, nil
}
func (s *friendServer) GetDesignatedFriends(ctx context.Context, req *relation.GetDesignatedFriendsReq) (resp *relation.GetDesignatedFriendsResp, err error) {
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
resp = &relation.GetDesignatedFriendsResp{}
if datautil.Duplicate(req.FriendUserIDs) {
return nil, errs.ErrArgs.WrapMsg("friend userID repeated")
}
friends, err := s.getFriend(ctx, req.OwnerUserID, req.FriendUserIDs)
if err != nil {
return nil, err
}
return &relation.GetDesignatedFriendsResp{
FriendsInfo: friends,
}, nil
}
func (s *friendServer) getFriend(ctx context.Context, ownerUserID string, friendUserIDs []string) ([]*sdkws.FriendInfo, error) {
if len(friendUserIDs) == 0 {
return nil, nil
}
friends, err := s.db.FindFriendsWithError(ctx, ownerUserID, friendUserIDs)
if err != nil {
return nil, err
}
return convert.FriendsDB2Pb(ctx, friends, s.userClient.GetUsersInfoMap)
}
// Get the list of friend requests sent out proactively.
func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, req *relation.GetDesignatedFriendsApplyReq) (resp *relation.GetDesignatedFriendsApplyResp, err error) {
if err := authverify.CheckAccessIn(ctx, req.FromUserID, req.ToUserID); err != nil {
return nil, err
}
friendRequests, err := s.db.FindBothFriendRequests(ctx, req.FromUserID, req.ToUserID)
if err != nil {
return nil, err
}
resp = &relation.GetDesignatedFriendsApplyResp{}
resp.FriendRequests, err = convert.FriendRequestDB2Pb(ctx, friendRequests, s.getCommonUserMap)
if err != nil {
return nil, err
}
return resp, nil
}
// Get received friend requests (i.e., those initiated by others).
func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *relation.GetPaginationFriendsApplyToReq) (resp *relation.GetPaginationFriendsApplyToResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
handleResults := datautil.Slice(req.HandleResults, func(e int32) int {
return int(e)
})
total, friendRequests, err := s.db.PageFriendRequestToMe(ctx, req.UserID, handleResults, req.Pagination)
if err != nil {
return nil, err
}
resp = &relation.GetPaginationFriendsApplyToResp{}
resp.FriendRequests, err = convert.FriendRequestDB2Pb(ctx, friendRequests, s.getCommonUserMap)
if err != nil {
return nil, err
}
resp.Total = int32(total)
return resp, nil
}
func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *relation.GetPaginationFriendsApplyFromReq) (resp *relation.GetPaginationFriendsApplyFromResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
handleResults := datautil.Slice(req.HandleResults, func(e int32) int {
return int(e)
})
total, friendRequests, err := s.db.PageFriendRequestFromMe(ctx, req.UserID, handleResults, req.Pagination)
if err != nil {
return nil, err
}
resp = &relation.GetPaginationFriendsApplyFromResp{}
resp.FriendRequests, err = convert.FriendRequestDB2Pb(ctx, friendRequests, s.getCommonUserMap)
if err != nil {
return nil, err
}
resp.Total = int32(total)
return resp, nil
}
// ok.
func (s *friendServer) IsFriend(ctx context.Context, req *relation.IsFriendReq) (resp *relation.IsFriendResp, err error) {
if err := authverify.CheckAccessIn(ctx, req.UserID1, req.UserID2); err != nil {
return nil, err
}
resp = &relation.IsFriendResp{}
resp.InUser1Friends, resp.InUser2Friends, err = s.db.CheckIn(ctx, req.UserID1, req.UserID2)
if err != nil {
return nil, err
}
return resp, nil
}
func (s *friendServer) GetPaginationFriends(ctx context.Context, req *relation.GetPaginationFriendsReq) (resp *relation.GetPaginationFriendsResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
total, friends, err := s.db.PageOwnerFriends(ctx, req.UserID, req.Pagination)
if err != nil {
return nil, err
}
resp = &relation.GetPaginationFriendsResp{}
resp.FriendsInfo, err = convert.FriendsDB2Pb(ctx, friends, s.userClient.GetUsersInfoMap)
if err != nil {
return nil, err
}
resp.Total = int32(total)
return resp, nil
}
func (s *friendServer) GetFriendIDs(ctx context.Context, req *relation.GetFriendIDsReq) (resp *relation.GetFriendIDsResp, err error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
resp = &relation.GetFriendIDsResp{}
resp.FriendIDs, err = s.db.FindFriendUserIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
return resp, nil
}
func (s *friendServer) GetSpecifiedFriendsInfo(ctx context.Context, req *relation.GetSpecifiedFriendsInfoReq) (*relation.GetSpecifiedFriendsInfoResp, error) {
if len(req.UserIDList) == 0 {
return nil, errs.ErrArgs.WrapMsg("userIDList is empty")
}
if datautil.Duplicate(req.UserIDList) {
return nil, errs.ErrArgs.WrapMsg("userIDList repeated")
}
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
userMap, err := s.userClient.GetUsersInfoMap(ctx, req.UserIDList)
if err != nil {
return nil, err
}
friends, err := s.db.FindFriendsWithError(ctx, req.OwnerUserID, req.UserIDList)
if err != nil {
return nil, err
}
blacks, err := s.blackDatabase.FindBlackInfos(ctx, req.OwnerUserID, req.UserIDList)
if err != nil {
return nil, err
}
friendMap := datautil.SliceToMap(friends, func(e *model.Friend) string {
return e.FriendUserID
})
blackMap := datautil.SliceToMap(blacks, func(e *model.Black) string {
return e.BlockUserID
})
resp := &relation.GetSpecifiedFriendsInfoResp{
Infos: make([]*relation.GetSpecifiedFriendsInfoInfo, 0, len(req.UserIDList)),
}
for _, userID := range req.UserIDList {
user := userMap[userID]
if user == nil {
continue
}
var friendInfo *sdkws.FriendInfo
if friend := friendMap[userID]; friend != nil {
friendInfo = &sdkws.FriendInfo{
OwnerUserID: friend.OwnerUserID,
Remark: friend.Remark,
CreateTime: friend.CreateTime.UnixMilli(),
AddSource: friend.AddSource,
OperatorUserID: friend.OperatorUserID,
Ex: friend.Ex,
IsPinned: friend.IsPinned,
}
}
var blackInfo *sdkws.BlackInfo
if black := blackMap[userID]; black != nil {
blackInfo = &sdkws.BlackInfo{
OwnerUserID: black.OwnerUserID,
CreateTime: black.CreateTime.UnixMilli(),
AddSource: black.AddSource,
OperatorUserID: black.OperatorUserID,
Ex: black.Ex,
}
}
resp.Infos = append(resp.Infos, &relation.GetSpecifiedFriendsInfoInfo{
UserInfo: user,
FriendInfo: friendInfo,
BlackInfo: blackInfo,
})
}
return resp, nil
}
func (s *friendServer) UpdateFriends(ctx context.Context, req *relation.UpdateFriendsReq) (*relation.UpdateFriendsResp, error) {
if len(req.FriendUserIDs) == 0 {
return nil, errs.ErrArgs.WrapMsg("friendIDList is empty")
}
if datautil.Duplicate(req.FriendUserIDs) {
return nil, errs.ErrArgs.WrapMsg("friendIDList repeated")
}
if err := authverify.CheckAccess(ctx, req.OwnerUserID); err != nil {
return nil, err
}
_, err := s.db.FindFriendsWithError(ctx, req.OwnerUserID, req.FriendUserIDs)
if err != nil {
return nil, err
}
val := make(map[string]any)
if req.IsPinned != nil {
val["is_pinned"] = req.IsPinned.Value
}
if req.Remark != nil {
val["remark"] = req.Remark.Value
}
if req.Ex != nil {
val["ex"] = req.Ex.Value
}
if err = s.db.UpdateFriends(ctx, req.OwnerUserID, req.FriendUserIDs, val); err != nil {
return nil, err
}
resp := &relation.UpdateFriendsResp{}
s.notificationSender.FriendsInfoUpdateNotification(ctx, req.OwnerUserID, req.FriendUserIDs)
return resp, nil
}
func (s *friendServer) GetSelfUnhandledApplyCount(ctx context.Context, req *relation.GetSelfUnhandledApplyCountReq) (*relation.GetSelfUnhandledApplyCountResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
count, err := s.db.GetUnhandledCount(ctx, req.UserID, req.Time)
if err != nil {
return nil, err
}
return &relation.GetSelfUnhandledApplyCountResp{
Count: count,
}, nil
}
func (s *friendServer) getCommonUserMap(ctx context.Context, userIDs []string) (map[string]common_user.CommonUser, error) {
users, err := s.userClient.GetUsersInfo(ctx, userIDs)
if err != nil {
return nil, err
}
return datautil.SliceToMapAny(users, func(e *sdkws.UserInfo) (string, common_user.CommonUser) {
return e.UserID, e
}), nil
}

View File

@@ -0,0 +1,303 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package relation
import (
"context"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/msg"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/versionctx"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/convert"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/controller"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification"
"git.imall.cloud/openim/open-im-server-deploy/pkg/notification/common_user"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
)
type FriendNotificationSender struct {
*notification.NotificationSender
// Target not found err
getUsersInfo func(ctx context.Context, userIDs []string) ([]common_user.CommonUser, error)
// db controller
db controller.FriendDatabase
}
type friendNotificationSenderOptions func(*FriendNotificationSender)
func WithFriendDB(db controller.FriendDatabase) friendNotificationSenderOptions {
return func(s *FriendNotificationSender) {
s.db = db
}
}
func WithDBFunc(fn func(ctx context.Context, userIDs []string) (users []*relationtb.User, err error)) friendNotificationSenderOptions {
return func(s *FriendNotificationSender) {
f := func(ctx context.Context, userIDs []string) (result []common_user.CommonUser, err error) {
users, err := fn(ctx, userIDs)
if err != nil {
return nil, err
}
for _, user := range users {
result = append(result, user)
}
return result, nil
}
s.getUsersInfo = f
}
}
func WithRpcFunc(fn func(ctx context.Context, userIDs []string) ([]*sdkws.UserInfo, error)) friendNotificationSenderOptions {
return func(s *FriendNotificationSender) {
f := func(ctx context.Context, userIDs []string) (result []common_user.CommonUser, err error) {
users, err := fn(ctx, userIDs)
if err != nil {
return nil, err
}
for _, user := range users {
result = append(result, user)
}
return result, err
}
s.getUsersInfo = f
}
}
func NewFriendNotificationSender(conf *config.Notification, msgClient *rpcli.MsgClient, opts ...friendNotificationSenderOptions) *FriendNotificationSender {
f := &FriendNotificationSender{
NotificationSender: notification.NewNotificationSender(conf, notification.WithRpcClient(func(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
return msgClient.SendMsg(ctx, req)
})),
}
for _, opt := range opts {
opt(f)
}
return f
}
func (f *FriendNotificationSender) getUsersInfoMap(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error) {
users, err := f.getUsersInfo(ctx, userIDs)
if err != nil {
return nil, err
}
result := make(map[string]*sdkws.UserInfo)
for _, user := range users {
result[user.GetUserID()] = user.(*sdkws.UserInfo)
}
return result, nil
}
//nolint:unused
func (f *FriendNotificationSender) getFromToUserNickname(ctx context.Context, fromUserID, toUserID string) (string, string, error) {
users, err := f.getUsersInfoMap(ctx, []string{fromUserID, toUserID})
if err != nil {
return "", "", nil
}
return users[fromUserID].Nickname, users[toUserID].Nickname, nil
}
func (f *FriendNotificationSender) UserInfoUpdatedNotification(ctx context.Context, changedUserID string) {
tips := sdkws.UserInfoUpdatedTips{UserID: changedUserID}
f.Notification(ctx, mcontext.GetOpUserID(ctx), changedUserID, constant.UserInfoUpdatedNotification, &tips)
}
func (f *FriendNotificationSender) getCommonUserMap(ctx context.Context, userIDs []string) (map[string]common_user.CommonUser, error) {
users, err := f.getUsersInfo(ctx, userIDs)
if err != nil {
return nil, err
}
return datautil.SliceToMap(users, func(e common_user.CommonUser) string {
return e.GetUserID()
}), nil
}
func (f *FriendNotificationSender) getFriendRequests(ctx context.Context, fromUserID, toUserID string) (*sdkws.FriendRequest, error) {
if f.db == nil {
return nil, errs.ErrInternalServer.WithDetail("db is nil")
}
friendRequests, err := f.db.FindBothFriendRequests(ctx, fromUserID, toUserID)
if err != nil {
return nil, err
}
requests, err := convert.FriendRequestDB2Pb(ctx, friendRequests, f.getCommonUserMap)
if err != nil {
return nil, err
}
for _, request := range requests {
if request.FromUserID == fromUserID && request.ToUserID == toUserID {
return request, nil
}
}
return nil, errs.ErrRecordNotFound.WrapMsg("friend request not found", "fromUserID", fromUserID, "toUserID", toUserID)
}
func (f *FriendNotificationSender) FriendApplicationAddNotification(ctx context.Context, req *relation.ApplyToAddFriendReq) {
request, err := f.getFriendRequests(ctx, req.FromUserID, req.ToUserID)
if err != nil {
log.ZError(ctx, "FriendApplicationAddNotification get friend request", err, "fromUserID", req.FromUserID, "toUserID", req.ToUserID)
return
}
tips := sdkws.FriendApplicationTips{
FromToUserID: &sdkws.FromToUserID{
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
},
Request: request,
}
f.Notification(ctx, req.FromUserID, req.ToUserID, constant.FriendApplicationNotification, &tips)
}
func (f *FriendNotificationSender) FriendApplicationAgreedNotification(ctx context.Context, req *relation.RespondFriendApplyReq, checkReq bool) {
var (
request *sdkws.FriendRequest
err error
)
if checkReq {
request, err = f.getFriendRequests(ctx, req.FromUserID, req.ToUserID)
if err != nil {
log.ZError(ctx, "FriendApplicationAgreedNotification get friend request", err, "fromUserID", req.FromUserID, "toUserID", req.ToUserID)
return
}
}
tips := sdkws.FriendApplicationApprovedTips{
FromToUserID: &sdkws.FromToUserID{
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
},
HandleMsg: req.HandleMsg,
Request: request,
}
f.Notification(ctx, req.ToUserID, req.FromUserID, constant.FriendApplicationApprovedNotification, &tips)
}
func (f *FriendNotificationSender) FriendApplicationRefusedNotification(ctx context.Context, req *relation.RespondFriendApplyReq) {
request, err := f.getFriendRequests(ctx, req.FromUserID, req.ToUserID)
if err != nil {
log.ZError(ctx, "FriendApplicationRefusedNotification get friend request", err, "fromUserID", req.FromUserID, "toUserID", req.ToUserID)
return
}
tips := sdkws.FriendApplicationRejectedTips{
FromToUserID: &sdkws.FromToUserID{
FromUserID: req.FromUserID,
ToUserID: req.ToUserID,
},
HandleMsg: req.HandleMsg,
Request: request,
}
f.Notification(ctx, req.ToUserID, req.FromUserID, constant.FriendApplicationRejectedNotification, &tips)
}
//func (f *FriendNotificationSender) FriendAddedNotification(ctx context.Context, operationID, opUserID, fromUserID, toUserID string) error {
// tips := sdkws.FriendAddedTips{Friend: &sdkws.FriendInfo{}, OpUser: &sdkws.PublicUserInfo{}}
// user, err := f.getUsersInfo(ctx, []string{opUserID})
// if err != nil {
// return err
// }
// tips.OpUser.UserID = user[0].GetUserID()
// tips.OpUser.Ex = user[0].GetEx()
// tips.OpUser.Nickname = user[0].GetNickname()
// tips.OpUser.FaceURL = user[0].GetFaceURL()
// friends, err := f.db.FindFriendsWithError(ctx, fromUserID, []string{toUserID})
// if err != nil {
// return err
// }
// tips.Friend, err = convert.FriendDB2Pb(ctx, friends[0], f.getUsersInfoMap)
// if err != nil {
// return err
// }
// f.Notification(ctx, fromUserID, toUserID, constant.FriendAddedNotification, &tips)
// return nil
//}
func (f *FriendNotificationSender) FriendDeletedNotification(ctx context.Context, req *relation.DeleteFriendReq) {
tips := sdkws.FriendDeletedTips{FromToUserID: &sdkws.FromToUserID{
FromUserID: req.OwnerUserID,
ToUserID: req.FriendUserID,
}}
f.Notification(ctx, req.OwnerUserID, req.FriendUserID, constant.FriendDeletedNotification, &tips)
}
func (f *FriendNotificationSender) setVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string) {
versions := versionctx.GetVersionLog(ctx).Get()
for _, coll := range versions {
if coll.Name == collName && coll.Doc.DID == id {
*version = uint64(coll.Doc.Version)
*versionID = coll.Doc.ID.Hex()
return
}
}
}
func (f *FriendNotificationSender) setSortVersion(ctx context.Context, version *uint64, versionID *string, collName string, id string, sortVersion *uint64) {
versions := versionctx.GetVersionLog(ctx).Get()
for _, coll := range versions {
if coll.Name == collName && coll.Doc.DID == id {
*version = uint64(coll.Doc.Version)
*versionID = coll.Doc.ID.Hex()
for _, elem := range coll.Doc.Logs {
if elem.EID == relationtb.VersionSortChangeID {
*sortVersion = uint64(elem.Version)
}
}
}
}
}
func (f *FriendNotificationSender) FriendRemarkSetNotification(ctx context.Context, fromUserID, toUserID string) {
tips := sdkws.FriendInfoChangedTips{FromToUserID: &sdkws.FromToUserID{}}
tips.FromToUserID.FromUserID = fromUserID
tips.FromToUserID.ToUserID = toUserID
f.setSortVersion(ctx, &tips.FriendVersion, &tips.FriendVersionID, database.FriendVersionName, toUserID, &tips.FriendSortVersion)
f.Notification(ctx, fromUserID, toUserID, constant.FriendRemarkSetNotification, &tips)
}
func (f *FriendNotificationSender) FriendsInfoUpdateNotification(ctx context.Context, toUserID string, friendIDs []string) {
tips := sdkws.FriendsInfoUpdateTips{FromToUserID: &sdkws.FromToUserID{}}
tips.FromToUserID.ToUserID = toUserID
tips.FriendIDs = friendIDs
f.Notification(ctx, toUserID, toUserID, constant.FriendsInfoUpdateNotification, &tips)
}
func (f *FriendNotificationSender) BlackAddedNotification(ctx context.Context, req *relation.AddBlackReq) {
tips := sdkws.BlackAddedTips{FromToUserID: &sdkws.FromToUserID{}}
tips.FromToUserID.FromUserID = req.OwnerUserID
tips.FromToUserID.ToUserID = req.BlackUserID
f.Notification(ctx, req.OwnerUserID, req.BlackUserID, constant.BlackAddedNotification, &tips)
}
func (f *FriendNotificationSender) BlackDeletedNotification(ctx context.Context, req *relation.RemoveBlackReq) {
blackDeletedTips := sdkws.BlackDeletedTips{FromToUserID: &sdkws.FromToUserID{
FromUserID: req.OwnerUserID,
ToUserID: req.BlackUserID,
}}
f.Notification(ctx, req.OwnerUserID, req.BlackUserID, constant.BlackDeletedNotification, &blackDeletedTips)
}
func (f *FriendNotificationSender) FriendInfoUpdatedNotification(ctx context.Context, changedUserID string, needNotifiedUserID string) {
tips := sdkws.UserInfoUpdatedTips{UserID: changedUserID}
f.Notification(ctx, mcontext.GetOpUserID(ctx), needNotifiedUserID, constant.FriendInfoUpdatedNotification, &tips)
}

View File

@@ -0,0 +1,108 @@
package relation
import (
"context"
"slices"
"git.imall.cloud/openim/open-im-server-deploy/pkg/util/hashutil"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/log"
"git.imall.cloud/openim/open-im-server-deploy/internal/rpc/incrversion"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/relation"
)
func (s *friendServer) NotificationUserInfoUpdate(ctx context.Context, req *relation.NotificationUserInfoUpdateReq) (*relation.NotificationUserInfoUpdateResp, error) {
userIDs, err := s.db.FindFriendUserIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
if len(userIDs) > 0 {
friendUserIDs := []string{req.UserID}
noCancelCtx := context.WithoutCancel(ctx)
err := s.queue.PushCtx(ctx, func() {
for _, userID := range userIDs {
if err := s.db.OwnerIncrVersion(noCancelCtx, userID, friendUserIDs, model.VersionStateUpdate); err != nil {
log.ZError(ctx, "OwnerIncrVersion", err, "userID", userID, "friendUserIDs", friendUserIDs)
}
}
for _, userID := range userIDs {
s.notificationSender.FriendInfoUpdatedNotification(noCancelCtx, req.UserID, userID)
}
})
if err != nil {
log.ZError(ctx, "NotificationUserInfoUpdate timeout", err, "userID", req.UserID)
}
}
return &relation.NotificationUserInfoUpdateResp{}, nil
}
func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.GetFullFriendUserIDsReq) (*relation.GetFullFriendUserIDsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
vl, err := s.db.FindMaxFriendVersionCache(ctx, req.UserID)
if err != nil {
return nil, err
}
userIDs, err := s.db.FindFriendUserIDs(ctx, req.UserID)
if err != nil {
return nil, err
}
idHash := hashutil.IdHash(userIDs)
if req.IdHash == idHash {
userIDs = nil
}
return &relation.GetFullFriendUserIDsResp{
Version: idHash,
VersionID: vl.ID.Hex(),
Equal: req.IdHash == idHash,
UserIDs: userIDs,
}, nil
}
func (s *friendServer) GetIncrementalFriends(ctx context.Context, req *relation.GetIncrementalFriendsReq) (*relation.GetIncrementalFriendsResp, error) {
if err := authverify.CheckAccess(ctx, req.UserID); err != nil {
return nil, err
}
var sortVersion uint64
opt := incrversion.Option[*sdkws.FriendInfo, relation.GetIncrementalFriendsResp]{
Ctx: ctx,
VersionKey: req.UserID,
VersionID: req.VersionID,
VersionNumber: req.Version,
Version: func(ctx context.Context, ownerUserID string, version uint, limit int) (*model.VersionLog, error) {
vl, err := s.db.FindFriendIncrVersion(ctx, ownerUserID, version, limit)
if err != nil {
return nil, err
}
vl.Logs = slices.DeleteFunc(vl.Logs, func(elem model.VersionLogElem) bool {
if elem.EID == model.VersionSortChangeID {
vl.LogLen--
sortVersion = uint64(elem.Version)
return true
}
return false
})
return vl, nil
},
CacheMaxVersion: s.db.FindMaxFriendVersionCache,
Find: func(ctx context.Context, ids []string) ([]*sdkws.FriendInfo, error) {
return s.getFriend(ctx, req.UserID, ids)
},
Resp: func(version *model.VersionLog, deleteIds []string, insertList, updateList []*sdkws.FriendInfo, full bool) *relation.GetIncrementalFriendsResp {
return &relation.GetIncrementalFriendsResp{
VersionID: version.ID.Hex(),
Version: uint64(version.Version),
Full: full,
Delete: deleteIds,
Insert: insertList,
Update: updateList,
SortVersion: sortVersion,
}
},
}
return opt.Build()
}

233
internal/rpc/third/log.go Normal file
View File

@@ -0,0 +1,233 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package third
import (
"context"
"crypto/rand"
"net/url"
"strings"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
relationtb "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/third"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
func genLogID() string {
const dataLen = 10
data := make([]byte, dataLen)
rand.Read(data)
chars := []byte("0123456789")
for i := 0; i < len(data); i++ {
if i == 0 {
data[i] = chars[1:][data[i]%9]
} else {
data[i] = chars[data[i]%10]
}
}
return string(data)
}
// extractKeyFromLogURL 从日志URL中提取S3的key
// URL格式: https://s3.jizhying.com/images/openim/data/hash/{hash}?...
// 或: https://chatall.oss-ap-southeast-1.aliyuncs.com/openim%2Fdata%2Fhash%2F{hash}
// key格式: openim/data/hash/{hash}不包含bucket名称
// bucket名称在URL路径的第一段如images需要去掉
func extractKeyFromLogURL(logURL string, bucketName string) string {
if logURL == "" {
return ""
}
parsedURL, err := url.Parse(logURL)
if err != nil {
return ""
}
// 获取路径部分,去掉开头的'/'
path := strings.TrimPrefix(parsedURL.Path, "/")
if path == "" {
return ""
}
// 如果配置了bucket名称且路径以bucket名称开头则去掉bucket名称前缀
if bucketName != "" && strings.HasPrefix(path, bucketName+"/") {
path = strings.TrimPrefix(path, bucketName+"/")
} else {
// 如果没有匹配到bucket名称尝试去掉路径的第一段可能是bucket名称
// 这种情况下假设路径的第一段是bucket名称
parts := strings.SplitN(path, "/", 2)
if len(parts) > 1 {
path = parts[1]
}
}
// URL.Path已经是解码后的路径所以直接返回即可
return path
}
func (t *thirdServer) UploadLogs(ctx context.Context, req *third.UploadLogsReq) (*third.UploadLogsResp, error) {
var dbLogs []*relationtb.Log
userID := mcontext.GetOpUserID(ctx)
platform := constant.PlatformID2Name[int(req.Platform)]
for _, fileURL := range req.FileURLs {
log := relationtb.Log{
Platform: platform,
UserID: userID,
CreateTime: time.Now(),
Url: fileURL.URL,
FileName: fileURL.Filename,
AppFramework: req.AppFramework,
Version: req.Version,
Ex: req.Ex,
}
for i := 0; i < 20; i++ {
id := genLogID()
logs, err := t.thirdDatabase.GetLogs(ctx, []string{id}, "")
if err != nil {
return nil, err
}
if len(logs) == 0 {
log.LogID = id
break
}
}
if log.LogID == "" {
return nil, servererrs.ErrData.WrapMsg("Log id gen error")
}
dbLogs = append(dbLogs, &log)
}
err := t.thirdDatabase.UploadLogs(ctx, dbLogs)
if err != nil {
return nil, err
}
return &third.UploadLogsResp{}, nil
}
func (t *thirdServer) DeleteLogs(ctx context.Context, req *third.DeleteLogsReq) (*third.DeleteLogsResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
userID := ""
logs, err := t.thirdDatabase.GetLogs(ctx, req.LogIDs, userID)
if err != nil {
return nil, err
}
var logIDs []string
for _, log := range logs {
logIDs = append(logIDs, log.LogID)
}
if ids := datautil.Single(req.LogIDs, logIDs); len(ids) > 0 {
return nil, errs.ErrRecordNotFound.WrapMsg("logIDs not found", "logIDs", ids)
}
// 在删除日志记录前先删除对应的S3文件
engine := t.config.RpcConfig.Object.Enable
if engine != "" && t.s3 != nil {
// 获取bucket名称从minio配置中
bucketName := ""
if engine == "minio" {
bucketName = t.config.MinioConfig.Bucket
}
for _, logRecord := range logs {
if logRecord.Url == "" {
continue
}
// 从URL中提取S3的key不包含bucket名称
key := extractKeyFromLogURL(logRecord.Url, bucketName)
if key == "" {
log.ZDebug(ctx, "DeleteLogs: cannot extract key from URL, skipping S3 deletion", "logID", logRecord.LogID, "url", logRecord.Url)
continue
}
// 直接使用key删除S3文件
log.ZInfo(ctx, "DeleteLogs: attempting to delete S3 file", "logID", logRecord.LogID, "url", logRecord.Url, "key", key, "bucket", bucketName, "engine", engine)
if err := t.s3.DeleteObject(ctx, key); err != nil {
// S3文件删除失败返回错误不删除数据库记录
log.ZError(ctx, "DeleteLogs: S3 file delete failed", err, "logID", logRecord.LogID, "url", logRecord.Url, "key", key, "bucket", bucketName, "engine", engine)
return nil, errs.WrapMsg(err, "failed to delete S3 file for log", "logID", logRecord.LogID, "url", logRecord.Url, "key", key)
}
log.ZInfo(ctx, "DeleteLogs: S3 file delete command executed successfully", "logID", logRecord.LogID, "url", logRecord.Url, "key", key, "bucket", bucketName, "engine", engine)
}
}
err = t.thirdDatabase.DeleteLogs(ctx, req.LogIDs, userID)
if err != nil {
return nil, err
}
return &third.DeleteLogsResp{}, nil
}
func dbToPbLogInfos(logs []*relationtb.Log) []*third.LogInfo {
db2pbForLogInfo := func(log *relationtb.Log) *third.LogInfo {
return &third.LogInfo{
Filename: log.FileName,
UserID: log.UserID,
Platform: log.Platform,
Url: log.Url,
CreateTime: log.CreateTime.UnixMilli(),
LogID: log.LogID,
SystemType: log.SystemType,
Version: log.Version,
Ex: log.Ex,
}
}
return datautil.Slice(logs, db2pbForLogInfo)
}
func (t *thirdServer) SearchLogs(ctx context.Context, req *third.SearchLogsReq) (*third.SearchLogsResp, error) {
if err := authverify.CheckAdmin(ctx); err != nil {
return nil, err
}
var (
resp third.SearchLogsResp
userIDs []string
)
if req.StartTime > req.EndTime {
return nil, errs.ErrArgs.WrapMsg("startTime>endTime")
}
if req.StartTime == 0 && req.EndTime == 0 {
t := time.Date(2019, time.January, 1, 0, 0, 0, 0, time.UTC)
timestampMills := t.UnixNano() / int64(time.Millisecond)
req.StartTime = timestampMills
req.EndTime = time.Now().UnixNano() / int64(time.Millisecond)
}
total, logs, err := t.thirdDatabase.SearchLogs(ctx, req.Keyword, time.UnixMilli(req.StartTime), time.UnixMilli(req.EndTime), req.Pagination)
if err != nil {
return nil, err
}
pbLogs := dbToPbLogInfos(logs)
for _, log := range logs {
userIDs = append(userIDs, log.UserID)
}
userMap, err := t.userClient.GetUsersInfoMap(ctx, userIDs)
if err != nil {
return nil, err
}
for _, pbLog := range pbLogs {
if user, ok := userMap[pbLog.UserID]; ok {
pbLog.Nickname = user.Nickname
}
}
resp.LogsInfos = pbLogs
resp.Total = uint32(total)
return &resp, nil
}

446
internal/rpc/third/r2.go Normal file
View File

@@ -0,0 +1,446 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package third
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/credentials"
aws3 "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/openimsdk/tools/s3"
)
const (
minPartSize int64 = 1024 * 1024 * 5 // 5MB
maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB
maxNumSize int64 = 10000
)
type R2Config struct {
Endpoint string
Region string
Bucket string
AccessKeyID string
SecretAccessKey string
SessionToken string
}
// NewR2 创建支持 Cloudflare R2 的 S3 客户端
func NewR2(conf R2Config) (*R2, error) {
if conf.Endpoint == "" {
return nil, errors.New("endpoint is required for R2")
}
// 创建 HTTP 客户端,设置合理的超时
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
},
}
cfg := aws.Config{
Region: conf.Region,
Credentials: credentials.NewStaticCredentialsProvider(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
HTTPClient: httpClient,
}
// 创建 S3 客户端启用路径风格访问R2 要求)并设置自定义 endpoint
client := aws3.NewFromConfig(cfg, func(o *aws3.Options) {
o.BaseEndpoint = aws.String(conf.Endpoint)
o.UsePathStyle = true
})
r2 := &R2{
bucket: conf.Bucket,
client: client,
presign: aws3.NewPresignClient(client),
}
// 测试连接:尝试列出 bucket验证 bucket 存在且有权限),设置 5 秒超时
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
fmt.Printf("[R2] Testing connection to bucket '%s' at endpoint '%s'...\n", conf.Bucket, conf.Endpoint)
_, err := client.ListObjectsV2(ctx, &aws3.ListObjectsV2Input{
Bucket: aws.String(conf.Bucket),
MaxKeys: aws.Int32(1),
})
if err != nil {
// 详细的错误信息
var respErr *awshttp.ResponseError
if errors.As(err, &respErr) {
fmt.Printf("[R2] Bucket verification HTTP error:\n")
fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode)
fmt.Printf(" Status: %s\n", respErr.Response.Status)
}
fmt.Printf("[R2] Warning: failed to verify R2 bucket '%s' at endpoint '%s': %v\n", conf.Bucket, conf.Endpoint, err)
fmt.Printf("[R2] Please ensure:\n")
fmt.Printf(" 1. Bucket '%s' exists in your R2 account\n", conf.Bucket)
fmt.Printf(" 2. API credentials have correct permissions (Object Read & Write)\n")
fmt.Printf(" 3. Account ID in endpoint matches your R2 account\n")
} else {
fmt.Printf("[R2] Successfully connected to bucket '%s'\n", conf.Bucket)
}
return r2, nil
}
type R2 struct {
bucket string
client *aws3.Client
presign *aws3.PresignClient
}
func (r *R2) Engine() string {
return "aws"
}
func (r *R2) PartLimit() (*s3.PartLimit, error) {
return &s3.PartLimit{
MinPartSize: minPartSize,
MaxPartSize: maxPartSize,
MaxNumSize: maxNumSize,
}, nil
}
func (r *R2) formatETag(etag string) string {
return strings.Trim(etag, `"`)
}
func (r *R2) PartSize(ctx context.Context, size int64) (int64, error) {
if size <= 0 {
return 0, errors.New("size must be greater than 0")
}
if size > maxPartSize*maxNumSize {
return 0, fmt.Errorf("size must be less than the maximum allowed limit")
}
if size <= minPartSize*maxNumSize {
return minPartSize, nil
}
partSize := size / maxNumSize
if size%maxNumSize != 0 {
partSize++
}
return partSize, nil
}
func (r *R2) IsNotFound(err error) bool {
var respErr *awshttp.ResponseError
if !errors.As(err, &respErr) {
return false
}
if respErr == nil || respErr.Response == nil {
return false
}
return respErr.Response.StatusCode == http.StatusNotFound
}
func (r *R2) PresignedPutObject(ctx context.Context, name string, expire time.Duration, opt *s3.PutOption) (*s3.PresignedPutResult, error) {
res, err := r.presign.PresignPutObject(ctx, &aws3.PutObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(nil))
if err != nil {
return nil, err
}
return &s3.PresignedPutResult{URL: res.URL}, nil
}
func (r *R2) DeleteObject(ctx context.Context, name string) error {
_, err := r.client.DeleteObject(ctx, &aws3.DeleteObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
})
return err
}
func (r *R2) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
res, err := r.client.CopyObject(ctx, &aws3.CopyObjectInput{
Bucket: aws.String(r.bucket),
CopySource: aws.String(r.bucket + "/" + src),
Key: aws.String(dst),
})
if err != nil {
return nil, err
}
if res.CopyObjectResult == nil || res.CopyObjectResult.ETag == nil || *res.CopyObjectResult.ETag == "" {
return nil, errors.New("CopyObject etag is nil")
}
return &s3.CopyObjectInfo{
Key: dst,
ETag: r.formatETag(*res.CopyObjectResult.ETag),
}, nil
}
func (r *R2) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) {
res, err := r.client.HeadObject(ctx, &aws3.HeadObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
})
if err != nil {
return nil, err
}
if res.ETag == nil || *res.ETag == "" {
return nil, errors.New("GetObjectAttributes etag is nil")
}
if res.ContentLength == nil {
return nil, errors.New("GetObjectAttributes object size is nil")
}
info := &s3.ObjectInfo{
ETag: r.formatETag(*res.ETag),
Key: name,
Size: *res.ContentLength,
}
if res.LastModified == nil {
info.LastModified = time.Unix(0, 0)
} else {
info.LastModified = *res.LastModified
}
return info, nil
}
func (r *R2) InitiateMultipartUpload(ctx context.Context, name string, opt *s3.PutOption) (*s3.InitiateMultipartUploadResult, error) {
startTime := time.Now()
fmt.Printf("[R2] InitiateMultipartUpload start: bucket=%s, key=%s\n", r.bucket, name)
input := &aws3.CreateMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}
// 如果提供了 ContentType添加到请求中
if opt != nil && opt.ContentType != "" {
input.ContentType = aws.String(opt.ContentType)
fmt.Printf("[R2] ContentType: %s\n", opt.ContentType)
}
res, err := r.client.CreateMultipartUpload(ctx, input)
duration := time.Since(startTime)
if err != nil {
// 详细错误信息
var respErr *awshttp.ResponseError
if errors.As(err, &respErr) {
fmt.Printf("[R2] HTTP Response Error after %v:\n", duration)
fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode)
fmt.Printf(" Status: %s\n", respErr.Response.Status)
if respErr.Response.Body != nil {
body := make([]byte, 1024)
n, _ := respErr.Response.Body.Read(body)
fmt.Printf(" Body: %s\n", string(body[:n]))
}
}
fmt.Printf("[R2] InitiateMultipartUpload failed after %v: %v\n", duration, err)
return nil, fmt.Errorf("CreateMultipartUpload failed (bucket=%s, key=%s): %w", r.bucket, name, err)
}
if res.UploadId == nil || *res.UploadId == "" {
return nil, errors.New("CreateMultipartUpload upload id is nil")
}
fmt.Printf("[R2] InitiateMultipartUpload success after %v: uploadID=%s\n", duration, *res.UploadId)
return &s3.InitiateMultipartUploadResult{
Key: name,
Bucket: r.bucket,
UploadID: *res.UploadId,
}, nil
}
func (r *R2) CompleteMultipartUpload(ctx context.Context, uploadID string, name string, parts []s3.Part) (*s3.CompleteMultipartUploadResult, error) {
params := &aws3.CompleteMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
MultipartUpload: &types.CompletedMultipartUpload{
Parts: make([]types.CompletedPart, 0, len(parts)),
},
}
for _, part := range parts {
params.MultipartUpload.Parts = append(params.MultipartUpload.Parts, types.CompletedPart{
ETag: aws.String(part.ETag),
PartNumber: aws.Int32(int32(part.PartNumber)),
})
}
res, err := r.client.CompleteMultipartUpload(ctx, params)
if err != nil {
return nil, err
}
if res.ETag == nil || *res.ETag == "" {
return nil, errors.New("CompleteMultipartUpload etag is nil")
}
info := &s3.CompleteMultipartUploadResult{
Key: name,
Bucket: r.bucket,
ETag: r.formatETag(*res.ETag),
}
if res.Location != nil {
info.Location = *res.Location
}
return info, nil
}
func (r *R2) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error {
_, err := r.client.AbortMultipartUpload(ctx, &aws3.AbortMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
})
return err
}
func (r *R2) ListUploadedParts(ctx context.Context, uploadID string, name string, partNumberMarker int, maxParts int) (*s3.ListUploadedPartsResult, error) {
params := &aws3.ListPartsInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
MaxParts: aws.Int32(int32(maxParts)),
}
res, err := r.client.ListParts(ctx, params)
if err != nil {
return nil, err
}
info := &s3.ListUploadedPartsResult{
Key: name,
UploadID: uploadID,
UploadedParts: make([]s3.UploadedPart, 0, len(res.Parts)),
}
if res.MaxParts != nil {
info.MaxParts = int(*res.MaxParts)
}
if res.NextPartNumberMarker != nil {
info.NextPartNumberMarker, _ = strconv.Atoi(*res.NextPartNumberMarker)
}
for _, part := range res.Parts {
var val s3.UploadedPart
if part.PartNumber != nil {
val.PartNumber = int(*part.PartNumber)
}
if part.LastModified != nil {
val.LastModified = *part.LastModified
}
if part.Size != nil {
val.Size = *part.Size
}
info.UploadedParts = append(info.UploadedParts, val)
}
return info, nil
}
func (r *R2) AuthSign(ctx context.Context, uploadID string, name string, expire time.Duration, partNumbers []int) (*s3.AuthSignResult, error) {
res := &s3.AuthSignResult{
Parts: make([]s3.SignPart, 0, len(partNumbers)),
}
params := &aws3.UploadPartInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
}
opt := aws3.WithPresignExpires(expire)
for _, number := range partNumbers {
params.PartNumber = aws.Int32(int32(number))
val, err := r.presign.PresignUploadPart(ctx, params, opt)
if err != nil {
return nil, err
}
u, err := url.Parse(val.URL)
if err != nil {
return nil, err
}
query := u.Query()
u.RawQuery = ""
urlstr := u.String()
if res.URL == "" {
res.URL = urlstr
}
if res.URL == urlstr {
urlstr = ""
}
res.Parts = append(res.Parts, s3.SignPart{
PartNumber: number,
URL: urlstr,
Query: query,
Header: val.SignedHeader,
})
}
return res, nil
}
func (r *R2) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
params := &aws3.GetObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}
res, err := r.presign.PresignGetObject(ctx, params, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(opt))
if err != nil {
return "", err
}
return res.URL, nil
}
func (r *R2) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) {
return nil, errors.New("R2 does not currently support form data file uploads")
}
func withDisableHTTPPresignerHeaderV4(opt *s3.AccessURLOption) func(options *aws3.PresignOptions) {
return func(options *aws3.PresignOptions) {
options.Presigner = &disableHTTPPresignerHeaderV4{
opt: opt,
presigner: options.Presigner,
}
}
}
type disableHTTPPresignerHeaderV4 struct {
opt *s3.AccessURLOption
presigner aws3.HTTPPresignerV4
}
func (d *disableHTTPPresignerHeaderV4) PresignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*v4.SignerOptions)) (url string, signedHeader http.Header, err error) {
optFns = append(optFns, func(options *v4.SignerOptions) {
options.DisableHeaderHoisting = true
})
r.Header.Del("Amz-Sdk-Request")
d.setOption(r.URL)
return d.presigner.PresignHTTP(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
}
func (d *disableHTTPPresignerHeaderV4) setOption(u *url.URL) {
if d.opt == nil {
return
}
query := u.Query()
if d.opt.ContentType != "" {
query.Set("response-content-type", d.opt.ContentType)
}
if d.opt.Filename != "" {
query.Set("response-content-disposition", `attachment; filename*=UTF-8''`+url.PathEscape(d.opt.Filename))
}
u.RawQuery = query.Encode()
}

Some files were not shown because too many files have changed in this diff Show More