复制项目

This commit is contained in:
kim.dev.6789
2026-01-14 22:16:44 +08:00
parent e2577b8cee
commit e50142a3b9
691 changed files with 97009 additions and 1 deletions

45
internal/api/auth.go Normal file
View File

@@ -0,0 +1,45 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/auth"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
)
type AuthApi struct {
Client auth.AuthClient
}
func NewAuthApi(client auth.AuthClient) AuthApi {
return AuthApi{client}
}
func (o *AuthApi) GetAdminToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.GetAdminToken, o.Client)
}
func (o *AuthApi) GetUserToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.GetUserToken, o.Client)
}
func (o *AuthApi) ParseToken(c *gin.Context) {
a2r.Call(c, auth.AuthClient.ParseToken, o.Client)
}
func (o *AuthApi) ForceLogout(c *gin.Context) {
a2r.Call(c, auth.AuthClient.ForceLogout, o.Client)
}

View File

@@ -0,0 +1,413 @@
package api
import (
"encoding/json"
"reflect"
"strconv"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/discovery/etcd"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/version"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/runtimeenv"
clientv3 "go.etcd.io/etcd/client/v3"
)
const (
// wait for Restart http call return
waitHttp = time.Millisecond * 200
)
type ConfigManager struct {
imAdminUserID []string
config *config.AllConfig
client *clientv3.Client
configPath string
systemConfigDB database.SystemConfig
}
func NewConfigManager(IMAdminUserID []string, cfg *config.AllConfig, client *clientv3.Client, configPath string, systemConfigDB database.SystemConfig) *ConfigManager {
cm := &ConfigManager{
imAdminUserID: IMAdminUserID,
config: cfg,
client: client,
configPath: configPath,
systemConfigDB: systemConfigDB,
}
return cm
}
func (cm *ConfigManager) CheckAdmin(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
c.Abort()
}
}
func (cm *ConfigManager) GetConfig(c *gin.Context) {
var req apistruct.GetConfigReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
conf := cm.config.Name2Config(req.ConfigName)
if conf == nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail("config name not found").Wrap())
return
}
b, err := json.Marshal(conf)
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, string(b))
}
func (cm *ConfigManager) GetConfigList(c *gin.Context) {
var resp apistruct.GetConfigListResp
resp.ConfigNames = cm.config.GetConfigNames()
resp.Environment = runtimeenv.RuntimeEnvironment()
resp.Version = version.Version
apiresp.GinSuccess(c, resp)
}
func (cm *ConfigManager) SetConfig(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support set config").Wrap())
return
}
var req apistruct.SetConfigReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var err error
switch req.ConfigName {
case cm.config.Discovery.GetConfigFileName():
err = compareAndSave[config.Discovery](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Kafka.GetConfigFileName():
err = compareAndSave[config.Kafka](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.LocalCache.GetConfigFileName():
err = compareAndSave[config.LocalCache](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Log.GetConfigFileName():
err = compareAndSave[config.Log](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Minio.GetConfigFileName():
err = compareAndSave[config.Minio](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Mongo.GetConfigFileName():
err = compareAndSave[config.Mongo](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Notification.GetConfigFileName():
err = compareAndSave[config.Notification](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.API.GetConfigFileName():
err = compareAndSave[config.API](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.CronTask.GetConfigFileName():
err = compareAndSave[config.CronTask](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.MsgGateway.GetConfigFileName():
err = compareAndSave[config.MsgGateway](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.MsgTransfer.GetConfigFileName():
err = compareAndSave[config.MsgTransfer](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Push.GetConfigFileName():
err = compareAndSave[config.Push](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Auth.GetConfigFileName():
err = compareAndSave[config.Auth](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Conversation.GetConfigFileName():
err = compareAndSave[config.Conversation](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Friend.GetConfigFileName():
err = compareAndSave[config.Friend](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Group.GetConfigFileName():
err = compareAndSave[config.Group](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Msg.GetConfigFileName():
err = compareAndSave[config.Msg](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Third.GetConfigFileName():
err = compareAndSave[config.Third](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.User.GetConfigFileName():
err = compareAndSave[config.User](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Redis.GetConfigFileName():
err = compareAndSave[config.Redis](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Share.GetConfigFileName():
err = compareAndSave[config.Share](c, cm.config.Name2Config(req.ConfigName), &req, cm)
case cm.config.Webhooks.GetConfigFileName():
err = compareAndSave[config.Webhooks](c, cm.config.Name2Config(req.ConfigName), &req, cm)
default:
apiresp.GinError(c, errs.ErrArgs.Wrap())
return
}
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) SetConfigs(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support set config").Wrap())
return
}
var req apistruct.SetConfigsReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var (
err error
ops []*clientv3.Op
)
for _, cf := range req.Configs {
var op *clientv3.Op
switch cf.ConfigName {
case cm.config.Discovery.GetConfigFileName():
op, err = compareAndOp[config.Discovery](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Kafka.GetConfigFileName():
op, err = compareAndOp[config.Kafka](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.LocalCache.GetConfigFileName():
op, err = compareAndOp[config.LocalCache](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Log.GetConfigFileName():
op, err = compareAndOp[config.Log](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Minio.GetConfigFileName():
op, err = compareAndOp[config.Minio](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Mongo.GetConfigFileName():
op, err = compareAndOp[config.Mongo](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Notification.GetConfigFileName():
op, err = compareAndOp[config.Notification](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.API.GetConfigFileName():
op, err = compareAndOp[config.API](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.CronTask.GetConfigFileName():
op, err = compareAndOp[config.CronTask](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.MsgGateway.GetConfigFileName():
op, err = compareAndOp[config.MsgGateway](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.MsgTransfer.GetConfigFileName():
op, err = compareAndOp[config.MsgTransfer](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Push.GetConfigFileName():
op, err = compareAndOp[config.Push](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Auth.GetConfigFileName():
op, err = compareAndOp[config.Auth](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Conversation.GetConfigFileName():
op, err = compareAndOp[config.Conversation](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Friend.GetConfigFileName():
op, err = compareAndOp[config.Friend](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Group.GetConfigFileName():
op, err = compareAndOp[config.Group](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Msg.GetConfigFileName():
op, err = compareAndOp[config.Msg](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Third.GetConfigFileName():
op, err = compareAndOp[config.Third](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.User.GetConfigFileName():
op, err = compareAndOp[config.User](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Redis.GetConfigFileName():
op, err = compareAndOp[config.Redis](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Share.GetConfigFileName():
op, err = compareAndOp[config.Share](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
case cm.config.Webhooks.GetConfigFileName():
op, err = compareAndOp[config.Webhooks](c, cm.config.Name2Config(cf.ConfigName), &cf, cm)
default:
apiresp.GinError(c, errs.ErrArgs.Wrap())
return
}
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if op != nil {
ops = append(ops, op)
}
}
if len(ops) > 0 {
tx := cm.client.Txn(c)
if _, err = tx.Then(datautil.Batch(func(op *clientv3.Op) clientv3.Op { return *op }, ops)...).Commit(); err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "save to etcd failed"))
return
}
}
apiresp.GinSuccess(c, nil)
}
func compareAndOp[T any](c *gin.Context, old any, req *apistruct.SetConfigReq, cm *ConfigManager) (*clientv3.Op, error) {
conf := new(T)
err := json.Unmarshal([]byte(req.Data), &conf)
if err != nil {
return nil, errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
eq := reflect.DeepEqual(old, conf)
if eq {
return nil, nil
}
data, err := json.Marshal(conf)
if err != nil {
return nil, errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
op := clientv3.OpPut(etcd.BuildKey(req.ConfigName), string(data))
return &op, nil
}
func compareAndSave[T any](c *gin.Context, old any, req *apistruct.SetConfigReq, cm *ConfigManager) error {
conf := new(T)
err := json.Unmarshal([]byte(req.Data), &conf)
if err != nil {
return errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
eq := reflect.DeepEqual(old, conf)
if eq {
return nil
}
data, err := json.Marshal(conf)
if err != nil {
return errs.ErrArgs.WithDetail(err.Error()).Wrap()
}
_, err = cm.client.Put(c, etcd.BuildKey(req.ConfigName), string(data))
if err != nil {
return errs.WrapMsg(err, "save to etcd failed")
}
return nil
}
func (cm *ConfigManager) ResetConfig(c *gin.Context) {
go func() {
if err := cm.resetConfig(c, true); err != nil {
log.ZError(c, "reset config err", err)
}
}()
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) resetConfig(c *gin.Context, checkChange bool, ops ...clientv3.Op) error {
txn := cm.client.Txn(c)
type initConf struct {
old any
new any
}
configMap := map[string]*initConf{
cm.config.Discovery.GetConfigFileName(): {old: &cm.config.Discovery, new: new(config.Discovery)},
cm.config.Kafka.GetConfigFileName(): {old: &cm.config.Kafka, new: new(config.Kafka)},
cm.config.LocalCache.GetConfigFileName(): {old: &cm.config.LocalCache, new: new(config.LocalCache)},
cm.config.Log.GetConfigFileName(): {old: &cm.config.Log, new: new(config.Log)},
cm.config.Minio.GetConfigFileName(): {old: &cm.config.Minio, new: new(config.Minio)},
cm.config.Mongo.GetConfigFileName(): {old: &cm.config.Mongo, new: new(config.Mongo)},
cm.config.Notification.GetConfigFileName(): {old: &cm.config.Notification, new: new(config.Notification)},
cm.config.API.GetConfigFileName(): {old: &cm.config.API, new: new(config.API)},
cm.config.CronTask.GetConfigFileName(): {old: &cm.config.CronTask, new: new(config.CronTask)},
cm.config.MsgGateway.GetConfigFileName(): {old: &cm.config.MsgGateway, new: new(config.MsgGateway)},
cm.config.MsgTransfer.GetConfigFileName(): {old: &cm.config.MsgTransfer, new: new(config.MsgTransfer)},
cm.config.Push.GetConfigFileName(): {old: &cm.config.Push, new: new(config.Push)},
cm.config.Auth.GetConfigFileName(): {old: &cm.config.Auth, new: new(config.Auth)},
cm.config.Conversation.GetConfigFileName(): {old: &cm.config.Conversation, new: new(config.Conversation)},
cm.config.Friend.GetConfigFileName(): {old: &cm.config.Friend, new: new(config.Friend)},
cm.config.Group.GetConfigFileName(): {old: &cm.config.Group, new: new(config.Group)},
cm.config.Msg.GetConfigFileName(): {old: &cm.config.Msg, new: new(config.Msg)},
cm.config.Third.GetConfigFileName(): {old: &cm.config.Third, new: new(config.Third)},
cm.config.User.GetConfigFileName(): {old: &cm.config.User, new: new(config.User)},
cm.config.Redis.GetConfigFileName(): {old: &cm.config.Redis, new: new(config.Redis)},
cm.config.Share.GetConfigFileName(): {old: &cm.config.Share, new: new(config.Share)},
cm.config.Webhooks.GetConfigFileName(): {old: &cm.config.Webhooks, new: new(config.Webhooks)},
}
changedKeys := make([]string, 0, len(configMap))
for k, v := range configMap {
err := config.Load(cm.configPath, k, config.EnvPrefixMap[k], v.new)
if err != nil {
log.ZError(c, "load config failed", err)
continue
}
equal := reflect.DeepEqual(v.old, v.new)
if !checkChange || !equal {
changedKeys = append(changedKeys, k)
}
}
for _, k := range changedKeys {
data, err := json.Marshal(configMap[k].new)
if err != nil {
log.ZError(c, "marshal config failed", err)
continue
}
ops = append(ops, clientv3.OpPut(etcd.BuildKey(k), string(data)))
}
if len(ops) > 0 {
txn.Then(ops...)
_, err := txn.Commit()
if err != nil {
return errs.WrapMsg(err, "commit etcd txn failed")
}
}
return nil
}
func (cm *ConfigManager) Restart(c *gin.Context) {
go cm.restart(c)
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) restart(c *gin.Context) {
time.Sleep(waitHttp) // wait for Restart http call return
t := time.Now().Unix()
_, err := cm.client.Put(c, etcd.BuildKey(etcd.RestartKey), strconv.Itoa(int(t)))
if err != nil {
log.ZError(c, "restart etcd put key failed", err)
}
}
func (cm *ConfigManager) SetEnableConfigManager(c *gin.Context) {
if cm.config.Discovery.Enable != config.ETCD {
apiresp.GinError(c, errs.New("only etcd support config manager").Wrap())
return
}
var req apistruct.SetEnableConfigManagerReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var enableStr string
if req.Enable {
enableStr = etcd.Enable
} else {
enableStr = etcd.Disable
}
resp, err := cm.client.Get(c, etcd.BuildKey(etcd.EnableConfigCenterKey))
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "getEnableConfigManager failed"))
return
}
if !(resp.Count > 0 && string(resp.Kvs[0].Value) == etcd.Enable) && req.Enable {
go func() {
time.Sleep(waitHttp) // wait for Restart http call return
err := cm.resetConfig(c, false, clientv3.OpPut(etcd.BuildKey(etcd.EnableConfigCenterKey), enableStr))
if err != nil {
log.ZError(c, "resetConfig failed", err)
}
}()
} else {
_, err = cm.client.Put(c, etcd.BuildKey(etcd.EnableConfigCenterKey), enableStr)
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "setEnableConfigManager failed"))
return
}
}
apiresp.GinSuccess(c, nil)
}
func (cm *ConfigManager) GetEnableConfigManager(c *gin.Context) {
resp, err := cm.client.Get(c, etcd.BuildKey(etcd.EnableConfigCenterKey))
if err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "getEnableConfigManager failed"))
return
}
var enable bool
if resp.Count > 0 && string(resp.Kvs[0].Value) == etcd.Enable {
enable = true
}
apiresp.GinSuccess(c, &apistruct.GetEnableConfigManagerResp{Enable: enable})
}

View File

@@ -0,0 +1,82 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/conversation"
"github.com/openimsdk/tools/a2r"
)
type ConversationApi struct {
Client conversation.ConversationClient
}
func NewConversationApi(client conversation.ConversationClient) ConversationApi {
return ConversationApi{client}
}
func (o *ConversationApi) GetAllConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetAllConversations, o.Client)
}
func (o *ConversationApi) GetSortedConversationList(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetSortedConversationList, o.Client)
}
func (o *ConversationApi) GetConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetConversation, o.Client)
}
func (o *ConversationApi) GetConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetConversations, o.Client)
}
func (o *ConversationApi) SetConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.SetConversations, o.Client)
}
//func (o *ConversationApi) GetConversationOfflinePushUserIDs(c *gin.Context) {
// a2r.Call(c, conversation.ConversationClient.GetConversationOfflinePushUserIDs, o.Client)
//}
func (o *ConversationApi) GetFullOwnerConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetFullOwnerConversationIDs, o.Client)
}
func (o *ConversationApi) GetIncrementalConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetIncrementalConversation, o.Client)
}
func (o *ConversationApi) GetOwnerConversation(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetOwnerConversation, o.Client)
}
func (o *ConversationApi) GetNotNotifyConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetNotNotifyConversationIDs, o.Client)
}
func (o *ConversationApi) GetPinnedConversationIDs(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.GetPinnedConversationIDs, o.Client)
}
func (o *ConversationApi) UpdateConversationsByUser(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.UpdateConversationsByUser, o.Client)
}
func (o *ConversationApi) DeleteConversations(c *gin.Context) {
a2r.Call(c, conversation.ConversationClient.DeleteConversations, o.Client)
}

View File

@@ -0,0 +1,34 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/constant"
"github.com/go-playground/validator/v10"
)
// RequiredIf validates if the specified field is required based on the session type.
func RequiredIf(fl validator.FieldLevel) bool {
sessionType := fl.Parent().FieldByName("SessionType").Int()
switch sessionType {
case constant.SingleChatType, constant.NotificationChatType:
return fl.FieldName() != "RecvID" || fl.Field().String() != ""
case constant.WriteGroupChatType, constant.ReadGroupChatType:
return fl.FieldName() != "GroupID" || fl.Field().String() != ""
default:
return true
}
}

120
internal/api/friend.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/relation"
"github.com/openimsdk/tools/a2r"
)
type FriendApi struct {
Client relation.FriendClient
}
func NewFriendApi(client relation.FriendClient) FriendApi {
return FriendApi{client}
}
func (o *FriendApi) ApplyToAddFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.ApplyToAddFriend, o.Client)
}
func (o *FriendApi) RespondFriendApply(c *gin.Context) {
a2r.Call(c, relation.FriendClient.RespondFriendApply, o.Client)
}
func (o *FriendApi) DeleteFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.DeleteFriend, o.Client)
}
func (o *FriendApi) GetFriendApplyList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriendsApplyTo, o.Client)
}
func (o *FriendApi) GetDesignatedFriendsApply(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetDesignatedFriendsApply, o.Client)
}
func (o *FriendApi) GetSelfApplyList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriendsApplyFrom, o.Client)
}
func (o *FriendApi) GetFriendList(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationFriends, o.Client)
}
func (o *FriendApi) GetDesignatedFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetDesignatedFriends, o.Client)
}
func (o *FriendApi) SetFriendRemark(c *gin.Context) {
a2r.Call(c, relation.FriendClient.SetFriendRemark, o.Client)
}
func (o *FriendApi) AddBlack(c *gin.Context) {
a2r.Call(c, relation.FriendClient.AddBlack, o.Client)
}
func (o *FriendApi) GetPaginationBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetPaginationBlacks, o.Client)
}
func (o *FriendApi) GetSpecifiedBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSpecifiedBlacks, o.Client)
}
func (o *FriendApi) RemoveBlack(c *gin.Context) {
a2r.Call(c, relation.FriendClient.RemoveBlack, o.Client)
}
func (o *FriendApi) ImportFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.ImportFriends, o.Client)
}
func (o *FriendApi) IsFriend(c *gin.Context) {
a2r.Call(c, relation.FriendClient.IsFriend, o.Client)
}
func (o *FriendApi) GetFriendIDs(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetFriendIDs, o.Client)
}
func (o *FriendApi) GetSpecifiedFriendsInfo(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSpecifiedFriendsInfo, o.Client)
}
func (o *FriendApi) UpdateFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.UpdateFriends, o.Client)
}
func (o *FriendApi) GetIncrementalFriends(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetIncrementalFriends, o.Client)
}
// GetIncrementalBlacks is temporarily unused.
// Deprecated: This function is currently unused and may be removed in future versions.
func (o *FriendApi) GetIncrementalBlacks(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetIncrementalBlacks, o.Client)
}
func (o *FriendApi) GetFullFriendUserIDs(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetFullFriendUserIDs, o.Client)
}
func (o *FriendApi) GetSelfUnhandledApplyCount(c *gin.Context) {
a2r.Call(c, relation.FriendClient.GetSelfUnhandledApplyCount, o.Client)
}

