package urlrewrite import ( "context" "fmt" "regexp" "sort" "strconv" "strings" "time" "github.com/redis/go-redis/v9" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) const ( backupCollectionName = "url_rewrite_backups" msgCachePrefix = "MSG_CACHE:" ModeDryRun = "dry-run" ModeApply = "apply" ModeRollback = "rollback" ModeInvalidateCache = "invalidate-cache" ) type Params struct { OldPrefix string `json:"oldPrefix"` NewPrefix string `json:"newPrefix"` Mode string `json:"mode"` BatchID string `json:"batchId"` SampleSize int `json:"sampleSize"` } type RedisConfig struct { Addr string Username string Password string Client *redis.Client } type Summary struct { Collection string Scanned int64 Updated int64 Samples []string } type fieldUpdate struct { Path string OldValue string NewValue string } type msgDoc struct { ID any `bson:"_id"` Doc string `bson:"doc_id"` Msgs []msgEntryModel `bson:"msgs"` } type msgEntryModel struct { Msg *msgPayload `bson:"msg"` } type msgPayload struct { Content string `bson:"content"` SenderFaceURL string `bson:"sender_face_url"` Seq int64 `bson:"seq"` } type LogFunc func(msg string, args ...any) func Run(ctx context.Context, db *mongo.Database, params Params, redisCfg RedisConfig, logf LogFunc) (string, error) { if err := validateParams(params); err != nil { return "", err } if params.SampleSize <= 0 { params.SampleSize = 5 } if params.Mode == "" { params.Mode = ModeDryRun } if params.Mode == ModeApply && params.BatchID == "" { params.BatchID = "urlrewrite-" + time.Now().Format("20060102150405") } backupColl := db.Collection(backupCollectionName) if params.Mode == ModeRollback { return params.BatchID, rollbackBatch(ctx, backupColl, db, params, logf) } if params.Mode == ModeInvalidateCache { return params.BatchID, invalidateBatchCache(ctx, db, redisCfg, params, logf) } summaries := []Summary{ rewriteSimpleFieldCollection(ctx, db.Collection("wallets"), backupColl, "real_name_auth.id_card_photo_front", params), rewriteSimpleFieldCollection(ctx, db.Collection("wallets"), backupColl, "real_name_auth.id_card_photo_back", params), rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "content", params), rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "thumbnail", params), rewriteSimpleFieldCollection(ctx, db.Collection("favorites"), backupColl, "link_url", params), rewriteSimpleFieldCollection(ctx, db.Collection("attributes"), backupColl, "face_url", params), rewriteSimpleFieldCollection(ctx, db.Collection("attribute"), backupColl, "face_url", params), rewriteSimpleFieldCollection(ctx, db.Collection("user"), backupColl, "face_url", params), rewriteSimpleFieldCollection(ctx, db.Collection("group"), backupColl, "face_url", params), rewriteSimpleFieldCollection(ctx, db.Collection("group_member"), backupColl, "face_url", params), } msgSummary, err := rewriteMsgCollection(ctx, db.Collection("msg"), backupColl, params) if err != nil { return "", fmt.Errorf("rewrite msg collection: %w", err) } summaries = append(summaries, msgSummary) for _, s := range summaries { logf("[%s] scanned=%d updated=%d", s.Collection, s.Scanned, s.Updated) for _, sample := range s.Samples { logf("[%s] sample: %s", s.Collection, sample) } } switch params.Mode { case ModeDryRun: logf("dry-run complete, rerun with mode=apply to write changes") case ModeApply: logf("apply complete, batch_id=%s", params.BatchID) } return params.BatchID, nil } func validateParams(p Params) error { switch p.Mode { case ModeDryRun, ModeApply, ModeRollback, ModeInvalidateCache: case "": default: return fmt.Errorf("invalid mode: %s", p.Mode) } if p.Mode == ModeRollback || p.Mode == ModeInvalidateCache { if p.BatchID == "" { return fmt.Errorf("batchId is required for %s", p.Mode) } } else { if p.OldPrefix == "" || p.NewPrefix == "" { return fmt.Errorf("oldPrefix and newPrefix are required") } } return nil } func invalidateBatchCache(ctx context.Context, db *mongo.Database, redisCfg RedisConfig, params Params, logf LogFunc) error { if redisCfg.Addr == "" { return fmt.Errorf("REDIS_ADDR is not configured") } backupColl := db.Collection(backupCollectionName) cursor, err := backupColl.Find(ctx, bson.M{ "batch_id": params.BatchID, "collection": "msg", "field": bson.M{"$regex": primitive.Regex{Pattern: `^msgs\.\d+\.msg\.`}}, }) if err != nil { return fmt.Errorf("find backup docs: %w", err) } defer cursor.Close(ctx) type backupDoc struct { DocumentID any `bson:"document_id"` Field string `bson:"field"` } docIndexes := make(map[string]map[int]struct{}) for cursor.Next(ctx) { var item backupDoc if err := cursor.Decode(&item); err != nil { return fmt.Errorf("decode backup doc: %w", err) } docID, ok := item.DocumentID.(primitive.ObjectID) if !ok { continue } msgIndex, ok := parseMsgFieldIndex(item.Field) if !ok { continue } id := docID.Hex() if _, exists := docIndexes[id]; !exists { docIndexes[id] = make(map[int]struct{}) } docIndexes[id][msgIndex] = struct{}{} } msgColl := db.Collection("msg") var keys []string for hexID, indexes := range docIndexes { docObjectID, err := primitive.ObjectIDFromHex(hexID) if err != nil { return fmt.Errorf("parse object id %s: %w", hexID, err) } var doc msgDoc if err := msgColl.FindOne(ctx, bson.M{"_id": docObjectID}).Decode(&doc); err != nil { return fmt.Errorf("find msg doc %s: %w", hexID, err) } conversationID := trimDocIDSuffix(doc.Doc) if conversationID == "" { continue } for idx := range indexes { if idx < 0 || idx >= len(doc.Msgs) || doc.Msgs[idx].Msg == nil || doc.Msgs[idx].Msg.Seq <= 0 { continue } keys = append(keys, msgCachePrefix+conversationID+":"+strconv.Itoa(int(doc.Msgs[idx].Msg.Seq))) } } keys = deduplicateStrings(keys) if len(keys) == 0 { logf("invalidate-cache complete, batch_id=%s keys=0", params.BatchID) return nil } rdb := redisCfg.Client shouldClose := false if rdb == nil { var err error rdb, err = newRedisClient(ctx, redisCfg.Addr, redisCfg.Username, redisCfg.Password) if err != nil { return fmt.Errorf("connect redis: %w", err) } shouldClose = true } if shouldClose { defer rdb.Close() } if err := rdb.Del(ctx, keys...).Err(); err != nil { return fmt.Errorf("redis del: %w", err) } logf("invalidate-cache complete, batch_id=%s keys=%d", params.BatchID, len(keys)) return nil } func rewriteSimpleFieldCollection(ctx context.Context, coll, backupColl *mongo.Collection, field string, params Params) Summary { filter := bson.M{field: bson.M{"$regex": primitive.Regex{Pattern: "^" + regexp.QuoteMeta(params.OldPrefix)}}} cursor, err := coll.Find(ctx, filter) if err != nil { return Summary{Collection: coll.Name() + "." + field, Samples: []string{fmt.Sprintf("find error: %v", err)}} } defer cursor.Close(ctx) summary := Summary{Collection: coll.Name() + "." + field} for cursor.Next(ctx) { summary.Scanned++ var doc bson.M if err := cursor.Decode(&doc); err != nil { if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("decode error: %v", err)) } continue } original, ok := nestedString(doc, field) if !ok { continue } rewritten, changed := rewriteStringValue(original, params.OldPrefix, params.NewPrefix) if !changed { continue } updates := []fieldUpdate{{ Path: field, OldValue: original, NewValue: rewritten, }} summary.Updated++ if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("%s -> %s", original, rewritten)) } if params.Mode != ModeApply { continue } if err := insertBackupDocs(ctx, backupColl, params.BatchID, coll.Name(), doc["_id"], updates); err != nil { if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("backup error: %v", err)) } continue } if _, err := coll.UpdateByID(ctx, doc["_id"], bson.M{"$set": bson.M{field: rewritten}}); err != nil { if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("update error: %v", err)) } } } return summary } func rewriteMsgCollection(ctx context.Context, coll, backupColl *mongo.Collection, params Params) (Summary, error) { cursor, err := coll.Find(ctx, bson.M{ "$or": bson.A{ bson.M{"msgs.msg.content": bson.M{"$regex": primitive.Regex{Pattern: regexp.QuoteMeta(params.OldPrefix)}}}, bson.M{"msgs.msg.sender_face_url": bson.M{"$regex": primitive.Regex{Pattern: "^" + regexp.QuoteMeta(params.OldPrefix)}}}, }, }) if err != nil { return Summary{}, fmt.Errorf("find msg docs: %w", err) } defer cursor.Close(ctx) summary := Summary{Collection: coll.Name()} for cursor.Next(ctx) { summary.Scanned++ var doc msgDoc if err := cursor.Decode(&doc); err != nil { if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("decode error: %v", err)) } continue } updates, changed, samples := collectMsgFieldUpdates(doc, params.OldPrefix, params.NewPrefix, params.SampleSize-len(summary.Samples)) if !changed { continue } summary.Updated++ summary.Samples = append(summary.Samples, samples...) if params.Mode != ModeApply { continue } if err := insertBackupDocs(ctx, backupColl, params.BatchID, coll.Name(), doc.ID, updates); err != nil { if len(summary.Samples) < params.SampleSize { summary.Samples = append(summary.Samples, fmt.Sprintf("backup error: %v", err)) } continue } if _, err := coll.UpdateByID(ctx, doc.ID, bson.M{"$set": toUpdateMap(updates)}); err != nil { return summary, fmt.Errorf("update msg doc: %w", err) } } return summary, nil } func collectMsgFieldUpdates(doc msgDoc, oldPrefix, newPrefix string, sampleBudget int) ([]fieldUpdate, bool, []string) { var updates []fieldUpdate var samples []string for i := range doc.Msgs { if doc.Msgs[i].Msg == nil { continue } if doc.Msgs[i].Msg.Content != "" { rewritten, contentChanged, err := rewriteJSONContent(doc.Msgs[i].Msg.Content, oldPrefix, newPrefix) if err == nil && contentChanged { updates = append(updates, fieldUpdate{ Path: fmt.Sprintf("msgs.%d.msg.content", i), OldValue: doc.Msgs[i].Msg.Content, NewValue: rewritten, }) if len(samples) < sampleBudget { samples = append(samples, fmt.Sprintf("doc=%s msgs.%d.msg.content updated", doc.Doc, i)) } } } if doc.Msgs[i].Msg.SenderFaceURL != "" { rewritten, faceChanged := rewriteStringValue(doc.Msgs[i].Msg.SenderFaceURL, oldPrefix, newPrefix) if faceChanged { updates = append(updates, fieldUpdate{ Path: fmt.Sprintf("msgs.%d.msg.sender_face_url", i), OldValue: doc.Msgs[i].Msg.SenderFaceURL, NewValue: rewritten, }) if len(samples) < sampleBudget { samples = append(samples, fmt.Sprintf("doc=%s msgs.%d.msg.sender_face_url updated", doc.Doc, i)) } } } } return updates, len(updates) > 0, samples } func toUpdateMap(updates []fieldUpdate) bson.M { sets := bson.M{} for _, u := range updates { sets[u.Path] = u.NewValue } return sets } func buildBackupDocs(batchID, collection string, documentID any, updates []fieldUpdate) []any { now := time.Now() docs := make([]any, 0, len(updates)) for _, u := range updates { docs = append(docs, bson.M{ "batch_id": batchID, "collection": collection, "document_id": documentID, "field": u.Path, "old_value": u.OldValue, "new_value": u.NewValue, "created_at": now, }) } return docs } func insertBackupDocs(ctx context.Context, backupColl *mongo.Collection, batchID, collection string, documentID any, updates []fieldUpdate) error { if backupColl == nil || len(updates) == 0 { return nil } _, err := backupColl.InsertMany(ctx, buildBackupDocs(batchID, collection, documentID, updates)) return err } func rollbackBatch(ctx context.Context, backupColl *mongo.Collection, db *mongo.Database, params Params, logf LogFunc) error { cursor, err := backupColl.Find(ctx, bson.M{"batch_id": params.BatchID}, options.Find().SetSort(bson.D{ {Key: "collection", Value: 1}, {Key: "document_id", Value: 1}, {Key: "field", Value: 1}, })) if err != nil { return fmt.Errorf("find backup docs: %w", err) } defer cursor.Close(ctx) type rollbackDoc struct { Collection string `bson:"collection"` Field string `bson:"field"` OldValue string `bson:"old_value"` DocumentID any `bson:"document_id"` } grouped := map[string]map[any][]fieldUpdate{} for cursor.Next(ctx) { var doc rollbackDoc if err := cursor.Decode(&doc); err != nil { return fmt.Errorf("decode backup doc: %w", err) } if _, ok := grouped[doc.Collection]; !ok { grouped[doc.Collection] = map[any][]fieldUpdate{} } grouped[doc.Collection][doc.DocumentID] = append(grouped[doc.Collection][doc.DocumentID], fieldUpdate{ Path: doc.Field, NewValue: doc.OldValue, }) } if len(grouped) == 0 { return fmt.Errorf("no backup records found for batch_id %s", params.BatchID) } var collectionNames []string for name := range grouped { collectionNames = append(collectionNames, name) } sort.Strings(collectionNames) for _, name := range collectionNames { coll := db.Collection(name) for documentID, updates := range grouped[name] { if _, err := coll.UpdateByID(ctx, documentID, bson.M{"$set": toUpdateMap(updates)}); err != nil { return fmt.Errorf("rollback %s: %w", name, err) } } } logf("rollback complete, batch_id=%s", params.BatchID) return nil } func parseMsgFieldIndex(path string) (int, bool) { parts := strings.Split(path, ".") if len(parts) < 4 || parts[0] != "msgs" { return 0, false } idx, err := strconv.Atoi(parts[1]) if err != nil { return 0, false } return idx, true } func trimDocIDSuffix(docID string) string { pos := strings.LastIndex(docID, ":") if pos <= 0 { return docID } return docID[:pos] } func deduplicateStrings(values []string) []string { seen := make(map[string]struct{}, len(values)) result := make([]string, 0, len(values)) for _, v := range values { if _, ok := seen[v]; ok { continue } seen[v] = struct{}{} result = append(result, v) } return result } func nestedString(doc bson.M, path string) (string, bool) { current := any(doc) for _, segment := range strings.Split(path, ".") { switch m := current.(type) { case bson.M: current = m[segment] case map[string]any: current = m[segment] default: return "", false } } val, ok := current.(string) return val, ok }