// Copyright © 2023 OpenIM. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mgo import ( "context" "time" "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/database" "git.imall.cloud/openim/open-im-server-deploy/pkg/common/storage/model" "github.com/openimsdk/tools/db/mongoutil" "github.com/openimsdk/tools/db/pagination" "github.com/openimsdk/tools/errs" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // WalletMgo implements Wallet using MongoDB as the storage backend. type WalletMgo struct { coll *mongo.Collection } // NewWalletMongo creates a new instance of WalletMgo with the provided MongoDB database. func NewWalletMongo(db *mongo.Database) (database.Wallet, error) { coll := db.Collection(database.WalletName) _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ { Keys: bson.D{{Key: "user_id", Value: 1}}, Options: options.Index().SetUnique(true), }, { Keys: bson.D{{Key: "create_time", Value: -1}}, }, { Keys: bson.D{{Key: "update_time", Value: -1}}, }, }) if err != nil { return nil, err } return &WalletMgo{coll: coll}, nil } // Create creates a new wallet record. func (w *WalletMgo) Create(ctx context.Context, wallet *model.Wallet) error { if wallet.CreateTime.IsZero() { wallet.CreateTime = time.Now() } if wallet.UpdateTime.IsZero() { wallet.UpdateTime = time.Now() } if wallet.Version == 0 { wallet.Version = 1 } return mongoutil.InsertOne(ctx, w.coll, wallet) } // Take retrieves a wallet by user ID. Returns an error if not found. func (w *WalletMgo) Take(ctx context.Context, userID string) (*model.Wallet, error) { return mongoutil.FindOne[*model.Wallet](ctx, w.coll, bson.M{"user_id": userID}) } // UpdateBalance updates the balance of a wallet. func (w *WalletMgo) UpdateBalance(ctx context.Context, userID string, balance int64) error { update := bson.M{ "$set": bson.M{ "balance": balance, "update_time": time.Now(), }, } return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false) } // UpdateBalanceByAmount updates the balance by adding/subtracting an amount. func (w *WalletMgo) UpdateBalanceByAmount(ctx context.Context, userID string, amount int64) error { update := bson.M{ "$inc": bson.M{ "balance": amount, }, "$set": bson.M{ "update_time": time.Now(), }, } return mongoutil.UpdateOne(ctx, w.coll, bson.M{"user_id": userID}, update, false) } // UpdateBalanceWithVersion 使用版本号更新余额(防止并发覆盖) func (w *WalletMgo) UpdateBalanceWithVersion(ctx context.Context, params *database.WalletUpdateParams) (*database.WalletUpdateResult, error) { // 基于单文档原子操作,避免并发覆盖 if params.Amount < 0 { return nil, errs.ErrArgs.WrapMsg("amount cannot be negative") } // 兼容旧数据(无 version 字段或 version=0) filter := bson.M{ "user_id": params.UserID, "$or": []bson.M{ {"version": params.OldVersion}, {"version": bson.M{"$exists": false}}, {"version": 0}, }, } // 默认更新:更新时间、版本号自增 update := bson.M{ "$set": bson.M{ "update_time": time.Now(), }, "$inc": bson.M{ "version": 1, }, } switch params.Operation { case "add": update["$inc"].(bson.M)["balance"] = params.Amount case "subtract": update["$inc"].(bson.M)["balance"] = -params.Amount // 防止出现负数,过滤条件要求当前余额 >= amount filter["balance"] = bson.M{"$gte": params.Amount} case "set": // 直接设置余额,同时递增版本 delete(update, "$inc") update["$set"].(bson.M)["balance"] = params.Amount update["$set"].(bson.M)["version"] = params.OldVersion + 1 default: return nil, errs.ErrArgs.WrapMsg("invalid operation: " + params.Operation) } // 如果 set 分支删除了 $inc,需要补上版本递增 if _, ok := update["$inc"]; !ok { update["$inc"] = bson.M{"version": 1} } else if _, ok := update["$inc"].(bson.M)["version"]; !ok { update["$inc"].(bson.M)["version"] = 1 } // 使用 findOneAndUpdate 返回更新后的文档 opts := options.FindOneAndUpdate().SetReturnDocument(options.After) var updatedWallet model.Wallet err := w.coll.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updatedWallet) if err != nil { if err == mongo.ErrNoDocuments { // 版本号或余额不匹配,说明有并发修改 return nil, errs.ErrInternalServer.WrapMsg("concurrent modification detected: version or balance mismatch") } return nil, err } return &database.WalletUpdateResult{ NewBalance: updatedWallet.Balance, NewVersion: updatedWallet.Version, Success: true, }, nil } // FindAllWallets finds all wallets with pagination. func (w *WalletMgo) FindAllWallets(ctx context.Context, pagination pagination.Pagination) (total int64, wallets []*model.Wallet, err error) { return mongoutil.FindPage[*model.Wallet](ctx, w.coll, bson.M{}, pagination, &options.FindOptions{ Sort: bson.D{{Key: "create_time", Value: -1}}, }) } // FindWalletsByUserIDs finds wallets by user IDs. func (w *WalletMgo) FindWalletsByUserIDs(ctx context.Context, userIDs []string) ([]*model.Wallet, error) { if len(userIDs) == 0 { return []*model.Wallet{}, nil } filter := bson.M{"user_id": bson.M{"$in": userIDs}} return mongoutil.Find[*model.Wallet](ctx, w.coll, filter) } // WalletBalanceRecordMgo implements WalletBalanceRecord using MongoDB as the storage backend. type WalletBalanceRecordMgo struct { coll *mongo.Collection } // NewWalletBalanceRecordMongo creates a new instance of WalletBalanceRecordMgo with the provided MongoDB database. func NewWalletBalanceRecordMongo(db *mongo.Database) (database.WalletBalanceRecord, error) { coll := db.Collection(database.WalletBalanceRecordName) // 先尝试删除可能存在的旧 record_id 索引(如果不是稀疏索引) // 忽略删除失败的错误(索引可能不存在) _, _ = coll.Indexes().DropOne(context.Background(), "record_id_1") // 创建索引,使用稀疏索引允许 record_id 为 null _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ { Keys: bson.D{{Key: "user_id", Value: 1}, {Key: "create_time", Value: -1}}, }, { Keys: bson.D{{Key: "record_id", Value: 1}}, Options: options.Index().SetUnique(true).SetSparse(true), }, { Keys: bson.D{{Key: "create_time", Value: -1}}, }, }) if err != nil { return nil, err } return &WalletBalanceRecordMgo{coll: coll}, nil } // Create creates a new wallet balance record. func (w *WalletBalanceRecordMgo) Create(ctx context.Context, record *model.WalletBalanceRecord) error { if record.CreateTime.IsZero() { record.CreateTime = time.Now() } return mongoutil.InsertOne(ctx, w.coll, record) } // FindByUserID finds all balance records for a user with pagination. func (w *WalletBalanceRecordMgo) FindByUserID(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, records []*model.WalletBalanceRecord, err error) { filter := bson.M{"user_id": userID} return mongoutil.FindPage[*model.WalletBalanceRecord](ctx, w.coll, filter, pagination, &options.FindOptions{ Sort: bson.D{{Key: "create_time", Value: -1}}, }) }