Files
chat-deploy/internal/rpc/chat/login.go
kim.dev.6789 b7f8db7d08 复制项目
2026-01-14 22:35:45 +08:00

1496 lines
49 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package chat
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
constantpb "git.imall.cloud/openim/protocol/constant"
"github.com/openimsdk/tools/utils/datautil"
"github.com/google/uuid"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
"git.imall.cloud/openim/chat/pkg/common/constant"
"git.imall.cloud/openim/chat/pkg/common/db/dbutil"
chatdb "git.imall.cloud/openim/chat/pkg/common/db/table/chat"
"git.imall.cloud/openim/chat/pkg/common/util"
"git.imall.cloud/openim/chat/pkg/eerrs"
"git.imall.cloud/openim/chat/pkg/protocol/chat"
"git.imall.cloud/openim/chat/pkg/sms"
)
type verifyType int
const (
phone verifyType = iota
mail
)
func (o *chatSvr) verifyCodeJoin(areaCode, phoneNumber string) string {
return areaCode + " " + phoneNumber
}
func (o *chatSvr) SendVerifyCode(ctx context.Context, req *chat.SendVerifyCodeReq) (*chat.SendVerifyCodeResp, error) {
switch int(req.UsedFor) {
case constant.VerificationCodeForRegister, constant.VerificationCodeForH5Register:
if err := o.Admin.CheckRegister(ctx, req.Ip); err != nil {
return nil, err
}
if req.Email == "" {
if req.AreaCode == "" || req.PhoneNumber == "" {
return nil, errs.ErrArgs.WrapMsg("area code or phone number is empty")
}
if !strings.HasPrefix(req.AreaCode, "+") {
req.AreaCode = "+" + req.AreaCode
}
if _, err := strconv.ParseUint(req.AreaCode[1:], 10, 64); err != nil {
return nil, errs.ErrArgs.WrapMsg("area code must be number")
}
if _, err := strconv.ParseUint(req.PhoneNumber, 10, 64); err != nil {
return nil, errs.ErrArgs.WrapMsg("phone number must be number")
}
} else {
if err := chat.EmailCheck(req.Email); err != nil {
return nil, errs.ErrArgs.WrapMsg("email must be right")
}
}
conf, err := o.Admin.GetConfig(ctx)
if err != nil {
return nil, err
}
if val := conf[constant.NeedInvitationCodeRegisterConfigKey]; datautil.Contain(strings.ToLower(val), "1", "true", "yes") {
if req.InvitationCode == "" {
return nil, errs.ErrArgs.WrapMsg("invitation code is empty")
}
if err := o.Admin.CheckInvitationCode(ctx, req.InvitationCode); err != nil {
return nil, err
}
}
case constant.VerificationCodeForLogin, constant.VerificationCodeForResetPassword:
if req.Email == "" {
_, err := o.Database.TakeAttributeByPhone(ctx, req.AreaCode, req.PhoneNumber)
if dbutil.IsDBNotFound(err) {
return nil, eerrs.ErrAccountNotFound.WrapMsg("phone unregistered")
} else if err != nil {
return nil, err
}
} else {
_, err := o.Database.TakeAttributeByEmail(ctx, req.Email)
if dbutil.IsDBNotFound(err) {
return nil, eerrs.ErrAccountNotFound.WrapMsg("email unregistered")
} else if err != nil {
return nil, err
}
}
default:
return nil, errs.ErrArgs.WrapMsg("used unknown")
}
// 先尝试从数据库 SystemConfig 集合读取 SMS 配置
var dbSMSClient sms.SMS
var dbSMSClientErr error
if req.AreaCode != "" {
// 先读取数据库配置(从 system_configs 集合)
dbSMSClient, dbSMSClientErr = o.getSMSClientFromDB(ctx)
if dbSMSClientErr != nil {
log.ZWarn(ctx, "Failed to get SMS client from SystemConfig collection", dbSMSClientErr)
} else if dbSMSClient != nil {
log.ZInfo(ctx, "Using SMS client from database config", "clientName", dbSMSClient.Name())
}
}
// 确定最终使用的 SMS 客户端:优先使用数据库配置,如果没有则使用 yml 配置
var finalSMSClient sms.SMS
if dbSMSClient != nil {
finalSMSClient = dbSMSClient
} else {
finalSMSClient = o.SMS
if dbSMSClientErr != nil {
log.ZWarn(ctx, "Failed to get SMS client from DB, using yml config", dbSMSClientErr)
}
}
// 检查最终是否有可用的 SMS 或 Mail 客户端
if finalSMSClient == nil && o.Mail == nil {
log.ZInfo(ctx, "SMS and Mail clients are both nil, using superCode",
"areaCode", req.AreaCode,
"phoneNumber", req.PhoneNumber,
"email", req.Email,
"usedFor", req.UsedFor,
"superCode", o.Code.SuperCode,
"dbSMSAvailable", dbSMSClient != nil,
"ymlSMSAvailable", o.SMS != nil)
return &chat.SendVerifyCodeResp{}, nil // super code
}
if req.Email != "" {
switch o.conf.Mail.Use {
case constant.VerifySuperCode:
log.ZInfo(ctx, "Using superCode for email verification, skip email sending",
"email", req.Email,
"usedFor", req.UsedFor,
"superCode", o.Code.SuperCode)
return &chat.SendVerifyCodeResp{}, nil // super code
case constant.VerifyMail:
default:
return nil, errs.ErrInternalServer.WrapMsg("email verification code is not enabled")
}
}
if req.AreaCode != "" {
// 如果数据库有配置,直接使用;否则检查 yml 配置
if dbSMSClient == nil {
// 使用 yml 配置,检查配置类型
switch o.conf.Phone.Use {
case constant.VerifySuperCode:
log.ZInfo(ctx, "Using superCode for phone verification, skip SMS sending",
"areaCode", req.AreaCode,
"phoneNumber", req.PhoneNumber,
"usedFor", req.UsedFor,
"superCode", o.Code.SuperCode)
return &chat.SendVerifyCodeResp{}, nil // super code
case constant.VerifyALi, constant.VerifyBao:
default:
return nil, errs.ErrInternalServer.WrapMsg("phone verification code is not enabled")
}
}
}
isEmail := req.Email != ""
var account string
var timeWindowKey string // 用于在发送成功后设置时间窗口标志(手机号注册场景)
// 先创建 account并检查时间窗口
if isEmail {
account = req.Email
} else {
account = o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
// 对于手机号注册场景先检查时间窗口60秒内只能发送一次
// 同一个号码在短时间内只能发送一次,不管请求几次
if int(req.UsedFor) == constant.VerificationCodeForRegister || int(req.UsedFor) == constant.VerificationCodeForH5Register {
timeWindowKey = "verify_code_sent_window:" + account
timeWindow := 60 * time.Second // 60秒时间窗口
if o.rdb != nil {
// 检查是否在时间窗口内已经发送过
exists, err := o.rdb.Exists(ctx, timeWindowKey).Result()
if err == nil && exists > 0 {
log.ZWarn(ctx, "Verify code already sent within time window for this account", nil,
"account", account, "timeWindow", timeWindow)
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
}
// 获取分布式锁,防止并发请求
lockKey := "send_verify_code_lock:" + account
lockValue := uuid.New().String()
lockTTL := 10 * time.Second
acquired, err := o.rdb.SetNX(ctx, lockKey, lockValue, lockTTL).Result()
if err != nil {
log.ZWarn(ctx, "Failed to acquire distributed lock for send verify code", err, "account", account)
} else if !acquired {
log.ZWarn(ctx, "Another request is sending verify code for this account", nil, "account", account)
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
} else {
// 成功获取锁,确保在函数返回时释放锁
defer func() {
script := `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`
if err := o.rdb.Eval(ctx, script, []string{lockKey}, lockValue).Err(); err != nil {
log.ZWarn(ctx, "Failed to release distributed lock for send verify code", err, "account", account)
}
}()
// 再次检查时间窗口(双重检查,防止在获取锁的过程中其他请求已经发送)
exists, err = o.rdb.Exists(ctx, timeWindowKey).Result()
if err == nil && exists > 0 {
log.ZWarn(ctx, "Verify code already sent within time window (double check)", nil,
"account", account, "timeWindow", timeWindow)
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
}
}
}
}
}
// 在获取锁之后才生成验证码,确保同一账号只有一个请求能生成验证码
var (
code = o.genVerifyCode()
sendCode func() error
)
if isEmail {
sendCode = func() error {
log.ZInfo(ctx, "Send email verify code request", "email", req.Email, "code", code, "usedFor", req.UsedFor)
err := o.Mail.SendMail(ctx, req.Email, code)
if err != nil {
log.ZError(ctx, "Send email verify code failed", err, "email", req.Email, "code", code)
} else {
log.ZInfo(ctx, "Send email verify code success", "email", req.Email, "code", code)
}
return err
}
} else {
// 使用之前从数据库读取的 SMS 客户端,如果没有则使用 yml 配置
smsClient := finalSMSClient
smsClientType := "yml"
configSource := "yml"
if dbSMSClient != nil {
// 使用数据库配置
smsClientType = "database"
configSource = "database"
} else {
// 使用 yml 配置
smsClient = o.SMS
}
if smsClient == nil {
return nil, errs.ErrInternalServer.WrapMsg("SMS client is not available")
}
// 记录使用的 SMS 客户端类型和配置来源
clientName := smsClient.Name()
log.ZInfo(ctx, "Send SMS verify code request",
"areaCode", req.AreaCode,
"phoneNumber", req.PhoneNumber,
"code", code,
"usedFor", req.UsedFor,
"configSource", configSource,
"smsClientType", smsClientType,
"smsClientName", clientName)
sendCode = func() error {
err := smsClient.SendCode(ctx, req.AreaCode, req.PhoneNumber, code)
if err != nil {
log.ZError(ctx, "Send SMS verify code failed", err,
"areaCode", req.AreaCode,
"phoneNumber", req.PhoneNumber,
"code", code,
"configSource", configSource,
"smsClientType", smsClientType,
"smsClientName", clientName)
} else {
log.ZInfo(ctx, "Send SMS verify code success",
"areaCode", req.AreaCode,
"phoneNumber", req.PhoneNumber,
"code", code,
"configSource", configSource,
"smsClientType", smsClientType,
"smsClientName", clientName)
}
return err
}
}
now := time.Now()
count, err := o.Database.CountVerifyCodeRange(ctx, account, now.Add(-o.Code.UintTime), now)
if err != nil {
return nil, err
}
if o.Code.MaxCount < int(count) {
return nil, eerrs.ErrVerifyCodeSendFrequently.Wrap()
}
platformName := constantpb.PlatformIDToName(int(req.Platform))
if platformName == "" {
platformName = fmt.Sprintf("platform:%d", req.Platform)
}
// 使用 Redis 标志确保短信只发送一次,即使事务重试也不会重复
sentFlagKey := "verify_code_sent:" + account + ":" + code
var smsSent bool
if o.rdb != nil {
// 检查是否已经发送过
exists, _ := o.rdb.Exists(ctx, sentFlagKey).Result()
if exists > 0 {
log.ZInfo(ctx, "Verify code already sent, skipping send", "account", account, "code", code)
smsSent = true
}
}
// 创建 sendCode 包装函数,确保只发送一次
var sendCodeOnce func() error
if smsSent {
// 已经发送过,直接返回成功
sendCodeOnce = func() error { return nil }
} else {
// 使用 sync.Once 确保即使事务重试也只发送一次
// 注意sync.Once 必须在外部作用域,不能在闭包内创建
var sendOnce sync.Once
var sendErr error
sendCodeOnce = func() error {
sendOnce.Do(func() {
sendErr = sendCode()
// 发送成功后,设置 Redis 标志和时间窗口标志
if sendErr == nil && o.rdb != nil {
_ = o.rdb.Set(ctx, sentFlagKey, "1", 5*time.Minute).Err()
// 设置时间窗口标志60秒内不允许再次发送
if timeWindowKey != "" {
_ = o.rdb.Set(ctx, timeWindowKey, "1", 60*time.Second).Err()
}
}
})
return sendErr
}
}
vc := &chatdb.VerifyCode{
Account: account,
Code: code,
Platform: platformName,
Duration: uint(o.Code.ValidTime / time.Second),
Count: 0,
Used: false,
CreateTime: now,
}
if err := o.Database.AddVerifyCode(ctx, vc, sendCodeOnce); err != nil {
log.ZError(ctx, "Failed to add verify code to database or send code", err,
"account", account)
return nil, err
}
// H5注册时缓存手机号到Redis
if int(req.UsedFor) == constant.VerificationCodeForH5Register && req.PhoneNumber != "" {
// 构建完整手机号(包含区号)
fullPhone := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
// 生成一个唯一ID用于关联验证码和手机号
phoneKey := "h5_register_phone:" + account
// 缓存手机号过期时间与验证码相同5分钟
if err := o.storePhoneToRedis(ctx, phoneKey, fullPhone); err != nil {
log.ZError(ctx, "store phone to redis failed", err, "phoneKey", phoneKey)
// 不阻断流程,只记录日志
}
}
return &chat.SendVerifyCodeResp{}, nil
}
func (o *chatSvr) verifyCode(ctx context.Context, account string, verifyCode string, type_ verifyType) (string, error) {
if verifyCode == "" {
return "", errs.ErrArgs.WrapMsg("verify code is empty")
}
// 先检查数据库中的 sms_type 配置
// 如果 sms_type 为 true说明使用的是数据库配置发送的真实验证码应该直接验证数据库中的验证码
usePlainText, err := o.checkSMSTypeFromDB(ctx)
if err != nil {
log.ZWarn(ctx, "Failed to check sms_type from DB in verifyCode, will check yml config", err)
usePlainText = false
}
// 如果 sms_type 为 true跳过 yml 配置检查,直接去数据库验证验证码
if !usePlainText {
// sms_type 为 false 或不存在,检查 yml 配置
switch type_ {
case phone:
switch o.conf.Phone.Use {
case constant.VerifySuperCode:
log.ZInfo(ctx, "Using superCode for phone verification",
"superCode", o.Code.SuperCode,
"inputCode", verifyCode)
if o.Code.SuperCode != verifyCode {
log.ZError(ctx, "SuperCode verification failed", nil,
"superCode", o.Code.SuperCode,
"inputCode", verifyCode,
"match", false)
return "", eerrs.ErrVerifyCodeNotMatch.Wrap()
}
log.ZInfo(ctx, "SuperCode verification success")
return "", nil
case constant.VerifyALi, constant.VerifyBao:
default:
return "", errs.ErrInternalServer.WrapMsg("phone verification code is not enabled", "use", o.conf.Phone.Use)
}
case mail:
switch o.conf.Mail.Use {
case constant.VerifySuperCode:
log.ZInfo(ctx, "Using superCode for email verification",
"superCode", o.Code.SuperCode,
"inputCode", verifyCode)
if o.Code.SuperCode != verifyCode {
log.ZError(ctx, "SuperCode verification failed", nil,
"superCode", o.Code.SuperCode,
"inputCode", verifyCode,
"match", false)
return "", eerrs.ErrVerifyCodeNotMatch.Wrap()
}
log.ZInfo(ctx, "SuperCode verification success")
return "", nil
case constant.VerifyMail:
default:
return "", errs.ErrInternalServer.WrapMsg("email verification code is not enabled")
}
}
}
last, err := o.Database.TakeLastVerifyCode(ctx, account)
if err != nil {
if dbutil.IsDBNotFound(err) {
log.ZError(ctx, "Verify code not found in database", err,
"account", account)
return "", eerrs.ErrVerifyCodeExpired.Wrap()
}
log.ZError(ctx, "Failed to get verify code from database", err,
"account", account)
return "", err
}
if last.CreateTime.Unix()+int64(last.Duration) < time.Now().Unix() {
log.ZError(ctx, "Verify code expired", nil,
"account", account,
"createTime", last.CreateTime.Unix(),
"duration", last.Duration,
"currentTime", time.Now().Unix(),
"expiredTime", last.CreateTime.Unix()+int64(last.Duration))
return last.ID, eerrs.ErrVerifyCodeExpired.Wrap()
}
if last.Used {
log.ZError(ctx, "Verify code already used", nil,
"account", account,
"codeID", last.ID)
return last.ID, eerrs.ErrVerifyCodeUsed.Wrap()
}
if n := o.Code.ValidCount; n > 0 {
if last.Count >= n {
log.ZError(ctx, "Verify code max count exceeded", nil,
"account", account,
"count", last.Count,
"maxCount", n)
return last.ID, eerrs.ErrVerifyCodeMaxCount.Wrap()
}
if last.Code != verifyCode {
log.ZWarn(ctx, "Verify code mismatch, incrementing count", nil,
"account", account,
"storedCode", last.Code,
"inputCode", verifyCode,
"currentCount", last.Count)
if err := o.Database.UpdateVerifyCodeIncrCount(ctx, last.ID); err != nil {
return last.ID, err
}
}
}
// 比较验证码
if last.Code != verifyCode {
log.ZError(ctx, "Verify code not match", nil,
"account", account,
"storedCode", last.Code,
"inputCode", verifyCode)
return last.ID, eerrs.ErrVerifyCodeNotMatch.Wrap()
}
return last.ID, nil
}
// verifyCodeByAccount 使用account作为key验证验证码用于account注册场景
func (o *chatSvr) verifyCodeByAccount(ctx context.Context, account string, verifyCode string) (string, error) {
if verifyCode == "" {
return "", errs.ErrArgs.WrapMsg("verify code is empty")
}
// 检查超级验证码
if o.Code.SuperCode != "" && o.Code.SuperCode == verifyCode {
log.ZInfo(ctx, "SuperCode verification success for account",
"account", account,
"superCode", o.Code.SuperCode)
return "", nil
}
last, err := o.Database.TakeLastVerifyCode(ctx, account)
if err != nil {
if dbutil.IsDBNotFound(err) {
log.ZError(ctx, "Verify code not found in database for account", err,
"account", account)
return "", eerrs.ErrVerifyCodeExpired.Wrap()
}
log.ZError(ctx, "Failed to get verify code from database for account", err,
"account", account)
return "", err
}
if last.CreateTime.Unix()+int64(last.Duration) < time.Now().Unix() {
log.ZError(ctx, "Verify code expired for account", nil,
"account", account,
"createTime", last.CreateTime.Unix(),
"duration", last.Duration,
"currentTime", time.Now().Unix())
return last.ID, eerrs.ErrVerifyCodeExpired.Wrap()
}
if last.Used {
log.ZError(ctx, "Verify code already used for account", nil,
"account", account,
"codeID", last.ID)
return last.ID, eerrs.ErrVerifyCodeUsed.Wrap()
}
if n := o.Code.ValidCount; n > 0 {
if last.Count >= n {
log.ZError(ctx, "Verify code max count exceeded for account", nil,
"account", account,
"count", last.Count,
"maxCount", n)
return last.ID, eerrs.ErrVerifyCodeMaxCount.Wrap()
}
if last.Code != verifyCode {
log.ZWarn(ctx, "Verify code mismatch for account, incrementing count", nil,
"account", account,
"storedCode", last.Code,
"inputCode", verifyCode,
"currentCount", last.Count)
if err := o.Database.UpdateVerifyCodeIncrCount(ctx, last.ID); err != nil {
return last.ID, err
}
}
}
// 比较验证码
if last.Code != verifyCode {
log.ZError(ctx, "Verify code not match for account", nil,
"account", account,
"storedCode", last.Code,
"inputCode", verifyCode)
return last.ID, eerrs.ErrVerifyCodeNotMatch.Wrap()
}
return last.ID, nil
}
func (o *chatSvr) VerifyCode(ctx context.Context, req *chat.VerifyCodeReq) (*chat.VerifyCodeResp, error) {
var verifyCode string
// 如果提供了captchaID则是图片验证码验证
if req.CaptchaID != "" {
// 从Redis获取图片验证码
storedCode, err := o.getCaptchaFromRedis(ctx, req.CaptchaID)
if err != nil {
return nil, errs.ErrArgs.WrapMsg("验证码已过期或不存在,请重新获取")
}
// 比较验证码
if storedCode != req.VerifyCode {
return nil, eerrs.ErrVerifyCodeNotMatch.Wrap()
}
// 验证成功删除Redis中的验证码一次性使用
if err := o.delCaptchaFromRedis(ctx, req.CaptchaID); err != nil {
log.ZWarn(ctx, "delete captcha from redis failed", err, "captchaID", req.CaptchaID)
}
// 如果是H5注册场景提供了手机号生成registerToken
if req.PhoneNumber != "" {
// 从Redis获取之前缓存的手机号
account := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
phoneKey := "h5_register_phone:" + account
cachedPhone, err := o.getPhoneFromRedis(ctx, phoneKey)
if err != nil {
// 如果Redis中没有使用请求中的手机号
cachedPhone = account
}
// 生成registerToken
registerToken := uuid.New().String()
// 加密手机号并存储到Redistoken -> 加密手机号有效期120秒
encryptedPhone, err := util.EncryptPhone(cachedPhone)
if err != nil {
log.ZError(ctx, "encrypt phone failed", err, "phone", cachedPhone)
// 如果加密失败,直接存储明文(不推荐,但保证流程继续)
if err := o.storeRegisterTokenToRedis(ctx, registerToken, cachedPhone); err != nil {
log.ZError(ctx, "store register token failed", err, "token", registerToken)
}
} else {
// 存储加密后的手机号
if err := o.storeRegisterTokenToRedis(ctx, registerToken, encryptedPhone); err != nil {
log.ZError(ctx, "store register token failed", err, "token", registerToken)
}
}
// 返回registerToken
return &chat.VerifyCodeResp{
RegisterToken: registerToken,
}, nil
}
// 图片验证码验证成功非H5注册场景
return &chat.VerifyCodeResp{}, nil
}
// 否则是短信/邮件验证码验证
// 根据数据库中的 sms_type 配置决定验证码格式
// 检查数据库中的 sms_type 配置
usePlainText, err := o.checkSMSTypeFromDB(ctx)
if err != nil {
log.ZWarn(ctx, "Failed to check sms_type from DB, will try decrypt", err)
// 如果查询失败,默认尝试解密(保持原有逻辑)
usePlainText = false
}
if usePlainText {
// sms_type 为 true使用明文验证码
verifyCode = req.VerifyCode
} else {
// sms_type 为 false 或不存在需要解密验证码base64解码
decryptedCode, err := util.DecryptVerifyCode(req.VerifyCode)
if err != nil {
log.ZError(ctx, "Failed to decrypt verify code", err)
return nil, errs.ErrArgs.WrapMsg("验证码解密失败", "error", err.Error())
}
// 解析验证码和时间戳(格式:验证码-时间戳)
code, timestamp, err := util.ParseVerifyCodeWithTimestamp(decryptedCode)
if err != nil {
log.ZError(ctx, "Failed to parse verify code with timestamp", err)
return nil, errs.ErrArgs.WrapMsg("解析验证码失败", "error", err.Error())
}
// 验证时间戳必须在1分钟内
if !util.ValidateTimestamp(timestamp, 30) {
return nil, errs.ErrArgs.WrapMsg("验证码已过期,请重新获取")
}
// 使用解析出的验证码
verifyCode = code
}
var account string
if req.Account != "" {
// 使用account进行验证注册场景
account = req.Account
// 检查账号是否已存在(注册场景不允许使用已存在的账号)
_, err := o.Database.TakeAttributeByAccount(ctx, account)
if err == nil {
// 账号已存在,返回错误告知
return nil, eerrs.ErrAccountAlreadyRegister.WrapMsg("该账号已被注册,请使用其他账号")
} else if !dbutil.IsDBNotFound(err) {
// 数据库查询出错
return nil, err
}
// 对于account我们需要查找验证码但verifyCode函数只支持phone和mail
// 这里我们直接使用account作为key查找验证码
if _, err := o.verifyCodeByAccount(ctx, account, verifyCode); err != nil {
return nil, err
}
// 验证成功生成registerToken用于注册
registerToken := uuid.New().String()
// 加密account并存储到Redistoken -> 加密account有效期120秒
encryptedAccount, err := util.EncryptPhone(account) // 复用EncryptPhone函数它实际是AES加密
if err != nil {
log.ZError(ctx, "encrypt account failed", err, "account", account)
// 如果加密失败,直接存储明文(不推荐,但保证流程继续)
if err := o.storeRegisterTokenToRedis(ctx, registerToken, account); err != nil {
log.ZError(ctx, "store register token failed", err, "token", registerToken)
}
} else {
// 存储加密后的account
if err := o.storeRegisterTokenToRedis(ctx, registerToken, encryptedAccount); err != nil {
log.ZError(ctx, "store register token failed", err, "token", registerToken)
}
}
// 返回registerToken
return &chat.VerifyCodeResp{
RegisterToken: registerToken,
}, nil
} else if req.PhoneNumber != "" {
account = o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
if _, err := o.verifyCode(ctx, account, verifyCode, phone); err != nil {
return nil, err
}
} else {
account = req.Email
if _, err := o.verifyCode(ctx, account, verifyCode, mail); err != nil {
return nil, err
}
}
return &chat.VerifyCodeResp{}, nil
}
func (o *chatSvr) genUserID() string {
const l = 10
data := make([]byte, l)
rand.Read(data)
chars := []byte("0123456789")
for i := 0; i < len(data); i++ {
if i == 0 {
data[i] = chars[1:][data[i]%9]
} else {
data[i] = chars[data[i]%10]
}
}
return string(data)
}
func (o *chatSvr) genVerifyCode() string {
data := make([]byte, o.Code.Len)
rand.Read(data)
chars := []byte("0123456789")
for i := 0; i < len(data); i++ {
data[i] = chars[data[i]%10]
}
return string(data)
}
func (o *chatSvr) RegisterUser(ctx context.Context, req *chat.RegisterUserReq) (*chat.RegisterUserResp, error) {
isAdmin, err := o.Admin.CheckNilOrAdmin(ctx)
ctx = o.WithAdminUser(ctx)
if err != nil {
return nil, err
}
if err = o.checkRegisterInfo(ctx, req.User, isAdmin); err != nil {
return nil, err
}
var usedInvitationCode bool
if !isAdmin {
if !o.AllowRegister {
return nil, errs.ErrNoPermission.WrapMsg("register user is disabled")
}
if req.User.UserID != "" {
return nil, errs.ErrNoPermission.WrapMsg("only admin can set user id")
}
if err := o.Admin.CheckRegister(ctx, req.Ip); err != nil {
return nil, err
}
conf, err := o.Admin.GetConfig(ctx)
if err != nil {
return nil, err
}
if val := conf[constant.NeedInvitationCodeRegisterConfigKey]; datautil.Contain(strings.ToLower(val), "1", "true", "yes") {
usedInvitationCode = true
if req.InvitationCode == "" {
return nil, errs.ErrArgs.WrapMsg("invitation code is empty")
}
if err := o.Admin.CheckInvitationCode(ctx, req.InvitationCode); err != nil {
return nil, err
}
}
// 如果提供了registerTokenH5注册场景或account注册场景验证token
if req.RegisterToken != "" {
// 从Redis获取token关联的数据可能是手机号或account
tokenData, err := o.getRegisterTokenFromRedis(ctx, req.RegisterToken)
if err != nil {
return nil, errs.ErrArgs.WrapMsg("registerToken已失效请重新验证")
}
// 尝试解密(可能是加密存储的)
var tokenDataDecrypted string
decryptedData, err := util.DecryptPhone(tokenData)
if err != nil {
// 如果解密失败,可能是明文存储的(兼容旧数据)
tokenDataDecrypted = tokenData
} else {
tokenDataDecrypted = decryptedData
}
// 判断是account注册还是手机号注册
if req.User.Account != "" {
// account注册场景
if tokenDataDecrypted != req.User.Account {
return nil, errs.ErrArgs.WrapMsg("账号不匹配,请重新验证")
}
} else if req.User.PhoneNumber != "" {
// 手机号注册场景
// 构建请求中的完整手机号
account := o.verifyCodeJoin(req.User.AreaCode, req.User.PhoneNumber)
if tokenDataDecrypted != account {
return nil, errs.ErrArgs.WrapMsg("手机号不匹配,请重新验证")
}
} else {
return nil, errs.ErrArgs.WrapMsg("registerToken需要配合账号或手机号使用")
}
// 验证成功删除token一次性使用
if err := o.delRegisterTokenFromRedis(ctx, req.RegisterToken); err != nil {
log.ZWarn(ctx, "delete register token failed", err, "token", req.RegisterToken)
}
// 已通过token验证不需要验证短信验证码
} else {
// 普通注册场景,验证短信/邮件验证码
// 根据数据库中的 sms_type 配置决定验证码格式
// 检查数据库中的 sms_type 配置
usePlainText, err := o.checkSMSTypeFromDB(ctx)
if err != nil {
log.ZWarn(ctx, "Failed to check sms_type from DB for register, will try decrypt", err)
// 如果查询失败,默认尝试解密(保持原有逻辑)
usePlainText = false
}
var verifyCode string
if usePlainText {
// sms_type 为 true使用明文验证码
verifyCode = req.VerifyCode
} else {
// sms_type 为 false 或不存在需要解密验证码base64解码
decryptedCode, err := util.DecryptVerifyCode(req.VerifyCode)
if err != nil {
log.ZError(ctx, "Failed to decrypt verify code for register", err)
return nil, errs.ErrArgs.WrapMsg("验证码解密失败", "error", err.Error())
}
// 解析验证码和时间戳(格式:验证码-时间戳)
code, timestamp, err := util.ParseVerifyCodeWithTimestamp(decryptedCode)
if err != nil {
log.ZError(ctx, "Failed to parse verify code with timestamp for register", err)
return nil, errs.ErrArgs.WrapMsg("解析验证码失败", "error", err.Error())
}
// 验证时间戳必须在1分钟内
if !util.ValidateTimestamp(timestamp, 30) {
return nil, errs.ErrArgs.WrapMsg("验证码已过期,请重新获取")
}
// 使用解析出的验证码
verifyCode = code
}
if req.User.Email == "" {
if _, err := o.verifyCode(ctx, o.verifyCodeJoin(req.User.AreaCode, req.User.PhoneNumber), verifyCode, phone); err != nil {
return nil, err
}
} else {
if _, err := o.verifyCode(ctx, req.User.Email, verifyCode, mail); err != nil {
return nil, err
}
}
}
}
if req.User.UserID == "" {
for i := 0; i < 20; i++ {
userID := o.genUserID()
_, err := o.Database.GetUser(ctx, userID)
if err == nil {
continue
} else if dbutil.IsDBNotFound(err) {
req.User.UserID = userID
break
} else {
return nil, err
}
}
if req.User.UserID == "" {
return nil, errs.ErrInternalServer.WrapMsg("gen user id failed")
}
} else {
_, err := o.Database.GetUser(ctx, req.User.UserID)
if err == nil {
return nil, errs.ErrArgs.WrapMsg("appoint user id already register")
} else if !dbutil.IsDBNotFound(err) {
return nil, err
}
}
var (
credentials []*chatdb.Credential
registerType int32
)
if req.User.PhoneNumber != "" {
registerType = constant.PhoneRegister
credentials = append(credentials, &chatdb.Credential{
UserID: req.User.UserID,
Account: BuildCredentialPhone(req.User.AreaCode, req.User.PhoneNumber),
Type: constant.CredentialPhone,
AllowChange: true,
})
}
if req.User.Account != "" {
credentials = append(credentials, &chatdb.Credential{
UserID: req.User.UserID,
Account: req.User.Account,
Type: constant.CredentialAccount,
AllowChange: true,
})
registerType = constant.AccountRegister
}
if req.User.Email != "" {
registerType = constant.EmailRegister
credentials = append(credentials, &chatdb.Credential{
UserID: req.User.UserID,
Account: req.User.Email,
Type: constant.CredentialEmail,
AllowChange: true,
})
}
register := &chatdb.Register{
UserID: req.User.UserID,
DeviceID: req.DeviceID,
IP: req.Ip,
Platform: constantpb.PlatformID2Name[int(req.Platform)],
AccountType: "",
Mode: constant.UserMode,
CreateTime: time.Now(),
}
account := &chatdb.Account{
UserID: req.User.UserID,
Password: req.User.Password,
OperatorUserID: mcontext.GetOpUserID(ctx),
ChangeTime: register.CreateTime,
CreateTime: register.CreateTime,
}
attribute := &chatdb.Attribute{
UserID: req.User.UserID,
Account: req.User.Account,
PhoneNumber: req.User.PhoneNumber,
AreaCode: req.User.AreaCode,
Email: req.User.Email,
Nickname: req.User.Nickname,
FaceURL: req.User.FaceURL,
Gender: req.User.Gender,
UserType: req.User.UserType, // 使用传入的用户类型如果没有则默认为0
UserFlag: req.User.UserFlag, // 使用传入的用户标签,如果没有则默认为空
BirthTime: time.UnixMilli(req.User.Birth),
ChangeTime: register.CreateTime,
CreateTime: register.CreateTime,
AllowVibration: constant.DefaultAllowVibration,
AllowBeep: constant.DefaultAllowBeep,
AllowAddFriend: constant.DefaultAllowAddFriend,
RegisterType: registerType,
}
if err := o.Database.RegisterUser(ctx, register, account, attribute, credentials); err != nil {
return nil, err
}
if usedInvitationCode {
if err := o.Admin.UseInvitationCode(ctx, req.User.UserID, req.InvitationCode); err != nil {
log.ZError(ctx, "UseInvitationCode", err, "userID", req.User.UserID, "invitationCode", req.InvitationCode)
}
}
// 检查系统配置是否开启钱包功能
walletConfig, err := o.Database.GetSystemConfig(ctx, "wallet.enabled")
if err == nil && walletConfig.Value == "true" {
// 为新注册用户创建默认钱包
wallet := &chatdb.Wallet{
UserID: req.User.UserID,
Balance: 0,
CreateTime: time.Now(),
UpdateTime: time.Now(),
}
if err := o.Database.CreateWallet(ctx, wallet); err != nil {
log.ZError(ctx, "CreateWallet failed on register", err, "userID", req.User.UserID)
// 不影响注册流程,只记录错误
} else {
log.ZInfo(ctx, "Auto created wallet for user on register", "userID", req.User.UserID)
}
}
// 查询注册奖励配置,如果启用则给用户钱包赠送金额
o.handleRegisterReward(ctx, req.User.UserID)
var resp chat.RegisterUserResp
if req.AutoLogin {
chatToken, err := o.Admin.CreateToken(ctx, req.User.UserID, constant.NormalUser)
if err == nil {
resp.ChatToken = chatToken.Token
} else {
log.ZError(ctx, "Admin CreateToken Failed", err, "userID", req.User.UserID, "platform", req.Platform)
}
// 注册成功后自动登录时,创建登录记录
record := &chatdb.UserLoginRecord{
UserID: req.User.UserID,
LoginTime: time.Now(),
IP: req.Ip,
DeviceID: req.DeviceID,
Platform: constantpb.PlatformID2Name[int(req.Platform)],
}
if err := o.Database.LoginRecord(ctx, record, nil); err != nil {
log.ZError(ctx, "LoginRecord failed on register auto login", err, "userID", req.User.UserID)
// 不影响注册流程,只记录错误
} else {
log.ZInfo(ctx, "Created login record on register auto login", "userID", req.User.UserID, "ip", req.Ip)
}
}
resp.UserID = req.User.UserID
return &resp, nil
}
func (o *chatSvr) Login(ctx context.Context, req *chat.LoginReq) (*chat.LoginResp, error) {
resp := &chat.LoginResp{}
if req.Password == "" && req.VerifyCode == "" {
return nil, errs.ErrArgs.WrapMsg("password or code must be set")
}
var (
err error
credential *chatdb.Credential
acc string
)
switch {
case req.Account != "":
acc = req.Account
case req.PhoneNumber != "":
if req.AreaCode == "" {
return nil, errs.ErrArgs.WrapMsg("area code must")
}
if !strings.HasPrefix(req.AreaCode, "+") {
req.AreaCode = "+" + req.AreaCode
}
if _, err := strconv.ParseUint(req.AreaCode[1:], 10, 64); err != nil {
return nil, errs.ErrArgs.WrapMsg("area code must be number")
}
acc = BuildCredentialPhone(req.AreaCode, req.PhoneNumber)
case req.Email != "":
acc = req.Email
default:
return nil, errs.ErrArgs.WrapMsg("account or phone number or email must be set")
}
credential, err = o.Database.TakeCredentialByAccount(ctx, acc)
if err != nil {
if dbutil.IsDBNotFound(err) {
return nil, eerrs.ErrAccountNotFound.WrapMsg("user unregistered")
}
return nil, err
}
if err := o.Admin.CheckLogin(ctx, credential.UserID, req.Ip); err != nil {
// 检查是否是账户被封禁的错误,确保错误码正确传递
// 即使错误被包装,错误字符串中也会包含原始错误信息
errStr := err.Error()
if strings.Contains(errStr, "20015") || strings.Contains(errStr, "AccountBlocked") || strings.Contains(errStr, "账户已被封禁") {
log.ZInfo(ctx, "Detected AccountBlocked error in CheckLogin, returning explicit error", "originalError", errStr)
return nil, eerrs.ErrAccountBlocked.WrapMsg("账户已被封禁")
}
return nil, err
}
var verifyCodeID *string
if req.Password == "" {
var (
id string
)
if req.Email == "" {
account := o.verifyCodeJoin(req.AreaCode, req.PhoneNumber)
id, err = o.verifyCode(ctx, account, req.VerifyCode, phone)
if err != nil {
return nil, err
}
} else {
account := req.Email
id, err = o.verifyCode(ctx, account, req.VerifyCode, mail)
if err != nil {
return nil, err
}
}
if id != "" {
verifyCodeID = &id
}
} else {
account, err := o.Database.TakeAccount(ctx, credential.UserID)
if err != nil {
return nil, err
}
if account.Password != req.Password {
return nil, eerrs.ErrPassword.Wrap()
}
}
chatToken, err := o.Admin.CreateToken(ctx, credential.UserID, constant.NormalUser)
if err != nil {
return nil, err
}
record := &chatdb.UserLoginRecord{
UserID: credential.UserID,
LoginTime: time.Now(),
IP: req.Ip,
DeviceID: req.DeviceID,
Platform: constantpb.PlatformIDToName(int(req.Platform)),
}
if err := o.Database.LoginRecord(ctx, record, verifyCodeID); err != nil {
log.ZError(ctx, "LoginRecord failed", err, "userID", credential.UserID)
return nil, err
}
if verifyCodeID != nil {
if err := o.Database.DelVerifyCode(ctx, *verifyCodeID); err != nil {
log.ZError(ctx, "DelVerifyCode failed", err, "verifyCodeID", *verifyCodeID)
return nil, err
}
}
// 检查系统配置是否开启钱包功能
walletConfig, err := o.Database.GetSystemConfig(ctx, "wallet.enabled")
if err == nil && walletConfig.Value == "true" {
// 检查用户钱包是否存在
_, err := o.Database.GetWallet(ctx, credential.UserID)
if err != nil {
if dbutil.IsDBNotFound(err) {
// 钱包不存在,创建默认钱包
wallet := &chatdb.Wallet{
UserID: credential.UserID,
Balance: 0,
CreateTime: time.Now(),
UpdateTime: time.Now(),
}
if err := o.Database.CreateWallet(ctx, wallet); err != nil {
log.ZError(ctx, "CreateWallet failed", err, "userID", credential.UserID)
// 不影响登录流程,只记录错误
} else {
log.ZInfo(ctx, "Auto created wallet for user on login", "userID", credential.UserID)
}
} else {
// 其他错误,只记录
log.ZError(ctx, "GetWallet failed", err, "userID", credential.UserID)
}
}
}
resp.UserID = credential.UserID
resp.ChatToken = chatToken.Token
return resp, nil
}
func (o *chatSvr) GetCaptchaImage(ctx context.Context, req *chat.GetCaptchaImageReq) (*chat.GetCaptchaImageResp, error) {
// 生成6位随机数字验证码
code := o.genVerifyCode()
// 生成唯一的验证码ID
captchaID := uuid.New().String()
// 将验证码存储到Redis过期时间5分钟
if err := o.storeCaptchaToRedis(ctx, captchaID, code); err != nil {
log.ZError(ctx, "store captcha to redis failed", err, "captchaID", captchaID)
return nil, err
}
log.ZDebug(ctx, "generate captcha code success", "captchaID", captchaID)
return &chat.GetCaptchaImageResp{
CaptchaID: captchaID,
Code: code,
}, nil
}
// storeCaptchaToRedis 将验证码存储到Redis过期时间5分钟
func (o *chatSvr) storeCaptchaToRedis(ctx context.Context, captchaID, code string) error {
if o.rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "captcha:" + captchaID
return errs.Wrap(o.rdb.Set(ctx, key, code, 5*time.Minute).Err())
}
// storePhoneToRedis 将手机号存储到Redis过期时间5分钟
func (o *chatSvr) storePhoneToRedis(ctx context.Context, key, phoneNumber string) error {
if o.rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
return errs.Wrap(o.rdb.Set(ctx, key, phoneNumber, 5*time.Minute).Err())
}
// getCaptchaFromRedis 从Redis获取图片验证码
func (o *chatSvr) getCaptchaFromRedis(ctx context.Context, captchaID string) (string, error) {
if o.rdb == nil {
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "captcha:" + captchaID
code, err := o.rdb.Get(ctx, key).Result()
if err != nil {
return "", errs.WrapMsg(err, "验证码不存在或已过期")
}
return code, nil
}
// delCaptchaFromRedis 从Redis删除图片验证码
func (o *chatSvr) delCaptchaFromRedis(ctx context.Context, captchaID string) error {
if o.rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "captcha:" + captchaID
return errs.Wrap(o.rdb.Del(ctx, key).Err())
}
// getPhoneFromRedis 从Redis获取手机号
func (o *chatSvr) getPhoneFromRedis(ctx context.Context, key string) (string, error) {
if o.rdb == nil {
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
phone, err := o.rdb.Get(ctx, key).Result()
if err != nil {
return "", errs.WrapMsg(err, "手机号不存在或已过期")
}
return phone, nil
}
// storeRegisterTokenToRedis 将注册token存储到Redis关联手机号有效期120秒
func (o *chatSvr) storeRegisterTokenToRedis(ctx context.Context, token, phoneNumber string) error {
if o.rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "h5_register_token:" + token
return errs.Wrap(o.rdb.Set(ctx, key, phoneNumber, 120*time.Second).Err())
}
// getRegisterTokenFromRedis 从Redis获取注册token关联的手机号
func (o *chatSvr) getRegisterTokenFromRedis(ctx context.Context, token string) (string, error) {
if o.rdb == nil {
return "", errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "h5_register_token:" + token
phone, err := o.rdb.Get(ctx, key).Result()
if err != nil {
return "", errs.WrapMsg(err, "token不存在或已过期")
}
return phone, nil
}
// delRegisterTokenFromRedis 从Redis删除注册token一次性使用
func (o *chatSvr) delRegisterTokenFromRedis(ctx context.Context, token string) error {
if o.rdb == nil {
return errs.ErrInternalServer.WrapMsg("redis client not initialized")
}
key := "h5_register_token:" + token
return errs.Wrap(o.rdb.Del(ctx, key).Err())
}
// handleRegisterReward 处理注册奖励:查询系统配置,如果启用则给用户钱包赠送金额
func (o *chatSvr) handleRegisterReward(ctx context.Context, userID string) {
// 查询系统配置 key="register"
config, err := o.Database.GetSystemConfig(ctx, "register")
if err != nil {
// 配置不存在或查询失败,记录日志但不影响注册流程
if dbutil.IsDBNotFound(err) {
log.ZDebug(ctx, "Register reward config not found", "key", "register")
} else {
log.ZWarn(ctx, "Get register reward config failed", err, "key", "register", "userID", userID)
}
return
}
// 检查配置是否启用
if !config.Enabled {
log.ZDebug(ctx, "Register reward config is disabled", "key", "register", "userID", userID)
return
}
// 解析配置值中的金额
var amount int64
switch config.ValueType {
case chatdb.ConfigValueTypeNumber:
// 数字类型:直接解析为整数(单位:分)
amountFloat, err := strconv.ParseFloat(config.Value, 64)
if err != nil {
log.ZWarn(ctx, "Parse register reward amount failed", err, "value", config.Value, "valueType", config.ValueType, "userID", userID)
return
}
amount = int64(amountFloat)
log.ZDebug(ctx, "Parsed register reward amount from number type", "value", config.Value, "amount", amount, "userID", userID)
case chatdb.ConfigValueTypeString:
// 字符串类型:尝试解析为数字
amountFloat, err := strconv.ParseFloat(config.Value, 64)
if err != nil {
log.ZWarn(ctx, "Parse register reward amount from string failed", err, "value", config.Value, "valueType", config.ValueType, "userID", userID)
return
}
amount = int64(amountFloat)
log.ZDebug(ctx, "Parsed register reward amount from string type", "value", config.Value, "amount", amount, "userID", userID)
default:
log.ZWarn(ctx, "Register reward config value type not supported", nil, "valueType", config.ValueType, "userID", userID)
return
}
// 金额必须大于0
if amount <= 0 {
log.ZDebug(ctx, "Register reward amount is zero or negative", "amount", amount, "userID", userID)
return
}
// 给用户钱包充值
beforeBalance, afterBalance, err := o.Database.IncrementWalletBalance(ctx, userID, amount)
if err != nil {
log.ZError(ctx, "Increment wallet balance for register reward failed", err, "userID", userID, "amount", amount)
return
}
// 创建余额变动记录
record := &chatdb.WalletBalanceRecord{
ID: uuid.New().String(),
UserID: userID,
Amount: amount,
Type: 5, // 5-奖励
BeforeBalance: beforeBalance,
AfterBalance: afterBalance,
Remark: "注册奖励",
CreateTime: time.Now(),
}
if err := o.Database.CreateWalletBalanceRecord(ctx, record); err != nil {
// 记录创建失败不影响充值,只记录警告日志
log.ZWarn(ctx, "Create wallet balance record for register reward failed", err, "userID", userID, "amount", amount)
} else {
log.ZInfo(ctx, "Register reward granted successfully", "userID", userID, "amount", amount, "beforeBalance", beforeBalance, "afterBalance", afterBalance)
}
}
// checkSMSTypeFromDB 检查数据库中的 sms_type 配置
// 返回: true 表示 sms_type 为 true使用明文验证码false 表示 sms_type 为 false 或不存在(需要解密)
func (o *chatSvr) checkSMSTypeFromDB(ctx context.Context) (bool, error) {
// 从 SystemConfig 集合system_configs中读取 sms_type 配置
smsTypeConfig, err := o.Database.GetSystemConfig(ctx, "sms_type")
if err != nil {
if dbutil.IsDBNotFound(err) {
// SystemConfig 集合中没有 sms_type 配置,返回 false需要解密保持原有逻辑
return false, nil
}
// 从 SystemConfig 集合查询出错,返回错误
return false, errs.WrapMsg(err, "failed to get sms_type config")
}
// 检查 sms_type 的值
var usePlainText bool
switch smsTypeConfig.ValueType {
case chatdb.ConfigValueTypeBool:
// 布尔类型:直接解析
usePlainText = smsTypeConfig.Value == "true" || smsTypeConfig.Value == "1"
case chatdb.ConfigValueTypeString:
// 字符串类型:检查是否为 "true" 或 "1"
usePlainText = strings.ToLower(smsTypeConfig.Value) == "true" || smsTypeConfig.Value == "1"
case chatdb.ConfigValueTypeNumber:
// 数字类型1 表示 true0 表示 false
val, err := strconv.ParseInt(smsTypeConfig.Value, 10, 64)
if err == nil {
usePlainText = val != 0
} else {
usePlainText = false
}
default:
// 其他类型默认使用加密格式false
usePlainText = false
}
return usePlainText, nil
}
// getSMSClientFromDB 从数据库读取短信配置
// 逻辑:先读数据库,如果数据库没有 sms_type 配置,返回 nil 使用 yml 配置
// 如果数据库有 sms_type 且为 true使用数据库中的 sms_info 配置创建 SMS 客户端
// 如果数据库有 sms_type 但为 false返回 nil 使用 yml 配置
func (o *chatSvr) getSMSClientFromDB(ctx context.Context) (sms.SMS, error) {
// 从 SystemConfig 集合system_configs中读取 sms_type 配置
smsTypeConfig, err := o.Database.GetSystemConfig(ctx, "sms_type")
if err != nil {
if dbutil.IsDBNotFound(err) {
// SystemConfig 集合中没有 sms_type 配置,返回 nil 使用 yml 配置
return nil, nil
}
// 从 SystemConfig 集合查询出错,返回错误
log.ZError(ctx, "Failed to get sms_type config from SystemConfig collection", err)
return nil, errs.WrapMsg(err, "failed to get sms_type config")
}
// 检查 sms_type 的值
var useDBSMS bool
switch smsTypeConfig.ValueType {
case chatdb.ConfigValueTypeBool:
// 布尔类型:直接解析
useDBSMS = smsTypeConfig.Value == "true" || smsTypeConfig.Value == "1"
case chatdb.ConfigValueTypeString:
// 字符串类型:检查是否为 "true" 或 "1"
useDBSMS = strings.ToLower(smsTypeConfig.Value) == "true" || smsTypeConfig.Value == "1"
case chatdb.ConfigValueTypeNumber:
// 数字类型1 表示 true0 表示 false
val, err := strconv.ParseInt(smsTypeConfig.Value, 10, 64)
if err == nil {
useDBSMS = val != 0
}
default:
// 其他类型,默认不使用数据库配置
return nil, nil
}
// 如果 sms_type 为 false返回 nil 使用 yml 配置
if !useDBSMS {
return nil, nil
}
// sms_type 为 true从 SystemConfig 集合读取 sms_info 配置
smsInfoConfig, err := o.Database.GetSystemConfig(ctx, "sms_info")
if err != nil {
if dbutil.IsDBNotFound(err) {
log.ZWarn(ctx, "sms_info config not found but sms_type is true, will use yml config", nil)
return nil, nil
}
log.ZError(ctx, "Failed to get sms_info config from SystemConfig collection", err)
return nil, errs.WrapMsg(err, "failed to get sms_info config")
}
// 解析 sms_info JSON
type SMSInfo struct {
AccessKeyID string `json:"accessKeyId"`
AccessKeySecret string `json:"accessKeySecret"`
Endpoint string `json:"endpoint"`
SignName string `json:"signName"`
Type string `json:"type"`
VerificationCodeTemplateCode string `json:"verificationCodeTemplateCode"`
}
var smsInfo SMSInfo
if err := json.Unmarshal([]byte(smsInfoConfig.Value), &smsInfo); err != nil {
log.ZError(ctx, "Failed to parse sms_info JSON", err)
return nil, errs.WrapMsg(err, "failed to parse sms_info JSON")
}
// 根据 type 字段创建对应的 SMS 客户端
switch strings.ToLower(smsInfo.Type) {
case "bao":
client, err := sms.NewBao(
smsInfo.Endpoint,
smsInfo.AccessKeyID,
smsInfo.AccessKeySecret,
smsInfo.SignName,
smsInfo.VerificationCodeTemplateCode,
)
if err != nil {
log.ZError(ctx, "Failed to create Bao SMS client from database config", err)
return nil, errs.WrapMsg(err, "failed to create Bao SMS client")
}
log.ZInfo(ctx, "Created Bao SMS client from database config")
return client, nil
case "ali":
client, err := sms.NewAli(
smsInfo.Endpoint,
smsInfo.AccessKeyID,
smsInfo.AccessKeySecret,
smsInfo.SignName,
smsInfo.VerificationCodeTemplateCode,
)
if err != nil {
log.ZError(ctx, "Failed to create Ali SMS client from database config", err)
return nil, errs.WrapMsg(err, "failed to create Ali SMS client")
}
log.ZInfo(ctx, "Created Ali SMS client from database config")
return client, nil
default:
log.ZWarn(ctx, "Unknown SMS type in sms_info config, using default SMS client", nil, "type", smsInfo.Type)
return nil, nil
}
}