524 lines
16 KiB
Go
524 lines
16 KiB
Go
// Copyright © 2023 OpenIM. All rights reserved.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package api
|
||
|
||
import (
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
|
||
"github.com/openimsdk/tools/apiresp"
|
||
"github.com/openimsdk/tools/errs"
|
||
"github.com/openimsdk/tools/log"
|
||
"github.com/openimsdk/tools/utils/idutil"
|
||
"github.com/openimsdk/tools/utils/timeutil"
|
||
)
|
||
|
||
type WalletApi struct {
|
||
walletDB database.Wallet
|
||
walletBalanceRecordDB database.WalletBalanceRecord
|
||
userDB database.User
|
||
userClient *rpcli.UserClient
|
||
}
|
||
|
||
func NewWalletApi(walletDB database.Wallet, walletBalanceRecordDB database.WalletBalanceRecord, userDB database.User, userClient *rpcli.UserClient) *WalletApi {
|
||
return &WalletApi{
|
||
walletDB: walletDB,
|
||
walletBalanceRecordDB: walletBalanceRecordDB,
|
||
userDB: userDB,
|
||
userClient: userClient,
|
||
}
|
||
}
|
||
|
||
// updateBalanceWithRecord 统一的余额更新方法,包含余额记录创建和并发控制
|
||
// operation: 操作类型(set/add/subtract)
|
||
// amount: 金额(分)
|
||
// oldBalance: 旧余额
|
||
// oldVersion: 旧版本号
|
||
// remark: 备注信息
|
||
// 返回新余额、新版本号和错误
|
||
func (w *WalletApi) updateBalanceWithRecord(ctx *gin.Context, userID string, operation string, amount int64, oldBalance int64, oldVersion int64, remark string) (newBalance int64, newVersion int64, err error) {
|
||
// 使用版本号更新余额(防止并发覆盖)
|
||
params := &database.WalletUpdateParams{
|
||
UserID: userID,
|
||
Operation: operation,
|
||
Amount: amount,
|
||
OldBalance: oldBalance,
|
||
OldVersion: oldVersion,
|
||
}
|
||
|
||
result, err := w.walletDB.UpdateBalanceWithVersion(ctx, params)
|
||
if err != nil {
|
||
return 0, 0, err
|
||
}
|
||
|
||
// 计算变动金额:用于余额记录
|
||
var recordAmount int64
|
||
switch operation {
|
||
case "set":
|
||
recordAmount = result.NewBalance - oldBalance
|
||
case "add":
|
||
recordAmount = amount
|
||
case "subtract":
|
||
recordAmount = -amount
|
||
}
|
||
|
||
// 创建余额记录(新字段)
|
||
recordID := idutil.GetMsgIDByMD5(userID + timeutil.GetCurrentTimeFormatted() + operation + strconv.FormatInt(amount, 10))
|
||
var recordType int32 = 99 // 默认其他
|
||
switch operation {
|
||
case "add":
|
||
recordType = 99 // 通用“加钱”,具体业务可在上层传入
|
||
case "subtract":
|
||
recordType = 99 // 通用“减钱”
|
||
case "set":
|
||
recordType = 99 // 设置,归为其他
|
||
}
|
||
balanceRecord := &model.WalletBalanceRecord{
|
||
ID: recordID,
|
||
UserID: userID,
|
||
Amount: recordAmount,
|
||
Type: recordType,
|
||
BeforeBalance: oldBalance,
|
||
AfterBalance: result.NewBalance,
|
||
OrderID: "",
|
||
TransactionID: "",
|
||
RedPacketID: "",
|
||
Remark: remark,
|
||
CreateTime: time.Now(),
|
||
}
|
||
if err := w.walletBalanceRecordDB.Create(ctx, balanceRecord); err != nil {
|
||
// 余额记录创建失败不影响主流程,只记录警告日志
|
||
log.ZWarn(ctx, "updateBalanceWithRecord: failed to create balance record", err,
|
||
"userID", userID,
|
||
"operation", operation,
|
||
"amount", amount,
|
||
"oldBalance", oldBalance,
|
||
"newBalance", result.NewBalance)
|
||
}
|
||
|
||
return result.NewBalance, result.NewVersion, nil
|
||
}
|
||
|
||
// paginationWrapper 实现 pagination.Pagination 接口
|
||
type walletPaginationWrapper struct {
|
||
pageNumber int32
|
||
showNumber int32
|
||
}
|
||
|
||
func (p *walletPaginationWrapper) GetPageNumber() int32 {
|
||
if p.pageNumber <= 0 {
|
||
return 1
|
||
}
|
||
return p.pageNumber
|
||
}
|
||
|
||
func (p *walletPaginationWrapper) GetShowNumber() int32 {
|
||
if p.showNumber <= 0 {
|
||
return 20
|
||
}
|
||
return p.showNumber
|
||
}
|
||
|
||
// GetWallets 查询用户钱包列表(后台管理接口)
|
||
func (w *WalletApi) GetWallets(c *gin.Context) {
|
||
var (
|
||
req apistruct.GetWalletsReq
|
||
resp apistruct.GetWalletsResp
|
||
)
|
||
if err := c.BindJSON(&req); err != nil {
|
||
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
|
||
return
|
||
}
|
||
|
||
// 设置默认分页参数
|
||
if req.Pagination.PageNumber <= 0 {
|
||
req.Pagination.PageNumber = 1
|
||
}
|
||
if req.Pagination.ShowNumber <= 0 {
|
||
req.Pagination.ShowNumber = 20
|
||
}
|
||
|
||
// 创建分页对象
|
||
pagination := &walletPaginationWrapper{
|
||
pageNumber: req.Pagination.PageNumber,
|
||
showNumber: req.Pagination.ShowNumber,
|
||
}
|
||
|
||
// 查询钱包列表
|
||
var total int64
|
||
var wallets []*apistruct.WalletInfo
|
||
var err error
|
||
|
||
// 判断是否有查询条件(用户ID、手机号、账号)
|
||
hasQueryCondition := req.UserID != "" || req.PhoneNumber != "" || req.Account != ""
|
||
|
||
if hasQueryCondition {
|
||
// 如果有查询条件,先通过条件查询用户ID
|
||
var userIDs []string
|
||
if req.UserID != "" {
|
||
// 如果直接提供了用户ID,直接使用
|
||
userIDs = []string{req.UserID}
|
||
} else {
|
||
// 通过手机号或账号查询用户ID
|
||
searchUserIDs, err := w.userDB.SearchUsersByFields(c, req.Account, req.PhoneNumber, "")
|
||
if err != nil {
|
||
log.ZError(c, "GetWallets: failed to search users", err, "account", req.Account, "phoneNumber", req.PhoneNumber)
|
||
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to search users"))
|
||
return
|
||
}
|
||
if len(searchUserIDs) == 0 {
|
||
// 没有找到匹配的用户,返回空列表
|
||
resp.Total = 0
|
||
resp.Wallets = []*apistruct.WalletInfo{}
|
||
apiresp.GinSuccess(c, resp)
|
||
return
|
||
}
|
||
userIDs = searchUserIDs
|
||
}
|
||
|
||
// 根据查询到的用户ID列表查询钱包
|
||
walletModels, err := w.walletDB.FindWalletsByUserIDs(c, userIDs)
|
||
if err != nil {
|
||
log.ZError(c, "GetWallets: failed to find wallets by userIDs", err, "userIDs", userIDs)
|
||
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
|
||
for _, wallet := range walletModels {
|
||
wallets = append(wallets, &apistruct.WalletInfo{
|
||
UserID: wallet.UserID,
|
||
Balance: wallet.Balance,
|
||
CreateTime: wallet.CreateTime.UnixMilli(),
|
||
UpdateTime: wallet.UpdateTime.UnixMilli(),
|
||
})
|
||
}
|
||
total = int64(len(wallets))
|
||
} else {
|
||
// 如果没有任何查询条件,查询所有钱包(带分页)
|
||
var walletModels []*model.Wallet
|
||
total, walletModels, err = w.walletDB.FindAllWallets(c, pagination)
|
||
if err != nil {
|
||
log.ZError(c, "GetWallets: failed to find wallets", err)
|
||
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
|
||
return
|
||
}
|
||
// 转换为响应格式
|
||
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
|
||
for _, wallet := range walletModels {
|
||
wallets = append(wallets, &apistruct.WalletInfo{
|
||
UserID: wallet.UserID,
|
||
Balance: wallet.Balance,
|
||
CreateTime: wallet.CreateTime.UnixMilli(),
|
||
UpdateTime: wallet.UpdateTime.UnixMilli(),
|
||
})
|
||
}
|
||
}
|
||
|
||
// 收集所有用户ID
|
||
userIDMap := make(map[string]bool)
|
||
for _, wallet := range wallets {
|
||
if wallet.UserID != "" {
|
||
userIDMap[wallet.UserID] = true
|
||
}
|
||
}
|
||
|
||
// 批量查询用户信息
|
||
userIDList := make([]string, 0, len(userIDMap))
|
||
for userID := range userIDMap {
|
||
userIDList = append(userIDList, userID)
|
||
}
|
||
|
||
userInfoMap := make(map[string]*apistruct.WalletInfo) // userID -> WalletInfo (with nickname and faceURL)
|
||
if len(userIDList) > 0 {
|
||
userInfos, err := w.userClient.GetUsersInfo(c, userIDList)
|
||
if err == nil && len(userInfos) > 0 {
|
||
for _, userInfo := range userInfos {
|
||
if userInfo != nil {
|
||
userInfoMap[userInfo.UserID] = &apistruct.WalletInfo{
|
||
Nickname: userInfo.Nickname,
|
||
FaceURL: userInfo.FaceURL,
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
log.ZWarn(c, "GetWallets: failed to get users info", err, "userIDs", userIDList)
|
||
}
|
||
}
|
||
|
||
// 填充用户信息
|
||
for _, wallet := range wallets {
|
||
if userInfo, ok := userInfoMap[wallet.UserID]; ok {
|
||
wallet.Nickname = userInfo.Nickname
|
||
wallet.FaceURL = userInfo.FaceURL
|
||
}
|
||
}
|
||
|
||
// 填充响应
|
||
resp.Total = total
|
||
resp.Wallets = wallets
|
||
|
||
log.ZInfo(c, "GetWallets: success", "userID", req.UserID, "phoneNumber", req.PhoneNumber, "account", req.Account, "total", total, "count", len(resp.Wallets))
|
||
apiresp.GinSuccess(c, resp)
|
||
}
|
||
|
||
// BatchUpdateWalletBalance 批量修改用户余额(后台管理接口)
|
||
func (w *WalletApi) BatchUpdateWalletBalance(c *gin.Context) {
|
||
var (
|
||
req apistruct.BatchUpdateWalletBalanceReq
|
||
resp apistruct.BatchUpdateWalletBalanceResp
|
||
)
|
||
if err := c.BindJSON(&req); err != nil {
|
||
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
|
||
return
|
||
}
|
||
|
||
// 验证请求参数
|
||
if len(req.Users) == 0 {
|
||
apiresp.GinError(c, errs.ErrArgs.WrapMsg("users list cannot be empty"))
|
||
return
|
||
}
|
||
|
||
// 设置默认操作类型
|
||
defaultOperation := req.Operation
|
||
if defaultOperation == "" {
|
||
defaultOperation = "add" // 默认为增加
|
||
}
|
||
if defaultOperation != "" && defaultOperation != "set" && defaultOperation != "add" && defaultOperation != "subtract" {
|
||
apiresp.GinError(c, errs.ErrArgs.WrapMsg("operation must be 'set', 'add', or 'subtract'"))
|
||
return
|
||
}
|
||
|
||
// 处理每个用户
|
||
resp.Total = int32(len(req.Users))
|
||
resp.Results = make([]apistruct.WalletUpdateResult, 0, len(req.Users))
|
||
|
||
for _, user := range req.Users {
|
||
result := apistruct.WalletUpdateResult{
|
||
UserID: user.UserID.String(),
|
||
PhoneNumber: user.PhoneNumber,
|
||
Account: user.Account,
|
||
Success: false,
|
||
Remark: user.Remark, // 保存备注信息
|
||
}
|
||
|
||
// 确定用户ID
|
||
var userID string
|
||
if user.UserID != "" {
|
||
userID = user.UserID.String()
|
||
} else if user.PhoneNumber != "" || user.Account != "" {
|
||
// 通过手机号或账号查询用户ID
|
||
searchUserIDs, err := w.userDB.SearchUsersByFields(c, user.Account, user.PhoneNumber, "")
|
||
if err != nil {
|
||
result.Message = "failed to search user: " + err.Error()
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
if len(searchUserIDs) == 0 {
|
||
result.Message = "user not found"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
if len(searchUserIDs) > 1 {
|
||
result.Message = "multiple users found, please use userID"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
userID = searchUserIDs[0]
|
||
} else {
|
||
result.Message = "user identifier is required (userID, phoneNumber, or account)"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
// 先验证用户是否存在
|
||
userExists, err := w.userDB.Exist(c, userID)
|
||
if err != nil {
|
||
result.Message = "failed to check user existence: " + err.Error()
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
if !userExists {
|
||
result.Message = "user not found"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
// 查询当前钱包
|
||
wallet, err := w.walletDB.Take(c, userID)
|
||
if err != nil {
|
||
// 如果钱包不存在,但用户存在,创建新钱包
|
||
if errs.ErrRecordNotFound.Is(err) {
|
||
wallet = &model.Wallet{
|
||
UserID: userID,
|
||
Balance: 0,
|
||
CreateTime: time.Now(),
|
||
UpdateTime: time.Now(),
|
||
}
|
||
if err := w.walletDB.Create(c, wallet); err != nil {
|
||
result.Message = "failed to create wallet: " + err.Error()
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
} else {
|
||
result.Message = "failed to get wallet: " + err.Error()
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 记录旧余额和版本号(旧数据可能没有 version,保持 0 以兼容)
|
||
result.OldBalance = wallet.Balance
|
||
oldVersion := wallet.Version
|
||
|
||
// 确定该用户使用的金额和操作类型
|
||
userAmount := user.Amount
|
||
if userAmount == 0 {
|
||
// 如果用户没有指定金额,使用默认金额
|
||
userAmount = req.Amount
|
||
}
|
||
if userAmount == 0 {
|
||
result.Message = "amount is required (either in user object or in request)"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
userOperation := user.Operation
|
||
if userOperation == "" {
|
||
// 如果用户没有指定操作类型,使用默认操作类型
|
||
userOperation = defaultOperation
|
||
}
|
||
if userOperation != "set" && userOperation != "add" && userOperation != "subtract" {
|
||
result.Message = "operation must be 'set', 'add', or 'subtract'"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
// 预检查:计算新余额并检查不能为负数
|
||
var expectedNewBalance int64
|
||
switch userOperation {
|
||
case "set":
|
||
expectedNewBalance = userAmount
|
||
case "add":
|
||
expectedNewBalance = wallet.Balance + userAmount
|
||
case "subtract":
|
||
expectedNewBalance = wallet.Balance - userAmount
|
||
}
|
||
if expectedNewBalance < 0 {
|
||
result.Message = "balance cannot be negative"
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
// 使用统一的余额更新方法(包含并发控制和余额记录创建)
|
||
// 如果发生并发冲突,重试一次(重新获取最新余额和版本号)
|
||
var newBalance int64
|
||
maxRetries := 2
|
||
for retry := 0; retry < maxRetries; retry++ {
|
||
if retry > 0 {
|
||
// 重试时重新获取钱包信息
|
||
wallet, err = w.walletDB.Take(c, userID)
|
||
if err != nil {
|
||
result.Message = "failed to get wallet for retry: " + err.Error()
|
||
break
|
||
}
|
||
result.OldBalance = wallet.Balance
|
||
oldVersion = wallet.Version
|
||
// 重新计算预期新余额
|
||
switch userOperation {
|
||
case "set":
|
||
expectedNewBalance = userAmount
|
||
case "add":
|
||
expectedNewBalance = wallet.Balance + userAmount
|
||
case "subtract":
|
||
expectedNewBalance = wallet.Balance - userAmount
|
||
}
|
||
if expectedNewBalance < 0 {
|
||
result.Message = "balance cannot be negative after retry"
|
||
break
|
||
}
|
||
}
|
||
|
||
newBalance, _, err = w.updateBalanceWithRecord(c, userID, userOperation, userAmount, result.OldBalance, oldVersion, user.Remark)
|
||
if err == nil {
|
||
// 更新成功
|
||
break
|
||
}
|
||
|
||
// 如果是并发冲突且还有重试机会,继续重试
|
||
if retry < maxRetries-1 && errs.ErrInternalServer.Is(err) {
|
||
log.ZWarn(c, "BatchUpdateWalletBalance: concurrent modification detected, retrying", err,
|
||
"userID", userID,
|
||
"retry", retry+1,
|
||
"maxRetries", maxRetries)
|
||
continue
|
||
}
|
||
|
||
// 其他错误或重试次数用完,返回错误
|
||
result.Message = "failed to update balance: " + err.Error()
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
if err != nil {
|
||
// 更新失败
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Failed++
|
||
continue
|
||
}
|
||
|
||
// 更新成功
|
||
result.Success = true
|
||
result.NewBalance = newBalance
|
||
result.Message = "success"
|
||
// 备注信息已在初始化时设置
|
||
resp.Results = append(resp.Results, result)
|
||
resp.Success++
|
||
|
||
// 记录日志,包含备注信息
|
||
if user.Remark != "" {
|
||
log.ZInfo(c, "BatchUpdateWalletBalance: user balance updated",
|
||
"userID", userID,
|
||
"operation", userOperation,
|
||
"amount", userAmount,
|
||
"oldBalance", result.OldBalance,
|
||
"newBalance", newBalance,
|
||
"remark", user.Remark)
|
||
}
|
||
}
|
||
|
||
log.ZInfo(c, "BatchUpdateWalletBalance: success", "total", resp.Total, "success", resp.Success, "failed", resp.Failed)
|
||
apiresp.GinSuccess(c, resp)
|
||
}
|