171
internal/api/group.go Normal file
View File

@@ -0,0 +1,171 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/protocol/group"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
)
type GroupApi struct {
Client group.GroupClient
}
func NewGroupApi(client group.GroupClient) GroupApi {
return GroupApi{client}
}
func (o *GroupApi) CreateGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.CreateGroup, o.Client)
}
func (o *GroupApi) SetGroupInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupInfo, o.Client)
}
func (o *GroupApi) SetGroupInfoEx(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupInfoEx, o.Client)
}
func (o *GroupApi) JoinGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.JoinGroup, o.Client)
}
func (o *GroupApi) QuitGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.QuitGroup, o.Client)
}
func (o *GroupApi) ApplicationGroupResponse(c *gin.Context) {
a2r.Call(c, group.GroupClient.GroupApplicationResponse, o.Client)
}
func (o *GroupApi) TransferGroupOwner(c *gin.Context) {
a2r.Call(c, group.GroupClient.TransferGroupOwner, o.Client)
}
func (o *GroupApi) GetRecvGroupApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupApplicationList, o.Client)
}
func (o *GroupApi) GetUserReqGroupApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetUserReqApplicationList, o.Client)
}
func (o *GroupApi) GetGroupUsersReqApplicationList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupUsersReqApplicationList, o.Client)
}
func (o *GroupApi) GetSpecifiedUserGroupRequestInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetSpecifiedUserGroupRequestInfo, o.Client)
}
func (o *GroupApi) GetGroupsInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupsInfo, o.Client)
//a2r.Call(c, group.GroupClient.GetGroupsInfo, o.Client, c, a2r.NewNilReplaceOption(group.GroupClient.GetGroupsInfo))
}
func (o *GroupApi) KickGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.KickGroupMember, o.Client)
}
func (o *GroupApi) GetGroupMembersInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMembersInfo, o.Client)
//a2r.Call(c, group.GroupClient.GetGroupMembersInfo, o.Client, c, a2r.NewNilReplaceOption(group.GroupClient.GetGroupMembersInfo))
}
func (o *GroupApi) GetGroupMemberList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMemberList, o.Client)
}
func (o *GroupApi) InviteUserToGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.InviteUserToGroup, o.Client)
}
func (o *GroupApi) GetJoinedGroupList(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetJoinedGroupList, o.Client)
}
func (o *GroupApi) DismissGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.DismissGroup, o.Client)
}
func (o *GroupApi) MuteGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.MuteGroupMember, o.Client)
}
func (o *GroupApi) CancelMuteGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.CancelMuteGroupMember, o.Client)
}
func (o *GroupApi) MuteGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.MuteGroup, o.Client)
}
func (o *GroupApi) CancelMuteGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.CancelMuteGroup, o.Client)
}
func (o *GroupApi) SetGroupMemberInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.SetGroupMemberInfo, o.Client)
}
func (o *GroupApi) GetGroupAbstractInfo(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupAbstractInfo, o.Client)
}
// func (g *Group) SetGroupMemberNickname(c *gin.Context) {
// a2r.Call(c, group.GroupClient.SetGroupMemberNickname, g.userClient)
//}
//
// func (g *Group) GetGroupAllMemberList(c *gin.Context) {
// a2r.Call(c, group.GroupClient.GetGroupAllMember, g.userClient)
//}
func (o *GroupApi) GroupCreateCount(c *gin.Context) {
a2r.Call(c, group.GroupClient.GroupCreateCount, o.Client)
}
func (o *GroupApi) GetGroups(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroups, o.Client)
}
func (o *GroupApi) GetGroupMemberUserIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupMemberUserIDs, o.Client)
}
func (o *GroupApi) GetIncrementalJoinGroup(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetIncrementalJoinGroup, o.Client)
}
func (o *GroupApi) GetIncrementalGroupMember(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetIncrementalGroupMember, o.Client)
}
func (o *GroupApi) GetIncrementalGroupMemberBatch(c *gin.Context) {
a2r.Call(c, group.GroupClient.BatchGetIncrementalGroupMember, o.Client)
}
func (o *GroupApi) GetFullGroupMemberUserIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetFullGroupMemberUserIDs, o.Client)
}
func (o *GroupApi) GetFullJoinGroupIDs(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetFullJoinGroupIDs, o.Client)
}
func (o *GroupApi) GetGroupApplicationUnhandledCount(c *gin.Context) {
a2r.Call(c, group.GroupClient.GetGroupApplicationUnhandledCount, o.Client)
}

104
internal/api/init.go Normal file
View File

@@ -0,0 +1,104 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"time"
conf "git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/network"
"github.com/openimsdk/tools/utils/runtimeenv"
"google.golang.org/grpc"
)
type Config struct {
conf.AllConfig
ConfigPath conf.Path
Index conf.Index
}
func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryRegistry, service grpc.ServiceRegistrar) error {
apiPort, err := datautil.GetElemByIndex(config.API.Api.Ports, int(config.Index))
if err != nil {
return err
}
router, err := newGinRouter(ctx, client, config)
if err != nil {
return err
}
apiCtx, apiCancel := context.WithCancelCause(context.Background())
done := make(chan struct{})
go func() {
httpServer := &http.Server{
Handler: router,
Addr: net.JoinHostPort(network.GetListenIP(config.API.Api.ListenIP), strconv.Itoa(apiPort)),
}
go func() {
defer close(done)
select {
case <-ctx.Done():
apiCancel(fmt.Errorf("recv ctx %w", context.Cause(ctx)))
case <-apiCtx.Done():
}
log.ZDebug(ctx, "api server is shutting down")
if err := httpServer.Shutdown(context.Background()); err != nil {
log.ZWarn(ctx, "api server shutdown err", err)
}
}()
log.CInfo(ctx, "api server is init", "runtimeEnv", runtimeenv.RuntimeEnvironment(), "address", httpServer.Addr, "apiPort", apiPort)
err := httpServer.ListenAndServe()
if err == nil {
err = errors.New("api done")
}
apiCancel(err)
}()
//if config.Discovery.Enable == conf.ETCD {
// cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames())
// cm.Watch(ctx)
//}
//sigs := make(chan os.Signal, 1)
//signal.Notify(sigs, syscall.SIGTERM)
//select {
//case val := <-sigs:
// log.ZDebug(ctx, "recv exit", "signal", val.String())
// cancel(fmt.Errorf("signal %s", val.String()))
//case <-ctx.Done():
//}
<-apiCtx.Done()
exitCause := context.Cause(apiCtx)
log.ZWarn(ctx, "api server exit", exitCause)
timer := time.NewTimer(time.Second * 15)
defer timer.Stop()
select {
case <-timer.C:
log.ZWarn(ctx, "api server graceful stop timeout", nil)
case <-done:
log.ZDebug(ctx, "api server graceful stop done")
}
return exitCause
}

287
internal/api/jssdk/jssdk.go Normal file
View File

