package api import ( "context" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "log" "math/big" "net/http" "strings" "sync" "time" "github.com/golang-jwt/jwt/v5" "workorders/internal/model" ) // jwksCache caches the public keys from Keycloak type jwksCache struct { mu sync.RWMutex keys map[string]*rsa.PublicKey fetchAt time.Time url string } var cache = &jwksCache{} func InitJWKS(url string) { cache.url = url if err := cache.refresh(); err != nil { log.Printf("JWKS initial fetch warning: %v (will retry per-request)", err) } } func (c *jwksCache) refresh() error { resp, err := http.Get(c.url) //nolint:gosec if err != nil { return fmt.Errorf("fetch JWKS: %w", err) } defer resp.Body.Close() var jwks struct { Keys []struct { Kid string `json:"kid"` Kty string `json:"kty"` Alg string `json:"alg"` N string `json:"n"` E string `json:"e"` } `json:"keys"` } if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { return fmt.Errorf("decode JWKS: %w", err) } keys := make(map[string]*rsa.PublicKey, len(jwks.Keys)) for _, k := range jwks.Keys { if k.Kty != "RSA" { continue } pub, err := rsaPublicKey(k.N, k.E) if err != nil { continue } keys[k.Kid] = pub } c.mu.Lock() c.keys = keys c.fetchAt = time.Now() c.mu.Unlock() return nil } func (c *jwksCache) get(kid string) (*rsa.PublicKey, error) { c.mu.RLock() key, ok := c.keys[kid] stale := time.Since(c.fetchAt) > 10*time.Minute c.mu.RUnlock() if ok && !stale { return key, nil } if err := c.refresh(); err != nil { if ok { return key, nil // use stale key if refresh fails } return nil, err } c.mu.RLock() key, ok = c.keys[kid] c.mu.RUnlock() if !ok { return nil, fmt.Errorf("key %q not found", kid) } return key, nil } func rsaPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { nBytes, err := base64.RawURLEncoding.DecodeString(nStr) if err != nil { return nil, err } eBytes, err := base64.RawURLEncoding.DecodeString(eStr) if err != nil { return nil, err } n := new(big.Int).SetBytes(nBytes) e := new(big.Int).SetBytes(eBytes) return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil } // OIDCAuth validates a Keycloak-issued JWT in the Authorization header. func OIDCAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { jsonError(w, "unauthorized", http.StatusUnauthorized) return } raw := strings.TrimPrefix(auth, "Bearer ") token, err := jwt.Parse(raw, func(t *jwt.Token) (any, error) { if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } kid, _ := t.Header["kid"].(string) return cache.get(kid) }, jwt.WithExpirationRequired()) if err != nil || !token.Valid { jsonError(w, "invalid token", http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { jsonError(w, "invalid claims", http.StatusUnauthorized) return } user := model.UserClaims{ Sub: stringClaim(claims, "sub"), Email: stringClaim(claims, "email"), Name: stringClaim(claims, "name"), } if ra, ok := claims["realm_access"].(map[string]any); ok { if roles, ok := ra["roles"].([]any); ok { for _, r := range roles { if s, ok := r.(string); ok { user.Roles = append(user.Roles, s) } } } } ctx := context.WithValue(r.Context(), model.CtxUserKey, user) next.ServeHTTP(w, r.WithContext(ctx)) }) } func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } func stringClaim(c jwt.MapClaims, key string) string { v, _ := c[key].(string) return v } func jsonError(w http.ResponseWriter, msg string, code int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) fmt.Fprintf(w, `{"error":%q}`, msg) } func UserFromCtx(r *http.Request) model.UserClaims { u, _ := r.Context().Value(model.CtxUserKey).(model.UserClaims) return u }