Files
scheduler-backend/internal/s3migrate/runner.go
2026-05-28 00:16:19 +08:00

221 lines
6.3 KiB
Go

package s3migrate
import (
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
const (
collectionName = "s3"
presignExpiry = time.Hour
)
type Params struct {
SourceEngine string `json:"sourceEngine"`
SourceEndpoint string `json:"sourceEndpoint"`
SourceBucket string `json:"sourceBucket"`
SourceAccessKey string `json:"sourceAccessKey"`
SourceSecretKey string `json:"sourceSecretKey"`
SourceRegion string `json:"sourceRegion"`
DestEngine string `json:"destEngine"`
DestEndpoint string `json:"destEndpoint"`
DestBucket string `json:"destBucket"`
DestAccessKey string `json:"destAccessKey"`
DestSecretKey string `json:"destSecretKey"`
DestRegion string `json:"destRegion"`
}
type objectRecord struct {
Name string `bson:"name"`
Engine string `bson:"engine"`
Key string `bson:"key"`
Size int64 `bson:"size"`
ContentType string `bson:"content_type"`
}
type LogFunc func(msg string, args ...any)
func Run(ctx context.Context, db *mongo.Database, params Params, logf LogFunc) error {
if err := validateParams(params); err != nil {
return err
}
if params.SourceRegion == "" {
params.SourceRegion = "us-east-1"
}
if params.DestRegion == "" {
params.DestRegion = "us-east-1"
}
srcPresign := newPresignClient(params.SourceEndpoint, params.SourceAccessKey, params.SourceSecretKey, params.SourceRegion)
dstPresign := newPresignClient(params.DestEndpoint, params.DestAccessKey, params.DestSecretKey, params.DestRegion)
coll := db.Collection(collectionName)
count, err := coll.CountDocuments(ctx, bson.M{"engine": params.SourceEngine})
if err != nil {
return fmt.Errorf("count source objects: %w", err)
}
logf("source engine=%s count=%d", params.SourceEngine, count)
if count == 0 {
logf("no objects to migrate")
return nil
}
cursor, err := coll.Find(ctx, bson.M{"engine": params.SourceEngine}, options.Find().SetBatchSize(100))
if err != nil {
return fmt.Errorf("find source objects: %w", err)
}
defer cursor.Close(ctx)
var migrated, skipped, failed int64
for cursor.Next(ctx) {
var obj objectRecord
if err := cursor.Decode(&obj); err != nil {
logf("decode error: %v", err)
failed++
continue
}
exists, err := objectExistsInDest(ctx, coll, params.DestEngine, obj.Name)
if err != nil {
logf("check dest error name=%s: %v", obj.Name, err)
failed++
continue
}
if exists {
skipped++
continue
}
if err := copyObject(ctx, srcPresign, params.SourceBucket, dstPresign, params.DestBucket, obj); err != nil {
logf("copy error name=%s: %v", obj.Name, err)
failed++
continue
}
if _, err := coll.UpdateOne(ctx,
bson.M{"engine": params.SourceEngine, "name": obj.Name},
bson.M{"$set": bson.M{"engine": params.DestEngine}},
); err != nil {
logf("update engine error name=%s: %v", obj.Name, err)
failed++
continue
}
migrated++
if migrated%100 == 0 {
logf("progress: migrated=%d skipped=%d failed=%d", migrated, skipped, failed)
}
}
logf("complete: migrated=%d skipped=%d failed=%d total=%d", migrated, skipped, failed, count)
return nil
}
func validateParams(p Params) error {
if p.SourceEngine == "" {
return fmt.Errorf("sourceEngine is required")
}
if p.DestEngine == "" {
return fmt.Errorf("destEngine is required")
}
if p.SourceEngine == p.DestEngine {
return fmt.Errorf("sourceEngine and destEngine must be different")
}
if p.SourceEndpoint == "" || p.SourceBucket == "" || p.SourceAccessKey == "" || p.SourceSecretKey == "" {
return fmt.Errorf("source S3 config (endpoint, bucket, accessKey, secretKey) is required")
}
if p.DestEndpoint == "" || p.DestBucket == "" || p.DestAccessKey == "" || p.DestSecretKey == "" {
return fmt.Errorf("dest S3 config (endpoint, bucket, accessKey, secretKey) is required")
}
return nil
}
func newPresignClient(endpoint, accessKey, secretKey, region string) *s3.PresignClient {
cfg := aws.Config{
Region: region,
Credentials: credentials.NewStaticCredentialsProvider(accessKey, secretKey, ""),
BaseEndpoint: &endpoint,
}
client := s3.NewFromConfig(cfg, func(o *s3.Options) {
o.UsePathStyle = true
})
return s3.NewPresignClient(client)
}
func objectExistsInDest(ctx context.Context, coll *mongo.Collection, destEngine, name string) (bool, error) {
count, err := coll.CountDocuments(ctx, bson.M{"engine": destEngine, "name": name}, options.Count().SetLimit(1))
if err != nil {
return false, err
}
return count > 0, nil
}
func copyObject(ctx context.Context, srcPresign *s3.PresignClient, srcBucket string, dstPresign *s3.PresignClient, dstBucket string, obj objectRecord) error {
getReq, err := srcPresign.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: &srcBucket,
Key: &obj.Key,
}, s3.WithPresignExpires(presignExpiry))
if err != nil {
return fmt.Errorf("presign get: %w", err)
}
contentType := obj.ContentType
putReq, err := dstPresign.PresignPutObject(ctx, &s3.PutObjectInput{
Bucket: &dstBucket,
Key: &obj.Key,
ContentType: &contentType,
}, s3.WithPresignExpires(presignExpiry))
if err != nil {
return fmt.Errorf("presign put: %w", err)
}
downloadResp, err := http.Get(getReq.URL)
if err != nil {
return fmt.Errorf("download: %w", err)
}
defer downloadResp.Body.Close()
if downloadResp.StatusCode == http.StatusNotFound {
return fmt.Errorf("source object not found: %s", obj.Key)
}
if downloadResp.StatusCode != http.StatusOK {
return fmt.Errorf("download status: %s", downloadResp.Status)
}
uploadReq, err := http.NewRequestWithContext(ctx, http.MethodPut, putReq.URL, downloadResp.Body)
if err != nil {
return fmt.Errorf("create upload request: %w", err)
}
if downloadResp.ContentLength > 0 {
uploadReq.ContentLength = downloadResp.ContentLength
} else {
uploadReq.ContentLength = obj.Size
}
if contentType != "" {
uploadReq.Header.Set("Content-Type", contentType)
}
uploadResp, err := http.DefaultClient.Do(uploadReq)
if err != nil {
return fmt.Errorf("upload: %w", err)
}
defer func() { _, _ = io.Copy(io.Discard, uploadResp.Body); uploadResp.Body.Close() }()
if uploadResp.StatusCode != http.StatusOK {
return fmt.Errorf("upload status: %s", uploadResp.Status)
}
return nil
}