@@ -0,0 +1,287 @@
package jssdk
import (
"context"
"sort"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/log"
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/jssdk"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
)
const (
maxGetActiveConversation = 500
defaultGetActiveConversation = 100
)
func NewJSSdkApi(userClient *rpcli.UserClient, relationClient *rpcli.RelationClient, groupClient *rpcli.GroupClient,
conversationClient *rpcli.ConversationClient, msgClient *rpcli.MsgClient) *JSSdk {
return &JSSdk{
userClient: userClient,
relationClient: relationClient,
groupClient: groupClient,
conversationClient: conversationClient,
msgClient: msgClient,
}
}
type JSSdk struct {
userClient *rpcli.UserClient
relationClient *rpcli.RelationClient
groupClient *rpcli.GroupClient
conversationClient *rpcli.ConversationClient
msgClient *rpcli.MsgClient
}
func (x *JSSdk) GetActiveConversations(c *gin.Context) {
call(c, x.getActiveConversations)
}
func (x *JSSdk) GetConversations(c *gin.Context) {
call(c, x.getConversations)
}
func (x *JSSdk) fillConversations(ctx context.Context, conversations []*jssdk.ConversationMsg) error {
if len(conversations) == 0 {
return nil
}
var (
userIDs []string
groupIDs []string
)
for _, c := range conversations {
if c.Conversation.GroupID == "" {
userIDs = append(userIDs, c.Conversation.UserID)
} else {
groupIDs = append(groupIDs, c.Conversation.GroupID)
}
}
var (
userMap map[string]*sdkws.UserInfo
friendMap map[string]*relation.FriendInfoOnly
groupMap map[string]*sdkws.GroupInfo
)
if len(userIDs) > 0 {
users, err := x.userClient.GetUsersInfo(ctx, userIDs)
if err != nil {
return err
}
friends, err := x.relationClient.GetFriendsInfo(ctx, conversations[0].Conversation.OwnerUserID, userIDs)
if err != nil {
return err
}
userMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
friendMap = datautil.SliceToMap(friends, (*relation.FriendInfoOnly).GetFriendUserID)
}
if len(groupIDs) > 0 {
groups, err := x.groupClient.GetGroupsInfo(ctx, groupIDs)
if err != nil {
return err
}
groupMap = datautil.SliceToMap(groups, (*sdkws.GroupInfo).GetGroupID)
}
for _, c := range conversations {
if c.Conversation.GroupID == "" {
c.User = userMap[c.Conversation.UserID]
c.Friend = friendMap[c.Conversation.UserID]
} else {
c.Group = groupMap[c.Conversation.GroupID]
}
}
return nil
}
func (x *JSSdk) getActiveConversations(ctx context.Context, req *jssdk.GetActiveConversationsReq) (*jssdk.GetActiveConversationsResp, error) {
if req.Count <= 0 || req.Count > maxGetActiveConversation {
req.Count = defaultGetActiveConversation
}
req.OwnerUserID = mcontext.GetOpUserID(ctx)
conversationIDs, err := x.conversationClient.GetConversationIDs(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
if len(conversationIDs) == 0 {
return &jssdk.GetActiveConversationsResp{}, nil
}
activeConversation, err := x.msgClient.GetActiveConversation(ctx, conversationIDs)
if err != nil {
return nil, err
}
if len(activeConversation) == 0 {
return &jssdk.GetActiveConversationsResp{}, nil
}
readSeq, err := x.msgClient.GetHasReadSeqs(ctx, conversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
sortConversations := sortActiveConversations{
Conversation: activeConversation,
}
if len(activeConversation) > 1 {
pinnedConversationIDs, err := x.conversationClient.GetPinnedConversationIDs(ctx, req.OwnerUserID)
if err != nil {
return nil, err
}
sortConversations.PinnedConversationIDs = datautil.SliceSet(pinnedConversationIDs)
}
sort.Sort(&sortConversations)
sortList := sortConversations.Top(int(req.Count))
conversations, err := x.conversationClient.GetConversations(ctx, datautil.Slice(sortList, func(c *msg.ActiveConversation) string {
return c.ConversationID
}), req.OwnerUserID)
if err != nil {
return nil, err
}
msgs, err := x.msgClient.GetSeqMessage(ctx, req.OwnerUserID, datautil.Slice(sortList, func(c *msg.ActiveConversation) *msg.ConversationSeqs {
return &msg.ConversationSeqs{
ConversationID: c.ConversationID,
Seqs: []int64{c.MaxSeq},
}
}))
if err != nil {
return nil, err
}
x.checkMessagesAndGetLastMessage(ctx, req.OwnerUserID, msgs)
conversationMap := datautil.SliceToMap(conversations, func(c *conversation.Conversation) string {
return c.ConversationID
})
resp := make([]*jssdk.ConversationMsg, 0, len(sortList))
for _, c := range sortList {
conv, ok := conversationMap[c.ConversationID]
if !ok {
continue
}
if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 {
resp = append(resp, &jssdk.ConversationMsg{
Conversation: conv,
LastMsg: msgList.Msgs[0],
MaxSeq: c.MaxSeq,
ReadSeq: readSeq[c.ConversationID],
})
}
}
if err := x.fillConversations(ctx, resp); err != nil {
return nil, err
}
var unreadCount int64
for _, c := range activeConversation {
count := c.MaxSeq - readSeq[c.ConversationID]
if count > 0 {
unreadCount += count
}
}
return &jssdk.GetActiveConversationsResp{
Conversations: resp,
UnreadCount: unreadCount,
}, nil
}
func (x *JSSdk) getConversations(ctx context.Context, req *jssdk.GetConversationsReq) (*jssdk.GetConversationsResp, error) {
req.OwnerUserID = mcontext.GetOpUserID(ctx)
conversations, err := x.conversationClient.GetConversations(ctx, req.ConversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
if len(conversations) == 0 {
return &jssdk.GetConversationsResp{}, nil
}
req.ConversationIDs = datautil.Slice(conversations, func(c *conversation.Conversation) string {
return c.ConversationID
})
maxSeqs, err := x.msgClient.GetMaxSeqs(ctx, req.ConversationIDs)
if err != nil {
return nil, err
}
readSeqs, err := x.msgClient.GetHasReadSeqs(ctx, req.ConversationIDs, req.OwnerUserID)
if err != nil {
return nil, err
}
conversationSeqs := make([]*msg.ConversationSeqs, 0, len(conversations))
for _, c := range conversations {
if seq := maxSeqs[c.ConversationID]; seq > 0 {
conversationSeqs = append(conversationSeqs, &msg.ConversationSeqs{
ConversationID: c.ConversationID,
Seqs: []int64{seq},
})
}
}
var msgs map[string]*sdkws.PullMsgs
if len(conversationSeqs) > 0 {
msgs, err = x.msgClient.GetSeqMessage(ctx, req.OwnerUserID, conversationSeqs)
if err != nil {
return nil, err
}
}
x.checkMessagesAndGetLastMessage(ctx, req.OwnerUserID, msgs)
resp := make([]*jssdk.ConversationMsg, 0, len(conversations))
for _, c := range conversations {
if msgList, ok := msgs[c.ConversationID]; ok && len(msgList.Msgs) > 0 {
resp = append(resp, &jssdk.ConversationMsg{
Conversation: c,
LastMsg: msgList.Msgs[0],
MaxSeq: maxSeqs[c.ConversationID],
ReadSeq: readSeqs[c.ConversationID],
})
}
}
if err := x.fillConversations(ctx, resp); err != nil {
return nil, err
}
var unreadCount int64
for conversationID, maxSeq := range maxSeqs {
count := maxSeq - readSeqs[conversationID]
if count > 0 {
unreadCount += count
}
}
return &jssdk.GetConversationsResp{
Conversations: resp,
UnreadCount: unreadCount,
}, nil
}
// This function checks whether the latest MaxSeq message is valid.
// If not, it needs to fetch a valid message again.
func (x *JSSdk) checkMessagesAndGetLastMessage(ctx context.Context, userID string, messages map[string]*sdkws.PullMsgs) {
var conversationIDs []string
for conversationID, message := range messages {
allInValid := true
for _, data := range message.Msgs {
if data.Status < constant.MsgStatusHasDeleted {
allInValid = false
break
}
}
if allInValid {
conversationIDs = append(conversationIDs, conversationID)
}
}
if len(conversationIDs) > 0 {
resp, err := x.msgClient.GetLastMessage(ctx, &msg.GetLastMessageReq{
UserID: userID,
ConversationIDs: conversationIDs,
})
if err != nil {
log.ZError(ctx, "fetchLatestValidMessages", err, "conversationIDs", conversationIDs)
return
}
for conversationID, message := range resp.Msgs {
messages[conversationID] = &sdkws.PullMsgs{Msgs: []*sdkws.MsgData{message}}
}
}
}

View File

@@ -0,0 +1,33 @@
package jssdk
import "git.imall.cloud/openim/protocol/msg"
type sortActiveConversations struct {
Conversation []*msg.ActiveConversation
PinnedConversationIDs map[string]struct{}
}
func (s sortActiveConversations) Top(limit int) []*msg.ActiveConversation {
if limit > 0 && len(s.Conversation) > limit {
return s.Conversation[:limit]
}
return s.Conversation
}
func (s sortActiveConversations) Len() int {
return len(s.Conversation)
}
func (s sortActiveConversations) Less(i, j int) bool {
iv, jv := s.Conversation[i], s.Conversation[j]
_, ip := s.PinnedConversationIDs[iv.ConversationID]
_, jp := s.PinnedConversationIDs[jv.ConversationID]
if ip != jp {
return ip
}
return iv.LastTime > jv.LastTime
}
func (s sortActiveConversations) Swap(i, j int) {
s.Conversation[i], s.Conversation[j] = s.Conversation[j], s.Conversation[i]
}

View File

@@ -0,0 +1,77 @@
package jssdk
import (
"context"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/checker"
"github.com/openimsdk/tools/errs"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"io"
"strings"
)
func field[A, B, C any](ctx context.Context, fn func(ctx context.Context, req *A, opts ...grpc.CallOption) (*B, error), req *A, get func(*B) C) (C, error) {
resp, err := fn(ctx, req)
if err != nil {
var c C
return c, err
}
return get(resp), nil
}
func call[A, B any](c *gin.Context, fn func(ctx context.Context, req *A) (*B, error)) {
var isJSON bool
switch contentType := c.GetHeader("Content-Type"); {
case contentType == "":
isJSON = true
case strings.Contains(contentType, "application/json"):
isJSON = true
case strings.Contains(contentType, "application/protobuf"):
case strings.Contains(contentType, "application/x-protobuf"):
default:
apiresp.GinError(c, errs.ErrArgs.WrapMsg("unsupported content type"))
return
}
var req *A
if isJSON {
var err error
req, err = a2r.ParseRequest[A](c)
if err != nil {
apiresp.GinError(c, err)
return
}
} else {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
apiresp.GinError(c, err)
return
}
req = new(A)
if err := proto.Unmarshal(body, any(req).(proto.Message)); err != nil {
apiresp.GinError(c, err)
return
}
if err := checker.Validate(&req); err != nil {
apiresp.GinError(c, err)
return
}
}
resp, err := fn(c, req)
if err != nil {
apiresp.GinError(c, err)
return
}
if isJSON {
apiresp.GinSuccess(c, resp)
return
}
body, err := proto.Marshal(any(resp).(proto.Message))
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, body)
}

1371
internal/api/meeting.go Normal file

File diff suppressed because it is too large Load Diff

575
internal/api/msg.go Normal file
View File

@@ -0,0 +1,575 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"encoding/base64"
"encoding/json"
"sync"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/reflect/protoreflect"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/webhook"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"github.com/openimsdk/tools/utils/datautil"
"github.com/openimsdk/tools/utils/idutil"
"github.com/openimsdk/tools/utils/jsonutil"
"github.com/openimsdk/tools/utils/timeutil"
)
var (
msgDataDescriptor []protoreflect.FieldDescriptor
msgDataDescriptorOnce sync.Once
)
func getMsgDataDescriptor() []protoreflect.FieldDescriptor {
msgDataDescriptorOnce.Do(func() {
skip := make(map[string]struct{})
respFields := new(msg.SendMsgResp).ProtoReflect().Descriptor().Fields()
for i := 0; i < respFields.Len(); i++ {
field := respFields.Get(i)
if !field.HasJSONName() {
continue
}
skip[field.JSONName()] = struct{}{}
}
fields := new(sdkws.MsgData).ProtoReflect().Descriptor().Fields()
num := fields.Len()
msgDataDescriptor = make([]protoreflect.FieldDescriptor, 0, num)
for i := 0; i < num; i++ {
field := fields.Get(i)
if !field.HasJSONName() {
continue
}
if _, ok := skip[field.JSONName()]; ok {
continue
}
msgDataDescriptor = append(msgDataDescriptor, fields.Get(i))
}
})
return msgDataDescriptor
}
type MessageApi struct {
Client msg.MsgClient
userClient *rpcli.UserClient
imAdminUserID []string
validate *validator.Validate
}
func NewMessageApi(client msg.MsgClient, userClient *rpcli.UserClient, imAdminUserID []string) MessageApi {
return MessageApi{Client: client, userClient: userClient, imAdminUserID: imAdminUserID, validate: validator.New()}
}
func (*MessageApi) SetOptions(options map[string]bool, value bool) {
datautil.SetSwitchFromOptions(options, constant.IsHistory, value)
datautil.SetSwitchFromOptions(options, constant.IsPersistent, value)
datautil.SetSwitchFromOptions(options, constant.IsSenderSync, value)
datautil.SetSwitchFromOptions(options, constant.IsConversationUpdate, value)
}
func (m *MessageApi) newUserSendMsgReq(_ *gin.Context, params *apistruct.SendMsg, data any) *msg.SendMsgReq {
msgData := &sdkws.MsgData{
SendID: params.SendID,
GroupID: params.GroupID,
ClientMsgID: idutil.GetMsgIDByMD5(params.SendID),
SenderPlatformID: params.SenderPlatformID,
SenderNickname: params.SenderNickname,
SenderFaceURL: params.SenderFaceURL,
SessionType: params.SessionType,
MsgFrom: constant.SysMsgType,
ContentType: params.ContentType,
CreateTime: timeutil.GetCurrentTimestampByMill(),
SendTime: params.SendTime,
OfflinePushInfo: params.OfflinePushInfo,
Ex: params.Ex,
}
var newContent string
options := make(map[string]bool, 5)
switch params.ContentType {
case constant.OANotification:
notification := sdkws.NotificationElem{}
notification.Detail = jsonutil.StructToJsonString(params.Content)
newContent = jsonutil.StructToJsonString(&notification)
case constant.Text:
fallthrough
case constant.AtText:
if atElem, ok := data.(*apistruct.AtElem); ok {
msgData.AtUserIDList = atElem.AtUserList
}
fallthrough
case constant.Picture:
fallthrough
case constant.Custom:
fallthrough
case constant.Voice:
fallthrough
case constant.Video:
fallthrough
case constant.File:
fallthrough
default:
newContent = jsonutil.StructToJsonString(params.Content)
}
if params.IsOnlineOnly {
m.SetOptions(options, false)
}
if params.NotOfflinePush {
datautil.SetSwitchFromOptions(options, constant.IsOfflinePush, false)
}
msgData.Content = []byte(newContent)
msgData.Options = options
pbData := msg.SendMsgReq{
MsgData: msgData,
}
return &pbData
}
func (m *MessageApi) GetSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetMaxSeq, m.Client)
}
func (m *MessageApi) PullMsgBySeqs(c *gin.Context) {
a2r.Call(c, msg.MsgClient.PullMessageBySeqs, m.Client)
}
func (m *MessageApi) RevokeMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.RevokeMsg, m.Client)
}
func (m *MessageApi) MarkMsgsAsRead(c *gin.Context) {
a2r.Call(c, msg.MsgClient.MarkMsgsAsRead, m.Client)
}
func (m *MessageApi) MarkConversationAsRead(c *gin.Context) {
a2r.Call(c, msg.MsgClient.MarkConversationAsRead, m.Client)
}
func (m *MessageApi) GetConversationsHasReadAndMaxSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetConversationsHasReadAndMaxSeq, m.Client)
}
func (m *MessageApi) SetConversationHasReadSeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.SetConversationHasReadSeq, m.Client)
}
func (m *MessageApi) ClearConversationsMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.ClearConversationsMsg, m.Client)
}
func (m *MessageApi) UserClearAllMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.UserClearAllMsg, m.Client)
}
func (m *MessageApi) DeleteMsgs(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgs, m.Client)
}
func (m *MessageApi) DeleteMsgPhysicalBySeq(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgPhysicalBySeq, m.Client)
}
func (m *MessageApi) DeleteMsgPhysical(c *gin.Context) {
a2r.Call(c, msg.MsgClient.DeleteMsgPhysical, m.Client)
}
func (m *MessageApi) getSendMsgReq(c *gin.Context, req apistruct.SendMsg) (sendMsgReq *msg.SendMsgReq, err error) {
var data any
log.ZDebug(c, "getSendMsgReq", "req", req.Content)
switch req.ContentType {
case constant.Text:
data = &apistruct.TextElem{}
case constant.Picture:
data = &apistruct.PictureElem{}
case constant.Voice:
data = &apistruct.SoundElem{}
case constant.Video:
data = &apistruct.VideoElem{}
case constant.File:
data = &apistruct.FileElem{}
case constant.AtText:
data = &apistruct.AtElem{}
case constant.Custom:
data = &apistruct.CustomElem{}
case constant.MarkdownText:
data = &apistruct.MarkdownTextElem{}
case constant.Quote:
data = &apistruct.QuoteElem{}
case constant.OANotification:
data = &apistruct.OANotificationElem{}
req.SessionType = constant.NotificationChatType
if err = m.userClient.GetNotificationByID(c, req.SendID); err != nil {
return nil, err
}
default:
return nil, errs.WrapMsg(errs.ErrArgs, "unsupported content type", "contentType", req.ContentType)
}
if err := mapstructure.WeakDecode(req.Content, data); err != nil {
return nil, errs.WrapMsg(err, "failed to decode message content")
}
log.ZDebug(c, "getSendMsgReq", "decodedContent", data)
if err := m.validate.Struct(data); err != nil {
return nil, errs.WrapMsg(err, "validation error")
}
return m.newUserSendMsgReq(c, &req, data), nil
}
func (m *MessageApi) getModifyFields(req, respModify *sdkws.MsgData) map[string]any {
if req == nil || respModify == nil {
return nil
}
fields := make(map[string]any)
reqProtoReflect := req.ProtoReflect()
respProtoReflect := respModify.ProtoReflect()
for _, descriptor := range getMsgDataDescriptor() {
reqValue := reqProtoReflect.Get(descriptor)
respValue := respProtoReflect.Get(descriptor)
if !reqValue.Equal(respValue) {
val := respValue.Interface()
name := descriptor.JSONName()
if name == "content" {
if bs, ok := val.([]byte); ok {
val = string(bs)
}
}
fields[name] = val
}
}
if len(fields) == 0 {
fields = nil
}
return fields
}
func (m *MessageApi) ginRespSendMsg(c *gin.Context, req *msg.SendMsgReq, resp *msg.SendMsgResp) {
res := m.getModifyFields(req.MsgData, resp.Modify)
resp.Modify = nil
apiresp.GinSuccess(c, &apistruct.SendMsgResp{
SendMsgResp: resp,
Modify: res,
})
}
// SendMessage handles the sending of a message. It's an HTTP handler function to be used with Gin framework.
func (m *MessageApi) SendMessage(c *gin.Context) {
// Initialize a request struct for sending a message.
req := apistruct.SendMsgReq{}
// Bind the JSON request body to the request struct.
if err := c.BindJSON(&req); err != nil {
// Respond with an error if request body binding fails.
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// Check if the user has the app manager role.
if !authverify.IsAdmin(c) {
// Respond with a permission error if the user is not an app manager.
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
// Prepare the message request with additional required data.
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil {
// Log and respond with an error if preparation fails.
apiresp.GinError(c, err)
return
}
// Set the receiver ID in the message data.
sendMsgReq.MsgData.RecvID = req.RecvID
// Attempt to send the message using the client.
respPb, err := m.Client.SendMsg(c, sendMsgReq)
if err != nil {
// Set the status to failed and respond with an error if sending fails.
apiresp.GinError(c, err)
return
}
// Set the status to successful if the message is sent.
var status = constant.MsgSendSuccessed
// Attempt to update the message sending status in the system.
_, err = m.Client.SetSendMsgStatus(c, &msg.SetSendMsgStatusReq{
Status: int32(status),
})
if err != nil {
// Log the error if updating the status fails.
apiresp.GinError(c, err)
return
}
// Respond with a success message and the response payload.
m.ginRespSendMsg(c, sendMsgReq, respPb)
}
func (m *MessageApi) SendBusinessNotification(c *gin.Context) {
req := struct {
Key string `json:"key"`
Data string `json:"data"`
SendUserID string `json:"sendUserID" binding:"required"`
RecvUserID string `json:"recvUserID"`
RecvGroupID string `json:"recvGroupID"`
SendMsg bool `json:"sendMsg"`
ReliabilityLevel *int `json:"reliabilityLevel"`
}{}
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if req.RecvUserID == "" && req.RecvGroupID == "" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("recvUserID and recvGroupID cannot be empty at the same time"))
return
}
if req.RecvUserID != "" && req.RecvGroupID != "" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("recvUserID and recvGroupID cannot be set at the same time"))
return
}
var sessionType int32
if req.RecvUserID != "" {
sessionType = constant.SingleChatType
} else {
sessionType = constant.ReadGroupChatType
}
if req.ReliabilityLevel == nil {
req.ReliabilityLevel = datautil.ToPtr(1)
}
if !authverify.IsAdmin(c) {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
sendMsgReq := msg.SendMsgReq{
MsgData: &sdkws.MsgData{
SendID: req.SendUserID,
RecvID: req.RecvUserID,
GroupID: req.RecvGroupID,
Content: []byte(jsonutil.StructToJsonString(&sdkws.NotificationElem{
Detail: jsonutil.StructToJsonString(&struct {
Key string `json:"key"`
Data string `json:"data"`
}{Key: req.Key, Data: req.Data}),
})),
MsgFrom: constant.SysMsgType,
ContentType: constant.BusinessNotification,
SessionType: sessionType,
CreateTime: timeutil.GetCurrentTimestampByMill(),
ClientMsgID: idutil.GetMsgIDByMD5(mcontext.GetOpUserID(c)),
Options: config.GetOptionsByNotification(config.NotificationConfig{
IsSendMsg: req.SendMsg,
ReliabilityLevel: *req.ReliabilityLevel,
UnreadCount: false,
}, nil),
},
}
respPb, err := m.Client.SendMsg(c, &sendMsgReq)
if err != nil {
apiresp.GinError(c, err)
return
}
m.ginRespSendMsg(c, &sendMsgReq, respPb)
}
func (m *MessageApi) BatchSendMsg(c *gin.Context) {
var (
req apistruct.BatchSendMsgReq
resp apistruct.BatchSendMsgResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, errs.ErrNoPermission.WrapMsg("only app manager can send message"))
return
}
var recvIDs []string
if req.IsSendAll {
var pageNumber int32 = 1
const showNumber = 500
for {
recvIDsPart, err := m.userClient.GetAllUserIDs(c, pageNumber, showNumber)
if err != nil {
apiresp.GinError(c, err)
return
}
recvIDs = append(recvIDs, recvIDsPart...)
if len(recvIDsPart) < showNumber {
break
}
pageNumber++
}
} else {
recvIDs = req.RecvIDs
}
log.ZDebug(c, "BatchSendMsg nums", "nums ", len(recvIDs))
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil {
apiresp.GinError(c, err)
return
}
for _, recvID := range recvIDs {
sendMsgReq.MsgData.RecvID = recvID
rpcResp, err := m.Client.SendMsg(c, sendMsgReq)
if err != nil {
resp.FailedIDs = append(resp.FailedIDs, recvID)
continue
}
resp.Results = append(resp.Results, &apistruct.SingleReturnResult{
ServerMsgID: rpcResp.ServerMsgID,
ClientMsgID: rpcResp.ClientMsgID,
SendTime: rpcResp.SendTime,
RecvID: recvID,
Modify: m.getModifyFields(sendMsgReq.MsgData, rpcResp.Modify),
})
}
apiresp.GinSuccess(c, resp)
}
func (m *MessageApi) SendSimpleMessage(c *gin.Context) {
encodedKey, ok := c.GetQuery(webhook.Key)
if !ok {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing key in query").Wrap())
return
}
decodedData, err := base64.StdEncoding.DecodeString(encodedKey)
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
var (
req apistruct.SendSingleMsgReq
keyMsgData apistruct.KeyMsgData
sendID string
sessionType int32
recvID string
)
if err = c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
err = json.Unmarshal(decodedData, &keyMsgData)
if err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
if keyMsgData.GroupID != "" {
sessionType = constant.ReadGroupChatType
sendID = req.SendID
} else {
sessionType = constant.SingleChatType
sendID = keyMsgData.RecvID
recvID = keyMsgData.SendID
}
// check param
if keyMsgData.SendID == "" {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing recvID or GroupID").Wrap())
return
}
if sendID == "" {
apiresp.GinError(c, errs.ErrArgs.WithDetail("missing sendID").Wrap())
return
}
content, err := jsonutil.JsonMarshal(apistruct.MarkdownTextElem{Content: req.Content})
if err != nil {
apiresp.GinError(c, errs.Wrap(err))
return
}
msgData := &sdkws.MsgData{
SendID: sendID,
RecvID: recvID,
GroupID: keyMsgData.GroupID,
ClientMsgID: idutil.GetMsgIDByMD5(sendID),
SenderPlatformID: constant.AdminPlatformID,
SessionType: sessionType,
MsgFrom: constant.UserMsgType,
ContentType: constant.MarkdownText,
Content: content,
OfflinePushInfo: req.OfflinePushInfo,
Ex: req.Ex,
}
sendReq := &msg.SendSimpleMsgReq{
MsgData: msgData,
}
respPb, err := m.Client.SendSimpleMsg(c, sendReq)
if err != nil {
apiresp.GinError(c, err)
return
}
var status = constant.MsgSendSuccessed
_, err = m.Client.SetSendMsgStatus(c, &msg.SetSendMsgStatusReq{
Status: int32(status),
})
if err != nil {
apiresp.GinError(c, err)
return
}
m.ginRespSendMsg(c, &msg.SendMsgReq{MsgData: sendReq.MsgData}, &msg.SendMsgResp{
ServerMsgID: respPb.ServerMsgID,
ClientMsgID: respPb.ClientMsgID,
SendTime: respPb.SendTime,
Modify: respPb.Modify,
})
}
func (m *MessageApi) CheckMsgIsSendSuccess(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetSendMsgStatus, m.Client)
}
func (m *MessageApi) GetUsersOnlineStatus(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetSendMsgStatus, m.Client)
}
func (m *MessageApi) GetActiveUser(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetActiveUser, m.Client)
}
func (m *MessageApi) GetActiveGroup(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetActiveGroup, m.Client)
}
func (m *MessageApi) SearchMsg(c *gin.Context) {
a2r.Call(c, msg.MsgClient.SearchMessage, m.Client)
}
func (m *MessageApi) GetServerTime(c *gin.Context) {
a2r.Call(c, msg.MsgClient.GetServerTime, m.Client)
}

