diff --git a/pkg/doctor/services.go b/pkg/doctor/services.go index 463cc4ab2..2689337b8 100644 --- a/pkg/doctor/services.go +++ b/pkg/doctor/services.go @@ -18,13 +18,16 @@ package doctor import ( "context" + "encoding/json" "fmt" + "io" "net" "net/http" "time" "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/modules/auth/ldap" + "code.vikunja.io/api/pkg/modules/auth/openid" "code.vikunja.io/api/pkg/red" ) @@ -243,9 +246,22 @@ func checkOpenID() CheckGroup { } var results []CheckResult + providerIssuers := make(map[string]string) for key, p := range providerMap { - result := checkOpenIDProvider(key, p) + result, issuer := checkOpenIDProvider(key, p) results = append(results, result) + if issuer != "" { + providerIssuers[key] = issuer + } + } + + // Check for duplicate issuers among successful providers + for _, dup := range openid.FindDuplicateIssuers(providerIssuers) { + results = append(results, CheckResult{ + Name: "Duplicate Issuer", + Passed: false, + Error: dup.Error(), + }) } return CheckGroup{ @@ -254,7 +270,7 @@ func checkOpenID() CheckGroup { } } -func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { +func checkOpenIDProvider(key string, rawProvider interface{}) (CheckResult, string) { // Extract provider config var pi map[string]interface{} switch p := rawProvider.(type) { @@ -272,7 +288,7 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", key), Passed: false, Error: "invalid configuration format", - } + }, "" } // Get provider name @@ -288,7 +304,7 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", name), Passed: false, Error: "authurl not configured", - } + }, "" } // Check if the provider's discovery endpoint is reachable @@ -308,7 +324,7 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", name), Passed: false, Error: err.Error(), - } + }, "" } client := &http.Client{Timeout: 5 * time.Second} @@ -318,7 +334,7 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", name), Passed: false, Error: err.Error(), - } + }, "" } defer resp.Body.Close() @@ -327,6 +343,18 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", name), Passed: false, Error: fmt.Sprintf("discovery endpoint returned status %d", resp.StatusCode), + }, "" + } + + // Parse the issuer from the discovery response for duplicate detection + var issuer string + body, err := io.ReadAll(resp.Body) + if err == nil { + var discovery struct { + Issuer string `json:"issuer"` + } + if json.Unmarshal(body, &discovery) == nil { + issuer = discovery.Issuer } } @@ -334,5 +362,5 @@ func checkOpenIDProvider(key string, rawProvider interface{}) CheckResult { Name: fmt.Sprintf("Provider: %s", name), Passed: true, Value: "OK", - } + }, issuer } diff --git a/pkg/initialize/init.go b/pkg/initialize/init.go index 7c70609ce..7145b06d8 100644 --- a/pkg/initialize/init.go +++ b/pkg/initialize/init.go @@ -101,6 +101,9 @@ func FullInitWithoutAsync() { // Check all OpenID Connect providers at startup _, err = openid.GetAllProviders() if err != nil { + if openid.IsErrDuplicateOIDCIssuer(err) { + log.Fatalf("OpenID Connect configuration error: %s", err) + } log.Errorf("Error initializing OpenID Connect providers: %s", err) } diff --git a/pkg/modules/auth/openid/providers.go b/pkg/modules/auth/openid/providers.go index 5638d0316..aaafab82f 100644 --- a/pkg/modules/auth/openid/providers.go +++ b/pkg/modules/auth/openid/providers.go @@ -17,6 +17,7 @@ package openid import ( + "errors" "fmt" "strconv" @@ -28,6 +29,44 @@ import ( "golang.org/x/oauth2" ) +// ErrDuplicateOIDCIssuer is returned when two configured providers resolve to the same issuer URL. +type ErrDuplicateOIDCIssuer struct { + Issuer string + Provider1 string + Provider2 string +} + +func (e *ErrDuplicateOIDCIssuer) Error() string { + return fmt.Sprintf( + "duplicate OpenID Connect issuer %q: providers %q and %q resolve to the same issuer, which will cause team sync conflicts", + e.Issuer, e.Provider1, e.Provider2, + ) +} + +// IsErrDuplicateOIDCIssuer checks if an error is a duplicate issuer error. +func IsErrDuplicateOIDCIssuer(err error) bool { + var target *ErrDuplicateOIDCIssuer + return errors.As(err, &target) +} + +// FindDuplicateIssuers checks a map of provider key → issuer URL for duplicates. +// It returns a list of all duplicate pairs found. +func FindDuplicateIssuers(providerIssuers map[string]string) []ErrDuplicateOIDCIssuer { + issuerToKey := make(map[string]string) + var duplicates []ErrDuplicateOIDCIssuer + for key, issuer := range providerIssuers { + if existingKey, exists := issuerToKey[issuer]; exists { + duplicates = append(duplicates, ErrDuplicateOIDCIssuer{ + Issuer: issuer, + Provider1: existingKey, + Provider2: key, + }) + } + issuerToKey[issuer] = key + } + return duplicates +} + // GetAllProviders returns all configured providers func GetAllProviders() (providers []*Provider, err error) { if !config.AuthOpenIDEnabled.GetBool() { @@ -97,6 +136,21 @@ func GetAllProviders() (providers []*Provider, err error) { return nil, err } } + + // Check for duplicate issuers across providers + providerIssuers := make(map[string]string) + for _, p := range providers { + issuer, issuerErr := p.Issuer() + if issuerErr != nil { + log.Errorf("Error getting issuer for openid provider %s: %s", p.Key, issuerErr) + continue + } + providerIssuers[p.Key] = issuer + } + if duplicates := FindDuplicateIssuers(providerIssuers); len(duplicates) > 0 { + return nil, &duplicates[0] + } + err = keyvalue.Put("openid_providers", providers) } diff --git a/pkg/modules/auth/openid/providers_test.go b/pkg/modules/auth/openid/providers_test.go index 377ae548c..902f140c9 100644 --- a/pkg/modules/auth/openid/providers_test.go +++ b/pkg/modules/auth/openid/providers_test.go @@ -17,10 +17,16 @@ package openid import ( + "encoding/json" + "net/http" + "net/http/httptest" "testing" "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/modules/keyvalue" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetAllProvidersTypeSafety(t *testing.T) { @@ -84,3 +90,119 @@ func TestGetAllProvidersTypeSafety(t *testing.T) { } }) } + +// newMockOIDCServer creates a test HTTP server that serves a valid OIDC discovery document. +// The issuer in the discovery document matches the server's URL. +func newMockOIDCServer() *httptest.Server { + var server *httptest.Server + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + discovery := map[string]interface{}{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/auth", + "token_endpoint": server.URL + "/token", + "jwks_uri": server.URL + "/jwks", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(discovery) + }) + server = httptest.NewServer(mux) + return server +} + +func TestDuplicateIssuersDetected(t *testing.T) { + defer CleanupSavedOpenIDProviders() + + // Create a single mock server — both providers will use the same issuer + server := newMockOIDCServer() + defer server.Close() + + config.AuthOpenIDEnabled.Set(true) + config.AuthOpenIDProviders.Set(map[string]interface{}{ + "provider1": map[string]interface{}{ + "name": "Provider One", + "authurl": server.URL, + "clientid": "client1", + "clientsecret": "secret1", + }, + "provider2": map[string]interface{}{ + "name": "Provider Two", + "authurl": server.URL, + "clientid": "client2", + "clientsecret": "secret2", + }, + }) + _ = keyvalue.Del("openid_providers") + + providers, err := GetAllProviders() + require.Error(t, err) + assert.Nil(t, providers) + assert.True(t, IsErrDuplicateOIDCIssuer(err)) + + var dupErr *ErrDuplicateOIDCIssuer + require.ErrorAs(t, err, &dupErr) + assert.Equal(t, server.URL, dupErr.Issuer) +} + +func TestUniqueIssuersAllowed(t *testing.T) { + defer CleanupSavedOpenIDProviders() + + // Create two separate mock servers — different issuers + server1 := newMockOIDCServer() + defer server1.Close() + server2 := newMockOIDCServer() + defer server2.Close() + + config.AuthOpenIDEnabled.Set(true) + config.AuthOpenIDProviders.Set(map[string]interface{}{ + "provider1": map[string]interface{}{ + "name": "Provider One", + "authurl": server1.URL, + "clientid": "client1", + "clientsecret": "secret1", + }, + "provider2": map[string]interface{}{ + "name": "Provider Two", + "authurl": server2.URL, + "clientid": "client2", + "clientsecret": "secret2", + }, + }) + _ = keyvalue.Del("openid_providers") + + providers, err := GetAllProviders() + require.NoError(t, err) + assert.Len(t, providers, 2) +} + +func TestFailedDiscoverySkippedInIssuerCheck(t *testing.T) { + defer CleanupSavedOpenIDProviders() + + // One valid server, one unreachable + server := newMockOIDCServer() + defer server.Close() + + config.AuthOpenIDEnabled.Set(true) + config.AuthOpenIDProviders.Set(map[string]interface{}{ + "valid": map[string]interface{}{ + "name": "Valid Provider", + "authurl": server.URL, + "clientid": "client1", + "clientsecret": "secret1", + }, + "broken": map[string]interface{}{ + "name": "Broken Provider", + "authurl": "http://127.0.0.1:1", + "clientid": "client2", + "clientsecret": "secret2", + }, + }) + _ = keyvalue.Del("openid_providers") + + // The broken provider will fail discovery and be skipped. + // The valid provider should load successfully. + providers, err := GetAllProviders() + require.NoError(t, err) + assert.Len(t, providers, 1) + assert.Equal(t, "Valid Provider", providers[0].Name) +}