diff --git a/pkg/config/config.go b/pkg/config/config.go index e4ece5694..8fd624ec6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -31,6 +31,7 @@ import ( "code.vikunja.io/api/pkg/log" + "github.com/c2h5oh/datasize" "github.com/spf13/viper" ) @@ -213,6 +214,8 @@ const ( PluginsDir Key = `plugins.dir` ) +var maxFileSizeInBytes uint64 + // GetString returns a string config value func (k Key) GetString() string { return viper.GetString(string(k)) @@ -622,6 +625,11 @@ func InitConfig() { publicURL := strings.TrimSuffix(ServicePublicURL.GetString(), "/") CorsOrigins.Set(append(CorsOrigins.GetStringSlice(), publicURL)) + + err = SetMaxFileSizeMBytesFromString(FilesMaxSize.GetString()) + if err != nil { + log.Fatalf("Could not parse files.maxsize: %s", err) + } } func random(length int) (string, error) { @@ -632,3 +640,21 @@ func random(length int) (string, error) { return fmt.Sprintf("%X", b), nil } + +func SetMaxFileSizeMBytesFromString(size string) error { + var maxSize datasize.ByteSize + err := maxSize.UnmarshalText([]byte(size)) + if err != nil { + return err + } + + maxFileSizeInBytes = uint64(maxSize.MBytes()) + return nil +} + +func GetMaxFileSizeInMBytes() uint64 { + if maxFileSizeInBytes == 0 { + return 20 + } + return maxFileSizeInBytes +} diff --git a/pkg/db/db.go b/pkg/db/db.go index 0aefac4bf..40a908bf5 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -241,12 +241,14 @@ func Type() schemas.DBType { } func GetDialect() string { - dialect := config.DatabaseType.GetString() - if dialect == "sqlite" { - dialect = builder.SQLITE + switch config.DatabaseType.GetString() { + case "mysql": + return builder.MYSQL + case "postgres": + return builder.POSTGRES + default: + return builder.SQLITE } - - return dialect } func checkParadeDB(engine *xorm.Engine) { diff --git a/pkg/db/test_fixtures.go b/pkg/db/test_fixtures.go index 870b0193e..a18d51df9 100644 --- a/pkg/db/test_fixtures.go +++ b/pkg/db/test_fixtures.go @@ -49,13 +49,13 @@ func InitFixtures(tablenames ...string) (err error) { loaderOptions := []func(loader *testfixtures.Loader) error{ testfixtures.Database(x.DB().DB), - testfixtures.Dialect(config.DatabaseType.GetString()), + testfixtures.Dialect(GetDialect()), testfixtures.DangerousSkipTestDatabaseCheck(), testfixtures.Location(config.GetTimeZone()), testfiles, } - if config.DatabaseType.GetString() == "postgres" { + if GetDialect() == "postgres" { loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences()) } diff --git a/pkg/files/error.go b/pkg/files/error.go index 2523bc261..23a84c5be 100644 --- a/pkg/files/error.go +++ b/pkg/files/error.go @@ -16,7 +16,12 @@ package files -import "fmt" +import ( + "fmt" + "net/http" + + "code.vikunja.io/api/pkg/web" +) // ErrFileDoesNotExist defines an error where a file does not exist in the db type ErrFileDoesNotExist struct { @@ -50,6 +55,18 @@ func IsErrFileIsTooLarge(err error) bool { return ok } +// ErrCodeFileIsTooLarge holds the unique world-error code of this error +const ErrCodeFileIsTooLarge = 4013 + +// HTTPError holds the http error description +func (err ErrFileIsTooLarge) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusRequestEntityTooLarge, + Code: ErrCodeFileIsTooLarge, + Message: "The uploaded file exceeds the maximum configured file size", + } +} + // ErrFileIsNotUnsplashFile defines an error where a file is not downloaded from unsplash. // Used in cases whenever unsplash information about a file is requested, but the file was not downloaded from unsplash. type ErrFileIsNotUnsplashFile struct { diff --git a/pkg/files/filehandling.go b/pkg/files/filehandling.go index d520f1d33..73a057d28 100644 --- a/pkg/files/filehandling.go +++ b/pkg/files/filehandling.go @@ -63,6 +63,8 @@ func initFixtures(t *testing.T) { db.LoadAndAssertFixtures(t) // File fixtures InitTestFileFixtures(t) + err := config.SetMaxFileSizeMBytesFromString("20MB") + require.NoError(t, err) } // InitTestFileFixtures initializes file fixtures diff --git a/pkg/files/files.go b/pkg/files/files.go index e56d78393..ef814ed3e 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -98,13 +98,7 @@ func CreateWithMime(f io.Reader, realname string, realsize uint64, a web.Auth, m } func CreateWithMimeAndSession(s *xorm.Session, f io.Reader, realname string, realsize uint64, a web.Auth, mime string, checkFileSizeLimit bool) (file *File, err error) { - // Get and parse the configured file size - var maxSize datasize.ByteSize - err = maxSize.UnmarshalText([]byte(config.FilesMaxSize.GetString())) - if err != nil { - return nil, err - } - if realsize > maxSize.Bytes() && checkFileSizeLimit { + if realsize > config.GetMaxFileSizeInMBytes()*uint64(datasize.MB) && checkFileSizeLimit { return nil, ErrFileIsTooLarge{Size: realsize} } diff --git a/pkg/routes/routes.go b/pkg/routes/routes.go index 65b0505a1..831367b90 100644 --- a/pkg/routes/routes.go +++ b/pkg/routes/routes.go @@ -53,12 +53,15 @@ package routes import ( "errors" + "fmt" "log/slog" + "net/http" "net/url" "strings" "time" "code.vikunja.io/api/pkg/config" + "code.vikunja.io/api/pkg/files" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/models" "code.vikunja.io/api/pkg/modules/auth/openid" @@ -135,6 +138,22 @@ func NewEcho() *echo.Echo { // Validation e.Validator = &CustomValidator{} + // Set body limit to allow file uploads up to the configured size + // Add some overhead for multipart form data (headers, boundaries, etc.) + e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", config.GetMaxFileSizeInMBytes()+2))) + + // Set up custom error handler for body limit exceeded when Sentry is not enabled + if !config.SentryEnabled.GetBool() { + e.HTTPErrorHandler = func(err error, c echo.Context) { + // Convert HTTP 413 errors to custom ErrFileIsTooLarge error + var herr *echo.HTTPError + if errors.As(err, &herr) && herr.Code == http.StatusRequestEntityTooLarge { + err = handler.HandleHTTPError(files.ErrFileIsTooLarge{}) + } + e.DefaultHTTPErrorHandler(err, c) + } + } + return e } @@ -157,8 +176,13 @@ func setupSentry(e *echo.Echo) { })) e.HTTPErrorHandler = func(err error, c echo.Context) { - // Only capture errors not already handled by echo + // Convert HTTP 413 errors to custom ErrFileIsTooLarge error var herr *echo.HTTPError + if errors.As(err, &herr) && herr.Code == http.StatusRequestEntityTooLarge { + err = handler.HandleHTTPError(files.ErrFileIsTooLarge{}) + } + + // Only capture errors not already handled by echo if errors.As(err, &herr) && herr.Code > 499 { var errToReport = err if herr.Internal == nil { @@ -177,6 +201,7 @@ func setupSentry(e *echo.Echo) { } log.Debugf("Error '%s' sent to sentry", err.Error()) } + e.DefaultHTTPErrorHandler(err, c) } } diff --git a/pkg/routes/validation.go b/pkg/routes/validation.go index 54c8ae936..64a8235ca 100644 --- a/pkg/routes/validation.go +++ b/pkg/routes/validation.go @@ -17,9 +17,7 @@ package routes import ( - "strings" - - "code.vikunja.io/api/pkg/config" + "code.vikunja.io/api/pkg/db" "code.vikunja.io/api/pkg/models" "github.com/asaskevich/govalidator" @@ -36,13 +34,13 @@ func init() { // Custom validator for database TEXT fields that adapts to the database being used govalidator.TagMap["dbtext"] = func(str string) bool { // Get the current database dialect - dialect := strings.ToLower(config.DatabaseType.GetString()) + dialect := db.GetDialect() // Default limit for MySQL and unknown databases (65KB safely under TEXT limit) maxLength := 65000 // For databases that support larger text fields - if dialect == "postgres" || dialect == "sqlite" || dialect == "sqlite3" { + if dialect == "postgres" || dialect == "sqlite3" { maxLength = 1048576 // ~1MB limit for PostgreSQL and SQLite } diff --git a/pkg/webtests/task_attachment_upload_test.go b/pkg/webtests/task_attachment_upload_test.go new file mode 100644 index 000000000..12d2edabd --- /dev/null +++ b/pkg/webtests/task_attachment_upload_test.go @@ -0,0 +1,118 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webtests + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "testing" + + "code.vikunja.io/api/pkg/config" + "code.vikunja.io/api/pkg/modules/auth" + "code.vikunja.io/api/pkg/routes" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTaskAttachmentUploadSize(t *testing.T) { + tests := []struct { + name string + fileSize int64 + expectedStatus int + configMaxSize string + }{ + { + name: "Upload file within 32MB boundary", + fileSize: 30 * 1024 * 1024, // 30MB + expectedStatus: http.StatusOK, + configMaxSize: "50MB", + }, + { + name: "Upload file above old 32MB limit", + fileSize: 35 * 1024 * 1024, // 35MB + expectedStatus: http.StatusOK, + configMaxSize: "50MB", + }, + { + name: "Upload file exceeding configured limit", + fileSize: 55 * 1024 * 1024, // 55MB + expectedStatus: http.StatusRequestEntityTooLarge, + configMaxSize: "50MB", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test environment first (this calls InitDefaultConfig) + _, err := setupTestEnv() + require.NoError(t, err) + + // Now set the config AFTER setupTestEnv + oldMaxSize := config.FilesMaxSize.GetString() + config.FilesMaxSize.Set(tt.configMaxSize) + defer config.FilesMaxSize.Set(oldMaxSize) + + // Re-initialize config to update maxFileSizeInBytes + config.InitConfig() + + // Create Echo instance with the updated config + e := routes.NewEcho() + routes.RegisterRoutes(e) + + // Create multipart form data + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("files", "test.pdf") + require.NoError(t, err) + + // Write dummy data of specified size + _, err = io.CopyN(part, bytes.NewReader(make([]byte, tt.fileSize)), tt.fileSize) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + // Create request + req := httptest.NewRequest(http.MethodPut, "/api/v1/tasks/1/attachments", body) + req.Header.Set(echo.HeaderContentType, writer.FormDataContentType()) + + // Add JWT token to request header for authentication + token, err := auth.NewUserJWTAuthtoken(&testuser1, false) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + rec := httptest.NewRecorder() + + // Execute request + e.ServeHTTP(rec, req) + + // Verify status code + assert.Equal(t, tt.expectedStatus, rec.Code) + + // If we expect an error, verify the error response includes code and message + if tt.expectedStatus == http.StatusRequestEntityTooLarge { + assert.Contains(t, rec.Body.String(), "4013") // Error code + assert.Contains(t, rec.Body.String(), "uploaded file exceeds") + } + }) + } +}