Files
open-im-server-deploy/internal/api/wallet.go
kim.dev.6789 e50142a3b9 复制项目
2026-01-14 22:16:44 +08:00

524 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"strconv"
"time"
"github.com/gin-gonic/gin"
"git.imall.cloud/openim/open-im-server-deploy/pkg/apistruct"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"git.imall.cloud/openim/open-im-server-deploy/pkg/rpcli"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/utils/idutil"
"github.com/openimsdk/tools/utils/timeutil"
)
type WalletApi struct {
walletDB database.Wallet
walletBalanceRecordDB database.WalletBalanceRecord
userDB database.User
userClient *rpcli.UserClient
}
func NewWalletApi(walletDB database.Wallet, walletBalanceRecordDB database.WalletBalanceRecord, userDB database.User, userClient *rpcli.UserClient) *WalletApi {
return &WalletApi{
walletDB: walletDB,
walletBalanceRecordDB: walletBalanceRecordDB,
userDB: userDB,
userClient: userClient,
}
}
// updateBalanceWithRecord 统一的余额更新方法,包含余额记录创建和并发控制
// operation: 操作类型set/add/subtract
// amount: 金额(分)
// oldBalance: 旧余额
// oldVersion: 旧版本号
// remark: 备注信息
// 返回新余额、新版本号和错误
func (w *WalletApi) updateBalanceWithRecord(ctx *gin.Context, userID string, operation string, amount int64, oldBalance int64, oldVersion int64, remark string) (newBalance int64, newVersion int64, err error) {
// 使用版本号更新余额(防止并发覆盖)
params := &database.WalletUpdateParams{
UserID: userID,
Operation: operation,
Amount: amount,
OldBalance: oldBalance,
OldVersion: oldVersion,
}
result, err := w.walletDB.UpdateBalanceWithVersion(ctx, params)
if err != nil {
return 0, 0, err
}
// 计算变动金额:用于余额记录
var recordAmount int64
switch operation {
case "set":
recordAmount = result.NewBalance - oldBalance
case "add":
recordAmount = amount
case "subtract":
recordAmount = -amount
}
// 创建余额记录(新字段)
recordID := idutil.GetMsgIDByMD5(userID + timeutil.GetCurrentTimeFormatted() + operation + strconv.FormatInt(amount, 10))
var recordType int32 = 99 // 默认其他
switch operation {
case "add":
recordType = 99 // 通用“加钱”,具体业务可在上层传入
case "subtract":
recordType = 99 // 通用“减钱”
case "set":
recordType = 99 // 设置,归为其他
}
balanceRecord := &model.WalletBalanceRecord{
ID: recordID,
UserID: userID,
Amount: recordAmount,
Type: recordType,
BeforeBalance: oldBalance,
AfterBalance: result.NewBalance,
OrderID: "",
TransactionID: "",
RedPacketID: "",
Remark: remark,
CreateTime: time.Now(),
}
if err := w.walletBalanceRecordDB.Create(ctx, balanceRecord); err != nil {
// 余额记录创建失败不影响主流程,只记录警告日志
log.ZWarn(ctx, "updateBalanceWithRecord: failed to create balance record", err,
"userID", userID,
"operation", operation,
"amount", amount,
"oldBalance", oldBalance,
"newBalance", result.NewBalance)
}
return result.NewBalance, result.NewVersion, nil
}
// paginationWrapper 实现 pagination.Pagination 接口
type walletPaginationWrapper struct {
pageNumber int32
showNumber int32
}
func (p *walletPaginationWrapper) GetPageNumber() int32 {
if p.pageNumber <= 0 {
return 1
}
return p.pageNumber
}
func (p *walletPaginationWrapper) GetShowNumber() int32 {
if p.showNumber <= 0 {
return 20
}
return p.showNumber
}
// GetWallets 查询用户钱包列表(后台管理接口)
func (w *WalletApi) GetWallets(c *gin.Context) {
var (
req apistruct.GetWalletsReq
resp apistruct.GetWalletsResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 设置默认分页参数
if req.Pagination.PageNumber <= 0 {
req.Pagination.PageNumber = 1
}
if req.Pagination.ShowNumber <= 0 {
req.Pagination.ShowNumber = 20
}
// 创建分页对象
pagination := &walletPaginationWrapper{
pageNumber: req.Pagination.PageNumber,
showNumber: req.Pagination.ShowNumber,
}
// 查询钱包列表
var total int64
var wallets []*apistruct.WalletInfo
var err error
// 判断是否有查询条件用户ID、手机号、账号
hasQueryCondition := req.UserID != "" || req.PhoneNumber != "" || req.Account != ""
if hasQueryCondition {
// 如果有查询条件先通过条件查询用户ID
var userIDs []string
if req.UserID != "" {
// 如果直接提供了用户ID直接使用
userIDs = []string{req.UserID}
} else {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, req.Account, req.PhoneNumber, "")
if err != nil {
log.ZError(c, "GetWallets: failed to search users", err, "account", req.Account, "phoneNumber", req.PhoneNumber)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to search users"))
return
}
if len(searchUserIDs) == 0 {
// 没有找到匹配的用户,返回空列表
resp.Total = 0
resp.Wallets = []*apistruct.WalletInfo{}
apiresp.GinSuccess(c, resp)
return
}
userIDs = searchUserIDs
}
// 根据查询到的用户ID列表查询钱包
walletModels, err := w.walletDB.FindWalletsByUserIDs(c, userIDs)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets by userIDs", err, "userIDs", userIDs)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
total = int64(len(wallets))
} else {
// 如果没有任何查询条件,查询所有钱包(带分页)
var walletModels []*model.Wallet
total, walletModels, err = w.walletDB.FindAllWallets(c, pagination)
if err != nil {
log.ZError(c, "GetWallets: failed to find wallets", err)
apiresp.GinError(c, errs.ErrInternalServer.WrapMsg("failed to find wallets"))
return
}
// 转换为响应格式
wallets = make([]*apistruct.WalletInfo, 0, len(walletModels))
for _, wallet := range walletModels {
wallets = append(wallets, &apistruct.WalletInfo{
UserID: wallet.UserID,
Balance: wallet.Balance,
CreateTime: wallet.CreateTime.UnixMilli(),
UpdateTime: wallet.UpdateTime.UnixMilli(),
})
}
}
// 收集所有用户ID
userIDMap := make(map[string]bool)
for _, wallet := range wallets {
if wallet.UserID != "" {
userIDMap[wallet.UserID] = true
}
}
// 批量查询用户信息
userIDList := make([]string, 0, len(userIDMap))
for userID := range userIDMap {
userIDList = append(userIDList, userID)
}
userInfoMap := make(map[string]*apistruct.WalletInfo) // userID -> WalletInfo (with nickname and faceURL)
if len(userIDList) > 0 {
userInfos, err := w.userClient.GetUsersInfo(c, userIDList)
if err == nil && len(userInfos) > 0 {
for _, userInfo := range userInfos {
if userInfo != nil {
userInfoMap[userInfo.UserID] = &apistruct.WalletInfo{
Nickname: userInfo.Nickname,
FaceURL: userInfo.FaceURL,
}
}
}
} else {
log.ZWarn(c, "GetWallets: failed to get users info", err, "userIDs", userIDList)
}
}
// 填充用户信息
for _, wallet := range wallets {
if userInfo, ok := userInfoMap[wallet.UserID]; ok {
wallet.Nickname = userInfo.Nickname
wallet.FaceURL = userInfo.FaceURL
}
}
// 填充响应
resp.Total = total
resp.Wallets = wallets
log.ZInfo(c, "GetWallets: success", "userID", req.UserID, "phoneNumber", req.PhoneNumber, "account", req.Account, "total", total, "count", len(resp.Wallets))
apiresp.GinSuccess(c, resp)
}
// BatchUpdateWalletBalance 批量修改用户余额(后台管理接口)
func (w *WalletApi) BatchUpdateWalletBalance(c *gin.Context) {
var (
req apistruct.BatchUpdateWalletBalanceReq
resp apistruct.BatchUpdateWalletBalanceResp
)
if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return
}
// 验证请求参数
if len(req.Users) == 0 {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("users list cannot be empty"))
return
}
// 设置默认操作类型
defaultOperation := req.Operation
if defaultOperation == "" {
defaultOperation = "add" // 默认为增加
}
if defaultOperation != "" && defaultOperation != "set" && defaultOperation != "add" && defaultOperation != "subtract" {
apiresp.GinError(c, errs.ErrArgs.WrapMsg("operation must be 'set', 'add', or 'subtract'"))
return
}
// 处理每个用户
resp.Total = int32(len(req.Users))
resp.Results = make([]apistruct.WalletUpdateResult, 0, len(req.Users))
for _, user := range req.Users {
result := apistruct.WalletUpdateResult{
UserID: user.UserID.String(),
PhoneNumber: user.PhoneNumber,
Account: user.Account,
Success: false,
Remark: user.Remark, // 保存备注信息
}
// 确定用户ID
var userID string
if user.UserID != "" {
userID = user.UserID.String()
} else if user.PhoneNumber != "" || user.Account != "" {
// 通过手机号或账号查询用户ID
searchUserIDs, err := w.userDB.SearchUsersByFields(c, user.Account, user.PhoneNumber, "")
if err != nil {
result.Message = "failed to search user: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) == 0 {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if len(searchUserIDs) > 1 {
result.Message = "multiple users found, please use userID"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userID = searchUserIDs[0]
} else {
result.Message = "user identifier is required (userID, phoneNumber, or account)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 先验证用户是否存在
userExists, err := w.userDB.Exist(c, userID)
if err != nil {
result.Message = "failed to check user existence: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if !userExists {
result.Message = "user not found"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 查询当前钱包
wallet, err := w.walletDB.Take(c, userID)
if err != nil {
// 如果钱包不存在,但用户存在,创建新钱包
if errs.ErrRecordNotFound.Is(err) {
wallet = &model.Wallet{
UserID: userID,
Balance: 0,
CreateTime: time.Now(),
UpdateTime: time.Now(),
}
if err := w.walletDB.Create(c, wallet); err != nil {
result.Message = "failed to create wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
} else {
result.Message = "failed to get wallet: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
}
// 记录旧余额和版本号(旧数据可能没有 version保持 0 以兼容)
result.OldBalance = wallet.Balance
oldVersion := wallet.Version
// 确定该用户使用的金额和操作类型
userAmount := user.Amount
if userAmount == 0 {
// 如果用户没有指定金额,使用默认金额
userAmount = req.Amount
}
if userAmount == 0 {
result.Message = "amount is required (either in user object or in request)"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
userOperation := user.Operation
if userOperation == "" {
// 如果用户没有指定操作类型,使用默认操作类型
userOperation = defaultOperation
}
if userOperation != "set" && userOperation != "add" && userOperation != "subtract" {
result.Message = "operation must be 'set', 'add', or 'subtract'"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 预检查:计算新余额并检查不能为负数
var expectedNewBalance int64
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative"
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 使用统一的余额更新方法(包含并发控制和余额记录创建)
// 如果发生并发冲突,重试一次(重新获取最新余额和版本号)
var newBalance int64
maxRetries := 2
for retry := 0; retry < maxRetries; retry++ {
if retry > 0 {
// 重试时重新获取钱包信息
wallet, err = w.walletDB.Take(c, userID)
if err != nil {
result.Message = "failed to get wallet for retry: " + err.Error()
break
}
result.OldBalance = wallet.Balance
oldVersion = wallet.Version
// 重新计算预期新余额
switch userOperation {
case "set":
expectedNewBalance = userAmount
case "add":
expectedNewBalance = wallet.Balance + userAmount
case "subtract":
expectedNewBalance = wallet.Balance - userAmount
}
if expectedNewBalance < 0 {
result.Message = "balance cannot be negative after retry"
break
}
}
newBalance, _, err = w.updateBalanceWithRecord(c, userID, userOperation, userAmount, result.OldBalance, oldVersion, user.Remark)
if err == nil {
// 更新成功
break
}
// 如果是并发冲突且还有重试机会,继续重试
if retry < maxRetries-1 && errs.ErrInternalServer.Is(err) {
log.ZWarn(c, "BatchUpdateWalletBalance: concurrent modification detected, retrying", err,
"userID", userID,
"retry", retry+1,
"maxRetries", maxRetries)
continue
}
// 其他错误或重试次数用完,返回错误
result.Message = "failed to update balance: " + err.Error()
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
if err != nil {
// 更新失败
resp.Results = append(resp.Results, result)
resp.Failed++
continue
}
// 更新成功
result.Success = true
result.NewBalance = newBalance
result.Message = "success"
// 备注信息已在初始化时设置
resp.Results = append(resp.Results, result)
resp.Success++
// 记录日志,包含备注信息
if user.Remark != "" {
log.ZInfo(c, "BatchUpdateWalletBalance: user balance updated",
"userID", userID,
"operation", userOperation,
"amount", userAmount,
"oldBalance", result.OldBalance,
"newBalance", newBalance,
"remark", user.Remark)
}
}
log.ZInfo(c, "BatchUpdateWalletBalance: success", "total", resp.Total, "success", resp.Success, "failed", resp.Failed)
apiresp.GinSuccess(c, resp)
}