232 lines
7.6 KiB
Go
232 lines
7.6 KiB
Go
// Copyright © 2023 OpenIM. All rights reserved.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package mgo
|
||
|
||
import (
|
||
"context"
|
||
"time"
|
||
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database"
|
||
"git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model"
|
||
"github.com/openimsdk/tools/db/mongoutil"
|
||
"github.com/openimsdk/tools/db/pagination"
|
||
"github.com/openimsdk/tools/errs"
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
)
|
||
|
||
// WalletMgo implements Wallet using MongoDB as the storage backend.
|
||
type WalletMgo struct {
|
||
coll *mongo.Collection
|
||
}
|
||
|
||
// NewWalletMongo creates a new instance of WalletMgo with the provided MongoDB database.
|
||
func NewWalletMongo(db *mongo.Database) (database.Wallet, error) {
|
||
coll := db.Collection(database.WalletName)
|
||
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
|
||
{
|
||
Keys: bson.D{{Key: "user_id", Value: 1}},
|
||
Options: options.Index().SetUnique(true),
|
||
},
|
||
{
|
||
Keys: bson.D{{Key: "create_time", Value: -1}},
|
||
},
|
||
{
|
||
Keys: bson.D{{Key: "update_time", Value: -1}},
|
||
},
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &WalletMgo{coll: coll}, nil
|
||
}
|
||
|
||
// Create creates a new wallet record.
|
||
func (w *WalletMgo) Create(ctx context.Context, wallet *model.Wallet) error {
|
||
if wallet.CreateTime.IsZero() {
|
||
wallet.CreateTime = time.Now()
|
||
}
|
||
if wallet.UpdateTime.IsZero() {
|
||
wallet.UpdateTime = time.Now()
|
||
}
|
||
if wallet.Version == 0 {
|
||
wallet.Version = 1
|
||
}
|
||
return mongoutil.InsertOne(ctx, w.coll, wallet)
|
||
}
|
||
|
||
// Take retrieves a wallet by user ID. Returns an error if not found.
|
||
func (w *WalletMgo) Take(ctx context.Context, userID string) (*model.Wallet, error) {
|
||
return mongoutil.FindOne[*model.Wallet](ctx, w.coll, bson.M{"user_id": userID})
|
||
}
|
||
|
||
// UpdateBalance updates the balance of a wallet.
|
||
func (w *WalletMgo) UpdateBalance(ctx context.Context, userID string, balance int64) error {
|
||
update := bson.M{
|
||
"$set": bson.M{
|
||
"balance": balance,
|
||
"update_time": time.Now(),
|
||
},
|
||
}
|
||
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
|
||
}
|
||
|
||
// UpdateBalanceByAmount updates the balance by adding/subtracting an amount.
|
||
func (w *WalletMgo) UpdateBalanceByAmount(ctx context.Context, userID string, amount int64) error {
|
||
update := bson.M{
|
||
"$inc": bson.M{
|
||
"balance": amount,
|
||
},
|
||
"$set": bson.M{
|
||
"update_time": time.Now(),
|
||
},
|
||
}
|
||
return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false)
|
||
}
|
||
|
||
// UpdateBalanceWithVersion 使用版本号更新余额(防止并发覆盖)
|
||
func (w *WalletMgo) UpdateBalanceWithVersion(ctx context.Context, params *database.WalletUpdateParams) (*database.WalletUpdateResult, error) {
|
||
// 基于单文档原子操作,避免并发覆盖
|
||
if params.Amount < 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("amount cannot be negative")
|
||
}
|
||
|
||
// 兼容旧数据(无 version 字段或 version=0)
|
||
filter := bson.M{
|
||
"user_id": params.UserID,
|
||
"$or": []bson.M{
|
||
{"version": params.OldVersion},
|
||
{"version": bson.M{"$exists": false}},
|
||
{"version": 0},
|
||
},
|
||
}
|
||
|
||
// 默认更新:更新时间、版本号自增
|
||
update := bson.M{
|
||
"$set": bson.M{
|
||
"update_time": time.Now(),
|
||
},
|
||
"$inc": bson.M{
|
||
"version": 1,
|
||
},
|
||
}
|
||
|
||
switch params.Operation {
|
||
case "add":
|
||
update["$inc"].(bson.M)["balance"] = params.Amount
|
||
case "subtract":
|
||
update["$inc"].(bson.M)["balance"] = -params.Amount
|
||
// 防止出现负数,过滤条件要求当前余额 >= amount
|
||
filter["balance"] = bson.M{"$gte": params.Amount}
|
||
case "set":
|
||
// 直接设置余额,同时递增版本
|
||
delete(update, "$inc")
|
||
update["$set"].(bson.M)["balance"] = params.Amount
|
||
update["$set"].(bson.M)["version"] = params.OldVersion + 1
|
||
default:
|
||
return nil, errs.ErrArgs.WrapMsg("invalid operation: " + params.Operation)
|
||
}
|
||
|
||
// 如果 set 分支删除了 $inc,需要补上版本递增
|
||
if _, ok := update["$inc"]; !ok {
|
||
update["$inc"] = bson.M{"version": 1}
|
||
} else if _, ok := update["$inc"].(bson.M)["version"]; !ok {
|
||
update["$inc"].(bson.M)["version"] = 1
|
||
}
|
||
|
||
// 使用 findOneAndUpdate 返回更新后的文档
|
||
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
|
||
var updatedWallet model.Wallet
|
||
err := w.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedWallet)
|
||
if err != nil {
|
||
if err == mongo.ErrNoDocuments {
|
||
// 版本号或余额不匹配,说明有并发修改
|
||
return nil, errs.ErrInternalServer.WrapMsg("concurrent modification detected: version or balance mismatch")
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
return &database.WalletUpdateResult{
|
||
NewBalance: updatedWallet.Balance,
|
||
NewVersion: updatedWallet.Version,
|
||
Success: true,
|
||
}, nil
|
||
}
|
||
|
||
// FindAllWallets finds all wallets with pagination.
|
||
func (w *WalletMgo) FindAllWallets(ctx context.Context, pagination pagination.Pagination) (total int64, wallets []*model.Wallet, err error) {
|
||
return mongoutil.FindPage[*model.Wallet](ctx, w.coll, bson.M{}, pagination, &options.FindOptions{
|
||
Sort: bson.D{{Key: "create_time", Value: -1}},
|
||
})
|
||
}
|
||
|
||
// FindWalletsByUserIDs finds wallets by user IDs.
|
||
func (w *WalletMgo) FindWalletsByUserIDs(ctx context.Context, userIDs []string) ([]*model.Wallet, error) {
|
||
if len(userIDs) == 0 {
|
||
return []*model.Wallet{}, nil
|
||
}
|
||
filter := bson.M{"user_id": bson.M{"$in": userIDs}}
|
||
return mongoutil.Find[*model.Wallet](ctx, w.coll, filter)
|
||
}
|
||
|
||
// WalletBalanceRecordMgo implements WalletBalanceRecord using MongoDB as the storage backend.
|
||
type WalletBalanceRecordMgo struct {
|
||
coll *mongo.Collection
|
||
}
|
||
|
||
// NewWalletBalanceRecordMongo creates a new instance of WalletBalanceRecordMgo with the provided MongoDB database.
|
||
func NewWalletBalanceRecordMongo(db *mongo.Database) (database.WalletBalanceRecord, error) {
|
||
coll := db.Collection(database.WalletBalanceRecordName)
|
||
|
||
// 先尝试删除可能存在的旧 record_id 索引(如果不是稀疏索引)
|
||
// 忽略删除失败的错误(索引可能不存在)
|
||
_, _ = coll.Indexes().DropOne(context.Background(), "record_id_1")
|
||
|
||
// 创建索引,使用稀疏索引允许 record_id 为 null
|
||
_, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{
|
||
{
|
||
Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "create_time", Value: -1}},
|
||
},
|
||
{
|
||
Keys: bson.D{{Key: "record_id", Value: 1}},
|
||
Options: options.Index().SetUnique(true).SetSparse(true),
|
||
},
|
||
{
|
||
Keys: bson.D{{Key: "create_time", Value: -1}},
|
||
},
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &WalletBalanceRecordMgo{coll: coll}, nil
|
||
}
|
||
|
||
// Create creates a new wallet balance record.
|
||
func (w *WalletBalanceRecordMgo) Create(ctx context.Context, record *model.WalletBalanceRecord) error {
|
||
if record.CreateTime.IsZero() {
|
||
record.CreateTime = time.Now()
|
||
}
|
||
return mongoutil.InsertOne(ctx, w.coll, record)
|
||
}
|
||
|
||
// FindByUserID finds all balance records for a user with pagination.
|
||
func (w *WalletBalanceRecordMgo) FindByUserID(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, records []*model.WalletBalanceRecord, err error) {
|
||
filter := bson.M{"user_id": userID}
|
||
return mongoutil.FindPage[*model.WalletBalanceRecord](ctx, w.coll, filter, pagination, &options.FindOptions{
|
||
Sort: bson.D{{Key: "create_time", Value: -1}},
|
||
})
|
||
}
|