191 lines
4.4 KiB
Go
191 lines
4.4 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"workorders/internal/model"
|
|
)
|
|
|
|
// jwksCache caches the public keys from Keycloak
|
|
type jwksCache struct {
|
|
mu sync.RWMutex
|
|
keys map[string]*rsa.PublicKey
|
|
fetchAt time.Time
|
|
url string
|
|
}
|
|
|
|
var cache = &jwksCache{}
|
|
|
|
func InitJWKS(url string) {
|
|
cache.url = url
|
|
if err := cache.refresh(); err != nil {
|
|
log.Printf("JWKS initial fetch warning: %v (will retry per-request)", err)
|
|
}
|
|
}
|
|
|
|
func (c *jwksCache) refresh() error {
|
|
resp, err := http.Get(c.url) //nolint:gosec
|
|
if err != nil {
|
|
return fmt.Errorf("fetch JWKS: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var jwks struct {
|
|
Keys []struct {
|
|
Kid string `json:"kid"`
|
|
Kty string `json:"kty"`
|
|
Alg string `json:"alg"`
|
|
N string `json:"n"`
|
|
E string `json:"e"`
|
|
} `json:"keys"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
|
return fmt.Errorf("decode JWKS: %w", err)
|
|
}
|
|
|
|
keys := make(map[string]*rsa.PublicKey, len(jwks.Keys))
|
|
for _, k := range jwks.Keys {
|
|
if k.Kty != "RSA" {
|
|
continue
|
|
}
|
|
pub, err := rsaPublicKey(k.N, k.E)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
keys[k.Kid] = pub
|
|
}
|
|
|
|
c.mu.Lock()
|
|
c.keys = keys
|
|
c.fetchAt = time.Now()
|
|
c.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (c *jwksCache) get(kid string) (*rsa.PublicKey, error) {
|
|
c.mu.RLock()
|
|
key, ok := c.keys[kid]
|
|
stale := time.Since(c.fetchAt) > 10*time.Minute
|
|
c.mu.RUnlock()
|
|
|
|
if ok && !stale {
|
|
return key, nil
|
|
}
|
|
if err := c.refresh(); err != nil {
|
|
if ok {
|
|
return key, nil // use stale key if refresh fails
|
|
}
|
|
return nil, err
|
|
}
|
|
c.mu.RLock()
|
|
key, ok = c.keys[kid]
|
|
c.mu.RUnlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("key %q not found", kid)
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
func rsaPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|
nBytes, err := base64.RawURLEncoding.DecodeString(nStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
eBytes, err := base64.RawURLEncoding.DecodeString(eStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
n := new(big.Int).SetBytes(nBytes)
|
|
e := new(big.Int).SetBytes(eBytes)
|
|
return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil
|
|
}
|
|
|
|
// OIDCAuth validates a Keycloak-issued JWT in the Authorization header.
|
|
func OIDCAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(auth, "Bearer ") {
|
|
jsonError(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
raw := strings.TrimPrefix(auth, "Bearer ")
|
|
|
|
token, err := jwt.Parse(raw, func(t *jwt.Token) (any, error) {
|
|
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
|
}
|
|
kid, _ := t.Header["kid"].(string)
|
|
return cache.get(kid)
|
|
}, jwt.WithExpirationRequired())
|
|
|
|
if err != nil || !token.Valid {
|
|
jsonError(w, "invalid token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
jsonError(w, "invalid claims", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
user := model.UserClaims{
|
|
Sub: stringClaim(claims, "sub"),
|
|
Email: stringClaim(claims, "email"),
|
|
Name: stringClaim(claims, "name"),
|
|
}
|
|
if ra, ok := claims["realm_access"].(map[string]any); ok {
|
|
if roles, ok := ra["roles"].([]any); ok {
|
|
for _, r := range roles {
|
|
if s, ok := r.(string); ok {
|
|
user.Roles = append(user.Roles, s)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), model.CtxUserKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func CORS(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func stringClaim(c jwt.MapClaims, key string) string {
|
|
v, _ := c[key].(string)
|
|
return v
|
|
}
|
|
|
|
func jsonError(w http.ResponseWriter, msg string, code int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(code)
|
|
fmt.Fprintf(w, `{"error":%q}`, msg)
|
|
}
|
|
|
|
func UserFromCtx(r *http.Request) model.UserClaims {
|
|
u, _ := r.Context().Value(model.CtxUserKey).(model.UserClaims)
|
|
return u
|
|
}
|