Files
scheduler-backend/internal/urlrewrite/runner.go
2026-05-28 13:29:24 +08:00

512 lines
15 KiB
Go

package urlrewrite
import (
"context"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
const (
backupCollectionName = "url_rewrite_backups"
msgCachePrefix = "MSG_CACHE:"
ModeDryRun = "dry-run"
ModeApply = "apply"
ModeRollback = "rollback"
ModeInvalidateCache = "invalidate-cache"
)
type Params struct {
OldPrefix string `json:"oldPrefix"`
NewPrefix string `json:"newPrefix"`
Mode string `json:"mode"`
BatchID string `json:"batchId"`
SampleSize int `json:"sampleSize"`
}
type RedisConfig struct {
Addr string
Username string
Password string
Client *redis.Client
}
type Summary struct {
Collection string
Scanned int64
Updated int64
Samples []string
}
type fieldUpdate struct {
Path string
OldValue string
NewValue string
}
type msgDoc struct {
ID any `bson:"_id"`
Doc string `bson:"doc_id"`
Msgs []msgEntryModel `bson:"msgs"`
}
type msgEntryModel struct {
Msg *msgPayload `bson:"msg"`
}
type msgPayload struct {
Content string `bson:"content"`
SenderFaceURL string `bson:"sender_face_url"`
Seq int64 `bson:"seq"`
}
type LogFunc func(msg string, args ...any)
func Run(ctx context.Context, db *mongo.Database, params Params, redisCfg RedisConfig, logf LogFunc) (string, error) {
if err := validateParams(params); err != nil {
return "", err
}
if params.SampleSize <= 0 {
params.SampleSize = 5
}
if params.Mode == "" {
params.Mode = ModeDryRun
}
if params.Mode == ModeApply && params.BatchID == "" {
params.BatchID = "urlrewrite-" + time.Now().Format("20060102150405")
}
backupColl := db.Collection(backupCollectionName)
if params.Mode == ModeRollback {
return params.BatchID, rollbackBatch(ctx, backupColl, db, params, logf)
}
if params.Mode == ModeInvalidateCache {
return params.BatchID, invalidateBatchCache(ctx, db, redisCfg, params, logf)
}
summaries := []Summary{
rewriteSimpleFieldCollection(ctx, db.Collection("wallets"), backupColl, "real_name_auth.id_card_photo_front", params),
rewriteSimpleFieldCollection(ctx, db.Collection("wallets"), backupColl, "real_name_auth.id_card_photo_back", params),
rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "content", params),
rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "thumbnail", params),
rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "link_url", params),
rewriteSimpleFieldCollection(ctx, db.Collection("attributes"), backupColl, "face_url", params),
rewriteSimpleFieldCollection(ctx, db.Collection("attribute"), backupColl, "face_url", params),
rewriteSimpleFieldCollection(ctx, db.Collection("user"), backupColl, "face_url", params),
rewriteSimpleFieldCollection(ctx, db.Collection("group"), backupColl, "face_url", params),
rewriteSimpleFieldCollection(ctx, db.Collection("group_member"), backupColl, "face_url", params),
}
msgSummary, err := rewriteMsgCollection(ctx, db.Collection("msg"), backupColl, params)
if err != nil {
return "", fmt.Errorf("rewrite msg collection: %w", err)
}
summaries = append(summaries, msgSummary)
for _, s := range summaries {
logf("[%s] scanned=%d updated=%d", s.Collection, s.Scanned, s.Updated)
for _, sample := range s.Samples {
logf("[%s] sample: %s", s.Collection, sample)
}
}
switch params.Mode {
case ModeDryRun:
logf("dry-run complete, rerun with mode=apply to write changes")
case ModeApply:
logf("apply complete, batch_id=%s", params.BatchID)
}
return params.BatchID, nil
}
func validateParams(p Params) error {
switch p.Mode {
case ModeDryRun, ModeApply, ModeRollback, ModeInvalidateCache:
case "":
default:
return fmt.Errorf("invalid mode: %s", p.Mode)
}
if p.Mode == ModeRollback || p.Mode == ModeInvalidateCache {
if p.BatchID == "" {
return fmt.Errorf("batchId is required for %s", p.Mode)
}
} else {
if p.OldPrefix == "" || p.NewPrefix == "" {
return fmt.Errorf("oldPrefix and newPrefix are required")
}
}
return nil
}
func invalidateBatchCache(ctx context.Context, db *mongo.Database, redisCfg RedisConfig, params Params, logf LogFunc) error {
if redisCfg.Addr == "" {
return fmt.Errorf("REDIS_ADDR is not configured")
}
backupColl := db.Collection(backupCollectionName)
cursor, err := backupColl.Find(ctx, bson.M{
"batch_id": params.BatchID,
"collection": "msg",
"field": bson.M{"$regex": primitive.Regex{Pattern: `^msgs\.\d+\.msg\.`}},
})
if err != nil {
return fmt.Errorf("find backup docs: %w", err)
}
defer cursor.Close(ctx)
type backupDoc struct {
DocumentID any `bson:"document_id"`
Field string `bson:"field"`
}
docIndexes := make(map[string]map[int]struct{})
for cursor.Next(ctx) {
var item backupDoc
if err := cursor.Decode(&item); err != nil {
return fmt.Errorf("decode backup doc: %w", err)
}
docID, ok := item.DocumentID.(primitive.ObjectID)
if !ok {
continue
}
msgIndex, ok := parseMsgFieldIndex(item.Field)
if !ok {
continue
}
id := docID.Hex()
if _, exists := docIndexes[id]; !exists {
docIndexes[id] = make(map[int]struct{})
}
docIndexes[id][msgIndex] = struct{}{}
}
msgColl := db.Collection("msg")
var keys []string
for hexID, indexes := range docIndexes {
docObjectID, err := primitive.ObjectIDFromHex(hexID)
if err != nil {
return fmt.Errorf("parse object id %s: %w", hexID, err)
}
var doc msgDoc
if err := msgColl.FindOne(ctx, bson.M{"_id": docObjectID}).Decode(&doc); err != nil {
return fmt.Errorf("find msg doc %s: %w", hexID, err)
}
conversationID := trimDocIDSuffix(doc.Doc)
if conversationID == "" {
continue
}
for idx := range indexes {
if idx < 0 || idx >= len(doc.Msgs) || doc.Msgs[idx].Msg == nil || doc.Msgs[idx].Msg.Seq <= 0 {
continue
}
keys = append(keys, msgCachePrefix+conversationID+":"+strconv.Itoa(int(doc.Msgs[idx].Msg.Seq)))
}
}
keys = deduplicateStrings(keys)
if len(keys) == 0 {
logf("invalidate-cache complete, batch_id=%s keys=0", params.BatchID)
return nil
}
rdb := redisCfg.Client
shouldClose := false
if rdb == nil {
var err error
rdb, err = newRedisClient(ctx, redisCfg.Addr, redisCfg.Username, redisCfg.Password)
if err != nil {
return fmt.Errorf("connect redis: %w", err)
}
shouldClose = true
}
if shouldClose {
defer rdb.Close()
}
if err := rdb.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("redis del: %w", err)
}
logf("invalidate-cache complete, batch_id=%s keys=%d", params.BatchID, len(keys))
return nil
}
func rewriteSimpleFieldCollection(ctx context.Context, coll, backupColl *mongo.Collection, field string, params Params) Summary {
filter := bson.M{field: bson.M{"$regex": primitive.Regex{Pattern: "^" + regexp.QuoteMeta(params.OldPrefix)}}}
cursor, err := coll.Find(ctx, filter)
if err != nil {
return Summary{Collection: coll.Name() + "." + field, Samples: []string{fmt.Sprintf("find error: %v", err)}}
}
defer cursor.Close(ctx)
summary := Summary{Collection: coll.Name() + "." + field}
for cursor.Next(ctx) {
summary.Scanned++
var doc bson.M
if err := cursor.Decode(&doc); err != nil {
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("decode error: %v", err))
}
continue
}
original, ok := nestedString(doc, field)
if !ok {
continue
}
rewritten, changed := rewriteStringValue(original, params.OldPrefix, params.NewPrefix)
if !changed {
continue
}
updates := []fieldUpdate{{
Path: field,
OldValue: original,
NewValue: rewritten,
}}
summary.Updated++
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("%s -> %s", original, rewritten))
}
if params.Mode != ModeApply {
continue
}
if err := insertBackupDocs(ctx, backupColl, params.BatchID, coll.Name(), doc["_id"], updates); err != nil {
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("backup error: %v", err))
}
continue
}
if _, err := coll.UpdateByID(ctx, doc["_id"], bson.M{"$set": bson.M{field: rewritten}}); err != nil {
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("update error: %v", err))
}
}
}
return summary
}
func rewriteMsgCollection(ctx context.Context, coll, backupColl *mongo.Collection, params Params) (Summary, error) {
cursor, err := coll.Find(ctx, bson.M{
"$or": bson.A{
bson.M{"msgs.msg.content": bson.M{"$regex": primitive.Regex{Pattern: regexp.QuoteMeta(params.OldPrefix)}}},
bson.M{"msgs.msg.sender_face_url": bson.M{"$regex": primitive.Regex{Pattern: "^" + regexp.QuoteMeta(params.OldPrefix)}}},
},
})
if err != nil {
return Summary{}, fmt.Errorf("find msg docs: %w", err)
}
defer cursor.Close(ctx)
summary := Summary{Collection: coll.Name()}
for cursor.Next(ctx) {
summary.Scanned++
var doc msgDoc
if err := cursor.Decode(&doc); err != nil {
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("decode error: %v", err))
}
continue
}
updates, changed, samples := collectMsgFieldUpdates(doc, params.OldPrefix, params.NewPrefix, params.SampleSize-len(summary.Samples))
if !changed {
continue
}
summary.Updated++
summary.Samples = append(summary.Samples, samples...)
if params.Mode != ModeApply {
continue
}
if err := insertBackupDocs(ctx, backupColl, params.BatchID, coll.Name(), doc.ID, updates); err != nil {
if len(summary.Samples) < params.SampleSize {
summary.Samples = append(summary.Samples, fmt.Sprintf("backup error: %v", err))
}
continue
}
if _, err := coll.UpdateByID(ctx, doc.ID, bson.M{"$set": toUpdateMap(updates)}); err != nil {
return summary, fmt.Errorf("update msg doc: %w", err)
}
}
return summary, nil
}
func collectMsgFieldUpdates(doc msgDoc, oldPrefix, newPrefix string, sampleBudget int) ([]fieldUpdate, bool, []string) {
var updates []fieldUpdate
var samples []string
for i := range doc.Msgs {
if doc.Msgs[i].Msg == nil {
continue
}
if doc.Msgs[i].Msg.Content != "" {
rewritten, contentChanged, err := rewriteJSONContent(doc.Msgs[i].Msg.Content, oldPrefix, newPrefix)
if err == nil && contentChanged {
updates = append(updates, fieldUpdate{
Path: fmt.Sprintf("msgs.%d.msg.content", i),
OldValue: doc.Msgs[i].Msg.Content,
NewValue: rewritten,
})
if len(samples) < sampleBudget {
samples = append(samples, fmt.Sprintf("doc=%s msgs.%d.msg.content updated", doc.Doc, i))
}
}
}
if doc.Msgs[i].Msg.SenderFaceURL != "" {
rewritten, faceChanged := rewriteStringValue(doc.Msgs[i].Msg.SenderFaceURL, oldPrefix, newPrefix)
if faceChanged {
updates = append(updates, fieldUpdate{
Path: fmt.Sprintf("msgs.%d.msg.sender_face_url", i),
OldValue: doc.Msgs[i].Msg.SenderFaceURL,
NewValue: rewritten,
})
if len(samples) < sampleBudget {
samples = append(samples, fmt.Sprintf("doc=%s msgs.%d.msg.sender_face_url updated", doc.Doc, i))
}
}
}
}
return updates, len(updates) > 0, samples
}
func toUpdateMap(updates []fieldUpdate) bson.M {
sets := bson.M{}
for _, u := range updates {
sets[u.Path] = u.NewValue
}
return sets
}
func buildBackupDocs(batchID, collection string, documentID any, updates []fieldUpdate) []any {
now := time.Now()
docs := make([]any, 0, len(updates))
for _, u := range updates {
docs = append(docs, bson.M{
"batch_id": batchID,
"collection": collection,
"document_id": documentID,
"field": u.Path,
"old_value": u.OldValue,
"new_value": u.NewValue,
"created_at": now,
})
}
return docs
}
func insertBackupDocs(ctx context.Context, backupColl *mongo.Collection, batchID, collection string, documentID any, updates []fieldUpdate) error {
if backupColl == nil || len(updates) == 0 {
return nil
}
_, err := backupColl.InsertMany(ctx, buildBackupDocs(batchID, collection, documentID, updates))
return err
}
func rollbackBatch(ctx context.Context, backupColl *mongo.Collection, db *mongo.Database, params Params, logf LogFunc) error {
cursor, err := backupColl.Find(ctx, bson.M{"batch_id": params.BatchID}, options.Find().SetSort(bson.D{
{Key: "collection", Value: 1},
{Key: "document_id", Value: 1},
{Key: "field", Value: 1},
}))
if err != nil {
return fmt.Errorf("find backup docs: %w", err)
}
defer cursor.Close(ctx)
type rollbackDoc struct {
Collection string `bson:"collection"`
Field string `bson:"field"`
OldValue string `bson:"old_value"`
DocumentID any `bson:"document_id"`
}
grouped := map[string]map[any][]fieldUpdate{}
for cursor.Next(ctx) {
var doc rollbackDoc
if err := cursor.Decode(&doc); err != nil {
return fmt.Errorf("decode backup doc: %w", err)
}
if _, ok := grouped[doc.Collection]; !ok {
grouped[doc.Collection] = map[any][]fieldUpdate{}
}
grouped[doc.Collection][doc.DocumentID] = append(grouped[doc.Collection][doc.DocumentID], fieldUpdate{
Path: doc.Field,
NewValue: doc.OldValue,
})
}
if len(grouped) == 0 {
return fmt.Errorf("no backup records found for batch_id %s", params.BatchID)
}
var collectionNames []string
for name := range grouped {
collectionNames = append(collectionNames, name)
}
sort.Strings(collectionNames)
for _, name := range collectionNames {
coll := db.Collection(name)
for documentID, updates := range grouped[name] {
if _, err := coll.UpdateByID(ctx, documentID, bson.M{"$set": toUpdateMap(updates)}); err != nil {
return fmt.Errorf("rollback %s: %w", name, err)
}
}
}
logf("rollback complete, batch_id=%s", params.BatchID)
return nil
}
func parseMsgFieldIndex(path string) (int, bool) {
parts := strings.Split(path, ".")
if len(parts) < 4 || parts[0] != "msgs" {
return 0, false
}
idx, err := strconv.Atoi(parts[1])
if err != nil {
return 0, false
}
return idx, true
}
func trimDocIDSuffix(docID string) string {
pos := strings.LastIndex(docID, ":")
if pos <= 0 {
return docID
}
return docID[:pos]
}
func deduplicateStrings(values []string) []string {
seen := make(map[string]struct{}, len(values))
result := make([]string, 0, len(values))
for _, v := range values {
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
result = append(result, v)
}
return result
}
func nestedString(doc bson.M, path string) (string, bool) {
current := any(doc)
for _, segment := range strings.Split(path, ".") {
switch m := current.(type) {
case bson.M:
current = m[segment]
case map[string]any:
current = m[segment]
default:
return "", false
}
}
val, ok := current.(string)
return val, ok
}