1496 lines
49 KiB
Go
1496 lines
49 KiB
Go
package chat
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"math/rand"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
constantpb "git.imall.cloud/openim/protocol/constant"
|
||
"github.com/openimsdk/tools/utils/datautil"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/openimsdk/tools/errs"
|
||
"github.com/openimsdk/tools/log"
|
||
"github.com/openimsdk/tools/mcontext"
|
||
|
||
"git.imall.cloud/openim/chat/pkg/common/constant"
|
||
"git.imall.cloud/openim/chat/pkg/common/db/dbutil"
|
||
chatdb "git.imall.cloud/openim/chat/pkg/common/db/table/chat"
|
||
"git.imall.cloud/openim/chat/pkg/common/util"
|
||
"git.imall.cloud/openim/chat/pkg/eerrs"
|
||
"git.imall.cloud/openim/chat/pkg/protocol/chat"
|
||
"git.imall.cloud/openim/chat/pkg/sms"
|
||
)
|
||
|
||
type verifyType int
|
||
|
||
const (
|
||
phone verifyType = iota
|
||
mail
|
||
)
|
||
|
||
func (o *chatSvr) verifyCodeJoin(areaCode, phoneNumber string) string {
|
||
return areaCode + " " + phoneNumber
|
||
}
|
||
|
||
func (o *chatSvr) SendVerifyCode(ctx context.Context, req *chat.SendVerifyCodeReq) (*chat.SendVerifyCodeResp, error) {
|
||
switch int(req.UsedFor) {
|
||
case constant.VerificationCodeForRegister, constant.VerificationCodeForH5Register:
|
||
if err := o.Admin.CheckRegister(ctx, req.Ip); err != nil {
|
||
return nil, err
|
||
}
|
||
if req.Email == "" {
|
||
if req.AreaCode == "" || req.PhoneNumber == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("area code or phone number is empty")
|
||
}
|
||
if !strings.HasPrefix(req.AreaCode, "+") {
|
||
req.AreaCode = "+" + req.AreaCode
|
||
}
|
||
if _, err := strconv.ParseUint(req.AreaCode[1:], 10, 64); err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("area code must be number")
|
||
}
|
||
if _, err := strconv.ParseUint(req.PhoneNumber, 10, 64); err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("phone number must be number")
|
||
}
|
||
} else {
|
||
if err := chat.EmailCheck(req.Email); err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("email must be right")
|
||
}
|
||
}
|
||
conf, err := o.Admin.GetConfig(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if val := conf[constant.NeedInvitationCodeRegisterConfigKey]; datautil.Contain(strings.ToLower(val), "1", "true", "yes") {
|
||
if req.InvitationCode == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("invitation code is empty")
|
||
}
|
||
if err := o.Admin.CheckInvitationCode(ctx, req.InvitationCode); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
case constant.VerificationCodeForLogin, constant.VerificationCodeForResetPassword:
|
||
if req.Email == "" {
|
||
_, err := o.Database.TakeAttributeByPhone(ctx, req.AreaCode, req.PhoneNumber)
|
||
if dbutil.IsDBNotFound(err) {
|
||
return nil, eerrs.ErrAccountNotFound.WrapMsg("phone unregistered")
|
||
} else if err != nil {
|
||
return nil, err
|
||
}
|
||
} else {
|
||
_, err := o.Database.TakeAttributeByEmail(ctx, req.Email)
|
||
if dbutil.IsDBNotFound(err) {
|
||
return nil, eerrs.ErrAccountNotFound.WrapMsg("email unregistered")
|
||
} else if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
default:
|
||
return nil, errs.ErrArgs.WrapMsg("used unknown")
|
||
}
|
||
|
||
// 先尝试从数据库 SystemConfig 集合读取 SMS 配置
|
||
var dbSMSClient sms.SMS
|
||
var dbSMSClientErr error
|
||
if req.AreaCode != "" {
|
||
// 先读取数据库配置(从 system_configs 集合)
|
||
dbSMSClient, dbSMSClientErr = o.getSMSClientFromDB(ctx)
|
||
if dbSMSClientErr != nil {
|
||
log.ZWarn(ctx, "Failed to get SMS client from SystemConfig collection", dbSMSClientErr)
|
||
} else if dbSMSClient != nil {
|
||
log.ZInfo(ctx, "Using SMS client from database config", "clientName", dbSMSClient.Name())
|
||
}
|
||
}
|
||
|
||
// 确定最终使用的 SMS 客户端:优先使用数据库配置,如果没有则使用 yml 配置
|
||
var finalSMSClient sms.SMS
|
||
if dbSMSClient != nil {
|
||
finalSMSClient = dbSMSClient
|
||
} else {
|
||
finalSMSClient = o.SMS
|
||
if dbSMSClientErr != nil {
|
||
log.ZWarn(ctx, "Failed to get SMS client from DB, using yml config", dbSMSClientErr)
|
||
}
|
||
}
|
||
|
||
// 检查最终是否有可用的 SMS 或 Mail 客户端
|
||
if finalSMSClient == nil && o.Mail == nil {
|
||
log.ZInfo(ctx, "SMS and Mail clients are both nil, using superCode",
|
||
"areaCode", req.AreaCode,
|
||
"phoneNumber", req.PhoneNumber,
|
||
"email", req.Email,
|
||
"usedFor", req.UsedFor,
|
||
"superCode", o.Code.SuperCode,
|
||
"dbSMSAvailable", dbSMSClient != nil,
|
||
"ymlSMSAvailable", o.SMS != nil)
|
||
return &chat.SendVerifyCodeResp{}, nil // super code
|
||
}
|
||
|
||
if req.Email != "" {
|
||
switch o.conf.Mail.Use {
|
||
case constant.VerifySuperCode:
|
||
log.ZInfo(ctx, "Using superCode for email verification, skip email sending",
|
||
"email", req.Email,
|
||
"usedFor", req.UsedFor,
|
||
"superCode", o.Code.SuperCode)
|
||
return &chat.SendVerifyCodeResp{}, nil // super code
|
||
case constant.VerifyMail:
|
||
default:
|
||
return nil, errs.ErrInternalServer.WrapMsg("email verification code is not enabled")
|
||
}
|
||
}
|
||
|
||
if req.AreaCode != "" {
|
||
// 如果数据库有配置,直接使用;否则检查 yml 配置
|
||
if dbSMSClient == nil {
|
||
// 使用 yml 配置,检查配置类型
|
||
switch o.conf.Phone.Use {
|
||
case constant.VerifySuperCode:
|
||
log.ZInfo(ctx, "Using superCode for phone verification, skip SMS sending",
|
||
"areaCode", req.AreaCode,
|
||
"phoneNumber", req.PhoneNumber,
|
||
"usedFor", req.UsedFor,
|
||
"superCode", o.Code.SuperCode)
|
||
return &chat.SendVerifyCodeResp{}, nil // super code
|
||
case constant.VerifyALi, constant.VerifyBao:
|
||
default:
|
||
return nil, errs.ErrInternalServer.WrapMsg("phone verification code is not enabled")
|
||
}
|
||
}
|
||
}
|
||
|
||
isEmail := req.Email != ""
|
||
var account string
|
||
var timeWindowKey string // 用于在发送成功后设置时间窗口标志(手机号注册场景)
|
||
|
||
// 先创建 account,并检查时间窗口
|
||
if isEmail {
|
||
account = req.Email
|
||
} else {
|
||
account = o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
|
||
|
||
// 对于手机号注册场景,先检查时间窗口(60秒内只能发送一次)
|
||
// 同一个号码在短时间内只能发送一次,不管请求几次
|
||
if int(req.UsedFor) == constant.VerificationCodeForRegister || int(req.UsedFor) == constant.VerificationCodeForH5Register {
|
||
timeWindowKey = "verify_code_sent_window:" + account
|
||
timeWindow := 60 * time.Second // 60秒时间窗口
|
||
|
||
if o.rdb != nil {
|
||
// 检查是否在时间窗口内已经发送过
|
||
exists, err := o.rdb.Exists(ctx, timeWindowKey).Result()
|
||
if err == nil && exists > 0 {
|
||
log.ZWarn(ctx, "Verify code already sent within time window for this account", nil,
|
||
"account", account, "timeWindow", timeWindow)
|
||
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
|
||
}
|
||
|
||
// 获取分布式锁,防止并发请求
|
||
lockKey := "send_verify_code_lock:" + account
|
||
lockValue := uuid.New().String()
|
||
lockTTL := 10 * time.Second
|
||
|
||
acquired, err := o.rdb.SetNX(ctx, lockKey, lockValue, lockTTL).Result()
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Failed to acquire distributed lock for send verify code", err, "account", account)
|
||
} else if !acquired {
|
||
log.ZWarn(ctx, "Another request is sending verify code for this account", nil, "account", account)
|
||
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
|
||
} else {
|
||
// 成功获取锁,确保在函数返回时释放锁
|
||
defer func() {
|
||
script := `
|
||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||
return redis.call("del", KEYS[1])
|
||
else
|
||
return 0
|
||
end
|
||
`
|
||
if err := o.rdb.Eval(ctx, script, []string{lockKey}, lockValue).Err(); err != nil {
|
||
log.ZWarn(ctx, "Failed to release distributed lock for send verify code", err, "account", account)
|
||
}
|
||
}()
|
||
|
||
// 再次检查时间窗口(双重检查,防止在获取锁的过程中其他请求已经发送)
|
||
exists, err = o.rdb.Exists(ctx, timeWindowKey).Result()
|
||
if err == nil && exists > 0 {
|
||
log.ZWarn(ctx, "Verify code already sent within time window (double check)", nil,
|
||
"account", account, "timeWindow", timeWindow)
|
||
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 在获取锁之后才生成验证码,确保同一账号只有一个请求能生成验证码
|
||
var (
|
||
code = o.genVerifyCode()
|
||
sendCode func() error
|
||
)
|
||
|
||
if isEmail {
|
||
sendCode = func() error {
|
||
log.ZInfo(ctx, "Send email verify code request", "email", req.Email, "code", code, "usedFor", req.UsedFor)
|
||
err := o.Mail.SendMail(ctx, req.Email, code)
|
||
if err != nil {
|
||
log.ZError(ctx, "Send email verify code failed", err, "email", req.Email, "code", code)
|
||
} else {
|
||
log.ZInfo(ctx, "Send email verify code success", "email", req.Email, "code", code)
|
||
}
|
||
return err
|
||
}
|
||
} else {
|
||
// 使用之前从数据库读取的 SMS 客户端,如果没有则使用 yml 配置
|
||
smsClient := finalSMSClient
|
||
smsClientType := "yml"
|
||
configSource := "yml"
|
||
|
||
if dbSMSClient != nil {
|
||
// 使用数据库配置
|
||
smsClientType = "database"
|
||
configSource = "database"
|
||
} else {
|
||
// 使用 yml 配置
|
||
smsClient = o.SMS
|
||
}
|
||
|
||
if smsClient == nil {
|
||
return nil, errs.ErrInternalServer.WrapMsg("SMS client is not available")
|
||
}
|
||
|
||
// 记录使用的 SMS 客户端类型和配置来源
|
||
clientName := smsClient.Name()
|
||
log.ZInfo(ctx, "Send SMS verify code request",
|
||
"areaCode", req.AreaCode,
|
||
"phoneNumber", req.PhoneNumber,
|
||
"code", code,
|
||
"usedFor", req.UsedFor,
|
||
"configSource", configSource,
|
||
"smsClientType", smsClientType,
|
||
"smsClientName", clientName)
|
||
|
||
sendCode = func() error {
|
||
err := smsClient.SendCode(ctx, req.AreaCode, req.PhoneNumber, code)
|
||
if err != nil {
|
||
log.ZError(ctx, "Send SMS verify code failed", err,
|
||
"areaCode", req.AreaCode,
|
||
"phoneNumber", req.PhoneNumber,
|
||
"code", code,
|
||
"configSource", configSource,
|
||
"smsClientType", smsClientType,
|
||
"smsClientName", clientName)
|
||
} else {
|
||
log.ZInfo(ctx, "Send SMS verify code success",
|
||
"areaCode", req.AreaCode,
|
||
"phoneNumber", req.PhoneNumber,
|
||
"code", code,
|
||
"configSource", configSource,
|
||
"smsClientType", smsClientType,
|
||
"smsClientName", clientName)
|
||
}
|
||
return err
|
||
}
|
||
}
|
||
|
||
now := time.Now()
|
||
count, err := o.Database.CountVerifyCodeRange(ctx, account, now.Add(-o.Code.UintTime), now)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if o.Code.MaxCount < int(count) {
|
||
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
|
||
}
|
||
platformName := constantpb.PlatformIDToName(int(req.Platform))
|
||
if platformName == "" {
|
||
platformName = fmt.Sprintf("platform:%d", req.Platform)
|
||
}
|
||
|
||
// 使用 Redis 标志确保短信只发送一次,即使事务重试也不会重复
|
||
sentFlagKey := "verify_code_sent:" + account + ":" + code
|
||
var smsSent bool
|
||
if o.rdb != nil {
|
||
// 检查是否已经发送过
|
||
exists, _ := o.rdb.Exists(ctx, sentFlagKey).Result()
|
||
if exists > 0 {
|
||
log.ZInfo(ctx, "Verify code already sent, skipping send", "account", account, "code", code)
|
||
smsSent = true
|
||
}
|
||
}
|
||
|
||
// 创建 sendCode 包装函数,确保只发送一次
|
||
var sendCodeOnce func() error
|
||
if smsSent {
|
||
// 已经发送过,直接返回成功
|
||
sendCodeOnce = func() error { return nil }
|
||
} else {
|
||
// 使用 sync.Once 确保即使事务重试也只发送一次
|
||
// 注意:sync.Once 必须在外部作用域,不能在闭包内创建
|
||
var sendOnce sync.Once
|
||
var sendErr error
|
||
sendCodeOnce = func() error {
|
||
sendOnce.Do(func() {
|
||
sendErr = sendCode()
|
||
// 发送成功后,设置 Redis 标志和时间窗口标志
|
||
if sendErr == nil && o.rdb != nil {
|
||
_ = o.rdb.Set(ctx, sentFlagKey, "1", 5*time.Minute).Err()
|
||
// 设置时间窗口标志(60秒内不允许再次发送)
|
||
if timeWindowKey != "" {
|
||
_ = o.rdb.Set(ctx, timeWindowKey, "1", 60*time.Second).Err()
|
||
}
|
||
}
|
||
})
|
||
return sendErr
|
||
}
|
||
}
|
||
|
||
vc := &chatdb.VerifyCode{
|
||
Account: account,
|
||
Code: code,
|
||
Platform: platformName,
|
||
Duration: uint(o.Code.ValidTime / time.Second),
|
||
Count: 0,
|
||
Used: false,
|
||
CreateTime: now,
|
||
}
|
||
if err := o.Database.AddVerifyCode(ctx, vc, sendCodeOnce); err != nil {
|
||
log.ZError(ctx, "Failed to add verify code to database or send code", err,
|
||
"account", account)
|
||
return nil, err
|
||
}
|
||
|
||
// H5注册时,缓存手机号到Redis
|
||
if int(req.UsedFor) == constant.VerificationCodeForH5Register && req.PhoneNumber != "" {
|
||
// 构建完整手机号(包含区号)
|
||
fullPhone := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
|
||
// 生成一个唯一ID用于关联验证码和手机号
|
||
phoneKey := "h5_register_phone:" + account
|
||
// 缓存手机号,过期时间与验证码相同(5分钟)
|
||
if err := o.storePhoneToRedis(ctx, phoneKey, fullPhone); err != nil {
|
||
log.ZError(ctx, "store phone to redis failed", err, "phoneKey", phoneKey)
|
||
// 不阻断流程,只记录日志
|
||
}
|
||
}
|
||
|
||
return &chat.SendVerifyCodeResp{}, nil
|
||
}
|
||
|
||
func (o *chatSvr) verifyCode(ctx context.Context, account string, verifyCode string, type_ verifyType) (string, error) {
|
||
if verifyCode == "" {
|
||
return "", errs.ErrArgs.WrapMsg("verify code is empty")
|
||
}
|
||
|
||
// 先检查数据库中的 sms_type 配置
|
||
// 如果 sms_type 为 true,说明使用的是数据库配置发送的真实验证码,应该直接验证数据库中的验证码
|
||
usePlainText, err := o.checkSMSTypeFromDB(ctx)
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Failed to check sms_type from DB in verifyCode, will check yml config", err)
|
||
usePlainText = false
|
||
}
|
||
|
||
// 如果 sms_type 为 true,跳过 yml 配置检查,直接去数据库验证验证码
|
||
if !usePlainText {
|
||
// sms_type 为 false 或不存在,检查 yml 配置
|
||
switch type_ {
|
||
case phone:
|
||
switch o.conf.Phone.Use {
|
||
case constant.VerifySuperCode:
|
||
log.ZInfo(ctx, "Using superCode for phone verification",
|
||
"superCode", o.Code.SuperCode,
|
||
"inputCode", verifyCode)
|
||
if o.Code.SuperCode != verifyCode {
|
||
log.ZError(ctx, "SuperCode verification failed", nil,
|
||
"superCode", o.Code.SuperCode,
|
||
"inputCode", verifyCode,
|
||
"match", false)
|
||
return "", eerrs.ErrVerifyCodeNotMatch.Wrap()
|
||
}
|
||
log.ZInfo(ctx, "SuperCode verification success")
|
||
return "", nil
|
||
case constant.VerifyALi, constant.VerifyBao:
|
||
default:
|
||
return "", errs.ErrInternalServer.WrapMsg("phone verification code is not enabled", "use", o.conf.Phone.Use)
|
||
}
|
||
case mail:
|
||
switch o.conf.Mail.Use {
|
||
case constant.VerifySuperCode:
|
||
log.ZInfo(ctx, "Using superCode for email verification",
|
||
"superCode", o.Code.SuperCode,
|
||
"inputCode", verifyCode)
|
||
if o.Code.SuperCode != verifyCode {
|
||
log.ZError(ctx, "SuperCode verification failed", nil,
|
||
"superCode", o.Code.SuperCode,
|
||
"inputCode", verifyCode,
|
||
"match", false)
|
||
return "", eerrs.ErrVerifyCodeNotMatch.Wrap()
|
||
}
|
||
log.ZInfo(ctx, "SuperCode verification success")
|
||
return "", nil
|
||
case constant.VerifyMail:
|
||
default:
|
||
return "", errs.ErrInternalServer.WrapMsg("email verification code is not enabled")
|
||
}
|
||
}
|
||
}
|
||
|
||
last, err := o.Database.TakeLastVerifyCode(ctx, account)
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
log.ZError(ctx, "Verify code not found in database", err,
|
||
"account", account)
|
||
return "", eerrs.ErrVerifyCodeExpired.Wrap()
|
||
}
|
||
log.ZError(ctx, "Failed to get verify code from database", err,
|
||
"account", account)
|
||
return "", err
|
||
}
|
||
|
||
if last.CreateTime.Unix()+int64(last.Duration) < time.Now().Unix() {
|
||
log.ZError(ctx, "Verify code expired", nil,
|
||
"account", account,
|
||
"createTime", last.CreateTime.Unix(),
|
||
"duration", last.Duration,
|
||
"currentTime", time.Now().Unix(),
|
||
"expiredTime", last.CreateTime.Unix()+int64(last.Duration))
|
||
return last.ID, eerrs.ErrVerifyCodeExpired.Wrap()
|
||
}
|
||
if last.Used {
|
||
log.ZError(ctx, "Verify code already used", nil,
|
||
"account", account,
|
||
"codeID", last.ID)
|
||
return last.ID, eerrs.ErrVerifyCodeUsed.Wrap()
|
||
}
|
||
if n := o.Code.ValidCount; n > 0 {
|
||
if last.Count >= n {
|
||
log.ZError(ctx, "Verify code max count exceeded", nil,
|
||
"account", account,
|
||
"count", last.Count,
|
||
"maxCount", n)
|
||
return last.ID, eerrs.ErrVerifyCodeMaxCount.Wrap()
|
||
}
|
||
if last.Code != verifyCode {
|
||
log.ZWarn(ctx, "Verify code mismatch, incrementing count", nil,
|
||
"account", account,
|
||
"storedCode", last.Code,
|
||
"inputCode", verifyCode,
|
||
"currentCount", last.Count)
|
||
if err := o.Database.UpdateVerifyCodeIncrCount(ctx, last.ID); err != nil {
|
||
return last.ID, err
|
||
}
|
||
}
|
||
}
|
||
|
||
// 比较验证码
|
||
if last.Code != verifyCode {
|
||
log.ZError(ctx, "Verify code not match", nil,
|
||
"account", account,
|
||
"storedCode", last.Code,
|
||
"inputCode", verifyCode)
|
||
return last.ID, eerrs.ErrVerifyCodeNotMatch.Wrap()
|
||
}
|
||
return last.ID, nil
|
||
}
|
||
|
||
// verifyCodeByAccount 使用account作为key验证验证码(用于account注册场景)
|
||
func (o *chatSvr) verifyCodeByAccount(ctx context.Context, account string, verifyCode string) (string, error) {
|
||
if verifyCode == "" {
|
||
return "", errs.ErrArgs.WrapMsg("verify code is empty")
|
||
}
|
||
|
||
// 检查超级验证码
|
||
if o.Code.SuperCode != "" && o.Code.SuperCode == verifyCode {
|
||
log.ZInfo(ctx, "SuperCode verification success for account",
|
||
"account", account,
|
||
"superCode", o.Code.SuperCode)
|
||
return "", nil
|
||
}
|
||
|
||
last, err := o.Database.TakeLastVerifyCode(ctx, account)
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
log.ZError(ctx, "Verify code not found in database for account", err,
|
||
"account", account)
|
||
return "", eerrs.ErrVerifyCodeExpired.Wrap()
|
||
}
|
||
log.ZError(ctx, "Failed to get verify code from database for account", err,
|
||
"account", account)
|
||
return "", err
|
||
}
|
||
|
||
if last.CreateTime.Unix()+int64(last.Duration) < time.Now().Unix() {
|
||
log.ZError(ctx, "Verify code expired for account", nil,
|
||
"account", account,
|
||
"createTime", last.CreateTime.Unix(),
|
||
"duration", last.Duration,
|
||
"currentTime", time.Now().Unix())
|
||
return last.ID, eerrs.ErrVerifyCodeExpired.Wrap()
|
||
}
|
||
if last.Used {
|
||
log.ZError(ctx, "Verify code already used for account", nil,
|
||
"account", account,
|
||
"codeID", last.ID)
|
||
return last.ID, eerrs.ErrVerifyCodeUsed.Wrap()
|
||
}
|
||
if n := o.Code.ValidCount; n > 0 {
|
||
if last.Count >= n {
|
||
log.ZError(ctx, "Verify code max count exceeded for account", nil,
|
||
"account", account,
|
||
"count", last.Count,
|
||
"maxCount", n)
|
||
return last.ID, eerrs.ErrVerifyCodeMaxCount.Wrap()
|
||
}
|
||
if last.Code != verifyCode {
|
||
log.ZWarn(ctx, "Verify code mismatch for account, incrementing count", nil,
|
||
"account", account,
|
||
"storedCode", last.Code,
|
||
"inputCode", verifyCode,
|
||
"currentCount", last.Count)
|
||
if err := o.Database.UpdateVerifyCodeIncrCount(ctx, last.ID); err != nil {
|
||
return last.ID, err
|
||
}
|
||
}
|
||
}
|
||
|
||
// 比较验证码
|
||
if last.Code != verifyCode {
|
||
log.ZError(ctx, "Verify code not match for account", nil,
|
||
"account", account,
|
||
"storedCode", last.Code,
|
||
"inputCode", verifyCode)
|
||
return last.ID, eerrs.ErrVerifyCodeNotMatch.Wrap()
|
||
}
|
||
return last.ID, nil
|
||
}
|
||
|
||
func (o *chatSvr) VerifyCode(ctx context.Context, req *chat.VerifyCodeReq) (*chat.VerifyCodeResp, error) {
|
||
var verifyCode string
|
||
|
||
// 如果提供了captchaID,则是图片验证码验证
|
||
if req.CaptchaID != "" {
|
||
// 从Redis获取图片验证码
|
||
storedCode, err := o.getCaptchaFromRedis(ctx, req.CaptchaID)
|
||
if err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("验证码已过期或不存在,请重新获取")
|
||
}
|
||
|
||
// 比较验证码
|
||
if storedCode != req.VerifyCode {
|
||
return nil, eerrs.ErrVerifyCodeNotMatch.Wrap()
|
||
}
|
||
|
||
// 验证成功,删除Redis中的验证码(一次性使用)
|
||
if err := o.delCaptchaFromRedis(ctx, req.CaptchaID); err != nil {
|
||
log.ZWarn(ctx, "delete captcha from redis failed", err, "captchaID", req.CaptchaID)
|
||
}
|
||
|
||
// 如果是H5注册场景(提供了手机号),生成registerToken
|
||
if req.PhoneNumber != "" {
|
||
// 从Redis获取之前缓存的手机号
|
||
account := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
|
||
phoneKey := "h5_register_phone:" + account
|
||
cachedPhone, err := o.getPhoneFromRedis(ctx, phoneKey)
|
||
if err != nil {
|
||
// 如果Redis中没有,使用请求中的手机号
|
||
cachedPhone = account
|
||
}
|
||
|
||
// 生成registerToken
|
||
registerToken := uuid.New().String()
|
||
|
||
// 加密手机号并存储到Redis(token -> 加密手机号),有效期120秒
|
||
encryptedPhone, err := util.EncryptPhone(cachedPhone)
|
||
if err != nil {
|
||
log.ZError(ctx, "encrypt phone failed", err, "phone", cachedPhone)
|
||
// 如果加密失败,直接存储明文(不推荐,但保证流程继续)
|
||
if err := o.storeRegisterTokenToRedis(ctx, registerToken, cachedPhone); err != nil {
|
||
log.ZError(ctx, "store register token failed", err, "token", registerToken)
|
||
}
|
||
} else {
|
||
// 存储加密后的手机号
|
||
if err := o.storeRegisterTokenToRedis(ctx, registerToken, encryptedPhone); err != nil {
|
||
log.ZError(ctx, "store register token failed", err, "token", registerToken)
|
||
}
|
||
}
|
||
|
||
// 返回registerToken
|
||
return &chat.VerifyCodeResp{
|
||
RegisterToken: registerToken,
|
||
}, nil
|
||
}
|
||
|
||
// 图片验证码验证成功(非H5注册场景)
|
||
return &chat.VerifyCodeResp{}, nil
|
||
}
|
||
|
||
// 否则是短信/邮件验证码验证
|
||
// 根据数据库中的 sms_type 配置决定验证码格式
|
||
// 检查数据库中的 sms_type 配置
|
||
usePlainText, err := o.checkSMSTypeFromDB(ctx)
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Failed to check sms_type from DB, will try decrypt", err)
|
||
// 如果查询失败,默认尝试解密(保持原有逻辑)
|
||
usePlainText = false
|
||
}
|
||
|
||
if usePlainText {
|
||
// sms_type 为 true,使用明文验证码
|
||
verifyCode = req.VerifyCode
|
||
} else {
|
||
// sms_type 为 false 或不存在,需要解密验证码(base64解码)
|
||
decryptedCode, err := util.DecryptVerifyCode(req.VerifyCode)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to decrypt verify code", err)
|
||
return nil, errs.ErrArgs.WrapMsg("验证码解密失败", "error", err.Error())
|
||
}
|
||
|
||
// 解析验证码和时间戳(格式:验证码-时间戳)
|
||
code, timestamp, err := util.ParseVerifyCodeWithTimestamp(decryptedCode)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to parse verify code with timestamp", err)
|
||
return nil, errs.ErrArgs.WrapMsg("解析验证码失败", "error", err.Error())
|
||
}
|
||
|
||
// 验证时间戳(必须在1分钟内)
|
||
if !util.ValidateTimestamp(timestamp, 30) {
|
||
return nil, errs.ErrArgs.WrapMsg("验证码已过期,请重新获取")
|
||
}
|
||
|
||
// 使用解析出的验证码
|
||
verifyCode = code
|
||
}
|
||
|
||
var account string
|
||
if req.Account != "" {
|
||
// 使用account进行验证(注册场景)
|
||
account = req.Account
|
||
|
||
// 检查账号是否已存在(注册场景不允许使用已存在的账号)
|
||
_, err := o.Database.TakeAttributeByAccount(ctx, account)
|
||
if err == nil {
|
||
// 账号已存在,返回错误告知
|
||
return nil, eerrs.ErrAccountAlreadyRegister.WrapMsg("该账号已被注册,请使用其他账号")
|
||
} else if !dbutil.IsDBNotFound(err) {
|
||
// 数据库查询出错
|
||
return nil, err
|
||
}
|
||
|
||
// 对于account,我们需要查找验证码,但verifyCode函数只支持phone和mail
|
||
// 这里我们直接使用account作为key查找验证码
|
||
if _, err := o.verifyCodeByAccount(ctx, account, verifyCode); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证成功,生成registerToken用于注册
|
||
registerToken := uuid.New().String()
|
||
|
||
// 加密account并存储到Redis(token -> 加密account),有效期120秒
|
||
encryptedAccount, err := util.EncryptPhone(account) // 复用EncryptPhone函数,它实际是AES加密
|
||
if err != nil {
|
||
log.ZError(ctx, "encrypt account failed", err, "account", account)
|
||
// 如果加密失败,直接存储明文(不推荐,但保证流程继续)
|
||
if err := o.storeRegisterTokenToRedis(ctx, registerToken, account); err != nil {
|
||
log.ZError(ctx, "store register token failed", err, "token", registerToken)
|
||
}
|
||
} else {
|
||
// 存储加密后的account
|
||
if err := o.storeRegisterTokenToRedis(ctx, registerToken, encryptedAccount); err != nil {
|
||
log.ZError(ctx, "store register token failed", err, "token", registerToken)
|
||
}
|
||
}
|
||
|
||
// 返回registerToken
|
||
return &chat.VerifyCodeResp{
|
||
RegisterToken: registerToken,
|
||
}, nil
|
||
} else if req.PhoneNumber != "" {
|
||
account = o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
|
||
if _, err := o.verifyCode(ctx, account, verifyCode, phone); err != nil {
|
||
return nil, err
|
||
}
|
||
} else {
|
||
account = req.Email
|
||
if _, err := o.verifyCode(ctx, account, verifyCode, mail); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
return &chat.VerifyCodeResp{}, nil
|
||
}
|
||
|
||
func (o *chatSvr) genUserID() string {
|
||
const l = 10
|
||
data := make([]byte, l)
|
||
rand.Read(data)
|
||
chars := []byte("0123456789")
|
||
for i := 0; i < len(data); i++ {
|
||
if i == 0 {
|
||
data[i] = chars[1:][data[i]%9]
|
||
} else {
|
||
data[i] = chars[data[i]%10]
|
||
}
|
||
}
|
||
return string(data)
|
||
}
|
||
|
||
func (o *chatSvr) genVerifyCode() string {
|
||
data := make([]byte, o.Code.Len)
|
||
rand.Read(data)
|
||
chars := []byte("0123456789")
|
||
for i := 0; i < len(data); i++ {
|
||
data[i] = chars[data[i]%10]
|
||
}
|
||
return string(data)
|
||
}
|
||
|
||
func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) (*chat.RegisterUserResp, error) {
|
||
isAdmin, err := o.Admin.CheckNilOrAdmin(ctx)
|
||
ctx = o.WithAdminUser(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err = o.checkRegisterInfo(ctx, req.User, isAdmin); err != nil {
|
||
return nil, err
|
||
}
|
||
var usedInvitationCode bool
|
||
if !isAdmin {
|
||
if !o.AllowRegister {
|
||
return nil, errs.ErrNoPermission.WrapMsg("register user is disabled")
|
||
}
|
||
if req.User.UserID != "" {
|
||
return nil, errs.ErrNoPermission.WrapMsg("only admin can set user id")
|
||
}
|
||
if err := o.Admin.CheckRegister(ctx, req.Ip); err != nil {
|
||
return nil, err
|
||
}
|
||
conf, err := o.Admin.GetConfig(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if val := conf[constant.NeedInvitationCodeRegisterConfigKey]; datautil.Contain(strings.ToLower(val), "1", "true", "yes") {
|
||
usedInvitationCode = true
|
||
if req.InvitationCode == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("invitation code is empty")
|
||
}
|
||
if err := o.Admin.CheckInvitationCode(ctx, req.InvitationCode); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// 如果提供了registerToken(H5注册场景或account注册场景),验证token
|
||
if req.RegisterToken != "" {
|
||
// 从Redis获取token关联的数据(可能是手机号或account)
|
||
tokenData, err := o.getRegisterTokenFromRedis(ctx, req.RegisterToken)
|
||
if err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("registerToken已失效,请重新验证")
|
||
}
|
||
|
||
// 尝试解密(可能是加密存储的)
|
||
var tokenDataDecrypted string
|
||
decryptedData, err := util.DecryptPhone(tokenData)
|
||
if err != nil {
|
||
// 如果解密失败,可能是明文存储的(兼容旧数据)
|
||
tokenDataDecrypted = tokenData
|
||
} else {
|
||
tokenDataDecrypted = decryptedData
|
||
}
|
||
|
||
// 判断是account注册还是手机号注册
|
||
if req.User.Account != "" {
|
||
// account注册场景
|
||
if tokenDataDecrypted != req.User.Account {
|
||
return nil, errs.ErrArgs.WrapMsg("账号不匹配,请重新验证")
|
||
}
|
||
} else if req.User.PhoneNumber != "" {
|
||
// 手机号注册场景
|
||
// 构建请求中的完整手机号
|
||
account := o.verifyCodeJoin(req.User.AreaCode, req.User.PhoneNumber)
|
||
if tokenDataDecrypted != account {
|
||
return nil, errs.ErrArgs.WrapMsg("手机号不匹配,请重新验证")
|
||
}
|
||
} else {
|
||
return nil, errs.ErrArgs.WrapMsg("registerToken需要配合账号或手机号使用")
|
||
}
|
||
|
||
// 验证成功,删除token(一次性使用)
|
||
if err := o.delRegisterTokenFromRedis(ctx, req.RegisterToken); err != nil {
|
||
log.ZWarn(ctx, "delete register token failed", err, "token", req.RegisterToken)
|
||
}
|
||
|
||
// 已通过token验证,不需要验证短信验证码
|
||
} else {
|
||
// 普通注册场景,验证短信/邮件验证码
|
||
// 根据数据库中的 sms_type 配置决定验证码格式
|
||
// 检查数据库中的 sms_type 配置
|
||
usePlainText, err := o.checkSMSTypeFromDB(ctx)
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Failed to check sms_type from DB for register, will try decrypt", err)
|
||
// 如果查询失败,默认尝试解密(保持原有逻辑)
|
||
usePlainText = false
|
||
}
|
||
|
||
var verifyCode string
|
||
if usePlainText {
|
||
// sms_type 为 true,使用明文验证码
|
||
verifyCode = req.VerifyCode
|
||
} else {
|
||
// sms_type 为 false 或不存在,需要解密验证码(base64解码)
|
||
decryptedCode, err := util.DecryptVerifyCode(req.VerifyCode)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to decrypt verify code for register", err)
|
||
return nil, errs.ErrArgs.WrapMsg("验证码解密失败", "error", err.Error())
|
||
}
|
||
|
||
// 解析验证码和时间戳(格式:验证码-时间戳)
|
||
code, timestamp, err := util.ParseVerifyCodeWithTimestamp(decryptedCode)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to parse verify code with timestamp for register", err)
|
||
return nil, errs.ErrArgs.WrapMsg("解析验证码失败", "error", err.Error())
|
||
}
|
||
|
||
// 验证时间戳(必须在1分钟内)
|
||
if !util.ValidateTimestamp(timestamp, 30) {
|
||
return nil, errs.ErrArgs.WrapMsg("验证码已过期,请重新获取")
|
||
}
|
||
|
||
// 使用解析出的验证码
|
||
verifyCode = code
|
||
}
|
||
|
||
if req.User.Email == "" {
|
||
if _, err := o.verifyCode(ctx, o.verifyCodeJoin(req.User.AreaCode, req.User.PhoneNumber), verifyCode, phone); err != nil {
|
||
return nil, err
|
||
}
|
||
} else {
|
||
if _, err := o.verifyCode(ctx, req.User.Email, verifyCode, mail); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if req.User.UserID == "" {
|
||
for i := 0; i < 20; i++ {
|
||
userID := o.genUserID()
|
||
_, err := o.Database.GetUser(ctx, userID)
|
||
if err == nil {
|
||
continue
|
||
} else if dbutil.IsDBNotFound(err) {
|
||
req.User.UserID = userID
|
||
break
|
||
} else {
|
||
return nil, err
|
||
}
|
||
}
|
||
if req.User.UserID == "" {
|
||
return nil, errs.ErrInternalServer.WrapMsg("gen user id failed")
|
||
}
|
||
} else {
|
||
_, err := o.Database.GetUser(ctx, req.User.UserID)
|
||
if err == nil {
|
||
return nil, errs.ErrArgs.WrapMsg("appoint user id already register")
|
||
} else if !dbutil.IsDBNotFound(err) {
|
||
return nil, err
|
||
}
|
||
}
|
||
var (
|
||
credentials []*chatdb.Credential
|
||
registerType int32
|
||
)
|
||
|
||
if req.User.PhoneNumber != "" {
|
||
registerType = constant.PhoneRegister
|
||
credentials = append(credentials, &chatdb.Credential{
|
||
UserID: req.User.UserID,
|
||
Account: BuildCredentialPhone(req.User.AreaCode, req.User.PhoneNumber),
|
||
Type: constant.CredentialPhone,
|
||
AllowChange: true,
|
||
})
|
||
}
|
||
|
||
if req.User.Account != "" {
|
||
credentials = append(credentials, &chatdb.Credential{
|
||
UserID: req.User.UserID,
|
||
Account: req.User.Account,
|
||
Type: constant.CredentialAccount,
|
||
AllowChange: true,
|
||
})
|
||
registerType = constant.AccountRegister
|
||
}
|
||
|
||
if req.User.Email != "" {
|
||
registerType = constant.EmailRegister
|
||
credentials = append(credentials, &chatdb.Credential{
|
||
UserID: req.User.UserID,
|
||
Account: req.User.Email,
|
||
Type: constant.CredentialEmail,
|
||
AllowChange: true,
|
||
})
|
||
}
|
||
register := &chatdb.Register{
|
||
UserID: req.User.UserID,
|
||
DeviceID: req.DeviceID,
|
||
IP: req.Ip,
|
||
Platform: constantpb.PlatformID2Name[int(req.Platform)],
|
||
AccountType: "",
|
||
Mode: constant.UserMode,
|
||
CreateTime: time.Now(),
|
||
}
|
||
account := &chatdb.Account{
|
||
UserID: req.User.UserID,
|
||
Password: req.User.Password,
|
||
OperatorUserID: mcontext.GetOpUserID(ctx),
|
||
ChangeTime: register.CreateTime,
|
||
CreateTime: register.CreateTime,
|
||
}
|
||
|
||
attribute := &chatdb.Attribute{
|
||
UserID: req.User.UserID,
|
||
Account: req.User.Account,
|
||
PhoneNumber: req.User.PhoneNumber,
|
||
AreaCode: req.User.AreaCode,
|
||
Email: req.User.Email,
|
||
Nickname: req.User.Nickname,
|
||
FaceURL: req.User.FaceURL,
|
||
Gender: req.User.Gender,
|
||
UserType: req.User.UserType, // 使用传入的用户类型,如果没有则默认为0
|
||
UserFlag: req.User.UserFlag, // 使用传入的用户标签,如果没有则默认为空
|
||
BirthTime: time.UnixMilli(req.User.Birth),
|
||
ChangeTime: register.CreateTime,
|
||
CreateTime: register.CreateTime,
|
||
AllowVibration: constant.DefaultAllowVibration,
|
||
AllowBeep: constant.DefaultAllowBeep,
|
||
AllowAddFriend: constant.DefaultAllowAddFriend,
|
||
RegisterType: registerType,
|
||
}
|
||
if err := o.Database.RegisterUser(ctx, register, account, attribute, credentials); err != nil {
|
||
return nil, err
|
||
}
|
||
if usedInvitationCode {
|
||
if err := o.Admin.UseInvitationCode(ctx, req.User.UserID, req.InvitationCode); err != nil {
|
||
log.ZError(ctx, "UseInvitationCode", err, "userID", req.User.UserID, "invitationCode", req.InvitationCode)
|
||
}
|
||
}
|
||
|
||
// 检查系统配置是否开启钱包功能
|
||
walletConfig, err := o.Database.GetSystemConfig(ctx, "wallet.enabled")
|
||
if err == nil && walletConfig.Value == "true" {
|
||
// 为新注册用户创建默认钱包
|
||
wallet := &chatdb.Wallet{
|
||
UserID: req.User.UserID,
|
||
Balance: 0,
|
||
CreateTime: time.Now(),
|
||
UpdateTime: time.Now(),
|
||
}
|
||
if err := o.Database.CreateWallet(ctx, wallet); err != nil {
|
||
log.ZError(ctx, "CreateWallet failed on register", err, "userID", req.User.UserID)
|
||
// 不影响注册流程,只记录错误
|
||
} else {
|
||
log.ZInfo(ctx, "Auto created wallet for user on register", "userID", req.User.UserID)
|
||
}
|
||
}
|
||
|
||
// 查询注册奖励配置,如果启用则给用户钱包赠送金额
|
||
o.handleRegisterReward(ctx, req.User.UserID)
|
||
|
||
var resp chat.RegisterUserResp
|
||
if req.AutoLogin {
|
||
chatToken, err := o.Admin.CreateToken(ctx, req.User.UserID, constant.NormalUser)
|
||
if err == nil {
|
||
resp.ChatToken = chatToken.Token
|
||
} else {
|
||
log.ZError(ctx, "Admin CreateToken Failed", err, "userID", req.User.UserID, "platform", req.Platform)
|
||
}
|
||
|
||
// 注册成功后自动登录时,创建登录记录
|
||
record := &chatdb.UserLoginRecord{
|
||
UserID: req.User.UserID,
|
||
LoginTime: time.Now(),
|
||
IP: req.Ip,
|
||
DeviceID: req.DeviceID,
|
||
Platform: constantpb.PlatformID2Name[int(req.Platform)],
|
||
}
|
||
if err := o.Database.LoginRecord(ctx, record, nil); err != nil {
|
||
log.ZError(ctx, "LoginRecord failed on register auto login", err, "userID", req.User.UserID)
|
||
// 不影响注册流程,只记录错误
|
||
} else {
|
||
log.ZInfo(ctx, "Created login record on register auto login", "userID", req.User.UserID, "ip", req.Ip)
|
||
}
|
||
}
|
||
resp.UserID = req.User.UserID
|
||
return &resp, nil
|
||
}
|
||
|
||
func (o *chatSvr) Login(ctx context.Context, req *chat.LoginReq) (*chat.LoginResp, error) {
|
||
resp := &chat.LoginResp{}
|
||
if req.Password == "" && req.VerifyCode == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("password or code must be set")
|
||
}
|
||
var (
|
||
err error
|
||
credential *chatdb.Credential
|
||
acc string
|
||
)
|
||
|
||
switch {
|
||
case req.Account != "":
|
||
acc = req.Account
|
||
case req.PhoneNumber != "":
|
||
if req.AreaCode == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("area code must")
|
||
}
|
||
if !strings.HasPrefix(req.AreaCode, "+") {
|
||
req.AreaCode = "+" + req.AreaCode
|
||
}
|
||
if _, err := strconv.ParseUint(req.AreaCode[1:], 10, 64); err != nil {
|
||
return nil, errs.ErrArgs.WrapMsg("area code must be number")
|
||
}
|
||
acc = BuildCredentialPhone(req.AreaCode, req.PhoneNumber)
|
||
case req.Email != "":
|
||
acc = req.Email
|
||
default:
|
||
return nil, errs.ErrArgs.WrapMsg("account or phone number or email must be set")
|
||
}
|
||
|
||
credential, err = o.Database.TakeCredentialByAccount(ctx, acc)
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
|
||
return nil, eerrs.ErrAccountNotFound.WrapMsg("user unregistered")
|
||
}
|
||
|
||
return nil, err
|
||
}
|
||
|
||
if err := o.Admin.CheckLogin(ctx, credential.UserID, req.Ip); err != nil {
|
||
// 检查是否是账户被封禁的错误,确保错误码正确传递
|
||
// 即使错误被包装,错误字符串中也会包含原始错误信息
|
||
errStr := err.Error()
|
||
if strings.Contains(errStr, "20015") || strings.Contains(errStr, "AccountBlocked") || strings.Contains(errStr, "账户已被封禁") {
|
||
log.ZInfo(ctx, "Detected AccountBlocked error in CheckLogin, returning explicit error", "originalError", errStr)
|
||
return nil, eerrs.ErrAccountBlocked.WrapMsg("账户已被封禁")
|
||
}
|
||
return nil, err
|
||
}
|
||
var verifyCodeID *string
|
||
if req.Password == "" {
|
||
var (
|
||
id string
|
||
)
|
||
|
||
if req.Email == "" {
|
||
account := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
|
||
id, err = o.verifyCode(ctx, account, req.VerifyCode, phone)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
} else {
|
||
account := req.Email
|
||
id, err = o.verifyCode(ctx, account, req.VerifyCode, mail)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
if id != "" {
|
||
verifyCodeID = &id
|
||
}
|
||
} else {
|
||
account, err := o.Database.TakeAccount(ctx, credential.UserID)
|
||
if err != nil {
|
||
|
||
return nil, err
|
||
}
|
||
if account.Password != req.Password {
|
||
|
||
return nil, eerrs.ErrPassword.Wrap()
|
||
}
|
||
|
||
}
|
||
|
||
chatToken, err := o.Admin.CreateToken(ctx, credential.UserID, constant.NormalUser)
|
||
if err != nil {
|
||
|
||
return nil, err
|
||
}
|
||
|
||
record := &chatdb.UserLoginRecord{
|
||
UserID: credential.UserID,
|
||
LoginTime: time.Now(),
|
||
IP: req.Ip,
|
||
DeviceID: req.DeviceID,
|
||
Platform: constantpb.PlatformIDToName(int(req.Platform)),
|
||
}
|
||
|
||
if err := o.Database.LoginRecord(ctx, record, verifyCodeID); err != nil {
|
||
log.ZError(ctx, "LoginRecord failed", err, "userID", credential.UserID)
|
||
return nil, err
|
||
}
|
||
|
||
if verifyCodeID != nil {
|
||
if err := o.Database.DelVerifyCode(ctx, *verifyCodeID); err != nil {
|
||
log.ZError(ctx, "DelVerifyCode failed", err, "verifyCodeID", *verifyCodeID)
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// 检查系统配置是否开启钱包功能
|
||
walletConfig, err := o.Database.GetSystemConfig(ctx, "wallet.enabled")
|
||
if err == nil && walletConfig.Value == "true" {
|
||
// 检查用户钱包是否存在
|
||
_, err := o.Database.GetWallet(ctx, credential.UserID)
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
// 钱包不存在,创建默认钱包
|
||
wallet := &chatdb.Wallet{
|
||
UserID: credential.UserID,
|
||
Balance: 0,
|
||
CreateTime: time.Now(),
|
||
UpdateTime: time.Now(),
|
||
}
|
||
if err := o.Database.CreateWallet(ctx, wallet); err != nil {
|
||
log.ZError(ctx, "CreateWallet failed", err, "userID", credential.UserID)
|
||
// 不影响登录流程,只记录错误
|
||
} else {
|
||
log.ZInfo(ctx, "Auto created wallet for user on login", "userID", credential.UserID)
|
||
}
|
||
} else {
|
||
// 其他错误,只记录
|
||
log.ZError(ctx, "GetWallet failed", err, "userID", credential.UserID)
|
||
}
|
||
}
|
||
}
|
||
|
||
resp.UserID = credential.UserID
|
||
resp.ChatToken = chatToken.Token
|
||
return resp, nil
|
||
}
|
||
|
||
func (o *chatSvr) GetCaptchaImage(ctx context.Context, req *chat.GetCaptchaImageReq) (*chat.GetCaptchaImageResp, error) {
|
||
// 生成6位随机数字验证码
|
||
code := o.genVerifyCode()
|
||
|
||
// 生成唯一的验证码ID
|
||
captchaID := uuid.New().String()
|
||
|
||
// 将验证码存储到Redis,过期时间5分钟
|
||
if err := o.storeCaptchaToRedis(ctx, captchaID, code); err != nil {
|
||
log.ZError(ctx, "store captcha to redis failed", err, "captchaID", captchaID)
|
||
return nil, err
|
||
}
|
||
|
||
log.ZDebug(ctx, "generate captcha code success", "captchaID", captchaID)
|
||
|
||
return &chat.GetCaptchaImageResp{
|
||
CaptchaID: captchaID,
|
||
Code: code,
|
||
}, nil
|
||
}
|
||
|
||
// storeCaptchaToRedis 将验证码存储到Redis,过期时间5分钟
|
||
func (o *chatSvr) storeCaptchaToRedis(ctx context.Context, captchaID, code string) error {
|
||
if o.rdb == nil {
|
||
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "captcha:" + captchaID
|
||
return errs.Wrap(o.rdb.Set(ctx, key, code, 5*time.Minute).Err())
|
||
}
|
||
|
||
// storePhoneToRedis 将手机号存储到Redis,过期时间5分钟
|
||
func (o *chatSvr) storePhoneToRedis(ctx context.Context, key, phoneNumber string) error {
|
||
if o.rdb == nil {
|
||
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
return errs.Wrap(o.rdb.Set(ctx, key, phoneNumber, 5*time.Minute).Err())
|
||
}
|
||
|
||
// getCaptchaFromRedis 从Redis获取图片验证码
|
||
func (o *chatSvr) getCaptchaFromRedis(ctx context.Context, captchaID string) (string, error) {
|
||
if o.rdb == nil {
|
||
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "captcha:" + captchaID
|
||
code, err := o.rdb.Get(ctx, key).Result()
|
||
if err != nil {
|
||
return "", errs.WrapMsg(err, "验证码不存在或已过期")
|
||
}
|
||
return code, nil
|
||
}
|
||
|
||
// delCaptchaFromRedis 从Redis删除图片验证码
|
||
func (o *chatSvr) delCaptchaFromRedis(ctx context.Context, captchaID string) error {
|
||
if o.rdb == nil {
|
||
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "captcha:" + captchaID
|
||
return errs.Wrap(o.rdb.Del(ctx, key).Err())
|
||
}
|
||
|
||
// getPhoneFromRedis 从Redis获取手机号
|
||
func (o *chatSvr) getPhoneFromRedis(ctx context.Context, key string) (string, error) {
|
||
if o.rdb == nil {
|
||
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
phone, err := o.rdb.Get(ctx, key).Result()
|
||
if err != nil {
|
||
return "", errs.WrapMsg(err, "手机号不存在或已过期")
|
||
}
|
||
return phone, nil
|
||
}
|
||
|
||
// storeRegisterTokenToRedis 将注册token存储到Redis,关联手机号,有效期120秒
|
||
func (o *chatSvr) storeRegisterTokenToRedis(ctx context.Context, token, phoneNumber string) error {
|
||
if o.rdb == nil {
|
||
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "h5_register_token:" + token
|
||
return errs.Wrap(o.rdb.Set(ctx, key, phoneNumber, 120*time.Second).Err())
|
||
}
|
||
|
||
// getRegisterTokenFromRedis 从Redis获取注册token关联的手机号
|
||
func (o *chatSvr) getRegisterTokenFromRedis(ctx context.Context, token string) (string, error) {
|
||
if o.rdb == nil {
|
||
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "h5_register_token:" + token
|
||
phone, err := o.rdb.Get(ctx, key).Result()
|
||
if err != nil {
|
||
return "", errs.WrapMsg(err, "token不存在或已过期")
|
||
}
|
||
return phone, nil
|
||
}
|
||
|
||
// delRegisterTokenFromRedis 从Redis删除注册token(一次性使用)
|
||
func (o *chatSvr) delRegisterTokenFromRedis(ctx context.Context, token string) error {
|
||
if o.rdb == nil {
|
||
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
|
||
}
|
||
key := "h5_register_token:" + token
|
||
return errs.Wrap(o.rdb.Del(ctx, key).Err())
|
||
}
|
||
|
||
// handleRegisterReward 处理注册奖励:查询系统配置,如果启用则给用户钱包赠送金额
|
||
func (o *chatSvr) handleRegisterReward(ctx context.Context, userID string) {
|
||
// 查询系统配置 key="register"
|
||
config, err := o.Database.GetSystemConfig(ctx, "register")
|
||
if err != nil {
|
||
// 配置不存在或查询失败,记录日志但不影响注册流程
|
||
if dbutil.IsDBNotFound(err) {
|
||
log.ZDebug(ctx, "Register reward config not found", "key", "register")
|
||
} else {
|
||
log.ZWarn(ctx, "Get register reward config failed", err, "key", "register", "userID", userID)
|
||
}
|
||
return
|
||
}
|
||
|
||
// 检查配置是否启用
|
||
if !config.Enabled {
|
||
log.ZDebug(ctx, "Register reward config is disabled", "key", "register", "userID", userID)
|
||
return
|
||
}
|
||
|
||
// 解析配置值中的金额
|
||
var amount int64
|
||
switch config.ValueType {
|
||
case chatdb.ConfigValueTypeNumber:
|
||
// 数字类型:直接解析为整数(单位:分)
|
||
amountFloat, err := strconv.ParseFloat(config.Value, 64)
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Parse register reward amount failed", err, "value", config.Value, "valueType", config.ValueType, "userID", userID)
|
||
return
|
||
}
|
||
amount = int64(amountFloat)
|
||
log.ZDebug(ctx, "Parsed register reward amount from number type", "value", config.Value, "amount", amount, "userID", userID)
|
||
case chatdb.ConfigValueTypeString:
|
||
// 字符串类型:尝试解析为数字
|
||
amountFloat, err := strconv.ParseFloat(config.Value, 64)
|
||
if err != nil {
|
||
log.ZWarn(ctx, "Parse register reward amount from string failed", err, "value", config.Value, "valueType", config.ValueType, "userID", userID)
|
||
return
|
||
}
|
||
amount = int64(amountFloat)
|
||
log.ZDebug(ctx, "Parsed register reward amount from string type", "value", config.Value, "amount", amount, "userID", userID)
|
||
default:
|
||
log.ZWarn(ctx, "Register reward config value type not supported", nil, "valueType", config.ValueType, "userID", userID)
|
||
return
|
||
}
|
||
|
||
// 金额必须大于0
|
||
if amount <= 0 {
|
||
log.ZDebug(ctx, "Register reward amount is zero or negative", "amount", amount, "userID", userID)
|
||
return
|
||
}
|
||
|
||
// 给用户钱包充值
|
||
beforeBalance, afterBalance, err := o.Database.IncrementWalletBalance(ctx, userID, amount)
|
||
if err != nil {
|
||
log.ZError(ctx, "Increment wallet balance for register reward failed", err, "userID", userID, "amount", amount)
|
||
return
|
||
}
|
||
|
||
// 创建余额变动记录
|
||
record := &chatdb.WalletBalanceRecord{
|
||
ID: uuid.New().String(),
|
||
UserID: userID,
|
||
Amount: amount,
|
||
Type: 5, // 5-奖励
|
||
BeforeBalance: beforeBalance,
|
||
AfterBalance: afterBalance,
|
||
Remark: "注册奖励",
|
||
CreateTime: time.Now(),
|
||
}
|
||
if err := o.Database.CreateWalletBalanceRecord(ctx, record); err != nil {
|
||
// 记录创建失败不影响充值,只记录警告日志
|
||
log.ZWarn(ctx, "Create wallet balance record for register reward failed", err, "userID", userID, "amount", amount)
|
||
} else {
|
||
log.ZInfo(ctx, "Register reward granted successfully", "userID", userID, "amount", amount, "beforeBalance", beforeBalance, "afterBalance", afterBalance)
|
||
}
|
||
}
|
||
|
||
// checkSMSTypeFromDB 检查数据库中的 sms_type 配置
|
||
// 返回: true 表示 sms_type 为 true(使用明文验证码),false 表示 sms_type 为 false 或不存在(需要解密)
|
||
func (o *chatSvr) checkSMSTypeFromDB(ctx context.Context) (bool, error) {
|
||
// 从 SystemConfig 集合(system_configs)中读取 sms_type 配置
|
||
smsTypeConfig, err := o.Database.GetSystemConfig(ctx, "sms_type")
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
// SystemConfig 集合中没有 sms_type 配置,返回 false(需要解密,保持原有逻辑)
|
||
return false, nil
|
||
}
|
||
// 从 SystemConfig 集合查询出错,返回错误
|
||
return false, errs.WrapMsg(err, "failed to get sms_type config")
|
||
}
|
||
|
||
// 检查 sms_type 的值
|
||
var usePlainText bool
|
||
switch smsTypeConfig.ValueType {
|
||
case chatdb.ConfigValueTypeBool:
|
||
// 布尔类型:直接解析
|
||
usePlainText = smsTypeConfig.Value == "true" || smsTypeConfig.Value == "1"
|
||
case chatdb.ConfigValueTypeString:
|
||
// 字符串类型:检查是否为 "true" 或 "1"
|
||
usePlainText = strings.ToLower(smsTypeConfig.Value) == "true" || smsTypeConfig.Value == "1"
|
||
case chatdb.ConfigValueTypeNumber:
|
||
// 数字类型:1 表示 true,0 表示 false
|
||
val, err := strconv.ParseInt(smsTypeConfig.Value, 10, 64)
|
||
if err == nil {
|
||
usePlainText = val != 0
|
||
} else {
|
||
usePlainText = false
|
||
}
|
||
default:
|
||
// 其他类型,默认使用加密格式(false)
|
||
usePlainText = false
|
||
}
|
||
|
||
return usePlainText, nil
|
||
}
|
||
|
||
// getSMSClientFromDB 从数据库读取短信配置
|
||
// 逻辑:先读数据库,如果数据库没有 sms_type 配置,返回 nil 使用 yml 配置
|
||
// 如果数据库有 sms_type 且为 true,使用数据库中的 sms_info 配置创建 SMS 客户端
|
||
// 如果数据库有 sms_type 但为 false,返回 nil 使用 yml 配置
|
||
func (o *chatSvr) getSMSClientFromDB(ctx context.Context) (sms.SMS, error) {
|
||
// 从 SystemConfig 集合(system_configs)中读取 sms_type 配置
|
||
smsTypeConfig, err := o.Database.GetSystemConfig(ctx, "sms_type")
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
// SystemConfig 集合中没有 sms_type 配置,返回 nil 使用 yml 配置
|
||
return nil, nil
|
||
}
|
||
// 从 SystemConfig 集合查询出错,返回错误
|
||
log.ZError(ctx, "Failed to get sms_type config from SystemConfig collection", err)
|
||
return nil, errs.WrapMsg(err, "failed to get sms_type config")
|
||
}
|
||
|
||
// 检查 sms_type 的值
|
||
var useDBSMS bool
|
||
switch smsTypeConfig.ValueType {
|
||
case chatdb.ConfigValueTypeBool:
|
||
// 布尔类型:直接解析
|
||
useDBSMS = smsTypeConfig.Value == "true" || smsTypeConfig.Value == "1"
|
||
case chatdb.ConfigValueTypeString:
|
||
// 字符串类型:检查是否为 "true" 或 "1"
|
||
useDBSMS = strings.ToLower(smsTypeConfig.Value) == "true" || smsTypeConfig.Value == "1"
|
||
case chatdb.ConfigValueTypeNumber:
|
||
// 数字类型:1 表示 true,0 表示 false
|
||
val, err := strconv.ParseInt(smsTypeConfig.Value, 10, 64)
|
||
if err == nil {
|
||
useDBSMS = val != 0
|
||
}
|
||
default:
|
||
// 其他类型,默认不使用数据库配置
|
||
return nil, nil
|
||
}
|
||
|
||
// 如果 sms_type 为 false,返回 nil 使用 yml 配置
|
||
if !useDBSMS {
|
||
return nil, nil
|
||
}
|
||
|
||
// sms_type 为 true,从 SystemConfig 集合读取 sms_info 配置
|
||
smsInfoConfig, err := o.Database.GetSystemConfig(ctx, "sms_info")
|
||
if err != nil {
|
||
if dbutil.IsDBNotFound(err) {
|
||
log.ZWarn(ctx, "sms_info config not found but sms_type is true, will use yml config", nil)
|
||
return nil, nil
|
||
}
|
||
log.ZError(ctx, "Failed to get sms_info config from SystemConfig collection", err)
|
||
return nil, errs.WrapMsg(err, "failed to get sms_info config")
|
||
}
|
||
|
||
// 解析 sms_info JSON
|
||
type SMSInfo struct {
|
||
AccessKeyID string `json:"accessKeyId"`
|
||
AccessKeySecret string `json:"accessKeySecret"`
|
||
Endpoint string `json:"endpoint"`
|
||
SignName string `json:"signName"`
|
||
Type string `json:"type"`
|
||
VerificationCodeTemplateCode string `json:"verificationCodeTemplateCode"`
|
||
}
|
||
|
||
var smsInfo SMSInfo
|
||
if err := json.Unmarshal([]byte(smsInfoConfig.Value), &smsInfo); err != nil {
|
||
log.ZError(ctx, "Failed to parse sms_info JSON", err)
|
||
return nil, errs.WrapMsg(err, "failed to parse sms_info JSON")
|
||
}
|
||
|
||
// 根据 type 字段创建对应的 SMS 客户端
|
||
switch strings.ToLower(smsInfo.Type) {
|
||
case "bao":
|
||
client, err := sms.NewBao(
|
||
smsInfo.Endpoint,
|
||
smsInfo.AccessKeyID,
|
||
smsInfo.AccessKeySecret,
|
||
smsInfo.SignName,
|
||
smsInfo.VerificationCodeTemplateCode,
|
||
)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to create Bao SMS client from database config", err)
|
||
return nil, errs.WrapMsg(err, "failed to create Bao SMS client")
|
||
}
|
||
log.ZInfo(ctx, "Created Bao SMS client from database config")
|
||
return client, nil
|
||
case "ali":
|
||
client, err := sms.NewAli(
|
||
smsInfo.Endpoint,
|
||
smsInfo.AccessKeyID,
|
||
smsInfo.AccessKeySecret,
|
||
smsInfo.SignName,
|
||
smsInfo.VerificationCodeTemplateCode,
|
||
)
|
||
if err != nil {
|
||
log.ZError(ctx, "Failed to create Ali SMS client from database config", err)
|
||
return nil, errs.WrapMsg(err, "failed to create Ali SMS client")
|
||
}
|
||
log.ZInfo(ctx, "Created Ali SMS client from database config")
|
||
return client, nil
|
||
default:
|
||
log.ZWarn(ctx, "Unknown SMS type in sms_info config, using default SMS client", nil, "type", smsInfo.Type)
|
||
return nil, nil
|
||
}
|
||
}
|