View File

@@ -0,0 +1,99 @@
package api
import (
"encoding/json"
"errors"
"net/http"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
)
type PrometheusDiscoveryApi struct {
config *Config
kv discovery.KeyValue
}
func NewPrometheusDiscoveryApi(config *Config, client discovery.SvcDiscoveryRegistry) *PrometheusDiscoveryApi {
api := &PrometheusDiscoveryApi{
config: config,
kv: client,
}
return api
}
func (p *PrometheusDiscoveryApi) discovery(c *gin.Context, key string) {
value, err := p.kv.GetKeyWithPrefix(c, prommetrics.BuildDiscoveryKeyPrefix(key))
if err != nil {
if errors.Is(err, discovery.ErrNotSupported) {
c.JSON(http.StatusOK, []struct{}{})
return
}
apiresp.GinError(c, errs.WrapMsg(err, "get key value"))
return
}
if len(value) == 0 {
c.JSON(http.StatusOK, []*prommetrics.RespTarget{})
return
}
var resp prommetrics.RespTarget
for i := range value {
var tmp prommetrics.Target
if err = json.Unmarshal(value[i], &tmp); err != nil {
apiresp.GinError(c, errs.WrapMsg(err, "json unmarshal err"))
return
}
resp.Targets = append(resp.Targets, tmp.Target)
resp.Labels = tmp.Labels // default label is fixed. See prommetrics.BuildDefaultTarget
}
c.JSON(http.StatusOK, []*prommetrics.RespTarget{&resp})
}
func (p *PrometheusDiscoveryApi) Api(c *gin.Context) {
p.discovery(c, prommetrics.APIKeyName)
}
func (p *PrometheusDiscoveryApi) User(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.User)
}
func (p *PrometheusDiscoveryApi) Group(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Group)
}
func (p *PrometheusDiscoveryApi) Msg(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Msg)
}
func (p *PrometheusDiscoveryApi) Friend(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Friend)
}
func (p *PrometheusDiscoveryApi) Conversation(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Conversation)
}
func (p *PrometheusDiscoveryApi) Third(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Third)
}
func (p *PrometheusDiscoveryApi) Auth(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Auth)
}
func (p *PrometheusDiscoveryApi) Push(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.Push)
}
func (p *PrometheusDiscoveryApi) MessageGateway(c *gin.Context) {
p.discovery(c, p.config.Discovery.RpcService.MessageGateway)
}
func (p *PrometheusDiscoveryApi) MessageTransfer(c *gin.Context) {
p.discovery(c, prommetrics.MessageTransferKeyName)
}

83
internal/api/ratelimit.go Normal file
View File

@@ -0,0 +1,83 @@
package api
import (
"fmt"
"math"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/stability/ratelimit"
"github.com/openimsdk/tools/stability/ratelimit/bbr"
)
type RateLimiter struct {
Enable bool `yaml:"enable"`
Window time.Duration `yaml:"window"` // time duration per window
Bucket int `yaml:"bucket"` // bucket number for each window
CPUThreshold int64 `yaml:"cpuThreshold"` // CPU threshold; valid range 01000 (1000 = 100%)
}
func RateLimitMiddleware(config *RateLimiter) gin.HandlerFunc {
if !config.Enable {
return func(c *gin.Context) {
c.Next()
}
}
limiter := bbr.NewBBRLimiter(
bbr.WithWindow(config.Window),
bbr.WithBucket(config.Bucket),
bbr.WithCPUThreshold(config.CPUThreshold),
)
return func(c *gin.Context) {
status := limiter.Stat()
c.Header("X-BBR-CPU", strconv.FormatInt(status.CPU, 10))
c.Header("X-BBR-MinRT", strconv.FormatInt(status.MinRt, 10))
c.Header("X-BBR-MaxPass", strconv.FormatInt(status.MaxPass, 10))
c.Header("X-BBR-MaxInFlight", strconv.FormatInt(status.MaxInFlight, 10))
c.Header("X-BBR-InFlight", strconv.FormatInt(status.InFlight, 10))
done, err := limiter.Allow()
if err != nil {
c.Header("X-RateLimit-Policy", "BBR")
c.Header("Retry-After", calculateBBRRetryAfter(status))
c.Header("X-RateLimit-Limit", strconv.FormatInt(status.MaxInFlight, 10))
c.Header("X-RateLimit-Remaining", "0") // There is no concept of remaining quota in BBR.
fmt.Println("rate limited:", err, "path:", c.Request.URL.Path)
log.ZWarn(c, "rate limited", err, "path", c.Request.URL.Path)
c.AbortWithStatus(http.StatusTooManyRequests)
apiresp.GinError(c, errs.NewCodeError(http.StatusTooManyRequests, "too many requests, please try again later"))
return
}
c.Next()
done(ratelimit.DoneInfo{})
}
}
func calculateBBRRetryAfter(status bbr.Stat) string {
loadRatio := float64(status.CPU) / float64(status.CPU)
if loadRatio < 0.8 {
return "1"
}
if loadRatio < 0.95 {
return "2"
}
backoff := 1 + int64(math.Pow(loadRatio-0.95, 2)*50)
if backoff > 5 {
backoff = 5
}
return strconv.FormatInt(backoff, 10)
}

1022
internal/api/redpacket.go Normal file

File diff suppressed because it is too large Load Diff

493
internal/api/router.go Normal file
View File

