Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cli/context/store/io_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"io"
)

// LimitedReader is a fork of io.LimitedReader to override Read.
type LimitedReader struct {
// limitedReader is a fork of [io.LimitedReader] to override Read.
type limitedReader struct {
R io.Reader
N int64 // max bytes remaining
}

// Read is a fork of io.LimitedReader.Read that returns an error when limit exceeded.
func (l *LimitedReader) Read(p []byte) (n int, err error) {
// Read is a fork of [io.LimitedReader.Read] that returns an error when limit exceeded.
func (l *limitedReader) Read(p []byte) (n int, err error) {
if l.N < 0 {
return 0, errors.New("read exceeds the defined limit")
}
Expand Down
4 changes: 2 additions & 2 deletions cli/context/store/io_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ func TestLimitReaderReadAll(t *testing.T) {
assert.NilError(t, err)

r = strings.NewReader("Test")
_, err = io.ReadAll(&LimitedReader{R: r, N: 4})
_, err = io.ReadAll(&limitedReader{R: r, N: 4})
assert.NilError(t, err)

r = strings.NewReader("Test")
_, err = io.ReadAll(&LimitedReader{R: r, N: 2})
_, err = io.ReadAll(&limitedReader{R: r, N: 2})
assert.Error(t, err, "read exceeds the defined limit")
}
6 changes: 3 additions & 3 deletions cli/context/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func isValidFilePath(p string) error {
}

func importTar(name string, s Writer, reader io.Reader) error {
tr := tar.NewReader(&LimitedReader{R: reader, N: maxAllowedFileSizeToImport})
tr := tar.NewReader(&limitedReader{R: reader, N: maxAllowedFileSizeToImport})
tlsData := ContextTLSData{
Endpoints: map[string]EndpointTLSData{},
}
Expand Down Expand Up @@ -406,7 +406,7 @@ func importTar(name string, s Writer, reader io.Reader) error {
}

func importZip(name string, s Writer, reader io.Reader) error {
body, err := io.ReadAll(&LimitedReader{R: reader, N: maxAllowedFileSizeToImport})
body, err := io.ReadAll(&limitedReader{R: reader, N: maxAllowedFileSizeToImport})
if err != nil {
return err
}
Expand Down Expand Up @@ -434,7 +434,7 @@ func importZip(name string, s Writer, reader io.Reader) error {
return err
}

data, err := io.ReadAll(&LimitedReader{R: f, N: maxAllowedFileSizeToImport})
data, err := io.ReadAll(&limitedReader{R: f, N: maxAllowedFileSizeToImport})
defer f.Close()
if err != nil {
return err
Expand Down
Loading