// Copyright © 2023 OpenIM. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package third import ( "context" "errors" "fmt" "net/http" "net/url" "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/credentials" aws3 "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/openimsdk/tools/s3" ) const ( minPartSize int64 = 1024 * 1024 * 5 // 5MB maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB maxNumSize int64 = 10000 ) type R2Config struct { Endpoint string Region string Bucket string AccessKeyID string SecretAccessKey string SessionToken string } // NewR2 创建支持 Cloudflare R2 的 S3 客户端 func NewR2(conf R2Config) (*R2, error) { if conf.Endpoint == "" { return nil, errors.New("endpoint is required for R2") } // 创建 HTTP 客户端,设置合理的超时 httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, }, } cfg := aws.Config{ Region: conf.Region, Credentials: credentials.NewStaticCredentialsProvider(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken), HTTPClient: httpClient, } // 创建 S3 客户端,启用路径风格访问(R2 要求)并设置自定义 endpoint client := aws3.NewFromConfig(cfg, func(o *aws3.Options) { o.BaseEndpoint = aws.String(conf.Endpoint) o.UsePathStyle = true }) r2 := &R2{ bucket: conf.Bucket, client: client, presign: aws3.NewPresignClient(client), } // 测试连接:尝试列出 bucket(验证 bucket 存在且有权限),设置 5 秒超时 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() fmt.Printf("[R2] Testing connection to bucket '%s' at endpoint '%s'...\n", conf.Bucket, conf.Endpoint) _, err := client.ListObjectsV2(ctx, &aws3.ListObjectsV2Input{ Bucket: aws.String(conf.Bucket), MaxKeys: aws.Int32(1), }) if err != nil { // 详细的错误信息 var respErr *awshttp.ResponseError if errors.As(err, &respErr) { fmt.Printf("[R2] Bucket verification HTTP error:\n") fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode) fmt.Printf(" Status: %s\n", respErr.Response.Status) } fmt.Printf("[R2] Warning: failed to verify R2 bucket '%s' at endpoint '%s': %v\n", conf.Bucket, conf.Endpoint, err) fmt.Printf("[R2] Please ensure:\n") fmt.Printf(" 1. Bucket '%s' exists in your R2 account\n", conf.Bucket) fmt.Printf(" 2. API credentials have correct permissions (Object Read & Write)\n") fmt.Printf(" 3. Account ID in endpoint matches your R2 account\n") } else { fmt.Printf("[R2] Successfully connected to bucket '%s'\n", conf.Bucket) } return r2, nil } type R2 struct { bucket string client *aws3.Client presign *aws3.PresignClient } func (r *R2) Engine() string { return "aws" } func (r *R2) PartLimit() (*s3.PartLimit, error) { return &s3.PartLimit{ MinPartSize: minPartSize, MaxPartSize: maxPartSize, MaxNumSize: maxNumSize, }, nil } func (r *R2) formatETag(etag string) string { return strings.Trim(etag, `"`) } func (r *R2) PartSize(ctx context.Context, size int64) (int64, error) { if size <= 0 { return 0, errors.New("size must be greater than 0") } if size > maxPartSize*maxNumSize { return 0, fmt.Errorf("size must be less than the maximum allowed limit") } if size <= minPartSize*maxNumSize { return minPartSize, nil } partSize := size / maxNumSize if size%maxNumSize != 0 { partSize++ } return partSize, nil } func (r *R2) IsNotFound(err error) bool { var respErr *awshttp.ResponseError if !errors.As(err, &respErr) { return false } if respErr == nil || respErr.Response == nil { return false } return respErr.Response.StatusCode == http.StatusNotFound } func (r *R2) PresignedPutObject(ctx context.Context, name string, expire time.Duration, opt *s3.PutOption) (*s3.PresignedPutResult, error) { res, err := r.presign.PresignPutObject(ctx, &aws3.PutObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), }, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(nil)) if err != nil { return nil, err } return &s3.PresignedPutResult{URL: res.URL}, nil } func (r *R2) DeleteObject(ctx context.Context, name string) error { _, err := r.client.DeleteObject(ctx, &aws3.DeleteObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), }) return err } func (r *R2) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) { res, err := r.client.CopyObject(ctx, &aws3.CopyObjectInput{ Bucket: aws.String(r.bucket), CopySource: aws.String(r.bucket + "/" + src), Key: aws.String(dst), }) if err != nil { return nil, err } if res.CopyObjectResult == nil || res.CopyObjectResult.ETag == nil || *res.CopyObjectResult.ETag == "" { return nil, errors.New("CopyObject etag is nil") } return &s3.CopyObjectInfo{ Key: dst, ETag: r.formatETag(*res.CopyObjectResult.ETag), }, nil } func (r *R2) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) { res, err := r.client.HeadObject(ctx, &aws3.HeadObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), }) if err != nil { return nil, err } if res.ETag == nil || *res.ETag == "" { return nil, errors.New("GetObjectAttributes etag is nil") } if res.ContentLength == nil { return nil, errors.New("GetObjectAttributes object size is nil") } info := &s3.ObjectInfo{ ETag: r.formatETag(*res.ETag), Key: name, Size: *res.ContentLength, } if res.LastModified == nil { info.LastModified = time.Unix(0, 0) } else { info.LastModified = *res.LastModified } return info, nil } func (r *R2) InitiateMultipartUpload(ctx context.Context, name string, opt *s3.PutOption) (*s3.InitiateMultipartUploadResult, error) { startTime := time.Now() fmt.Printf("[R2] InitiateMultipartUpload start: bucket=%s, key=%s\n", r.bucket, name) input := &aws3.CreateMultipartUploadInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), } // 如果提供了 ContentType,添加到请求中 if opt != nil && opt.ContentType != "" { input.ContentType = aws.String(opt.ContentType) fmt.Printf("[R2] ContentType: %s\n", opt.ContentType) } res, err := r.client.CreateMultipartUpload(ctx, input) duration := time.Since(startTime) if err != nil { // 详细错误信息 var respErr *awshttp.ResponseError if errors.As(err, &respErr) { fmt.Printf("[R2] HTTP Response Error after %v:\n", duration) fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode) fmt.Printf(" Status: %s\n", respErr.Response.Status) if respErr.Response.Body != nil { body := make([]byte, 1024) n, _ := respErr.Response.Body.Read(body) fmt.Printf(" Body: %s\n", string(body[:n])) } } fmt.Printf("[R2] InitiateMultipartUpload failed after %v: %v\n", duration, err) return nil, fmt.Errorf("CreateMultipartUpload failed (bucket=%s, key=%s): %w", r.bucket, name, err) } if res.UploadId == nil || *res.UploadId == "" { return nil, errors.New("CreateMultipartUpload upload id is nil") } fmt.Printf("[R2] InitiateMultipartUpload success after %v: uploadID=%s\n", duration, *res.UploadId) return &s3.InitiateMultipartUploadResult{ Key: name, Bucket: r.bucket, UploadID: *res.UploadId, }, nil } func (r *R2) CompleteMultipartUpload(ctx context.Context, uploadID string, name string, parts []s3.Part) (*s3.CompleteMultipartUploadResult, error) { params := &aws3.CompleteMultipartUploadInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), UploadId: aws.String(uploadID), MultipartUpload: &types.CompletedMultipartUpload{ Parts: make([]types.CompletedPart, 0, len(parts)), }, } for _, part := range parts { params.MultipartUpload.Parts = append(params.MultipartUpload.Parts, types.CompletedPart{ ETag: aws.String(part.ETag), PartNumber: aws.Int32(int32(part.PartNumber)), }) } res, err := r.client.CompleteMultipartUpload(ctx, params) if err != nil { return nil, err } if res.ETag == nil || *res.ETag == "" { return nil, errors.New("CompleteMultipartUpload etag is nil") } info := &s3.CompleteMultipartUploadResult{ Key: name, Bucket: r.bucket, ETag: r.formatETag(*res.ETag), } if res.Location != nil { info.Location = *res.Location } return info, nil } func (r *R2) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error { _, err := r.client.AbortMultipartUpload(ctx, &aws3.AbortMultipartUploadInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), UploadId: aws.String(uploadID), }) return err } func (r *R2) ListUploadedParts(ctx context.Context, uploadID string, name string, partNumberMarker int, maxParts int) (*s3.ListUploadedPartsResult, error) { params := &aws3.ListPartsInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), UploadId: aws.String(uploadID), PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)), MaxParts: aws.Int32(int32(maxParts)), } res, err := r.client.ListParts(ctx, params) if err != nil { return nil, err } info := &s3.ListUploadedPartsResult{ Key: name, UploadID: uploadID, UploadedParts: make([]s3.UploadedPart, 0, len(res.Parts)), } if res.MaxParts != nil { info.MaxParts = int(*res.MaxParts) } if res.NextPartNumberMarker != nil { info.NextPartNumberMarker, _ = strconv.Atoi(*res.NextPartNumberMarker) } for _, part := range res.Parts { var val s3.UploadedPart if part.PartNumber != nil { val.PartNumber = int(*part.PartNumber) } if part.LastModified != nil { val.LastModified = *part.LastModified } if part.Size != nil { val.Size = *part.Size } info.UploadedParts = append(info.UploadedParts, val) } return info, nil } func (r *R2) AuthSign(ctx context.Context, uploadID string, name string, expire time.Duration, partNumbers []int) (*s3.AuthSignResult, error) { res := &s3.AuthSignResult{ Parts: make([]s3.SignPart, 0, len(partNumbers)), } params := &aws3.UploadPartInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), UploadId: aws.String(uploadID), } opt := aws3.WithPresignExpires(expire) for _, number := range partNumbers { params.PartNumber = aws.Int32(int32(number)) val, err := r.presign.PresignUploadPart(ctx, params, opt) if err != nil { return nil, err } u, err := url.Parse(val.URL) if err != nil { return nil, err } query := u.Query() u.RawQuery = "" urlstr := u.String() if res.URL == "" { res.URL = urlstr } if res.URL == urlstr { urlstr = "" } res.Parts = append(res.Parts, s3.SignPart{ PartNumber: number, URL: urlstr, Query: query, Header: val.SignedHeader, }) } return res, nil } func (r *R2) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) { params := &aws3.GetObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(name), } res, err := r.presign.PresignGetObject(ctx, params, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(opt)) if err != nil { return "", err } return res.URL, nil } func (r *R2) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) { return nil, errors.New("R2 does not currently support form data file uploads") } func withDisableHTTPPresignerHeaderV4(opt *s3.AccessURLOption) func(options *aws3.PresignOptions) { return func(options *aws3.PresignOptions) { options.Presigner = &disableHTTPPresignerHeaderV4{ opt: opt, presigner: options.Presigner, } } } type disableHTTPPresignerHeaderV4 struct { opt *s3.AccessURLOption presigner aws3.HTTPPresignerV4 } func (d *disableHTTPPresignerHeaderV4) PresignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*v4.SignerOptions)) (url string, signedHeader http.Header, err error) { optFns = append(optFns, func(options *v4.SignerOptions) { options.DisableHeaderHoisting = true }) r.Header.Del("Amz-Sdk-Request") d.setOption(r.URL) return d.presigner.PresignHTTP(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...) } func (d *disableHTTPPresignerHeaderV4) setOption(u *url.URL) { if d.opt == nil { return } query := u.Query() if d.opt.ContentType != "" { query.Set("response-content-type", d.opt.ContentType) } if d.opt.Filename != "" { query.Set("response-content-disposition", `attachment; filename*=UTF-8''`+url.PathEscape(d.opt.Filename)) } u.RawQuery = query.Encode() }