Files
open-im-server-deploy/pkg/common/storage/database/mgo/wallet.go
kim.dev.6789 e50142a3b9 复制项目
2026-01-14 22:16:44 +08:00

232 lines
7.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mgo
import (
"context"
"time"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/errs"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// WalletMgo implements Wallet using MongoDB as the storage backend.
type WalletMgo struct {
coll *mongo.Collection
}
// NewWalletMongo creates a new instance of WalletMgo with the provided MongoDB database.
func NewWalletMongo(db *mongo.Database) (database.Wallet, error) {
coll := db.Collection(database.WalletName)
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "user_id", Value: 1}},
Options: options.Index().SetUnique(true),
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "update_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &WalletMgo{coll: coll}, nil
}
// Create creates a new wallet record.
func (w *WalletMgo) Create(ctx context.Context, wallet *model.Wallet) error {
if wallet.CreateTime.IsZero() {
wallet.CreateTime = time.Now()
}
if wallet.UpdateTime.IsZero() {
wallet.UpdateTime = time.Now()
}
if wallet.Version == 0 {
wallet.Version = 1
}
return mongoutil.InsertOne(ctx, w.coll, wallet)
}
// Take retrieves a wallet by user ID. Returns an error if not found.
func (w *WalletMgo) Take(ctx context.Context, userID string) (*model.Wallet, error) {
return mongoutil.FindOne[*model.Wallet](ctx, w.coll, bson.M{"user_id": userID})
}
// UpdateBalance updates the balance of a wallet.
func (w *WalletMgo) UpdateBalance(ctx context.Context, userID string, balance int64) error {
update := bson.M{
"$set": bson.M{
"balance": balance,
"update_time": time.Now(),
},
}
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
}
// UpdateBalanceByAmount updates the balance by adding/subtracting an amount.
func (w *WalletMgo) UpdateBalanceByAmount(ctx context.Context, userID string, amount int64) error {
update := bson.M{
"$inc": bson.M{
"balance": amount,
},
"$set": bson.M{
"update_time": time.Now(),
},
}
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
}
// UpdateBalanceWithVersion 使用版本号更新余额(防止并发覆盖)
func (w *WalletMgo) UpdateBalanceWithVersion(ctx context.Context, params *database.WalletUpdateParams) (*database.WalletUpdateResult, error) {
// 基于单文档原子操作,避免并发覆盖
if params.Amount < 0 {
return nil, errs.ErrArgs.WrapMsg("amount cannot be negative")
}
// 兼容旧数据(无 version 字段或 version=0
filter := bson.M{
"user_id": params.UserID,
"$or": []bson.M{
{"version": params.OldVersion},
{"version": bson.M{"$exists": false}},
{"version": 0},
},
}
// 默认更新:更新时间、版本号自增
update := bson.M{
"$set": bson.M{
"update_time": time.Now(),
},
"$inc": bson.M{
"version": 1,
},
}
switch params.Operation {
case "add":
update["$inc"].(bson.M)["balance"] = params.Amount
case "subtract":
update["$inc"].(bson.M)["balance"] = -params.Amount
// 防止出现负数,过滤条件要求当前余额 >= amount
filter["balance"] = bson.M{"$gte": params.Amount}
case "set":
// 直接设置余额,同时递增版本
delete(update, "$inc")
update["$set"].(bson.M)["balance"] = params.Amount
update["$set"].(bson.M)["version"] = params.OldVersion + 1
default:
return nil, errs.ErrArgs.WrapMsg("invalid operation: " + params.Operation)
}
// 如果 set 分支删除了 $inc需要补上版本递增
if _, ok := update["$inc"]; !ok {
update["$inc"] = bson.M{"version": 1}
} else if _, ok := update["$inc"].(bson.M)["version"]; !ok {
update["$inc"].(bson.M)["version"] = 1
}
// 使用 findOneAndUpdate 返回更新后的文档
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updatedWallet model.Wallet
err := w.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedWallet)
if err != nil {
if err == mongo.ErrNoDocuments {
// 版本号或余额不匹配,说明有并发修改
return nil, errs.ErrInternalServer.WrapMsg("concurrent modification detected: version or balance mismatch")
}
return nil, err
}
return &database.WalletUpdateResult{
NewBalance: updatedWallet.Balance,
NewVersion: updatedWallet.Version,
Success: true,
}, nil
}
// FindAllWallets finds all wallets with pagination.
func (w *WalletMgo) FindAllWallets(ctx context.Context, pagination pagination.Pagination) (total int64, wallets []*model.Wallet, err error) {
return mongoutil.FindPage[*model.Wallet](ctx, w.coll, bson.M{}, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}
// FindWalletsByUserIDs finds wallets by user IDs.
func (w *WalletMgo) FindWalletsByUserIDs(ctx context.Context, userIDs []string) ([]*model.Wallet, error) {
if len(userIDs) == 0 {
return []*model.Wallet{}, nil
}
filter := bson.M{"user_id": bson.M{"$in": userIDs}}
return mongoutil.Find[*model.Wallet](ctx, w.coll, filter)
}
// WalletBalanceRecordMgo implements WalletBalanceRecord using MongoDB as the storage backend.
type WalletBalanceRecordMgo struct {
coll *mongo.Collection
}
// NewWalletBalanceRecordMongo creates a new instance of WalletBalanceRecordMgo with the provided MongoDB database.
func NewWalletBalanceRecordMongo(db *mongo.Database) (database.WalletBalanceRecord, error) {
coll := db.Collection(database.WalletBalanceRecordName)
// 先尝试删除可能存在的旧 record_id 索引(如果不是稀疏索引)
// 忽略删除失败的错误(索引可能不存在)
_, _ = coll.Indexes().DropOne(context.Background(), "record_id_1")
// 创建索引,使用稀疏索引允许 record_id 为 null
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
{
Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "create_time", Value: -1}},
},
{
Keys: bson.D{{Key: "record_id", Value: 1}},
Options: options.Index().SetUnique(true).SetSparse(true),
},
{
Keys: bson.D{{Key: "create_time", Value: -1}},
},
})
if err != nil {
return nil, err
}
return &WalletBalanceRecordMgo{coll: coll}, nil
}
// Create creates a new wallet balance record.
func (w *WalletBalanceRecordMgo) Create(ctx context.Context, record *model.WalletBalanceRecord) error {
if record.CreateTime.IsZero() {
record.CreateTime = time.Now()
}
return mongoutil.InsertOne(ctx, w.coll, record)
}
// FindByUserID finds all balance records for a user with pagination.
func (w *WalletBalanceRecordMgo) FindByUserID(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, records []*model.WalletBalanceRecord, err error) {
filter := bson.M{"user_id": userID}
return mongoutil.FindPage[*model.WalletBalanceRecord](ctx, w.coll, filter, pagination, &options.FindOptions{
Sort: bson.D{{Key: "create_time", Value: -1}},
})
}