From ac8e48940437f6b7f7760127513b52f53c46915d Mon Sep 17 00:00:00 2001 From: cyk Date: Fri, 2 Jan 2026 23:08:45 +0800 Subject: [PATCH 01/20] fix(FileTransferTask): skip copying if destination directory does not exist --- internal/fs/copy_move.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/fs/copy_move.go b/internal/fs/copy_move.go index 3b6d91aec..3395584b4 100644 --- a/internal/fs/copy_move.go +++ b/internal/fs/copy_move.go @@ -191,16 +191,16 @@ func (t *FileTransferTask) RunWithNextTaskCallback(f func(nextTask *FileTransfer existedObjs := make(map[string]bool) if t.TaskType == merge { + // 检查目标目录是否存在,如果不存在则跳过(existedObjs保持为空map,所有文件都会被复制) dstObjs, err := op.List(t.Ctx(), t.DstStorage, dstActualPath, model.ListArgs{}) - if err != nil { - return errors.WithMessagef(err, "failed list dst [%s] objs", dstActualPath) - } - for _, obj := range dstObjs { - if err := t.Ctx().Err(); err != nil { - return err - } - if !obj.IsDir() { - existedObjs[obj.GetName()] = true + if err == nil { + for _, obj := range dstObjs { + if err := t.Ctx().Err(); err != nil { + return err + } + if !obj.IsDir() { + existedObjs[obj.GetName()] = true + } } } } From c670c0dcfe6e46aa9c74f56a3d35924b2a3a10dc Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 28 Dec 2025 18:17:05 +0800 Subject: [PATCH 02/20] fix(driver): fix file copy failure to 123pan due to incorrect etag --- drivers/123_open/driver.go | 41 ++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index d9984e57f..1529b6f15 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -2,7 +2,9 @@ package _123_open import ( "context" + "encoding/hex" "fmt" + "io" "strconv" "time" @@ -10,7 +12,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" - "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) @@ -175,20 +177,36 @@ func (d *Open123) Remove(ctx context.Context, obj model.Obj) error { } func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - // 1. 创建文件 + // 1. 准备参数 // parentFileID 父目录id,上传到根目录时填写 0 parentFileId, err := strconv.ParseInt(dstDir.GetID(), 10, 64) if err != nil { return nil, fmt.Errorf("parse parentFileID error: %v", err) } - // etag 文件md5 - etag := file.GetHash().GetHash(utils.MD5) - if len(etag) < utils.MD5.Width { - _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) + + // 1. 流式计算MD5 + md5Hash := utils.MD5.NewFunc() + size := file.GetSize() + chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk for MD5 calculation + var offset int64 = 0 + for offset < size { + readSize := min(chunkSize, size-offset) + reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) if err != nil { - return nil, err + return nil, fmt.Errorf("range read for MD5 calculation failed: %w", err) + } + if _, err := io.Copy(md5Hash, reader); err != nil { + return nil, fmt.Errorf("calculate MD5 failed: %w", err) } + offset += readSize + + progress := 40 * float64(offset) / float64(size) + up(progress) } + + etag := hex.EncodeToString(md5Hash.Sum(nil)) + + // 2. 创建上传任务 createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err @@ -207,13 +225,16 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } } - // 2. 上传分片 - err = d.Upload(ctx, file, createResp, up) + // 3. 上传分片 + uploadProgress := func(p float64) { + up(40 + p*0.6) + } + err = d.Upload(ctx, file, createResp, uploadProgress) if err != nil { return nil, err } - // 3. 上传完毕 + // 4. 合并分片/完成上传 for range 60 { uploadCompleteResp, err := d.complete(createResp.Data.PreuploadID) // 返回错误代码未知,如:20103,文档也没有具体说 From 4aec42c8f9096cd52e4de851659a086c419e4df2 Mon Sep 17 00:00:00 2001 From: cyk Date: Mon, 29 Dec 2025 00:24:08 +0800 Subject: [PATCH 03/20] fix(driver): improve etag handling for file uploads --- drivers/123_open/driver.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index 1529b6f15..730d091ce 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -184,7 +184,31 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre return nil, fmt.Errorf("parse parentFileID error: %v", err) } - // 1. 流式计算MD5 + // etag 文件md5 + etag := file.GetHash().GetHash(utils.MD5) + if len(etag) >= utils.MD5.Width { + // 有etag时,先尝试秒传 + createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) + if err != nil { + return nil, err + } + // 是否秒传 + if createResp.Data.Reuse { + // 秒传成功才会返回正确的 FileID,否则为 0 + if createResp.Data.FileID != 0 { + return File{ + FileName: file.GetName(), + Size: file.GetSize(), + FileId: createResp.Data.FileID, + Type: 2, + Etag: etag, + }, nil + } + } + // 秒传失败,etag可能不可靠,继续流式计算真实MD5 + } + + // 流式计算MD5 md5Hash := utils.MD5.NewFunc() size := file.GetSize() chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk for MD5 calculation @@ -204,7 +228,7 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre up(progress) } - etag := hex.EncodeToString(md5Hash.Sum(nil)) + etag = hex.EncodeToString(md5Hash.Sum(nil)) // 2. 创建上传任务 createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) From d3f7c812f90857cc75c9b299c1c4778c0194dadf Mon Sep 17 00:00:00 2001 From: cyk Date: Mon, 29 Dec 2025 02:46:00 +0800 Subject: [PATCH 04/20] fix(driver): optimize SHA1 calculation for file uploads using chunked reading --- drivers/115_open/driver.go | 3 +- drivers/123_open/driver.go | 28 +++----------- internal/stream/stream.go | 32 ++++----------- internal/stream/util.go | 79 +++++++++++++++++++++++--------------- 4 files changed, 64 insertions(+), 78 deletions(-) diff --git a/drivers/115_open/driver.go b/drivers/115_open/driver.go index 909bf4a99..f9d5027bf 100644 --- a/drivers/115_open/driver.go +++ b/drivers/115_open/driver.go @@ -228,7 +228,8 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre } sha1 := file.GetHash().GetHash(utils.SHA1) if len(sha1) != utils.SHA1.Width { - _, sha1, err = stream.CacheFullAndHash(file, &up, utils.SHA1) + // 流式计算SHA1 + sha1, err = stream.StreamHashFile(file, utils.SHA1, 100, &up) if err != nil { return err } diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index 730d091ce..7cf3cfe46 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -2,9 +2,7 @@ package _123_open import ( "context" - "encoding/hex" "fmt" - "io" "strconv" "time" @@ -12,7 +10,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" + "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" ) @@ -208,28 +206,12 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre // 秒传失败,etag可能不可靠,继续流式计算真实MD5 } - // 流式计算MD5 - md5Hash := utils.MD5.NewFunc() - size := file.GetSize() - chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk for MD5 calculation - var offset int64 = 0 - for offset < size { - readSize := min(chunkSize, size-offset) - reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) - if err != nil { - return nil, fmt.Errorf("range read for MD5 calculation failed: %w", err) - } - if _, err := io.Copy(md5Hash, reader); err != nil { - return nil, fmt.Errorf("calculate MD5 failed: %w", err) - } - offset += readSize - - progress := 40 * float64(offset) / float64(size) - up(progress) + // 流式MD5计算 + etag, err = stream.StreamHashFile(file, utils.MD5, 40, &up) + if err != nil { + return nil, err } - etag = hex.EncodeToString(md5Hash.Sum(nil)) - // 2. 创建上传任务 createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 4c8238100..c29dbbec3 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -211,7 +211,9 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { return io.NewSectionReader(f.GetFile(), httpRange.Start, httpRange.Length), nil } - cache, err := f.cache(httpRange.Start + httpRange.Length) + // 限制缓存大小,避免累积缓存整个文件 + maxCache := min(httpRange.Start+httpRange.Length, int64(conf.MaxBufferLimit)) + cache, err := f.cache(maxCache) if err != nil { return nil, err } @@ -224,31 +226,13 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 // 确保指定大小的数据被缓存 +// 注意:此方法只缓存到 maxCacheSize,不会缓存整个文件 func (f *FileStream) cache(maxCacheSize int64) (model.File, error) { + // 限制缓存大小,避免超大文件占用过多资源 + // 如果需要缓存整个文件,应该显式调用 CacheFullAndWriter if maxCacheSize > int64(conf.MaxBufferLimit) { - size := f.GetSize() - reader := f.Reader - if f.peekBuff != nil { - size -= f.peekBuff.Size() - reader = f.oriReader - } - tmpF, err := utils.CreateTempFile(reader, size) - if err != nil { - return nil, err - } - f.Add(utils.CloseFunc(func() error { - return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name())) - })) - if f.peekBuff != nil { - peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF) - if err != nil { - return nil, err - } - f.Reader = peekF - return peekF, nil - } - f.Reader = tmpF - return tmpF, nil + // 不再创建整个文件的临时文件,只缓存到 MaxBufferLimit + maxCacheSize = int64(conf.MaxBufferLimit) } if f.peekBuff == nil { diff --git a/internal/stream/util.go b/internal/stream/util.go index 6aa3dda5d..d444200d3 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -174,6 +174,53 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT return tmpF, hex.EncodeToString(h.Sum(nil)), nil } +// StreamHashFile 流式计算文件哈希值,避免将整个文件加载到内存 +// file: 文件流 +// hashType: 哈希算法类型 +// progressWeight: 进度权重(0-100),用于计算整体进度 +// up: 进度回调函数 +func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressWeight float64, up *model.UpdateProgress) (string, error) { + // 如果已经有完整缓存文件,直接使用 + if cache := file.GetFile(); cache != nil { + hashFunc := hashType.NewFunc() + cache.Seek(0, io.SeekStart) + _, err := io.Copy(hashFunc, cache) + if err != nil { + return "", err + } + if up != nil && progressWeight > 0 { + (*up)(progressWeight) + } + return hex.EncodeToString(hashFunc.Sum(nil)), nil + } + + hashFunc := hashType.NewFunc() + size := file.GetSize() + chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk + var offset int64 = 0 + for offset < size { + readSize := chunkSize + if size-offset < chunkSize { + readSize = size - offset + } + reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) + if err != nil { + return "", fmt.Errorf("range read for hash calculation failed: %w", err) + } + if _, err := io.Copy(hashFunc, reader); err != nil { + return "", fmt.Errorf("calculate hash failed: %w", err) + } + offset += readSize + + if up != nil && progressWeight > 0 { + progress := progressWeight * float64(offset) / float64(size) + (*up)(progress) + } + } + + return hex.EncodeToString(hashFunc.Sum(nil)), nil +} + type StreamSectionReaderIF interface { // 线程不安全 GetSectionReader(off, length int64) (io.ReadSeeker, error) @@ -188,37 +235,9 @@ func NewStreamSectionReader(file model.FileStreamer, maxBufferSize int, up *mode } maxBufferSize = min(maxBufferSize, int(file.GetSize())) - if maxBufferSize > conf.MaxBufferLimit { - f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - - if f.Truncate(file.GetSize()) != nil { - // fallback to full cache - _, _ = f.Close(), os.Remove(f.Name()) - cache, err := file.CacheFullAndWriter(up, nil) - if err != nil { - return nil, err - } - return &cachedSectionReader{cache}, nil - } - - ss := &fileSectionReader{file: file, temp: f} - ss.bufPool = &pool.Pool[*offsetWriterWithBase]{ - New: func() *offsetWriterWithBase { - base := ss.tempOffset - ss.tempOffset += int64(maxBufferSize) - return &offsetWriterWithBase{io.NewOffsetWriter(ss.temp, base), base} - }, - } - file.Add(utils.CloseFunc(func() error { - ss.bufPool.Reset() - return errors.Join(ss.temp.Close(), os.Remove(ss.temp.Name())) - })) - return ss, nil - } + // 始终使用 directSectionReader,只在内存中缓存当前分片 + // 避免创建临时文件导致中间文件增长到整个文件大小 ss := &directSectionReader{file: file} if conf.MmapThreshold > 0 && maxBufferSize >= conf.MmapThreshold { ss.bufPool = &pool.Pool[[]byte]{ From c04406a8e767119b3efbd922842e101dfddf4092 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 28 Dec 2025 23:16:46 +0800 Subject: [PATCH 05/20] refactor(build): restrict builds to x64 architecture and simplify Docker workflow --- .github/workflows/test_docker.yml | 48 +++++++++++-------------------- build.sh | 20 ++++++------- 2 files changed, 27 insertions(+), 41 deletions(-) diff --git a/.github/workflows/test_docker.yml b/.github/workflows/test_docker.yml index aa6fe8966..1110599aa 100644 --- a/.github/workflows/test_docker.yml +++ b/.github/workflows/test_docker.yml @@ -1,5 +1,4 @@ name: Beta Release (Docker) - on: workflow_dispatch: push: @@ -7,51 +6,51 @@ on: - main pull_request: branches: - - main + - fix # 👈 允许你的 fix 分支触发 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: - DOCKERHUB_ORG_NAME: ${{ vars.DOCKERHUB_ORG_NAME || 'openlistteam' }} - GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'openlistteam' }} - IMAGE_NAME: openlist-git - IMAGE_NAME_DOCKERHUB: openlist + GHCR_ORG_NAME: ${{ vars.GHCR_ORG_NAME || 'ironboxplus' }} # 👈 最好改成你的用户名,防止推错地方 + IMAGE_NAME: openlist REGISTRY: ghcr.io ARTIFACT_NAME: 'binaries_docker_release' - RELEASE_PLATFORMS: 'linux/amd64,linux/arm64,linux/arm/v7,linux/386,linux/arm/v6,linux/ppc64le,linux/riscv64,linux/loong64' ### Temporarily disable Docker builds for linux/s390x architectures for unknown reasons. - IMAGE_PUSH: ${{ github.event_name == 'push' }} + # 👇 关键修改:只保留 linux/amd64,删掉后面一长串 + RELEASE_PLATFORMS: 'linux/amd64' + # 👇 关键修改:强制允许推送,不用管是不是 push 事件 + IMAGE_PUSH: 'true' IMAGE_TAGS_BETA: | type=ref,event=pr - type=raw,value=beta,enable={{is_default_branch}} + type=raw,value=beta jobs: build_binary: - name: Build Binaries for Docker Release (Beta) + name: Build Binaries (x64 Only) runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v4 - - uses: actions/setup-go@v5 with: go-version: '1.25.0' + # 即使只构建 x64,我们也需要 musl 工具链(因为 BuildDockerMultiplatform 默认会检查它) - name: Cache Musl id: cache-musl uses: actions/cache@v4 with: path: build/musl-libs key: docker-musl-libs-v2 - - name: Download Musl Library if: steps.cache-musl.outputs.cache-hit != 'true' run: bash build.sh prepare docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Build go binary (beta) + - name: Build go binary + # 这里还是跑 docker-multiplatform,虽然会多编译一些架构,但这是兼容 Dockerfile 路径最稳妥的方法 run: bash build.sh beta docker-multiplatform env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -69,12 +68,13 @@ jobs: release_docker: needs: build_binary - name: Release Docker image (Beta) + name: Release Docker (x64) runs-on: ubuntu-latest permissions: packages: write strategy: matrix: + # 你可以选择只构建 latest,或者保留全部变体 image: ["latest", "ffmpeg", "aria2", "aio"] include: - image: "latest" @@ -102,46 +102,32 @@ jobs: with: name: ${{ env.ARTIFACT_NAME }} path: 'build/' - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + # 👇 只保留 GitHub 登录,删除了 DockerHub 登录 - name: Login to GitHub Container Registry - if: env.IMAGE_PUSH == 'true' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Login to DockerHub Container Registry - if: env.IMAGE_PUSH == 'true' - uses: docker/login-action@v3 - with: - username: ${{ vars.DOCKERHUB_ORG_NAME_BACKUP || env.DOCKERHUB_ORG_NAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Docker meta id: meta uses: docker/metadata-action@v5 with: images: | ${{ env.REGISTRY }}/${{ env.GHCR_ORG_NAME }}/${{ env.IMAGE_NAME }} - ${{ env.DOCKERHUB_ORG_NAME }}/${{ env.IMAGE_NAME_DOCKERHUB }} tags: ${{ env.IMAGE_TAGS_BETA }} - flavor: | - ${{ matrix.tag_favor }} + flavor: ${{ matrix.tag_favor }} - name: Build and push - id: docker_build uses: docker/build-push-action@v6 with: context: . file: Dockerfile.ci - push: ${{ env.IMAGE_PUSH == 'true' }} + push: true build-args: | BASE_IMAGE_TAG=${{ matrix.base_image_tag }} ${{ matrix.build_arg }} diff --git a/build.sh b/build.sh index 26e5a301b..0e8f4b85d 100644 --- a/build.sh +++ b/build.sh @@ -186,8 +186,8 @@ BuildDockerMultiplatform() { docker_lflags="--extldflags '-static -fpic' $ldflags" export CGO_ENABLED=1 - OS_ARCHES=(linux-amd64 linux-arm64 linux-386 linux-riscv64 linux-ppc64le linux-loong64) ## Disable linux-s390x builds - CGO_ARGS=(x86_64-linux-musl-gcc aarch64-linux-musl-gcc i486-linux-musl-gcc riscv64-linux-musl-gcc powerpc64le-linux-musl-gcc loongarch64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds + OS_ARCHES=(linux-amd64) ## Disable linux-s390x builds + CGO_ARGS=(x86_64-linux-musl-gcc) ## Disable s390x-linux-musl-gcc builds for i in "${!OS_ARCHES[@]}"; do os_arch=${OS_ARCHES[$i]} cgo_cc=${CGO_ARGS[$i]} @@ -205,14 +205,14 @@ BuildDockerMultiplatform() { GO_ARM=(6 7) export GOOS=linux export GOARCH=arm - for i in "${!DOCKER_ARM_ARCHES[@]}"; do - docker_arch=${DOCKER_ARM_ARCHES[$i]} - cgo_cc=${CGO_ARGS[$i]} - export GOARM=${GO_ARM[$i]} - export CC=${cgo_cc} - echo "building for $docker_arch" - go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . - done + # for i in "${!DOCKER_ARM_ARCHES[@]}"; do + # docker_arch=${DOCKER_ARM_ARCHES[$i]} + # cgo_cc=${CGO_ARGS[$i]} + # export GOARM=${GO_ARM[$i]} + # export CC=${cgo_cc} + # echo "building for $docker_arch" + # go build -o build/${docker_arch%%-*}/${docker_arch##*-}/"$appName" -ldflags="$docker_lflags" -tags=jsoniter . + # done } BuildRelease() { From ce1e9d8a3ae0d8722da5cbdb8e0b97c049327e41 Mon Sep 17 00:00:00 2001 From: cyk Date: Thu, 1 Jan 2026 19:03:17 +0800 Subject: [PATCH 06/20] feat: Implement streaming upload for Baidu Netdisk - Added `upload.go` to handle streaming uploads without temporary file caching. - Introduced `calculateHashesStream` for efficient MD5 hash calculation during upload. - Implemented `uploadChunksStream` for concurrent chunk uploads using `StreamSectionReader`. - Refactored `uploadSliceStream` to accept `io.ReadSeeker` for better flexibility. - Enhanced error handling for upload ID expiration with retry logic. - Updated documentation to reflect changes in upload process and architecture. --- drivers/baidu_netdisk/driver.go | 216 ++++---------------------- drivers/baidu_netdisk/upload.go | 262 ++++++++++++++++++++++++++++++++ 2 files changed, 293 insertions(+), 185 deletions(-) create mode 100644 drivers/baidu_netdisk/upload.go diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index fe77aca38..7021f21c9 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -1,30 +1,18 @@ package baidu_netdisk import ( - "bytes" "context" - "crypto/md5" - "encoding/hex" "errors" - "io" - "mime/multipart" - "net/http" "net/url" - "os" stdpath "path" "strconv" - "strings" "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" - "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/driver" - "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" - "github.com/OpenListTeam/OpenList/v4/internal/net" - "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/utils" - "github.com/avast/retry-go" log "github.com/sirupsen/logrus" ) @@ -199,80 +187,26 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return newObj, nil } - var ( - cache = stream.GetFile() - tmpF *os.File - err error - ) - if cache == nil { - tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return nil, err - } - defer func() { - _ = tmpF.Close() - _ = os.Remove(tmpF.Name()) - }() - cache = tmpF - } - streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) count := 1 if streamSize > sliceSize { count = int((streamSize + sliceSize - 1) / sliceSize) } - lastBlockSize := streamSize % sliceSize - if lastBlockSize == 0 { - lastBlockSize = sliceSize - } - - // cal md5 for first 256k data - const SliceSize int64 = 256 * utils.KB - blockList := make([]string, 0, count) - byteSize := sliceSize - fileMd5H := md5.New() - sliceMd5H := md5.New() - sliceMd5H2 := md5.New() - slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) - writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} - if tmpF != nil { - writers = append(writers, tmpF) - } - written := int64(0) - for i := 1; i <= count; i++ { - if utils.IsCanceled(ctx) { - return nil, ctx.Err() - } - if i == count { - byteSize = lastBlockSize - } - n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) - written += n - if err != nil && err != io.EOF { - return nil, err - } - blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) - sliceMd5H.Reset() - } - if tmpF != nil { - if written != streamSize { - return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize) - } - _, err = tmpF.Seek(0, io.SeekStart) - if err != nil { - return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") - } - } - contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) - sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) - blockListStr, _ := utils.Json.MarshalToString(blockList) path := stdpath.Join(dstDir.GetPath(), stream.GetName()) mtime := stream.ModTime().Unix() ctime := stream.CreateTime().Unix() - // step.1 尝试读取已保存进度 + // step.1 流式计算MD5哈希值 + contentMd5, sliceMd5, blockList, ss, err := d.calculateHashesStream(ctx, stream, sliceSize, &up) + if err != nil { + return nil, err + } + + blockListStr, _ := utils.Json.MarshalToString(blockList) + + // step.2 尝试读取已保存进度或执行预上传 precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5) if !ok { // 没有进度,走预上传 @@ -288,6 +222,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return fileToObj(precreateResp.File), nil } } + ensureUploadURL := func() { if precreateResp.UploadURL != "" { return @@ -295,58 +230,24 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) } - // step.2 上传分片 + // step.3 流式上传分片 + // 由于流式上传已经消耗了流,需要重新创建 StreamSectionReader + // 如果有缓存文件,可以直接使用;否则需要通过 RangeRead 重新获取 + if ss == nil || stream.GetFile() == nil { + // 重新创建 StreamSectionReader + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } + } + uploadLoop: for range 2 { // 获取上传域名 ensureUploadURL() - // 并发上传 - threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, - retry.Attempts(UPLOAD_RETRY_COUNT), - retry.Delay(UPLOAD_RETRY_WAIT_TIME), - retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), - retry.DelayType(retry.BackOffDelay), - retry.RetryIf(func(err error) bool { - return !errors.Is(err, ErrUploadIDExpired) - }), - retry.LastErrorOnly(true)) - - totalParts := len(precreateResp.BlockList) - - for i, partseq := range precreateResp.BlockList { - if utils.IsCanceled(upCtx) { - break - } - if partseq < 0 { - continue - } - i, partseq := i, partseq - offset, size := int64(partseq)*sliceSize, sliceSize - if partseq+1 == count { - size = lastBlockSize - } - threadG.Go(func(ctx context.Context) error { - params := map[string]string{ - "method": "upload", - "access_token": d.AccessToken, - "type": "tmpfile", - "path": path, - "uploadid": precreateResp.Uploadid, - "partseq": strconv.Itoa(partseq), - } - section := io.NewSectionReader(cache, offset, size) - err := d.uploadSlice(ctx, precreateResp.UploadURL, params, stream.GetName(), section) - if err != nil { - return err - } - precreateResp.BlockList[i] = -1 - progress := float64(threadG.Success()+1) * 100 / float64(totalParts+1) - up(progress) - return nil - }) - } - err = threadG.Wait() + // 流式并发上传 + err = d.uploadChunksStream(ctx, ss, stream, precreateResp, path, sliceSize, count, up) if err == nil { break uploadLoop } @@ -372,13 +273,19 @@ uploadLoop: precreateResp.UploadURL = "" // 覆盖掉旧的进度 base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + + // 尝试重新创建 StreamSectionReader(如果流支持重新读取) + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err + } continue uploadLoop } return nil, err } defer up(100) - // step.3 创建文件 + // step.4 创建文件 var newFile File _, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime) if err != nil { @@ -427,67 +334,6 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in return &precreateResp, nil } -func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file *io.SectionReader) error { - b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) - mw := multipart.NewWriter(b) - _, err := mw.CreateFormFile("file", fileName) - if err != nil { - return err - } - headSize := b.Len() - err = mw.Close() - if err != nil { - return err - } - head := bytes.NewReader(b.Bytes()[:headSize]) - tail := bytes.NewReader(b.Bytes()[headSize:]) - rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, file, tail)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) - if err != nil { - return err - } - query := req.URL.Query() - for k, v := range params { - query.Set(k, v) - } - req.URL.RawQuery = query.Encode() - req.Header.Set("Content-Type", mw.FormDataContentType()) - req.ContentLength = int64(b.Len()) + file.Size() - - client := net.NewHttpClient() - if d.UploadSliceTimeout > 0 { - client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) - } else { - client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT - } - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - b.Reset() - _, err = b.ReadFrom(resp.Body) - if err != nil { - return err - } - body := b.Bytes() - respStr := string(body) - log.Debugln(respStr) - lower := strings.ToLower(respStr) - // 合并 uploadid 过期检测逻辑 - if strings.Contains(lower, "uploadid") && - (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { - return ErrUploadIDExpired - } - - errCode := utils.Json.Get(body, "error_code").ToInt() - errNo := utils.Json.Get(body, "errno").ToInt() - if errCode != 0 || errNo != 0 { - return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) - } - return nil -} func (d *BaiduNetdisk) GetDetails(ctx context.Context) (*model.StorageDetails, error) { du, err := d.quota(ctx) diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go new file mode 100644 index 000000000..cc94d838e --- /dev/null +++ b/drivers/baidu_netdisk/upload.go @@ -0,0 +1,262 @@ +package baidu_netdisk + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "errors" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + "time" + + "github.com/OpenListTeam/OpenList/v4/internal/driver" + "github.com/OpenListTeam/OpenList/v4/internal/errs" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/net" + streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" + "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" +) + +// calculateHashesStream 流式计算文件的MD5哈希值 +// 返回:文件MD5、前256KB的MD5、每个分片的MD5列表、StreamSectionReader +func (d *BaiduNetdisk) calculateHashesStream( + ctx context.Context, + stream model.FileStreamer, + sliceSize int64, + up *driver.UpdateProgress, +) (contentMd5 string, sliceMd5 string, blockList []string, ss streamPkg.StreamSectionReaderIF, err error) { + streamSize := stream.GetSize() + count := 1 + if streamSize > sliceSize { + count = int((streamSize + sliceSize - 1) / sliceSize) + } + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 创建 StreamSectionReader 用于流式读取 + ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), nil) + if err != nil { + return "", "", nil, nil, err + } + + // 前256KB的MD5 + const SliceSize int64 = 256 * utils.KB + blockList = make([]string, 0, count) + fileMd5H := md5.New() + sliceMd5H2 := md5.New() + sliceWritten := int64(0) + + for i := 0; i < count; i++ { + if utils.IsCanceled(ctx) { + return "", "", nil, nil, ctx.Err() + } + + offset := int64(i) * sliceSize + length := sliceSize + if i == count-1 { + length = lastBlockSize + } + + reader, err := ss.GetSectionReader(offset, length) + if err != nil { + return "", "", nil, nil, err + } + + // 计算分片MD5 + sliceMd5Calc := md5.New() + + // 同时写入多个哈希计算器 + writers := []io.Writer{fileMd5H, sliceMd5Calc} + if sliceWritten < SliceSize { + remaining := SliceSize - sliceWritten + writers = append(writers, utils.LimitWriter(sliceMd5H2, remaining)) + } + + reader.Seek(0, io.SeekStart) + n, err := io.Copy(io.MultiWriter(writers...), reader) + if err != nil { + ss.FreeSectionReader(reader) + return "", "", nil, nil, err + } + sliceWritten += n + + blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) + ss.FreeSectionReader(reader) + + // 更新进度(哈希计算占总进度的一小部分) + if up != nil { + progress := float64(i+1) * 10 / float64(count) + (*up)(progress) + } + } + + return hex.EncodeToString(fileMd5H.Sum(nil)), + hex.EncodeToString(sliceMd5H2.Sum(nil)), + blockList, ss, nil +} + +// uploadChunksStream 流式上传所有分片 +func (d *BaiduNetdisk) uploadChunksStream( + ctx context.Context, + ss streamPkg.StreamSectionReaderIF, + stream model.FileStreamer, + precreateResp *PrecreateResp, + path string, + sliceSize int64, + count int, + up driver.UpdateProgress, +) error { + streamSize := stream.GetSize() + lastBlockSize := streamSize % sliceSize + if lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + // 使用 OrderedGroup 保证 Before 阶段有序 + thread := min(d.uploadThread, len(precreateResp.BlockList)) + threadG, upCtx := errgroup.NewOrderedGroupWithContext(ctx, thread, + retry.Attempts(UPLOAD_RETRY_COUNT), + retry.Delay(UPLOAD_RETRY_WAIT_TIME), + retry.MaxDelay(UPLOAD_RETRY_MAX_WAIT_TIME), + retry.DelayType(retry.BackOffDelay), + retry.RetryIf(func(err error) bool { + return !errors.Is(err, ErrUploadIDExpired) + }), + retry.LastErrorOnly(true)) + + totalParts := len(precreateResp.BlockList) + + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + if partseq < 0 { + continue + } + + i, partseq := i, partseq + offset := int64(partseq) * sliceSize + size := sliceSize + if partseq+1 == count { + size = lastBlockSize + } + + var reader io.ReadSeeker + + threadG.GoWithLifecycle(errgroup.Lifecycle{ + Before: func(ctx context.Context) error { + var err error + reader, err = ss.GetSectionReader(offset, size) + return err + }, + Do: func(ctx context.Context) error { + reader.Seek(0, io.SeekStart) + err := d.uploadSliceStream(ctx, precreateResp.UploadURL, path, + precreateResp.Uploadid, partseq, stream.GetName(), reader, size) + if err != nil { + return err + } + precreateResp.BlockList[i] = -1 + // 进度从10%开始(前10%是哈希计算) + progress := 10 + float64(threadG.Success()+1)*90/float64(totalParts+1) + up(progress) + return nil + }, + After: func(err error) { + ss.FreeSectionReader(reader) + }, + }) + } + + return threadG.Wait() +} + +// uploadSliceStream 上传单个分片(接受io.ReadSeeker) +func (d *BaiduNetdisk) uploadSliceStream( + ctx context.Context, + uploadUrl string, + path string, + uploadid string, + partseq int, + fileName string, + reader io.ReadSeeker, + size int64, +) error { + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": uploadid, + "partseq": strconv.Itoa(partseq), + } + + b := bytes.NewBuffer(make([]byte, 0, bytes.MinRead)) + mw := multipart.NewWriter(b) + _, err := mw.CreateFormFile("file", fileName) + if err != nil { + return err + } + headSize := b.Len() + err = mw.Close() + if err != nil { + return err + } + head := bytes.NewReader(b.Bytes()[:headSize]) + tail := bytes.NewReader(b.Bytes()[headSize:]) + rateLimitedRd := driver.NewLimitedUploadStream(ctx, io.MultiReader(head, reader, tail)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadUrl+"/rest/2.0/pcs/superfile2", rateLimitedRd) + if err != nil { + return err + } + query := req.URL.Query() + for k, v := range params { + query.Set(k, v) + } + req.URL.RawQuery = query.Encode() + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = int64(b.Len()) + size + + client := net.NewHttpClient() + if d.UploadSliceTimeout > 0 { + client.Timeout = time.Second * time.Duration(d.UploadSliceTimeout) + } else { + client.Timeout = DEFAULT_UPLOAD_SLICE_TIMEOUT + } + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + b.Reset() + _, err = b.ReadFrom(resp.Body) + if err != nil { + return err + } + body := b.Bytes() + respStr := string(body) + log.Debugln(respStr) + lower := strings.ToLower(respStr) + // 合并 uploadid 过期检测逻辑 + if strings.Contains(lower, "uploadid") && + (strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) { + return ErrUploadIDExpired + } + + errCode := utils.Json.Get(body, "error_code").ToInt() + errNo := utils.Json.Get(body, "errno").ToInt() + if errCode != 0 || errNo != 0 { + return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) + } + return nil +} \ No newline at end of file From 8402100084c7496aa17e00b5c5f1645700d9dd57 Mon Sep 17 00:00:00 2001 From: cyk Date: Thu, 1 Jan 2026 19:47:12 +0800 Subject: [PATCH 07/20] fix(driver): optimize MD5 hash calculation and stream handling for uploads --- drivers/baidu_netdisk/driver.go | 17 ++++++----------- drivers/baidu_netdisk/upload.go | 32 +++++++++++++++----------------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 7021f21c9..1290279da 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -198,8 +198,8 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F mtime := stream.ModTime().Unix() ctime := stream.CreateTime().Unix() - // step.1 流式计算MD5哈希值 - contentMd5, sliceMd5, blockList, ss, err := d.calculateHashesStream(ctx, stream, sliceSize, &up) + // step.1 流式计算MD5哈希值(使用 RangeRead,不会消耗流) + contentMd5, sliceMd5, blockList, err := d.calculateHashesStream(ctx, stream, sliceSize, &up) if err != nil { return nil, err } @@ -231,14 +231,10 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F } // step.3 流式上传分片 - // 由于流式上传已经消耗了流,需要重新创建 StreamSectionReader - // 如果有缓存文件,可以直接使用;否则需要通过 RangeRead 重新获取 - if ss == nil || stream.GetFile() == nil { - // 重新创建 StreamSectionReader - ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) - if err != nil { - return nil, err - } + // 创建 StreamSectionReader 用于上传 + ss, err := streamPkg.NewStreamSectionReader(stream, int(sliceSize), &up) + if err != nil { + return nil, err } uploadLoop: @@ -334,7 +330,6 @@ func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize in return &precreateResp, nil } - func (d *BaiduNetdisk) GetDetails(ctx context.Context) (*model.StorageDetails, error) { du, err := d.quota(ctx) if err != nil { diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go index cc94d838e..54b1f29bb 100644 --- a/drivers/baidu_netdisk/upload.go +++ b/drivers/baidu_netdisk/upload.go @@ -19,19 +19,21 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/net" streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" log "github.com/sirupsen/logrus" ) // calculateHashesStream 流式计算文件的MD5哈希值 -// 返回:文件MD5、前256KB的MD5、每个分片的MD5列表、StreamSectionReader +// 返回:文件MD5、前256KB的MD5、每个分片的MD5列表 +// 注意:此函数使用 RangeRead 读取数据,不会消耗流 func (d *BaiduNetdisk) calculateHashesStream( ctx context.Context, stream model.FileStreamer, sliceSize int64, up *driver.UpdateProgress, -) (contentMd5 string, sliceMd5 string, blockList []string, ss streamPkg.StreamSectionReaderIF, err error) { +) (contentMd5 string, sliceMd5 string, blockList []string, err error) { streamSize := stream.GetSize() count := 1 if streamSize > sliceSize { @@ -42,12 +44,6 @@ func (d *BaiduNetdisk) calculateHashesStream( lastBlockSize = sliceSize } - // 创建 StreamSectionReader 用于流式读取 - ss, err = streamPkg.NewStreamSectionReader(stream, int(sliceSize), nil) - if err != nil { - return "", "", nil, nil, err - } - // 前256KB的MD5 const SliceSize int64 = 256 * utils.KB blockList = make([]string, 0, count) @@ -57,7 +53,7 @@ func (d *BaiduNetdisk) calculateHashesStream( for i := 0; i < count; i++ { if utils.IsCanceled(ctx) { - return "", "", nil, nil, ctx.Err() + return "", "", nil, ctx.Err() } offset := int64(i) * sliceSize @@ -66,9 +62,10 @@ func (d *BaiduNetdisk) calculateHashesStream( length = lastBlockSize } - reader, err := ss.GetSectionReader(offset, length) + // 使用 RangeRead 读取数据,不会消耗流 + reader, err := stream.RangeRead(http_range.Range{Start: offset, Length: length}) if err != nil { - return "", "", nil, nil, err + return "", "", nil, err } // 计算分片MD5 @@ -81,16 +78,17 @@ func (d *BaiduNetdisk) calculateHashesStream( writers = append(writers, utils.LimitWriter(sliceMd5H2, remaining)) } - reader.Seek(0, io.SeekStart) n, err := io.Copy(io.MultiWriter(writers...), reader) + // 关闭 reader(如果是 ReadCloser) + if rc, ok := reader.(io.Closer); ok { + rc.Close() + } if err != nil { - ss.FreeSectionReader(reader) - return "", "", nil, nil, err + return "", "", nil, err } sliceWritten += n blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) - ss.FreeSectionReader(reader) // 更新进度(哈希计算占总进度的一小部分) if up != nil { @@ -101,7 +99,7 @@ func (d *BaiduNetdisk) calculateHashesStream( return hex.EncodeToString(fileMd5H.Sum(nil)), hex.EncodeToString(sliceMd5H2.Sum(nil)), - blockList, ss, nil + blockList, nil } // uploadChunksStream 流式上传所有分片 @@ -259,4 +257,4 @@ func (d *BaiduNetdisk) uploadSliceStream( return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", respStr) } return nil -} \ No newline at end of file +} From 2b2bd18fb637d8edb66cfc77da7a1ee36c0b3e51 Mon Sep 17 00:00:00 2001 From: cyk Date: Fri, 2 Jan 2026 01:43:03 +0800 Subject: [PATCH 08/20] fix(workflow): update beta image tag to remove unnecessary suffix --- .github/workflows/test_docker.yml | 2 +- internal/stream/util.go | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test_docker.yml b/.github/workflows/test_docker.yml index 1110599aa..a3ca52258 100644 --- a/.github/workflows/test_docker.yml +++ b/.github/workflows/test_docker.yml @@ -23,7 +23,7 @@ env: IMAGE_PUSH: 'true' IMAGE_TAGS_BETA: | type=ref,event=pr - type=raw,value=beta + type=raw,value=beta-retry jobs: build_binary: diff --git a/internal/stream/util.go b/internal/stream/util.go index d444200d3..83b20da71 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -198,18 +198,34 @@ func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressW size := file.GetSize() chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk var offset int64 = 0 + const maxRetries = 3 for offset < size { readSize := chunkSize if size-offset < chunkSize { readSize = size - offset } - reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) - if err != nil { - return "", fmt.Errorf("range read for hash calculation failed: %w", err) + + var lastErr error + for retry := 0; retry < maxRetries; retry++ { + reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) + if err != nil { + lastErr = fmt.Errorf("range read for hash calculation failed: %w", err) + continue + } + _, err = io.Copy(hashFunc, reader) + if closer, ok := reader.(io.Closer); ok { + closer.Close() + } + if err == nil { + lastErr = nil + break + } + lastErr = fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) } - if _, err := io.Copy(hashFunc, reader); err != nil { - return "", fmt.Errorf("calculate hash failed: %w", err) + if lastErr != nil { + return "", lastErr } + offset += readSize if up != nil && progressWeight > 0 { From 122bd1d6281435a7292dfae6a0ff900fef1abe92 Mon Sep 17 00:00:00 2001 From: cyk Date: Fri, 2 Jan 2026 02:16:01 +0800 Subject: [PATCH 09/20] feat(stream): add readFullWithRangeReadFallback function for improved data reading --- internal/stream/util.go | 92 ++++++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/internal/stream/util.go b/internal/stream/util.go index 83b20da71..ddd215a43 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "os" + "time" "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/errs" @@ -174,6 +175,45 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT return tmpF, hex.EncodeToString(h.Sum(nil)), nil } +// readFullWithRangeRead 使用 RangeRead 从文件流中读取数据到 buf +// file: 文件流 +// buf: 目标缓冲区 +// off: 读取的起始偏移量 +// 返回值: 实际读取的字节数和错误 +// 支持自动重试(最多3次),每次重试之间有递增延迟(3秒、6秒、9秒) +func readFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, error) { + length := int64(len(buf)) + var lastErr error + + // 重试最多3次 + for retry := 0; retry < 3; retry++ { + reader, err := file.RangeRead(http_range.Range{Start: off, Length: length}) + if err != nil { + lastErr = fmt.Errorf("RangeRead failed at offset %d: %w", off, err) + log.Debugf("RangeRead retry %d failed: %v", retry+1, lastErr) + // 递增延迟:3秒、6秒、9秒,等待代理恢复 + time.Sleep(time.Duration(retry+1) * 3 * time.Second) + continue + } + + n, err := io.ReadFull(reader, buf) + if closer, ok := reader.(io.Closer); ok { + closer.Close() + } + + if err == nil { + return n, nil + } + + lastErr = fmt.Errorf("failed to read all data via RangeRead at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) + log.Debugf("RangeRead retry %d read failed: %v", retry+1, lastErr) + // 递增延迟:3秒、6秒、9秒,等待网络恢复 + time.Sleep(time.Duration(retry+1) * 3 * time.Second) + } + + return 0, lastErr +} + // StreamHashFile 流式计算文件哈希值,避免将整个文件加载到内存 // file: 文件流 // hashType: 哈希算法类型 @@ -197,36 +237,28 @@ func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressW hashFunc := hashType.NewFunc() size := file.GetSize() chunkSize := int64(10 * 1024 * 1024) // 10MB per chunk + buf := make([]byte, chunkSize) var offset int64 = 0 - const maxRetries = 3 + for offset < size { readSize := chunkSize if size-offset < chunkSize { readSize = size - offset } - var lastErr error - for retry := 0; retry < maxRetries; retry++ { - reader, err := file.RangeRead(http_range.Range{Start: offset, Length: readSize}) + // 首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) + n, err := io.ReadFull(file, buf[:readSize]) + if err != nil { + // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) + log.Warnf("StreamHashFile: sequential read failed at offset %d, retrying with RangeRead: %v", offset, err) + n, err = readFullWithRangeRead(file, buf[:readSize], offset) if err != nil { - lastErr = fmt.Errorf("range read for hash calculation failed: %w", err) - continue - } - _, err = io.Copy(hashFunc, reader) - if closer, ok := reader.(io.Closer); ok { - closer.Close() + return "", fmt.Errorf("calculate hash failed at offset %d: sequential read and RangeRead both failed: %w", offset, err) } - if err == nil { - lastErr = nil - break - } - lastErr = fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) - } - if lastErr != nil { - return "", lastErr } - offset += readSize + hashFunc.Write(buf[:n]) + offset += int64(n) if up != nil && progressWeight > 0 { progress := progressWeight * float64(offset) / float64(size) @@ -381,12 +413,26 @@ func (ss *directSectionReader) GetSectionReader(off, length int64) (io.ReadSeeke } tempBuf := ss.bufPool.Get() buf := tempBuf[:length] + + // 首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) + // 对于 FileStream,RangeRead 会消耗底层 oriReader,所以必须先尝试顺序流读取 n, err := io.ReadFull(ss.file, buf) - ss.fileOffset += int64(n) - if int64(n) != length { - return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", length, n, err) + if err == nil { + ss.fileOffset = off + int64(n) + return &bufferSectionReader{bytes.NewReader(buf), tempBuf}, nil } - return &bufferSectionReader{bytes.NewReader(buf), buf}, nil + + // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) + log.Debugf("Sequential read failed at offset %d, retrying with RangeRead: %v", off, err) + n, err = readFullWithRangeRead(ss.file, buf, off) + if err != nil { + ss.bufPool.Put(tempBuf) + return nil, fmt.Errorf("both sequential read and RangeRead failed at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) + } + + // 更新 fileOffset + ss.fileOffset = off + int64(n) + return &bufferSectionReader{bytes.NewReader(buf), tempBuf}, nil } func (ss *directSectionReader) FreeSectionReader(rs io.ReadSeeker) { if sr, ok := rs.(*bufferSectionReader); ok { From c16759aa9d18be4482034c5141089dd990e5bbbb Mon Sep 17 00:00:00 2001 From: cyk Date: Sat, 3 Jan 2026 00:27:29 +0800 Subject: [PATCH 10/20] feat(upload): enhance token handling and bucket creation for OSS uploads --- drivers/115_open/upload.go | 44 +++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/drivers/115_open/upload.go b/drivers/115_open/upload.go index 3575678c2..292c8371b 100644 --- a/drivers/115_open/upload.go +++ b/drivers/115_open/upload.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "io" + "strings" "time" sdk "github.com/OpenListTeam/115-sdk-go" @@ -13,8 +14,19 @@ import ( "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" ) +// isTokenExpiredError 检测是否为OSS凭证过期错误 +func isTokenExpiredError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "SecurityTokenExpired") || + strings.Contains(errStr, "InvalidAccessKeyId") +} + func calPartSize(fileSize int64) int64 { var partSize int64 = 20 * utils.MB if fileSize > partSize { @@ -70,11 +82,16 @@ func (d *Open115) singleUpload(ctx context.Context, tempF model.File, tokenResp // } func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress, tokenResp *sdk.UploadGetTokenResp, initResp *sdk.UploadInitResp) error { - ossClient, err := oss.New(tokenResp.Endpoint, tokenResp.AccessKeyId, tokenResp.AccessKeySecret, oss.SecurityToken(tokenResp.SecurityToken)) - if err != nil { - return err + // 创建OSS客户端的辅助函数 + createBucket := func(token *sdk.UploadGetTokenResp) (*oss.Bucket, error) { + ossClient, err := oss.New(token.Endpoint, token.AccessKeyId, token.AccessKeySecret, oss.SecurityToken(token.SecurityToken)) + if err != nil { + return nil, err + } + return ossClient.Bucket(initResp.Bucket) } - bucket, err := ossClient.Bucket(initResp.Bucket) + + bucket, err := createBucket(tokenResp) if err != nil { return err } @@ -119,7 +136,24 @@ func (d *Open115) multpartUpload(ctx context.Context, stream model.FileStreamer, retry.Context(ctx), retry.Attempts(3), retry.DelayType(retry.BackOffDelay), - retry.Delay(time.Second)) + retry.Delay(time.Second), + retry.OnRetry(func(n uint, err error) { + // 如果是凭证过期错误,在重试前刷新凭证并重建bucket + if isTokenExpiredError(err) { + log.Warnf("115 OSS token expired, refreshing token...") + if newToken, refreshErr := d.client.UploadGetToken(ctx); refreshErr == nil { + tokenResp = newToken + if newBucket, bucketErr := createBucket(tokenResp); bucketErr == nil { + bucket = newBucket + log.Infof("115 OSS token refreshed successfully") + } else { + log.Errorf("Failed to create new bucket with refreshed token: %v", bucketErr) + } + } else { + log.Errorf("Failed to refresh 115 OSS token: %v", refreshErr) + } + } + })) ss.FreeSectionReader(rd) if err != nil { return err From 399b43c9ce733e45f79608b33263d96ef148b813 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 00:30:41 +0800 Subject: [PATCH 11/20] feat(link): add retry logic with timeout for HEAD requests in linkOfficial function --- drivers/baidu_netdisk/util.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go index 8f55911d1..09341e9f6 100644 --- a/drivers/baidu_netdisk/util.go +++ b/drivers/baidu_netdisk/util.go @@ -204,7 +204,24 @@ func (d *BaiduNetdisk) linkOfficial(file model.Obj, _ model.LinkArgs) (*model.Li return nil, err } u := fmt.Sprintf("%s&access_token=%s", resp.List[0].Dlink, d.AccessToken) - res, err := base.NoRedirectClient.R().SetHeader("User-Agent", "pan.baidu.com").Head(u) + + // Retry HEAD request with longer timeout to avoid client-side errors + // Create a client with longer timeout (base.NoRedirectClient doesn't have timeout set) + client := base.NoRedirectClient.SetTimeout(60 * time.Second) + var res *resty.Response + maxRetries := 5 + for i := 0; i < maxRetries; i++ { + res, err = client.R(). + SetHeader("User-Agent", "pan.baidu.com"). + Head(u) + if err == nil { + break + } + if i < maxRetries-1 { + log.Warnf("HEAD request failed (attempt %d/%d): %v, retrying...", i+1, maxRetries, err) + time.Sleep(time.Duration(i+1) * 2 * time.Second) // Exponential backoff: 2s, 4s, 6s, 8s + } + } if err != nil { return nil, err } From b1b83ffa0e34cb0865c5d98f128bbc1897fefba0 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 12:17:22 +0800 Subject: [PATCH 12/20] fix(alias): update storage retrieval method in listRoot function --- drivers/alias/util.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drivers/alias/util.go b/drivers/alias/util.go index 23b10e994..54d10939d 100644 --- a/drivers/alias/util.go +++ b/drivers/alias/util.go @@ -40,7 +40,7 @@ func (d *Alias) listRoot(ctx context.Context, withDetails, refresh bool) []model if !withDetails || len(v) != 1 { continue } - remoteDriver, err := op.GetStorageByMountPath(v[0]) + remoteDriver, err := fs.GetStorage(v[0], &fs.GetStoragesArgs{}) if err != nil { continue } From 4acab1bc111d19b15812fd1066d9753f8b147a58 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 12:37:51 +0800 Subject: [PATCH 13/20] feat(docs): add CLAUDE.md for project guidance and development instructions --- CLAUDE.md | 297 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..6a1e1461c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,297 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build and Development Commands + +```bash +# Development +go run main.go # Run backend server (default port 5244) +air # Hot reload during development (uses .air.toml) +./build.sh dev # Build development version with frontend +./build.sh release # Build release version + +# Testing +go test ./... # Run all tests + +# Docker +docker-compose up # Run with docker-compose +docker build -f Dockerfile . # Build docker image +``` + +**Build Script Details** (`build.sh`): +- Fetches frontend from OpenListTeam/OpenList-Frontend releases +- Injects version info via ldflags: `-X "github.com/OpenListTeam/OpenList/v4/internal/conf.BuiltAt=$(date +'%F %T %z')"` +- Supports `dev`, `beta`, and release builds +- Downloads prebuilt frontend distribution automatically + +**Go Version**: Requires Go 1.23.4+ + +## Architecture Overview + +### Driver System (Storage Abstraction) + +OpenList uses a **driver pattern** to support 70+ cloud storage providers. Each driver implements the core `Driver` interface. + +**Location**: `drivers/*/` + +**Core Interfaces** (`internal/driver/driver.go`): +- `Reader`: List directories, generate download links (REQUIRED) +- `Writer`: Upload, delete, move files (optional) +- `ArchiveDriver`: Extract archives (optional) +- `LinkCacheModeResolver`: Custom cache TTL strategies (optional) + +**Driver Registration Pattern**: +```go +// In drivers/your_driver/meta.go +var config = driver.Config{ + Name: "YourDriver", + LocalSort: false, + NoCache: false, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &YourDriver{} + }) +} +``` + +**Adding a New Driver**: +1. Copy `drivers/template/` to `drivers/your_driver/` +2. Implement `List()` and `Link()` methods (required) +3. Define `Addition` struct with configuration fields using struct tags: + - `json:"field_name"` - JSON field name + - `type:"select"` - Input type (select, string, text, bool, number) + - `required:"true"` - Required field + - `options:"a,b,c"` - Dropdown options + - `default:"value"` - Default value +4. Register driver in `init()` function + +**Example Driver Structure**: +```go +type YourDriver struct { + model.Storage + Addition + client *YourClient +} + +func (d *YourDriver) Init(ctx context.Context) error { + // Initialize client, login, etc. +} + +func (d *YourDriver) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // Return list of files/folders +} + +func (d *YourDriver) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // Return download URL or RangeReader +} +``` + +### Request Flow + +``` +HTTP Request (Gin Router) + ↓ +Middleware (Auth, CORS, Logging) + ↓ +Handler (server/handles/) + ↓ +fs.List/Get/Link (mount path → storage path conversion) + ↓ +op.List/Get/Link (caching, driver lookup) + ↓ +Driver.List/Link (storage-specific API calls) + ↓ +Response (JSON / Proxy / Redirect) +``` + +### Internal Package Structure + +| Package | Purpose | +|---------|---------| +| `bootstrap/` | Initialization sequence: config, DB, storages, servers | +| `conf/` | Configuration management | +| `db/` | Database models (SQLite/MySQL/Postgres) | +| `driver/` | Driver interface definitions | +| `fs/` | Mount path abstraction (converts `/mount/path` to storage + path) | +| `op/` | Core operations with caching and driver management | +| `stream/` | Streaming, range readers, link refresh, rate limiting | +| `model/` | Data models (Obj, Link, Storage, User) | +| `cache/` | Multi-level caching (directories, links, users, settings) | +| `net/` | HTTP utilities, proxy config, download manager | + +### Link Generation and Caching + +**Link Types**: +1. **Direct URL** (`link.URL`): Simple redirect to storage provider +2. **RangeReader** (`link.RangeReader`): Custom streaming implementation +3. **Refreshable Link** (`link.Refresher`): Auto-refresh on expiration + +**Cache System** (`internal/op/cache.go`): +- **Directory Cache**: Stores file listings with configurable TTL +- **Link Cache**: Stores download URLs (30min default) +- **User Cache**: Authentication data (1hr default) +- **Custom Policies**: Pattern-based TTL via `pattern:ttl` format + +**Cache Key Pattern**: `{storageMountPath}/{relativePath}` + +**Invalidation**: Recursive tree deletion for directory operations + +### Range Reader and Streaming + +**Location**: `internal/stream/` + +**Purpose**: Handle partial content requests (HTTP 206), multi-threaded downloads, and link refresh during streaming. + +**Key Components**: + +1. **RangeReaderIF**: Core interface for range-based reading + ```go + type RangeReaderIF interface { + RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) + } + ``` + +2. **RefreshableRangeReader**: Wraps RangeReader with automatic link refresh + - Detects expired links via error strings or HTTP status codes (401, 403, 410, 500) + - Calls `link.Refresher(ctx)` to get new link + - Resumes download from current byte position + - Max 3 refresh attempts to prevent infinite loops + +3. **Multi-threaded Downloader** (`internal/net/downloader.go`): + - Splits file into parts based on `Concurrency` and `PartSize` + - Downloads parts in parallel + - Assembles final stream + +**Link Refresh Pattern**: +```go +// In op.Link(), a refresher is automatically attached +link.Refresher = func(refreshCtx context.Context) (*model.Link, model.Obj, error) { + // Get fresh link from storage driver + file, err := GetUnwrap(refreshCtx, storage, path) + newLink, err := storage.Link(refreshCtx, file, args) + return newLink, file, nil +} + +// RefreshableRangeReader uses this during streaming +if IsLinkExpiredError(err) && r.link.Refresher != nil { + newLink, _, err := r.link.Refresher(ctx) + // Resume from current position +} +``` + +**Proxy Function** (`server/common/proxy.go`): + +Handles multiple scenarios: +1. Multi-threaded download (`link.Concurrency > 0`) +2. Direct RangeReader (`link.RangeReader != nil`) +3. Refreshable link (`link.Refresher != nil`) ← Wraps with RefreshableRangeReader +4. Transparent proxy (forwards to `link.URL`) + +### Startup Sequence + +**Location**: `internal/bootstrap/run.go` + +Order of initialization: +1. `InitConfig()` - Load config, environment variables +2. `Log()` - Initialize logging +3. `InitDB()` - Connect to database +4. `data.InitData()` - Initialize default data +5. `LoadStorages()` - Load and initialize all storage drivers +6. `InitTaskManager()` - Start background tasks +7. `Start()` - Start HTTP/HTTPS/WebDAV/FTP/SFTP servers + +## Common Patterns + +### Error Handling + +Use custom errors from `internal/errs/`: +- `errs.NotImplement` - Feature not implemented +- `errs.ObjectNotFound` - File/folder not found +- `errs.NotFolder` - Path is not a directory +- `errs.StorageNotInit` - Storage driver not initialized + +**Link Expiry Detection**: +```go +// Checks error string for keywords: "expired", "invalid signature", "token expired" +// Also checks HTTP status: 401, 403, 410, 500 +if stream.IsLinkExpiredError(err) { + // Refresh link +} +``` + +### Saving Driver State + +When updating tokens or credentials: +```go +d.AccessToken = newToken +op.MustSaveDriverStorage(d) // Persists to database +``` + +### Rate Limiting + +Use `rate.Limiter` for API rate limits: +```go +type YourDriver struct { + limiter *rate.Limiter +} + +func (d *YourDriver) Init(ctx context.Context) error { + d.limiter = rate.NewLimiter(rate.Every(time.Second), 1) // 1 req/sec +} + +func (d *YourDriver) List(...) { + d.limiter.Wait(ctx) + // Make API call +} +``` + +### Context Cancellation + +Always respect context cancellation in long operations: +```go +select { +case <-ctx.Done(): + return nil, ctx.Err() +default: + // Continue operation +} +``` + +## Important Conventions + +**Naming**: +- Drivers: lowercase with underscores (e.g., `baidu_netdisk`, `aliyundrive_open`) +- Packages: lowercase (e.g., `internal/op`) +- Interfaces: PascalCase with suffix (e.g., `Reader`, `Writer`) + +**Driver Configuration Fields**: +- Use `driver.RootPath` or `driver.RootID` for root folder +- Add `omitempty` to optional JSON fields +- Use descriptive help text in struct tags + +**Retries and Timeouts**: +- Use `github.com/avast/retry-go` for retry logic +- Set reasonable timeouts on HTTP clients (default 30s in `base.RestyClient`) +- For unstable APIs, implement exponential backoff + +**Logging**: +- Use `logrus` via `log` package +- Levels: `log.Debugf`, `log.Infof`, `log.Warnf`, `log.Errorf` +- Include driver name in logs: `log.Infof("[driver_name] message")` + +## Project Context + +OpenList is a community-driven fork of AList, focused on: +- Long-term governance and trust +- Support for 70+ cloud storage providers +- Web UI for file management +- Multi-protocol support (HTTP, WebDAV, FTP, SFTP, S3) +- Offline downloads (Aria2, Transmission) +- Full-text search +- Archive extraction + +**License**: AGPL-3.0 From 31c4dad788b2e8632e590f4a4aabd514ee7b0bab Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 16:11:15 +0800 Subject: [PATCH 14/20] feat(upload): enhance hash calculation and upload logic for various stream types --- CLAUDE.md | 49 +++++++++++++ drivers/115_open/driver.go | 105 +++++++++++++++++++++++----- drivers/123_open/driver.go | 49 +++++++------ drivers/aliyundrive_open/upload.go | 41 +++++++---- drivers/openlist/driver.go | 106 +++++++++++++++++++++++++++-- internal/stream/stream.go | 12 +--- internal/stream/util.go | 22 ++++-- pkg/utils/hash.go | 6 ++ server/handles/fsup.go | 12 ++++ 9 files changed, 332 insertions(+), 70 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 6a1e1461c..c3ffdded0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,6 +2,12 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +## Core Development Principles + +1. **最小代码改动原则** (Minimum code changes): Make the smallest change necessary to achieve the goal +2. **不缓存整个文件原则** (No full file caching for seekable streams): For SeekableStream, use RangeRead instead of caching entire file +3. **必要情况下可以多遍上传原则** (Multi-pass upload when necessary): If rapid upload fails, fall back to normal upload + ## Build and Development Commands ```bash @@ -166,6 +172,49 @@ Response (JSON / Proxy / Redirect) - Downloads parts in parallel - Assembles final stream +**Stream Types and Reader Management**: + +⚠️ **CRITICAL**: SeekableStream.Reader must NEVER be created early! + +- **FileStream**: One-time sequential stream (e.g., HTTP body) + - `Reader` is set at creation and consumed sequentially + - Cannot be rewound or re-read + +- **SeekableStream**: Reusable stream with RangeRead capability + - Has `rangeReader` for creating new readers on-demand + - `Reader` should ONLY be created when actually needed for sequential reading + - **DO NOT create Reader early** - use lazy initialization via `generateReader()` + +**Common Pitfall - Early Reader Creation**: +```go +// ❌ WRONG: Creating Reader early +if _, ok := rr.(*model.FileRangeReader); ok { + rc, _ := rr.RangeRead(ctx, http_range.Range{Length: -1}) + fs.Reader = rc // This will be consumed by intermediate operations! +} + +// ✅ CORRECT: Let generateReader() create it on-demand +// Reader will be created only when Read() is called +return &SeekableStream{FileStream: fs, rangeReader: rr}, nil +``` + +**Why This Matters**: +- Hash calculation uses `StreamHashFile()` which reads the file via RangeRead +- If Reader is created early, it may be at EOF when HTTP upload actually needs it +- Result: `http: ContentLength=X with Body length 0` error + +**Hash Calculation for Uploads**: +```go +// For SeekableStream: Use RangeRead to avoid consuming Reader +if _, ok := file.(*SeekableStream); ok { + hash, err = stream.StreamHashFile(file, utils.MD5, 40, &up) + // StreamHashFile uses RangeRead internally, Reader remains unused +} + +// For FileStream: Must cache first, then calculate hash +_, hash, err = stream.CacheFullAndHash(file, &up, utils.MD5) +``` + **Link Refresh Pattern**: ```go // In op.Link(), a refresher is automatically attached diff --git a/drivers/115_open/driver.go b/drivers/115_open/driver.go index f9d5027bf..dbccb23bc 100644 --- a/drivers/115_open/driver.go +++ b/drivers/115_open/driver.go @@ -226,28 +226,97 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre if err != nil { return err } + sha1 := file.GetHash().GetHash(utils.SHA1) - if len(sha1) != utils.SHA1.Width { - // 流式计算SHA1 - sha1, err = stream.StreamHashFile(file, utils.SHA1, 100, &up) + sha1128k := file.GetHash().GetHash(utils.SHA1_128K) + + // 检查是否是可重复读取的流 + _, isSeekable := file.(*stream.SeekableStream) + + // 如果有预计算的 hash,先尝试秒传 + if len(sha1) == utils.SHA1.Width && len(sha1128k) == utils.SHA1_128K.Width { + resp, err := d.client.UploadInit(ctx, &sdk.UploadInitReq{ + FileName: file.GetName(), + FileSize: file.GetSize(), + Target: dstDir.GetID(), + FileID: strings.ToUpper(sha1), + PreID: strings.ToUpper(sha1128k), + }) if err != nil { return err } + if resp.Status == 2 { + up(100) + return nil + } + // 秒传失败,继续后续流程 } - const PreHashSize int64 = 128 * utils.KB - hashSize := PreHashSize - if file.GetSize() < PreHashSize { - hashSize = file.GetSize() - } - reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) - if err != nil { - return err - } - sha1128k, err := utils.HashReader(utils.SHA1, reader) - if err != nil { - return err + + if isSeekable { + // 可重复读取的流,使用 RangeRead 计算 hash,不缓存 + if len(sha1) != utils.SHA1.Width { + sha1, err = stream.StreamHashFile(file, utils.SHA1, 100, &up) + if err != nil { + return err + } + } + // 计算 sha1_128k(如果没有预计算) + if len(sha1128k) != utils.SHA1_128K.Width { + const PreHashSize int64 = 128 * utils.KB + hashSize := PreHashSize + if file.GetSize() < PreHashSize { + hashSize = file.GetSize() + } + reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err != nil { + return err + } + sha1128k, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return err + } + } + } else { + // 不可重复读取的流(如 HTTP body) + // 如果有预计算的 hash,上面已经尝试过秒传了 + if len(sha1) == utils.SHA1.Width && len(sha1128k) == utils.SHA1_128K.Width { + // 秒传失败,需要缓存文件进行实际上传 + _, err = file.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + } else { + // 没有预计算的 hash,缓存整个文件并计算 + if len(sha1) != utils.SHA1.Width { + _, sha1, err = stream.CacheFullAndHash(file, &up, utils.SHA1) + if err != nil { + return err + } + } else if file.GetFile() == nil { + // 有 SHA1 但没有缓存,需要缓存以支持后续 RangeRead + _, err = file.CacheFullAndWriter(&up, nil) + if err != nil { + return err + } + } + // 计算 sha1_128k + const PreHashSize int64 = 128 * utils.KB + hashSize := PreHashSize + if file.GetSize() < PreHashSize { + hashSize = file.GetSize() + } + reader, err := file.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err != nil { + return err + } + sha1128k, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return err + } + } } - // 1. Init + + // 1. Init(SeekableStream 或已缓存的 FileStream) resp, err := d.client.UploadInit(ctx, &sdk.UploadInitReq{ FileName: file.GetName(), FileSize: file.GetSize(), @@ -273,11 +342,11 @@ func (d *Open115) Put(ctx context.Context, dstDir model.Obj, file model.FileStre if err != nil { return err } - reader, err = file.RangeRead(http_range.Range{Start: start, Length: end - start + 1}) + signReader, err := file.RangeRead(http_range.Range{Start: start, Length: end - start + 1}) if err != nil { return err } - signVal, err := utils.HashReader(utils.SHA1, reader) + signVal, err := utils.HashReader(utils.SHA1, signReader) if err != nil { return err } diff --git a/drivers/123_open/driver.go b/drivers/123_open/driver.go index 7cf3cfe46..0a31bc284 100644 --- a/drivers/123_open/driver.go +++ b/drivers/123_open/driver.go @@ -184,35 +184,46 @@ func (d *Open123) Put(ctx context.Context, dstDir model.Obj, file model.FileStre // etag 文件md5 etag := file.GetHash().GetHash(utils.MD5) + + // 检查是否是可重复读取的流 + _, isSeekable := file.(*stream.SeekableStream) + + // 如果有预计算的 hash,先尝试秒传 if len(etag) >= utils.MD5.Width { - // 有etag时,先尝试秒传 createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err } - // 是否秒传 - if createResp.Data.Reuse { - // 秒传成功才会返回正确的 FileID,否则为 0 - if createResp.Data.FileID != 0 { - return File{ - FileName: file.GetName(), - Size: file.GetSize(), - FileId: createResp.Data.FileID, - Type: 2, - Etag: etag, - }, nil - } + if createResp.Data.Reuse && createResp.Data.FileID != 0 { + return File{ + FileName: file.GetName(), + Size: file.GetSize(), + FileId: createResp.Data.FileID, + Type: 2, + Etag: etag, + }, nil } - // 秒传失败,etag可能不可靠,继续流式计算真实MD5 + // 秒传失败,继续后续流程 } - // 流式MD5计算 - etag, err = stream.StreamHashFile(file, utils.MD5, 40, &up) - if err != nil { - return nil, err + if isSeekable { + // 可重复读取的流,使用 RangeRead 计算 hash,不缓存 + if len(etag) < utils.MD5.Width { + etag, err = stream.StreamHashFile(file, utils.MD5, 40, &up) + if err != nil { + return nil, err + } + } + } else { + // 不可重复读取的流(如 HTTP body) + // 秒传失败或没有 hash,缓存整个文件并计算 MD5 + _, etag, err = stream.CacheFullAndHash(file, &up, utils.MD5) + if err != nil { + return nil, err + } } - // 2. 创建上传任务 + // 2. 创建上传任务(或再次尝试秒传) createResp, err := d.create(parentFileId, file.GetName(), etag, file.GetSize(), 2, false) if err != nil { return nil, err diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index a4a6c1de1..00c806e5f 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -163,21 +163,29 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m } count := int(math.Ceil(float64(stream.GetSize()) / float64(partSize))) createData["part_info_list"] = makePartInfos(count) + + // 检查是否是可重复读取的流 + _, isSeekable := stream.(*streamPkg.SeekableStream) + // rapid upload rapidUpload := !stream.IsForceStreamUpload() && stream.GetSize() > 100*utils.KB && d.RapidUpload if rapidUpload { log.Debugf("[aliyundrive_open] start cal pre_hash") - // read 1024 bytes to calculate pre hash - reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: 1024}) - if err != nil { - return nil, err - } - hash, err := utils.HashReader(utils.SHA1, reader) - if err != nil { - return nil, err + // 优先使用预计算的 pre_hash + preHash := stream.GetHash().GetHash(utils.PRE_HASH) + if len(preHash) != utils.PRE_HASH.Width { + // 没有预计算的 pre_hash,使用 RangeRead 计算 + reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: 1024}) + if err != nil { + return nil, err + } + preHash, err = utils.HashReader(utils.SHA1, reader) + if err != nil { + return nil, err + } } createData["size"] = stream.GetSize() - createData["pre_hash"] = hash + createData["pre_hash"] = preHash } var createResp CreateResp _, err, e := d.requestReturnErrResp(ctx, limiterOther, "/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { @@ -191,9 +199,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m hash := stream.GetHash().GetHash(utils.SHA1) if len(hash) != utils.SHA1.Width { - _, hash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) - if err != nil { - return nil, err + if isSeekable { + // 可重复读取的流,使用 StreamHashFile(RangeRead),不缓存 + hash, err = streamPkg.StreamHashFile(stream, utils.SHA1, 50, &up) + if err != nil { + return nil, err + } + } else { + // 不可重复读取的流,缓存并计算 + _, hash, err = streamPkg.CacheFullAndHash(stream, &up, utils.SHA1) + if err != nil { + return nil, err + } } } diff --git a/drivers/openlist/driver.go b/drivers/openlist/driver.go index 2ca60ff61..9b69bbeb5 100644 --- a/drivers/openlist/driver.go +++ b/drivers/openlist/driver.go @@ -14,6 +14,8 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/errs" "github.com/OpenListTeam/OpenList/v4/internal/model" + "github.com/OpenListTeam/OpenList/v4/internal/stream" + "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/OpenListTeam/OpenList/v4/server/common" "github.com/go-resty/resty/v2" @@ -195,6 +197,92 @@ func (d *OpenList) Remove(ctx context.Context, obj model.Obj) error { } func (d *OpenList) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { + // 预计算 hash(如果不存在),使用 RangeRead 不消耗 Reader + // 这样远端驱动不需要再计算,避免 HTTP body 被重复读取 + md5Hash := s.GetHash().GetHash(utils.MD5) + sha1Hash := s.GetHash().GetHash(utils.SHA1) + sha256Hash := s.GetHash().GetHash(utils.SHA256) + sha1_128kHash := s.GetHash().GetHash(utils.SHA1_128K) + preHash := s.GetHash().GetHash(utils.PRE_HASH) + + // 计算所有缺失的 hash,确保最大兼容性 + if len(md5Hash) != utils.MD5.Width { + var err error + md5Hash, err = stream.StreamHashFile(s, utils.MD5, 33, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate MD5: %v", err) + md5Hash = "" + } + } + if len(sha1Hash) != utils.SHA1.Width { + var err error + sha1Hash, err = stream.StreamHashFile(s, utils.SHA1, 33, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA1: %v", err) + sha1Hash = "" + } + } + if len(sha256Hash) != utils.SHA256.Width { + var err error + sha256Hash, err = stream.StreamHashFile(s, utils.SHA256, 34, &up) + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA256: %v", err) + sha256Hash = "" + } + } + + // 计算特殊 hash(用于秒传验证) + // SHA1_128K: 前128KB的SHA1,115网盘使用 + if len(sha1_128kHash) != utils.SHA1_128K.Width { + const PreHashSize int64 = 128 * 1024 // 128KB + hashSize := PreHashSize + if s.GetSize() < PreHashSize { + hashSize = s.GetSize() + } + reader, err := s.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err == nil { + sha1_128kHash, err = utils.HashReader(utils.SHA1, reader) + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + if err != nil { + log.Warnf("[openlist] failed to pre-calculate SHA1_128K: %v", err) + sha1_128kHash = "" + } + } else { + log.Warnf("[openlist] failed to RangeRead for SHA1_128K: %v", err) + } + } + + // PRE_HASH: 前1024字节的SHA1,阿里云盘使用 + if len(preHash) != utils.PRE_HASH.Width { + const PreHashSize int64 = 1024 // 1KB + hashSize := PreHashSize + if s.GetSize() < PreHashSize { + hashSize = s.GetSize() + } + reader, err := s.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err == nil { + preHash, err = utils.HashReader(utils.SHA1, reader) + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + if err != nil { + log.Warnf("[openlist] failed to pre-calculate PRE_HASH: %v", err) + preHash = "" + } + } else { + log.Warnf("[openlist] failed to RangeRead for PRE_HASH: %v", err) + } + } + + // 诊断日志:检查流的状态 + if ss, ok := s.(*stream.SeekableStream); ok { + if ss.Reader != nil { + log.Warnf("[openlist] WARNING: SeekableStream.Reader is not nil for file %s, stream may have been consumed!", s.GetName()) + } + } + reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, @@ -206,14 +294,20 @@ func (d *OpenList) Put(ctx context.Context, dstDir model.Obj, s model.FileStream req.Header.Set("Authorization", d.Token) req.Header.Set("File-Path", path.Join(dstDir.GetPath(), s.GetName())) req.Header.Set("Password", d.MetaPassword) - if md5 := s.GetHash().GetHash(utils.MD5); len(md5) > 0 { - req.Header.Set("X-File-Md5", md5) + if len(md5Hash) > 0 { + req.Header.Set("X-File-Md5", md5Hash) + } + if len(sha1Hash) > 0 { + req.Header.Set("X-File-Sha1", sha1Hash) + } + if len(sha256Hash) > 0 { + req.Header.Set("X-File-Sha256", sha256Hash) } - if sha1 := s.GetHash().GetHash(utils.SHA1); len(sha1) > 0 { - req.Header.Set("X-File-Sha1", sha1) + if len(sha1_128kHash) > 0 { + req.Header.Set("X-File-Sha1-128k", sha1_128kHash) } - if sha256 := s.GetHash().GetHash(utils.SHA256); len(sha256) > 0 { - req.Header.Set("X-File-Sha256", sha256) + if len(preHash) > 0 { + req.Header.Set("X-File-Pre-Hash", preHash) } req.ContentLength = s.GetSize() diff --git a/internal/stream/stream.go b/internal/stream/stream.go index c29dbbec3..7eec75dd9 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -299,15 +299,9 @@ func NewSeekableStream(fs *FileStream, link *model.Link) (*SeekableStream, error if err != nil { return nil, err } - if _, ok := rr.(*model.FileRangeReader); ok { - var rc io.ReadCloser - rc, err = rr.RangeRead(fs.Ctx, http_range.Range{Length: -1}) - if err != nil { - return nil, err - } - fs.Reader = rc - fs.Add(rc) - } + // IMPORTANT: Do NOT create Reader early for FileRangeReader! + // Let generateReader() create it on-demand when actually needed for reading + // This prevents the Reader from being consumed by intermediate operations like hash calculation fs.size = size fs.Add(link) return &SeekableStream{FileStream: fs, rangeReader: rr}, nil diff --git a/internal/stream/util.go b/internal/stream/util.go index ddd215a43..1ee9f7d99 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -246,17 +246,27 @@ func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressW readSize = size - offset } - // 首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) - n, err := io.ReadFull(file, buf[:readSize]) - if err != nil { - // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) - log.Warnf("StreamHashFile: sequential read failed at offset %d, retrying with RangeRead: %v", offset, err) + var n int + var err error + + // 对于 SeekableStream,优先使用 RangeRead 避免消耗 Reader + // 这样后续发送时 Reader 还能正常工作 + if _, ok := file.(*SeekableStream); ok { n, err = readFullWithRangeRead(file, buf[:readSize], offset) + } else { + // 对于 FileStream,首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) + n, err = io.ReadFull(file, buf[:readSize]) if err != nil { - return "", fmt.Errorf("calculate hash failed at offset %d: sequential read and RangeRead both failed: %w", offset, err) + // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) + log.Warnf("StreamHashFile: sequential read failed at offset %d, retrying with RangeRead: %v", offset, err) + n, err = readFullWithRangeRead(file, buf[:readSize], offset) } } + if err != nil { + return "", fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) + } + hashFunc.Write(buf[:n]) offset += int64(n) diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go index 596e61e54..c4b4e735f 100644 --- a/pkg/utils/hash.go +++ b/pkg/utils/hash.go @@ -90,6 +90,12 @@ var ( // SHA256 indicates SHA-256 support SHA256 = RegisterHash("sha256", "SHA-256", 64, sha256.New) + + // SHA1_128K is SHA1 of first 128KB, used by 115 driver for rapid upload + SHA1_128K = RegisterHash("sha1_128k", "SHA1-128K", 40, sha1.New) + + // PRE_HASH is SHA1 of first 1024 bytes, used by Aliyundrive for rapid upload + PRE_HASH = RegisterHash("pre_hash", "PRE-HASH", 40, sha1.New) ) // HashData get hash of one hashType diff --git a/server/handles/fsup.go b/server/handles/fsup.go index 0f46398cd..54cdb4fee 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -93,6 +93,12 @@ func FsStream(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + if sha1_128k := c.GetHeader("X-File-Sha1-128k"); sha1_128k != "" { + h[utils.SHA1_128K] = sha1_128k + } + if preHash := c.GetHeader("X-File-Pre-Hash"); preHash != "" { + h[utils.PRE_HASH] = preHash + } mimetype := c.GetHeader("Content-Type") if len(mimetype) == 0 { mimetype = utils.GetMimeType(name) @@ -190,6 +196,12 @@ func FsForm(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + if sha1_128k := c.GetHeader("X-File-Sha1-128k"); sha1_128k != "" { + h[utils.SHA1_128K] = sha1_128k + } + if preHash := c.GetHeader("X-File-Pre-Hash"); preHash != "" { + h[utils.PRE_HASH] = preHash + } mimetype := file.Header.Get("Content-Type") if len(mimetype) == 0 { mimetype = utils.GetMimeType(name) From 3d2e1d9537873d6fad223c721e135603e2a29972 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 18:53:38 +0800 Subject: [PATCH 15/20] fix(stream): improve thread safety and handling for SeekableStream and FileStream in directSectionReader --- internal/stream/util.go | 42 ++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/internal/stream/util.go b/internal/stream/util.go index 1ee9f7d99..a72fb7990 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -398,8 +398,16 @@ type directSectionReader struct { bufPool *pool.Pool[[]byte] } -// 线程不安全 +// 线程不安全(依赖调用方保证串行调用) +// 对于 SeekableStream:直接跳过(无需实际读取) +// 对于 FileStream:必须顺序读取并丢弃 func (ss *directSectionReader) DiscardSection(off int64, length int64) error { + // 对于 SeekableStream,直接跳过(RangeRead 支持随机访问,不需要实际读取) + if _, ok := ss.file.(*SeekableStream); ok { + return nil + } + + // 对于 FileStream,必须顺序读取并丢弃 if off != ss.fileOffset { return fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) } @@ -416,31 +424,35 @@ type bufferSectionReader struct { buf []byte } -// 线程不安全 +// 线程不安全(依赖调用方保证串行调用) +// 对于 SeekableStream:使用 RangeRead,支持随机访问(续传场景可跳过已上传分片) +// 对于 FileStream:必须顺序读取 func (ss *directSectionReader) GetSectionReader(off, length int64) (io.ReadSeeker, error) { - if off != ss.fileOffset { - return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) - } tempBuf := ss.bufPool.Get() buf := tempBuf[:length] - // 首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) - // 对于 FileStream,RangeRead 会消耗底层 oriReader,所以必须先尝试顺序流读取 - n, err := io.ReadFull(ss.file, buf) - if err == nil { - ss.fileOffset = off + int64(n) + // 对于 SeekableStream,直接使用 RangeRead(支持随机访问,适用于续传场景) + if _, ok := ss.file.(*SeekableStream); ok { + n, err := readFullWithRangeRead(ss.file, buf, off) + if err != nil { + ss.bufPool.Put(tempBuf) + return nil, fmt.Errorf("RangeRead failed at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) + } return &bufferSectionReader{bytes.NewReader(buf), tempBuf}, nil } - // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) - log.Debugf("Sequential read failed at offset %d, retrying with RangeRead: %v", off, err) - n, err = readFullWithRangeRead(ss.file, buf, off) + // 对于 FileStream,必须顺序读取 + if off != ss.fileOffset { + ss.bufPool.Put(tempBuf) + return nil, fmt.Errorf("stream not cached: request offset %d != current offset %d", off, ss.fileOffset) + } + + n, err := io.ReadFull(ss.file, buf) if err != nil { ss.bufPool.Put(tempBuf) - return nil, fmt.Errorf("both sequential read and RangeRead failed at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) + return nil, fmt.Errorf("sequential read failed at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) } - // 更新 fileOffset ss.fileOffset = off + int64(n) return &bufferSectionReader{bytes.NewReader(buf), tempBuf}, nil } From fe62e35073aa6958e0f3f5ba368a5c07309cd116 Mon Sep 17 00:00:00 2001 From: cyk Date: Sun, 4 Jan 2026 19:37:12 +0800 Subject: [PATCH 16/20] feat(link): add link refresh capability for expired download links --- internal/model/args.go | 8 +++ internal/op/fs.go | 22 +++++++ internal/stream/util.go | 124 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+) diff --git a/internal/model/args.go b/internal/model/args.go index 073c94a63..d165908fb 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -25,6 +25,10 @@ type LinkArgs struct { Redirect bool } +// LinkRefresher is a callback function type for refreshing download links +// It returns a new Link and the associated object, or an error +type LinkRefresher func(ctx context.Context) (*Link, Obj, error) + type Link struct { URL string `json:"url"` // most common way Header http.Header `json:"header"` // needed header (for url) @@ -37,6 +41,10 @@ type Link struct { PartSize int `json:"part_size"` ContentLength int64 `json:"content_length"` // 转码视频、缩略图 + // Refresher is a callback to refresh the link when it expires during long downloads + // This field is not serialized and is optional - if nil, no refresh will be attempted + Refresher LinkRefresher `json:"-"` + utils.SyncClosers `json:"-"` // 如果SyncClosers中的资源被关闭后Link将不可用,则此值应为 true RequireReference bool `json:"-"` diff --git a/internal/op/fs.go b/internal/op/fs.go index 5116bbef5..2c91e6cf3 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -262,6 +262,28 @@ func Link(ctx context.Context, storage driver.Driver, path string, args model.Li if err != nil { return nil, errors.Wrapf(err, "failed get link") } + + // Set up link refresher for automatic refresh on expiry during long downloads + // This enables all download scenarios to handle link expiration gracefully + if link.Refresher == nil { + storageCopy := storage + pathCopy := path + argsCopy := args + link.Refresher = func(refreshCtx context.Context) (*model.Link, model.Obj, error) { + log.Infof("Refreshing download link for: %s", pathCopy) + // Get fresh link directly from storage, bypassing cache + file, err := GetUnwrap(refreshCtx, storageCopy, pathCopy) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed to get file for refresh") + } + newLink, err := storageCopy.Link(refreshCtx, file, argsCopy) + if err != nil { + return nil, nil, errors.Wrapf(err, "failed to refresh link") + } + return newLink, file, nil + } + } + ol := &objWithLink{link: link, obj: file} if link.Expiration != nil { Cache.linkCache.SetTypeWithTTL(key, typeKey, ol, *link.Expiration) diff --git a/internal/stream/util.go b/internal/stream/util.go index a72fb7990..db84badeb 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -9,6 +9,8 @@ import ( "io" "net/http" "os" + "strings" + "sync" "time" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -28,7 +30,129 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran return f(ctx, httpRange) } +// IsLinkExpiredError checks if the error indicates an expired download link +func IsLinkExpiredError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + + // Common expired link error keywords + expiredKeywords := []string{ + "expired", "invalid signature", "token expired", + "access denied", "forbidden", "unauthorized", + "link has expired", "url expired", "request has expired", + "signature expired", "accessdenied", "invalidtoken", + } + for _, keyword := range expiredKeywords { + if strings.Contains(errStr, keyword) { + return true + } + } + + // Check for HTTP status codes that typically indicate expired links + if statusErr, ok := errs.UnwrapOrSelf(err).(net.HttpStatusCodeError); ok { + code := int(statusErr) + // 401 Unauthorized, 403 Forbidden, 410 Gone are common for expired links + // 500 Internal Server Error - some providers (e.g., Baidu) return 500 when link expires + if code == 401 || code == 403 || code == 410 || code == 500 { + return true + } + } + + return false +} + +// RefreshableRangeReader wraps a RangeReader with link refresh capability +type RefreshableRangeReader struct { + link *model.Link + size int64 + innerReader model.RangeReaderIF + mu sync.Mutex + refreshCount int // track refresh count to avoid infinite loops +} + +// NewRefreshableRangeReader creates a new RefreshableRangeReader +func NewRefreshableRangeReader(link *model.Link, size int64) *RefreshableRangeReader { + return &RefreshableRangeReader{ + link: link, + size: size, + } +} + +func (r *RefreshableRangeReader) getInnerReader() (model.RangeReaderIF, error) { + if r.innerReader != nil { + return r.innerReader, nil + } + + // Create inner reader without Refresher to avoid recursion + linkCopy := *r.link + linkCopy.Refresher = nil + + reader, err := GetRangeReaderFromLink(r.size, &linkCopy) + if err != nil { + return nil, err + } + r.innerReader = reader + return reader, nil +} + +func (r *RefreshableRangeReader) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + r.mu.Lock() + reader, err := r.getInnerReader() + r.mu.Unlock() + if err != nil { + return nil, err + } + + rc, err := reader.RangeRead(ctx, httpRange) + if err != nil { + // Check if we should try to refresh on initial connection error + if IsLinkExpiredError(err) && r.link.Refresher != nil { + rc, err = r.refreshAndRetry(ctx, httpRange) + } + if err != nil { + return nil, err + } + } + + return rc, nil +} + +func (r *RefreshableRangeReader) refreshAndRetry(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.refreshCount >= 3 { + return nil, fmt.Errorf("max refresh attempts reached") + } + + log.Infof("Link expired, attempting to refresh...") + newLink, _, refreshErr := r.link.Refresher(ctx) + if refreshErr != nil { + return nil, fmt.Errorf("failed to refresh link: %w", refreshErr) + } + + newLink.Refresher = r.link.Refresher + r.link = newLink + r.innerReader = nil + r.refreshCount++ + + log.Infof("Link refreshed successfully, retrying request...") + + reader, err := r.getInnerReader() + if err != nil { + return nil, err + } + return reader.RangeRead(ctx, httpRange) +} + func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) { + // If link has a Refresher, wrap with RefreshableRangeReader for automatic refresh on expiry + if link.Refresher != nil { + return NewRefreshableRangeReader(link, size), nil + } + if link.RangeReader != nil { if link.Concurrency < 1 && link.PartSize < 1 { return link.RangeReader, nil From 5e39acd6d966ba2b4acb85d527c8b8d12812b99b Mon Sep 17 00:00:00 2001 From: cyk Date: Mon, 5 Jan 2026 17:08:58 +0800 Subject: [PATCH 17/20] feat(upload): add error handling for upload URL refresh on network errors --- drivers/baidu_netdisk/driver.go | 1 + drivers/baidu_netdisk/upload.go | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 1290279da..474dd2b98 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -25,6 +25,7 @@ type BaiduNetdisk struct { } var ErrUploadIDExpired = errors.New("uploadid expired") +var ErrUploadURLExpired = errors.New("upload url expired or unavailable") func (d *BaiduNetdisk) Config() driver.Config { return config diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go index 54b1f29bb..d3edec528 100644 --- a/drivers/baidu_netdisk/upload.go +++ b/drivers/baidu_netdisk/upload.go @@ -129,6 +129,13 @@ func (d *BaiduNetdisk) uploadChunksStream( retry.RetryIf(func(err error) bool { return !errors.Is(err, ErrUploadIDExpired) }), + retry.OnRetry(func(n uint, err error) { + // 重试前检测是否需要刷新上传 URL + if errors.Is(err, ErrUploadURLExpired) { + log.Infof("[baidu_netdisk] refreshing upload URL due to error: %v", err) + precreateResp.UploadURL = d.getUploadUrl(path, precreateResp.Uploadid) + } + }), retry.LastErrorOnly(true)) totalParts := len(precreateResp.BlockList) @@ -233,6 +240,11 @@ func (d *BaiduNetdisk) uploadSliceStream( } resp, err := client.Do(req) if err != nil { + // 检测超时或网络错误,标记需要刷新上传 URL + if isUploadURLError(err) { + log.Warnf("[baidu_netdisk] upload slice failed with network error: %v, will refresh upload URL", err) + return errors.Join(err, ErrUploadURLExpired) + } return err } defer resp.Body.Close() @@ -258,3 +270,30 @@ func (d *BaiduNetdisk) uploadSliceStream( } return nil } + +// isUploadURLError 判断是否为需要刷新上传 URL 的错误 +// 包括:超时、连接被拒绝、连接重置、DNS 解析失败等网络错误 +func isUploadURLError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + // 超时错误 + if strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "deadline exceeded") { + return true + } + // 连接错误 + if strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "no such host") || + strings.Contains(errStr, "network is unreachable") { + return true + } + // EOF 错误(连接被服务器关闭) + if strings.Contains(errStr, "eof") || + strings.Contains(errStr, "broken pipe") { + return true + } + return false +} From 0425d4fb8bbf9a76739437821669e53efc77577e Mon Sep 17 00:00:00 2001 From: cyk Date: Mon, 5 Jan 2026 17:09:03 +0800 Subject: [PATCH 18/20] feat(link): implement ForceRefreshLink method for refreshing download links on read failure --- drivers/baidu_netdisk/upload.go | 54 +++++++++++++--------- internal/stream/stream.go | 8 ++++ internal/stream/util.go | 79 +++++++++++++++++++++++++++------ 3 files changed, 106 insertions(+), 35 deletions(-) diff --git a/drivers/baidu_netdisk/upload.go b/drivers/baidu_netdisk/upload.go index d3edec528..c160c3a9e 100644 --- a/drivers/baidu_netdisk/upload.go +++ b/drivers/baidu_netdisk/upload.go @@ -19,7 +19,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/net" streamPkg "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/errgroup" - "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" log "github.com/sirupsen/logrus" @@ -51,6 +50,11 @@ func (d *BaiduNetdisk) calculateHashesStream( sliceMd5H2 := md5.New() sliceWritten := int64(0) + // 使用固定大小的缓冲区进行流式哈希计算 + // 这样可以利用 readFullWithRangeRead 的链接刷新逻辑 + const chunkSize = 10 * 1024 * 1024 // 10MB per chunk + buf := make([]byte, chunkSize) + for i := 0; i < count; i++ { if utils.IsCanceled(ctx) { return "", "", nil, ctx.Err() @@ -62,31 +66,39 @@ func (d *BaiduNetdisk) calculateHashesStream( length = lastBlockSize } - // 使用 RangeRead 读取数据,不会消耗流 - reader, err := stream.RangeRead(http_range.Range{Start: offset, Length: length}) - if err != nil { - return "", "", nil, err - } - // 计算分片MD5 sliceMd5Calc := md5.New() - // 同时写入多个哈希计算器 - writers := []io.Writer{fileMd5H, sliceMd5Calc} - if sliceWritten < SliceSize { - remaining := SliceSize - sliceWritten - writers = append(writers, utils.LimitWriter(sliceMd5H2, remaining)) - } + // 分块读取并计算哈希 + var sliceOffset int64 = 0 + for sliceOffset < length { + readSize := chunkSize + if length-sliceOffset < int64(chunkSize) { + readSize = int(length - sliceOffset) + } - n, err := io.Copy(io.MultiWriter(writers...), reader) - // 关闭 reader(如果是 ReadCloser) - if rc, ok := reader.(io.Closer); ok { - rc.Close() - } - if err != nil { - return "", "", nil, err + // 使用 readFullWithRangeRead 读取数据,自动处理链接刷新 + n, err := streamPkg.ReadFullWithRangeRead(stream, buf[:readSize], offset+sliceOffset) + if err != nil { + return "", "", nil, err + } + + // 同时写入多个哈希计算器 + fileMd5H.Write(buf[:n]) + sliceMd5Calc.Write(buf[:n]) + if sliceWritten < SliceSize { + remaining := SliceSize - sliceWritten + if int64(n) > remaining { + sliceMd5H2.Write(buf[:remaining]) + sliceWritten += remaining + } else { + sliceMd5H2.Write(buf[:n]) + sliceWritten += int64(n) + } + } + + sliceOffset += int64(n) } - sliceWritten += n blockList = append(blockList, hex.EncodeToString(sliceMd5Calc.Sum(nil))) diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 7eec75dd9..aaf310487 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -346,6 +346,14 @@ func (ss *SeekableStream) generateReader() error { return nil } +// ForceRefreshLink 实现 LinkRefresher 接口,用于在读取失败时刷新链接 +func (ss *SeekableStream) ForceRefreshLink(ctx context.Context) bool { + if rr, ok := ss.rangeReader.(*RefreshableRangeReader); ok { + return rr.ForceRefresh(ctx) + } + return false +} + func (ss *SeekableStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) { if err := ss.generateReader(); err != nil { return nil, err diff --git a/internal/stream/util.go b/internal/stream/util.go index db84badeb..00c3bde52 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -30,6 +30,13 @@ func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Ran return f(ctx, httpRange) } +// LinkRefresher 接口用于在读取数据失败时强制刷新链接 +type LinkRefresher interface { + // ForceRefreshLink 强制刷新下载链接 + // 返回 true 表示刷新成功,false 表示无法刷新 + ForceRefreshLink(ctx context.Context) bool +} + // IsLinkExpiredError checks if the error indicates an expired download link func IsLinkExpiredError(err error) bool { if err == nil { @@ -123,14 +130,40 @@ func (r *RefreshableRangeReader) refreshAndRetry(ctx context.Context, httpRange r.mu.Lock() defer r.mu.Unlock() + if err := r.doRefreshLocked(ctx); err != nil { + return nil, err + } + + reader, err := r.getInnerReader() + if err != nil { + return nil, err + } + return reader.RangeRead(ctx, httpRange) +} + +// ForceRefresh 强制刷新链接,用于读取数据失败(如读取 0 字节)的情况 +// 返回 true 表示刷新成功,false 表示无法刷新(没有 Refresher 或达到最大刷新次数) +func (r *RefreshableRangeReader) ForceRefresh(ctx context.Context) bool { + if r.link.Refresher == nil { + return false + } + + r.mu.Lock() + defer r.mu.Unlock() + + return r.doRefreshLocked(ctx) == nil +} + +// doRefreshLocked 执行实际的刷新逻辑(需要持有锁) +func (r *RefreshableRangeReader) doRefreshLocked(ctx context.Context) error { if r.refreshCount >= 3 { - return nil, fmt.Errorf("max refresh attempts reached") + return fmt.Errorf("max refresh attempts reached") } log.Infof("Link expired, attempting to refresh...") newLink, _, refreshErr := r.link.Refresher(ctx) if refreshErr != nil { - return nil, fmt.Errorf("failed to refresh link: %w", refreshErr) + return fmt.Errorf("failed to refresh link: %w", refreshErr) } newLink.Refresher = r.link.Refresher @@ -138,13 +171,8 @@ func (r *RefreshableRangeReader) refreshAndRetry(ctx context.Context, httpRange r.innerReader = nil r.refreshCount++ - log.Infof("Link refreshed successfully, retrying request...") - - reader, err := r.getInnerReader() - if err != nil { - return nil, err - } - return reader.RangeRead(ctx, httpRange) + log.Infof("Link refreshed successfully") + return nil } func GetRangeReaderFromLink(size int64, link *model.Link) (model.RangeReaderIF, error) { @@ -299,13 +327,14 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT return tmpF, hex.EncodeToString(h.Sum(nil)), nil } -// readFullWithRangeRead 使用 RangeRead 从文件流中读取数据到 buf +// ReadFullWithRangeRead 使用 RangeRead 从文件流中读取数据到 buf // file: 文件流 // buf: 目标缓冲区 // off: 读取的起始偏移量 // 返回值: 实际读取的字节数和错误 // 支持自动重试(最多3次),每次重试之间有递增延迟(3秒、6秒、9秒) -func readFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, error) { +// 支持链接刷新:当检测到 0 字节读取时,会自动刷新下载链接 +func ReadFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, error) { length := int64(len(buf)) var lastErr error @@ -331,6 +360,28 @@ func readFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, lastErr = fmt.Errorf("failed to read all data via RangeRead at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) log.Debugf("RangeRead retry %d read failed: %v", retry+1, lastErr) + + // 检测是否可能是链接过期(读取 0 字节或 EOF) + if n == 0 && (err == io.EOF || err == io.ErrUnexpectedEOF) { + // 尝试刷新链接 + if refresher, ok := file.(LinkRefresher); ok { + // 获取 context - 从 FileStream 或 SeekableStream 中获取 + var ctx context.Context + if fs, ok := file.(*FileStream); ok { + ctx = fs.Ctx + } else if ss, ok := file.(*SeekableStream); ok { + ctx = ss.Ctx + } else { + ctx = context.Background() + } + + if refresher.ForceRefreshLink(ctx) { + log.Infof("Link refreshed after 0-byte read, retrying immediately...") + continue // 立即重试,不延迟 + } + } + } + // 递增延迟:3秒、6秒、9秒,等待网络恢复 time.Sleep(time.Duration(retry+1) * 3 * time.Second) } @@ -376,14 +427,14 @@ func StreamHashFile(file model.FileStreamer, hashType *utils.HashType, progressW // 对于 SeekableStream,优先使用 RangeRead 避免消耗 Reader // 这样后续发送时 Reader 还能正常工作 if _, ok := file.(*SeekableStream); ok { - n, err = readFullWithRangeRead(file, buf[:readSize], offset) + n, err = ReadFullWithRangeRead(file, buf[:readSize], offset) } else { // 对于 FileStream,首先尝试顺序流读取(不消耗额外资源,适用于所有流类型) n, err = io.ReadFull(file, buf[:readSize]) if err != nil { // 顺序流读取失败,尝试使用 RangeRead 重试(适用于 SeekableStream) log.Warnf("StreamHashFile: sequential read failed at offset %d, retrying with RangeRead: %v", offset, err) - n, err = readFullWithRangeRead(file, buf[:readSize], offset) + n, err = ReadFullWithRangeRead(file, buf[:readSize], offset) } } @@ -557,7 +608,7 @@ func (ss *directSectionReader) GetSectionReader(off, length int64) (io.ReadSeeke // 对于 SeekableStream,直接使用 RangeRead(支持随机访问,适用于续传场景) if _, ok := ss.file.(*SeekableStream); ok { - n, err := readFullWithRangeRead(ss.file, buf, off) + n, err := ReadFullWithRangeRead(ss.file, buf, off) if err != nil { ss.bufPool.Put(tempBuf) return nil, fmt.Errorf("RangeRead failed at offset %d: (expect=%d, actual=%d) %w", off, length, n, err) From 8a8a708eb8facedc8a27a92bd8b7407cc4f409e8 Mon Sep 17 00:00:00 2001 From: cyk Date: Tue, 6 Jan 2026 02:43:16 +0800 Subject: [PATCH 19/20] =?UTF-8?q?feat(network):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9=E6=85=A2=E9=80=9F=E7=BD=91=E7=BB=9C=E7=9A=84=E6=94=AF?= =?UTF-8?q?=E6=8C=81=EF=BC=8C=E8=B0=83=E6=95=B4=E8=B6=85=E6=97=B6=E5=92=8C?= =?UTF-8?q?=E9=87=8D=E8=AF=95=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- drivers/115_open/driver.go | 31 +++++++++++++++++++++++++++++++ drivers/baidu_netdisk/meta.go | 4 ++-- internal/net/serve.go | 11 ++++++++++- internal/stream/util.go | 26 +++++++++++++++++--------- 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 1d71f0d60..add6d56bb 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,5 @@ output/ /public/dist/* /!public/dist/README.md -.VSCodeCounter \ No newline at end of file +.VSCodeCounter +nul diff --git a/drivers/115_open/driver.go b/drivers/115_open/driver.go index dbccb23bc..52bde0cd7 100644 --- a/drivers/115_open/driver.go +++ b/drivers/115_open/driver.go @@ -17,6 +17,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/stream" "github.com/OpenListTeam/OpenList/v4/pkg/http_range" "github.com/OpenListTeam/OpenList/v4/pkg/utils" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" ) @@ -74,13 +75,20 @@ func (d *Open115) Drop(ctx context.Context) error { } func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + start := time.Now() + log.Infof("[115] List request started for dir: %s (ID: %s)", dir.GetName(), dir.GetID()) + var res []model.Obj pageSize := int64(d.PageSize) offset := int64(0) + pageCount := 0 + for { if err := d.WaitLimit(ctx); err != nil { return nil, err } + + pageStart := time.Now() resp, err := d.client.GetFiles(ctx, &sdk.GetFilesReq{ CID: dir.GetID(), Limit: pageSize, @@ -90,7 +98,12 @@ func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) // Cur: 1, ShowDir: true, }) + pageDuration := time.Since(pageStart) + pageCount++ + log.Infof("[115] GetFiles page %d took: %v (offset=%d, limit=%d)", pageCount, pageDuration, offset, pageSize) + if err != nil { + log.Errorf("[115] GetFiles page %d failed after %v: %v", pageCount, pageDuration, err) return nil, err } res = append(res, utils.MustSliceConvert(resp.Data, func(src sdk.GetFilesResp_File) model.Obj { @@ -102,10 +115,17 @@ func (d *Open115) List(ctx context.Context, dir model.Obj, args model.ListArgs) } offset += pageSize } + + totalDuration := time.Since(start) + log.Infof("[115] List request completed in %v (%d pages, %d files)", totalDuration, pageCount, len(res)) + return res, nil } func (d *Open115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + start := time.Now() + log.Infof("[115] Link request started for file: %s", file.GetName()) + if err := d.WaitLimit(ctx); err != nil { return nil, err } @@ -121,14 +141,25 @@ func (d *Open115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) return nil, fmt.Errorf("can't convert obj") } pc := obj.Pc + + apiStart := time.Now() + log.Infof("[115] Calling DownURL API...") resp, err := d.client.DownURL(ctx, pc, ua) + apiDuration := time.Since(apiStart) + log.Infof("[115] DownURL API took: %v", apiDuration) + if err != nil { + log.Errorf("[115] DownURL API failed after %v: %v", apiDuration, err) return nil, err } u, ok := resp[obj.GetID()] if !ok { return nil, fmt.Errorf("can't get link") } + + totalDuration := time.Since(start) + log.Infof("[115] Link request completed in %v (API: %v)", totalDuration, apiDuration) + return &model.Link{ URL: u.URL.URL, Header: http.Header{ diff --git a/drivers/baidu_netdisk/meta.go b/drivers/baidu_netdisk/meta.go index f75f1c774..62ad89d85 100644 --- a/drivers/baidu_netdisk/meta.go +++ b/drivers/baidu_netdisk/meta.go @@ -30,8 +30,8 @@ type Addition struct { const ( UPLOAD_FALLBACK_API = "https://d.pcs.baidu.com" // 备用上传地址 UPLOAD_URL_EXPIRE_TIME = time.Minute * 60 // 上传地址有效期(分钟) - DEFAULT_UPLOAD_SLICE_TIMEOUT = time.Second * 60 // 上传分片请求默认超时时间 - UPLOAD_RETRY_COUNT = 3 + DEFAULT_UPLOAD_SLICE_TIMEOUT = time.Second * 180 // 上传分片请求默认超时时间(增加到3分钟以应对慢速网络) + UPLOAD_RETRY_COUNT = 5 // 增加重试次数以提高成功率 UPLOAD_RETRY_WAIT_TIME = time.Second * 1 UPLOAD_RETRY_MAX_WAIT_TIME = time.Second * 5 ) diff --git a/internal/net/serve.go b/internal/net/serve.go index 6a20460b1..ee288b86a 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "mime/multipart" + stdnet "net" // 标准库net包,用于Dialer "net/http" "strconv" "strings" @@ -286,12 +287,20 @@ func NewHttpClient() *http.Client { transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + // 快速连接超时:10秒建立连接,失败快速重试 + DialContext: (&stdnet.Dialer{ + Timeout: 10 * time.Second, // TCP握手超时 + KeepAlive: 30 * time.Second, // TCP keep-alive + }).DialContext, + // 响应头超时:15秒等待服务器响应头(平衡API调用与下载检测) + ResponseHeaderTimeout: 15 * time.Second, + // 允许长时间读取数据(无 IdleConnTimeout 限制) } SetProxyIfConfigured(transport) return &http.Client{ - Timeout: time.Hour * 48, + Timeout: time.Hour * 48, // 总超时保持48小时(允许大文件慢速下载) Transport: transport, } } diff --git a/internal/stream/util.go b/internal/stream/util.go index 00c3bde52..b24ad2417 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -24,6 +24,14 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // 链接刷新相关常量 + MAX_LINK_REFRESH_COUNT = 50 // 下载链接最大刷新次数(支持长时间传输) + + // RangeRead 重试相关常量 + MAX_RANGE_READ_RETRY_COUNT = 5 // RangeRead 最大重试次数(从3增加到5) +) + type RangeReaderFunc func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) func (f RangeReaderFunc) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { @@ -156,8 +164,8 @@ func (r *RefreshableRangeReader) ForceRefresh(ctx context.Context) bool { // doRefreshLocked 执行实际的刷新逻辑(需要持有锁) func (r *RefreshableRangeReader) doRefreshLocked(ctx context.Context) error { - if r.refreshCount >= 3 { - return fmt.Errorf("max refresh attempts reached") + if r.refreshCount >= MAX_LINK_REFRESH_COUNT { + return fmt.Errorf("max refresh attempts (%d) reached", MAX_LINK_REFRESH_COUNT) } log.Infof("Link expired, attempting to refresh...") @@ -332,20 +340,20 @@ func CacheFullAndHash(stream model.FileStreamer, up *model.UpdateProgress, hashT // buf: 目标缓冲区 // off: 读取的起始偏移量 // 返回值: 实际读取的字节数和错误 -// 支持自动重试(最多3次),每次重试之间有递增延迟(3秒、6秒、9秒) +// 支持自动重试(最多5次),快速重试策略(1秒、2秒、3秒、4秒、5秒) // 支持链接刷新:当检测到 0 字节读取时,会自动刷新下载链接 func ReadFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, error) { length := int64(len(buf)) var lastErr error - // 重试最多3次 - for retry := 0; retry < 3; retry++ { + // 重试最多 MAX_RANGE_READ_RETRY_COUNT 次 + for retry := 0; retry < MAX_RANGE_READ_RETRY_COUNT; retry++ { reader, err := file.RangeRead(http_range.Range{Start: off, Length: length}) if err != nil { lastErr = fmt.Errorf("RangeRead failed at offset %d: %w", off, err) log.Debugf("RangeRead retry %d failed: %v", retry+1, lastErr) - // 递增延迟:3秒、6秒、9秒,等待代理恢复 - time.Sleep(time.Duration(retry+1) * 3 * time.Second) + // 快速重试:1秒、2秒、3秒、4秒、5秒(连接失败快速重试) + time.Sleep(time.Duration(retry+1) * time.Second) continue } @@ -382,8 +390,8 @@ func ReadFullWithRangeRead(file model.FileStreamer, buf []byte, off int64) (int, } } - // 递增延迟:3秒、6秒、9秒,等待网络恢复 - time.Sleep(time.Duration(retry+1) * 3 * time.Second) + // 快速重试:1秒、2秒、3秒、4秒、5秒(读取失败快速重试) + time.Sleep(time.Duration(retry+1) * time.Second) } return 0, lastErr From 44df52acb33b5998dbc9af815f6c9f394742196c Mon Sep 17 00:00:00 2001 From: cyk Date: Wed, 7 Jan 2026 21:36:59 +0800 Subject: [PATCH 20/20] =?UTF-8?q?feat(upload):=20=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E5=88=86=E7=89=87=E4=B8=8A=E4=BC=A0=E6=94=AF=E6=8C=81=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=B6=85=E6=97=B6=E5=92=8CETag=E6=8F=90?= =?UTF-8?q?=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- drivers/quark_open/driver.go | 131 +++++++++++++++++++++++++++-------- drivers/quark_open/meta.go | 4 +- drivers/quark_open/util.go | 94 ++++++++++++++++++++++--- 3 files changed, 189 insertions(+), 40 deletions(-) diff --git a/drivers/quark_open/driver.go b/drivers/quark_open/driver.go index f0b8baf09..cf1ff3cb0 100644 --- a/drivers/quark_open/driver.go +++ b/drivers/quark_open/driver.go @@ -8,6 +8,7 @@ import ( "hash" "io" "net/http" + "strings" "time" "github.com/OpenListTeam/OpenList/v4/drivers/base" @@ -18,6 +19,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/avast/retry-go" "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" ) type QuarkOpen struct { @@ -144,30 +146,84 @@ func (d *QuarkOpen) Remove(ctx context.Context, obj model.Obj) error { func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1) - var ( - md5 hash.Hash - sha1 hash.Hash - ) - writers := []io.Writer{} - if len(md5Str) != utils.MD5.Width { - md5 = utils.MD5.NewFunc() - writers = append(writers, md5) - } - if len(sha1Str) != utils.SHA1.Width { - sha1 = utils.SHA1.NewFunc() - writers = append(writers, sha1) - } - if len(writers) > 0 { - _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) - if err != nil { - return err - } - if md5 != nil { - md5Str = hex.EncodeToString(md5.Sum(nil)) - } - if sha1 != nil { - sha1Str = hex.EncodeToString(sha1.Sum(nil)) + // 检查是否需要计算hash + needMD5 := len(md5Str) != utils.MD5.Width + needSHA1 := len(sha1Str) != utils.SHA1.Width + + if needMD5 || needSHA1 { + // 检查是否为可重复读取的流 + _, isSeekable := stream.(*streamPkg.SeekableStream) + + if isSeekable { + // 可重复读取的流,使用 RangeRead 一次性计算所有hash,避免重复读取 + var md5 hash.Hash + var sha1 hash.Hash + writers := []io.Writer{} + + if needMD5 { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) + } + if needSHA1 { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) + } + + // 使用 RangeRead 分块读取文件,同时计算多个hash + multiWriter := io.MultiWriter(writers...) + size := stream.GetSize() + chunkSize := int64(10 * utils.MB) // 10MB per chunk + buf := make([]byte, chunkSize) + var offset int64 = 0 + + for offset < size { + readSize := min(chunkSize, size-offset) + + n, err := streamPkg.ReadFullWithRangeRead(stream, buf[:readSize], offset) + if err != nil { + return fmt.Errorf("calculate hash failed at offset %d: %w", offset, err) + } + + multiWriter.Write(buf[:n]) + offset += int64(n) + + // 更新进度(hash计算占用40%的进度) + up(40 * float64(offset) / float64(size)) + } + + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } + } else { + // 不可重复读取的流(如网络流),需要缓存并计算hash + var md5 hash.Hash + var sha1 hash.Hash + writers := []io.Writer{} + + if needMD5 { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) + } + if needSHA1 { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) + } + + _, err := stream.CacheFullAndWriter(&up, io.MultiWriter(writers...)) + if err != nil { + return err + } + + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } } } // pre @@ -210,24 +266,43 @@ func (d *QuarkOpen) Put(ctx context.Context, dstDir model.Obj, stream model.File if err != nil { return err } + + // 上传重试逻辑,包含URL刷新 + var etag string err = retry.Do(func() error { rd.Seek(0, io.SeekStart) - etag, err := d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) - if err != nil { - return err + var uploadErr error + etag, uploadErr = d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) + + // 检查是否为URL过期错误 + if uploadErr != nil && strings.Contains(uploadErr.Error(), "expire") { + log.Warnf("[quark_open] Upload URL expired for part %d, refreshing...", i) + // 刷新上传URL + newUpUrlInfo, refreshErr := d.upUrl(ctx, pre, partInfo) + if refreshErr != nil { + return fmt.Errorf("failed to refresh upload url: %w", refreshErr) + } + upUrlInfo = newUpUrlInfo + log.Infof("[quark_open] Upload URL refreshed successfully") + + // 使用新URL重试上传 + rd.Seek(0, io.SeekStart) + etag, uploadErr = d.upPart(ctx, upUrlInfo, i, driver.NewLimitedUploadStream(ctx, rd)) } - etags = append(etags, etag) - return nil + + return uploadErr }, retry.Context(ctx), retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.Delay(time.Second)) + ss.FreeSectionReader(rd) if err != nil { return fmt.Errorf("failed to upload part %d: %w", i, err) } + etags = append(etags, etag) up(95 * float64(offset+size) / float64(total)) } diff --git a/drivers/quark_open/meta.go b/drivers/quark_open/meta.go index 3527b52e9..ee1903939 100644 --- a/drivers/quark_open/meta.go +++ b/drivers/quark_open/meta.go @@ -13,8 +13,8 @@ type Addition struct { APIAddress string `json:"api_url_address" default:"https://api.oplist.org/quarkyun/renewapi"` AccessToken string `json:"access_token" required:"false" default:""` RefreshToken string `json:"refresh_token" required:"true"` - AppID string `json:"app_id" required:"true" help:"Keep it empty if you don't have one"` - SignKey string `json:"sign_key" required:"true" help:"Keep it empty if you don't have one"` + AppID string `json:"app_id" required:"false" default:"" help:"Optional - Auto-filled from online API, or use your own"` + SignKey string `json:"sign_key" required:"false" default:"" help:"Optional - Auto-filled from online API, or use your own"` } type Conf struct { diff --git a/drivers/quark_open/util.go b/drivers/quark_open/util.go index 788ca0e99..1a3058375 100644 --- a/drivers/quark_open/util.go +++ b/drivers/quark_open/util.go @@ -20,6 +20,7 @@ import ( "github.com/OpenListTeam/OpenList/v4/drivers/base" "github.com/OpenListTeam/OpenList/v4/internal/model" "github.com/OpenListTeam/OpenList/v4/internal/op" + "github.com/OpenListTeam/OpenList/v4/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" ) @@ -283,8 +284,15 @@ func (d *QuarkOpen) getProofRange(proofSeed string, fileSize int64) (*ProofRange func (d *QuarkOpen) _getPartInfo(stream model.FileStreamer, partSize int64) []base.Json { // 计算分片信息 - partInfo := make([]base.Json, 0) total := stream.GetSize() + + // 确保partSize合理:最小4MB,避免分片过多 + const minPartSize int64 = 4 * utils.MB + if partSize < minPartSize { + partSize = minPartSize + } + + partInfo := make([]base.Json, 0) left := total partNumber := 1 @@ -304,6 +312,7 @@ func (d *QuarkOpen) _getPartInfo(stream model.FileStreamer, partSize int64) []ba partNumber++ } + log.Infof("[quark_open] Upload plan: file_size=%d, part_size=%d, part_count=%d", total, partSize, len(partInfo)) return partInfo } @@ -315,11 +324,17 @@ func (d *QuarkOpen) upUrl(ctx context.Context, pre UpPreResp, partInfo []base.Js } var resp UpUrlResp + log.Infof("[quark_open] Requesting upload URLs for %d parts (task_id: %s)", len(partInfo), pre.Data.TaskID) + _, err = d.request(ctx, "/open/v1/file/get_upload_urls", http.MethodPost, func(req *resty.Request) { req.SetBody(data) }, &resp) if err != nil { + // 如果是分片超限错误,记录详细信息 + if strings.Contains(err.Error(), "part list exceed") { + log.Errorf("[quark_open] Part list exceeded limit! Requested %d parts. Please check Quark API documentation for actual limit.", len(partInfo)) + } return upUrlInfo, err } @@ -340,13 +355,43 @@ func (d *QuarkOpen) upPart(ctx context.Context, upUrlInfo UpUrlInfo, partNumber req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("User-Agent", "Go-http-client/1.1") + // ✅ 关键修复:使用更长的超时时间(10分钟) + // 慢速网络下大文件分片上传可能需要很长时间 + client := &http.Client{ + Timeout: 10 * time.Minute, + Transport: base.HttpClient.Transport, + } + // 发送请求 - resp, err := base.HttpClient.Do(req) + resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() + // 检查是否为URL过期错误(403, 410等状态码) + if resp.StatusCode == 403 || resp.StatusCode == 410 { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("upload url expired (status: %d): %s", resp.StatusCode, string(body)) + } + + // ✅ 关键修复:409 PartAlreadyExist 不是错误! + // 夸克使用Sequential模式,超时重试时如果分片已存在,说明第一次其实成功了 + if resp.StatusCode == 409 { + body, _ := io.ReadAll(resp.Body) + // 从响应体中提取已存在分片的ETag + if strings.Contains(string(body), "PartAlreadyExist") { + // 尝试从XML响应中提取ETag + if etag := extractEtagFromXML(string(body)); etag != "" { + log.Infof("[quark_open] Part %d already exists (409), using existing ETag: %s", partNumber+1, etag) + return etag, nil + } + // 如果无法提取ETag,返回错误 + log.Warnf("[quark_open] Part %d already exists but cannot extract ETag from response: %s", partNumber+1, string(body)) + return "", fmt.Errorf("part already exists but ETag not found in response") + } + } + if resp.StatusCode != 200 { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("up status: %d, error: %s", resp.StatusCode, string(body)) @@ -355,6 +400,23 @@ func (d *QuarkOpen) upPart(ctx context.Context, upUrlInfo UpUrlInfo, partNumber return resp.Header.Get("Etag"), nil } +// extractEtagFromXML 从OSS的XML错误响应中提取ETag +// 示例: "2F796AC486BB2891E3237D8BFDE020B5" +func extractEtagFromXML(xmlBody string) string { + start := strings.Index(xmlBody, "") + if start == -1 { + return "" + } + start += len("") + end := strings.Index(xmlBody[start:], "") + if end == -1 { + return "" + } + etag := xmlBody[start : start+end] + // 移除引号 + return strings.Trim(etag, "\"") +} + func (d *QuarkOpen) upFinish(ctx context.Context, pre UpPreResp, partInfo []base.Json, etags []string) error { // 创建 part_info_list partInfoList := make([]base.Json, len(partInfo)) @@ -417,25 +479,36 @@ func (d *QuarkOpen) generateReqSign(method string, pathname string, signKey stri } func (d *QuarkOpen) refreshToken() error { - refresh, access, err := d._refreshToken() + refresh, access, appID, signKey, err := d._refreshToken() for i := 0; i < 3; i++ { if err == nil { break } else { log.Errorf("[quark_open] failed to refresh token: %s", err) } - refresh, access, err = d._refreshToken() + refresh, access, appID, signKey, err = d._refreshToken() } if err != nil { return err } log.Infof("[quark_open] token exchange: %s -> %s", d.RefreshToken, refresh) d.RefreshToken, d.AccessToken = refresh, access + + // 如果在线API返回了AppID和SignKey,保存它们(不为空时才更新) + if appID != "" && appID != d.AppID { + d.AppID = appID + log.Infof("[quark_open] AppID updated from online API: %s", appID) + } + if signKey != "" && signKey != d.SignKey { + d.SignKey = signKey + log.Infof("[quark_open] SignKey updated from online API") + } + op.MustSaveDriverStorage(d) return nil } -func (d *QuarkOpen) _refreshToken() (string, string, error) { +func (d *QuarkOpen) _refreshToken() (string, string, string, string, error) { if d.UseOnlineAPI && d.APIAddress != "" { u := d.APIAddress var resp RefreshTokenOnlineAPIResp @@ -448,19 +521,20 @@ func (d *QuarkOpen) _refreshToken() (string, string, error) { }). Get(u) if err != nil { - return "", "", err + return "", "", "", "", err } if resp.RefreshToken == "" || resp.AccessToken == "" { if resp.ErrorMessage != "" { - return "", "", fmt.Errorf("failed to refresh token: %s", resp.ErrorMessage) + return "", "", "", "", fmt.Errorf("failed to refresh token: %s", resp.ErrorMessage) } - return "", "", fmt.Errorf("empty token returned from official API, a wrong refresh token may have been used") + return "", "", "", "", fmt.Errorf("empty token returned from official API, a wrong refresh token may have been used") } - return resp.RefreshToken, resp.AccessToken, nil + // 返回所有字段,包括AppID和SignKey + return resp.RefreshToken, resp.AccessToken, resp.AppID, resp.SignKey, nil } // TODO 本地刷新逻辑 - return "", "", fmt.Errorf("local refresh token logic is not implemented yet, please use online API or contact the developer") + return "", "", "", "", fmt.Errorf("local refresh token logic is not implemented yet, please use online API or contact the developer") } // 生成认证 Cookie