Files
kim.dev.6789 e50142a3b9 复制项目
2026-01-14 22:16:44 +08:00

447 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright © 2023 OpenIM. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package third
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/credentials"
aws3 "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/openimsdk/tools/s3"
)
const (
minPartSize int64 = 1024 * 1024 * 5 // 5MB
maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB
maxNumSize int64 = 10000
)
type R2Config struct {
Endpoint string
Region string
Bucket string
AccessKeyID string
SecretAccessKey string
SessionToken string
}
// NewR2 创建支持 Cloudflare R2 的 S3 客户端
func NewR2(conf R2Config) (*R2, error) {
if conf.Endpoint == "" {
return nil, errors.New("endpoint is required for R2")
}
// 创建 HTTP 客户端,设置合理的超时
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
},
}
cfg := aws.Config{
Region: conf.Region,
Credentials: credentials.NewStaticCredentialsProvider(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
HTTPClient: httpClient,
}
// 创建 S3 客户端启用路径风格访问R2 要求)并设置自定义 endpoint
client := aws3.NewFromConfig(cfg, func(o *aws3.Options) {
o.BaseEndpoint = aws.String(conf.Endpoint)
o.UsePathStyle = true
})
r2 := &R2{
bucket: conf.Bucket,
client: client,
presign: aws3.NewPresignClient(client),
}
// 测试连接:尝试列出 bucket验证 bucket 存在且有权限),设置 5 秒超时
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
fmt.Printf("[R2] Testing connection to bucket '%s' at endpoint '%s'...\n", conf.Bucket, conf.Endpoint)
_, err := client.ListObjectsV2(ctx, &aws3.ListObjectsV2Input{
Bucket: aws.String(conf.Bucket),
MaxKeys: aws.Int32(1),
})
if err != nil {
// 详细的错误信息
var respErr *awshttp.ResponseError
if errors.As(err, &respErr) {
fmt.Printf("[R2] Bucket verification HTTP error:\n")
fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode)
fmt.Printf(" Status: %s\n", respErr.Response.Status)
}
fmt.Printf("[R2] Warning: failed to verify R2 bucket '%s' at endpoint '%s': %v\n", conf.Bucket, conf.Endpoint, err)
fmt.Printf("[R2] Please ensure:\n")
fmt.Printf(" 1. Bucket '%s' exists in your R2 account\n", conf.Bucket)
fmt.Printf(" 2. API credentials have correct permissions (Object Read & Write)\n")
fmt.Printf(" 3. Account ID in endpoint matches your R2 account\n")
} else {
fmt.Printf("[R2] Successfully connected to bucket '%s'\n", conf.Bucket)
}
return r2, nil
}
type R2 struct {
bucket string
client *aws3.Client
presign *aws3.PresignClient
}
func (r *R2) Engine() string {
return "aws"
}
func (r *R2) PartLimit() (*s3.PartLimit, error) {
return &s3.PartLimit{
MinPartSize: minPartSize,
MaxPartSize: maxPartSize,
MaxNumSize: maxNumSize,
}, nil
}
func (r *R2) formatETag(etag string) string {
return strings.Trim(etag, `"`)
}
func (r *R2) PartSize(ctx context.Context, size int64) (int64, error) {
if size <= 0 {
return 0, errors.New("size must be greater than 0")
}
if size > maxPartSize*maxNumSize {
return 0, fmt.Errorf("size must be less than the maximum allowed limit")
}
if size <= minPartSize*maxNumSize {
return minPartSize, nil
}
partSize := size / maxNumSize
if size%maxNumSize != 0 {
partSize++
}
return partSize, nil
}
func (r *R2) IsNotFound(err error) bool {
var respErr *awshttp.ResponseError
if !errors.As(err, &respErr) {
return false
}
if respErr == nil || respErr.Response == nil {
return false
}
return respErr.Response.StatusCode == http.StatusNotFound
}
func (r *R2) PresignedPutObject(ctx context.Context, name string, expire time.Duration, opt *s3.PutOption) (*s3.PresignedPutResult, error) {
res, err := r.presign.PresignPutObject(ctx, &aws3.PutObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(nil))
if err != nil {
return nil, err
}
return &s3.PresignedPutResult{URL: res.URL}, nil
}
func (r *R2) DeleteObject(ctx context.Context, name string) error {
_, err := r.client.DeleteObject(ctx, &aws3.DeleteObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
})
return err
}
func (r *R2) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
res, err := r.client.CopyObject(ctx, &aws3.CopyObjectInput{
Bucket: aws.String(r.bucket),
CopySource: aws.String(r.bucket + "/" + src),
Key: aws.String(dst),
})
if err != nil {
return nil, err
}
if res.CopyObjectResult == nil || res.CopyObjectResult.ETag == nil || *res.CopyObjectResult.ETag == "" {
return nil, errors.New("CopyObject etag is nil")
}
return &s3.CopyObjectInfo{
Key: dst,
ETag: r.formatETag(*res.CopyObjectResult.ETag),
}, nil
}
func (r *R2) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) {
res, err := r.client.HeadObject(ctx, &aws3.HeadObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
})
if err != nil {
return nil, err
}
if res.ETag == nil || *res.ETag == "" {
return nil, errors.New("GetObjectAttributes etag is nil")
}
if res.ContentLength == nil {
return nil, errors.New("GetObjectAttributes object size is nil")
}
info := &s3.ObjectInfo{
ETag: r.formatETag(*res.ETag),
Key: name,
Size: *res.ContentLength,
}
if res.LastModified == nil {
info.LastModified = time.Unix(0, 0)
} else {
info.LastModified = *res.LastModified
}
return info, nil
}
func (r *R2) InitiateMultipartUpload(ctx context.Context, name string, opt *s3.PutOption) (*s3.InitiateMultipartUploadResult, error) {
startTime := time.Now()
fmt.Printf("[R2] InitiateMultipartUpload start: bucket=%s, key=%s\n", r.bucket, name)
input := &aws3.CreateMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}
// 如果提供了 ContentType添加到请求中
if opt != nil && opt.ContentType != "" {
input.ContentType = aws.String(opt.ContentType)
fmt.Printf("[R2] ContentType: %s\n", opt.ContentType)
}
res, err := r.client.CreateMultipartUpload(ctx, input)
duration := time.Since(startTime)
if err != nil {
// 详细错误信息
var respErr *awshttp.ResponseError
if errors.As(err, &respErr) {
fmt.Printf("[R2] HTTP Response Error after %v:\n", duration)
fmt.Printf(" Status Code: %d\n", respErr.Response.StatusCode)
fmt.Printf(" Status: %s\n", respErr.Response.Status)
if respErr.Response.Body != nil {
body := make([]byte, 1024)
n, _ := respErr.Response.Body.Read(body)
fmt.Printf(" Body: %s\n", string(body[:n]))
}
}
fmt.Printf("[R2] InitiateMultipartUpload failed after %v: %v\n", duration, err)
return nil, fmt.Errorf("CreateMultipartUpload failed (bucket=%s, key=%s): %w", r.bucket, name, err)
}
if res.UploadId == nil || *res.UploadId == "" {
return nil, errors.New("CreateMultipartUpload upload id is nil")
}
fmt.Printf("[R2] InitiateMultipartUpload success after %v: uploadID=%s\n", duration, *res.UploadId)
return &s3.InitiateMultipartUploadResult{
Key: name,
Bucket: r.bucket,
UploadID: *res.UploadId,
}, nil
}
func (r *R2) CompleteMultipartUpload(ctx context.Context, uploadID string, name string, parts []s3.Part) (*s3.CompleteMultipartUploadResult, error) {
params := &aws3.CompleteMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
MultipartUpload: &types.CompletedMultipartUpload{
Parts: make([]types.CompletedPart, 0, len(parts)),
},
}
for _, part := range parts {
params.MultipartUpload.Parts = append(params.MultipartUpload.Parts, types.CompletedPart{
ETag: aws.String(part.ETag),
PartNumber: aws.Int32(int32(part.PartNumber)),
})
}
res, err := r.client.CompleteMultipartUpload(ctx, params)
if err != nil {
return nil, err
}
if res.ETag == nil || *res.ETag == "" {
return nil, errors.New("CompleteMultipartUpload etag is nil")
}
info := &s3.CompleteMultipartUploadResult{
Key: name,
Bucket: r.bucket,
ETag: r.formatETag(*res.ETag),
}
if res.Location != nil {
info.Location = *res.Location
}
return info, nil
}
func (r *R2) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error {
_, err := r.client.AbortMultipartUpload(ctx, &aws3.AbortMultipartUploadInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
})
return err
}
func (r *R2) ListUploadedParts(ctx context.Context, uploadID string, name string, partNumberMarker int, maxParts int) (*s3.ListUploadedPartsResult, error) {
params := &aws3.ListPartsInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
PartNumberMarker: aws.String(strconv.Itoa(partNumberMarker)),
MaxParts: aws.Int32(int32(maxParts)),
}
res, err := r.client.ListParts(ctx, params)
if err != nil {
return nil, err
}
info := &s3.ListUploadedPartsResult{
Key: name,
UploadID: uploadID,
UploadedParts: make([]s3.UploadedPart, 0, len(res.Parts)),
}
if res.MaxParts != nil {
info.MaxParts = int(*res.MaxParts)
}
if res.NextPartNumberMarker != nil {
info.NextPartNumberMarker, _ = strconv.Atoi(*res.NextPartNumberMarker)
}
for _, part := range res.Parts {
var val s3.UploadedPart
if part.PartNumber != nil {
val.PartNumber = int(*part.PartNumber)
}
if part.LastModified != nil {
val.LastModified = *part.LastModified
}
if part.Size != nil {
val.Size = *part.Size
}
info.UploadedParts = append(info.UploadedParts, val)
}
return info, nil
}
func (r *R2) AuthSign(ctx context.Context, uploadID string, name string, expire time.Duration, partNumbers []int) (*s3.AuthSignResult, error) {
res := &s3.AuthSignResult{
Parts: make([]s3.SignPart, 0, len(partNumbers)),
}
params := &aws3.UploadPartInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
UploadId: aws.String(uploadID),
}
opt := aws3.WithPresignExpires(expire)
for _, number := range partNumbers {
params.PartNumber = aws.Int32(int32(number))
val, err := r.presign.PresignUploadPart(ctx, params, opt)
if err != nil {
return nil, err
}
u, err := url.Parse(val.URL)
if err != nil {
return nil, err
}
query := u.Query()
u.RawQuery = ""
urlstr := u.String()
if res.URL == "" {
res.URL = urlstr
}
if res.URL == urlstr {
urlstr = ""
}
res.Parts = append(res.Parts, s3.SignPart{
PartNumber: number,
URL: urlstr,
Query: query,
Header: val.SignedHeader,
})
}
return res, nil
}
func (r *R2) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
params := &aws3.GetObjectInput{
Bucket: aws.String(r.bucket),
Key: aws.String(name),
}
res, err := r.presign.PresignGetObject(ctx, params, aws3.WithPresignExpires(expire), withDisableHTTPPresignerHeaderV4(opt))
if err != nil {
return "", err
}
return res.URL, nil
}
func (r *R2) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) {
return nil, errors.New("R2 does not currently support form data file uploads")
}
func withDisableHTTPPresignerHeaderV4(opt *s3.AccessURLOption) func(options *aws3.PresignOptions) {
return func(options *aws3.PresignOptions) {
options.Presigner = &disableHTTPPresignerHeaderV4{
opt: opt,
presigner: options.Presigner,
}
}
}
type disableHTTPPresignerHeaderV4 struct {
opt *s3.AccessURLOption
presigner aws3.HTTPPresignerV4
}
func (d *disableHTTPPresignerHeaderV4) PresignHTTP(ctx context.Context, credentials aws.Credentials, r *http.Request, payloadHash string, service string, region string, signingTime time.Time, optFns ...func(*v4.SignerOptions)) (url string, signedHeader http.Header, err error) {
optFns = append(optFns, func(options *v4.SignerOptions) {
options.DisableHeaderHoisting = true
})
r.Header.Del("Amz-Sdk-Request")
d.setOption(r.URL)
return d.presigner.PresignHTTP(ctx, credentials, r, payloadHash, service, region, signingTime, optFns...)
}
func (d *disableHTTPPresignerHeaderV4) setOption(u *url.URL) {
if d.opt == nil {
return
}
query := u.Query()
if d.opt.ContentType != "" {
query.Set("response-content-type", d.opt.ContentType)
}
if d.opt.Filename != "" {
query.Set("response-content-disposition", `attachment; filename*=UTF-8''`+url.PathEscape(d.opt.Filename))
}
u.RawQuery = query.Encode()
}