复制项目
This commit is contained in:
231
pkg/common/storage/database/mgo/wallet.go
Normal file
231
pkg/common/storage/database/mgo/wallet.go
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mgo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
|
||||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
|
||||
"github.com/openimsdk/tools/db/mongoutil"
|
||||
"github.com/openimsdk/tools/db/pagination"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
// WalletMgo implements Wallet using MongoDB as the storage backend.
|
||||
type WalletMgo struct {
|
||||
coll *mongo.Collection
|
||||
}
|
||||
|
||||
// NewWalletMongo creates a new instance of WalletMgo with the provided MongoDB database.
|
||||
func NewWalletMongo(db *mongo.Database) (database.Wallet, error) {
|
||||
coll := db.Collection(database.WalletName)
|
||||
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
|
||||
{
|
||||
Keys: bson.D{{Key: "user_id", Value: 1}},
|
||||
Options: options.Index().SetUnique(true),
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{Key: "create_time", Value: -1}},
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{Key: "update_time", Value: -1}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &WalletMgo{coll: coll}, nil
|
||||
}
|
||||
|
||||
// Create creates a new wallet record.
|
||||
func (w *WalletMgo) Create(ctx context.Context, wallet *model.Wallet) error {
|
||||
if wallet.CreateTime.IsZero() {
|
||||
wallet.CreateTime = time.Now()
|
||||
}
|
||||
if wallet.UpdateTime.IsZero() {
|
||||
wallet.UpdateTime = time.Now()
|
||||
}
|
||||
if wallet.Version == 0 {
|
||||
wallet.Version = 1
|
||||
}
|
||||
return mongoutil.InsertOne(ctx, w.coll, wallet)
|
||||
}
|
||||
|
||||
// Take retrieves a wallet by user ID. Returns an error if not found.
|
||||
func (w *WalletMgo) Take(ctx context.Context, userID string) (*model.Wallet, error) {
|
||||
return mongoutil.FindOne[*model.Wallet](ctx, w.coll, bson.M{"user_id": userID})
|
||||
}
|
||||
|
||||
// UpdateBalance updates the balance of a wallet.
|
||||
func (w *WalletMgo) UpdateBalance(ctx context.Context, userID string, balance int64) error {
|
||||
update := bson.M{
|
||||
"$set": bson.M{
|
||||
"balance": balance,
|
||||
"update_time": time.Now(),
|
||||
},
|
||||
}
|
||||
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
|
||||
}
|
||||
|
||||
// UpdateBalanceByAmount updates the balance by adding/subtracting an amount.
|
||||
func (w *WalletMgo) UpdateBalanceByAmount(ctx context.Context, userID string, amount int64) error {
|
||||
update := bson.M{
|
||||
"$inc": bson.M{
|
||||
"balance": amount,
|
||||
},
|
||||
"$set": bson.M{
|
||||
"update_time": time.Now(),
|
||||
},
|
||||
}
|
||||
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
|
||||
}
|
||||
|
||||
// UpdateBalanceWithVersion 使用版本号更新余额(防止并发覆盖)
|
||||
func (w *WalletMgo) UpdateBalanceWithVersion(ctx context.Context, params *database.WalletUpdateParams) (*database.WalletUpdateResult, error) {
|
||||
// 基于单文档原子操作,避免并发覆盖
|
||||
if params.Amount < 0 {
|
||||
return nil, errs.ErrArgs.WrapMsg("amount cannot be negative")
|
||||
}
|
||||
|
||||
// 兼容旧数据(无 version 字段或 version=0)
|
||||
filter := bson.M{
|
||||
"user_id": params.UserID,
|
||||
"$or": []bson.M{
|
||||
{"version": params.OldVersion},
|
||||
{"version": bson.M{"$exists": false}},
|
||||
{"version": 0},
|
||||
},
|
||||
}
|
||||
|
||||
// 默认更新:更新时间、版本号自增
|
||||
update := bson.M{
|
||||
"$set": bson.M{
|
||||
"update_time": time.Now(),
|
||||
},
|
||||
"$inc": bson.M{
|
||||
"version": 1,
|
||||
},
|
||||
}
|
||||
|
||||
switch params.Operation {
|
||||
case "add":
|
||||
update["$inc"].(bson.M)["balance"] = params.Amount
|
||||
case "subtract":
|
||||
update["$inc"].(bson.M)["balance"] = -params.Amount
|
||||
// 防止出现负数,过滤条件要求当前余额 >= amount
|
||||
filter["balance"] = bson.M{"$gte": params.Amount}
|
||||
case "set":
|
||||
// 直接设置余额,同时递增版本
|
||||
delete(update, "$inc")
|
||||
update["$set"].(bson.M)["balance"] = params.Amount
|
||||
update["$set"].(bson.M)["version"] = params.OldVersion + 1
|
||||
default:
|
||||
return nil, errs.ErrArgs.WrapMsg("invalid operation: " + params.Operation)
|
||||
}
|
||||
|
||||
// 如果 set 分支删除了 $inc,需要补上版本递增
|
||||
if _, ok := update["$inc"]; !ok {
|
||||
update["$inc"] = bson.M{"version": 1}
|
||||
} else if _, ok := update["$inc"].(bson.M)["version"]; !ok {
|
||||
update["$inc"].(bson.M)["version"] = 1
|
||||
}
|
||||
|
||||
// 使用 findOneAndUpdate 返回更新后的文档
|
||||
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
|
||||
var updatedWallet model.Wallet
|
||||
err := w.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedWallet)
|
||||
if err != nil {
|
||||
if err == mongo.ErrNoDocuments {
|
||||
// 版本号或余额不匹配,说明有并发修改
|
||||
return nil, errs.ErrInternalServer.WrapMsg("concurrent modification detected: version or balance mismatch")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &database.WalletUpdateResult{
|
||||
NewBalance: updatedWallet.Balance,
|
||||
NewVersion: updatedWallet.Version,
|
||||
Success: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FindAllWallets finds all wallets with pagination.
|
||||
func (w *WalletMgo) FindAllWallets(ctx context.Context, pagination pagination.Pagination) (total int64, wallets []*model.Wallet, err error) {
|
||||
return mongoutil.FindPage[*model.Wallet](ctx, w.coll, bson.M{}, pagination, &options.FindOptions{
|
||||
Sort: bson.D{{Key: "create_time", Value: -1}},
|
||||
})
|
||||
}
|
||||
|
||||
// FindWalletsByUserIDs finds wallets by user IDs.
|
||||
func (w *WalletMgo) FindWalletsByUserIDs(ctx context.Context, userIDs []string) ([]*model.Wallet, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return []*model.Wallet{}, nil
|
||||
}
|
||||
filter := bson.M{"user_id": bson.M{"$in": userIDs}}
|
||||
return mongoutil.Find[*model.Wallet](ctx, w.coll, filter)
|
||||
}
|
||||
|
||||
// WalletBalanceRecordMgo implements WalletBalanceRecord using MongoDB as the storage backend.
|
||||
type WalletBalanceRecordMgo struct {
|
||||
coll *mongo.Collection
|
||||
}
|
||||
|
||||
// NewWalletBalanceRecordMongo creates a new instance of WalletBalanceRecordMgo with the provided MongoDB database.
|
||||
func NewWalletBalanceRecordMongo(db *mongo.Database) (database.WalletBalanceRecord, error) {
|
||||
coll := db.Collection(database.WalletBalanceRecordName)
|
||||
|
||||
// 先尝试删除可能存在的旧 record_id 索引(如果不是稀疏索引)
|
||||
// 忽略删除失败的错误(索引可能不存在)
|
||||
_, _ = coll.Indexes().DropOne(context.Background(), "record_id_1")
|
||||
|
||||
// 创建索引,使用稀疏索引允许 record_id 为 null
|
||||
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
|
||||
{
|
||||
Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "create_time", Value: -1}},
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{Key: "record_id", Value: 1}},
|
||||
Options: options.Index().SetUnique(true).SetSparse(true),
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{Key: "create_time", Value: -1}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &WalletBalanceRecordMgo{coll: coll}, nil
|
||||
}
|
||||
|
||||
// Create creates a new wallet balance record.
|
||||
func (w *WalletBalanceRecordMgo) Create(ctx context.Context, record *model.WalletBalanceRecord) error {
|
||||
if record.CreateTime.IsZero() {
|
||||
record.CreateTime = time.Now()
|
||||
}
|
||||
return mongoutil.InsertOne(ctx, w.coll, record)
|
||||
}
|
||||
|
||||
// FindByUserID finds all balance records for a user with pagination.
|
||||
func (w *WalletBalanceRecordMgo) FindByUserID(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, records []*model.WalletBalanceRecord, err error) {
|
||||
filter := bson.M{"user_id": userID}
|
||||
return mongoutil.FindPage[*model.WalletBalanceRecord](ctx, w.coll, filter, pagination, &options.FindOptions{
|
||||
Sort: bson.D{{Key: "create_time", Value: -1}},
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user