mirror of
https://github.com/go-vikunja/vikunja.git
synced 2026-04-25 14:45:20 +00:00
268 lines
6.8 KiB
Go
268 lines
6.8 KiB
Go
// Vikunja is a to-do list application to facilitate your life.
|
|
// Copyright 2018-present Vikunja and contributors. All rights reserved.
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Affero General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Affero General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Affero General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package openid
|
|
|
|
import (
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"code.vikunja.io/api/pkg/config"
|
|
"code.vikunja.io/api/pkg/log"
|
|
"code.vikunja.io/api/pkg/modules/keyvalue"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// GetAllProviders returns all configured providers
|
|
func GetAllProviders() (providers []*Provider, err error) {
|
|
if !config.AuthOpenIDEnabled.GetBool() {
|
|
return nil, nil
|
|
}
|
|
|
|
providers = []*Provider{}
|
|
exists, err := keyvalue.GetWithValue("openid_providers", &providers)
|
|
if !exists {
|
|
rawProviders := config.AuthOpenIDProviders.Get()
|
|
if rawProviders == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
rawProvider, is := rawProviders.(map[string]interface{})
|
|
if !is {
|
|
// Try to convert from map[interface{}]interface{} (YAML format)
|
|
if rawProviderInterface, ok := rawProviders.(map[interface{}]interface{}); ok {
|
|
rawProvider = make(map[string]interface{}, len(rawProviderInterface))
|
|
for k, v := range rawProviderInterface {
|
|
if key, keyOK := k.(string); keyOK {
|
|
rawProvider[key] = v
|
|
}
|
|
}
|
|
} else {
|
|
log.Criticalf("It looks like your openid configuration is in the wrong format. Please check the docs for the correct format.")
|
|
return
|
|
}
|
|
}
|
|
|
|
for key, p := range rawProvider {
|
|
var pi map[string]interface{}
|
|
var is bool
|
|
pi, is = p.(map[string]interface{})
|
|
// JSON config is a map[string]interface{}, other providers are not. Under the hood they are all strings so
|
|
// it is safe to cast.
|
|
if !is {
|
|
if pis, pisOK := p.(map[interface{}]interface{}); pisOK {
|
|
pi = make(map[string]interface{}, len(pis))
|
|
for i, s := range pis {
|
|
if key, keyOK := i.(string); keyOK {
|
|
pi[key] = s
|
|
}
|
|
}
|
|
} else {
|
|
log.Errorf("Provider %s has invalid configuration format, skipping", key)
|
|
continue
|
|
}
|
|
}
|
|
|
|
provider, err := getProviderFromMap(pi, key)
|
|
|
|
if err != nil {
|
|
log.Errorf("Error while getting openid provider %s: %s", key, err)
|
|
continue
|
|
}
|
|
|
|
if provider == nil {
|
|
log.Errorf("Could not openid provider %s, please check your config", key)
|
|
continue
|
|
}
|
|
|
|
providers = append(providers, provider)
|
|
|
|
err = keyvalue.Put("openid_provider_"+key, provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
err = keyvalue.Put("openid_providers", providers)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// GetProvider retrieves a provider from keyvalue
|
|
func GetProvider(key string) (provider *Provider, err error) {
|
|
provider = &Provider{}
|
|
exists, err := keyvalue.GetWithValue("openid_provider_"+key, provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !exists {
|
|
_, err = GetAllProviders() // This will put all providers in cache
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = keyvalue.GetWithValue("openid_provider_"+key, provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
err = provider.setOicdProvider()
|
|
return
|
|
}
|
|
|
|
func getProviderFromMap(pi map[string]interface{}, key string) (provider *Provider, err error) {
|
|
|
|
requiredKeys := []string{
|
|
"name",
|
|
"authurl",
|
|
"clientsecret",
|
|
"clientid",
|
|
}
|
|
|
|
allKeys := append(
|
|
[]string{
|
|
"logouturl",
|
|
"scope",
|
|
"emailfallback",
|
|
"usernamefallback",
|
|
"forceuserinfo",
|
|
"requireavailability",
|
|
},
|
|
requiredKeys...,
|
|
)
|
|
|
|
for _, configKey := range allKeys {
|
|
valueFromFile := config.GetConfigValueFromFile("auth.openid.providers." + key + "." + configKey)
|
|
if valueFromFile != "" {
|
|
pi[configKey] = valueFromFile
|
|
}
|
|
}
|
|
|
|
for _, key := range requiredKeys {
|
|
if _, exists := pi[key]; !exists {
|
|
return nil, fmt.Errorf("required key '%s' is missing in the provider configuration", key)
|
|
}
|
|
}
|
|
|
|
name, is := pi["name"].(string)
|
|
if !is {
|
|
return nil, nil
|
|
}
|
|
|
|
var logoutURL string
|
|
logoutValue, exists := pi["logouturl"]
|
|
if exists {
|
|
url, ok := logoutValue.(string)
|
|
if ok {
|
|
logoutURL = url
|
|
}
|
|
}
|
|
|
|
var scope string
|
|
if scopeValue, exists := pi["scope"]; exists {
|
|
scope = scopeValue.(string)
|
|
}
|
|
if scope == "" {
|
|
scope = "openid profile email"
|
|
}
|
|
|
|
var emailFallback = false
|
|
emailFallbackValue, exists := pi["emailfallback"]
|
|
if exists {
|
|
emailFallbackTypedValue, ok := emailFallbackValue.(bool)
|
|
if ok {
|
|
emailFallback = emailFallbackTypedValue
|
|
}
|
|
}
|
|
var usernameFallback = false
|
|
usernameFallbackValue, exists := pi["usernamefallback"]
|
|
if exists {
|
|
usernameFallbackTypedValue, ok := usernameFallbackValue.(bool)
|
|
if ok {
|
|
usernameFallback = usernameFallbackTypedValue
|
|
}
|
|
}
|
|
|
|
var forceUserInfo = false
|
|
forceUserInfoValue, exists := pi["forceuserinfo"]
|
|
if exists {
|
|
forceUserInfoTypedValue, ok := forceUserInfoValue.(bool)
|
|
if ok {
|
|
forceUserInfo = forceUserInfoTypedValue
|
|
} else {
|
|
log.Errorf("forceuserinfo is not a boolean for provider %s, value: %v", key, forceUserInfoValue)
|
|
}
|
|
}
|
|
|
|
var requireAvailability = false
|
|
requireAvailabilityValue, exists := pi["requireavailability"]
|
|
if exists {
|
|
requireAvailabilityTypedValue, ok := requireAvailabilityValue.(bool)
|
|
if ok {
|
|
requireAvailability = requireAvailabilityTypedValue
|
|
} else {
|
|
log.Errorf("requireavailability is not a boolean for provider %s, value: %v", key, requireAvailabilityValue)
|
|
}
|
|
}
|
|
|
|
provider = &Provider{
|
|
Name: name,
|
|
Key: key,
|
|
AuthURL: pi["authurl"].(string),
|
|
OriginalAuthURL: pi["authurl"].(string),
|
|
ClientSecret: pi["clientsecret"].(string),
|
|
LogoutURL: logoutURL,
|
|
Scope: scope,
|
|
EmailFallback: emailFallback,
|
|
UsernameFallback: usernameFallback,
|
|
ForceUserInfo: forceUserInfo,
|
|
RequireAvailability: requireAvailability,
|
|
}
|
|
|
|
cl, is := pi["clientid"].(int)
|
|
if is {
|
|
provider.ClientID = strconv.Itoa(cl)
|
|
} else {
|
|
provider.ClientID = pi["clientid"].(string)
|
|
}
|
|
|
|
err = provider.setOicdProvider()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
provider.Oauth2Config = &oauth2.Config{
|
|
ClientID: provider.ClientID,
|
|
ClientSecret: provider.ClientSecret,
|
|
// Discovery returns the OAuth2 endpoints.
|
|
Endpoint: provider.openIDProvider.Endpoint(),
|
|
|
|
// "openid" is a required scope for OpenID Connect flows.
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
|
|
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
|
|
|
|
return
|
|
}
|
|
|
|
func CleanupSavedOpenIDProviders() {
|
|
_ = keyvalue.Del("openid_providers")
|
|
}
|