Files
workorders/internal/api/middleware.go
T

191 lines
4.4 KiB
Go

package api
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"math/big"
"net/http"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"workorders/internal/model"
)
// jwksCache caches the public keys from Keycloak
type jwksCache struct {
mu sync.RWMutex
keys map[string]*rsa.PublicKey
fetchAt time.Time
url string
}
var cache = &jwksCache{}
func InitJWKS(url string) {
cache.url = url
if err := cache.refresh(); err != nil {
log.Printf("JWKS initial fetch warning: %v (will retry per-request)", err)
}
}
func (c *jwksCache) refresh() error {
resp, err := http.Get(c.url) //nolint:gosec
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer resp.Body.Close()
var jwks struct {
Keys []struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
} `json:"keys"`
}
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return fmt.Errorf("decode JWKS: %w", err)
}
keys := make(map[string]*rsa.PublicKey, len(jwks.Keys))
for _, k := range jwks.Keys {
if k.Kty != "RSA" {
continue
}
pub, err := rsaPublicKey(k.N, k.E)
if err != nil {
continue
}
keys[k.Kid] = pub
}
c.mu.Lock()
c.keys = keys
c.fetchAt = time.Now()
c.mu.Unlock()
return nil
}
func (c *jwksCache) get(kid string) (*rsa.PublicKey, error) {
c.mu.RLock()
key, ok := c.keys[kid]
stale := time.Since(c.fetchAt) > 10*time.Minute
c.mu.RUnlock()
if ok && !stale {
return key, nil
}
if err := c.refresh(); err != nil {
if ok {
return key, nil // use stale key if refresh fails
}
return nil, err
}
c.mu.RLock()
key, ok = c.keys[kid]
c.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("key %q not found", kid)
}
return key, nil
}
func rsaPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(nStr)
if err != nil {
return nil, err
}
eBytes, err := base64.RawURLEncoding.DecodeString(eStr)
if err != nil {
return nil, err
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil
}
// OIDCAuth validates a Keycloak-issued JWT in the Authorization header.
func OIDCAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
jsonError(w, "unauthorized", http.StatusUnauthorized)
return
}
raw := strings.TrimPrefix(auth, "Bearer ")
token, err := jwt.Parse(raw, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
kid, _ := t.Header["kid"].(string)
return cache.get(kid)
}, jwt.WithExpirationRequired())
if err != nil || !token.Valid {
jsonError(w, "invalid token", http.StatusUnauthorized)
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
jsonError(w, "invalid claims", http.StatusUnauthorized)
return
}
user := model.UserClaims{
Sub: stringClaim(claims, "sub"),
Email: stringClaim(claims, "email"),
Name: stringClaim(claims, "name"),
}
if ra, ok := claims["realm_access"].(map[string]any); ok {
if roles, ok := ra["roles"].([]any); ok {
for _, r := range roles {
if s, ok := r.(string); ok {
user.Roles = append(user.Roles, s)
}
}
}
}
ctx := context.WithValue(r.Context(), model.CtxUserKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func stringClaim(c jwt.MapClaims, key string) string {
v, _ := c[key].(string)
return v
}
func jsonError(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
fmt.Fprintf(w, `{"error":%q}`, msg)
}
func UserFromCtx(r *http.Request) model.UserClaims {
u, _ := r.Context().Value(model.CtxUserKey).(model.UserClaims)
return u
}