@@ -0,0 +1,493 @@
package api
import (
"context"
"net/http"
"strings"
"git.imall.cloud/openim/open-im-server-deploy/internal/api/jssdk"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/prommetrics"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/servererrs"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database/mgo"
"git.imall.cloud/openim/open-im-server-deploy/pkg/dbbuild"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
pbAuth "git.imall.cloud/openim/protocol/auth"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/conversation"
"git.imall.cloud/openim/protocol/group"
"git.imall.cloud/openim/protocol/msg"
"git.imall.cloud/openim/protocol/relation"
"git.imall.cloud/openim/protocol/third"
"git.imall.cloud/openim/protocol/user"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/discovery/etcd"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mw"
"github.com/openimsdk/tools/mw/api"
clientv3 "go.etcd.io/etcd/client/v3"
)
const (
NoCompression = -1
DefaultCompression = 0
BestCompression = 1
BestSpeed = 2
)
func prommetricsGin() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
path := c.FullPath()
if c.Writer.Status() == http.StatusNotFound {
prommetrics.HttpCall("<404>", c.Request.Method, c.Writer.Status())
} else {
prommetrics.HttpCall(path, c.Request.Method, c.Writer.Status())
}
if resp := apiresp.GetGinApiResponse(c); resp != nil {
prommetrics.APICall(path, c.Request.Method, resp.ErrCode)
}
}
}
func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cfg *Config) (*gin.Engine, error) {
authConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Auth)
if err != nil {
return nil, err
}
userConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.User)
if err != nil {
return nil, err
}
groupConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Group)
if err != nil {
return nil, err
}
friendConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Friend)
if err != nil {
return nil, err
}
conversationConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Conversation)
if err != nil {
return nil, err
}
thirdConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Third)
if err != nil {
return nil, err
}
msgConn, err := client.GetConn(ctx, cfg.Discovery.RpcService.Msg)
if err != nil {
return nil, err
}
// 初始化数据库连接(用于红包功能)
dbb := dbbuild.NewBuilder(&cfg.AllConfig.Mongo, &cfg.AllConfig.Redis)
mgocli, err := dbb.Mongo(ctx)
if err != nil {
return nil, err
}
redisClient, err := dbb.Redis(ctx)
if err != nil {
return nil, err
}
redPacketDB, err := mgo.NewRedPacketMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
redPacketReceiveDB, err := mgo.NewRedPacketReceiveMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
walletDB, err := mgo.NewWalletMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
walletBalanceRecordDB, err := mgo.NewWalletBalanceRecordMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
userDB, err := mgo.NewUserMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
meetingDB, err := mgo.NewMeetingMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
meetingCheckInDB, err := mgo.NewMeetingCheckInMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
msgDocDatabase, err := mgo.NewMsgMongo(mgocli.GetDB())
if err != nil {
return nil, err
}
startOnlineCountRefresher(ctx, cfg, redisClient)
gin.SetMode(gin.ReleaseMode)
r := gin.New()
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
_ = v.RegisterValidation("required_if", RequiredIf)
}
switch cfg.API.Api.CompressionLevel {
case NoCompression:
case DefaultCompression:
r.Use(gzip.Gzip(gzip.DefaultCompression))
case BestCompression:
r.Use(gzip.Gzip(gzip.BestCompression))
case BestSpeed:
r.Use(gzip.Gzip(gzip.BestSpeed))
}
// Use rate limiter middleware
if cfg.API.RateLimiter.Enable {
rl := &RateLimiter{
Enable: cfg.API.RateLimiter.Enable,
Window: cfg.API.RateLimiter.Window,
Bucket: cfg.API.RateLimiter.Bucket,
CPUThreshold: cfg.API.RateLimiter.CPUThreshold,
}
r.Use(RateLimitMiddleware(rl))
}
if config.Standalone() {
r.Use(func(c *gin.Context) {
c.Set(authverify.CtxAdminUserIDsKey, cfg.Share.IMAdminUser.UserIDs)
})
}
r.Use(api.GinLogger(), prommetricsGin(), gin.RecoveryWithWriter(gin.DefaultErrorWriter, mw.GinPanicErr), mw.CorsHandler(),
mw.GinParseOperationID(), GinParseToken(rpcli.NewAuthClient(authConn)), setGinIsAdmin(cfg.Share.IMAdminUser.UserIDs))
u := NewUserApi(user.NewUserClient(userConn), client, cfg.Discovery.RpcService)
{
userRouterGroup := r.Group("/user")
userRouterGroup.POST("/user_register", u.UserRegister)
userRouterGroup.POST("/update_user_info", u.UpdateUserInfo)
userRouterGroup.POST("/update_user_info_ex", u.UpdateUserInfoEx)
userRouterGroup.POST("/set_global_msg_recv_opt", u.SetGlobalRecvMessageOpt)
userRouterGroup.POST("/get_users_info", u.GetUsersPublicInfo)
userRouterGroup.POST("/get_all_users_uid", u.GetAllUsersID)
userRouterGroup.POST("/account_check", u.AccountCheck)
userRouterGroup.POST("/get_users", u.GetUsers)
userRouterGroup.POST("/get_users_online_status", u.GetUsersOnlineStatus)
userRouterGroup.POST("/get_users_online_token_detail", u.GetUsersOnlineTokenDetail)
userRouterGroup.POST("/subscribe_users_status", u.SubscriberStatus)
userRouterGroup.POST("/get_users_status", u.GetUserStatus)
userRouterGroup.POST("/get_subscribe_users_status", u.GetSubscribeUsersStatus)
userRouterGroup.POST("/process_user_command_add", u.ProcessUserCommandAdd)
userRouterGroup.POST("/process_user_command_delete", u.ProcessUserCommandDelete)
userRouterGroup.POST("/process_user_command_update", u.ProcessUserCommandUpdate)
userRouterGroup.POST("/process_user_command_get", u.ProcessUserCommandGet)
userRouterGroup.POST("/process_user_command_get_all", u.ProcessUserCommandGetAll)
userRouterGroup.POST("/add_notification_account", u.AddNotificationAccount)
userRouterGroup.POST("/update_notification_account", u.UpdateNotificationAccountInfo)
userRouterGroup.POST("/search_notification_account", u.SearchNotificationAccount)
userRouterGroup.POST("/get_user_client_config", u.GetUserClientConfig)
userRouterGroup.POST("/set_user_client_config", u.SetUserClientConfig)
userRouterGroup.POST("/del_user_client_config", u.DelUserClientConfig)
userRouterGroup.POST("/page_user_client_config", u.PageUserClientConfig)
}
// friend routing group
{
f := NewFriendApi(relation.NewFriendClient(friendConn))
friendRouterGroup := r.Group("/friend")
friendRouterGroup.POST("/delete_friend", f.DeleteFriend)
friendRouterGroup.POST("/get_friend_apply_list", f.GetFriendApplyList)
friendRouterGroup.POST("/get_designated_friend_apply", f.GetDesignatedFriendsApply)
friendRouterGroup.POST("/get_self_friend_apply_list", f.GetSelfApplyList)
friendRouterGroup.POST("/get_friend_list", f.GetFriendList)
friendRouterGroup.POST("/get_designated_friends", f.GetDesignatedFriends)
friendRouterGroup.POST("/add_friend", f.ApplyToAddFriend)
friendRouterGroup.POST("/add_friend_response", f.RespondFriendApply)
friendRouterGroup.POST("/set_friend_remark", f.SetFriendRemark)
friendRouterGroup.POST("/add_black", f.AddBlack)
friendRouterGroup.POST("/get_black_list", f.GetPaginationBlacks)
friendRouterGroup.POST("/get_specified_blacks", f.GetSpecifiedBlacks)
friendRouterGroup.POST("/remove_black", f.RemoveBlack)
friendRouterGroup.POST("/get_incremental_blacks", f.GetIncrementalBlacks)
friendRouterGroup.POST("/import_friend", f.ImportFriends)
friendRouterGroup.POST("/is_friend", f.IsFriend)
friendRouterGroup.POST("/get_friend_id", f.GetFriendIDs)
friendRouterGroup.POST("/get_specified_friends_info", f.GetSpecifiedFriendsInfo)
friendRouterGroup.POST("/update_friends", f.UpdateFriends)
friendRouterGroup.POST("/get_incremental_friends", f.GetIncrementalFriends)
friendRouterGroup.POST("/get_full_friend_user_ids", f.GetFullFriendUserIDs)
friendRouterGroup.POST("/get_self_unhandled_apply_count", f.GetSelfUnhandledApplyCount)
}
g := NewGroupApi(group.NewGroupClient(groupConn))
{
groupRouterGroup := r.Group("/group")
groupRouterGroup.POST("/create_group", g.CreateGroup)
groupRouterGroup.POST("/set_group_info", g.SetGroupInfo)
groupRouterGroup.POST("/set_group_info_ex", g.SetGroupInfoEx)
groupRouterGroup.POST("/join_group", g.JoinGroup)
groupRouterGroup.POST("/quit_group", g.QuitGroup)
groupRouterGroup.POST("/group_application_response", g.ApplicationGroupResponse)
groupRouterGroup.POST("/transfer_group", g.TransferGroupOwner)
groupRouterGroup.POST("/get_recv_group_applicationList", g.GetRecvGroupApplicationList)
groupRouterGroup.POST("/get_user_req_group_applicationList", g.GetUserReqGroupApplicationList)
groupRouterGroup.POST("/get_group_users_req_application_list", g.GetGroupUsersReqApplicationList)
groupRouterGroup.POST("/get_specified_user_group_request_info", g.GetSpecifiedUserGroupRequestInfo)
groupRouterGroup.POST("/get_groups_info", g.GetGroupsInfo)
groupRouterGroup.POST("/kick_group", g.KickGroupMember)
groupRouterGroup.POST("/get_group_members_info", g.GetGroupMembersInfo)
groupRouterGroup.POST("/get_group_member_list", g.GetGroupMemberList)
groupRouterGroup.POST("/invite_user_to_group", g.InviteUserToGroup)
groupRouterGroup.POST("/get_joined_group_list", g.GetJoinedGroupList)
groupRouterGroup.POST("/dismiss_group", g.DismissGroup) //
groupRouterGroup.POST("/mute_group_member", g.MuteGroupMember)
groupRouterGroup.POST("/cancel_mute_group_member", g.CancelMuteGroupMember)
groupRouterGroup.POST("/mute_group", g.MuteGroup)
groupRouterGroup.POST("/cancel_mute_group", g.CancelMuteGroup)
groupRouterGroup.POST("/set_group_member_info", g.SetGroupMemberInfo)
groupRouterGroup.POST("/get_group_abstract_info", g.GetGroupAbstractInfo)
groupRouterGroup.POST("/get_groups", g.GetGroups)
groupRouterGroup.POST("/get_group_member_user_id", g.GetGroupMemberUserIDs)
groupRouterGroup.POST("/get_incremental_join_groups", g.GetIncrementalJoinGroup)
groupRouterGroup.POST("/get_incremental_group_members", g.GetIncrementalGroupMember)
groupRouterGroup.POST("/get_incremental_group_members_batch", g.GetIncrementalGroupMemberBatch)
groupRouterGroup.POST("/get_full_group_member_user_ids", g.GetFullGroupMemberUserIDs)
groupRouterGroup.POST("/get_full_join_group_ids", g.GetFullJoinGroupIDs)
groupRouterGroup.POST("/get_group_application_unhandled_count", g.GetGroupApplicationUnhandledCount)
}
// certificate
{
a := NewAuthApi(pbAuth.NewAuthClient(authConn))
authRouterGroup := r.Group("/auth")
authRouterGroup.POST("/get_admin_token", a.GetAdminToken)
authRouterGroup.POST("/get_user_token", a.GetUserToken)
authRouterGroup.POST("/parse_token", a.ParseToken)
authRouterGroup.POST("/force_logout", a.ForceLogout)
}
// Third service
{
t := NewThirdApi(third.NewThirdClient(thirdConn), cfg.API.Prometheus.GrafanaURL)
thirdGroup := r.Group("/third")
thirdGroup.GET("/prometheus", t.GetPrometheus)
thirdGroup.POST("/fcm_update_token", t.FcmUpdateToken)
thirdGroup.POST("/set_app_badge", t.SetAppBadge)
logs := thirdGroup.Group("/logs")
logs.POST("/upload", t.UploadLogs)
logs.POST("/delete", t.DeleteLogs)
logs.POST("/search", t.SearchLogs)
objectGroup := r.Group("/object")
objectGroup.POST("/part_limit", t.PartLimit)
objectGroup.POST("/part_size", t.PartSize)
objectGroup.POST("/initiate_multipart_upload", t.InitiateMultipartUpload)
objectGroup.POST("/auth_sign", t.AuthSign)
objectGroup.POST("/complete_multipart_upload", t.CompleteMultipartUpload)
objectGroup.POST("/access_url", t.AccessURL)
objectGroup.POST("/initiate_form_data", t.InitiateFormData)
objectGroup.POST("/complete_form_data", t.CompleteFormData)
objectGroup.GET("/*name", t.ObjectRedirect)
}
// Message
m := NewMessageApi(msg.NewMsgClient(msgConn), rpcli.NewUserClient(userConn), cfg.Share.IMAdminUser.UserIDs)
{
msgGroup := r.Group("/msg")
msgGroup.POST("/newest_seq", m.GetSeq)
msgGroup.POST("/search_msg", m.SearchMsg)
msgGroup.POST("/send_msg", m.SendMessage)
msgGroup.POST("/send_business_notification", m.SendBusinessNotification)
msgGroup.POST("/pull_msg_by_seq", m.PullMsgBySeqs)
msgGroup.POST("/revoke_msg", m.RevokeMsg)
msgGroup.POST("/mark_msgs_as_read", m.MarkMsgsAsRead)
msgGroup.POST("/mark_conversation_as_read", m.MarkConversationAsRead)
msgGroup.POST("/get_conversations_has_read_and_max_seq", m.GetConversationsHasReadAndMaxSeq)
msgGroup.POST("/set_conversation_has_read_seq", m.SetConversationHasReadSeq)
msgGroup.POST("/clear_conversation_msg", m.ClearConversationsMsg)
msgGroup.POST("/user_clear_all_msg", m.UserClearAllMsg)
msgGroup.POST("/delete_msgs", m.DeleteMsgs)
msgGroup.POST("/delete_msg_phsical_by_seq", m.DeleteMsgPhysicalBySeq)
msgGroup.POST("/delete_msg_physical", m.DeleteMsgPhysical)
msgGroup.POST("/batch_send_msg", m.BatchSendMsg)
msgGroup.POST("/send_simple_msg", m.SendSimpleMessage)
msgGroup.POST("/check_msg_is_send_success", m.CheckMsgIsSendSuccess)
msgGroup.POST("/get_server_time", m.GetServerTime)
}
// RedPacket
{
rp := NewRedPacketApi(rpcli.NewGroupClient(groupConn), rpcli.NewUserClient(userConn), msg.NewMsgClient(msgConn), redPacketDB, redPacketReceiveDB, walletDB, walletBalanceRecordDB, redisClient)
redPacketGroup := r.Group("/redpacket")
redPacketGroup.POST("/send_redpacket", rp.SendRedPacket)
redPacketGroup.POST("/receive", rp.ReceiveRedPacket)
redPacketGroup.POST("/get_detail", rp.GetRedPacketDetail) // 用户端查询红包详情
// 后台管理接口
redPacketGroup.POST("/get_redpackets_by_group", rp.GetRedPacketsByGroup)
redPacketGroup.POST("/get_receive_info", rp.GetRedPacketReceiveInfo)
redPacketGroup.POST("/pause", rp.PauseRedPacket)
}
// Wallet
{
wa := NewWalletApi(walletDB, walletBalanceRecordDB, userDB, rpcli.NewUserClient(userConn))
walletGroup := r.Group("/wallet")
// 后台管理接口
walletGroup.POST("/get_wallets", wa.GetWallets)
walletGroup.POST("/batch_update_balance", wa.BatchUpdateWalletBalance)
}
// Meeting
{
// 使用已初始化的systemConfigDB如果失败则使用nil
meetingSystemConfigDB, initErr := mgo.NewSystemConfigMongo(mgocli.GetDB())
if initErr != nil {
log.ZWarn(ctx, "failed to init system config db for meeting api, webhook attentionIds update will be disabled", initErr)
meetingSystemConfigDB = nil
}
ma := NewMeetingApi(meetingDB, meetingCheckInDB, rpcli.NewGroupClient(groupConn), rpcli.NewUserClient(userConn), rpcli.NewConversationClient(conversationConn), meetingSystemConfigDB)
meetingGroup := r.Group("/meeting")
// 管理接口
meetingGroup.POST("/create_meeting", ma.CreateMeeting)
meetingGroup.POST("/update_meeting", ma.UpdateMeeting)
meetingGroup.POST("/get_meetings", ma.GetMeetings)
meetingGroup.POST("/delete_meeting", ma.DeleteMeeting)
// 用户端接口
meetingGroup.POST("/get_meeting", ma.GetMeetingPublic)
meetingGroup.POST("/get_meetings_public", ma.GetMeetingsPublic)
// 签到接口
meetingGroup.POST("/check_in", ma.CheckInMeeting)
meetingGroup.POST("/check_user_check_in", ma.CheckUserCheckIn)
meetingGroup.POST("/get_check_ins", ma.GetMeetingCheckIns)
meetingGroup.POST("/get_check_in_stats", ma.GetMeetingCheckInStats)
}
// Conversation
{
c := NewConversationApi(conversation.NewConversationClient(conversationConn))
conversationGroup := r.Group("/conversation")
conversationGroup.POST("/get_sorted_conversation_list", c.GetSortedConversationList)
conversationGroup.POST("/get_all_conversations", c.GetAllConversations)
conversationGroup.POST("/get_conversation", c.GetConversation)
conversationGroup.POST("/get_conversations", c.GetConversations)
conversationGroup.POST("/set_conversations", c.SetConversations)
//conversationGroup.POST("/get_conversation_offline_push_user_ids", c.GetConversationOfflinePushUserIDs)
conversationGroup.POST("/get_full_conversation_ids", c.GetFullOwnerConversationIDs)
conversationGroup.POST("/get_incremental_conversations", c.GetIncrementalConversation)
conversationGroup.POST("/get_owner_conversation", c.GetOwnerConversation)
conversationGroup.POST("/get_not_notify_conversation_ids", c.GetNotNotifyConversationIDs)
conversationGroup.POST("/get_pinned_conversation_ids", c.GetPinnedConversationIDs)
conversationGroup.POST("/delete_conversations", c.DeleteConversations)
}
stats := NewStatisticsApi(redisClient, msgDocDatabase, rpcli.NewUserClient(userConn), rpcli.NewGroupClient(groupConn))
{
statisticsGroup := r.Group("/statistics")
statisticsGroup.POST("/user/register", u.UserRegisterCount)
statisticsGroup.POST("/user/active", m.GetActiveUser)
statisticsGroup.POST("/group/create", g.GroupCreateCount)
statisticsGroup.POST("/group/active", m.GetActiveGroup)
statisticsGroup.POST("/online_user_count", stats.OnlineUserCount)
statisticsGroup.POST("/online_user_count_trend", stats.OnlineUserCountTrend)
statisticsGroup.POST("/user_send_msg_count", stats.UserSendMsgCount)
statisticsGroup.POST("/user_send_msg_count_trend", stats.UserSendMsgCountTrend)
statisticsGroup.POST("/user_send_msg_query", stats.UserSendMsgQuery)
}
{
j := jssdk.NewJSSdkApi(rpcli.NewUserClient(userConn), rpcli.NewRelationClient(friendConn),
rpcli.NewGroupClient(groupConn), rpcli.NewConversationClient(conversationConn), rpcli.NewMsgClient(msgConn))
jssdk := r.Group("/jssdk")
jssdk.POST("/get_conversations", j.GetConversations)
jssdk.POST("/get_active_conversations", j.GetActiveConversations)
}
{
pd := NewPrometheusDiscoveryApi(cfg, client)
proDiscoveryGroup := r.Group("/prometheus_discovery")
proDiscoveryGroup.GET("/api", pd.Api)
proDiscoveryGroup.GET("/user", pd.User)
proDiscoveryGroup.GET("/group", pd.Group)
proDiscoveryGroup.GET("/msg", pd.Msg)
proDiscoveryGroup.GET("/friend", pd.Friend)
proDiscoveryGroup.GET("/conversation", pd.Conversation)
proDiscoveryGroup.GET("/third", pd.Third)
proDiscoveryGroup.GET("/auth", pd.Auth)
proDiscoveryGroup.GET("/push", pd.Push)
proDiscoveryGroup.GET("/msg_gateway", pd.MessageGateway)
proDiscoveryGroup.GET("/msg_transfer", pd.MessageTransfer)
}
var etcdClient *clientv3.Client
if cfg.Discovery.Enable == config.ETCD {
etcdClient = client.(*etcd.SvcDiscoveryRegistryImpl).GetClient()
}
// 初始化SystemConfig数据库用于webhook配置管理
var systemConfigDB database.SystemConfig
systemConfigDB, err = mgo.NewSystemConfigMongo(mgocli.GetDB())
if err != nil {
log.ZWarn(ctx, "failed to init system config db, webhook config management will be disabled", err)
}
cm := NewConfigManager(cfg.Share.IMAdminUser.UserIDs, &cfg.AllConfig, etcdClient, string(cfg.ConfigPath), systemConfigDB)
{
configGroup := r.Group("/config", cm.CheckAdmin)
configGroup.POST("/get_config_list", cm.GetConfigList)
configGroup.POST("/get_config", cm.GetConfig)
configGroup.POST("/set_config", cm.SetConfig)
configGroup.POST("/reset_config", cm.ResetConfig)
configGroup.POST("/set_enable_config_manager", cm.SetEnableConfigManager)
configGroup.POST("/get_enable_config_manager", cm.GetEnableConfigManager)
}
{
r.POST("/restart", cm.CheckAdmin, cm.Restart)
}
return r, nil
}
func GinParseToken(authClient *rpcli.AuthClient) gin.HandlerFunc {
return func(c *gin.Context) {
switch c.Request.Method {
case http.MethodPost:
for _, wApi := range Whitelist {
if strings.HasPrefix(c.Request.URL.Path, wApi) {
c.Next()
return
}
}
token := c.Request.Header.Get(constant.Token)
if token == "" {
log.ZWarn(c, "header get token error", servererrs.ErrArgs.WrapMsg("header must have token"))
apiresp.GinError(c, servererrs.ErrArgs.WrapMsg("header must have token"))
c.Abort()
return
}
resp, err := authClient.ParseToken(c, token)
if err != nil {
apiresp.GinError(c, err)
c.Abort()
return
}
c.Set(constant.OpUserPlatform, constant.PlatformIDToName(int(resp.PlatformID)))
c.Set(constant.OpUserID, resp.UserID)
c.Next()
}
}
}
func setGinIsAdmin(imAdminUserID []string) gin.HandlerFunc {
return func(c *gin.Context) {
c.Set(authverify.CtxAdminUserIDsKey, imAdminUserID)
}
}
// Whitelist api not parse token
var Whitelist = []string{
"/auth/get_admin_token",
"/auth/parse_token",
}

