refactor: use helper function to check user local

This commit is contained in:
kolaente
2025-09-04 17:27:54 +02:00
parent b8afdcf62d
commit 1b5a9dbdea
3 changed files with 34 additions and 52 deletions

View File

@@ -30,8 +30,27 @@ import (
"code.vikunja.io/api/pkg/web/handler"
"github.com/labstack/echo/v4"
"xorm.io/xorm"
)
// getLocalUserFromContext is a helper function to get the current local user and database session
func getLocalUserFromContext(c echo.Context) (*user.User, *xorm.Session, error) {
s := db.NewSession()
u, err := user.GetCurrentUserFromDB(s, c)
if err != nil {
s.Close()
return nil, nil, err
}
if !u.IsLocalUser() {
s.Close()
return nil, nil, &user.ErrAccountIsNotLocal{UserID: u.ID}
}
return u, s, nil
}
// UserTOTPEnroll is the handler to enroll a user into totp
// @Summary Enroll a user into totp
// @Description Creates an initial setup for the user in the db. After this step, the user needs to verify they have a working totp setup with the "enable totp" endpoint.
@@ -45,18 +64,11 @@ import (
// @Failure 500 {object} models.Message "Internal server error."
// @Router /user/settings/totp/enroll [post]
func UserTOTPEnroll(c echo.Context) error {
s := db.NewSession()
defer s.Close()
u, err := user.GetCurrentUserFromDB(s, c)
u, s, err := getLocalUserFromContext(c)
if err != nil {
return handler.HandleHTTPError(err)
}
// Check if the user is a local user
if !u.IsLocalUser() {
return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID})
}
defer s.Close()
t, err := user.EnrollTOTP(s, u)
if err != nil {
@@ -87,18 +99,11 @@ func UserTOTPEnroll(c echo.Context) error {
// @Failure 500 {object} models.Message "Internal server error."
// @Router /user/settings/totp/enable [post]
func UserTOTPEnable(c echo.Context) error {
s := db.NewSession()
defer s.Close()
u, err := user.GetCurrentUserFromDB(s, c)
u, s, err := getLocalUserFromContext(c)
if err != nil {
return handler.HandleHTTPError(err)
}
// Check if the user is a local user
if !u.IsLocalUser() {
return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID})
}
defer s.Close()
passcode := &user.TOTPPasscode{
User: u,
@@ -150,23 +155,12 @@ func UserTOTPDisable(c echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid model provided.").SetInternal(err)
}
s := db.NewSession()
u, s, err := getLocalUserFromContext(c)
if err != nil {
return handler.HandleHTTPError(err)
}
defer s.Close()
u, err := user.GetCurrentUserFromDB(s, c)
if err != nil {
return handler.HandleHTTPError(err)
}
// Check if the user is a local user
if !u.IsLocalUser() {
return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID})
}
if err != nil {
_ = s.Rollback()
return handler.HandleHTTPError(err)
}
err = user.CheckUserPassword(u, login.Password)
if err != nil {
_ = s.Rollback()
@@ -198,18 +192,11 @@ func UserTOTPDisable(c echo.Context) error {
// @Failure 500 {object} models.Message "Internal server error."
// @Router /user/settings/totp/qrcode [get]
func UserTOTPQrCode(c echo.Context) error {
s := db.NewSession()
defer s.Close()
u, err := user.GetCurrentUserFromDB(s, c)
u, s, err := getLocalUserFromContext(c)
if err != nil {
return handler.HandleHTTPError(err)
}
// Check if the user is a local user
if !u.IsLocalUser() {
return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID})
}
defer s.Close()
qrcode, err := user.GetTOTPQrCodeForUser(s, u)
if err != nil {
@@ -243,18 +230,11 @@ func UserTOTPQrCode(c echo.Context) error {
// @Failure 500 {object} models.Message "Internal server error."
// @Router /user/settings/totp [get]
func UserTOTP(c echo.Context) error {
s := db.NewSession()
defer s.Close()
u, err := user.GetCurrentUserFromDB(s, c)
u, s, err := getLocalUserFromContext(c)
if err != nil {
return handler.HandleHTTPError(err)
}
// Check if the user is a local user
if !u.IsLocalUser() {
return handler.HandleHTTPError(&user.ErrAccountIsNotLocal{UserID: u.ID})
}
defer s.Close()
t, err := user.GetTOTPForUser(s, u)
if err != nil {

View File

@@ -52,12 +52,14 @@ var (
Username: "user1",
Password: "$2a$14$dcadBoMBL9jQoOcZK8Fju.cy0Ptx2oZECkKLnaa8ekRoTFe1w7To.",
Email: "user1@example.com",
Issuer: "local",
}
testuser15 = user.User{
ID: 15,
Username: "user15",
Password: "$2a$14$dcadBoMBL9jQoOcZK8Fju.cy0Ptx2oZECkKLnaa8ekRoTFe1w7To.",
Email: "user15@example.com",
Issuer: "local",
}
)

View File

@@ -50,4 +50,4 @@ func TestUserTOTPLocalUser(t *testing.T) {
assert.Contains(t, rec.Body.String(), `"secret"`)
assert.Contains(t, rec.Body.String(), `"enabled":false`)
})
}
}