289 lines
8.0 KiB
Go
289 lines
8.0 KiB
Go
// Copyright © 2023 OpenIM open source community. All rights reserved.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package chat
|
||
|
||
import (
|
||
"context"
|
||
"time"
|
||
|
||
chatdb "git.imall.cloud/openim/chat/pkg/common/db/table/chat"
|
||
"git.imall.cloud/openim/chat/pkg/common/mctx"
|
||
"git.imall.cloud/openim/chat/pkg/protocol/chat"
|
||
"github.com/openimsdk/tools/errs"
|
||
)
|
||
|
||
// ==================== 定时任务相关 RPC ====================
|
||
|
||
// CreateScheduledTask 创建定时任务
|
||
func (o *chatSvr) CreateScheduledTask(ctx context.Context, req *chat.CreateScheduledTaskReq) (*chat.CreateScheduledTaskResp, error) {
|
||
// 获取当前用户ID
|
||
userID, _, err := mctx.Check(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证必填字段
|
||
if req.Name == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("task name is required")
|
||
}
|
||
if req.CronExpression == "" {
|
||
return nil, errs.ErrArgs.WrapMsg("cron expression is required")
|
||
}
|
||
if len(req.Messages) == 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("messages is required")
|
||
}
|
||
if len(req.RecvIDs) == 0 && len(req.GroupIDs) == 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("recvIDs or groupIDs is required")
|
||
}
|
||
|
||
// 验证消息类型
|
||
for _, msg := range req.Messages {
|
||
if msg.Type < 1 || msg.Type > 3 {
|
||
return nil, errs.ErrArgs.WrapMsg("invalid message type")
|
||
}
|
||
}
|
||
|
||
// 转换消息列表
|
||
messages := make([]chatdb.Message, 0, len(req.Messages))
|
||
for _, msg := range req.Messages {
|
||
messages = append(messages, chatdb.Message{
|
||
Type: msg.Type,
|
||
Content: msg.Content,
|
||
Thumbnail: msg.Thumbnail,
|
||
Duration: msg.Duration,
|
||
FileSize: msg.FileSize,
|
||
Width: msg.Width,
|
||
Height: msg.Height,
|
||
})
|
||
}
|
||
|
||
// 创建定时任务对象
|
||
task := &chatdb.ScheduledTask{
|
||
UserID: userID,
|
||
Name: req.Name,
|
||
CronExpression: req.CronExpression,
|
||
Messages: messages,
|
||
RecvIDs: req.RecvIDs,
|
||
GroupIDs: req.GroupIDs,
|
||
Status: req.Status,
|
||
CreateTime: time.Now(),
|
||
UpdateTime: time.Now(),
|
||
}
|
||
|
||
// 如果状态未设置,默认为启用
|
||
if task.Status == 0 {
|
||
task.Status = 1
|
||
}
|
||
|
||
// 保存到数据库
|
||
if err := o.Database.CreateScheduledTask(ctx, task); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &chat.CreateScheduledTaskResp{
|
||
TaskID: task.ID,
|
||
}, nil
|
||
}
|
||
|
||
// GetScheduledTask 获取定时任务详情
|
||
func (o *chatSvr) GetScheduledTask(ctx context.Context, req *chat.GetScheduledTaskReq) (*chat.GetScheduledTaskResp, error) {
|
||
// 获取当前用户ID
|
||
userID, _, err := mctx.Check(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取任务
|
||
task, err := o.Database.GetScheduledTask(ctx, req.TaskID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证是否为当前用户的任务
|
||
if task.UserID != userID {
|
||
return nil, errs.ErrNoPermission.WrapMsg("not your task")
|
||
}
|
||
|
||
return &chat.GetScheduledTaskResp{
|
||
Task: convertScheduledTaskToProto(task),
|
||
}, nil
|
||
}
|
||
|
||
// GetScheduledTasks 获取定时任务列表
|
||
func (o *chatSvr) GetScheduledTasks(ctx context.Context, req *chat.GetScheduledTasksReq) (*chat.GetScheduledTasksResp, error) {
|
||
// 获取当前用户ID
|
||
userID, _, err := mctx.Check(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取任务列表
|
||
total, tasks, err := o.Database.GetScheduledTasksByUserID(ctx, userID, req.Pagination)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 转换为响应格式
|
||
taskInfos := make([]*chat.ScheduledTaskInfo, 0, len(tasks))
|
||
for _, task := range tasks {
|
||
taskInfos = append(taskInfos, convertScheduledTaskToProto(task))
|
||
}
|
||
|
||
return &chat.GetScheduledTasksResp{
|
||
Total: uint32(total),
|
||
Tasks: taskInfos,
|
||
}, nil
|
||
}
|
||
|
||
// UpdateScheduledTask 更新定时任务
|
||
func (o *chatSvr) UpdateScheduledTask(ctx context.Context, req *chat.UpdateScheduledTaskReq) (*chat.UpdateScheduledTaskResp, error) {
|
||
// 获取当前用户ID
|
||
userID, _, err := mctx.Check(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取任务,验证所有权
|
||
task, err := o.Database.GetScheduledTask(ctx, req.TaskID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if task.UserID != userID {
|
||
return nil, errs.ErrNoPermission.WrapMsg("not your task")
|
||
}
|
||
|
||
// 构建更新数据
|
||
updateData := make(map[string]any)
|
||
if req.Name != "" {
|
||
updateData["name"] = req.Name
|
||
}
|
||
if req.CronExpression != "" {
|
||
updateData["cron_expression"] = req.CronExpression
|
||
}
|
||
if len(req.Messages) > 0 {
|
||
// 验证消息类型
|
||
for _, msg := range req.Messages {
|
||
if msg.Type < 1 || msg.Type > 3 {
|
||
return nil, errs.ErrArgs.WrapMsg("invalid message type")
|
||
}
|
||
}
|
||
// 转换消息列表
|
||
messages := make([]chatdb.Message, 0, len(req.Messages))
|
||
for _, msg := range req.Messages {
|
||
messages = append(messages, chatdb.Message{
|
||
Type: msg.Type,
|
||
Content: msg.Content,
|
||
Thumbnail: msg.Thumbnail,
|
||
Duration: msg.Duration,
|
||
FileSize: msg.FileSize,
|
||
Width: msg.Width,
|
||
Height: msg.Height,
|
||
})
|
||
}
|
||
updateData["messages"] = messages
|
||
}
|
||
if req.RecvIDs != nil {
|
||
updateData["recv_ids"] = req.RecvIDs
|
||
}
|
||
if req.GroupIDs != nil {
|
||
updateData["group_ids"] = req.GroupIDs
|
||
}
|
||
// status字段:0-已禁用,1-已启用,允许设置为0
|
||
if req.Status == 0 || req.Status == 1 {
|
||
updateData["status"] = req.Status
|
||
}
|
||
|
||
// 验证:如果更新后没有接收者,返回错误
|
||
if req.RecvIDs != nil && req.GroupIDs != nil {
|
||
if len(req.RecvIDs) == 0 && len(req.GroupIDs) == 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("recvIDs or groupIDs is required")
|
||
}
|
||
} else if req.RecvIDs != nil && len(req.RecvIDs) == 0 {
|
||
// 如果只更新了RecvIDs且为空,检查GroupIDs是否也为空
|
||
if len(task.GroupIDs) == 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("recvIDs or groupIDs is required")
|
||
}
|
||
} else if req.GroupIDs != nil && len(req.GroupIDs) == 0 {
|
||
// 如果只更新了GroupIDs且为空,检查RecvIDs是否也为空
|
||
if len(task.RecvIDs) == 0 {
|
||
return nil, errs.ErrArgs.WrapMsg("recvIDs or groupIDs is required")
|
||
}
|
||
}
|
||
|
||
// 更新任务
|
||
if err := o.Database.UpdateScheduledTask(ctx, req.TaskID, updateData); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &chat.UpdateScheduledTaskResp{}, nil
|
||
}
|
||
|
||
// DeleteScheduledTask 删除定时任务
|
||
func (o *chatSvr) DeleteScheduledTask(ctx context.Context, req *chat.DeleteScheduledTaskReq) (*chat.DeleteScheduledTaskResp, error) {
|
||
// 获取当前用户ID
|
||
userID, _, err := mctx.Check(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 验证所有权(批量验证)
|
||
for _, taskID := range req.TaskIDs {
|
||
task, err := o.Database.GetScheduledTask(ctx, taskID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if task.UserID != userID {
|
||
return nil, errs.ErrNoPermission.WrapMsg("not your task")
|
||
}
|
||
}
|
||
|
||
// 删除任务
|
||
if err := o.Database.DeleteScheduledTask(ctx, req.TaskIDs); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &chat.DeleteScheduledTaskResp{}, nil
|
||
}
|
||
|
||
// convertScheduledTaskToProto 将数据库模型转换为 protobuf 消息
|
||
func convertScheduledTaskToProto(task *chatdb.ScheduledTask) *chat.ScheduledTaskInfo {
|
||
messages := make([]*chat.ScheduledTaskMessage, 0, len(task.Messages))
|
||
for _, msg := range task.Messages {
|
||
messages = append(messages, &chat.ScheduledTaskMessage{
|
||
Type: msg.Type,
|
||
Content: msg.Content,
|
||
Thumbnail: msg.Thumbnail,
|
||
Duration: msg.Duration,
|
||
FileSize: msg.FileSize,
|
||
Width: msg.Width,
|
||
Height: msg.Height,
|
||
})
|
||
}
|
||
|
||
return &chat.ScheduledTaskInfo{
|
||
Id: task.ID,
|
||
UserID: task.UserID,
|
||
Name: task.Name,
|
||
CronExpression: task.CronExpression,
|
||
Messages: messages,
|
||
RecvIDs: task.RecvIDs,
|
||
GroupIDs: task.GroupIDs,
|
||
Status: task.Status,
|
||
CreateTime: task.CreateTime.UnixMilli(),
|
||
UpdateTime: task.UpdateTime.UnixMilli(),
|
||
}
|
||
}
|