555
internal/api/statistics.go Normal file
View File

@@ -0,0 +1,555 @@
package api
import (
"context"
"errors"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/authverify"
rediscache "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/cache/redis"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/sdkws"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/datautil"
"github.com/redis/go-redis/v9"
)
type StatisticsApi struct {
rdb redis.UniversalClient
msgDatabase database.Msg
userClient *rpcli.UserClient
groupClient *rpcli.GroupClient
}
func NewStatisticsApi(rdb redis.UniversalClient, msgDatabase database.Msg, userClient *rpcli.UserClient, groupClient *rpcli.GroupClient) *StatisticsApi {
return &StatisticsApi{
rdb: rdb,
msgDatabase: msgDatabase,
userClient: userClient,
groupClient: groupClient,
}
}
const (
trendIntervalMinutes15 = 15
trendIntervalMinutes30 = 30
trendIntervalMinutes60 = 60
trendChatTypeSingle = 1
trendChatTypeGroup = 2
defaultTrendDuration = 24 * time.Hour
)
// refreshOnlineUserCountAndHistory 刷新在线人数并写入历史采样
func refreshOnlineUserCountAndHistory(ctx context.Context, rdb redis.UniversalClient) {
count, err := rediscache.RefreshOnlineUserCount(ctx, rdb)
if err != nil {
log.ZWarn(ctx, "refresh online user count failed", err)
return
}
if err := rediscache.AppendOnlineUserCountHistory(ctx, rdb, time.Now().UnixMilli(), count); err != nil {
log.ZWarn(ctx, "append online user count history failed", err)
}
}
// startOnlineCountRefresher 定时刷新在线人数缓存
func startOnlineCountRefresher(ctx context.Context, cfg *Config, rdb redis.UniversalClient) {
if cfg == nil || rdb == nil {
return
}
refreshCfg := cfg.API.OnlineCountRefresh
if !refreshCfg.Enable || refreshCfg.Interval <= 0 {
return
}
log.ZInfo(ctx, "online user count refresh enabled", "interval", refreshCfg.Interval)
go func() {
refreshOnlineUserCountAndHistory(ctx, rdb)
ticker := time.NewTicker(refreshCfg.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
refreshOnlineUserCountAndHistory(ctx, rdb)
}
}
}()
}
// OnlineUserCount 在线人数统计接口
func (s *StatisticsApi) OnlineUserCount(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
if s.rdb == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("redis client is nil"))
return
}
count, err := rediscache.GetOnlineUserCount(c, s.rdb)
if err != nil {
if errors.Is(err, redis.Nil) {
count, err = rediscache.RefreshOnlineUserCount(c, s.rdb)
if err == nil {
if appendErr := rediscache.AppendOnlineUserCountHistory(c, s.rdb, time.Now().UnixMilli(), count); appendErr != nil {
log.ZWarn(c, "append online user count history failed", appendErr)
}
}
}
}
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, &apistruct.OnlineUserCountResp{OnlineCount: count})
}
// OnlineUserCountTrend 在线人数走势统计接口
func (s *StatisticsApi) OnlineUserCountTrend(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.OnlineUserCountTrendReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.rdb == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("redis client is nil"))
return
}
intervalMillis, err := parseTrendIntervalMillis(req.IntervalMinutes)
if err != nil {
apiresp.GinError(c, err)
return
}
startTime, endTime, err := normalizeTrendTimeRange(req.StartTime, req.EndTime)
if err != nil {
apiresp.GinError(c, err)
return
}
bucketStart, bucketEnd := alignTrendRange(startTime, endTime, intervalMillis)
// 使用对齐后的时间范围获取历史数据,确保数据范围与构建数据点的范围一致
samples, err := rediscache.GetOnlineUserCountHistory(c, s.rdb, bucketStart, bucketEnd)
if err != nil {
apiresp.GinError(c, err)
return
}
// 将当前在线人数作为最新采样,确保最后一个时间段展示该段内的最大在线人数
now := time.Now().UnixMilli()
currentBucket := now - (now % intervalMillis)
if now < 0 && now%intervalMillis != 0 {
currentBucket = now - ((now % intervalMillis) + intervalMillis)
}
if currentBucket >= bucketStart && currentBucket <= bucketEnd {
if currentCount, err := rediscache.GetOnlineUserCount(c, s.rdb); err == nil {
samples = append(samples, rediscache.OnlineUserCountSample{
Timestamp: now,
Count: currentCount,
})
}
}
points := buildOnlineUserCountTrendPoints(samples, bucketStart, bucketEnd, intervalMillis)
apiresp.GinSuccess(c, &apistruct.OnlineUserCountTrendResp{
IntervalMinutes: req.IntervalMinutes,
Points: points,
})
}
// UserSendMsgCount 用户发送消息总数统计
func (s *StatisticsApi) UserSendMsgCount(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
_, err := a2r.ParseRequest[apistruct.UserSendMsgCountReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
now := time.Now()
endTime := now.UnixMilli()
start24h := now.Add(-24 * time.Hour).UnixMilli()
start7d := now.Add(-7 * 24 * time.Hour).UnixMilli()
start30d := now.Add(-30 * 24 * time.Hour).UnixMilli()
count24h, err := s.msgDatabase.CountUserSendMessages(c, "", start24h, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
count7d, err := s.msgDatabase.CountUserSendMessages(c, "", start7d, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
count30d, err := s.msgDatabase.CountUserSendMessages(c, "", start30d, endTime, "")
if err != nil {
apiresp.GinError(c, err)
return
}
apiresp.GinSuccess(c, &apistruct.UserSendMsgCountResp{
Count24h: count24h,
Count7d: count7d,
Count30d: count30d,
})
}
// UserSendMsgCountTrend 用户发送消息走势统计
func (s *StatisticsApi) UserSendMsgCountTrend(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.UserSendMsgCountTrendReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
intervalMillis, err := parseTrendIntervalMillis(req.IntervalMinutes)
if err != nil {
apiresp.GinError(c, err)
return
}
startTime, endTime, err := normalizeTrendTimeRange(req.StartTime, req.EndTime)
if err != nil {
apiresp.GinError(c, err)
return
}
sessionTypes, err := mapTrendChatType(req.ChatType)
if err != nil {
apiresp.GinError(c, err)
return
}
bucketStart, bucketEnd := alignTrendRange(startTime, endTime, intervalMillis)
countMap, err := s.msgDatabase.CountUserSendMessagesTrend(c, req.UserID, sessionTypes, startTime, endTime, intervalMillis)
if err != nil {
apiresp.GinError(c, err)
return
}
points := buildUserSendMsgCountTrendPoints(countMap, bucketStart, bucketEnd, intervalMillis)
apiresp.GinSuccess(c, &apistruct.UserSendMsgCountTrendResp{
UserID: req.UserID,
ChatType: req.ChatType,
IntervalMinutes: req.IntervalMinutes,
Points: points,
})
}
// UserSendMsgQuery 用户发送消息查询
func (s *StatisticsApi) UserSendMsgQuery(c *gin.Context) {
if err := authverify.CheckAdmin(c); err != nil {
apiresp.GinError(c, err)
return
}
req, err := a2r.ParseRequest[apistruct.UserSendMsgQueryReq](c)
if err != nil {
apiresp.GinError(c, err)
return
}
if req.StartTime > 0 && req.EndTime > 0 && req.EndTime <= req.StartTime {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("invalid time range"))
return
}
if s.msgDatabase == nil {
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("msg database is nil"))
return
}
pageNumber := req.PageNumber
if pageNumber <= 0 {
pageNumber = 1
}
showNumber := req.ShowNumber
if showNumber <= 0 {
showNumber = 50
}
const maxShowNumber int32 = 200
if showNumber > maxShowNumber {
showNumber = maxShowNumber
}
total, msgs, err := s.msgDatabase.SearchUserMessages(c, req.UserID, req.StartTime, req.EndTime, req.Content, pageNumber, showNumber)
if err != nil {
apiresp.GinError(c, err)
return
}
sendIDs := make([]string, 0, len(msgs))
recvIDs := make([]string, 0, len(msgs))
groupIDs := make([]string, 0, len(msgs))
for _, item := range msgs {
if item == nil || item.Msg == nil {
continue
}
msg := item.Msg
if msg.SendID != "" {
sendIDs = append(sendIDs, msg.SendID)
}
switch msg.SessionType {
case constant.ReadGroupChatType, constant.WriteGroupChatType:
if msg.GroupID != "" {
groupIDs = append(groupIDs, msg.GroupID)
}
default:
if msg.RecvID != "" {
recvIDs = append(recvIDs, msg.RecvID)
}
}
}
sendIDs = datautil.Distinct(sendIDs)
recvIDs = datautil.Distinct(recvIDs)
groupIDs = datautil.Distinct(groupIDs)
sendMap, recvMap, groupMap := map[string]*sdkws.UserInfo{}, map[string]*sdkws.UserInfo{}, map[string]*sdkws.GroupInfo{}
if s.userClient != nil {
if len(sendIDs) > 0 {
if users, err := s.userClient.GetUsersInfo(c, sendIDs); err == nil {
sendMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
}
}
if len(recvIDs) > 0 {
if users, err := s.userClient.GetUsersInfo(c, recvIDs); err == nil {
recvMap = datautil.SliceToMap(users, (*sdkws.UserInfo).GetUserID)
}
}
}
if s.groupClient != nil && len(groupIDs) > 0 {
if groups, err := s.groupClient.GetGroupsInfo(c, groupIDs); err == nil {
groupMap = datautil.SliceToMap(groups, (*sdkws.GroupInfo).GetGroupID)
}
}
records := make([]*apistruct.UserSendMsgQueryRecord, 0, len(msgs))
for _, item := range msgs {
if item == nil || item.Msg == nil {
continue
}
msg := item.Msg
msgID := msg.ServerMsgID
if msgID == "" {
msgID = msg.ClientMsgID
}
senderName := msg.SenderNickname
if senderName == "" {
if u := sendMap[msg.SendID]; u != nil {
senderName = u.Nickname
} else {
senderName = msg.SendID
}
}
recvID := msg.RecvID
recvName := ""
if msg.SessionType == constant.ReadGroupChatType || msg.SessionType == constant.WriteGroupChatType {
if msg.GroupID != "" {
recvID = msg.GroupID
}
if g := groupMap[recvID]; g != nil {
recvName = g.GroupName
} else if recvID != "" {
recvName = recvID
}
} else {
if u := recvMap[msg.RecvID]; u != nil {
recvName = u.Nickname
} else if msg.RecvID != "" {
recvName = msg.RecvID
}
}
records = append(records, &apistruct.UserSendMsgQueryRecord{
MsgID: msgID,
SendID: msg.SendID,
SenderName: senderName,
RecvID: recvID,
RecvName: recvName,
ContentType: msg.ContentType,
ContentTypeName: contentTypeName(msg.ContentType),
SessionType: msg.SessionType,
ChatTypeName: chatTypeName(msg.SessionType),
Content: msg.Content,
SendTime: msg.SendTime,
})
}
apiresp.GinSuccess(c, &apistruct.UserSendMsgQueryResp{
Count: total,
PageNumber: pageNumber,
ShowNumber: showNumber,
Records: records,
})
}
// parseTrendIntervalMillis 解析走势统计间隔并转换为毫秒
func parseTrendIntervalMillis(intervalMinutes int32) (int64, error) {
switch intervalMinutes {
case trendIntervalMinutes15, trendIntervalMinutes30, trendIntervalMinutes60:
return int64(intervalMinutes) * int64(time.Minute/time.Millisecond), nil
default:
return 0, errs.ErrArgs.WrapMsg("invalid intervalMinutes")
}
}
// normalizeTrendTimeRange 标准化走势统计时间区间
func normalizeTrendTimeRange(startTime int64, endTime int64) (int64, int64, error) {
now := time.Now().UnixMilli()
if endTime <= 0 {
endTime = now
}
if startTime <= 0 {
startTime = endTime - int64(defaultTrendDuration/time.Millisecond)
}
if startTime < 0 {
startTime = 0
}
if endTime <= startTime {
return 0, 0, errs.ErrArgs.WrapMsg("invalid time range")
}
return startTime, endTime, nil
}
// alignTrendRange 对齐走势统计区间到间隔边界
func alignTrendRange(startTime int64, endTime int64, intervalMillis int64) (int64, int64) {
if intervalMillis <= 0 {
return startTime, endTime
}
// 开始时间向下对齐到间隔边界
bucketStart := startTime - (startTime % intervalMillis)
if startTime < 0 {
bucketStart = startTime - ((startTime % intervalMillis) + intervalMillis)
}
// 结束时间向下对齐到所在间隔的起始(只包含已发生的间隔)
bucketEnd := endTime - (endTime % intervalMillis)
if endTime < 0 && endTime%intervalMillis != 0 {
bucketEnd = endTime - ((endTime % intervalMillis) + intervalMillis)
}
// 确保至少覆盖一个间隔
if bucketEnd < bucketStart {
bucketEnd = bucketStart
}
return bucketStart, bucketEnd
}
// buildOnlineUserCountTrendPoints 构建在线人数走势数据点
func buildOnlineUserCountTrendPoints(samples []rediscache.OnlineUserCountSample, startTime int64, endTime int64, intervalMillis int64) []*apistruct.OnlineUserCountTrendItem {
points := make([]*apistruct.OnlineUserCountTrendItem, 0)
if intervalMillis <= 0 || endTime <= startTime {
return points
}
maxMap := make(map[int64]int64)
for _, sample := range samples {
// 将采样时间戳对齐到间隔边界
bucket := sample.Timestamp - (sample.Timestamp % intervalMillis)
// 处理负数时间戳的情况(虽然通常不会发生)
if sample.Timestamp < 0 && sample.Timestamp%intervalMillis != 0 {
bucket = sample.Timestamp - ((sample.Timestamp % intervalMillis) + intervalMillis)
}
if sample.Count > maxMap[bucket] {
maxMap[bucket] = sample.Count
}
}
// 计算需要生成的数据点数量
// endTime是对齐后的最后一个bucket的起始时间所以需要包含它
estimated := int((endTime-startTime)/intervalMillis) + 1
if estimated > 0 {
points = make([]*apistruct.OnlineUserCountTrendItem, 0, estimated)
}
// 生成从startTime到endTime包含endTime的所有时间点
// endTime已经是对齐后的最后一个bucket的起始时间
for ts := startTime; ts <= endTime; ts += intervalMillis {
maxVal := maxMap[ts]
points = append(points, &apistruct.OnlineUserCountTrendItem{
Timestamp: ts,
OnlineCount: maxVal,
})
}
return points
}
// buildUserSendMsgCountTrendPoints 构建用户发送消息走势数据点
func buildUserSendMsgCountTrendPoints(countMap map[int64]int64, startTime int64, endTime int64, intervalMillis int64) []*apistruct.UserSendMsgCountTrendItem {
points := make([]*apistruct.UserSendMsgCountTrendItem, 0)
if intervalMillis <= 0 || endTime <= startTime {
return points
}
estimated := int((endTime - startTime) / intervalMillis)
if estimated > 0 {
points = make([]*apistruct.UserSendMsgCountTrendItem, 0, estimated)
}
for ts := startTime; ts < endTime; ts += intervalMillis {
points = append(points, &apistruct.UserSendMsgCountTrendItem{
Timestamp: ts,
Count: countMap[ts],
})
}
return points
}
// mapTrendChatType 走势统计聊天类型转为 sessionType 列表
func mapTrendChatType(chatType int32) ([]int32, error) {
switch chatType {
case trendChatTypeSingle:
return []int32{constant.SingleChatType}, nil
case trendChatTypeGroup:
return []int32{constant.ReadGroupChatType, constant.WriteGroupChatType}, nil
default:
return nil, errs.ErrArgs.WrapMsg("invalid chatType")
}
}
// contentTypeName 消息类型名称转换
func contentTypeName(contentType int32) string {
switch contentType {
case constant.Text:
return "文本消息"
case constant.Picture:
return "图片消息"
case constant.Voice:
return "语音消息"
case constant.Video:
return "视频消息"
case constant.File:
return "文件消息"
case constant.AtText:
return "艾特消息"
case constant.Merger:
return "合并消息"
case constant.Card:
return "名片消息"
case constant.Location:
return "位置消息"
case constant.Custom:
return "自定义消息"
case constant.Revoke:
return "撤回消息"
case constant.MarkdownText:
return "Markdown消息"
default:
return "未知消息"
}
}
// chatTypeName 聊天类型名称转换
func chatTypeName(sessionType int32) string {
switch sessionType {
case constant.SingleChatType:
return "单聊"
case constant.ReadGroupChatType, constant.WriteGroupChatType:
return "群聊"
case constant.NotificationChatType:
return "通知"
default:
return "未知"
}
}

175
internal/api/third.go Normal file
View File

@@ -0,0 +1,175 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"context"
"math/rand"
"net/http"
"net/url"
"strconv"
"strings"
"google.golang.org/grpc"
"git.imall.cloud/openim/protocol/third"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/mcontext"
)
type ThirdApi struct {
GrafanaUrl string
Client third.ThirdClient
}
func NewThirdApi(client third.ThirdClient, grafanaUrl string) ThirdApi {
return ThirdApi{Client: client, GrafanaUrl: grafanaUrl}
}
func (o *ThirdApi) FcmUpdateToken(c *gin.Context) {
a2r.Call(c, third.ThirdClient.FcmUpdateToken, o.Client)
}
func (o *ThirdApi) SetAppBadge(c *gin.Context) {
a2r.Call(c, third.ThirdClient.SetAppBadge, o.Client)
}
// #################### s3 ####################
func setURLPrefixOption[A, B, C any](_ func(client C, ctx context.Context, req *A, options ...grpc.CallOption) (*B, error), fn func(*A) error) *a2r.Option[A, B] {
return &a2r.Option[A, B]{
BindAfter: fn,
}
}
func setURLPrefix(c *gin.Context, urlPrefix *string) error {
host := c.GetHeader("X-Request-Api")
if host != "" {
if strings.HasSuffix(host, "/") {
*urlPrefix = host + "object/"
return nil
} else {
*urlPrefix = host + "/object/"
return nil
}
}
u := url.URL{
Scheme: "http",
Host: c.Request.Host,
Path: "/object/",
}
if c.Request.TLS != nil {
u.Scheme = "https"
}
*urlPrefix = u.String()
return nil
}
func (o *ThirdApi) PartLimit(c *gin.Context) {
a2r.Call(c, third.ThirdClient.PartLimit, o.Client)
}
func (o *ThirdApi) PartSize(c *gin.Context) {
a2r.Call(c, third.ThirdClient.PartSize, o.Client)
}
func (o *ThirdApi) InitiateMultipartUpload(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.InitiateMultipartUpload, func(req *third.InitiateMultipartUploadReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.InitiateMultipartUpload, o.Client, opt)
}
func (o *ThirdApi) AuthSign(c *gin.Context) {
a2r.Call(c, third.ThirdClient.AuthSign, o.Client)
}
func (o *ThirdApi) CompleteMultipartUpload(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.CompleteMultipartUpload, func(req *third.CompleteMultipartUploadReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.CompleteMultipartUpload, o.Client, opt)
}
func (o *ThirdApi) AccessURL(c *gin.Context) {
a2r.Call(c, third.ThirdClient.AccessURL, o.Client)
}
func (o *ThirdApi) InitiateFormData(c *gin.Context) {
a2r.Call(c, third.ThirdClient.InitiateFormData, o.Client)
}
func (o *ThirdApi) CompleteFormData(c *gin.Context) {
opt := setURLPrefixOption(third.ThirdClient.CompleteFormData, func(req *third.CompleteFormDataReq) error {
return setURLPrefix(c, &req.UrlPrefix)
})
a2r.Call(c, third.ThirdClient.CompleteFormData, o.Client, opt)
}
func (o *ThirdApi) ObjectRedirect(c *gin.Context) {
name := c.Param("name")
if name == "" {
c.String(http.StatusBadRequest, "name is empty")
return
}
if name[0] == '/' {
name = name[1:]
}
operationID := c.Query("operationID")
if operationID == "" {
operationID = strconv.Itoa(rand.Int())
}
ctx := mcontext.SetOperationID(c, operationID)
query := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) == 0 {
continue
}
query[key] = values[0]
}
resp, err := o.Client.AccessURL(ctx, &third.AccessURLReq{Name: name, Query: query})
if err != nil {
if errs.ErrArgs.Is(err) {
c.String(http.StatusBadRequest, err.Error())
return
}
if errs.ErrRecordNotFound.Is(err) {
c.String(http.StatusNotFound, err.Error())
return
}
c.String(http.StatusInternalServerError, err.Error())
return
}
c.Redirect(http.StatusFound, resp.Url)
}
// #################### logs ####################.
func (o *ThirdApi) UploadLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.UploadLogs, o.Client)
}
func (o *ThirdApi) DeleteLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.DeleteLogs, o.Client)
}
func (o *ThirdApi) SearchLogs(c *gin.Context) {
a2r.Call(c, third.ThirdClient.SearchLogs, o.Client)
}
func (o *ThirdApi) GetPrometheus(c *gin.Context) {
c.Redirect(http.StatusFound, o.GrafanaUrl)
}

260
internal/api/user.go Normal file
View File

@@ -0,0 +1,260 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/config"
"git.imall.cloud/openim/protocol/constant"
"git.imall.cloud/openim/protocol/msggateway"
"git.imall.cloud/openim/protocol/user"
"github.com/gin-gonic/gin"
"github.com/openimsdk/tools/a2r"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/discovery"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
)
type UserApi struct {
Client user.UserClient
discov discovery.Conn
config config.RpcService
}
func NewUserApi(client user.UserClient, discov discovery.Conn, config config.RpcService) UserApi {
return UserApi{Client: client, discov: discov, config: config}
}
func (u *UserApi) UserRegister(c *gin.Context) {
a2r.Call(c, user.UserClient.UserRegister, u.Client)
}
// UpdateUserInfo is deprecated. Use UpdateUserInfoEx
func (u *UserApi) UpdateUserInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateUserInfo, u.Client)
}
func (u *UserApi) UpdateUserInfoEx(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateUserInfoEx, u.Client)
}
func (u *UserApi) SetGlobalRecvMessageOpt(c *gin.Context) {
a2r.Call(c, user.UserClient.SetGlobalRecvMessageOpt, u.Client)
}
func (u *UserApi) GetUsersPublicInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.GetDesignateUsers, u.Client)
}
func (u *UserApi) GetAllUsersID(c *gin.Context) {
a2r.Call(c, user.UserClient.GetAllUserID, u.Client)
}
func (u *UserApi) AccountCheck(c *gin.Context) {
a2r.Call(c, user.UserClient.AccountCheck, u.Client)
}
func (u *UserApi) GetUsers(c *gin.Context) {
a2r.Call(c, user.UserClient.GetPaginationUsers, u.Client)
}
// GetUsersOnlineStatus Get user online status.
func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) {
var req msggateway.GetUsersOnlineStatusReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, err)
return
}
conns, err := u.discov.GetConns(c, u.config.MessageGateway)
if err != nil {
apiresp.GinError(c, err)
return
}
var wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
var respResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
flag := false
// Online push message
for _, v := range conns {
msgClient := msggateway.NewMsgGatewayClient(v)
reply, err := msgClient.GetUsersOnlineStatus(c, &req)
if err != nil {
log.ZDebug(c, "GetUsersOnlineStatus rpc error", err)
parseError := apiresp.ParseError(err)
if parseError.ErrCode == errs.NoPermissionError {
apiresp.GinError(c, err)
return
}
} else {
wsResult = append(wsResult, reply.SuccessResult...)
}
}
// Traversing the userIDs in the api request body
for _, v1 := range req.UserIDs {
flag = false
res := new(msggateway.GetUsersOnlineStatusResp_SuccessResult)
// Iterate through the online results fetched from various gateways
for _, v2 := range wsResult {
// If matches the above description on the line, and vice versa
if v2.UserID == v1 {
flag = true
res.UserID = v1
res.Status = constant.Online
res.DetailPlatformStatus = append(res.DetailPlatformStatus, v2.DetailPlatformStatus...)
break
}
}
if !flag {
res.UserID = v1
res.Status = constant.Offline
}
respResult = append(respResult, res)
}
apiresp.GinSuccess(c, respResult)
}
func (u *UserApi) UserRegisterCount(c *gin.Context) {
a2r.Call(c, user.UserClient.UserRegisterCount, u.Client)
}
// GetUsersOnlineTokenDetail Get user online token details.
func (u *UserApi) GetUsersOnlineTokenDetail(c *gin.Context) {
var wsResult []*msggateway.GetUsersOnlineStatusResp_SuccessResult
var respResult []*msggateway.SingleDetail
flag := false
var req msggateway.GetUsersOnlineStatusReq
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
conns, err := u.discov.GetConns(c, u.config.MessageGateway)
if err != nil {
apiresp.GinError(c, err)
return
}
// Online push message
for _, v := range conns {
msgClient := msggateway.NewMsgGatewayClient(v)
reply, err := msgClient.GetUsersOnlineStatus(c, &req)
if err != nil {
log.ZWarn(c, "GetUsersOnlineStatus rpc err", err)
continue
} else {
wsResult = append(wsResult, reply.SuccessResult...)
}
}
for _, v1 := range req.UserIDs {
m := make(map[int32][]string, 10)
flag = false
temp := new(msggateway.SingleDetail)
for _, v2 := range wsResult {
if v2.UserID == v1 {
flag = true
temp.UserID = v1
temp.Status = constant.Online
for _, status := range v2.DetailPlatformStatus {
if v, ok := m[status.PlatformID]; ok {
m[status.PlatformID] = append(v, status.Token)
} else {
m[status.PlatformID] = []string{status.Token}
}
}
}
}
for p, tokens := range m {
t := new(msggateway.SinglePlatformToken)
t.PlatformID = p
t.Token = tokens
t.Total = int32(len(tokens))
temp.SinglePlatformToken = append(temp.SinglePlatformToken, t)
}
if flag {
respResult = append(respResult, temp)
}
}
apiresp.GinSuccess(c, respResult)
}
// SubscriberStatus Presence status of subscribed users.
func (u *UserApi) SubscriberStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.SubscribeOrCancelUsersStatus, u.Client)
}
// GetUserStatus Get the online status of the user.
func (u *UserApi) GetUserStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.GetUserStatus, u.Client)
}
// GetSubscribeUsersStatus Get the online status of subscribers.
func (u *UserApi) GetSubscribeUsersStatus(c *gin.Context) {
a2r.Call(c, user.UserClient.GetSubscribeUsersStatus, u.Client)
}
// ProcessUserCommandAdd user general function add.
func (u *UserApi) ProcessUserCommandAdd(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandAdd, u.Client)
}
// ProcessUserCommandDelete user general function delete.
func (u *UserApi) ProcessUserCommandDelete(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandDelete, u.Client)
}
// ProcessUserCommandUpdate user general function update.
func (u *UserApi) ProcessUserCommandUpdate(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandUpdate, u.Client)
}
// ProcessUserCommandGet user general function get.
func (u *UserApi) ProcessUserCommandGet(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandGet, u.Client)
}
// ProcessUserCommandGet user general function get all.
func (u *UserApi) ProcessUserCommandGetAll(c *gin.Context) {
a2r.Call(c, user.UserClient.ProcessUserCommandGetAll, u.Client)
}
func (u *UserApi) AddNotificationAccount(c *gin.Context) {
a2r.Call(c, user.UserClient.AddNotificationAccount, u.Client)
}
func (u *UserApi) UpdateNotificationAccountInfo(c *gin.Context) {
a2r.Call(c, user.UserClient.UpdateNotificationAccountInfo, u.Client)
}
func (u *UserApi) SearchNotificationAccount(c *gin.Context) {
a2r.Call(c, user.UserClient.SearchNotificationAccount, u.Client)
}
func (u *UserApi) GetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.GetUserClientConfig, u.Client)
}
func (u *UserApi) SetUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.SetUserClientConfig, u.Client)
}
func (u *UserApi) DelUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.DelUserClientConfig, u.Client)
}
func (u *UserApi) PageUserClientConfig(c *gin.Context) {
a2r.Call(c, user.UserClient.PageUserClientConfig, u.Client)
}

523
internal/api/wallet.go Normal file
View File

@@ -0,0 +1,523 @@
// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"strconv"
"time"
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/idutil"
"github.com/openimsdk/tools/utils/timeutil"
)
type WalletApi struct {
walletDB database.Wallet
walletBalanceRecordDB database.WalletBalanceRecord
userDB database.User
userClient *rpcli.UserClient
}
func NewWalletApi(walletDB database.Wallet, walletBalanceRecordDB database.WalletBalanceRecord, userDB database.User, userClient *rpcli.UserClient) *WalletApi {
return &WalletApi{
walletDB: walletDB,
walletBalanceRecordDB: walletBalanceRecordDB,
userDB: userDB,
userClient: userClient,
}
}
// updateBalanceWithRecord 统一的余额更新方法,包含余额记录创建和并发控制
// operation: 操作类型set/add/subtract
// amount: 金额(分)
// oldBalance: 旧余额
// oldVersion: 旧版本号
// remark: 备注信息
// 返回新余额、新版本号和错误
func (w *WalletApi) updateBalanceWithRecord(ctx *gin.Context, userID string, operation string, amount int64, oldBalance int64, oldVersion int64, remark string) (newBalance int64, newVersion int64, err error) {
// 使用版本号更新余额(防止并发覆盖)
params := &database.WalletUpdateParams{
UserID: userID,
Operation: operation,
Amount: amount,
OldBalance: oldBalance,
OldVersion: oldVersion,
}
result, err := w.walletDB.UpdateBalanceWithVersion(ctx, params)
if err != nil {
return 0, 0, err
}
// 计算变动金额:用于余额记录
var recordAmount int64
switch operation {
case "set":
recordAmount = result.NewBalance - oldBalance
case "add":
recordAmount = amount
case "subtract":
recordAmount = -amount
}
// 创建余额记录(新字段)
recordID := idutil.GetMsgIDByMD5(userID + timeutil.GetCurrentTimeFormatted() + operation + strconv.FormatInt(amount, 10))
var recordType int32 = 99 // 默认其他
switch operation {
case "add":
recordType = 99 // 通用“加钱”,具体业务可在上层传入
case "subtract":
recordType = 99 // 通用“减钱”
case "set":
recordType = 99 // 设置,归为其他
}
balanceRecord := &model.WalletBalanceRecord{
ID: recordID,
UserID: userID,
Amount: recordAmount,
Type: recordType,
BeforeBalance: oldBalance,
AfterBalance: result.NewBalance,
OrderID: "",
TransactionID: "",
RedPacketID: "",
Remark: remark,
CreateTime: time.Now(),
}
if err := w.walletBalanceRecordDB.Create(ctx, balanceRecord); err != nil {
// 余额记录创建失败不影响主流程,只记录警告日志
log.ZWarn(ctx, "updateBalanceWithRecord: failed to create balance record", err,
"userID", userID,
"operation", operation,
"amount", amount,
"oldBalance", oldBalance,
"newBalance", result.NewBalance)
}
return result.NewBalance, result.NewVersion, nil
}
// paginationWrapper 实现 pagination.Pagination 接口
type walletPaginationWrapper struct {
pageNumber int32
showNumber int32
}
func (p *walletPaginationWrapper) GetPageNumber() int32 {
if p.pageNumber <= 0 {
return 1
}
return p.pageNumber
}
func (p *walletPaginationWrapper) GetShowNumber() int32 {
if p.showNumber <= 0 {
return 20
}
return p.showNumber
}
// GetWallets 查询用户钱包列表(后台管理接口)
func (w *WalletApi) GetWallets(c *gin.Context) {
var (
req apistruct.GetWalletsReq
resp apistruct.GetWalletsResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 设置默认分页参数
if req.Pagination.PageNumber <= 0 {
req.Pagination.PageNumber = 1
}
if req.Pagination.ShowNumber <= 0 {
req.Pagination.ShowNumber = 20
}
// 创建分页对象
pagination := &walletPaginationWrapper{
pageNumber: req.Pagination.PageNumber,
showNumber: req.Pagination.ShowNumber,
}
// 查询钱包列表
var total int64
var wallets []*apistruct.WalletInfo
var err error
// 判断是否有查询条件用户ID、手机号、账号
hasQueryCondition := req.UserID != "" || req.PhoneNumber != "" || req.Account != ""
if hasQueryCondition {
// 如果有查询条件先通过条件查询用户ID
var userIDs []string
if req.UserID != "" {
// 如果直接提供了用户ID直接使用
userIDs = []string{req.UserID}
} else {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, req.Account, req.PhoneNumber, "")
if err != nil {
log.ZError(c, "GetWallets: failed to search users", err, "account", req.Account, "phoneNumber", req.PhoneNumber)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to search users"))
return
}
if len(searchUserIDs) == 0 {
// 没有找到匹配的用户,返回空列表
resp.Total = 0
resp.Wallets = []*apistruct.WalletInfo{}
apiresp.GinSuccess(c, resp)
return
}
userIDs = searchUserIDs
}
// 根据查询到的用户ID列表查询钱包
walletModels, err := w.walletDB.FindWalletsByUserIDs(c, userIDs)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets by userIDs", err, "userIDs", userIDs)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
total = int64(len(wallets))
} else {
// 如果没有任何查询条件,查询所有钱包(带分页)
var walletModels []*model.Wallet
total, walletModels, err = w.walletDB.FindAllWallets(c, pagination)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets", err)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
}
// 收集所有用户ID
userIDMap := make(map[string]bool)
for _, wallet := range wallets {
if wallet.UserID != "" {
userIDMap[wallet.UserID] = true
}
}
// 批量查询用户信息
userIDList := make([]string, 0, len(userIDMap))
for userID := range userIDMap {
userIDList = append(userIDList, userID)
}
userInfoMap := make(map[string]*apistruct.WalletInfo) // userID -> WalletInfo (with nickname and faceURL)
if len(userIDList) > 0 {
userInfos, err := w.userClient.GetUsersInfo(c, userIDList)
if err == nil && len(userInfos) > 0 {
for _, userInfo := range userInfos {
if userInfo != nil {
userInfoMap[userInfo.UserID] = &apistruct.WalletInfo{
Nickname: userInfo.Nickname,
FaceURL: userInfo.FaceURL,
}
}
}
} else {
log.ZWarn(c, "GetWallets: failed to get users info", err, "userIDs", userIDList)
}
}
// 填充用户信息
for _, wallet := range wallets {
if userInfo, ok := userInfoMap[wallet.UserID]; ok {
wallet.Nickname = userInfo.Nickname
wallet.FaceURL = userInfo.FaceURL
}
}
// 填充响应
resp.Total = total
resp.Wallets = wallets
log.ZInfo(c, "GetWallets: success", "userID", req.UserID, "phoneNumber", req.PhoneNumber, "account", req.Account, "total", total, "count", len(resp.Wallets))
apiresp.GinSuccess(c, resp)
}
// BatchUpdateWalletBalance 批量修改用户余额(后台管理接口)
func (w *WalletApi) BatchUpdateWalletBalance(c *gin.Context) {
var (
req apistruct.BatchUpdateWalletBalanceReq
resp apistruct.BatchUpdateWalletBalanceResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 验证请求参数
if len(req.Users) == 0 {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("users list cannot be empty"))
return
}
// 设置默认操作类型
defaultOperation := req.Operation
if defaultOperation == "" {
defaultOperation = "add" // 默认为增加
}
if defaultOperation != "" && defaultOperation != "set" && defaultOperation != "add" && defaultOperation != "subtract" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("operation must be 'set', 'add', or 'subtract'"))
return
}
// 处理每个用户
resp.Total = int32(len(req.Users))
resp.Results = make([]apistruct.WalletUpdateResult, 0, len(req.Users))
for _, user := range req.Users {
result := apistruct.WalletUpdateResult{
UserID: user.UserID.String(),
PhoneNumber: user.PhoneNumber,
Account: user.Account,
Success: false,
Remark: user.Remark, // 保存备注信息
}
// 确定用户ID
var userID string
if user.UserID != "" {
userID = user.UserID.String()
} else if user.PhoneNumber != "" || user.Account != "" {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, user.Account, user.PhoneNumber, "")
if err != nil {
result.Message = "failed to search user: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) == 0 {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) > 1 {
result.Message = "multiple users found, please use userID"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userID = searchUserIDs[0]
} else {
result.Message = "user identifier is required (userID, phoneNumber, or account)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 先验证用户是否存在
userExists, err := w.userDB.Exist(c, userID)
if err != nil {
result.Message = "failed to check user existence: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if !userExists {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 查询当前钱包
wallet, err := w.walletDB.Take(c, userID)
if err != nil {
// 如果钱包不存在,但用户存在,创建新钱包
if errs.ErrRecordNotFound.Is(err) {
wallet = &model.Wallet{
UserID: userID,
Balance: 0,
CreateTime: time.Now(),
UpdateTime: time.Now(),
}
if err := w.walletDB.Create(c, wallet); err != nil {
result.Message = "failed to create wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
} else {
result.Message = "failed to get wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
}
// 记录旧余额和版本号(旧数据可能没有 version保持 0 以兼容)
result.OldBalance = wallet.Balance
oldVersion := wallet.Version
// 确定该用户使用的金额和操作类型
userAmount := user.Amount
if userAmount == 0 {
// 如果用户没有指定金额,使用默认金额
userAmount = req.Amount
}
if userAmount == 0 {
result.Message = "amount is required (either in user object or in request)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userOperation := user.Operation
if userOperation == "" {
// 如果用户没有指定操作类型,使用默认操作类型
userOperation = defaultOperation
}
if userOperation != "set" && userOperation != "add" && userOperation != "subtract" {
result.Message = "operation must be 'set', 'add', or 'subtract'"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 预检查:计算新余额并检查不能为负数
var expectedNewBalance int64
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 使用统一的余额更新方法(包含并发控制和余额记录创建)
// 如果发生并发冲突,重试一次(重新获取最新余额和版本号)
var newBalance int64
maxRetries := 2
for retry := 0; retry < maxRetries; retry++ {
if retry > 0 {
// 重试时重新获取钱包信息
wallet, err = w.walletDB.Take(c, userID)
if err != nil {
result.Message = "failed to get wallet for retry: " + err.Error()
break
}
result.OldBalance = wallet.Balance
oldVersion = wallet.Version
// 重新计算预期新余额
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative after retry"
break
}
}
newBalance, _, err = w.updateBalanceWithRecord(c, userID, userOperation, userAmount, result.OldBalance, oldVersion, user.Remark)
if err == nil {
// 更新成功
break
}
// 如果是并发冲突且还有重试机会,继续重试
if retry < maxRetries-1 && errs.ErrInternalServer.Is(err) {
log.ZWarn(c, "BatchUpdateWalletBalance: concurrent modification detected, retrying", err,
"userID", userID,
"retry", retry+1,
"maxRetries", maxRetries)
continue
}
// 其他错误或重试次数用完,返回错误
result.Message = "failed to update balance: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if err != nil {
// 更新失败
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 更新成功
result.Success = true
result.NewBalance = newBalance
result.Message = "success"
// 备注信息已在初始化时设置
resp.Results = append(resp.Results, result)
resp.Success++
// 记录日志,包含备注信息
if user.Remark != "" {
log.ZInfo(c, "BatchUpdateWalletBalance: user balance updated",
"userID", userID,
"operation", userOperation,
"amount", userAmount,
"oldBalance", result.OldBalance,
"newBalance", newBalance,
"remark", user.Remark)
}
}
log.ZInfo(c, "BatchUpdateWalletBalance: success", "total", resp.Total, "success", resp.Success, "failed", resp.Failed)
apiresp.GinSuccess(c, resp)
}