// Package admin provides HTTP client functionality for communicating
// with the Management API from the Admin UI server.
package admin
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// context keys used to forward browser metadata from Admin UI → Management API
type forwardedCtxKey string
const (
ctxKeyForwardedUA forwardedCtxKey = "forwarded_user_agent"
ctxKeyForwardedReferer forwardedCtxKey = "forwarded_referer"
ctxKeyForwardedIP forwardedCtxKey = "forwarded_ip"
)
// APIClient handles communication with the Management API
type APIClient struct {
baseURL string
token string
httpClient *http.Client
}
// NewAPIClient creates a new Management API client
func NewAPIClient(baseURL, token string) *APIClient {
return &APIClient{
baseURL: baseURL,
token: token,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// ObfuscateAPIKey obfuscates an API key for display purposes
// Shows first 8 characters followed by dots and last 4 characters
func ObfuscateAPIKey(apiKey string) string { return obfuscate.ObfuscateTokenGeneric(apiKey) }
// ObfuscateToken obfuscates a token for display purposes.
// Use centralized helper to avoid linter warning on deprecated wrapper.
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenGeneric(token) }
// Project represents a project from the Management API
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
APIKey string `json:"api_key"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Token represents a token from the Management API (sanitized)
type Token struct {
ID string `json:"id"` // Token UUID (for API operations)
Token string `json:"token"` // Obfuscated token string (for display)
ProjectID string `json:"project_id"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CacheHitCount int `json:"cache_hit_count"`
}
// TokenCreateResponse represents the response when creating a token
type TokenCreateResponse struct {
Token string `json:"token"`
ExpiresAt time.Time `json:"expires_at"`
MaxRequests *int `json:"max_requests,omitempty"`
}
// Pagination represents pagination metadata
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalItems int `json:"total_items"`
TotalPages int `json:"total_pages"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_prev"`
}
// DashboardData represents dashboard statistics
type DashboardData struct {
TotalProjects int `json:"total_projects"`
TotalTokens int `json:"total_tokens"`
ActiveTokens int `json:"active_tokens"`
ExpiredTokens int `json:"expired_tokens"`
TotalRequests int `json:"total_requests"`
RequestsToday int `json:"requests_today"`
RequestsThisWeek int `json:"requests_this_week"`
}
// AuditEvent represents an audit log entry for the admin UI
type AuditEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"`
Actor string `json:"actor"`
ProjectID *string `json:"project_id,omitempty"`
RequestID *string `json:"request_id,omitempty"`
CorrelationID *string `json:"correlation_id,omitempty"`
ClientIP *string `json:"client_ip,omitempty"`
Method *string `json:"method,omitempty"`
Path *string `json:"path,omitempty"`
UserAgent *string `json:"user_agent,omitempty"`
Outcome string `json:"outcome"`
Reason *string `json:"reason,omitempty"`
TokenID *string `json:"token_id,omitempty"`
Metadata *string `json:"metadata,omitempty"` // JSON string
ParsedMeta *AuditMetadata `json:"-"` // Parsed metadata for template use
}
// AuditMetadata represents parsed metadata for easier template access
type AuditMetadata map[string]interface{}
// AuditEventsResponse represents the API response for audit events listing
type AuditEventsResponse struct {
Events []AuditEvent `json:"events"`
Pagination Pagination `json:"pagination"`
}
// GetDashboardData retrieves dashboard statistics
func (c *APIClient) GetDashboardData(ctx context.Context) (*DashboardData, error) {
// For now, calculate from projects and tokens lists
// In the future, this could be a dedicated dashboard endpoint
projects, _, err := c.GetProjects(ctx, 1, 1000) // Get all projects
if err != nil {
return nil, err
}
tokens, _, err := c.GetTokens(ctx, "", 1, 1000) // Get all tokens
if err != nil {
return nil, err
}
data := &DashboardData{
TotalProjects: len(projects),
TotalTokens: len(tokens),
}
// Calculate active/expired tokens and request counts
now := time.Now()
for _, token := range tokens {
if token.IsActive && token.ExpiresAt != nil && token.ExpiresAt.After(now) {
data.ActiveTokens++
} else {
data.ExpiredTokens++
}
data.TotalRequests += token.RequestCount
// Calculate today's requests (approximation)
if token.LastUsedAt != nil && token.LastUsedAt.After(now.AddDate(0, 0, -1)) {
data.RequestsToday += token.RequestCount
}
// Calculate this week's requests (approximation)
if token.LastUsedAt != nil && token.LastUsedAt.After(now.AddDate(0, 0, -7)) {
data.RequestsThisWeek += token.RequestCount
}
}
return data, nil
}
// GetProjects retrieves a paginated list of projects
func (c *APIClient) GetProjects(ctx context.Context, page, pageSize int) ([]Project, *Pagination, error) {
// Since the Management API doesn't currently support pagination,
// we'll get all projects and simulate pagination
req, err := c.newRequest(ctx, "GET", "/manage/projects", nil)
if err != nil {
return nil, nil, err
}
var projects []Project
if err := c.doRequest(req, &projects); err != nil {
return nil, nil, err
}
// Simulate pagination
totalItems := len(projects)
totalPages := (totalItems + pageSize - 1) / pageSize
start := (page - 1) * pageSize
end := start + pageSize
if start >= totalItems {
projects = []Project{}
} else {
if end > totalItems {
end = totalItems
}
projects = projects[start:end]
}
pagination := &Pagination{
Page: page,
PageSize: pageSize,
TotalItems: totalItems,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
return projects, pagination, nil
}
// GetProject retrieves a single project by ID
func (c *APIClient) GetProject(ctx context.Context, id string) (*Project, error) {
req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/manage/projects/%s", id), nil)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// CreateProject creates a new project
func (c *APIClient) CreateProject(ctx context.Context, name, apiKey string) (*Project, error) {
payload := map[string]string{
"name": name,
"api_key": apiKey,
}
req, err := c.newRequest(ctx, "POST", "/manage/projects", payload)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// UpdateProject updates an existing project
func (c *APIClient) UpdateProject(ctx context.Context, id, name, apiKey string, isActive *bool) (*Project, error) {
payload := map[string]interface{}{}
if name != "" {
payload["name"] = name
}
if apiKey != "" {
payload["api_key"] = apiKey
}
if isActive != nil {
payload["is_active"] = *isActive
}
req, err := c.newRequest(ctx, "PATCH", fmt.Sprintf("/manage/projects/%s", id), payload)
if err != nil {
return nil, err
}
var project Project
if err := c.doRequest(req, &project); err != nil {
return nil, err
}
return &project, nil
}
// DeleteProject deletes a project
func (c *APIClient) DeleteProject(ctx context.Context, id string) error {
req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/manage/projects/%s", id), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// GetTokens retrieves a paginated list of tokens
func (c *APIClient) GetTokens(ctx context.Context, projectID string, page, pageSize int) ([]Token, *Pagination, error) {
path := "/manage/tokens"
if projectID != "" {
path += "?projectId=" + url.QueryEscape(projectID)
}
req, err := c.newRequest(ctx, "GET", path, nil)
if err != nil {
return nil, nil, err
}
var tokens []Token
if err := c.doRequest(req, &tokens); err != nil {
return nil, nil, err
}
// Simulate pagination (similar to projects)
totalItems := len(tokens)
totalPages := (totalItems + pageSize - 1) / pageSize
start := (page - 1) * pageSize
end := start + pageSize
if start >= totalItems {
tokens = []Token{}
} else {
if end > totalItems {
end = totalItems
}
tokens = tokens[start:end]
}
pagination := &Pagination{
Page: page,
PageSize: pageSize,
TotalItems: totalItems,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
return tokens, pagination, nil
}
// CreateToken creates a new token for a project with a given duration in minutes
func (c *APIClient) CreateToken(ctx context.Context, projectID string, durationMinutes int, maxRequests *int) (*TokenCreateResponse, error) {
payload := map[string]interface{}{
"project_id": projectID,
"duration_minutes": durationMinutes,
}
if maxRequests != nil {
payload["max_requests"] = *maxRequests
}
// Use newRequest and doRequest for consistent error handling
req, err := c.newRequest(ctx, "POST", "/manage/tokens", payload)
if err != nil {
return nil, err
}
var result TokenCreateResponse
if err := c.doRequest(req, &result); err != nil {
return nil, err
}
return &result, nil
}
// newRequest creates a new HTTP request with authentication
func (c *APIClient) newRequest(ctx context.Context, method, path string, body any) (*http.Request, error) {
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Authorization", "Bearer "+c.token)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
// Forward browser context when present
if v := ctx.Value(ctxKeyForwardedUA); v != nil {
if ua, ok := v.(string); ok && ua != "" {
req.Header.Set("X-Forwarded-User-Agent", ua)
req.Header.Set("X-Admin-Origin", "1")
}
}
if v := ctx.Value(ctxKeyForwardedReferer); v != nil {
if ref, ok := v.(string); ok && ref != "" {
req.Header.Set("X-Forwarded-Referer", ref)
req.Header.Set("X-Admin-Origin", "1")
}
}
if v := ctx.Value(ctxKeyForwardedIP); v != nil {
if ip, ok := v.(string); ok && ip != "" {
// Provide original browser IP for backend audit logging
req.Header.Set("X-Forwarded-For", ip)
req.Header.Set("X-Admin-Origin", "1")
}
}
return req, nil
}
// doRequest executes an HTTP request and handles the response
func (c *APIClient) doRequest(req *http.Request, result any) error {
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer func() {
err := resp.Body.Close()
if err != nil {
// Log or handle the error as appropriate
// For now, just log to standard error
fmt.Fprintf(os.Stderr, "failed to close response body: %v\n", err)
}
}()
if resp.StatusCode >= 400 {
var errorResp map[string]any
if err := json.NewDecoder(resp.Body).Decode(&errorResp); err == nil {
if msg, ok := errorResp["error"].(string); ok {
return fmt.Errorf("API error (%d): %s", resp.StatusCode, msg)
}
}
return fmt.Errorf("API error: %d %s", resp.StatusCode, resp.Status)
}
if result != nil && resp.StatusCode != http.StatusNoContent {
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
}
return nil
}
// GetAuditEvents retrieves a paginated list of audit events with optional filters
func (c *APIClient) GetAuditEvents(ctx context.Context, filters map[string]string, page, pageSize int) ([]AuditEvent, *Pagination, error) {
// Build query parameters
params := url.Values{}
for key, value := range filters {
if value != "" {
params.Set(key, value)
}
}
params.Set("page", fmt.Sprintf("%d", page))
params.Set("page_size", fmt.Sprintf("%d", pageSize))
endpoint := "/manage/audit"
if len(params) > 0 {
endpoint += "?" + params.Encode()
}
req, err := c.newRequest(ctx, "GET", endpoint, nil)
if err != nil {
return nil, nil, err
}
var response AuditEventsResponse
if err := c.doRequest(req, &response); err != nil {
return nil, nil, err
}
// Parse metadata for each event
for i := range response.Events {
if response.Events[i].Metadata != nil && *response.Events[i].Metadata != "" {
var meta AuditMetadata
if err := json.Unmarshal([]byte(*response.Events[i].Metadata), &meta); err == nil {
response.Events[i].ParsedMeta = &meta
}
}
}
return response.Events, &response.Pagination, nil
}
// GetAuditEvent retrieves a specific audit event by ID
func (c *APIClient) GetAuditEvent(ctx context.Context, id string) (*AuditEvent, error) {
req, err := c.newRequest(ctx, "GET", "/manage/audit/"+id, nil)
if err != nil {
return nil, err
}
var event AuditEvent
if err := c.doRequest(req, &event); err != nil {
return nil, err
}
// Parse metadata
if event.Metadata != nil && *event.Metadata != "" {
var meta AuditMetadata
if err := json.Unmarshal([]byte(*event.Metadata), &meta); err == nil {
event.ParsedMeta = &meta
}
}
return &event, nil
}
// GetToken retrieves a single token by ID
func (c *APIClient) GetToken(ctx context.Context, tokenID string) (*Token, error) {
req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/manage/tokens/%s", tokenID), nil)
if err != nil {
return nil, err
}
var token Token
if err := c.doRequest(req, &token); err != nil {
return nil, err
}
return &token, nil
}
// UpdateToken updates an existing token
func (c *APIClient) UpdateToken(ctx context.Context, tokenID string, isActive *bool, maxRequests *int) (*Token, error) {
payload := map[string]interface{}{}
if isActive != nil {
payload["is_active"] = *isActive
}
if maxRequests != nil {
payload["max_requests"] = *maxRequests
}
req, err := c.newRequest(ctx, "PATCH", fmt.Sprintf("/manage/tokens/%s", tokenID), payload)
if err != nil {
return nil, err
}
var token Token
if err := c.doRequest(req, &token); err != nil {
return nil, err
}
return &token, nil
}
// RevokeToken revokes a single token by setting is_active to false
func (c *APIClient) RevokeToken(ctx context.Context, tokenID string) error {
req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/manage/tokens/%s", tokenID), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// RevokeProjectTokens revokes all tokens for a project in bulk
func (c *APIClient) RevokeProjectTokens(ctx context.Context, projectID string) error {
req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/manage/projects/%s/tokens/revoke", projectID), nil)
if err != nil {
return err
}
return c.doRequest(req, nil)
}
// Package admin provides the HTTP server for the Admin UI.
// This package implements a separate web interface for managing
// projects and tokens via the Management API.
package admin
import (
"context"
"encoding/json"
"fmt"
"html/template"
"log"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/config"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
)
// Session represents a user session
type Session struct {
ID string
Token string
CreatedAt time.Time
ExpiresAt time.Time
RememberMe bool
}
// getSessionSecret derives a secret key from the management token and a salt
func getSessionSecret(cfg *config.Config) []byte {
salt := "llmproxy-cookie-salt" // TODO: move to config/env
return []byte(cfg.AdminUI.ManagementToken + salt)
}
// APIClientInterface abstracts the API client for testability
//go:generate mockgen -destination=mock_api_client.go -package=admin . APIClientInterface
// Only the methods needed by handlers are included
type APIClientInterface interface {
GetDashboardData(ctx context.Context) (*DashboardData, error)
GetProjects(ctx context.Context, page, pageSize int) ([]Project, *Pagination, error)
GetTokens(ctx context.Context, projectID string, page, pageSize int) ([]Token, *Pagination, error)
CreateToken(ctx context.Context, projectID string, durationMinutes int, maxRequests *int) (*TokenCreateResponse, error)
GetProject(ctx context.Context, projectID string) (*Project, error)
UpdateProject(ctx context.Context, projectID string, name string, openAIAPIKey string, isActive *bool) (*Project, error)
DeleteProject(ctx context.Context, projectID string) error
CreateProject(ctx context.Context, name string, openAIAPIKey string) (*Project, error)
GetAuditEvents(ctx context.Context, filters map[string]string, page, pageSize int) ([]AuditEvent, *Pagination, error)
GetAuditEvent(ctx context.Context, id string) (*AuditEvent, error)
GetToken(ctx context.Context, tokenID string) (*Token, error)
UpdateToken(ctx context.Context, tokenID string, isActive *bool, maxRequests *int) (*Token, error)
RevokeToken(ctx context.Context, tokenID string) error
RevokeProjectTokens(ctx context.Context, projectID string) error
}
// Server represents the Admin UI HTTP server.
// It provides a web interface for managing projects and tokens
// by communicating with the Management API.
type Server struct {
server *http.Server
config *config.Config
engine *gin.Engine
apiClient *APIClient
logger *zap.Logger
assetVersion string
// For testability: allow injection of token validation logic
ValidateTokenWithAPI func(context.Context, string) bool
// Audit logger for admin actions
auditLogger *audit.Logger
}
// NewServer creates a new Admin UI server with the provided configuration.
// It initializes the Gin engine, sets up routes, and configures the HTTP server.
func NewServer(cfg *config.Config) (*Server, error) {
// Initialize logger
logger, err := logging.NewLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize admin server logger: %w", err)
}
// Set Gin mode based on log level
if cfg.LogLevel == "debug" {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
// Add middleware
engine.Use(gin.Logger())
engine.Use(gin.Recovery())
// Add method override middleware for HTML forms
engine.Use(func(c *gin.Context) {
if c.Request.Method == "POST" && c.Request.FormValue("_method") != "" {
c.Request.Method = c.Request.FormValue("_method")
}
c.Next()
})
// Add session middleware
store := cookie.NewStore(getSessionSecret(cfg))
engine.Use(sessions.Sessions("llmproxy_session", store))
// Create API client for communicating with Management API
// Note: API client will be updated when user logs in
var apiClient *APIClient
if cfg.AdminUI.ManagementToken != "" {
apiClient = NewAPIClient(cfg.AdminUI.APIBaseURL, cfg.AdminUI.ManagementToken)
}
// Initialize audit logger
var auditLogger *audit.Logger
if cfg.AuditEnabled && cfg.AuditLogFile != "" {
auditConfig := audit.LoggerConfig{
FilePath: cfg.AuditLogFile,
CreateDir: cfg.AuditCreateDir,
}
var err error
auditLogger, err = audit.NewLogger(auditConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize admin audit logger: %w", err)
}
} else {
auditLogger = audit.NewNullLogger()
}
s := &Server{
config: cfg,
engine: engine,
apiClient: apiClient,
logger: logger,
assetVersion: strconv.FormatInt(time.Now().Unix(), 10),
auditLogger: auditLogger,
server: &http.Server{
Addr: cfg.AdminUI.ListenAddr,
Handler: engine,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
},
}
// Setup routes
s.setupRoutes()
return s, nil
}
// Start starts the Admin UI server.
// This method blocks until the server is shut down or an error occurs.
func (s *Server) Start() error {
return s.server.ListenAndServe()
}
// Shutdown gracefully shuts down the server without interrupting active connections.
func (s *Server) Shutdown(ctx context.Context) error {
if s.server == nil {
return nil
}
return s.server.Shutdown(ctx)
}
// setupRoutes configures all the routes for the Admin UI
func (s *Server) setupRoutes() {
// Serve static files (CSS, JS, images)
s.engine.Static("/static", "./web/static")
// Load HTML templates with custom functions using glob patterns
s.engine.SetFuncMap(s.templateFuncs())
td := s.config.AdminUI.TemplateDir
// Load all templates - both root level and subdirectories
templGlob := template.Must(template.New("").Funcs(s.templateFuncs()).ParseGlob(filepath.Join(td, "*.html")))
templGlob = template.Must(templGlob.ParseGlob(filepath.Join(td, "*/*.html")))
s.engine.SetHTMLTemplate(templGlob)
// Authentication routes (no middleware)
auth := s.engine.Group("/auth")
{
auth.GET("/login", s.handleLoginForm)
auth.POST("/login", s.handleLogin)
auth.GET("/logout", s.handleLogout) // Allow GET for direct URL access
auth.POST("/logout", s.handleLogout) // Keep POST for form submission
}
// Root route - redirect to dashboard
s.engine.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/dashboard")
})
// Protected routes with authentication middleware
protected := s.engine.Group("/")
protected.Use(s.authMiddleware())
{
// Dashboard
protected.GET("/dashboard", s.handleDashboard)
// Projects routes
projects := protected.Group("/projects")
{
projects.GET("", s.handleProjectsList)
projects.GET("/new", s.handleProjectsNew)
projects.POST("", s.handleProjectsCreate)
projects.GET("/:id", s.handleProjectsShow)
projects.GET("/:id/edit", s.handleProjectsEdit)
// HTML forms submit via POST with _method override; Gin matches routes before middleware,
// so provide POST fallback routes that dispatch to the correct handlers.
projects.POST("/:id", s.handleProjectsPostOverride)
projects.PUT("/:id", s.handleProjectsUpdate)
projects.DELETE("/:id", s.handleProjectsDelete)
projects.POST("/:id/revoke-tokens", s.handleProjectsBulkRevoke)
}
// Tokens routes
tokens := protected.Group("/tokens")
{
tokens.GET("", s.handleTokensList)
tokens.GET("/new", s.handleTokensNew)
tokens.POST("", s.handleTokensCreate)
tokens.GET("/:token", s.handleTokensShow)
tokens.GET("/:token/edit", s.handleTokensEdit)
// POST fallback for HTML forms with _method override
tokens.POST("/:token", s.handleTokensPostOverride)
tokens.PUT("/:token", s.handleTokensUpdate)
tokens.DELETE("/:token", s.handleTokensRevoke)
}
// Audit routes
audit := protected.Group("/audit")
{
audit.GET("", s.handleAuditList)
audit.GET("/:id", s.handleAuditShow)
}
}
// Health check
s.engine.GET("/health", func(c *gin.Context) {
adminHealth := gin.H{
"status": "ok",
"timestamp": time.Now(),
"service": "admin-ui",
"version": "0.1.0",
}
backendURL := s.config.AdminUI.APIBaseURL
if backendURL == "" {
backendURL = "http://localhost:8080"
}
if !strings.HasSuffix(backendURL, "/health") {
backendURL = strings.TrimRight(backendURL, "/") + "/health"
}
client := &http.Client{Timeout: 2 * time.Second}
backendHealth := gin.H{
"status": "down",
"error": "Backend unavailable",
}
if resp, err := client.Get(backendURL); err == nil && resp.StatusCode == 200 {
var backendData map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&backendData); err == nil {
backendHealth = backendData
}
err := resp.Body.Close()
if err != nil {
s.logger.Warn("failed to close backend health response body", zap.Error(err))
}
}
c.JSON(http.StatusOK, gin.H{
"admin": adminHealth,
"backend": backendHealth,
})
})
}
// Dashboard handlers
func (s *Server) handleDashboard(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Get dashboard data from Management API with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
dashboardData, err := apiClientIface.GetDashboardData(ctx)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load dashboard data: %v", err),
})
return
}
c.HTML(http.StatusOK, "dashboard.html", gin.H{
"title": "Dashboard",
"active": "dashboard",
"data": dashboardData,
})
}
// Project handlers
func (s *Server) handleProjectsList(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
page := getPageFromQuery(c, 1)
pageSize := getPageSizeFromQuery(c, 10)
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
// Forward best-effort original client IP from headers
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
projects, pagination, err := apiClient.GetProjects(ctx, page, pageSize)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
c.HTML(http.StatusOK, "projects/list.html", gin.H{
"title": "Projects",
"active": "projects",
"projects": projects,
"pagination": pagination,
})
}
func (s *Server) handleProjectsNew(c *gin.Context) {
c.HTML(http.StatusOK, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
})
}
func (s *Server) handleProjectsCreate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
var req struct {
Name string `form:"name" binding:"required"`
APIKey string `form:"api_key" binding:"required"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
"error": "Please fill in all required fields",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.CreateProject(ctx, req.Name, req.APIKey)
if err != nil {
c.HTML(http.StatusInternalServerError, "projects/new.html", gin.H{
"title": "Create Project",
"active": "projects",
"error": fmt.Sprintf("Failed to create project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", project.ID))
}
func (s *Server) handleProjectsShow(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.GetProject(ctx, id)
if err != nil {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
})
return
}
c.HTML(http.StatusOK, "projects/show.html", gin.H{
"title": "Project Details",
"active": "projects",
"project": project,
})
}
func (s *Server) handleProjectsEdit(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.GetProject(ctx, id)
if err != nil {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
})
return
}
c.HTML(http.StatusOK, "projects/edit.html", gin.H{
"title": "Edit Project",
"active": "projects",
"project": project,
})
}
func (s *Server) handleProjectsUpdate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
var req struct {
Name string `form:"name" binding:"required"`
APIKey string `form:"api_key"` // Optional - empty means keep existing
IsActive *bool `form:"is_active"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
})
return
}
// Checkbox handling via helper for clarity and reuse
isActive := parseBoolFormField(c, "is_active")
isActivePtr := &isActive
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
project, err := apiClient.UpdateProject(ctx, id, req.Name, req.APIKey, isActivePtr)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to update project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", project.ID))
}
// handleProjectsPostOverride routes POST requests with _method overrides to the appropriate handler.
// It ensures form submissions to /projects/:id work even though Gin resolves routes before middleware.
func (s *Server) handleProjectsPostOverride(c *gin.Context) {
// Parse form to access _method
if err := c.Request.ParseForm(); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Failed to parse form data",
})
return
}
method := c.PostForm("_method")
switch strings.ToUpper(method) {
case http.MethodPut:
s.handleProjectsUpdate(c)
return
case http.MethodDelete:
s.handleProjectsDelete(c)
return
default:
// No override provided; treat as bad request
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Unsupported method override for project action",
})
return
}
}
// parseBoolFormField returns the boolean value of a form field that may submit
// both a hidden "false" and a checkbox "true". The last value wins.
func parseBoolFormField(c *gin.Context, name string) bool {
vals := c.PostFormArray(name)
if len(vals) == 0 {
return false
}
return strings.EqualFold(vals[len(vals)-1], "true")
}
func (s *Server) handleProjectsDelete(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
id := c.Param("id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.DeleteProject(ctx, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Failed to delete project: %v", err),
})
return
}
c.Redirect(http.StatusSeeOther, "/projects")
}
// Token handlers
func (s *Server) handleTokensList(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
page := getPageFromQuery(c, 1)
pageSize := getPageSizeFromQuery(c, 10)
projectID := c.Query("project_id")
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
tokens, pagination, err := apiClient.GetTokens(ctx, projectID, page, pageSize)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load tokens: %v", err),
})
return
}
// Fetch all projects to create a lookup map for project names
projects, _, err := apiClient.GetProjects(ctx, 1, 1000) // Get up to 1000 projects
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
// Create project ID to name lookup map
projectNames := make(map[string]string)
for _, project := range projects {
projectNames[project.ID] = project.Name
}
c.HTML(http.StatusOK, "tokens/list.html", gin.H{
"title": "Tokens",
"active": "tokens",
"tokens": tokens,
"pagination": pagination,
"projectId": projectID,
"projectNames": projectNames,
"now": time.Now(),
"currentTime": time.Now(),
})
}
func (s *Server) handleTokensNew(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
projects, _, err := apiClient.GetProjects(ctx, 1, 100)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": fmt.Sprintf("Failed to load projects: %v", err),
})
return
}
c.HTML(http.StatusOK, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
})
}
func (s *Server) handleTokensCreate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
var req struct {
ProjectID string `form:"project_id" binding:"required"`
DurationMinutes int `form:"duration_minutes" binding:"required,min=1,max=525600"`
MaxRequests string `form:"max_requests"`
}
if err := c.ShouldBind(&req); err != nil {
// forward context as well for consistency in audit logs
projCtx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
projCtx = context.WithValue(projCtx, ctxKeyForwardedReferer, ref)
}
projects, _, _ := apiClient.GetProjects(projCtx, 1, 100)
c.HTML(http.StatusBadRequest, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
"project_id": c.PostForm("project_id"),
"duration_minutes": c.PostForm("duration_minutes"),
"max_requests": c.PostForm("max_requests"),
"error": "Please fill in all required fields correctly",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
renderNewTokenFormError := func(status int, message string) {
projects, _, _ := apiClient.GetProjects(ctx, 1, 100)
c.HTML(status, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
"project_id": req.ProjectID,
"duration_minutes": req.DurationMinutes,
"max_requests": req.MaxRequests,
"error": message,
})
}
var maxRequests *int
if strings.TrimSpace(req.MaxRequests) != "" {
parsedMaxRequests, err := strconv.Atoi(strings.TrimSpace(req.MaxRequests))
if err != nil || parsedMaxRequests < 0 {
renderNewTokenFormError(http.StatusBadRequest, "Please fill in all required fields correctly")
return
}
// 0 means unlimited; omit from payload.
if parsedMaxRequests > 0 {
maxRequests = &parsedMaxRequests
}
}
token, err := apiClient.CreateToken(ctx, req.ProjectID, req.DurationMinutes, maxRequests)
if err != nil {
// forward context as well for consistency in audit logs
projCtx := ctx
projects, _, _ := apiClient.GetProjects(projCtx, 1, 100)
c.HTML(http.StatusInternalServerError, "tokens/new.html", gin.H{
"title": "Generate Token",
"active": "tokens",
"projects": projects,
"error": fmt.Sprintf("Failed to create token: %v", err),
})
return
}
c.HTML(http.StatusOK, "tokens/created.html", gin.H{
"title": "Token Created",
"active": "tokens",
"token": token,
})
}
func (s *Server) handleTokensShow(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
token, err := apiClient.GetToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load token",
"details": err.Error(),
})
}
return
}
// Get project for the token
project, err := apiClient.GetProject(ctx, token.ProjectID)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load project",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "tokens/show.html", gin.H{
"title": "Token Details",
"active": "tokens",
"token": token,
"project": project,
"now": time.Now(),
"currentTime": time.Now(),
})
}
// Helper functions
func getPageFromQuery(c *gin.Context, defaultValue int) int {
if page := c.Query("page"); page != "" {
if p, err := parsePositiveInt(page); err == nil && p > 0 {
return p
}
}
return defaultValue
}
func getPageSizeFromQuery(c *gin.Context, defaultValue int) int {
if size := c.Query("size"); size != "" {
if s, err := parsePositiveInt(size); err == nil && s > 0 && s <= 100 {
return s
}
}
return defaultValue
}
func parsePositiveInt(s string) (int, error) {
var result int
if _, err := fmt.Sscanf(s, "%d", &result); err != nil {
return 0, err
}
return result, nil
}
// templateFuncs returns custom template functions for HTML templates
func (s *Server) templateFuncs() template.FuncMap {
return template.FuncMap{
"stringOr": func(value any, fallback string) string {
// Safely dereference optional strings for templates
switch v := value.(type) {
case *string:
if v != nil && *v != "" {
return *v
}
case string:
if v != "" {
return v
}
}
return fallback
},
"add": func(a, b int) int {
return a + b
},
"sub": func(a, b int) int {
return a - b
},
"safeSub": func(a, b int) int {
// Returns max(0, a-b) to prevent negative values in templates
if a < b {
return 0
}
return a - b
},
"formatMaxRequests": func(max *int) string {
if max == nil {
return "∞"
}
if *max <= 0 {
return "∞"
}
return strconv.Itoa(*max)
},
"inc": func(a int) int {
return a + 1
},
"dec": func(a int) int {
return a - 1
},
"seq": func(start, end int) []int {
if start > end {
return []int{}
}
seq := make([]int, end-start+1)
for i := range seq {
seq[i] = start + i
}
return seq
},
"now": func() time.Time {
return time.Now()
},
"assetVersion": func() string {
return s.assetVersion
},
"eq": func(a, b any) bool {
return a == b
},
"ne": func(a, b any) bool {
return a != b
},
"lt": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v < b2
}
case int64:
if b2, ok := b.(int64); ok {
return v < b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.Before(b2)
}
}
return false
},
"gt": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v > b2
}
case int64:
if b2, ok := b.(int64); ok {
return v > b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.After(b2)
}
}
return false
},
"le": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v <= b2
}
case int64:
if b2, ok := b.(int64); ok {
return v <= b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.Before(b2) || v.Equal(b2)
}
}
return false
},
"ge": func(a, b any) bool {
switch v := a.(type) {
case int:
if b2, ok := b.(int); ok {
return v >= b2
}
case int64:
if b2, ok := b.(int64); ok {
return v >= b2
}
case time.Time:
if b2, ok := b.(time.Time); ok {
return v.After(b2) || v.Equal(b2)
}
}
return false
},
"and": func(a, b bool) bool {
return a && b
},
"or": func(a, b bool) bool {
return a || b
},
"not": func(a bool) bool {
return !a
},
"obfuscateAPIKey": func(apiKey string) string { return obfuscate.ObfuscateTokenGeneric(apiKey) },
"obfuscateToken": func(token string) string { return obfuscate.ObfuscateTokenSimple(token) },
"contains": func(s, substr string) bool {
return strings.Contains(s, substr)
},
"pageRange": func(current, total int) []int {
// Show up to 7 page numbers around current page
start := current - 3
end := current + 3
if start < 1 {
start = 1
}
if end > total {
end = total
}
// Adjust if we have fewer than 7 pages to show
if end-start < 6 && total > 6 {
if start == 1 {
end = start + 6
if end > total {
end = total
}
} else if end == total {
start = end - 6
if start < 1 {
start = 1
}
}
}
pages := make([]int, 0, end-start+1)
for i := start; i <= end; i++ {
pages = append(pages, i)
}
return pages
},
"dict": func(values ...interface{}) map[string]interface{} {
if len(values)%2 != 0 {
return nil
}
dict := make(map[string]interface{}, len(values)/2)
for i := 0; i < len(values); i += 2 {
key, ok := values[i].(string)
if !ok {
return nil
}
dict[key] = values[i+1]
}
return dict
},
"formatRFC3339UTC": func(t time.Time) string {
return t.UTC().Format(time.RFC3339Nano)
},
"formatRFC3339UTCPtr": func(t *time.Time) string {
if t == nil {
return ""
}
return t.UTC().Format(time.RFC3339Nano)
},
}
}
// Authentication handlers
func (s *Server) handleLoginForm(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", gin.H{
"title": "Sign In",
})
}
func (s *Server) handleLogin(c *gin.Context) {
logger := s.logger
if logger == nil {
logger = zap.NewNop()
}
var req struct {
ManagementToken string `form:"management_token" binding:"required"`
RememberMe bool `form:"remember_me"`
}
if err := c.Request.ParseForm(); err == nil {
// Only log field presence, not values
logger.Debug("raw POST form fields", zap.Strings("fields", getFormFieldNames(c.Request.PostForm)))
}
if err := c.ShouldBind(&req); err != nil {
logger.Warn("ShouldBind error", zap.Error(err))
c.HTML(http.StatusBadRequest, "login.html", gin.H{
"title": "Sign In",
"error": "Please enter your management token",
})
return
}
req.ManagementToken = strings.TrimSpace(req.ManagementToken)
logger.Info("login attempt", zap.String("token", obfuscateToken(req.ManagementToken)), zap.Bool("remember_me", req.RememberMe))
// Use injected or default token validation
validate := s.ValidateTokenWithAPI
if validate == nil {
validate = s.validateTokenWithAPI
}
if !validate(c.Request.Context(), req.ManagementToken) {
logger.Error("token validation failed", zap.String("token", obfuscateToken(req.ManagementToken)))
c.HTML(http.StatusUnauthorized, "login.html", gin.H{
"title": "Sign In",
"error": "Invalid management token. Please check your token and try again.",
})
return
}
// Set session cookie
session := sessions.Default(c)
session.Set("token", req.ManagementToken)
if req.RememberMe {
session.Options(sessions.Options{
MaxAge: 30 * 24 * 60 * 60, // 30 days
Path: "/",
HttpOnly: true,
Secure: false,
})
} else {
session.Options(sessions.Options{
MaxAge: 0,
Path: "/",
HttpOnly: true,
Secure: false,
})
}
if err := session.Save(); err != nil {
logger.Error("session save error", zap.Error(err))
}
logger.Info("session saved", zap.String("token", obfuscateToken(req.ManagementToken)), zap.Bool("remember_me", req.RememberMe))
// Redirect to dashboard
c.Redirect(http.StatusSeeOther, "/dashboard")
}
func (s *Server) handleLogout(c *gin.Context) {
session := sessions.Default(c)
session.Clear()
if err := session.Save(); err != nil {
s.logger.Error("session save error", zap.Error(err))
}
c.Redirect(http.StatusSeeOther, "/auth/login")
}
// authMiddleware checks for valid session
func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
token := session.Get("token")
if token == nil {
c.Redirect(http.StatusSeeOther, "/auth/login")
c.Abort()
return
}
// Optionally, validate token again or cache validation
apiClient := NewAPIClient(s.config.AdminUI.APIBaseURL, token.(string))
c.Set("apiClient", apiClient)
c.Next()
}
}
// validateTokenWithAPI validates a token against the Management API
func (s *Server) validateTokenWithAPI(ctx context.Context, token string) bool {
// Create a HTTP client to validate the token
client := &http.Client{Timeout: 10 * time.Second}
// Try to access a simple management endpoint with the token
req, err := http.NewRequestWithContext(ctx, "GET", s.config.AdminUI.APIBaseURL+"/manage/projects", nil)
if err != nil {
return false
}
// Add authorization header
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
// Make the request
resp, err := client.Do(req)
if err != nil {
return false
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn("Error closing response body", zap.Error(err))
}
}()
// Valid token should return 200 OK (or other success status)
return resp.StatusCode >= 200 && resp.StatusCode < 300
}
// getFormFieldNames returns a slice of form field names for logging
func getFormFieldNames(form map[string][]string) []string {
fields := make([]string, 0, len(form))
for k := range form {
fields = append(fields, k)
}
return fields
}
// obfuscateToken returns a partially masked version of a token for logging
func obfuscateToken(token string) string {
return obfuscate.ObfuscateTokenSimple(token)
}
// Audit handlers
func (s *Server) handleAuditList(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Parse query parameters for filtering
filters := make(map[string]string)
query := c.Request.URL.Query()
// Filter parameters
if action := query.Get("action"); action != "" {
filters["action"] = action
}
if outcome := query.Get("outcome"); outcome != "" {
filters["outcome"] = outcome
}
if projectID := query.Get("project_id"); projectID != "" {
filters["project_id"] = projectID
}
if actor := query.Get("actor"); actor != "" {
filters["actor"] = actor
}
if clientIP := query.Get("client_ip"); clientIP != "" {
filters["client_ip"] = clientIP
}
if requestID := query.Get("request_id"); requestID != "" {
filters["request_id"] = requestID
}
if method := query.Get("method"); method != "" {
filters["method"] = method
}
if path := query.Get("path"); path != "" {
filters["path"] = path
}
if search := query.Get("search"); search != "" {
filters["search"] = search
}
if startTime := query.Get("start_time"); startTime != "" {
filters["start_time"] = startTime
}
if endTime := query.Get("end_time"); endTime != "" {
filters["end_time"] = endTime
}
// Parse pagination
page := 1
if pageStr := query.Get("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
pageSize := 20
if pageSizeStr := query.Get("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
// Get audit events with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
events, pagination, err := apiClientIface.GetAuditEvents(ctx, filters, page, pageSize)
if err != nil {
log.Printf("Failed to get audit events: %v", err)
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load audit events",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "audit/list.html", gin.H{
"title": "Audit Events",
"active": "audit",
"events": events,
"pagination": pagination,
"filters": filters,
"query": query,
})
}
func (s *Server) handleAuditShow(c *gin.Context) {
// Get API client from context
apiClientIface := c.MustGet("apiClient").(APIClientInterface)
// Get audit event ID from URL
id := c.Param("id")
if id == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid audit event ID",
"details": "Audit event ID is required",
})
return
}
// Get audit event with forwarded browser metadata
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
event, err := apiClientIface.GetAuditEvent(ctx, id)
if err != nil {
log.Printf("Failed to get audit event %s: %v", id, err)
if strings.Contains(err.Error(), "not found") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Audit event not found",
"details": fmt.Sprintf("Audit event with ID %s was not found", id),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load audit event",
"details": err.Error(),
})
}
return
}
c.HTML(http.StatusOK, "audit/show.html", gin.H{
"title": "Audit Event",
"active": "audit",
"event": event,
})
}
// Token edit/revoke handlers
func (s *Server) handleTokensEdit(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
token, err := apiClient.GetToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load token",
"details": err.Error(),
})
}
return
}
// Get project for the token
project, err := apiClient.GetProject(ctx, token.ProjectID)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to load project",
"details": err.Error(),
})
return
}
c.HTML(http.StatusOK, "tokens/edit.html", gin.H{
"title": "Edit Token",
"active": "tokens",
"token": token,
"project": project,
"tokenID": tokenID,
})
}
func (s *Server) handleTokensUpdate(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
var req struct {
IsActive *bool `form:"is_active"`
MaxRequests string `form:"max_requests"`
}
if err := c.ShouldBind(&req); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
"details": err.Error(),
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
var maxRequests *int
maxRequestsStr := strings.TrimSpace(req.MaxRequests)
if maxRequestsStr == "" {
// Empty input means "unset" (unlimited). We send 0 and let the management API
// normalize it to nil.
zero := 0
maxRequests = &zero
} else {
parsedMaxRequests, err := strconv.Atoi(maxRequestsStr)
if err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
"details": err.Error(),
})
return
}
if parsedMaxRequests < 0 {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Invalid form data",
"details": "max_requests must be >= 0",
})
return
}
maxRequests = &parsedMaxRequests
}
_, err := apiClient.UpdateToken(ctx, tokenID, req.IsActive, maxRequests)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to update token",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/tokens/%s", tokenID))
}
func (s *Server) handleTokensRevoke(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
tokenID := c.Param("token")
if tokenID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Token ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.RevokeToken(ctx, tokenID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Token not found",
"details": fmt.Sprintf("Token %s was not found", tokenID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to revoke token",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, "/tokens")
}
// handleTokensPostOverride routes POST requests with _method overrides for token actions.
func (s *Server) handleTokensPostOverride(c *gin.Context) {
// Parse form to access _method
if err := c.Request.ParseForm(); err != nil {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Failed to parse form data",
})
return
}
method := c.PostForm("_method")
switch strings.ToUpper(method) {
case http.MethodPut:
s.handleTokensUpdate(c)
return
case http.MethodDelete:
s.handleTokensRevoke(c)
return
default:
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Unsupported method override for token action",
})
return
}
}
// Project bulk revoke handler
func (s *Server) handleProjectsBulkRevoke(c *gin.Context) {
// Get API client from context
apiClient := c.MustGet("apiClient").(APIClientInterface)
projectID := c.Param("id")
if projectID == "" {
c.HTML(http.StatusBadRequest, "error.html", gin.H{
"error": "Project ID is required",
})
return
}
ctx := context.WithValue(c.Request.Context(), ctxKeyForwardedUA, c.Request.UserAgent())
if ip := c.Request.Header.Get("X-Forwarded-For"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, strings.Split(ip, ",")[0])
} else if ip := c.Request.Header.Get("X-Real-IP"); ip != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedIP, ip)
}
if ref := c.Request.Referer(); ref != "" {
ctx = context.WithValue(ctx, ctxKeyForwardedReferer, ref)
}
err := apiClient.RevokeProjectTokens(ctx, projectID)
if err != nil {
if strings.Contains(err.Error(), "404") {
c.HTML(http.StatusNotFound, "error.html", gin.H{
"error": "Project not found",
"details": fmt.Sprintf("Project %s was not found", projectID),
})
} else {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{
"error": "Failed to revoke project tokens",
"details": err.Error(),
})
}
return
}
c.Redirect(http.StatusSeeOther, fmt.Sprintf("/projects/%s", projectID))
}
// Package api provides shared utility functions for CLI and API clients.
package api
import (
"fmt"
"os"
"strings"
"time"
"github.com/spf13/cobra"
)
// ObfuscateKey obfuscates a sensitive key for display.
func ObfuscateKey(key string) string {
if len(key) <= 8 {
return "****"
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}
// ParseTimeHeader parses a time header in RFC3339Nano format.
func ParseTimeHeader(val string) time.Time {
if val == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339Nano, val)
if err != nil {
return time.Time{}
}
return t
}
// GetManagementToken gets the management token from a cobra command or environment.
func GetManagementToken(cmd *cobra.Command) (string, error) {
mgmtToken, _ := cmd.Flags().GetString("management-token")
if mgmtToken == "" {
mgmtToken = os.Getenv("MANAGEMENT_TOKEN")
}
if mgmtToken == "" {
return "", fmt.Errorf("management token is required (set MANAGEMENT_TOKEN env or use --management-token)")
}
return mgmtToken, nil
}
package audit
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sync"
)
// Logger handles writing audit events to multiple backends (file and database)
// with immutable semantics. It provides thread-safe operations and ensures
// all audit events are persisted.
type Logger struct {
file *os.File
writer io.Writer
mutex sync.Mutex
path string
dbStore DatabaseStore
dbEnabled bool
}
// DatabaseStore defines the interface for database audit storage
type DatabaseStore interface {
StoreAuditEvent(ctx context.Context, event *Event) error
}
// LoggerConfig holds configuration for the audit logger
type LoggerConfig struct {
// FilePath is the path to the audit log file
FilePath string
// CreateDir determines whether to create parent directories if they don't exist
CreateDir bool
// DatabaseStore is an optional database backend for audit events
DatabaseStore DatabaseStore
// EnableDatabase determines whether to store events in database
EnableDatabase bool
}
// NewLogger creates a new audit logger that writes to the specified file.
// If createDir is true, it will create parent directories if they don't exist.
// If a database store is provided, events will also be persisted to the database.
func NewLogger(config LoggerConfig) (*Logger, error) {
if config.FilePath == "" {
return nil, fmt.Errorf("audit log file path cannot be empty")
}
// Create parent directories if needed
if config.CreateDir {
dir := filepath.Dir(config.FilePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create audit log directory: %w", err)
}
}
// Open file for appending with appropriate permissions
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open audit log file: %w", err)
}
return &Logger{
file: file,
writer: file,
path: config.FilePath,
dbStore: config.DatabaseStore,
dbEnabled: config.EnableDatabase && config.DatabaseStore != nil,
}, nil
}
// Log writes an audit event to both file and database backends.
// Events are written as JSON lines (JSONL format) for easy parsing.
// This method is thread-safe.
func (l *Logger) Log(event *Event) error {
if event == nil {
return fmt.Errorf("audit event cannot be nil")
}
l.mutex.Lock()
defer l.mutex.Unlock()
// Write to file backend
if err := l.logToFile(event); err != nil {
return fmt.Errorf("failed to log to file: %w", err)
}
// Write to database backend if enabled
if l.dbEnabled {
// Use background context with timeout for database operations
ctx := context.Background()
if err := l.dbStore.StoreAuditEvent(ctx, event); err != nil {
// Avoid hard dependency on a logger; keep stdout warning for now but
// ensure it's concise. In production, this would be wired to a structured logger.
_, _ = io.WriteString(os.Stdout, "audit: failed to store event in database\n")
}
}
return nil
}
// logToFile writes an audit event to the file backend
func (l *Logger) logToFile(event *Event) error {
// Encode event as JSON
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal audit event: %w", err)
}
// Write JSON line followed by newline
if _, err := l.writer.Write(data); err != nil {
return fmt.Errorf("failed to write audit event: %w", err)
}
if _, err := l.writer.Write([]byte("\n")); err != nil {
return fmt.Errorf("failed to write newline: %w", err)
}
// Sync to ensure data is written to disk
if syncer, ok := l.writer.(interface{ Sync() error }); ok {
if err := syncer.Sync(); err != nil {
return fmt.Errorf("failed to sync audit log: %w", err)
}
}
return nil
}
// Close closes the audit log file.
// After calling Close, the logger should not be used for logging.
func (l *Logger) Close() error {
l.mutex.Lock()
defer l.mutex.Unlock()
if l.file != nil {
err := l.file.Close()
l.file = nil
return err
}
return nil
}
// GetPath returns the file path of the audit log
func (l *Logger) GetPath() string {
return l.path
}
// NewNullLogger creates a logger that discards all events.
// Useful for testing or when audit logging is disabled.
func NewNullLogger() *Logger {
return &Logger{
writer: io.Discard,
dbEnabled: false,
}
}
// Package audit provides audit logging functionality for security-sensitive events
// in the LLM proxy. It implements a separate audit sink with immutable semantics
// for compliance and security investigations.
package audit
import (
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// Event represents a security audit event with canonical fields.
// All audit events must include these core fields for compliance and investigation purposes.
type Event struct {
// Timestamp when the event occurred (ISO8601 format)
Timestamp time.Time `json:"timestamp"`
// Action describes what operation was performed (e.g., "token.create", "project.delete")
Action string `json:"action"`
// Actor identifies who performed the action (user ID, token ID, or system)
Actor string `json:"actor"`
// ProjectID identifies which project was affected (if applicable)
ProjectID string `json:"project_id,omitempty"`
// RequestID for correlation with request logs
RequestID string `json:"request_id,omitempty"`
// CorrelationID for tracing across services
CorrelationID string `json:"correlation_id,omitempty"`
// ClientIP is the IP address of the client making the request
ClientIP string `json:"client_ip,omitempty"`
// Result indicates success or failure of the operation
Result ResultType `json:"result"`
// Details contains additional context about the event (no secrets)
Details map[string]interface{} `json:"details,omitempty"`
}
// ResultType represents the outcome of an audited operation
type ResultType string
const (
// ResultSuccess indicates the operation completed successfully
ResultSuccess ResultType = "success"
// ResultFailure indicates the operation failed
ResultFailure ResultType = "failure"
// ResultDenied indicates the operation was denied (e.g., due to policy)
ResultDenied ResultType = "denied"
// ResultError indicates the operation encountered an error
ResultError ResultType = "error"
)
// Action constants for standardized audit event types
const (
// Token lifecycle actions
ActionTokenCreate = "token.create"
ActionTokenRead = "token.read"
ActionTokenUpdate = "token.update"
ActionTokenRevoke = "token.revoke"
ActionTokenRevokeBatch = "token.revoke_batch"
ActionTokenDelete = "token.delete"
ActionTokenList = "token.list"
ActionTokenValidate = "token.validate"
ActionTokenAccess = "token.access"
// Project lifecycle actions
ActionProjectCreate = "project.create"
ActionProjectRead = "project.read"
ActionProjectUpdate = "project.update"
ActionProjectDelete = "project.delete"
ActionProjectList = "project.list"
// Proxy request actions
ActionProxyRequest = "proxy_request"
// Admin actions
ActionAdminLogin = "admin.login"
ActionAdminLogout = "admin.logout"
ActionAdminAccess = "admin.access"
// Audit actions
ActionAuditList = "audit.list"
ActionAuditShow = "audit.show"
// Cache actions
ActionCachePurge = "cache.purge"
)
// Actor types for common audit actors
const (
ActorSystem = "system"
ActorAnonymous = "anonymous"
ActorAdmin = "admin"
ActorManagement = "management_api"
)
// NewEvent creates a new audit event with the specified action and result.
// The timestamp is automatically set to the current time.
func NewEvent(action string, actor string, result ResultType) *Event {
return &Event{
Timestamp: time.Now().UTC(),
Action: action,
Actor: actor,
Result: result,
Details: make(map[string]interface{}),
}
}
// WithProjectID sets the project ID for the audit event
func (e *Event) WithProjectID(projectID string) *Event {
e.ProjectID = projectID
return e
}
// WithRequestID sets the request ID for correlation with request logs
func (e *Event) WithRequestID(requestID string) *Event {
e.RequestID = requestID
return e
}
// WithCorrelationID sets the correlation ID for tracing across services
func (e *Event) WithCorrelationID(correlationID string) *Event {
e.CorrelationID = correlationID
return e
}
// WithClientIP sets the client IP address for the audit event
func (e *Event) WithClientIP(clientIP string) *Event {
e.ClientIP = clientIP
return e
}
// WithDetail adds a detail key-value pair to the audit event.
// Secrets and sensitive information should be obfuscated before calling this method.
func (e *Event) WithDetail(key string, value interface{}) *Event {
if e.Details == nil {
e.Details = make(map[string]interface{})
}
e.Details[key] = value
return e
}
// WithTokenID adds an obfuscated token ID to the audit event details
func (e *Event) WithTokenID(token string) *Event {
return e.WithDetail("token_id", obfuscate.ObfuscateTokenGeneric(token))
}
// WithError adds error information to the audit event details
func (e *Event) WithError(err error) *Event {
if err != nil {
return e.WithDetail("error", err.Error())
}
return e
}
// WithIPAddress adds the client IP address to the audit event details
// Deprecated: Use WithClientIP instead for first-class IP field support
func (e *Event) WithIPAddress(ip string) *Event {
return e.WithDetail("ip_address", ip)
}
// WithUserAgent adds the user agent to the audit event details
func (e *Event) WithUserAgent(userAgent string) *Event {
return e.WithDetail("user_agent", userAgent)
}
// WithHTTPMethod adds the HTTP method to the audit event details
func (e *Event) WithHTTPMethod(method string) *Event {
return e.WithDetail("http_method", method)
}
// WithEndpoint adds the API endpoint to the audit event details
func (e *Event) WithEndpoint(endpoint string) *Event {
return e.WithDetail("endpoint", endpoint)
}
// WithDuration adds the operation duration to the audit event details
func (e *Event) WithDuration(duration time.Duration) *Event {
return e.WithDetail("duration_ms", duration.Milliseconds())
}
// WithReason adds a reason/description to the audit event details
func (e *Event) WithReason(reason string) *Event {
return e.WithDetail("reason", reason)
}
// Package client provides HTTP client functionality for communicating with the LLM Proxy API.
package client
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/chzyer/readline"
)
// ChatMessage represents a message in the chat
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest represents a request to the chat API
type ChatRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
// ChatResponse represents a response from the chat API
type ChatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatCompletionStreamResponse represents a chunked response in a stream
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatClient handles communication with the LLM Proxy chat API
type ChatClient struct {
BaseURL string
Token string
HTTPClient *http.Client
}
// ChatOptions configures chat request parameters
type ChatOptions struct {
Model string
Temperature float64
MaxTokens int
UseStreaming bool
VerboseMode bool
}
// NewChatClient creates a new chat client
func NewChatClient(baseURL, token string) *ChatClient {
return &ChatClient{
BaseURL: baseURL,
Token: token,
HTTPClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// SendChatRequest sends a chat request and returns the response
func (c *ChatClient) SendChatRequest(messages []ChatMessage, options ChatOptions, readline *readline.Instance) (*ChatResponse, error) {
if c.Token == "" {
return nil, fmt.Errorf("token is required")
}
// Validate proxy URL
if _, err := url.Parse(c.BaseURL); err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
request := ChatRequest{
Model: options.Model,
Messages: messages,
Temperature: options.Temperature,
MaxTokens: options.MaxTokens,
Stream: options.UseStreaming,
}
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
if options.VerboseMode {
fmt.Printf("Request: %s\n", string(jsonData))
}
req, err := http.NewRequest("POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.Token)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Fprintf(os.Stderr, "failed to close response body: %v\n", err)
}
}()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body))
}
if options.UseStreaming {
return c.handleStreamingResponse(resp, readline, options.VerboseMode)
}
return c.handleNonStreamingResponse(resp, options.VerboseMode)
}
// handleStreamingResponse processes streaming chat responses
func (c *ChatClient) handleStreamingResponse(resp *http.Response, readline *readline.Instance, verbose bool) (*ChatResponse, error) {
scanner := bufio.NewScanner(resp.Body)
var finalResponse *ChatResponse
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var streamResp ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
if verbose {
fmt.Printf("Failed to parse stream data: %v\n", err)
}
continue
}
if len(streamResp.Choices) > 0 {
choice := streamResp.Choices[0]
if choice.Delta.Content != "" {
if readline != nil && readline.Config.Stdout != nil {
if _, err := readline.Config.Stdout.Write([]byte(choice.Delta.Content)); err != nil {
return nil, fmt.Errorf("failed to write streaming content: %w", err)
}
} else {
fmt.Print(choice.Delta.Content)
}
}
// Convert to final response format
if finalResponse == nil {
finalResponse = &ChatResponse{
ID: streamResp.ID,
Object: streamResp.Object,
Created: streamResp.Created,
Model: streamResp.Model,
Choices: []struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}{
{
Index: choice.Index,
Message: ChatMessage{
Role: "assistant",
Content: "",
},
FinishReason: choice.FinishReason,
},
},
Usage: streamResp.Usage,
}
}
// Accumulate content
finalResponse.Choices[0].Message.Content += choice.Delta.Content
if choice.FinishReason != "" {
finalResponse.Choices[0].FinishReason = choice.FinishReason
}
}
}
if readline != nil && readline.Config.Stdout != nil {
if _, err := readline.Config.Stdout.Write([]byte("\n")); err != nil {
return nil, fmt.Errorf("failed to write newline after streaming: %w", err)
}
} else {
fmt.Println()
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("stream reading error: %w", err)
}
if finalResponse == nil {
return nil, fmt.Errorf("no response received from stream")
}
return finalResponse, nil
}
// handleNonStreamingResponse processes non-streaming chat responses
func (c *ChatClient) handleNonStreamingResponse(resp *http.Response, verbose bool) (*ChatResponse, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if verbose {
fmt.Printf("Response: %s\n", string(body))
}
var chatResp ChatResponse
if err := json.Unmarshal(body, &chatResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &chatResp, nil
}
// Package config handles application configuration loading and validation
// from environment variables, providing a type-safe configuration structure.
package config
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
// Config holds all application configuration values loaded from environment variables.
// It provides a centralized, type-safe way to access configuration throughout the application.
type Config struct {
// Server configuration
ListenAddr string // Address to listen on (e.g., ":8080")
RequestTimeout time.Duration // Timeout for upstream API requests
MaxRequestSize int64 // Maximum size of incoming requests in bytes
MaxConcurrentReqs int // Maximum number of concurrent requests
// Environment
APIEnv string // API environment: 'production', 'development', 'test'
// Database configuration
DatabasePath string // Path to the SQLite database file
DatabasePoolSize int // Number of connections in the database pool
// Authentication
ManagementToken string // Token for admin operations, used to access the management API
// API Provider configuration
APIConfigPath string // Path to the API providers configuration file
DefaultAPIProvider string // Default API provider to use
OpenAIAPIURL string // Base URL for OpenAI API (legacy support)
EnableStreaming bool // Whether to enable streaming responses from APIs
// Admin UI settings
AdminUIPath string // Base path for the admin UI
AdminUI AdminUIConfig // Admin UI server configuration
// Logging
LogLevel string // Log level (debug, info, warn, error)
LogFormat string // Log format (json, text)
LogFile string // Path to log file (empty for stdout)
// Audit logging
AuditEnabled bool // Enable audit logging for security events
AuditLogFile string // Path to audit log file (empty to disable)
AuditCreateDir bool // Create parent directories for audit log file
AuditStoreInDB bool // Store audit events in database for analytics
// Observability middleware
ObservabilityEnabled bool // Enable async observability middleware
ObservabilityBufferSize int // Buffer size for in-memory event bus
// ObservabilityMaxRequestBodyBytes caps how many bytes of request bodies are captured for observability events.
// This is only for the async event payload and does not affect the proxied request body.
ObservabilityMaxRequestBodyBytes int64
// ObservabilityMaxResponseBodyBytes caps how many bytes of response bodies are captured for observability events.
// This is only for the async event payload and does not affect the proxied response body.
ObservabilityMaxResponseBodyBytes int64
// CORS settings
CORSAllowedOrigins []string // Allowed origins for CORS
CORSAllowedMethods []string // Allowed methods for CORS
CORSAllowedHeaders []string // Allowed headers for CORS
CORSMaxAge time.Duration // Max age for CORS preflight responses
// Rate limiting
GlobalRateLimit int // Maximum requests per minute globally
IPRateLimit int // Maximum requests per minute per IP
// Distributed rate limiting
DistributedRateLimitEnabled bool // Enable Redis-backed distributed rate limiting
DistributedRateLimitPrefix string // Redis key prefix for rate limit counters
DistributedRateLimitKeySecret string // HMAC secret for hashing token IDs in Redis keys (security)
DistributedRateLimitWindow time.Duration // Sliding window duration for rate limiting
DistributedRateLimitMax int // Maximum requests per window
DistributedRateLimitFallback bool // Enable fallback to in-memory when Redis unavailable
// Monitoring
EnableMetrics bool // Whether to enable a lightweight metrics endpoint (provider-agnostic)
MetricsPath string // Path for metrics endpoint
// Cleanup
TokenCleanupInterval time.Duration // Interval for cleaning up expired tokens
// Project active guard configuration
EnforceProjectActive bool // Whether to enforce project active status (default: true)
ActiveCacheTTL time.Duration // TTL for project active status cache (e.g., 5s)
ActiveCacheMax int // Maximum entries in project active status cache (e.g., 10000)
// API key caching (hot path: per-request upstream auth lookup)
APIKeyCacheTTL time.Duration // TTL for per-project upstream API key cache (e.g., 30s)
APIKeyCacheMax int // Maximum entries for upstream API key cache (e.g., 10000)
// Event bus configuration
EventBusBackend string // Backend for event bus: "redis-streams" or "in-memory"
RedisAddr string // Redis server address (e.g., "localhost:6379")
RedisDB int // Redis database number (default: 0)
// Redis Streams configuration (when EventBusBackend = "redis-streams")
RedisStreamKey string // Redis stream key name (default: "llm-proxy-events")
RedisConsumerGroup string // Consumer group name (default: "llm-proxy-dispatchers")
RedisConsumerName string // Consumer name within the group (should be unique per instance)
RedisStreamMaxLen int64 // Max stream length (0 = unlimited, default: 10000)
RedisStreamBlockTime time.Duration // Block timeout for reading (default: 5s)
RedisStreamClaimTime time.Duration // Min idle time before claiming pending msgs (default: 30s)
RedisStreamBatchSize int64 // Batch size for reading messages (default: 100)
// Cache stats aggregation
CacheStatsBufferSize int // Buffer size for async cache stats aggregation (default: 1000)
// Usage stats aggregation
UsageStatsBufferSize int // Buffer size for async usage stats aggregation (default: 1000)
}
// AdminUIConfig holds configuration for the Admin UI server
type AdminUIConfig struct {
ListenAddr string // Address for admin UI server to listen on
APIBaseURL string // Base URL of the Management API
ManagementToken string // Token for accessing Management API
Enabled bool // Whether Admin UI is enabled
TemplateDir string // Directory for HTML templates (default: "web/templates")
}
// New creates a new configuration with values from environment variables.
// It applies default values where environment variables are not set,
// and validates required configuration settings.
//
// Returns a populated Config struct and nil error on success,
// or nil and an error if validation fails.
func New() (*Config, error) {
config := &Config{
// Server defaults
ListenAddr: getEnvString("LISTEN_ADDR", ":8080"),
RequestTimeout: getEnvDuration("REQUEST_TIMEOUT", 30*time.Second),
MaxRequestSize: getEnvInt64("MAX_REQUEST_SIZE", 10*1024*1024), // 10MB
MaxConcurrentReqs: getEnvInt("MAX_CONCURRENT_REQUESTS", 100),
// Environment
APIEnv: getEnvString("API_ENV", "development"),
// Database defaults
DatabasePath: getEnvString("DATABASE_PATH", "./data/llm-proxy.db"),
DatabasePoolSize: getEnvInt("DATABASE_POOL_SIZE", 10),
// Authentication
ManagementToken: getEnvString("MANAGEMENT_TOKEN", ""),
// API Provider settings
APIConfigPath: getEnvString("API_CONFIG_PATH", "./config/api_providers.yaml"),
DefaultAPIProvider: getEnvString("DEFAULT_API_PROVIDER", "openai"),
OpenAIAPIURL: getEnvString("OPENAI_API_URL", "https://api.openai.com"),
EnableStreaming: getEnvBool("ENABLE_STREAMING", true),
// Admin UI settings
AdminUIPath: getEnvString("ADMIN_UI_PATH", "/admin"),
AdminUI: AdminUIConfig{
ListenAddr: getEnvString("ADMIN_UI_LISTEN_ADDR", ":8081"),
APIBaseURL: getEnvString("ADMIN_UI_API_BASE_URL", "http://localhost:8080"),
ManagementToken: getEnvString("MANAGEMENT_TOKEN", ""),
Enabled: getEnvBool("ADMIN_UI_ENABLED", true),
TemplateDir: getEnvString("ADMIN_UI_TEMPLATE_DIR", "web/templates"),
},
// Logging defaults
LogLevel: getEnvString("LOG_LEVEL", "info"),
LogFormat: getEnvString("LOG_FORMAT", "json"),
LogFile: getEnvString("LOG_FILE", ""),
// Audit logging defaults
AuditEnabled: getEnvBool("AUDIT_ENABLED", true),
AuditLogFile: getEnvString("AUDIT_LOG_FILE", "./data/audit.log"),
AuditCreateDir: getEnvBool("AUDIT_CREATE_DIR", true),
AuditStoreInDB: getEnvBool("AUDIT_STORE_IN_DB", true),
ObservabilityEnabled: getEnvBool("OBSERVABILITY_ENABLED", true),
ObservabilityBufferSize: getEnvInt("OBSERVABILITY_BUFFER_SIZE", 1000),
ObservabilityMaxRequestBodyBytes: getEnvInt64("OBSERVABILITY_MAX_REQUEST_BODY_BYTES", 64*1024), // 64KB
ObservabilityMaxResponseBodyBytes: getEnvInt64("OBSERVABILITY_MAX_RESPONSE_BODY_BYTES", 256*1024), // 256KB
// CORS defaults
CORSAllowedOrigins: getEnvStringSlice("CORS_ALLOWED_ORIGINS", []string{"*"}),
CORSAllowedMethods: getEnvStringSlice("CORS_ALLOWED_METHODS", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}),
CORSAllowedHeaders: getEnvStringSlice("CORS_ALLOWED_HEADERS", []string{"Authorization", "Content-Type"}),
CORSMaxAge: getEnvDuration("CORS_MAX_AGE", 24*time.Hour),
// Rate limiting defaults
GlobalRateLimit: getEnvInt("GLOBAL_RATE_LIMIT", 100),
IPRateLimit: getEnvInt("IP_RATE_LIMIT", 30),
// Distributed rate limiting defaults
DistributedRateLimitEnabled: getEnvBool("DISTRIBUTED_RATE_LIMIT_ENABLED", false),
DistributedRateLimitPrefix: getEnvString("DISTRIBUTED_RATE_LIMIT_PREFIX", "ratelimit:"),
DistributedRateLimitKeySecret: getEnvString("DISTRIBUTED_RATE_LIMIT_KEY_SECRET", ""),
DistributedRateLimitWindow: getEnvDuration("DISTRIBUTED_RATE_LIMIT_WINDOW", time.Minute),
DistributedRateLimitMax: getEnvInt("DISTRIBUTED_RATE_LIMIT_MAX", 60),
DistributedRateLimitFallback: getEnvBool("DISTRIBUTED_RATE_LIMIT_FALLBACK", true),
// Monitoring defaults
EnableMetrics: getEnvBool("ENABLE_METRICS", true),
MetricsPath: getEnvString("METRICS_PATH", "/metrics"),
// Cleanup defaults
TokenCleanupInterval: getEnvDuration("TOKEN_CLEANUP_INTERVAL", time.Hour),
// Project active guard defaults
EnforceProjectActive: getEnvBool("LLM_PROXY_ENFORCE_PROJECT_ACTIVE", true),
ActiveCacheTTL: getEnvDuration("LLM_PROXY_ACTIVE_CACHE_TTL", 5*time.Second),
ActiveCacheMax: getEnvInt("LLM_PROXY_ACTIVE_CACHE_MAX", 10000),
// API key caching defaults
APIKeyCacheTTL: getEnvDuration("LLM_PROXY_API_KEY_CACHE_TTL", 30*time.Second),
APIKeyCacheMax: getEnvInt("LLM_PROXY_API_KEY_CACHE_MAX", 10000),
// Event bus configuration
EventBusBackend: getEnvString("LLM_PROXY_EVENT_BUS", "redis-streams"),
RedisAddr: getEnvString("REDIS_ADDR", "localhost:6379"),
RedisDB: getEnvInt("REDIS_DB", 0),
// Redis Streams configuration
RedisStreamKey: getEnvString("REDIS_STREAM_KEY", "llm-proxy-events"),
RedisConsumerGroup: getEnvString("REDIS_CONSUMER_GROUP", "llm-proxy-dispatchers"),
RedisConsumerName: getEnvString("REDIS_CONSUMER_NAME", ""),
RedisStreamMaxLen: getEnvInt64("REDIS_STREAM_MAX_LEN", 10000),
RedisStreamBlockTime: getEnvDuration("REDIS_STREAM_BLOCK_TIME", 5*time.Second),
RedisStreamClaimTime: getEnvDuration("REDIS_STREAM_CLAIM_TIME", 30*time.Second),
RedisStreamBatchSize: getEnvInt64("REDIS_STREAM_BATCH_SIZE", 100),
// Cache stats aggregation
CacheStatsBufferSize: getEnvInt("CACHE_STATS_BUFFER_SIZE", 1000),
// Usage stats aggregation
// Backwards-compatible: if USAGE_STATS_BUFFER_SIZE is not set, re-use CACHE_STATS_BUFFER_SIZE.
UsageStatsBufferSize: getEnvInt("USAGE_STATS_BUFFER_SIZE", getEnvInt("CACHE_STATS_BUFFER_SIZE", 1000)),
}
// Validate required settings
if config.ManagementToken == "" {
return nil, fmt.Errorf("MANAGEMENT_TOKEN environment variable is required")
}
return config, nil
}
// getEnvString retrieves a string value from an environment variable,
// falling back to the provided default value if the variable is not set.
func getEnvString(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultValue
}
// getEnvBool retrieves a boolean value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a boolean.
func getEnvBool(key string, defaultValue bool) bool {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.ParseBool(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvInt retrieves an integer value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as an integer.
func getEnvInt(key string, defaultValue int) int {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.Atoi(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvInt64 retrieves a 64-bit integer value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a 64-bit integer.
func getEnvInt64(key string, defaultValue int64) int64 {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvDuration retrieves a duration value from an environment variable,
// falling back to the provided default value if the variable is not set
// or cannot be parsed as a duration.
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
if value, exists := os.LookupEnv(key); exists {
parsedValue, err := time.ParseDuration(value)
if err == nil {
return parsedValue
}
}
return defaultValue
}
// getEnvStringSlice retrieves a comma-separated string value from an environment variable
// and splits it into a slice of strings, falling back to the provided default value
// if the variable is not set or is empty.
func getEnvStringSlice(key string, defaultValue []string) []string {
if value, exists := os.LookupEnv(key); exists && value != "" {
parts := strings.Split(value, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
}
return parts
}
return defaultValue
}
// LoadFromFile loads configuration from a file (placeholder for future YAML/JSON support)
func LoadFromFile(path string) (*Config, error) {
// For now, return default config - file loading can be implemented later
return DefaultConfig(), nil
}
// DefaultConfig returns a configuration with default values
func DefaultConfig() *Config {
return &Config{
// Server defaults
ListenAddr: ":8080",
RequestTimeout: 30 * time.Second,
MaxRequestSize: 10 * 1024 * 1024, // 10MB
MaxConcurrentReqs: 100,
// Environment
APIEnv: "development",
// Database defaults
DatabasePath: "./data/llm-proxy.db",
DatabasePoolSize: 10,
// API Provider settings
APIConfigPath: "./config/api_providers.yaml",
DefaultAPIProvider: "openai",
OpenAIAPIURL: "https://api.openai.com",
EnableStreaming: true,
// Admin UI settings
AdminUIPath: "/admin",
AdminUI: AdminUIConfig{
ListenAddr: ":8081",
APIBaseURL: "http://localhost:8080",
ManagementToken: "",
Enabled: true,
TemplateDir: "web/templates",
},
// Logging defaults
LogLevel: "info",
LogFormat: "json",
LogFile: "",
ObservabilityEnabled: true,
ObservabilityBufferSize: 1000,
// CORS defaults
CORSAllowedOrigins: []string{"*"},
CORSAllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
CORSMaxAge: 24 * time.Hour,
// Rate limiting defaults
GlobalRateLimit: 100,
IPRateLimit: 30,
// Distributed rate limiting defaults
DistributedRateLimitEnabled: false,
DistributedRateLimitPrefix: "ratelimit:",
DistributedRateLimitKeySecret: "",
DistributedRateLimitWindow: time.Minute,
DistributedRateLimitMax: 60,
DistributedRateLimitFallback: true,
// Monitoring defaults
EnableMetrics: true,
MetricsPath: "/metrics",
// Cleanup defaults
TokenCleanupInterval: time.Hour,
// Project active guard defaults
EnforceProjectActive: true,
ActiveCacheTTL: 5 * time.Second,
ActiveCacheMax: 10000,
// Event bus configuration
EventBusBackend: "redis-streams",
RedisAddr: "localhost:6379",
RedisDB: 0,
// Redis Streams configuration
RedisStreamKey: "llm-proxy-events",
RedisConsumerGroup: "llm-proxy-dispatchers",
RedisConsumerName: "",
RedisStreamMaxLen: 10000,
RedisStreamBlockTime: 5 * time.Second,
RedisStreamClaimTime: 30 * time.Second,
RedisStreamBatchSize: 100,
// Cache stats aggregation
CacheStatsBufferSize: 1000,
// Usage stats aggregation
UsageStatsBufferSize: 1000,
}
}
package config
import (
"os"
"strconv"
)
// envOrDefault returns the value of the environment variable if set, otherwise the fallback.
func EnvOrDefault(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// EnvIntOrDefault returns the int value of the environment variable if set and valid, otherwise the fallback.
func EnvIntOrDefault(key string, fallback int) int {
if v := os.Getenv(key); v != "" {
if i, err := strconv.Atoi(v); err == nil {
return i
}
}
return fallback
}
// EnvBoolOrDefault returns the bool value of the environment variable if set and valid, otherwise the fallback.
func EnvBoolOrDefault(key string, fallback bool) bool {
if v := os.Getenv(key); v != "" {
if b, err := strconv.ParseBool(v); err == nil {
return b
}
}
return fallback
}
// EnvFloat64OrDefault returns the float64 value of the environment variable if set and valid, otherwise the fallback.
func EnvFloat64OrDefault(key string, fallback float64) float64 {
if v := os.Getenv(key); v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
return f
}
}
return fallback
}
package database
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/audit"
)
// AuditStore defines the interface for persisting audit events to database
type AuditStore interface {
StoreAuditEvent(ctx context.Context, event *audit.Event) error
ListAuditEvents(ctx context.Context, filters AuditEventFilters) ([]AuditEvent, error)
CountAuditEvents(ctx context.Context, filters AuditEventFilters) (int, error)
GetAuditEventByID(ctx context.Context, id string) (*AuditEvent, error)
}
// AuditEventFilters provides filtering options for audit event queries
type AuditEventFilters struct {
Action string
ClientIP string
ProjectID string
StartTime *string // RFC3339 format
EndTime *string // RFC3339 format
Outcome string
Actor string
RequestID string
CorrelationID string
Method string
Path string
Search string // Full-text search over reason/metadata
Limit int
Offset int
}
// StoreAuditEvent persists an audit event to the database
func (d *DB) StoreAuditEvent(ctx context.Context, event *audit.Event) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
if event == nil {
return fmt.Errorf("audit event cannot be nil")
}
// Generate UUID for the audit event
id := uuid.New().String()
// Convert metadata to JSON string if present
var metadataJSON *string
if len(event.Details) > 0 {
metadataBytes, err := json.Marshal(event.Details)
if err != nil {
return fmt.Errorf("failed to marshal audit event metadata: %w", err)
}
metadataStr := string(metadataBytes)
metadataJSON = &metadataStr
}
// Extract common fields from details for first-class columns
var method, path, userAgent, reason, tokenID *string
if event.Details != nil {
if v, ok := event.Details["http_method"].(string); ok {
method = &v
}
if v, ok := event.Details["endpoint"].(string); ok {
path = &v
}
if v, ok := event.Details["user_agent"].(string); ok {
userAgent = &v
}
if v, ok := event.Details["error"].(string); ok {
reason = &v
}
if v, ok := event.Details["token_id"].(string); ok {
tokenID = &v
}
}
// Convert optional string fields to pointers
var projectID, requestID, correlationID, clientIP *string
if event.ProjectID != "" {
projectID = &event.ProjectID
}
if event.RequestID != "" {
requestID = &event.RequestID
}
if event.CorrelationID != "" {
correlationID = &event.CorrelationID
}
if event.ClientIP != "" {
clientIP = &event.ClientIP
}
query := `
INSERT INTO audit_events (
id, timestamp, action, actor, project_id, request_id, correlation_id,
client_ip, method, path, user_agent, outcome, reason, token_id, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(ctx, query,
id,
event.Timestamp,
event.Action,
event.Actor,
projectID,
requestID,
correlationID,
clientIP,
method,
path,
userAgent,
string(event.Result),
reason,
tokenID,
metadataJSON,
)
if err != nil {
return fmt.Errorf("failed to insert audit event: %w", err)
}
return nil
}
// ListAuditEvents retrieves audit events from the database with optional filtering
func (d *DB) ListAuditEvents(ctx context.Context, filters AuditEventFilters) ([]AuditEvent, error) {
query := "SELECT id, timestamp, action, actor, project_id, request_id, correlation_id, client_ip, method, path, user_agent, outcome, reason, token_id, metadata FROM audit_events WHERE 1=1"
args := []interface{}{}
// Apply filters
if filters.Action != "" {
query += " AND action = ?"
args = append(args, filters.Action)
}
if filters.ClientIP != "" {
query += " AND client_ip = ?"
args = append(args, filters.ClientIP)
}
if filters.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, filters.ProjectID)
}
if filters.Outcome != "" {
query += " AND outcome = ?"
args = append(args, filters.Outcome)
}
if filters.Actor != "" {
query += " AND actor = ?"
args = append(args, filters.Actor)
}
if filters.RequestID != "" {
query += " AND request_id = ?"
args = append(args, filters.RequestID)
}
if filters.CorrelationID != "" {
query += " AND correlation_id = ?"
args = append(args, filters.CorrelationID)
}
if filters.Method != "" {
query += " AND method = ?"
args = append(args, filters.Method)
}
if filters.Path != "" {
query += " AND path = ?"
args = append(args, filters.Path)
}
if filters.Search != "" {
query += " AND (request_id LIKE ? OR correlation_id LIKE ? OR client_ip LIKE ? OR action LIKE ? OR actor LIKE ? OR method LIKE ? OR path LIKE ? OR reason LIKE ? OR metadata LIKE ?)"
searchPattern := "%" + filters.Search + "%"
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern)
}
if filters.StartTime != nil {
query += " AND timestamp >= datetime(?)"
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
query += " AND timestamp <= datetime(?)"
args = append(args, *filters.EndTime)
}
// Order by timestamp descending
query += " ORDER BY timestamp DESC"
// Apply limit and offset
if filters.Limit > 0 {
query += " LIMIT ?"
args = append(args, filters.Limit)
if filters.Offset > 0 {
query += " OFFSET ?"
args = append(args, filters.Offset)
}
}
rows, err := d.QueryContextRebound(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []AuditEvent
for rows.Next() {
var event AuditEvent
err := rows.Scan(
&event.ID,
&event.Timestamp,
&event.Action,
&event.Actor,
&event.ProjectID,
&event.RequestID,
&event.CorrelationID,
&event.ClientIP,
&event.Method,
&event.Path,
&event.UserAgent,
&event.Outcome,
&event.Reason,
&event.TokenID,
&event.Metadata,
)
if err != nil {
return nil, fmt.Errorf("failed to scan audit event: %w", err)
}
events = append(events, event)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating audit events: %w", err)
}
return events, nil
}
// CountAuditEvents returns the total count of audit events matching the given filters
func (d *DB) CountAuditEvents(ctx context.Context, filters AuditEventFilters) (int, error) {
query := "SELECT COUNT(*) FROM audit_events WHERE 1=1"
args := []interface{}{}
// Apply the same filters as ListAuditEvents (excluding limit/offset)
if filters.Action != "" {
query += " AND action = ?"
args = append(args, filters.Action)
}
if filters.ClientIP != "" {
query += " AND client_ip = ?"
args = append(args, filters.ClientIP)
}
if filters.ProjectID != "" {
query += " AND project_id = ?"
args = append(args, filters.ProjectID)
}
if filters.Outcome != "" {
query += " AND outcome = ?"
args = append(args, filters.Outcome)
}
if filters.Actor != "" {
query += " AND actor = ?"
args = append(args, filters.Actor)
}
if filters.RequestID != "" {
query += " AND request_id = ?"
args = append(args, filters.RequestID)
}
if filters.CorrelationID != "" {
query += " AND correlation_id = ?"
args = append(args, filters.CorrelationID)
}
if filters.Method != "" {
query += " AND method = ?"
args = append(args, filters.Method)
}
if filters.Path != "" {
query += " AND path = ?"
args = append(args, filters.Path)
}
if filters.Search != "" {
query += " AND (request_id LIKE ? OR correlation_id LIKE ? OR client_ip LIKE ? OR action LIKE ? OR actor LIKE ? OR method LIKE ? OR path LIKE ? OR reason LIKE ? OR metadata LIKE ?)"
searchPattern := "%" + filters.Search + "%"
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern, searchPattern)
}
if filters.StartTime != nil {
query += " AND timestamp >= datetime(?)"
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
query += " AND timestamp <= datetime(?)"
args = append(args, *filters.EndTime)
}
var count int
err := d.QueryRowContextRebound(ctx, query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count audit events: %w", err)
}
return count, nil
}
// GetAuditEventByID retrieves a specific audit event by its ID
func (d *DB) GetAuditEventByID(ctx context.Context, id string) (*AuditEvent, error) {
query := "SELECT id, timestamp, action, actor, project_id, request_id, correlation_id, client_ip, method, path, user_agent, outcome, reason, token_id, metadata FROM audit_events WHERE id = ?"
row := d.QueryRowContextRebound(ctx, query, id)
var event AuditEvent
err := row.Scan(
&event.ID,
&event.Timestamp,
&event.Action,
&event.Actor,
&event.ProjectID,
&event.RequestID,
&event.CorrelationID,
&event.ClientIP,
&event.Method,
&event.Path,
&event.UserAgent,
&event.Outcome,
&event.Reason,
&event.TokenID,
&event.Metadata,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("audit event not found")
}
return nil, fmt.Errorf("failed to get audit event: %w", err)
}
return &event, nil
}
// Package database provides SQLite database operations for the LLM Proxy.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"runtime"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
// DB represents the database connection.
type DB struct {
db *sql.DB
driver DriverType
}
// Config contains the database configuration.
type Config struct {
// Path is the path to the SQLite database file.
Path string
// MaxOpenConns is the maximum number of open connections.
MaxOpenConns int
// MaxIdleConns is the maximum number of idle connections.
MaxIdleConns int
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
ConnMaxLifetime time.Duration
}
// DefaultConfig returns a default database configuration.
func DefaultConfig() Config {
return Config{
Path: "data/llm-proxy.db",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
}
}
// New creates a new database connection.
func New(config Config) (*DB, error) {
// Ensure database directory exists
if err := ensureDirExists(filepath.Dir(config.Path)); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
// Open connection
// NOTE: We persist and interpret timestamps in UTC to avoid timezone drift.
// SQLite stores timestamps without timezone info; `_loc=UTC` forces parsing as UTC.
db, err := sql.Open("sqlite3", config.Path+"?_journal=WAL&_foreign_keys=on&_loc=UTC")
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Configure connection pool
// Special case: in-memory SQLite databases are per-connection. Use a single connection
// to ensure schema and data are visible across queries within the same *sql.DB handle.
if config.Path == ":memory:" {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
}
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
// Initialize SQLite schema (not migrations - SQLite uses schema.sql directly)
if err := initSQLiteSchema(db); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
return &DB{db: db, driver: DriverSQLite}, nil
}
// Close closes the database connection.
func (d *DB) Close() error {
if d.db != nil {
_ = d.db.Close()
}
return nil
}
// ensureDirExists creates the directory if it doesn't exist.
func ensureDirExists(dir string) error {
info, err := os.Stat(dir)
if errors.Is(err, fs.ErrNotExist) {
return os.MkdirAll(dir, 0755)
} else if err != nil {
return err
}
if !info.IsDir() {
return fmt.Errorf("path %s exists and is not a directory", dir)
}
return nil
}
// getSchemaPath returns the path to the SQLite schema file.
// SQLite uses a single schema file instead of migrations.
func getSchemaPath() (string, error) {
// Strategy 1: Relative path from current working directory (development)
cwdPath := filepath.Join("scripts", "schema.sql")
if _, err := os.Stat(cwdPath); err == nil {
return cwdPath, nil
}
// Strategy 2: Path relative to this source file (for tests)
_, thisFile, _, ok := runtime.Caller(0)
if ok {
repoRoot := filepath.Dir(filepath.Dir(filepath.Dir(thisFile)))
srcPath := filepath.Join(repoRoot, "scripts", "schema.sql")
if _, err := os.Stat(srcPath); err == nil {
return srcPath, nil
}
}
// Strategy 3: Relative path from executable location (production)
execPath, err := os.Executable()
if err == nil {
execDir := filepath.Dir(execPath)
execSchemaPath := filepath.Join(execDir, "scripts", "schema.sql")
if _, err := os.Stat(execSchemaPath); err == nil {
return execSchemaPath, nil
}
// Also try parent directory (for bin/ structure)
parentSchemaPath := filepath.Join(filepath.Dir(execDir), "scripts", "schema.sql")
if _, err := os.Stat(parentSchemaPath); err == nil {
return parentSchemaPath, nil
}
}
return "", fmt.Errorf("schema.sql not found in any expected location")
}
// initSQLiteSchema initializes the SQLite database from schema.sql.
// SQLite does NOT use migrations - only the current schema file.
func initSQLiteSchema(db *sql.DB) error {
schemaPath, err := getSchemaPath()
if err != nil {
return fmt.Errorf("failed to get schema path: %w", err)
}
schemaSQL, err := os.ReadFile(schemaPath)
if err != nil {
return fmt.Errorf("failed to read schema file: %w", err)
}
_, err = db.Exec(string(schemaSQL))
if err != nil {
return fmt.Errorf("failed to execute schema: %w", err)
}
return nil
}
// DBInitForTests is a helper to ensure schema exists in tests. No-op if db is nil.
func DBInitForTests(d *DB) error {
if d == nil || d.db == nil {
return nil
}
return initSQLiteSchema(d.db)
}
// Transaction executes the given function within a transaction.
func (d *DB) Transaction(ctx context.Context, fn func(*sql.Tx) error) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
// If the function panics, rollback the transaction
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p) // Re-throw the panic after rolling back
}
}()
// Execute the function
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// DB returns the underlying sql.DB instance.
func (d *DB) DB() *sql.DB {
return d.db
}
// Driver returns the database driver type.
func (d *DB) Driver() DriverType {
return d.driver
}
// HealthCheck performs a health check on the database connection.
// It verifies that the database is reachable and responsive.
func (d *DB) HealthCheck(ctx context.Context) error {
if d == nil || d.db == nil {
return fmt.Errorf("database is nil")
}
// Test the connection with a simple query
if err := d.db.PingContext(ctx); err != nil {
return fmt.Errorf("database ping failed: %w", err)
}
// Verify we can execute a simple query
var result int
err := d.db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
if err != nil {
return fmt.Errorf("database query failed: %w", err)
}
return nil
}
// Package database provides database operations for the LLM Proxy.
package database
import (
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/database/migrations"
_ "github.com/lib/pq" // PostgreSQL driver
)
// DriverType represents the database driver type.
type DriverType string
const (
// DriverSQLite represents the SQLite database driver.
DriverSQLite DriverType = "sqlite"
// DriverPostgres represents the PostgreSQL database driver.
DriverPostgres DriverType = "postgres"
// DriverMySQL represents the MySQL database driver.
DriverMySQL DriverType = "mysql"
)
// FullConfig contains the complete database configuration for all drivers.
type FullConfig struct {
// Driver specifies which database driver to use (sqlite, postgres, mysql).
Driver DriverType
// SQLite-specific configuration
// Path is the path to the SQLite database file.
Path string
// PostgreSQL and MySQL-specific configuration
// DatabaseURL is the PostgreSQL or MySQL connection string.
DatabaseURL string
// Connection pool settings (used by all drivers)
// MaxOpenConns is the maximum number of open connections.
MaxOpenConns int
// MaxIdleConns is the maximum number of idle connections.
MaxIdleConns int
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
ConnMaxLifetime time.Duration
}
// DefaultFullConfig returns a default database configuration.
func DefaultFullConfig() FullConfig {
return FullConfig{
Driver: DriverSQLite,
Path: "data/llm-proxy.db",
DatabaseURL: "",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
}
}
// ConfigFromEnv creates a FullConfig from environment variables.
// Invalid configuration values are logged as warnings and defaults are used.
func ConfigFromEnv() FullConfig {
config := DefaultFullConfig()
if driver := os.Getenv("DB_DRIVER"); driver != "" {
driverType := DriverType(strings.ToLower(driver))
if driverType != DriverSQLite && driverType != DriverPostgres && driverType != DriverMySQL {
log.Printf("Warning: unsupported DB_DRIVER '%s', defaulting to sqlite", driver)
} else {
config.Driver = driverType
}
}
if path := os.Getenv("DATABASE_PATH"); path != "" {
config.Path = path
}
if url := os.Getenv("DATABASE_URL"); url != "" {
config.DatabaseURL = url
}
if poolSize := os.Getenv("DATABASE_POOL_SIZE"); poolSize != "" {
if size, err := parsePositiveInt(poolSize); err == nil {
config.MaxOpenConns = size
} else {
log.Printf("Warning: invalid DATABASE_POOL_SIZE '%s': %v, using default %d", poolSize, err, config.MaxOpenConns)
}
}
if idleConns := os.Getenv("DATABASE_MAX_IDLE_CONNS"); idleConns != "" {
if size, err := parsePositiveInt(idleConns); err == nil {
config.MaxIdleConns = size
} else {
log.Printf("Warning: invalid DATABASE_MAX_IDLE_CONNS '%s': %v, using default %d", idleConns, err, config.MaxIdleConns)
}
}
if lifetime := os.Getenv("DATABASE_CONN_MAX_LIFETIME"); lifetime != "" {
if duration, err := time.ParseDuration(lifetime); err == nil {
config.ConnMaxLifetime = duration
} else {
log.Printf("Warning: invalid DATABASE_CONN_MAX_LIFETIME '%s': %v, using default %v", lifetime, err, config.ConnMaxLifetime)
}
}
return config
}
// parsePositiveInt parses a string as a positive integer.
func parsePositiveInt(s string) (int, error) {
var i int
_, err := fmt.Sscanf(s, "%d", &i)
if err != nil || i <= 0 {
return 0, fmt.Errorf("invalid positive integer: %s", s)
}
return i, nil
}
// NewFromConfig creates a new database connection based on the configuration.
func NewFromConfig(config FullConfig) (*DB, error) {
switch config.Driver {
case DriverSQLite:
return newSQLiteDB(config)
case DriverPostgres:
return newPostgresDB(config)
case DriverMySQL:
return newMySQLDB(config)
default:
return nil, fmt.Errorf("unsupported database driver: %s", config.Driver)
}
}
// newSQLiteDB creates a new SQLite database connection.
func newSQLiteDB(config FullConfig) (*DB, error) {
// Ensure database directory exists (skip for in-memory databases)
if config.Path != ":memory:" {
if err := ensureDirExists(filepath.Dir(config.Path)); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// Open connection
// NOTE: We persist and interpret timestamps in UTC to avoid timezone drift.
// SQLite stores timestamps without timezone info; `_loc=UTC` forces parsing as UTC.
db, err := sql.Open("sqlite3", config.Path+"?_journal=WAL&_foreign_keys=on&_loc=UTC")
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Configure connection pool
// Special case: in-memory SQLite databases are per-connection. Use a single connection
// to ensure schema and data are visible across queries within the same *sql.DB handle.
if config.Path == ":memory:" {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
}
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to ping SQLite database: %w", err)
}
// Initialize SQLite schema (SQLite uses schema.sql, NOT migrations)
if err := initSQLiteSchema(db); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to initialize SQLite schema: %w", err)
}
return &DB{db: db, driver: DriverSQLite}, nil
}
// runMigrationsForDriver runs database migrations for the specified driver.
// Note: Only PostgreSQL and MySQL use migrations. SQLite uses schema.sql directly.
func runMigrationsForDriver(db *sql.DB, dialect string) error {
if dialect == "sqlite3" || dialect == "sqlite" {
// SQLite does NOT use migrations - it uses schema.sql directly
return fmt.Errorf("SQLite does not use migrations; use initSQLiteSchema instead")
}
migrationsPath, err := getMigrationsPathForDialect(dialect)
if err != nil {
return fmt.Errorf("failed to get migrations path: %w", err)
}
runner := migrations.NewMigrationRunner(db, migrationsPath)
if err := runner.Up(); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
}
// getMigrationsPathForDialect returns the path to migrations for the specified dialect.
// Note: Only PostgreSQL and MySQL use migrations. SQLite uses schema.sql directly.
func getMigrationsPathForDialect(dialect string) (string, error) {
// SQLite does not use migrations
if dialect == "sqlite3" || dialect == "sqlite" {
return "", fmt.Errorf("SQLite does not use migrations; use schema.sql instead")
}
// Common base paths to try
basePaths := []string{
"internal/database/migrations",
}
// Add path relative to this source file (for tests)
_, filename, _, ok := runtime.Caller(0)
if ok {
sourceDir := filepath.Dir(filename)
basePaths = append(basePaths, filepath.Join(sourceDir, "migrations"))
}
// Add paths relative to executable
execPath, err := os.Executable()
if err == nil {
execDir := filepath.Dir(execPath)
basePaths = append(basePaths, filepath.Join(execDir, "internal/database/migrations"))
basePaths = append(basePaths, filepath.Join(filepath.Dir(execDir), "internal/database/migrations"))
}
// Try each base path for the specified dialect
for _, basePath := range basePaths {
// PostgreSQL and MySQL migrations are in sql/{dialect}/
dialectPath := filepath.Join(basePath, "sql", dialect)
if _, err := os.Stat(dialectPath); err == nil {
return dialectPath, nil
}
}
return "", fmt.Errorf("migrations directory not found for dialect: %s", dialect)
}
// MigrationsPathForDriver returns the migrations directory for the given driver type.
// Note: Only PostgreSQL and MySQL use migrations. SQLite uses schema.sql directly.
// This ensures CLI and server code share the same dialect-aware lookup logic.
func MigrationsPathForDriver(driver DriverType) (string, error) {
switch driver {
case DriverSQLite:
return "", fmt.Errorf("SQLite does not use migrations; use schema.sql instead")
case DriverPostgres:
return getMigrationsPathForDialect("postgres")
case DriverMySQL:
return getMigrationsPathForDialect("mysql")
default:
return getMigrationsPathForDialect(string(driver))
}
}
//go:build !mysql
package database
import "fmt"
// newMySQLDB is a stub that returns an error when MySQL support
// is not compiled in. The real implementation is in factory_mysql.go
// and requires the 'mysql' build tag.
//
// To enable MySQL support, build with: go build -tags mysql ./...
func newMySQLDB(_ FullConfig) (*DB, error) {
return nil, fmt.Errorf("MySQL support not compiled in; build with -tags mysql to enable")
}
//go:build !postgres
package database
import "fmt"
// newPostgresDB is a stub that returns an error when PostgreSQL support
// is not compiled in. The real implementation is in factory_postgres.go
// and requires the 'postgres' build tag.
//
// To enable PostgreSQL support, build with: go build -tags postgres ./...
func newPostgresDB(_ FullConfig) (*DB, error) {
return nil, fmt.Errorf("PostgreSQL support not compiled in; build with -tags postgres to enable")
}
//go:build !mysql
package migrations
import "fmt"
// acquireMySQLLock is a stub that returns an error when MySQL support
// is not compiled in. The real implementation is in mysql_lock.go and requires
// the 'mysql' build tag.
//
// MySQL locking will be tested via Docker Compose integration tests.
func (m *MigrationRunner) acquireMySQLLock() (func(), error) {
return nil, fmt.Errorf("MySQL named locking requires the 'mysql' build tag")
}
//go:build !postgres
package migrations
import "fmt"
// acquirePostgresLock is a stub that returns an error when PostgreSQL support
// is not compiled in. The real implementation is in postgres_lock.go and requires
// the 'postgres' build tag.
//
// PostgreSQL locking will be tested via Docker Compose integration tests (issue #139).
func (m *MigrationRunner) acquirePostgresLock() (func(), error) {
return nil, fmt.Errorf("PostgreSQL advisory locking requires the 'postgres' build tag; see issue #139 for Docker Compose integration tests")
}
// Package migrations provides database migration functionality using goose.
package migrations
import (
"database/sql"
"fmt"
"os"
"strings"
"time"
"github.com/pressly/goose/v3"
)
// MigrationRunner manages database migrations using goose.
type MigrationRunner struct {
db *sql.DB
migrationsPath string
}
// NewMigrationRunner creates a new migration runner.
// db is the database connection, migrationsPath is the directory containing SQL migration files.
func NewMigrationRunner(db *sql.DB, migrationsPath string) *MigrationRunner {
return &MigrationRunner{
db: db,
migrationsPath: migrationsPath,
}
}
// Up applies all pending migrations.
// Each migration runs in a transaction and will be rolled back if it fails.
// Advisory locking is used to prevent concurrent migrations in distributed systems.
func (m *MigrationRunner) Up() error {
if m.db == nil {
return fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return fmt.Errorf("migrations path is empty")
}
// Acquire advisory lock to prevent concurrent migrations
release, err := m.acquireMigrationLock()
if err != nil {
return fmt.Errorf("failed to acquire migration lock: %w", err)
}
defer release()
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return fmt.Errorf("failed to set goose dialect: %w", err)
}
// Run migrations
if err := goose.Up(m.db, m.migrationsPath); err != nil {
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
}
// Down rolls back the most recently applied migration.
// The rollback runs in a transaction.
func (m *MigrationRunner) Down() error {
if m.db == nil {
return fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return fmt.Errorf("migrations path is empty")
}
// Acquire migration lock to prevent concurrent operations
release, err := m.acquireMigrationLock()
if err != nil {
return fmt.Errorf("failed to acquire migration lock: %w", err)
}
defer release()
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return fmt.Errorf("failed to set goose dialect: %w", err)
}
// Roll back one migration
if err := goose.Down(m.db, m.migrationsPath); err != nil {
return fmt.Errorf("failed to roll back migration: %w", err)
}
return nil
}
// Status returns the current migration version.
// Returns 0 if no migrations have been applied.
func (m *MigrationRunner) Status() (int64, error) {
if m.db == nil {
return 0, fmt.Errorf("database connection is nil")
}
if m.migrationsPath == "" {
return 0, fmt.Errorf("migrations path is empty")
}
// Detect database driver and set goose dialect
driverName, err := m.detectDriver()
if err != nil {
return 0, fmt.Errorf("failed to detect database driver: %w", err)
}
if err := goose.SetDialect(driverName); err != nil {
return 0, fmt.Errorf("failed to set goose dialect: %w", err)
}
// Get current version
version, err := goose.GetDBVersion(m.db)
if err != nil {
return 0, fmt.Errorf("failed to get migration version: %w", err)
}
return version, nil
}
// Version is an alias for Status(). Returns the current migration version.
func (m *MigrationRunner) Version() (int64, error) {
return m.Status()
}
// detectDriver detects the database driver from the connection.
// Returns the goose dialect name: "sqlite3", "postgres", or "mysql".
func (m *MigrationRunner) detectDriver() (string, error) {
if m.db == nil {
return "", fmt.Errorf("database connection is nil")
}
// Get driver name from connection
driverName := m.db.Driver()
if driverName == nil {
return "", fmt.Errorf("driver is nil")
}
driverType := fmt.Sprintf("%T", driverName)
// Detect SQLite
if driverType == "*sqlite3.SQLiteDriver" || driverType == "*sqlite3.SQLiteConn" {
return "sqlite3", nil
}
// Detect PostgreSQL (common drivers: lib/pq, pgx, pgx/stdlib)
// pgx/stdlib registers as *stdlib.Driver when using sql.Open("pgx", ...)
if driverType == "*pq.driver" || driverType == "*pgx.Conn" || driverType == "*pgxpool.Pool" ||
driverType == "*stdlib.Driver" {
return "postgres", nil
}
// Detect MySQL (common driver: go-sql-driver/mysql)
if driverType == "*mysql.MySQLDriver" {
return "mysql", nil
}
// Try to detect via database-specific queries
// First, try PostgreSQL-specific query
var version string
pgErr := m.db.QueryRow("SELECT version()").Scan(&version)
if pgErr == nil && strings.HasPrefix(version, "PostgreSQL") {
// version() is PostgreSQL-specific and returns something like "PostgreSQL 15.x ..."
return "postgres", nil
}
// Try MySQL-specific query
var mysqlVersion string
mysqlErr := m.db.QueryRow("SELECT VERSION()").Scan(&mysqlVersion)
if mysqlErr == nil && (strings.Contains(mysqlVersion, "MySQL") || strings.Contains(mysqlVersion, "MariaDB")) {
// MySQL and MariaDB return version strings containing "MySQL" or "MariaDB"
return "mysql", nil
}
// Try SQLite-specific pragma
var journalMode string
pragmaErr := m.db.QueryRow("PRAGMA journal_mode").Scan(&journalMode)
if pragmaErr == nil {
return "sqlite3", nil
}
// Default to sqlite3 for backward compatibility
return "sqlite3", nil
}
// acquireMigrationLock acquires an advisory lock to prevent concurrent migrations.
// Returns a release function that must be called to release the lock.
func (m *MigrationRunner) acquireMigrationLock() (func(), error) {
driverName, err := m.detectDriver()
if err != nil {
return nil, fmt.Errorf("failed to detect driver for locking: %w", err)
}
switch driverName {
case "sqlite3":
return m.acquireSQLiteLock()
case "postgres":
return m.acquirePostgresLock()
case "mysql":
return m.acquireMySQLLock()
default:
return m.acquireSQLiteLock() // Default to SQLite lock for backward compatibility
}
}
// acquireSQLiteLock acquires a lock using a SQLite lock table.
// This prevents concurrent migrations when multiple instances start simultaneously.
func (m *MigrationRunner) acquireSQLiteLock() (func(), error) {
// Create lock table if it doesn't exist
_, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS migration_lock (
id INTEGER PRIMARY KEY CHECK (id = 1),
locked BOOLEAN NOT NULL DEFAULT 0,
locked_at DATETIME,
locked_by TEXT,
process_id INTEGER
)
`)
if err != nil {
return nil, fmt.Errorf("failed to create lock table: %w", err)
}
// Initialize lock row if it doesn't exist
_, _ = m.db.Exec(`INSERT OR IGNORE INTO migration_lock (id, locked) VALUES (1, 0)`)
// Try to acquire lock with retries
maxRetries := 10
retryDelay := 100 * time.Millisecond
processID := os.Getpid()
for i := 0; i < maxRetries; i++ {
// Use a transaction to atomically check and acquire lock
tx, err := m.db.Begin()
if err != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", err)
}
var locked bool
err = tx.QueryRow(`SELECT locked FROM migration_lock WHERE id = 1`).Scan(&locked)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("failed to read lock status: %w", err)
}
if locked {
_ = tx.Rollback()
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("migration lock is already held by another process (retried %d times)", maxRetries)
}
// Acquire lock
result, err := tx.Exec(`
UPDATE migration_lock
SET locked = 1, locked_at = CURRENT_TIMESTAMP, locked_by = ?, process_id = ?
WHERE id = 1 AND locked = 0
`, fmt.Sprintf("pid-%d", processID), processID)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("failed to acquire lock: %w", err)
}
// Verify that the update actually affected a row (lock was acquired)
rowsAffected, err := result.RowsAffected()
if err != nil || rowsAffected == 0 {
_ = tx.Rollback()
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("failed to acquire lock: another process may have acquired it")
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit lock acquisition: %w", err)
}
// Verify lock was acquired
var isLocked bool
err = m.db.QueryRow(`SELECT locked FROM migration_lock WHERE id = 1`).Scan(&isLocked)
if err != nil || !isLocked {
if i < maxRetries-1 {
time.Sleep(retryDelay)
continue
}
return nil, fmt.Errorf("failed to verify lock acquisition")
}
// Return release function
release := func() {
_, _ = m.db.Exec(`UPDATE migration_lock SET locked = 0 WHERE id = 1`)
}
return release, nil
}
return nil, fmt.Errorf("failed to acquire migration lock after %d retries", maxRetries)
}
// NOTE: acquirePostgresLock is defined in postgres_lock.go (with postgres build tag)
// and postgres_lock_stub.go (without postgres build tag). This allows PostgreSQL-specific
// code to be excluded from default coverage calculations. PostgreSQL integration tests
// will be added via Docker Compose in issue #139.
package database
import (
"context"
"errors"
"sync"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
// MockProjectStore is an in-memory implementation of ProjectStore for testing and development
type MockProjectStore struct {
projects map[string]Project
apiKeys map[string]string // Project ID -> API Key mapping
mutex sync.RWMutex
}
// NewMockProjectStore creates a new MockProjectStore
func NewMockProjectStore() *MockProjectStore {
return &MockProjectStore{
projects: make(map[string]Project),
apiKeys: make(map[string]string),
}
}
// CreateProject creates a new project in the store
func (m *MockProjectStore) DBCreateProject(ctx context.Context, project Project) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[project.ID]; exists {
return errors.New("project already exists")
}
m.projects[project.ID] = project
m.apiKeys[project.ID] = project.APIKey
return nil
}
// GetProjectByID retrieves a project by ID
func (m *MockProjectStore) DBGetProjectByID(ctx context.Context, projectID string) (Project, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
project, exists := m.projects[projectID]
if !exists {
return Project{}, errors.New("project not found")
}
return project, nil
}
// UpdateProject updates a project in the store
func (m *MockProjectStore) DBUpdateProject(ctx context.Context, project Project) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[project.ID]; !exists {
return errors.New("project not found")
}
m.projects[project.ID] = project
m.apiKeys[project.ID] = project.APIKey
return nil
}
// DeleteProject deletes a project from the store
func (m *MockProjectStore) DBDeleteProject(ctx context.Context, projectID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.projects[projectID]; !exists {
return errors.New("project not found")
}
delete(m.projects, projectID)
delete(m.apiKeys, projectID)
return nil
}
// ListProjects retrieves all projects from the store
func (m *MockProjectStore) DBListProjects(ctx context.Context) ([]Project, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
projects := make([]Project, 0, len(m.projects))
for _, p := range m.projects {
projects = append(projects, p)
}
return projects, nil
}
// CreateMockProject creates a new project in the mock store with the given parameters
func (m *MockProjectStore) CreateMockProject(projectID, name, apiKey string) (Project, error) {
if projectID == "" {
return Project{}, errors.New("project ID cannot be empty")
}
if name == "" {
return Project{}, errors.New("project name cannot be empty")
}
if apiKey == "" {
return Project{}, errors.New("API key cannot be empty")
}
project := Project{
ID: projectID,
Name: name,
APIKey: apiKey,
}
err := m.DBCreateProject(context.Background(), project)
return project, err
}
// GetAPIKeyForProject retrieves the API key for a project
func (m *MockProjectStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
apiKey, exists := m.apiKeys[projectID]
if !exists {
return "", errors.New("project not found")
}
return apiKey, nil
}
// --- proxy.ProjectStore interface adapters ---
func (m *MockProjectStore) ListProjects(ctx context.Context) ([]proxy.Project, error) {
dbProjects, err := m.DBListProjects(ctx)
if err != nil {
return nil, err
}
var out []proxy.Project
for _, p := range dbProjects {
out = append(out, ToProxyProject(p))
}
return out, nil
}
func (m *MockProjectStore) CreateProject(ctx context.Context, p proxy.Project) error {
return m.DBCreateProject(ctx, ToDBProject(p))
}
func (m *MockProjectStore) GetProjectByID(ctx context.Context, id string) (proxy.Project, error) {
dbP, err := m.DBGetProjectByID(ctx, id)
if err != nil {
return proxy.Project{}, err
}
return ToProxyProject(dbP), nil
}
func (m *MockProjectStore) UpdateProject(ctx context.Context, p proxy.Project) error {
return m.DBUpdateProject(ctx, ToDBProject(p))
}
func (m *MockProjectStore) DeleteProject(ctx context.Context, id string) error {
return m.DBDeleteProject(ctx, id)
}
// GetProjectActive retrieves the active status for a project by ID
func (m *MockProjectStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
project, exists := m.projects[projectID]
if !exists {
return false, errors.New("project not found")
}
return project.IsActive, nil
}
package database
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/sofatutor/llm-proxy/internal/token"
)
// MockTokenStore is an in-memory implementation of TokenStore for testing and development
type MockTokenStore struct {
tokens map[string]Token
mutex sync.RWMutex
projectIDs map[string]string // token -> projectID mapping for quick lookup
}
// NewMockTokenStore creates a new MockTokenStore
func NewMockTokenStore() *MockTokenStore {
return &MockTokenStore{
tokens: make(map[string]Token),
projectIDs: make(map[string]string),
}
}
// CreateToken creates a new token in the store
func (m *MockTokenStore) CreateToken(ctx context.Context, token Token) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[token.Token]; exists {
return ErrTokenExists
}
m.tokens[token.Token] = token
m.projectIDs[token.Token] = token.ProjectID
return nil
}
// GetTokenByID retrieves a token by ID
func (m *MockTokenStore) GetTokenByID(ctx context.Context, tokenID string) (Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
token, exists := m.tokens[tokenID]
if !exists {
return Token{}, ErrTokenNotFound
}
return token, nil
}
// UpdateToken updates a token in the store
func (m *MockTokenStore) UpdateToken(ctx context.Context, token Token) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[token.Token]; !exists {
return ErrTokenNotFound
}
m.tokens[token.Token] = token
m.projectIDs[token.Token] = token.ProjectID
return nil
}
// DeleteToken deletes a token from the store
func (m *MockTokenStore) DeleteToken(ctx context.Context, tokenID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.tokens[tokenID]; !exists {
return ErrTokenNotFound
}
delete(m.tokens, tokenID)
delete(m.projectIDs, tokenID)
return nil
}
// ListTokens retrieves all tokens from the store
func (m *MockTokenStore) ListTokens(ctx context.Context) ([]Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
tokens := make([]Token, 0, len(m.tokens))
for _, t := range m.tokens {
tokens = append(tokens, t)
}
return tokens, nil
}
// GetTokensByProjectID retrieves all tokens for a project
func (m *MockTokenStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]Token, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
var tokens []Token
for _, t := range m.tokens {
if t.ProjectID == projectID {
tokens = append(tokens, t)
}
}
return tokens, nil
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp
func (m *MockTokenStore) IncrementTokenUsage(ctx context.Context, tokenID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
t, exists := m.tokens[tokenID]
if !exists {
return ErrTokenNotFound
}
if t.MaxRequests != nil && *t.MaxRequests > 0 && t.RequestCount >= *t.MaxRequests {
return token.ErrTokenRateLimit
}
t.RequestCount++
now := time.Now()
t.LastUsedAt = &now
m.tokens[tokenID] = t
return nil
}
// CleanExpiredTokens deletes expired tokens from the store
func (m *MockTokenStore) CleanExpiredTokens(ctx context.Context) (int64, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
var count int64
for id, t := range m.tokens {
if t.ExpiresAt != nil && t.ExpiresAt.Before(now) {
delete(m.tokens, id)
delete(m.projectIDs, id)
count++
}
}
return count, nil
}
// CreateMockToken creates a new token in the mock store with the given parameters
func (m *MockTokenStore) CreateMockToken(tokenID, projectID string, expiresIn time.Duration, isActive bool, maxRequests *int) (Token, error) {
if tokenID == "" {
return Token{}, errors.New("token ID cannot be empty")
}
if projectID == "" {
return Token{}, errors.New("project ID cannot be empty")
}
var expiresAt *time.Time
if expiresIn > 0 {
expiry := time.Now().Add(expiresIn)
expiresAt = &expiry
}
now := time.Now()
token := Token{
Token: tokenID,
ProjectID: projectID,
ExpiresAt: expiresAt,
IsActive: isActive,
RequestCount: 0,
MaxRequests: maxRequests,
CreatedAt: now,
}
err := m.CreateToken(context.Background(), token)
return token, err
}
// TokenStoreAdapter adapts the database.DB to the token.TokenStore interface
type TokenStoreAdapter struct {
store *MockTokenStore
}
// NewTokenStoreAdapter creates a new TokenStoreAdapter
func NewTokenStoreAdapter(store *MockTokenStore) *TokenStoreAdapter {
return &TokenStoreAdapter{
store: store,
}
}
// GetTokenByID retrieves a token by ID
func (a *TokenStoreAdapter) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
t, err := a.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.TokenData{}, token.ErrTokenNotFound
}
return token.TokenData{}, err
}
return ExportTokenData(t), nil
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp
func (a *TokenStoreAdapter) IncrementTokenUsage(ctx context.Context, tokenID string) error {
err := a.store.IncrementTokenUsage(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.ErrTokenNotFound
}
return err
}
return nil
}
// CreateToken creates a new token in the store
func (a *TokenStoreAdapter) CreateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.store.CreateToken(ctx, dbToken)
}
// ListTokens retrieves all tokens from the store
func (a *TokenStoreAdapter) ListTokens(ctx context.Context) ([]token.TokenData, error) {
dbTokens, err := a.store.ListTokens(ctx)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
// GetTokensByProjectID retrieves all tokens for a project
func (a *TokenStoreAdapter) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
dbTokens, err := a.store.GetTokensByProjectID(ctx, projectID)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
func TestMockTokenStore_EdgeCases(t *testing.T) {
store := NewMockTokenStore()
ctx := context.Background()
t.Run("GetTokenByID not found", func(t *testing.T) {
_, err := store.GetTokenByID(ctx, "notfound")
if err == nil {
t.Error("expected error for notfound token")
}
})
t.Run("IncrementTokenUsage not found", func(t *testing.T) {
err := store.IncrementTokenUsage(ctx, "notfound")
if err == nil {
t.Error("expected error for notfound token")
}
})
t.Run("ListTokens empty", func(t *testing.T) {
ts, err := store.ListTokens(ctx)
if err != nil {
t.Error(err)
}
if len(ts) != 0 {
t.Errorf("expected 0 tokens, got %d", len(ts))
}
})
t.Run("GetTokensByProjectID empty", func(t *testing.T) {
ts, err := store.GetTokensByProjectID(ctx, "pid")
if err != nil {
t.Error(err)
}
if len(ts) != 0 {
t.Errorf("expected 0 tokens, got %d", len(ts))
}
})
}
package database
import (
"time"
)
// Project represents a project in the database.
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
APIKey string `json:"-"` // Sensitive data, not included in JSON. Encrypted when ENCRYPTION_KEY is set.
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Token represents a token in the database.
type Token struct {
ID string `json:"id"`
Token string `json:"token"`
ProjectID string `json:"project_id"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CacheHitCount int `json:"cache_hit_count"`
}
// AuditEvent represents an audit log entry in the database.
type AuditEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"`
Actor string `json:"actor"`
ProjectID *string `json:"project_id,omitempty"`
RequestID *string `json:"request_id,omitempty"`
CorrelationID *string `json:"correlation_id,omitempty"`
ClientIP *string `json:"client_ip,omitempty"`
Method *string `json:"method,omitempty"`
Path *string `json:"path,omitempty"`
UserAgent *string `json:"user_agent,omitempty"`
Outcome string `json:"outcome"`
Reason *string `json:"reason,omitempty"`
TokenID *string `json:"token_id,omitempty"`
Metadata *string `json:"metadata,omitempty"` // JSON string
}
// IsExpired returns true if the token has expired.
func (t *Token) IsExpired() bool {
if t.ExpiresAt == nil {
return false
}
return time.Now().After(*t.ExpiresAt)
}
// IsRateLimited returns true if the token has reached its maximum number of requests.
func (t *Token) IsRateLimited() bool {
if t.MaxRequests == nil {
return false
}
return t.RequestCount >= *t.MaxRequests
}
// IsValid returns true if the token is active, not expired, and not rate limited.
func (t *Token) IsValid() bool {
return t.IsActive && !t.IsExpired() && !t.IsRateLimited()
}
// IsDeactivated returns true if the token has been explicitly deactivated.
func (t *Token) IsDeactivated() bool {
return t.DeactivatedAt != nil
}
// IsDeactivated returns true if the project has been explicitly deactivated.
func (p *Project) IsDeactivated() bool {
return p.DeactivatedAt != nil
}
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
var (
// ErrProjectNotFound is returned when a project is not found.
ErrProjectNotFound = errors.New("project not found")
// ErrProjectExists is returned when a project already exists.
ErrProjectExists = errors.New("project already exists")
)
// GetProjectByName retrieves a project by name.
func (d *DB) GetProjectByName(ctx context.Context, name string) (Project, error) {
query := `
SELECT id, name, api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
WHERE name = ?
`
var project Project
var deactivatedAt sql.NullTime
err := d.QueryRowContextRebound(ctx, query, name).Scan(
&project.ID,
&project.Name,
&project.APIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Project{}, ErrProjectNotFound
}
return Project{}, fmt.Errorf("failed to get project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
return project, nil
}
// ToProxyProject converts a database.Project to a proxy.Project
func ToProxyProject(dbProject Project) proxy.Project {
return proxy.Project{
ID: dbProject.ID,
Name: dbProject.Name,
APIKey: dbProject.APIKey,
IsActive: dbProject.IsActive,
DeactivatedAt: dbProject.DeactivatedAt,
CreatedAt: dbProject.CreatedAt,
UpdatedAt: dbProject.UpdatedAt,
}
}
// ToDBProject converts a proxy.Project to a database.Project
func ToDBProject(proxyProject proxy.Project) Project {
return Project{
ID: proxyProject.ID,
Name: proxyProject.Name,
APIKey: proxyProject.APIKey,
IsActive: proxyProject.IsActive,
DeactivatedAt: proxyProject.DeactivatedAt,
CreatedAt: proxyProject.CreatedAt,
UpdatedAt: proxyProject.UpdatedAt,
}
}
// Rename CRUD methods for DB store
func (d *DB) DBListProjects(ctx context.Context) ([]Project, error) {
query := `
SELECT id, name, api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
ORDER BY name ASC
`
rows, err := d.QueryContextRebound(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list projects: %w", err)
}
defer func() {
_ = rows.Close()
}()
var projects []Project
for rows.Next() {
var project Project
var deactivatedAt sql.NullTime
if err := rows.Scan(
&project.ID,
&project.Name,
&project.APIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
projects = append(projects, project)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating projects: %w", err)
}
return projects, nil
}
func (d *DB) DBCreateProject(ctx context.Context, project Project) error {
query := `
INSERT INTO projects (id, name, api_key, is_active, deactivated_at, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(
ctx,
query,
project.ID,
project.Name,
project.APIKey,
project.IsActive,
project.DeactivatedAt,
project.CreatedAt,
project.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create project: %w", err)
}
return nil
}
func (d *DB) DBGetProjectByID(ctx context.Context, projectID string) (Project, error) {
query := `
SELECT id, name, api_key, is_active, deactivated_at, created_at, updated_at
FROM projects
WHERE id = ?
`
var project Project
var deactivatedAt sql.NullTime
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(
&project.ID,
&project.Name,
&project.APIKey,
&project.IsActive,
&deactivatedAt,
&project.CreatedAt,
&project.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Project{}, ErrProjectNotFound
}
return Project{}, fmt.Errorf("failed to get project: %w", err)
}
if deactivatedAt.Valid {
project.DeactivatedAt = &deactivatedAt.Time
}
return project, nil
}
func (d *DB) DBUpdateProject(ctx context.Context, project Project) error {
project.UpdatedAt = time.Now().UTC()
query := `
UPDATE projects
SET name = ?, api_key = ?, is_active = ?, deactivated_at = ?, updated_at = ?
WHERE id = ?
`
result, err := d.ExecContextRebound(
ctx,
query,
project.Name,
project.APIKey,
project.IsActive,
project.DeactivatedAt,
project.UpdatedAt,
project.ID,
)
if err != nil {
return fmt.Errorf("failed to update project: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrProjectNotFound
}
return nil
}
func (d *DB) DBDeleteProject(ctx context.Context, projectID string) error {
query := `
DELETE FROM projects
WHERE id = ?
`
result, err := d.ExecContextRebound(ctx, query, projectID)
if err != nil {
return fmt.Errorf("failed to delete project: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrProjectNotFound
}
return nil
}
// --- proxy.ProjectStore interface adapters ---
func (d *DB) ListProjects(ctx context.Context) ([]proxy.Project, error) {
dbProjects, err := d.DBListProjects(ctx)
if err != nil {
return nil, err
}
var out []proxy.Project
for _, p := range dbProjects {
out = append(out, ToProxyProject(p))
}
return out, nil
}
func (d *DB) CreateProject(ctx context.Context, p proxy.Project) error {
return d.DBCreateProject(ctx, ToDBProject(p))
}
func (d *DB) GetProjectByID(ctx context.Context, id string) (proxy.Project, error) {
dbP, err := d.DBGetProjectByID(ctx, id)
if err != nil {
return proxy.Project{}, err
}
return ToProxyProject(dbP), nil
}
func (d *DB) UpdateProject(ctx context.Context, p proxy.Project) error {
return d.DBUpdateProject(ctx, ToDBProject(p))
}
func (d *DB) DeleteProject(ctx context.Context, id string) error {
return d.DBDeleteProject(ctx, id)
}
// GetAPIKeyForProject retrieves the API key for a project by ID
func (d *DB) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
query := `SELECT api_key FROM projects WHERE id = ?`
var apiKey string
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(&apiKey)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrProjectNotFound
}
return "", fmt.Errorf("failed to get API key for project: %w", err)
}
return apiKey, nil
}
// GetProjectActive retrieves the active status for a project by ID
func (d *DB) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
query := `SELECT is_active FROM projects WHERE id = ?`
var isActive bool
err := d.QueryRowContextRebound(ctx, query, projectID).Scan(&isActive)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, ErrProjectNotFound
}
return false, fmt.Errorf("failed to get project active status: %w", err)
}
return isActive, nil
}
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"github.com/sofatutor/llm-proxy/internal/token"
)
var (
// ErrTokenNotFound is returned when a token is not found.
ErrTokenNotFound = errors.New("token not found")
// ErrTokenExists is returned when a token already exists.
ErrTokenExists = errors.New("token already exists")
)
// CreateToken creates a new token in the database.
// If token.ID is empty, a UUID will be generated automatically.
func (d *DB) CreateToken(ctx context.Context, token Token) error {
// Auto-generate ID if not provided
if token.ID == "" {
token.ID = uuid.New().String()
}
query := `
INSERT INTO tokens (id, token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := d.ExecContextRebound(
ctx,
query,
token.ID,
token.Token,
token.ProjectID,
token.ExpiresAt,
token.IsActive,
nil,
token.RequestCount,
token.MaxRequests,
token.CreatedAt,
token.LastUsedAt,
)
if err != nil {
return fmt.Errorf("failed to create token: %w", err)
}
return nil
}
// GetTokenByID retrieves a token by its UUID.
func (d *DB) GetTokenByID(ctx context.Context, id string) (Token, error) {
query := `
SELECT id, token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
WHERE id = ?
`
var token Token
var expiresAt, lastUsedAt, deactivatedAt sql.NullTime
var maxRequests sql.NullInt32
err := d.QueryRowContextRebound(ctx, query, id).Scan(
&token.ID,
&token.Token,
&token.ProjectID,
&expiresAt,
&token.IsActive,
&deactivatedAt,
&token.RequestCount,
&maxRequests,
&token.CreatedAt,
&lastUsedAt,
&token.CacheHitCount,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Token{}, ErrTokenNotFound
}
return Token{}, fmt.Errorf("failed to get token: %w", err)
}
if expiresAt.Valid {
token.ExpiresAt = &expiresAt.Time
}
if lastUsedAt.Valid {
token.LastUsedAt = &lastUsedAt.Time
}
if deactivatedAt.Valid {
token.DeactivatedAt = &deactivatedAt.Time
}
if maxRequests.Valid {
maxReq := int(maxRequests.Int32)
token.MaxRequests = &maxReq
}
return token, nil
}
// GetTokenByToken retrieves a token by its token string (for authentication).
func (d *DB) GetTokenByToken(ctx context.Context, tokenString string) (Token, error) {
query := `
SELECT id, token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
WHERE token = ?
`
var token Token
var expiresAt, lastUsedAt, deactivatedAt sql.NullTime
var maxRequests sql.NullInt32
err := d.QueryRowContextRebound(ctx, query, tokenString).Scan(
&token.ID,
&token.Token,
&token.ProjectID,
&expiresAt,
&token.IsActive,
&deactivatedAt,
&token.RequestCount,
&maxRequests,
&token.CreatedAt,
&lastUsedAt,
&token.CacheHitCount,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Token{}, ErrTokenNotFound
}
return Token{}, fmt.Errorf("failed to get token: %w", err)
}
if expiresAt.Valid {
token.ExpiresAt = &expiresAt.Time
}
if lastUsedAt.Valid {
token.LastUsedAt = &lastUsedAt.Time
}
if deactivatedAt.Valid {
token.DeactivatedAt = &deactivatedAt.Time
}
if maxRequests.Valid {
maxReq := int(maxRequests.Int32)
token.MaxRequests = &maxReq
}
return token, nil
}
// UpdateToken updates a token in the database.
// It looks up the token by ID (UUID). If ID is empty, it falls back to looking up by token string
// for backward compatibility.
func (d *DB) UpdateToken(ctx context.Context, token Token) error {
if token.ID == "" && token.Token == "" {
return fmt.Errorf("token ID or token string required for update")
}
queryByID := `
UPDATE tokens
SET project_id = ?, expires_at = ?, is_active = ?, request_count = ?, max_requests = ?, last_used_at = ?
WHERE id = ?
`
queryByToken := `
UPDATE tokens
SET project_id = ?, expires_at = ?, is_active = ?, request_count = ?, max_requests = ?, last_used_at = ?
WHERE token = ?
`
query := queryByID
lookupValue := token.ID
if token.ID == "" {
query = queryByToken
lookupValue = token.Token
}
result, err := d.ExecContextRebound(
ctx,
query,
token.ProjectID,
token.ExpiresAt,
token.IsActive,
token.RequestCount,
token.MaxRequests,
token.LastUsedAt,
lookupValue,
)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// DeleteToken deletes a token from the database.
func (d *DB) DeleteToken(ctx context.Context, tokenID string) error {
query := `
DELETE FROM tokens
WHERE token = ?
`
result, err := d.ExecContextRebound(ctx, query, tokenID)
if err != nil {
return fmt.Errorf("failed to delete token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// ListTokens retrieves all tokens from the database.
func (d *DB) ListTokens(ctx context.Context) ([]Token, error) {
query := `
SELECT id, token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
ORDER BY created_at DESC
`
return d.queryTokens(ctx, query)
}
// GetTokensByProjectID retrieves all tokens for a project.
func (d *DB) GetTokensByProjectID(ctx context.Context, projectID string) ([]Token, error) {
query := `
SELECT id, token, project_id, expires_at, is_active, deactivated_at, request_count, max_requests, created_at, last_used_at, cache_hit_count
FROM tokens
WHERE project_id = ?
ORDER BY created_at DESC
`
return d.queryTokens(ctx, query, projectID)
}
// IncrementTokenUsage increments the request count and updates the last_used_at timestamp.
func (d *DB) IncrementTokenUsage(ctx context.Context, tokenID string) error {
if tokenID == "" {
return fmt.Errorf("token string is required")
}
now := time.Now().UTC()
// Enforce max_requests atomically. max_requests is treated as unlimited when NULL/<=0
// (the API layer should normalize 0 to NULL, but we keep DB logic defensive).
query := `
UPDATE tokens
SET request_count = request_count + 1, last_used_at = ?
WHERE token = ?
AND is_active = TRUE
AND (expires_at IS NULL OR expires_at > ?)
AND (
max_requests IS NULL
OR max_requests <= 0
OR request_count < max_requests
)
`
result, err := d.ExecContextRebound(ctx, query, now, tokenID, now)
if err != nil {
return fmt.Errorf("failed to increment token usage: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
// No rows updated means either:
// - token doesn't exist
// - token is inactive
// - token is expired
// - token is already at quota
//
// We do a follow-up SELECT to return a semantically meaningful error.
var (
isActive bool
expiresAt sql.NullTime
requestCount int
maxRequests sql.NullInt32
)
checkQuery := `SELECT is_active, expires_at, request_count, max_requests FROM tokens WHERE token = ?`
err := d.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&isActive, &expiresAt, &requestCount, &maxRequests)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to check token usage for %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err)
}
if !isActive {
return token.ErrTokenInactive
}
if expiresAt.Valid {
exp := expiresAt.Time.UTC()
if now.After(exp) {
return token.ErrTokenExpired
}
}
if maxRequests.Valid && maxRequests.Int32 > 0 {
if requestCount >= int(maxRequests.Int32) {
return token.ErrTokenRateLimit
}
}
return fmt.Errorf("failed to increment token usage for %s: no rows updated", obfuscate.ObfuscateTokenGeneric(tokenID))
}
return nil
}
// IncrementTokenUsageBatch increments request_count for multiple tokens and updates last_used_at.
// The token IDs are token strings (sk-...).
func (d *DB) IncrementTokenUsageBatch(ctx context.Context, deltas map[string]int, lastUsedAt time.Time) error {
if len(deltas) == 0 {
return nil
}
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
_ = tx.Rollback() // No-op if already committed
}()
query := `UPDATE tokens SET request_count = request_count + ?, last_used_at = ? WHERE token = ?`
stmt, err := tx.PrepareContext(ctx, d.RebindQuery(query))
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer func() {
_ = stmt.Close()
}()
updated := int64(0)
missing := int64(0)
for tokenID, delta := range deltas {
if delta <= 0 {
continue
}
res, err := stmt.ExecContext(ctx, delta, lastUsedAt.UTC(), tokenID)
if err != nil {
return fmt.Errorf("failed to increment token usage for token %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err)
}
rows, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected for token %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err)
}
if rows == 0 {
// Tokens can be deleted/revoked while async events are buffered.
// Skipping missing tokens avoids discarding successful increments for other tokens.
missing++
continue
}
updated += rows
}
// If *all* requested updates were for missing tokens, surface ErrTokenNotFound to catch
// misconfiguration/bugs (e.g. wrong identifier type) without being noisy for partial misses.
if updated == 0 && missing > 0 {
return ErrTokenNotFound
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
// CleanExpiredTokens deletes expired tokens from the database.
func (d *DB) CleanExpiredTokens(ctx context.Context) (int64, error) {
now := time.Now().UTC()
query := `
DELETE FROM tokens
WHERE expires_at IS NOT NULL AND expires_at < ?
`
result, err := d.ExecContextRebound(ctx, query, now)
if err != nil {
return 0, fmt.Errorf("failed to clean expired tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return rowsAffected, nil
}
// queryTokens is a helper function to query tokens.
func (d *DB) queryTokens(ctx context.Context, query string, args ...interface{}) ([]Token, error) {
rows, err := d.QueryContextRebound(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("failed to query tokens: %w", err)
}
defer func() {
_ = rows.Close()
}()
var tokens []Token
for rows.Next() {
var token Token
var expiresAt, lastUsedAt, deactivatedAt sql.NullTime
var maxRequests sql.NullInt32
if err := rows.Scan(
&token.ID,
&token.Token,
&token.ProjectID,
&expiresAt,
&token.IsActive,
&deactivatedAt,
&token.RequestCount,
&maxRequests,
&token.CreatedAt,
&lastUsedAt,
&token.CacheHitCount,
); err != nil {
return nil, fmt.Errorf("failed to scan token: %w", err)
}
if expiresAt.Valid {
token.ExpiresAt = &expiresAt.Time
}
if lastUsedAt.Valid {
token.LastUsedAt = &lastUsedAt.Time
}
if deactivatedAt.Valid {
token.DeactivatedAt = &deactivatedAt.Time
}
if maxRequests.Valid {
maxReq := int(maxRequests.Int32)
token.MaxRequests = &maxReq
}
tokens = append(tokens, token)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating tokens: %w", err)
}
return tokens, nil
}
// --- token.TokenStore interface adapter for *DB ---
type DBTokenStoreAdapter struct {
db *DB
}
func NewDBTokenStoreAdapter(db *DB) *DBTokenStoreAdapter {
return &DBTokenStoreAdapter{db: db}
}
func (a *DBTokenStoreAdapter) GetTokenByID(ctx context.Context, id string) (token.TokenData, error) {
dbToken, err := a.db.GetTokenByID(ctx, id)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.TokenData{}, token.ErrTokenNotFound
}
return token.TokenData{}, err
}
return ExportTokenData(dbToken), nil
}
func (a *DBTokenStoreAdapter) GetTokenByToken(ctx context.Context, tokenString string) (token.TokenData, error) {
dbToken, err := a.db.GetTokenByToken(ctx, tokenString)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return token.TokenData{}, token.ErrTokenNotFound
}
return token.TokenData{}, err
}
return ExportTokenData(dbToken), nil
}
func (a *DBTokenStoreAdapter) IncrementTokenUsage(ctx context.Context, tokenID string) error {
return a.db.IncrementTokenUsage(ctx, tokenID)
}
func (a *DBTokenStoreAdapter) CreateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.db.CreateToken(ctx, dbToken)
}
func (a *DBTokenStoreAdapter) UpdateToken(ctx context.Context, td token.TokenData) error {
dbToken := ImportTokenData(td)
return a.db.UpdateToken(ctx, dbToken)
}
func (a *DBTokenStoreAdapter) ListTokens(ctx context.Context) ([]token.TokenData, error) {
dbTokens, err := a.db.ListTokens(ctx)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
func (a *DBTokenStoreAdapter) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
dbTokens, err := a.db.GetTokensByProjectID(ctx, projectID)
if err != nil {
return nil, err
}
tokens := make([]token.TokenData, len(dbTokens))
for i, t := range dbTokens {
tokens[i] = ExportTokenData(t)
}
return tokens, nil
}
// ImportTokenData and ExportTokenData helpers
func ImportTokenData(td token.TokenData) Token {
return Token{
ID: td.ID,
Token: td.Token,
ProjectID: td.ProjectID,
ExpiresAt: td.ExpiresAt,
IsActive: td.IsActive,
DeactivatedAt: td.DeactivatedAt,
RequestCount: td.RequestCount,
MaxRequests: td.MaxRequests,
CreatedAt: td.CreatedAt,
LastUsedAt: td.LastUsedAt,
CacheHitCount: td.CacheHitCount,
}
}
func ExportTokenData(t Token) token.TokenData {
return token.TokenData{
ID: t.ID,
Token: t.Token,
ProjectID: t.ProjectID,
ExpiresAt: t.ExpiresAt,
IsActive: t.IsActive,
DeactivatedAt: t.DeactivatedAt,
RequestCount: t.RequestCount,
MaxRequests: t.MaxRequests,
CreatedAt: t.CreatedAt,
LastUsedAt: t.LastUsedAt,
CacheHitCount: t.CacheHitCount,
}
}
// --- RevocationStore interface implementation ---
// RevokeToken disables a token by setting is_active to false and deactivated_at to current time
func (a *DBTokenStoreAdapter) RevokeToken(ctx context.Context, tokenID string) error {
if tokenID == "" {
return token.ErrTokenNotFound
}
now := time.Now().UTC()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE token = ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, tokenID, true)
if err != nil {
return fmt.Errorf("failed to revoke token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check rows affected: %w", err)
}
if rowsAffected == 0 {
// Check if token exists at all (could be already inactive)
var exists bool
checkQuery := `SELECT 1 FROM tokens WHERE token = ? LIMIT 1`
err = a.db.QueryRowContextRebound(ctx, checkQuery, tokenID).Scan(&exists)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return token.ErrTokenNotFound
}
return fmt.Errorf("failed to check token existence: %w", err)
}
// Token exists but was already inactive - this is idempotent, no error
}
return nil
}
// DeleteToken completely removes a token from storage
func (a *DBTokenStoreAdapter) DeleteToken(ctx context.Context, tokenID string) error {
if tokenID == "" {
return token.ErrTokenNotFound
}
result, err := a.db.ExecContextRebound(ctx, "DELETE FROM tokens WHERE token = ?", tokenID)
if err != nil {
return fmt.Errorf("failed to delete token: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return token.ErrTokenNotFound
}
return nil
}
// RevokeBatchTokens revokes multiple tokens at once
func (a *DBTokenStoreAdapter) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
if len(tokenIDs) == 0 {
return 0, nil
}
now := time.Now().UTC()
placeholders := make([]string, len(tokenIDs))
args := make([]interface{}, len(tokenIDs)+3)
args[0] = false
args[1] = now
for i, tokenID := range tokenIDs {
placeholders[i] = "?"
args[i+2] = tokenID
}
// Append active-state filter parameter
args[len(args)-1] = true
query := fmt.Sprintf(`UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE token IN (%s) AND is_active = ?`, strings.Join(placeholders, ","))
result, err := a.db.ExecContextRebound(ctx, query, args...)
if err != nil {
return 0, fmt.Errorf("failed to revoke batch tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// RevokeProjectTokens revokes all tokens for a project
func (a *DBTokenStoreAdapter) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
if projectID == "" {
return 0, nil
}
now := time.Now().UTC()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE project_id = ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, projectID, true)
if err != nil {
return 0, fmt.Errorf("failed to revoke project tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// RevokeExpiredTokens revokes all tokens that have expired
func (a *DBTokenStoreAdapter) RevokeExpiredTokens(ctx context.Context) (int, error) {
now := time.Now().UTC()
query := `UPDATE tokens SET is_active = ?, deactivated_at = COALESCE(deactivated_at, ?) WHERE expires_at IS NOT NULL AND expires_at < ? AND is_active = ?`
result, err := a.db.ExecContextRebound(ctx, query, false, now, now, true)
if err != nil {
return 0, fmt.Errorf("failed to revoke expired tokens: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("failed to get rows affected: %w", err)
}
return int(rowsAffected), nil
}
// IncrementCacheHitCount increments the cache_hit_count for a single token.
func (d *DB) IncrementCacheHitCount(ctx context.Context, tokenID string, delta int) error {
if delta <= 0 {
return nil
}
query := `UPDATE tokens SET cache_hit_count = cache_hit_count + ? WHERE token = ?`
_, err := d.ExecContextRebound(ctx, query, delta, tokenID)
if err != nil {
return fmt.Errorf("failed to increment cache hit count: %w", err)
}
return nil
}
// IncrementCacheHitCountBatch increments cache_hit_count for multiple tokens in batch.
// The deltas map has token IDs as keys and increment values as values.
func (d *DB) IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error {
if len(deltas) == 0 {
return nil
}
// Use a transaction for batch updates
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
_ = tx.Rollback() // No-op if already committed
}()
query := `UPDATE tokens SET cache_hit_count = cache_hit_count + ? WHERE token = ?`
stmt, err := tx.PrepareContext(ctx, d.RebindQuery(query))
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer func() {
_ = stmt.Close()
}()
for tokenID, delta := range deltas {
if delta <= 0 {
continue
}
if _, err := stmt.ExecContext(ctx, delta, tokenID); err != nil {
return fmt.Errorf("failed to increment cache hit count for token %s: %w", obfuscate.ObfuscateTokenGeneric(tokenID), err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
package database
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
)
// Placeholder returns the appropriate placeholder for the driver.
// For SQLite and MySQL: ?, for PostgreSQL: $1, $2, etc.
func (d *DB) Placeholder(n int) string {
if d.driver == DriverPostgres {
return fmt.Sprintf("$%d", n)
}
return "?"
}
// Placeholders returns a slice of placeholders for the driver.
// For n=3: SQLite/MySQL return ["?", "?", "?"], PostgreSQL returns ["$1", "$2", "$3"].
func (d *DB) Placeholders(n int) []string {
result := make([]string, n)
for i := 0; i < n; i++ {
result[i] = d.Placeholder(i + 1)
}
return result
}
// PlaceholderList returns a comma-separated list of placeholders.
// For n=3: SQLite/MySQL return "?, ?, ?", PostgreSQL returns "$1, $2, $3".
func (d *DB) PlaceholderList(n int) string {
return strings.Join(d.Placeholders(n), ", ")
}
// RebindQuery converts a query from ? placeholders to the appropriate
// placeholder style for the database driver.
//
// IMPORTANT: This function performs a simple character replacement and does NOT
// handle ? characters inside SQL string literals (e.g., "WHERE name = 'what?'").
// Since this codebase exclusively uses parameterized queries with ? as placeholders,
// this limitation does not affect normal usage. If you need to use literal ? in
// string values, use parameterized queries: "WHERE name = ?" with the value passed
// as an argument.
func (d *DB) RebindQuery(query string) string {
if d.driver != DriverPostgres {
return query
}
// Convert ? to $1, $2, $3, etc. (single pass for better performance)
var builder strings.Builder
builder.Grow(len(query) + 10) // pre-allocate with some buffer
count := 0
for i := 0; i < len(query); i++ {
if query[i] == '?' {
count++
builder.WriteString(fmt.Sprintf("$%d", count))
} else {
builder.WriteByte(query[i])
}
}
return builder.String()
}
// ExecContextRebound executes a query with automatic placeholder rebinding.
func (d *DB) ExecContextRebound(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return d.db.ExecContext(ctx, d.RebindQuery(query), args...)
}
// QueryRowContextRebound queries a single row with automatic placeholder rebinding.
func (d *DB) QueryRowContextRebound(ctx context.Context, query string, args ...interface{}) *sql.Row {
return d.db.QueryRowContext(ctx, d.RebindQuery(query), args...)
}
// QueryContextRebound queries multiple rows with automatic placeholder rebinding.
func (d *DB) QueryContextRebound(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return d.db.QueryContext(ctx, d.RebindQuery(query), args...)
}
// BackupDatabase creates a backup of the database.
// Note: This function is SQLite-specific. For PostgreSQL, use pg_dump. For MySQL, use mysqldump.
func (d *DB) BackupDatabase(ctx context.Context, backupPath string) error {
if d.driver == DriverPostgres {
return fmt.Errorf("backup not supported for PostgreSQL via this method; use pg_dump")
}
if d.driver == DriverMySQL {
return fmt.Errorf("backup not supported for MySQL via this method; use mysqldump")
}
// Validate the backupPath to ensure it is a valid file path
if backupPath == "" {
return fmt.Errorf("backup path cannot be empty")
}
// SQLite does not support parameterized VACUUM INTO, so we must sanitize the path
// Only allow simple file paths (no semicolons, no SQL metacharacters)
if len(backupPath) > 256 || backupPath[0] == '-' || backupPath[0] == '|' || backupPath[0] == ';' {
return fmt.Errorf("invalid backup path")
}
// For SQLite, we can use the VACUUM INTO statement to create a backup
query := fmt.Sprintf("VACUUM INTO '%s'", backupPath)
_, err := d.db.ExecContext(ctx, query)
if err != nil {
return fmt.Errorf("failed to backup database: %w", err)
}
return nil
}
// MaintainDatabase performs regular maintenance on the database.
// WARNING: VACUUM and ANALYZE can be expensive operations. In production, schedule this function to run periodically (e.g., daily) rather than on every call.
// The caller is responsible for scheduling.
func (d *DB) MaintainDatabase(ctx context.Context) error {
if d.driver == DriverPostgres {
// PostgreSQL uses VACUUM ANALYZE
_, err := d.db.ExecContext(ctx, "VACUUM ANALYZE")
if err != nil {
return fmt.Errorf("failed to vacuum analyze database: %w", err)
}
return nil
}
if d.driver == DriverMySQL {
// MySQL maintenance: Run ANALYZE on all tables in the current schema to update optimizer statistics.
// Query information_schema to get all table names dynamically to avoid schema drift.
rows, err := d.db.QueryContext(ctx, "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE'")
if err != nil {
return fmt.Errorf("failed to query table names: %w", err)
}
defer func() { _ = rows.Close() }()
var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return fmt.Errorf("failed to scan table name: %w", err)
}
tables = append(tables, tableName)
}
if err := rows.Err(); err != nil {
return fmt.Errorf("error iterating table names: %w", err)
}
// Run ANALYZE TABLE for each table
for _, table := range tables {
// Use parameterized query to prevent SQL injection
query := fmt.Sprintf("ANALYZE TABLE %s", table)
if _, err := d.db.ExecContext(ctx, query); err != nil {
return fmt.Errorf("failed to analyze table %s: %w", table, err)
}
}
return nil
}
// SQLite-specific maintenance
// Run VACUUM to reclaim space and optimize the database
_, err := d.db.ExecContext(ctx, "VACUUM")
if err != nil {
return fmt.Errorf("failed to vacuum database: %w", err)
}
// Run PRAGMA optimize to optimize the database
_, err = d.db.ExecContext(ctx, "PRAGMA optimize")
if err != nil {
return fmt.Errorf("failed to optimize database: %w", err)
}
// Run ANALYZE to update statistics
_, err = d.db.ExecContext(ctx, "ANALYZE")
if err != nil {
return fmt.Errorf("failed to analyze database: %w", err)
}
return nil
}
// boolValue returns the appropriate boolean representation for the driver.
// SQLite uses 1/0, PostgreSQL and MySQL use true/false.
func (d *DB) boolValue(b bool) interface{} {
if d.driver == DriverPostgres || d.driver == DriverMySQL {
return b
}
if b {
return 1
}
return 0
}
// GetStats returns database statistics.
func (d *DB) GetStats(ctx context.Context) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// Get database size (driver-specific - these queries are fundamentally different)
var dbSize int64
switch d.driver {
case DriverPostgres:
err := d.db.QueryRowContext(ctx, "SELECT pg_database_size(current_database())").Scan(&dbSize)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
case DriverMySQL:
err := d.db.QueryRowContext(ctx, "SELECT COALESCE(SUM(data_length + index_length), 0) FROM information_schema.tables WHERE table_schema = DATABASE()").Scan(&dbSize)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
default:
err := d.db.QueryRowContext(ctx, "SELECT (SELECT page_count FROM pragma_page_count) * (SELECT page_size FROM pragma_page_size)").Scan(&dbSize)
if err != nil {
return nil, fmt.Errorf("failed to get database size: %w", err)
}
}
stats["database_size_bytes"] = dbSize
// Count active projects
var projectCount int
err := d.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM projects").Scan(&projectCount)
if err != nil {
return nil, fmt.Errorf("failed to count projects: %w", err)
}
stats["project_count"] = projectCount
// Count active tokens using QueryRowContextRebound for placeholder rebinding
var activeTokens int
err = d.QueryRowContextRebound(ctx,
"SELECT COUNT(*) FROM tokens WHERE is_active = ? AND (expires_at IS NULL OR expires_at > ?)",
d.boolValue(true), time.Now().UTC()).Scan(&activeTokens)
if err != nil {
return nil, fmt.Errorf("failed to count active tokens: %w", err)
}
stats["active_token_count"] = activeTokens
// Count expired tokens using QueryRowContextRebound for placeholder rebinding
var expiredTokens int
err = d.QueryRowContextRebound(ctx,
"SELECT COUNT(*) FROM tokens WHERE expires_at IS NOT NULL AND expires_at <= ?",
time.Now().UTC()).Scan(&expiredTokens)
if err != nil {
return nil, fmt.Errorf("failed to count expired tokens: %w", err)
}
stats["expired_token_count"] = expiredTokens
// Count total request count
var totalRequests sql.NullInt64
err = d.db.QueryRowContext(ctx, "SELECT SUM(request_count) FROM tokens").Scan(&totalRequests)
if err != nil {
return nil, fmt.Errorf("failed to sum request counts: %w", err)
}
if totalRequests.Valid {
stats["total_request_count"] = totalRequests.Int64
} else {
stats["total_request_count"] = int64(0)
}
return stats, nil
}
// IsTokenValid checks if a token is valid (exists, is active, not expired, and not rate limited).
func (d *DB) IsTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := d.GetTokenByID(ctx, tokenID)
if err != nil {
if err == ErrTokenNotFound {
return false, nil
}
return false, err
}
return token.IsValid(), nil
}
package dispatcher
// PermanentBackendError is a custom error type for permanent backend errors (e.g., Helicone 500s that should not be retried)
type PermanentBackendError struct {
Msg string
}
func (e *PermanentBackendError) Error() string {
return e.Msg
}
package plugins
import (
"context"
"encoding/json"
"fmt"
"os"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// FilePlugin implements file-based event logging
type FilePlugin struct {
filePath string
file *os.File
}
// NewFilePlugin creates a new file plugin
func NewFilePlugin() *FilePlugin {
return &FilePlugin{}
}
// Init initializes the file plugin with configuration
func (p *FilePlugin) Init(cfg map[string]string) error {
filePath, ok := cfg["endpoint"]
if !ok || filePath == "" {
return fmt.Errorf("file plugin requires 'endpoint' configuration (file path)")
}
p.filePath = filePath
// Open file for writing
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", filePath, err)
}
p.file = file
return nil
}
// SendEvents writes events to the file as JSONL (JSON Lines)
func (p *FilePlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if p.file == nil {
return fmt.Errorf("file plugin not initialized")
}
for _, event := range events {
line, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
// Write JSON line with newline
if _, err := p.file.Write(append(line, '\n')); err != nil {
return fmt.Errorf("failed to write to file: %w", err)
}
}
// Ensure data is written to disk
if err := p.file.Sync(); err != nil {
return fmt.Errorf("failed to sync file: %w", err)
}
return nil
}
// Close closes the file
func (p *FilePlugin) Close() error {
if p.file != nil {
return p.file.Close()
}
return nil
}
package plugins
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// HeliconePlugin implements Helicone backend integration
type HeliconePlugin struct {
apiKey string
endpoint string
client *http.Client
}
// NewHeliconePlugin creates a new Helicone plugin
func NewHeliconePlugin() *HeliconePlugin {
return &HeliconePlugin{
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Init initializes the Helicone plugin with configuration
func (p *HeliconePlugin) Init(cfg map[string]string) error {
apiKey, ok := cfg["api-key"]
if !ok || apiKey == "" {
return fmt.Errorf("helicone plugin requires 'api-key' configuration")
}
endpoint, ok := cfg["endpoint"]
if !ok || endpoint == "" {
endpoint = "https://api.worker.helicone.ai/custom/v1/log"
}
p.apiKey = apiKey
p.endpoint = endpoint
return nil
}
// SendEvents sends events to Helicone
func (p *HeliconePlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if len(events) == 0 {
return nil
}
for _, event := range events {
// Skip events with empty output
if len(event.Output) == 0 && event.OutputBase64 == "" {
log.Printf("[helicone] Skipping event with empty output: RunID=%s, Path=%v", event.RunID, event.Metadata["path"])
continue
}
payload, err := heliconePayloadFromEvent(event)
if err != nil {
return err
}
if err := p.sendHeliconeEvent(ctx, payload); err != nil {
// Print payload for debugging on error
log.Printf("[helicone] Error sending event. Payload: %s", mustMarshalJSON(payload))
return err
}
}
return nil
}
// heliconePayloadFromEvent maps EventPayload to Helicone manual logger format
func heliconePayloadFromEvent(event dispatcher.EventPayload) (map[string]interface{}, error) {
// Extract request and response bodies
var reqBody map[string]interface{}
var respBody map[string]interface{}
isJSON := false
if len(event.Input) > 0 {
_ = json.Unmarshal(event.Input, &reqBody)
}
if reqBody == nil {
reqBody = map[string]interface{}{}
}
if len(event.Output) > 0 {
if err := json.Unmarshal(event.Output, &respBody); err == nil {
isJSON = true
}
}
// Timing (use event.Timestamp for both if no better info)
timestamp := event.Timestamp
sec := timestamp.Unix()
ms := timestamp.Nanosecond() / 1e6
timing := map[string]interface{}{
"startTime": map[string]int64{"seconds": sec, "milliseconds": int64(ms)},
"endTime": map[string]int64{"seconds": sec, "milliseconds": int64(ms)},
}
// Meta
meta := map[string]string{}
if event.UserID != nil {
meta["Helicone-User-Id"] = *event.UserID
}
if event.Extra != nil {
for k, v := range event.Extra {
if s, ok := v.(string); ok {
meta[k] = s
}
}
}
if event.Metadata != nil {
for k, v := range event.Metadata {
if s, ok := v.(string); ok {
meta[k] = s
}
}
// Ensure request_id is propagated if present in metadata
if rid, ok := event.Metadata["request_id"].(string); ok && rid != "" {
meta["request_id"] = rid
}
}
// Help Helicone infer pricing/provider when using manual logger
if _, ok := meta["provider"]; !ok {
meta["provider"] = "openai"
}
// ProviderRequest
// Prefer actual API path from metadata if present so Helicone can infer provider
urlPath := "custom-model-nopath"
if event.Metadata != nil {
if pth, ok := event.Metadata["path"].(string); ok && pth != "" {
urlPath = pth
}
}
providerRequest := map[string]interface{}{
"url": urlPath,
"json": reqBody,
"meta": meta,
}
// ProviderResponse
status := 200
if s, ok := event.Metadata["status"].(int); ok {
status = s
}
providerResponse := map[string]interface{}{
"status": status,
"headers": map[string]string{}, // Optionally fill from event.Metadata
}
// Ensure Helicone sees the provider as non-CUSTOM when using manual logger
if _, ok := meta["Helicone-Provider"]; !ok {
// Mirror meta["provider"] if set, otherwise default to openai
prov := meta["provider"]
if prov == "" {
prov = "openai"
}
meta["Helicone-Provider"] = prov
}
// Helicone requires providerResponse.json to be present; ensure it's always an object
if isJSON && respBody != nil {
providerResponse["json"] = respBody
} else {
providerResponse["json"] = map[string]interface{}{}
providerResponse["note"] = "response was not JSON; omitted"
}
// Inject usage if we computed it so Helicone can compute cost
if event.TokensUsage != nil {
if m, ok := providerResponse["json"].(map[string]interface{}); ok {
m["usage"] = map[string]int{
"prompt_tokens": event.TokensUsage.Prompt,
"completion_tokens": event.TokensUsage.Completion,
"total_tokens": event.TokensUsage.Prompt + event.TokensUsage.Completion,
}
}
}
// If OutputBase64 is set, include as base64
if event.OutputBase64 != "" {
providerResponse["base64"] = event.OutputBase64
}
return map[string]interface{}{
"providerRequest": providerRequest,
"providerResponse": providerResponse,
"timing": timing,
}, nil
}
// sendHeliconeEvent sends a single event to Helicone manual logger endpoint
func (p *HeliconePlugin) sendHeliconeEvent(ctx context.Context, payload map[string]interface{}) error {
data, err := json.Marshal(payload)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == 500 {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return &dispatcher.PermanentBackendError{Msg: fmt.Sprintf("helicone API returned status 500: %s", buf.String())}
}
if resp.StatusCode == http.StatusBadRequest { // 400: treat as permanent (payload/schema issue)
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return &dispatcher.PermanentBackendError{Msg: fmt.Sprintf("helicone API returned status 400: %s", buf.String())}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(resp.Body)
log.Printf("Helicone API error %d: %s", resp.StatusCode, buf.String())
return fmt.Errorf("helicone API returned status %d", resp.StatusCode)
}
return nil
}
// Close cleans up the plugin resources
func (p *HeliconePlugin) Close() error {
// Nothing to clean up for HTTP client
return nil
}
// mustMarshalJSON marshals v to JSON or returns an error string
func mustMarshalJSON(v interface{}) string {
b, err := json.Marshal(v)
if err != nil {
return "<marshal error>"
}
return string(b)
}
package plugins
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// LunaryPlugin implements Lunary.ai backend integration
type LunaryPlugin struct {
apiKey string
endpoint string
client *http.Client
}
// NewLunaryPlugin creates a new Lunary plugin
func NewLunaryPlugin() *LunaryPlugin {
return &LunaryPlugin{
client: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Init initializes the Lunary plugin with configuration
func (p *LunaryPlugin) Init(cfg map[string]string) error {
apiKey, ok := cfg["api-key"]
if !ok || apiKey == "" {
return fmt.Errorf("lunary plugin requires 'api-key' configuration")
}
endpoint, ok := cfg["endpoint"]
if !ok || endpoint == "" {
endpoint = "https://api.lunary.ai/v1/runs/ingest"
}
p.apiKey = apiKey
p.endpoint = endpoint
return nil
}
// SendEvents sends events to Lunary.ai
func (p *LunaryPlugin) SendEvents(ctx context.Context, events []dispatcher.EventPayload) error {
if len(events) == 0 {
return nil
}
// Lunary expects an array of events
data, err := json.Marshal(events)
if err != nil {
return fmt.Errorf("failed to marshal events: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, bytes.NewReader(data))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("[lunary] failed to close response body: %v", err)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("lunary API returned status %d", resp.StatusCode)
}
return nil
}
// Close cleans up the plugin resources
func (p *LunaryPlugin) Close() error {
// Nothing to clean up for HTTP client
return nil
}
package plugins
import (
"fmt"
"github.com/sofatutor/llm-proxy/internal/dispatcher"
)
// PluginFactory is a function that creates a new plugin instance
type PluginFactory func() dispatcher.BackendPlugin
// Registry holds all available plugin factories
var Registry = make(map[string]PluginFactory)
// init registers all built-in plugins
func init() {
Registry["file"] = func() dispatcher.BackendPlugin {
return NewFilePlugin()
}
Registry["lunary"] = func() dispatcher.BackendPlugin {
return NewLunaryPlugin()
}
Registry["helicone"] = func() dispatcher.BackendPlugin {
return NewHeliconePlugin()
}
}
// NewPlugin creates a new plugin instance by name
func NewPlugin(name string) (dispatcher.BackendPlugin, error) {
factory, exists := Registry[name]
if !exists {
return nil, fmt.Errorf("unknown plugin: %s", name)
}
return factory(), nil
}
// ListPlugins returns a list of available plugin names
func ListPlugins() []string {
var names []string
for name := range Registry {
names = append(names, name)
}
return names
}
package dispatcher
import (
"context"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"go.uber.org/zap"
)
const (
// maxBackoffDuration is the maximum backoff duration for retries
maxBackoffDuration = 30 * time.Second
// maxHealthyLagCount is the maximum number of pending messages before the dispatcher is considered unhealthy
maxHealthyLagCount = 10000
// maxInactivityDuration is the maximum time without processing activity before the dispatcher is considered unhealthy
maxInactivityDuration = 5 * time.Minute
// metricsUpdateInterval is the interval at which metrics are updated
metricsUpdateInterval = 10 * time.Second
)
// Config holds configuration for the dispatcher service
type Config struct {
BufferSize int
BatchSize int
FlushInterval time.Duration
RetryAttempts int
RetryBackoff time.Duration
Plugin BackendPlugin
EventTransformer EventTransformer
PluginName string
Verbose bool // If true, include response_headers and extra debug info
}
// Service represents the event dispatcher service
type Service struct {
config Config
eventBus eventbus.EventBus
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
stopOnce sync.Once
// startedCh is closed after the event processing goroutine has been added
// to the WaitGroup. This avoids a data race between Wait() and Add(1).
startedCh chan struct{}
// metrics
mu sync.Mutex
eventsProcessed int64
eventsDropped int64
eventsSent int64
lastProcessedAt time.Time
processingRate float64 // events per second
lagCount int64 // current lag (pending messages)
streamLength int64 // total messages in stream
}
// NewService creates a new dispatcher service
func NewService(cfg Config, logger *zap.Logger) (*Service, error) {
if cfg.Plugin == nil {
return nil, fmt.Errorf("backend plugin is required")
}
if cfg.EventTransformer == nil {
cfg.EventTransformer = NewDefaultEventTransformer(cfg.Verbose)
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Second
}
if cfg.RetryAttempts <= 0 {
cfg.RetryAttempts = 3
}
if cfg.RetryBackoff <= 0 {
cfg.RetryBackoff = time.Second
}
if logger == nil {
logger = zap.NewNop()
}
// Create event bus for the dispatcher
bus := eventbus.NewInMemoryEventBus(cfg.BufferSize)
return &Service{
config: cfg,
eventBus: bus,
logger: logger,
stopCh: make(chan struct{}),
startedCh: make(chan struct{}),
}, nil
}
// NewServiceWithBus creates a new dispatcher service with a provided event bus.
func NewServiceWithBus(cfg Config, logger *zap.Logger, bus eventbus.EventBus) (*Service, error) {
if cfg.Plugin == nil {
return nil, fmt.Errorf("backend plugin is required")
}
if cfg.EventTransformer == nil {
cfg.EventTransformer = NewDefaultEventTransformer(cfg.Verbose)
}
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Second
}
if cfg.RetryAttempts <= 0 {
cfg.RetryAttempts = 3
}
if cfg.RetryBackoff <= 0 {
cfg.RetryBackoff = time.Second
}
if logger == nil {
logger = zap.NewNop()
}
if bus == nil {
return nil, fmt.Errorf("event bus must not be nil")
}
return &Service{
config: cfg,
eventBus: bus,
logger: logger,
stopCh: make(chan struct{}),
startedCh: make(chan struct{}),
}, nil
}
// Run starts the dispatcher service and blocks until stopped
func (s *Service) Run(ctx context.Context, detach bool) error {
s.logger.Info("Starting event dispatcher service")
if detach {
return s.runDetached(ctx)
}
return s.runForeground(ctx)
}
// runDetached runs the service in background mode
func (s *Service) runDetached(ctx context.Context) error {
s.logger.Info("Running in detached mode")
// For detached mode, we still run in foreground but could be enhanced
// to fork the process or use systemd/supervisor in production
return s.runForeground(ctx)
}
// runForeground runs the service in foreground mode
func (s *Service) runForeground(ctx context.Context) error {
// Handle graceful shutdown
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
// Start background goroutines.
// Add to the WaitGroup before closing startedCh to avoid racing Wait() with Add().
s.wg.Add(2)
close(s.startedCh)
go s.processEvents(ctx)
go func() {
defer s.wg.Done()
metricsTicker := time.NewTicker(metricsUpdateInterval)
defer metricsTicker.Stop()
s.trackMetrics(ctx, metricsTicker.C)
}()
// Wait for shutdown signal
select {
case <-sigs:
s.logger.Info("Received shutdown signal")
case <-ctx.Done():
s.logger.Info("Context cancelled")
}
return s.Stop()
}
// Stop gracefully stops the dispatcher service
func (s *Service) Stop() error {
s.stopOnce.Do(func() {
s.logger.Info("Stopping event dispatcher service")
close(s.stopCh)
// Avoid racing Wait() with a concurrent Add(1) during startup.
select {
case <-s.startedCh:
s.wg.Wait()
default:
// Not started; nothing to wait for
}
if s.eventBus != nil {
s.eventBus.Stop()
}
if s.config.Plugin != nil {
if err := s.config.Plugin.Close(); err != nil {
s.logger.Error("Error closing plugin", zap.Error(err))
}
}
s.logger.Info("Event dispatcher service stopped")
})
return nil
}
// EventBus returns the event bus for this service (for connecting with other components)
func (s *Service) EventBus() eventbus.EventBus {
return s.eventBus
}
// processEvents handles the main event processing loop
func (s *Service) processEvents(ctx context.Context) {
defer s.wg.Done()
sub := s.eventBus.Subscribe()
batch := make([]EventPayload, 0, s.config.BatchSize)
ticker := time.NewTicker(s.config.FlushInterval)
defer ticker.Stop()
// Channel-based event bus (in-memory or Redis Streams)
for {
select {
case evt, ok := <-sub:
if !ok {
// Channel closed, flush remaining batch and exit
if len(batch) > 0 {
s.sendBatch(ctx, batch)
}
return
}
// Transform the event
payload, err := s.config.EventTransformer.Transform(evt)
if err != nil {
s.mu.Lock()
s.eventsDropped++
s.mu.Unlock()
s.logger.Error("Failed to transform event", zap.Error(err))
continue
}
if payload == nil {
// Event was filtered out (e.g., OPTIONS request)
continue
}
batch = append(batch, *payload)
s.mu.Lock()
s.eventsProcessed++
s.lastProcessedAt = time.Now()
s.mu.Unlock()
// Send batch if it's full
if len(batch) >= s.config.BatchSize {
s.sendBatch(ctx, batch)
batch = batch[:0] // Reset slice
}
case <-ticker.C:
// Flush batch on timer
if len(batch) > 0 {
s.sendBatch(ctx, batch)
batch = batch[:0] // Reset slice
}
case <-s.stopCh:
// Flush remaining batch and exit
if len(batch) > 0 {
s.sendBatch(ctx, batch)
}
return
}
}
}
// sendBatch sends a batch of events to the configured backend with exponential backoff retry logic
func (s *Service) sendBatch(ctx context.Context, batch []EventPayload) {
for attempt := 0; attempt <= s.config.RetryAttempts; attempt++ {
err := s.config.Plugin.SendEvents(ctx, batch)
if err == nil {
s.mu.Lock()
s.eventsSent += int64(len(batch))
s.mu.Unlock()
s.logger.Debug("Successfully sent batch",
zap.Int("batch_size", len(batch)),
zap.Int("attempt", attempt+1))
return
}
// Check for permanent errors that should not be retried
if _, ok := err.(*PermanentBackendError); ok {
s.logger.Warn("Permanent backend error, skipping batch", zap.Error(err), zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
return
}
if attempt < s.config.RetryAttempts {
// Exponential backoff: 2^attempt * base backoff
backoff := time.Duration(1<<uint(attempt)) * s.config.RetryBackoff
// Cap at 30 seconds
if backoff > maxBackoffDuration {
backoff = maxBackoffDuration
}
s.logger.Warn("Failed to send batch, retrying with exponential backoff",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("backoff", backoff))
select {
case <-time.After(backoff):
case <-ctx.Done():
return
case <-s.stopCh:
return
}
} else {
s.logger.Error("Failed to send batch after all retries",
zap.Error(err),
zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
}
}
}
// sendBatchWithResult sends a batch of events to the configured backend with exponential backoff retry logic and returns the result
func (s *Service) sendBatchWithResult(ctx context.Context, batch []EventPayload) error {
for attempt := 0; attempt <= s.config.RetryAttempts; attempt++ {
err := s.config.Plugin.SendEvents(ctx, batch)
if err == nil {
s.mu.Lock()
s.eventsSent += int64(len(batch))
s.mu.Unlock()
s.logger.Debug("Successfully sent batch",
zap.Int("batch_size", len(batch)),
zap.Int("attempt", attempt+1))
return nil
}
// If PermanentBackendError, treat as delivered and do not retry
if _, ok := err.(*PermanentBackendError); ok {
s.logger.Warn("Permanent backend error, skipping batch", zap.Error(err), zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
return nil // treat as delivered
}
if attempt < s.config.RetryAttempts {
// Exponential backoff: 2^attempt * base backoff
backoff := time.Duration(1<<uint(attempt)) * s.config.RetryBackoff
// Cap at 30 seconds
if backoff > maxBackoffDuration {
backoff = maxBackoffDuration
}
s.logger.Warn("Failed to send batch, retrying with exponential backoff",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("backoff", backoff))
select {
case <-time.After(backoff):
case <-ctx.Done():
return ctx.Err()
case <-s.stopCh:
return fmt.Errorf("stopped")
}
} else {
s.logger.Error("Failed to send batch after all retries",
zap.Error(err),
zap.Int("batch_size", len(batch)))
s.mu.Lock()
s.eventsDropped += int64(len(batch))
s.mu.Unlock()
return err
}
}
return fmt.Errorf("unreachable")
}
// trackMetrics periodically updates metrics like processing rate and Redis Streams lag.
//
// It exits when the service is stopped (s.stopCh) or ctx is canceled.
func (s *Service) trackMetrics(ctx context.Context, ticker <-chan time.Time) {
lastProcessed := int64(0)
lastTime := time.Now()
for {
select {
case <-ticker:
now := time.Now()
s.mu.Lock()
currentProcessed := s.eventsProcessed
s.mu.Unlock()
// Calculate processing rate
elapsed := now.Sub(lastTime).Seconds()
if elapsed > 0 {
rate := float64(currentProcessed-lastProcessed) / elapsed
s.mu.Lock()
s.processingRate = rate
s.mu.Unlock()
}
lastProcessed = currentProcessed
lastTime = now
// Update lag metrics for Redis Streams
if streamsBus, ok := s.eventBus.(*eventbus.RedisStreamsEventBus); ok {
if pending, err := streamsBus.PendingCount(ctx); err == nil {
s.mu.Lock()
s.lagCount = pending
s.mu.Unlock()
}
if length, err := streamsBus.StreamLength(ctx); err == nil {
s.mu.Lock()
s.streamLength = length
s.mu.Unlock()
}
}
case <-s.stopCh:
return
case <-ctx.Done():
return
}
}
}
// Stats returns service statistics
func (s *Service) Stats() (processed, dropped, sent int64) {
s.mu.Lock()
defer s.mu.Unlock()
return s.eventsProcessed, s.eventsDropped, s.eventsSent
}
// DetailedStats returns detailed service statistics including lag and rate
func (s *Service) DetailedStats() map[string]interface{} {
s.mu.Lock()
defer s.mu.Unlock()
return map[string]interface{}{
"events_processed": s.eventsProcessed,
"events_dropped": s.eventsDropped,
"events_sent": s.eventsSent,
"processing_rate": s.processingRate,
"lag_count": s.lagCount,
"stream_length": s.streamLength,
"last_processed_at": s.lastProcessedAt,
}
}
// HealthStatus represents the health status of the dispatcher.
type HealthStatus struct {
// Healthy indicates whether the dispatcher is currently healthy.
Healthy bool `json:"healthy"`
// Status is a human-readable string describing the current health status.
Status string `json:"status"`
// EventsProcessed is the total number of events processed by the dispatcher.
EventsProcessed int64 `json:"events_processed"`
// EventsDropped is the total number of events that were dropped and not processed.
EventsDropped int64 `json:"events_dropped"`
// EventsSent is the total number of events successfully sent to the backend.
EventsSent int64 `json:"events_sent"`
// ProcessingRate is the average number of events processed per second.
ProcessingRate float64 `json:"processing_rate"`
// LagCount is the number of pending messages in the event bus that have not yet been processed.
// For Redis Streams, this is the pending entries count (XPENDING).
LagCount int64 `json:"lag_count"`
// StreamLength is the total number of messages currently in the Redis Stream.
// For Redis Streams, this is the stream length (XLEN).
StreamLength int64 `json:"stream_length"`
// LastProcessedAt is the timestamp of the last successfully processed event.
LastProcessedAt time.Time `json:"last_processed_at"`
// Message provides additional information about the health status, if any.
Message string `json:"message,omitempty"`
}
// Health returns the health status of the dispatcher
func (s *Service) Health(ctx context.Context) HealthStatus {
s.mu.Lock()
stats := HealthStatus{
EventsProcessed: s.eventsProcessed,
EventsDropped: s.eventsDropped,
EventsSent: s.eventsSent,
ProcessingRate: s.processingRate,
LagCount: s.lagCount,
StreamLength: s.streamLength,
LastProcessedAt: s.lastProcessedAt,
}
s.mu.Unlock()
// Check if we're using Redis Streams
if streamsBus, ok := s.eventBus.(*eventbus.RedisStreamsEventBus); ok {
// Update lag and stream length from Redis
if pending, err := streamsBus.PendingCount(ctx); err == nil {
stats.LagCount = pending
}
if length, err := streamsBus.StreamLength(ctx); err == nil {
stats.StreamLength = length
}
// Consider unhealthy if lag is very high
if stats.LagCount > maxHealthyLagCount {
stats.Healthy = false
stats.Status = "unhealthy"
stats.Message = fmt.Sprintf("High lag: %d pending messages", stats.LagCount)
return stats
}
// Consider unhealthy if we haven't processed anything recently and there are pending messages
if !stats.LastProcessedAt.IsZero() && time.Since(stats.LastProcessedAt) > maxInactivityDuration && stats.LagCount > 0 {
stats.Healthy = false
stats.Status = "unhealthy"
stats.Message = fmt.Sprintf("No processing activity for %v with %d pending messages", time.Since(stats.LastProcessedAt), stats.LagCount)
return stats
}
}
stats.Healthy = true
stats.Status = "healthy"
return stats
}
package dispatcher
import (
"bytes"
"compress/gzip"
"encoding/json"
"io"
"log"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/andybalholm/brotli"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/eventtransformer"
)
// DefaultEventTransformer provides a basic transformation from eventbus.Event to EventPayload
// Verbose: if true, includes response_headers in metadata
// Use NewDefaultEventTransformer(verbose) to construct
type DefaultEventTransformer struct {
Verbose bool
}
// NewDefaultEventTransformer creates a transformer with the given verbosity
func NewDefaultEventTransformer(verbose bool) *DefaultEventTransformer {
return &DefaultEventTransformer{Verbose: verbose}
}
// cleanJSONBinary recursively replaces binary fields in a JSON object with a placeholder
func cleanJSONBinary(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
for k, v2 := range val {
val[k] = cleanJSONBinary(v2)
}
return val
case []interface{}:
for i, v2 := range val {
val[i] = cleanJSONBinary(v2)
}
return val
case string:
if !utf8.ValidString(val) {
return "<binary omitted>"
}
return val
case []byte:
if !utf8.Valid(val) {
return "<binary omitted>"
}
return string(val)
default:
return val
}
}
// safeRawMessageOrBase64 tries to decode data as JSON, decompressing with gzip or brotli if needed, then as UTF-8 string, else returns base64 string
// If Content-Type is JSON, always return cleaned JSON (with binary fields replaced)
func safeRawMessageOrBase64(data []byte, headers map[string][]string) (json.RawMessage, string) {
if len(data) == 0 {
return nil, ""
}
var js json.RawMessage
// Check for Content-Encoding
encoding := ""
contentType := ""
if headers != nil {
if v, ok := headers["Content-Encoding"]; ok && len(v) > 0 {
encoding = v[0]
}
if v, ok := headers["Content-Type"]; ok && len(v) > 0 {
contentType = v[0]
}
}
// If Content-Type is multipart, return a placeholder
if strings.Contains(contentType, "multipart") {
return []byte(strconv.Quote("<multipart response omitted>")), ""
}
decompressed := data
var decompressErr error
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, decompressErr = io.ReadAll(zr)
_ = zr.Close()
} else {
decompressErr = err
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
var err error
decompressed, err = io.ReadAll(br)
if err != nil {
decompressErr = err
}
}
// If Content-Type is JSON, always return cleaned JSON
if strings.Contains(contentType, "json") {
var obj interface{}
if json.Unmarshal(decompressed, &obj) == nil {
cleaned := cleanJSONBinary(obj)
if jsBytes, err := json.Marshal(cleaned); err == nil {
return jsBytes, ""
}
}
}
if decompressErr == nil && json.Unmarshal(decompressed, &js) == nil {
return js, ""
} else if decompressErr != nil {
log.Printf("[transformer] Decompression failed: %v", decompressErr)
} else if json.Unmarshal(decompressed, &js) != nil {
if strings.Contains(contentType, "json") {
log.Printf("[transformer] JSON unmarshal after decompress failed: %v First 64 bytes: %x", decompressErr, decompressed[:min(64, len(decompressed))])
}
}
// Try direct JSON unmarshal if not already tried
if decompressErr != nil && json.Unmarshal(data, &js) == nil {
return js, ""
} else if decompressErr != nil {
log.Printf("[transformer] JSON unmarshal failed: %v First 64 bytes: %x", decompressErr, data[:min(64, len(data))])
}
// If valid UTF-8, try to parse as JSON string or OpenAI event stream
if utf8.Valid(decompressed) {
str := string(decompressed)
trim := strings.TrimSpace(str)
if (strings.HasPrefix(trim, "{") && strings.HasSuffix(trim, "}")) ||
(strings.HasPrefix(trim, "[") && strings.HasSuffix(trim, "]")) {
// Looks like JSON object/array in a string
if json.Unmarshal([]byte(trim), &js) == nil {
return js, ""
}
}
// OpenAI streaming or event lines
if eventtransformer.IsOpenAIStreaming(str) {
if merged, err := eventtransformer.MergeOpenAIStreamingChunks(str); err == nil {
if js, err := json.Marshal(merged); err == nil {
return js, ""
}
}
}
if strings.Contains(str, "event: ") && strings.Contains(str, "data: ") {
if merged, err := eventtransformer.MergeThreadStreamingChunks(str); err == nil {
if js, err := json.Marshal(merged); err == nil {
return js, ""
}
}
}
// Fallback: log as JSON string
quoted := []byte(strconv.Quote(str))
return quoted, ""
}
// For binary data, return a placeholder string instead of base64
return []byte(strconv.Quote("<binary response omitted>")), ""
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// Transform converts an eventbus.Event to an EventPayload
func (t *DefaultEventTransformer) Transform(evt eventbus.Event) (*EventPayload, error) {
// Skip non-POST requests (like OPTIONS, GET)
if evt.Method != "POST" {
return nil, nil
}
// Generate a unique run ID for this event
runID := uuid.New().String()
// Basic transformation
payload := &EventPayload{
Type: "llm",
Event: "start", // For now, all events are considered "start" events
RunID: runID,
Timestamp: time.Now(),
LogID: evt.LogID,
Metadata: map[string]any{
"method": evt.Method,
"path": evt.Path,
"status": evt.Status,
"duration_ms": evt.Duration.Milliseconds(),
"request_id": evt.RequestID,
},
}
// Add request body as input (JSON or base64)
if len(evt.RequestBody) > 0 {
if js, b64 := safeRawMessageOrBase64(evt.RequestBody, nil); js != nil {
payload.Input = js
} else {
payload.InputBase64 = b64
}
}
// --- OpenAI-specific output transformation ---
isOpenAI := strings.HasPrefix(evt.Path, "/v1/completions") ||
strings.HasPrefix(evt.Path, "/v1/chat/completions") ||
strings.HasPrefix(evt.Path, "/v1/threads/")
if isOpenAI && len(evt.ResponseBody) > 0 {
// Only use OpenAI transformer if response is valid JSON
if js := json.Valid(evt.ResponseBody); js {
openaiTransformer := &eventtransformer.OpenAITransformer{}
parsed, err := openaiTransformer.TransformEvent(map[string]any{
"response_body": string(evt.ResponseBody),
"path": evt.Path,
})
if err == nil && parsed != nil {
if js, err := json.Marshal(parsed); err == nil {
payload.Output = js
// Optionally extract token usage if present
if usage, ok := parsed["usage"].(map[string]any); ok {
payload.TokensUsage = &TokensUsage{
Prompt: int(usage["prompt_tokens"].(float64)),
Completion: int(usage["completion_tokens"].(float64)),
}
}
return payload, nil
}
}
// If parsing fails, fall through to generic logic
}
}
// Add response body as output (JSON or base64)
if len(evt.ResponseBody) > 0 {
if js, b64 := safeRawMessageOrBase64(evt.ResponseBody, evt.ResponseHeaders); js != nil {
payload.Output = js
} else {
payload.OutputBase64 = b64
}
}
// Add response headers to metadata only if Verbose is true
if t.Verbose && evt.ResponseHeaders != nil {
headers := make(map[string]any)
for k, v := range evt.ResponseHeaders {
if len(v) == 1 {
headers[k] = v[0]
} else {
headers[k] = v
}
}
payload.Metadata["response_headers"] = headers
}
return payload, nil
}
// Package encryption provides utilities for encrypting and decrypting sensitive data.
// It uses AES-256-GCM for symmetric encryption of data at rest.
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
const (
// KeySize is the required size for AES-256 encryption keys (32 bytes).
KeySize = 32
// NonceSize is the size of the GCM nonce (12 bytes).
NonceSize = 12
// EncryptedPrefix is added to encrypted values to identify them.
EncryptedPrefix = "enc:v1:"
)
var (
// ErrInvalidKeySize is returned when the encryption key has an invalid size.
ErrInvalidKeySize = errors.New("encryption key must be exactly 32 bytes")
// ErrDecryptionFailed is returned when decryption fails.
ErrDecryptionFailed = errors.New("decryption failed")
// ErrNoEncryptionKey is returned when no encryption key is configured.
ErrNoEncryptionKey = errors.New("no encryption key configured")
// ErrInvalidCiphertext is returned when the ciphertext is invalid.
ErrInvalidCiphertext = errors.New("invalid ciphertext format")
)
// Encryptor provides encryption and decryption operations.
// It is safe for concurrent use - cipher.AEAD implementations are thread-safe.
type Encryptor struct {
gcm cipher.AEAD
}
// NewEncryptor creates a new Encryptor with the given 32-byte key.
// The key must be exactly 32 bytes for AES-256 encryption.
func NewEncryptor(key []byte) (*Encryptor, error) {
if len(key) != KeySize {
return nil, ErrInvalidKeySize
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
return &Encryptor{gcm: gcm}, nil
}
// NewEncryptorFromBase64Key creates a new Encryptor from a base64-encoded key.
func NewEncryptorFromBase64Key(base64Key string) (*Encryptor, error) {
if base64Key == "" {
return nil, ErrNoEncryptionKey
}
key, err := base64.StdEncoding.DecodeString(base64Key)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
}
return NewEncryptor(key)
}
// Encrypt encrypts plaintext and returns a base64-encoded ciphertext with prefix.
func (e *Encryptor) Encrypt(plaintext string) (string, error) {
if plaintext == "" {
return "", nil // Empty strings are not encrypted
}
// Generate a random nonce
nonce := make([]byte, NonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the plaintext
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64 and add prefix
encoded := base64.StdEncoding.EncodeToString(ciphertext)
return EncryptedPrefix + encoded, nil
}
// Decrypt decrypts a base64-encoded ciphertext and returns the plaintext.
// If the value is not encrypted (no prefix), it returns the value as-is.
func (e *Encryptor) Decrypt(ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil // Empty strings are returned as-is
}
// Check if the value is encrypted
if !IsEncrypted(ciphertext) {
return ciphertext, nil // Return unencrypted values as-is (backward compatibility)
}
// Remove the prefix
encoded := ciphertext[len(EncryptedPrefix):]
// Decode from base64
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
// Validate minimum length (nonce + tag; plaintext can be empty)
if len(data) < NonceSize+e.gcm.Overhead() {
return "", ErrInvalidCiphertext
}
// Extract nonce and ciphertext
nonce := data[:NonceSize]
encryptedData := data[NonceSize:]
// Decrypt
plaintext, err := e.gcm.Open(nil, nonce, encryptedData, nil)
if err != nil {
return "", ErrDecryptionFailed
}
return string(plaintext), nil
}
// IsEncrypted checks if a value has the encryption prefix.
func IsEncrypted(value string) bool {
return len(value) > len(EncryptedPrefix) && value[:len(EncryptedPrefix)] == EncryptedPrefix
}
// GenerateKey generates a new random 32-byte encryption key.
func GenerateKey() ([]byte, error) {
key := make([]byte, KeySize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
return key, nil
}
// GenerateKeyBase64 generates a new random encryption key and returns it as base64.
func GenerateKeyBase64() (string, error) {
key, err := GenerateKey()
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// NullEncryptor is a no-op encryptor for when encryption is disabled.
type NullEncryptor struct{}
// NewNullEncryptor creates a new NullEncryptor.
func NewNullEncryptor() *NullEncryptor {
return &NullEncryptor{}
}
// Encrypt returns the plaintext as-is (no encryption).
func (e *NullEncryptor) Encrypt(plaintext string) (string, error) {
return plaintext, nil
}
// Decrypt returns the ciphertext as-is (no decryption).
func (e *NullEncryptor) Decrypt(ciphertext string) (string, error) {
return ciphertext, nil
}
// FieldEncryptor is an interface for encrypting and decrypting field values.
type FieldEncryptor interface {
Encrypt(plaintext string) (string, error)
Decrypt(ciphertext string) (string, error)
}
// Compile-time interface checks
var (
_ FieldEncryptor = (*Encryptor)(nil)
_ FieldEncryptor = (*NullEncryptor)(nil)
)
// Package encryption provides utilities for hashing sensitive data.
// It uses bcrypt for secure password-like hashing of tokens.
package encryption
import (
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
const (
// HashPrefix is added to hashed values to identify them.
HashPrefix = "hash:v1:"
// DefaultBcryptCost is the default cost parameter for bcrypt.
// A cost of 10 is a good balance between security and performance.
DefaultBcryptCost = 10
)
var (
// ErrHashMismatch is returned when a hash comparison fails.
ErrHashMismatch = errors.New("hash does not match")
// ErrInvalidHash is returned when the hash format is invalid.
ErrInvalidHash = errors.New("invalid hash format")
)
// TokenHasher provides secure hashing for authentication tokens.
// It uses SHA-256 for creating lookup keys and bcrypt for secure storage.
type TokenHasher struct {
bcryptCost int
}
// NewTokenHasher creates a new TokenHasher with the default bcrypt cost.
func NewTokenHasher() *TokenHasher {
return &TokenHasher{bcryptCost: DefaultBcryptCost}
}
// NewTokenHasherWithCost creates a new TokenHasher with a custom bcrypt cost.
func NewTokenHasherWithCost(cost int) (*TokenHasher, error) {
if cost < bcrypt.MinCost || cost > bcrypt.MaxCost {
return nil, fmt.Errorf("bcrypt cost must be between %d and %d", bcrypt.MinCost, bcrypt.MaxCost)
}
return &TokenHasher{bcryptCost: cost}, nil
}
// HashToken creates a bcrypt hash of a token for secure storage.
// Returns a hash prefixed with HashPrefix for identification.
// For tokens longer than 72 bytes, a SHA-256 pre-hash is used since
// bcrypt has a 72-byte input limit.
func (h *TokenHasher) HashToken(token string) (string, error) {
if token == "" {
return "", errors.New("token cannot be empty")
}
// Pre-hash if token is too long for bcrypt (72 byte limit)
input := []byte(token)
if len(input) > 72 {
sum := sha256.Sum256(input)
input = sum[:]
}
hash, err := bcrypt.GenerateFromPassword(input, h.bcryptCost)
if err != nil {
return "", fmt.Errorf("failed to hash token: %w", err)
}
return HashPrefix + string(hash), nil
}
// VerifyToken compares a plaintext token against a stored hash.
// It returns nil if the token matches, or ErrHashMismatch if it doesn't.
func (h *TokenHasher) VerifyToken(token, hashedToken string) error {
if token == "" || hashedToken == "" {
return ErrHashMismatch
}
// Check if it's a hashed value
if !IsHashed(hashedToken) {
// For backward compatibility, do a constant-time comparison
// if the stored value is not hashed
if subtle.ConstantTimeCompare([]byte(token), []byte(hashedToken)) == 1 {
return nil
}
return ErrHashMismatch
}
// Remove the prefix
bcryptHash := hashedToken[len(HashPrefix):]
// Pre-hash if token is too long for bcrypt (72 byte limit)
input := []byte(token)
if len(input) > 72 {
sum := sha256.Sum256(input)
input = sum[:]
}
// Verify with bcrypt
err := bcrypt.CompareHashAndPassword([]byte(bcryptHash), input)
if err != nil {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrHashMismatch
}
return fmt.Errorf("failed to verify token: %w", err)
}
return nil
}
// CreateLookupKey creates a deterministic hash for token lookup.
// This is used as an index key in the database for finding tokens.
// Uses SHA-256 which is fast and collision-resistant.
func (h *TokenHasher) CreateLookupKey(token string) string {
if token == "" {
return ""
}
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// IsHashed checks if a value has the hash prefix.
func IsHashed(value string) bool {
return len(value) > len(HashPrefix) && value[:len(HashPrefix)] == HashPrefix
}
// NullTokenHasher is a no-op hasher for when hashing is disabled.
type NullTokenHasher struct{}
// NewNullTokenHasher creates a new NullTokenHasher.
func NewNullTokenHasher() *NullTokenHasher {
return &NullTokenHasher{}
}
// HashToken returns the token as-is (no hashing).
func (h *NullTokenHasher) HashToken(token string) (string, error) {
return token, nil
}
// VerifyToken performs a constant-time comparison of the tokens.
func (h *NullTokenHasher) VerifyToken(token, storedToken string) error {
if subtle.ConstantTimeCompare([]byte(token), []byte(storedToken)) == 1 {
return nil
}
return ErrHashMismatch
}
// CreateLookupKey returns the token as-is (it's already the lookup key).
func (h *NullTokenHasher) CreateLookupKey(token string) string {
return token
}
// TokenHasherInterface defines the interface for token hashing operations.
type TokenHasherInterface interface {
HashToken(token string) (string, error)
VerifyToken(token, hashedToken string) error
CreateLookupKey(token string) string
}
// Compile-time interface checks
var (
_ TokenHasherInterface = (*TokenHasher)(nil)
_ TokenHasherInterface = (*NullTokenHasher)(nil)
)
// Package encryption provides a secure database wrapper that encrypts/decrypts sensitive fields.
package encryption
import (
"context"
"fmt"
"github.com/sofatutor/llm-proxy/internal/proxy"
)
// SecureProjectStore wraps a ProjectStore and encrypts/decrypts API keys.
type SecureProjectStore struct {
store proxy.ProjectStore
encryptor FieldEncryptor
}
// NewSecureProjectStore creates a new SecureProjectStore.
// The encryptor is used to encrypt API keys before storing and decrypt after retrieval.
// If encryptor is nil, a NullEncryptor is used (no encryption).
func NewSecureProjectStore(store proxy.ProjectStore, encryptor FieldEncryptor) *SecureProjectStore {
if encryptor == nil {
encryptor = NewNullEncryptor()
}
return &SecureProjectStore{
store: store,
encryptor: encryptor,
}
}
// GetAPIKeyForProject retrieves and decrypts the API key for a project.
func (s *SecureProjectStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
encryptedKey, err := s.store.GetAPIKeyForProject(ctx, projectID)
if err != nil {
return "", err
}
// Decrypt the API key
decryptedKey, err := s.encryptor.Decrypt(encryptedKey)
if err != nil {
return "", fmt.Errorf("failed to decrypt API key: %w", err)
}
return decryptedKey, nil
}
// GetProjectActive returns whether a project is active.
func (s *SecureProjectStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
return s.store.GetProjectActive(ctx, projectID)
}
// ListProjects retrieves all projects and decrypts their API keys.
func (s *SecureProjectStore) ListProjects(ctx context.Context) ([]proxy.Project, error) {
projects, err := s.store.ListProjects(ctx)
if err != nil {
return nil, err
}
// Decrypt API keys for each project
for i := range projects {
decryptedKey, err := s.encryptor.Decrypt(projects[i].APIKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt API key for project %s: %w", projects[i].ID, err)
}
projects[i].APIKey = decryptedKey
}
return projects, nil
}
// CreateProject encrypts the API key and creates the project.
func (s *SecureProjectStore) CreateProject(ctx context.Context, project proxy.Project) error {
// Encrypt the API key before storing
encryptedKey, err := s.encryptor.Encrypt(project.APIKey)
if err != nil {
return fmt.Errorf("failed to encrypt API key: %w", err)
}
project.APIKey = encryptedKey
return s.store.CreateProject(ctx, project)
}
// GetProjectByID retrieves a project and decrypts its API key.
func (s *SecureProjectStore) GetProjectByID(ctx context.Context, projectID string) (proxy.Project, error) {
project, err := s.store.GetProjectByID(ctx, projectID)
if err != nil {
return proxy.Project{}, err
}
// Decrypt the API key
decryptedKey, err := s.encryptor.Decrypt(project.APIKey)
if err != nil {
return proxy.Project{}, fmt.Errorf("failed to decrypt API key: %w", err)
}
project.APIKey = decryptedKey
return project, nil
}
// UpdateProject encrypts the API key and updates the project.
func (s *SecureProjectStore) UpdateProject(ctx context.Context, project proxy.Project) error {
// Only encrypt if the API key is not already encrypted
if !IsEncrypted(project.APIKey) {
encryptedKey, err := s.encryptor.Encrypt(project.APIKey)
if err != nil {
return fmt.Errorf("failed to encrypt API key: %w", err)
}
project.APIKey = encryptedKey
}
return s.store.UpdateProject(ctx, project)
}
// DeleteProject deletes a project.
func (s *SecureProjectStore) DeleteProject(ctx context.Context, projectID string) error {
return s.store.DeleteProject(ctx, projectID)
}
// Compile-time interface check
var _ proxy.ProjectStore = (*SecureProjectStore)(nil)
// Package encryption provides a secure token store wrapper that hashes tokens.
package encryption
import (
"context"
"time"
"github.com/sofatutor/llm-proxy/internal/token"
)
// SecureTokenStore wraps a TokenStore and hashes tokens before storage.
// This prevents tokens from being exposed if the database is compromised.
type SecureTokenStore struct {
store token.TokenStore
hasher TokenHasherInterface
}
// NewSecureTokenStore creates a new SecureTokenStore.
// If hasher is nil, a NullTokenHasher is used (no hashing).
func NewSecureTokenStore(store token.TokenStore, hasher TokenHasherInterface) *SecureTokenStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureTokenStore{
store: store,
hasher: hasher,
}
}
// GetTokenByID retrieves a token by its UUID.
func (s *SecureTokenStore) GetTokenByID(ctx context.Context, id string) (token.TokenData, error) {
return s.store.GetTokenByID(ctx, id)
}
// GetTokenByToken retrieves a token by its token string (for authentication).
// The token is hashed before lookup, and the returned TokenData will
// have the hashed token value (not the original).
func (s *SecureTokenStore) GetTokenByToken(ctx context.Context, tokenString string) (token.TokenData, error) {
hashedToken := s.hasher.CreateLookupKey(tokenString)
return s.store.GetTokenByToken(ctx, hashedToken)
}
// IncrementTokenUsage increments the usage count for a token by token string.
// The token is hashed before the operation.
func (s *SecureTokenStore) IncrementTokenUsage(ctx context.Context, tokenString string) error {
hashedToken := s.hasher.CreateLookupKey(tokenString)
return s.store.IncrementTokenUsage(ctx, hashedToken)
}
// CreateToken creates a new token in the store.
// The token value is hashed before storage.
func (s *SecureTokenStore) CreateToken(ctx context.Context, td token.TokenData) error {
// Hash the token value
td.Token = s.hasher.CreateLookupKey(td.Token)
return s.store.CreateToken(ctx, td)
}
// UpdateToken updates an existing token.
// The token value is hashed before the operation.
func (s *SecureTokenStore) UpdateToken(ctx context.Context, td token.TokenData) error {
// Hash the token value if it's not already a SHA-256 hex string (64 hex chars)
// Validate both length and hex content to avoid skipping plaintext 64-char tokens
if len(td.Token) != 64 || !IsHexString(td.Token) {
td.Token = s.hasher.CreateLookupKey(td.Token)
}
return s.store.UpdateToken(ctx, td)
}
// IsHexString checks if a string contains only hexadecimal characters.
// Exported for use by migration tools and other packages.
func IsHexString(s string) bool {
for _, c := range s {
isDigit := c >= '0' && c <= '9'
isLowerHex := c >= 'a' && c <= 'f'
isUpperHex := c >= 'A' && c <= 'F'
if !isDigit && !isLowerHex && !isUpperHex {
return false
}
}
return true
}
// ListTokens retrieves all tokens from the store.
// Note: The returned tokens will have hashed token values.
func (s *SecureTokenStore) ListTokens(ctx context.Context) ([]token.TokenData, error) {
return s.store.ListTokens(ctx)
}
// GetTokensByProjectID retrieves all tokens for a project.
// Note: The returned tokens will have hashed token values.
func (s *SecureTokenStore) GetTokensByProjectID(ctx context.Context, projectID string) ([]token.TokenData, error) {
return s.store.GetTokensByProjectID(ctx, projectID)
}
// Compile-time interface check
var _ token.TokenStore = (*SecureTokenStore)(nil)
// SecureRevocationStore wraps a RevocationStore and hashes tokens before operations.
type SecureRevocationStore struct {
store token.RevocationStore
hasher TokenHasherInterface
}
// NewSecureRevocationStore creates a new SecureRevocationStore.
func NewSecureRevocationStore(store token.RevocationStore, hasher TokenHasherInterface) *SecureRevocationStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureRevocationStore{
store: store,
hasher: hasher,
}
}
// RevokeToken revokes a token by its ID.
func (s *SecureRevocationStore) RevokeToken(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.RevokeToken(ctx, hashedToken)
}
// DeleteToken deletes a token by its ID.
func (s *SecureRevocationStore) DeleteToken(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.DeleteToken(ctx, hashedToken)
}
// RevokeBatchTokens revokes multiple tokens at once.
func (s *SecureRevocationStore) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
hashedTokens := make([]string, len(tokenIDs))
for i, t := range tokenIDs {
hashedTokens[i] = s.hasher.CreateLookupKey(t)
}
return s.store.RevokeBatchTokens(ctx, hashedTokens)
}
// RevokeProjectTokens revokes all tokens for a project.
func (s *SecureRevocationStore) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
return s.store.RevokeProjectTokens(ctx, projectID)
}
// RevokeExpiredTokens revokes all expired tokens.
func (s *SecureRevocationStore) RevokeExpiredTokens(ctx context.Context) (int, error) {
return s.store.RevokeExpiredTokens(ctx)
}
// Compile-time interface check
var _ token.RevocationStore = (*SecureRevocationStore)(nil)
// SecureRateLimitStore wraps a RateLimitStore and hashes tokens before operations.
type SecureRateLimitStore struct {
store token.RateLimitStore
hasher TokenHasherInterface
}
// NewSecureRateLimitStore creates a new SecureRateLimitStore.
func NewSecureRateLimitStore(store token.RateLimitStore, hasher TokenHasherInterface) *SecureRateLimitStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureRateLimitStore{
store: store,
hasher: hasher,
}
}
// GetTokenByID retrieves a token by its ID.
func (s *SecureRateLimitStore) GetTokenByID(ctx context.Context, tokenID string) (token.TokenData, error) {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.GetTokenByID(ctx, hashedToken)
}
// IncrementTokenUsage increments the usage count for a token.
func (s *SecureRateLimitStore) IncrementTokenUsage(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.IncrementTokenUsage(ctx, hashedToken)
}
// ResetTokenUsage resets the usage count for a token to zero.
func (s *SecureRateLimitStore) ResetTokenUsage(ctx context.Context, tokenID string) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.ResetTokenUsage(ctx, hashedToken)
}
// UpdateTokenLimit updates the maximum allowed requests for a token.
func (s *SecureRateLimitStore) UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error {
hashedToken := s.hasher.CreateLookupKey(tokenID)
return s.store.UpdateTokenLimit(ctx, hashedToken, maxRequests)
}
// Compile-time interface check
var _ token.RateLimitStore = (*SecureRateLimitStore)(nil)
// SecureUsageStatsStore wraps a UsageStatsStore and hashes tokens before batch operations.
type SecureUsageStatsStore struct {
store token.UsageStatsStore
hasher TokenHasherInterface
}
// NewSecureUsageStatsStore creates a new SecureUsageStatsStore.
func NewSecureUsageStatsStore(store token.UsageStatsStore, hasher TokenHasherInterface) *SecureUsageStatsStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureUsageStatsStore{
store: store,
hasher: hasher,
}
}
// IncrementTokenUsageBatch increments request_count for multiple tokens.
// Token strings are hashed before the operation.
func (s *SecureUsageStatsStore) IncrementTokenUsageBatch(ctx context.Context, deltas map[string]int, lastUsedAt time.Time) error {
if len(deltas) == 0 {
return nil
}
hashedDeltas := make(map[string]int, len(deltas))
for tokenString, delta := range deltas {
hashedToken := s.hasher.CreateLookupKey(tokenString)
hashedDeltas[hashedToken] = delta
}
return s.store.IncrementTokenUsageBatch(ctx, hashedDeltas, lastUsedAt)
}
// Compile-time interface check
var _ token.UsageStatsStore = (*SecureUsageStatsStore)(nil)
// CacheStatsStore defines the interface for cache stats storage.
// This mirrors proxy.CacheStatsStore to avoid import cycles.
type CacheStatsStore interface {
IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error
}
// SecureCacheStatsStore wraps a CacheStatsStore and hashes tokens before batch operations.
type SecureCacheStatsStore struct {
store CacheStatsStore
hasher TokenHasherInterface
}
// NewSecureCacheStatsStore creates a new SecureCacheStatsStore.
func NewSecureCacheStatsStore(store CacheStatsStore, hasher TokenHasherInterface) *SecureCacheStatsStore {
if hasher == nil {
hasher = NewNullTokenHasher()
}
return &SecureCacheStatsStore{
store: store,
hasher: hasher,
}
}
// IncrementCacheHitCountBatch increments cache_hit_count for multiple tokens.
// Token strings are hashed before the operation.
func (s *SecureCacheStatsStore) IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error {
if len(deltas) == 0 {
return nil
}
hashedDeltas := make(map[string]int, len(deltas))
for tokenString, delta := range deltas {
hashedToken := s.hasher.CreateLookupKey(tokenString)
hashedDeltas[hashedToken] = delta
}
return s.store.IncrementCacheHitCountBatch(ctx, hashedDeltas)
}
// Compile-time interface check
var _ CacheStatsStore = (*SecureCacheStatsStore)(nil)
package eventbus
import (
"context"
"net/http"
"sync"
"sync/atomic"
"time"
)
// Event represents an observability event emitted by the proxy.
type Event struct {
LogID int64 // Monotonic event log ID
RequestID string
Method string
Path string
Status int
Duration time.Duration
ResponseHeaders http.Header
ResponseBody []byte
RequestBody []byte
}
// EventBus is a simple interface for publishing events to subscribers.
type EventBus interface {
Publish(ctx context.Context, evt Event)
Subscribe() <-chan Event
Stop()
}
type busStats struct {
published atomic.Int64
dropped atomic.Int64
}
// InMemoryEventBus is an EventBus implementation backed by a buffered channel and
// fan-out broadcasting to multiple subscribers. Events are dispatched
// asynchronously to avoid blocking the request path.
type InMemoryEventBus struct {
bufferSize int
ch chan Event
subsMu sync.RWMutex
subscribers []chan Event
stopCh chan struct{}
wg sync.WaitGroup
retryInterval time.Duration
maxRetries int
stats busStats
}
// NewInMemoryEventBus creates a new in-memory event bus with the given buffer size.
func NewInMemoryEventBus(bufferSize int) *InMemoryEventBus {
b := &InMemoryEventBus{
bufferSize: bufferSize,
ch: make(chan Event, bufferSize),
stopCh: make(chan struct{}),
retryInterval: 10 * time.Millisecond,
maxRetries: 3,
}
b.wg.Add(1)
go b.loop()
return b
}
// Publish sends an event to the bus without blocking if the buffer is full.
func (b *InMemoryEventBus) Publish(ctx context.Context, evt Event) {
select {
case b.ch <- evt:
b.stats.published.Add(1)
default:
b.stats.dropped.Add(1)
}
}
// Subscribe returns a channel that receives events published to the bus.
// Each subscriber receives all events.
func (b *InMemoryEventBus) Subscribe() <-chan Event {
sub := make(chan Event, b.bufferSize)
b.subsMu.Lock()
b.subscribers = append(b.subscribers, sub)
b.subsMu.Unlock()
return sub
}
func (b *InMemoryEventBus) loop() {
defer b.wg.Done()
for {
select {
case evt := <-b.ch:
b.dispatch(evt)
case <-b.stopCh:
b.subsMu.Lock()
for _, sub := range b.subscribers {
close(sub)
}
b.subscribers = nil
b.subsMu.Unlock()
return
}
}
}
func (b *InMemoryEventBus) dispatch(evt Event) {
b.subsMu.RLock()
subs := append([]chan Event(nil), b.subscribers...)
b.subsMu.RUnlock()
for _, sub := range subs {
sent := false
for i := 0; i <= b.maxRetries; i++ {
select {
case sub <- evt:
sent = true
default:
time.Sleep(b.retryInterval * time.Duration(i+1))
}
if sent {
break
}
}
}
}
// Stop gracefully stops the event bus and closes all subscriber channels.
func (b *InMemoryEventBus) Stop() {
close(b.stopCh)
b.wg.Wait()
}
// Stats returns the number of published and dropped events.
func (b *InMemoryEventBus) Stats() (published, dropped int) {
return int(b.stats.published.Load()), int(b.stats.dropped.Load())
}
// Note: Legacy Redis LIST backend has been removed. Use Redis Streams for production
// deployments requiring durability and guaranteed delivery.
package eventbus
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
)
// RedisStreamsClient interface for Redis Streams operations.
// This abstraction allows for easy mocking in tests.
type RedisStreamsClient interface {
// XAdd adds an entry to a stream
XAdd(ctx context.Context, args *redis.XAddArgs) (string, error)
// XReadGroup reads entries from a stream using a consumer group
XReadGroup(ctx context.Context, args *redis.XReadGroupArgs) ([]redis.XStream, error)
// XAck acknowledges processed messages
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
// XGroupCreateMkStream creates a consumer group (and the stream if needed)
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
// XPending returns pending entries for a consumer group
XPending(ctx context.Context, stream, group string) (*redis.XPending, error)
// XPendingExt returns detailed pending entries
XPendingExt(ctx context.Context, args *redis.XPendingExtArgs) ([]redis.XPendingExt, error)
// XClaim claims pending messages for a consumer
XClaim(ctx context.Context, args *redis.XClaimArgs) ([]redis.XMessage, error)
// XLen returns the length of a stream
XLen(ctx context.Context, stream string) (int64, error)
// XInfoGroups returns consumer group info for a stream
XInfoGroups(ctx context.Context, stream string) ([]redis.XInfoGroup, error)
}
// RedisStreamsClientAdapter adapts go-redis/v9 Client to the RedisStreamsClient interface.
type RedisStreamsClientAdapter struct {
Client *redis.Client
}
// XAdd adds an entry to a stream
func (a *RedisStreamsClientAdapter) XAdd(ctx context.Context, args *redis.XAddArgs) (string, error) {
return a.Client.XAdd(ctx, args).Result()
}
// XReadGroup reads entries from a stream using a consumer group
func (a *RedisStreamsClientAdapter) XReadGroup(ctx context.Context, args *redis.XReadGroupArgs) ([]redis.XStream, error) {
return a.Client.XReadGroup(ctx, args).Result()
}
// XAck acknowledges processed messages
func (a *RedisStreamsClientAdapter) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
return a.Client.XAck(ctx, stream, group, ids...).Result()
}
// XGroupCreateMkStream creates a consumer group (and the stream if needed)
func (a *RedisStreamsClientAdapter) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return a.Client.XGroupCreateMkStream(ctx, stream, group, start).Err()
}
// XPending returns pending entries for a consumer group
func (a *RedisStreamsClientAdapter) XPending(ctx context.Context, stream, group string) (*redis.XPending, error) {
return a.Client.XPending(ctx, stream, group).Result()
}
// XPendingExt returns detailed pending entries
func (a *RedisStreamsClientAdapter) XPendingExt(ctx context.Context, args *redis.XPendingExtArgs) ([]redis.XPendingExt, error) {
return a.Client.XPendingExt(ctx, args).Result()
}
// XClaim claims pending messages for a consumer
func (a *RedisStreamsClientAdapter) XClaim(ctx context.Context, args *redis.XClaimArgs) ([]redis.XMessage, error) {
return a.Client.XClaim(ctx, args).Result()
}
// XLen returns the length of a stream
func (a *RedisStreamsClientAdapter) XLen(ctx context.Context, stream string) (int64, error) {
return a.Client.XLen(ctx, stream).Result()
}
// XInfoGroups returns consumer group info for a stream
func (a *RedisStreamsClientAdapter) XInfoGroups(ctx context.Context, stream string) ([]redis.XInfoGroup, error) {
return a.Client.XInfoGroups(ctx, stream).Result()
}
// RedisStreamsConfig holds configuration for Redis Streams event bus.
type RedisStreamsConfig struct {
StreamKey string // Redis stream key name
ConsumerGroup string // Consumer group name
ConsumerName string // Unique consumer name within the group
MaxLen int64 // Max stream length (0 = unlimited, uses MAXLEN ~ approximation)
BlockTimeout time.Duration // Block timeout for XREADGROUP (0 = non-blocking)
ClaimMinIdleTime time.Duration // Minimum idle time before claiming pending messages
BatchSize int64 // Number of messages to read at once
}
// DefaultRedisStreamsConfig returns default configuration.
func DefaultRedisStreamsConfig() RedisStreamsConfig {
return RedisStreamsConfig{
StreamKey: "llm-proxy-events",
ConsumerGroup: "llm-proxy-dispatchers",
ConsumerName: "dispatcher-1",
MaxLen: 10000,
BlockTimeout: 5 * time.Second,
ClaimMinIdleTime: 30 * time.Second,
BatchSize: 100,
}
}
// RedisStreamsEventBus implements EventBus using Redis Streams.
// It provides durable, distributed event delivery with consumer groups,
// acknowledgment, and at-least-once delivery semantics.
type RedisStreamsEventBus struct {
client RedisStreamsClient
config RedisStreamsConfig
stats busStats
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
subscribers []chan Event
subsMu sync.RWMutex
groupCreated atomic.Bool
}
// NewRedisStreamsEventBus creates a new Redis Streams event bus.
func NewRedisStreamsEventBus(client RedisStreamsClient, config RedisStreamsConfig) *RedisStreamsEventBus {
return &RedisStreamsEventBus{
client: client,
config: config,
stopCh: make(chan struct{}),
}
}
// EnsureConsumerGroup creates the consumer group if it doesn't exist.
// This should be called before starting to consume messages.
func (b *RedisStreamsEventBus) EnsureConsumerGroup(ctx context.Context) error {
if b.groupCreated.Load() {
return nil
}
// Try to create the group; if it already exists, that's fine
err := b.client.XGroupCreateMkStream(ctx, b.config.StreamKey, b.config.ConsumerGroup, "0")
if err != nil {
// Check if error is because group already exists
if isGroupExistsError(err) {
b.groupCreated.Store(true)
return nil
}
return fmt.Errorf("failed to create consumer group: %w", err)
}
b.groupCreated.Store(true)
return nil
}
// isGroupExistsError checks if the error indicates the group already exists.
func isGroupExistsError(err error) bool {
if err == nil {
return false
}
// Redis returns "BUSYGROUP Consumer Group name already exists" error
return err.Error() == "BUSYGROUP Consumer Group name already exists"
}
// Publish adds an event to the Redis stream using XADD.
func (b *RedisStreamsEventBus) Publish(ctx context.Context, evt Event) {
data, err := json.Marshal(evt)
if err != nil {
log.Printf("[eventbus] Failed to marshal event: %v", err)
b.stats.dropped.Add(1)
return
}
args := &redis.XAddArgs{
Stream: b.config.StreamKey,
Values: map[string]interface{}{
"data": string(data),
},
}
// Apply MaxLen if configured
if b.config.MaxLen > 0 {
args.MaxLen = b.config.MaxLen
args.Approx = true // Use ~ for better performance
}
_, err = b.client.XAdd(ctx, args)
if err != nil {
log.Printf("[eventbus] Failed to publish event to stream %s: %v", b.config.StreamKey, err)
b.stats.dropped.Add(1)
return
}
b.stats.published.Add(1)
}
// Subscribe returns a channel that receives events from the stream.
// This starts a background goroutine that reads from the stream using consumer groups.
func (b *RedisStreamsEventBus) Subscribe() <-chan Event {
ch := make(chan Event, b.config.BatchSize)
b.subsMu.Lock()
b.subscribers = append(b.subscribers, ch)
b.subsMu.Unlock()
b.wg.Add(1)
go b.consumeLoop(ch)
return ch
}
// consumeLoop reads messages from the stream and dispatches to the subscriber channel.
func (b *RedisStreamsEventBus) consumeLoop(ch chan Event) {
defer b.wg.Done()
defer close(ch)
ctx := context.Background()
// Ensure consumer group exists
if err := b.EnsureConsumerGroup(ctx); err != nil {
log.Printf("[eventbus] Failed to ensure consumer group: %v", err)
return
}
// First, process any pending messages (messages we received but didn't acknowledge)
b.processPendingMessages(ctx, ch)
// Then start normal consumption
for {
select {
case <-b.stopCh:
return
default:
}
// Read new messages
streams, err := b.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
Streams: []string{b.config.StreamKey, ">"},
Count: b.config.BatchSize,
Block: b.config.BlockTimeout,
})
if err != nil {
if err == redis.Nil {
// No new messages, check for pending messages to claim
b.claimPendingMessages(ctx, ch)
continue
}
// Check if context was cancelled or we're stopping
select {
case <-b.stopCh:
return
default:
log.Printf("[eventbus] Error reading from stream: %v", err)
time.Sleep(time.Second) // Back off on error
continue
}
}
// Process received messages
for _, stream := range streams {
for _, msg := range stream.Messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse message %s: %v", msg.ID, err)
// Acknowledge invalid messages so they don't get stuck
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
// Try to send to subscriber
select {
case ch <- evt:
// Message delivered, acknowledge it
_, err := b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
if err != nil {
log.Printf("[eventbus] Failed to acknowledge message %s: %v", msg.ID, err)
}
case <-b.stopCh:
return
}
}
}
}
}
// processPendingMessages handles messages that were delivered but not acknowledged.
// This is called on startup to handle messages from a previous crash.
func (b *RedisStreamsEventBus) processPendingMessages(ctx context.Context, ch chan Event) {
// Read our pending messages
streams, err := b.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
Streams: []string{b.config.StreamKey, "0"}, // "0" means pending messages
Count: b.config.BatchSize,
})
if err != nil && err != redis.Nil {
log.Printf("[eventbus] Error reading pending messages: %v", err)
return
}
for _, stream := range streams {
for _, msg := range stream.Messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse pending message %s: %v", msg.ID, err)
// Acknowledge invalid messages
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
select {
case ch <- evt:
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
case <-b.stopCh:
return
}
}
}
}
// claimPendingMessages claims messages from other consumers that have been idle too long.
// This implements the "at-least-once" delivery guarantee for crashed consumers.
func (b *RedisStreamsEventBus) claimPendingMessages(ctx context.Context, ch chan Event) {
// Get pending entries that have exceeded the idle time
pending, err := b.client.XPendingExt(ctx, &redis.XPendingExtArgs{
Stream: b.config.StreamKey,
Group: b.config.ConsumerGroup,
Start: "-",
End: "+",
Count: b.config.BatchSize,
})
if err != nil || len(pending) == 0 {
return
}
// Find messages that are idle long enough to claim
var toClaim []string
for _, p := range pending {
if p.Idle >= b.config.ClaimMinIdleTime {
toClaim = append(toClaim, p.ID)
}
}
if len(toClaim) == 0 {
return
}
// Claim the messages
messages, err := b.client.XClaim(ctx, &redis.XClaimArgs{
Stream: b.config.StreamKey,
Group: b.config.ConsumerGroup,
Consumer: b.config.ConsumerName,
MinIdle: b.config.ClaimMinIdleTime,
Messages: toClaim,
})
if err != nil {
log.Printf("[eventbus] Error claiming messages: %v", err)
return
}
// Process claimed messages
for _, msg := range messages {
evt, err := b.parseMessage(msg)
if err != nil {
log.Printf("[eventbus] Failed to parse claimed message %s: %v", msg.ID, err)
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
continue
}
select {
case ch <- evt:
_, _ = b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, msg.ID)
case <-b.stopCh:
return
}
}
}
// parseMessage extracts an Event from a Redis stream message.
func (b *RedisStreamsEventBus) parseMessage(msg redis.XMessage) (Event, error) {
var evt Event
data, ok := msg.Values["data"]
if !ok {
return evt, fmt.Errorf("message missing 'data' field")
}
dataStr, ok := data.(string)
if !ok {
return evt, fmt.Errorf("'data' field is not a string")
}
if err := json.Unmarshal([]byte(dataStr), &evt); err != nil {
return evt, fmt.Errorf("failed to unmarshal event: %w", err)
}
return evt, nil
}
// Stop gracefully stops the event bus and closes all subscriber channels.
func (b *RedisStreamsEventBus) Stop() {
b.stopOnce.Do(func() {
close(b.stopCh)
b.wg.Wait()
})
}
// Stats returns the number of published and dropped events.
func (b *RedisStreamsEventBus) Stats() (published, dropped int) {
return int(b.stats.published.Load()), int(b.stats.dropped.Load())
}
// StreamLength returns the current length of the stream.
func (b *RedisStreamsEventBus) StreamLength(ctx context.Context) (int64, error) {
return b.client.XLen(ctx, b.config.StreamKey)
}
// PendingCount returns the number of pending messages in the consumer group.
func (b *RedisStreamsEventBus) PendingCount(ctx context.Context) (int64, error) {
pending, err := b.client.XPending(ctx, b.config.StreamKey, b.config.ConsumerGroup)
if err != nil {
return 0, err
}
return pending.Count, nil
}
// Client returns the underlying RedisStreamsClient.
func (b *RedisStreamsEventBus) Client() RedisStreamsClient {
return b.client
}
// Acknowledge manually acknowledges a message by ID.
// This is useful when external code handles message processing and acknowledgment.
func (b *RedisStreamsEventBus) Acknowledge(ctx context.Context, messageID string) error {
_, err := b.client.XAck(ctx, b.config.StreamKey, b.config.ConsumerGroup, messageID)
return err
}
package eventtransformer
import (
"bytes"
"compress/gzip"
"encoding/base64"
"encoding/json"
"io"
"strings"
"unicode/utf8"
"github.com/andybalholm/brotli"
)
// DecompressAndDecode attempts to decompress (gzip, brotli) if needed, then base64 decode if needed, and returns the decoded string and true if decoding was successful.
func DecompressAndDecode(val string, headers map[string]interface{}) (string, bool) {
// Only log errors, binary skipping, and major state changes
data := []byte(val)
encoding := ""
contentType := ""
for k, v := range headers {
key := strings.ToLower(strings.ReplaceAll(k, "-", "_"))
if key == "content_encoding" {
if arr, ok := v.([]interface{}); ok && len(arr) > 0 {
if s, ok := arr[0].(string); ok {
encoding = strings.ToLower(s)
}
} else if s, ok := v.(string); ok {
encoding = strings.ToLower(s)
}
}
if key == "content_type" {
if arr, ok := v.([]interface{}); ok && len(arr) > 0 {
if s, ok := arr[0].(string); ok {
contentType = strings.ToLower(s)
}
} else if s, ok := v.(string); ok {
contentType = strings.ToLower(s)
}
}
}
// Skip binary content types (audio, image, octet-stream)
if strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "image/") || contentType == "application/octet-stream" {
// Skipping decode for binary content-type
return val, false
}
// 1. Try base64 decode (standard)
decoded, err := base64.StdEncoding.DecodeString(string(data))
if err == nil {
data = decoded
// Use tagged switch for encoding
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
data = decompressed
}
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
decompressed, err := io.ReadAll(br)
if err == nil {
data = decompressed
}
}
if json.Valid(data) {
return string(data), true
}
if utf8.Valid(data) {
return string(data), true
}
}
// 2. Try base64 decode (URL-safe)
decoded, err = base64.URLEncoding.DecodeString(string(data))
if err == nil {
data = decoded
// Use tagged switch for encoding
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader(data))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
data = decompressed
}
}
case "br":
br := brotli.NewReader(bytes.NewReader(data))
decompressed, err := io.ReadAll(br)
if err == nil {
data = decompressed
}
}
if json.Valid(data) {
return string(data), true
}
if utf8.Valid(data) {
return string(data), true
}
}
// 3. If base64 fails, try decompressing original data (legacy case)
switch encoding {
case "gzip":
zr, err := gzip.NewReader(bytes.NewReader([]byte(val)))
if err == nil {
decompressed, err := io.ReadAll(zr)
_ = zr.Close()
if err == nil {
if json.Valid(decompressed) {
return string(decompressed), true
}
if utf8.Valid(decompressed) {
return string(decompressed), true
}
}
}
case "br":
br := brotli.NewReader(bytes.NewReader([]byte(val)))
decompressed, err := io.ReadAll(br)
if err == nil {
if json.Valid(decompressed) {
return string(decompressed), true
}
if utf8.Valid(decompressed) {
return string(decompressed), true
}
}
}
// 4. Fallback: check if input is JSON or UTF-8
if json.Valid([]byte(val)) {
return val, true
}
if utf8.Valid([]byte(val)) {
return val, true
}
// Fallback: returning original input, could not decode
return val, false
}
package eventtransformer
import (
"encoding/base64"
"encoding/json"
"os"
"strings"
"unicode/utf8"
"github.com/google/uuid"
)
// IsOpenAIStreaming detects if the response body is a sequence of OpenAI streaming chunks (data: ... lines).
func IsOpenAIStreaming(body string) bool {
lines := strings.Split(body, "\n")
count := 0
for _, line := range lines {
if strings.HasPrefix(line, "data: ") && !strings.HasPrefix(line, "data: [DONE]") {
count++
}
}
return count > 1 // at least two chunks
}
// MergeOpenAIStreamingChunks parses and merges OpenAI streaming chunks into a single response object.
func MergeOpenAIStreamingChunks(body string) (map[string]any, error) {
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
var (
id string
object string
created int
model string
content strings.Builder
finish string
usage Usage
)
lines := strings.Split(body, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
data := strings.TrimPrefix(line, "data: ")
var chunk map[string]any
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if v, ok := chunk["id"].(string); ok && id == "" {
id = v
}
if v, ok := chunk["object"].(string); ok && object == "" {
object = v
}
if v, ok := chunk["created"].(float64); ok && created == 0 {
created = int(v)
}
if v, ok := chunk["model"].(string); ok && model == "" {
model = v
}
// Merge choices
if choices, ok := chunk["choices"].([]any); ok && len(choices) > 0 {
choice := choices[0].(map[string]any)
if delta, ok := choice["delta"].(map[string]any); ok {
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
if fr, ok := choice["finish_reason"].(string); ok && fr != "" {
finish = fr
}
}
// Merge usage if present (usually only in last chunk)
if u, ok := chunk["usage"].(map[string]any); ok {
if v, ok := u["prompt_tokens"].(float64); ok {
usage.PromptTokens = int(v)
}
if v, ok := u["completion_tokens"].(float64); ok {
usage.CompletionTokens = int(v)
}
if v, ok := u["total_tokens"].(float64); ok {
usage.TotalTokens = int(v)
}
}
}
merged := map[string]any{
"id": id,
"object": object,
"created": created,
"model": model,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": content.String(),
},
"finish_reason": finish,
}},
}
if usage.PromptTokens > 0 || usage.CompletionTokens > 0 || usage.TotalTokens > 0 {
merged["usage"] = map[string]any{
"prompt_tokens": usage.PromptTokens,
"completion_tokens": usage.CompletionTokens,
"total_tokens": usage.TotalTokens,
}
}
// Merged OpenAI chunks successfully
return merged, nil
}
// MergeThreadStreamingChunks parses and merges assistant thread.run streaming events.
func MergeThreadStreamingChunks(body string) (map[string]any, error) {
type Usage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}
var (
id string
assistantID string
threadID string
status string
created int
model string
contentB strings.Builder
usage Usage
finalContent string // for thread.message.completed
)
lines := strings.Split(body, "\n")
foundCompleted := false
for i := 0; i < len(lines); i++ {
line := strings.TrimSpace(lines[i])
if strings.HasPrefix(line, "event: ") {
eventType := strings.TrimPrefix(line, "event: ")
i++
if i >= len(lines) {
break
}
dataLine := strings.TrimSpace(lines[i])
if !strings.HasPrefix(dataLine, "data: ") || dataLine == "data: [DONE]" {
continue
}
var msg map[string]any
if err := json.Unmarshal([]byte(strings.TrimPrefix(dataLine, "data: ")), &msg); err != nil {
continue
}
switch eventType {
case "thread.run.created", "thread.run.queued":
if v, ok := msg["id"].(string); ok && id == "" {
id = v
}
if v, ok := msg["assistant_id"].(string); ok && assistantID == "" {
assistantID = v
}
if v, ok := msg["thread_id"].(string); ok && threadID == "" {
threadID = v
}
if v, ok := msg["status"].(string); ok && status == "" {
status = v
}
if v, ok := msg["created_at"].(float64); ok && created == 0 {
created = int(v)
}
case "thread.message.delta":
if !foundCompleted {
if delta, ok := msg["delta"].(map[string]any); ok {
if segs, ok := delta["content"].([]any); ok {
for _, seg := range segs {
segMap := seg.(map[string]any)
if txt, ok := segMap["text"].(map[string]any); ok {
contentB.WriteString(txt["value"].(string))
}
}
}
}
}
case "thread.message.completed":
// If this event is found, ignore all previous deltas and use only this content
if contentArr, ok := msg["content"].([]any); ok {
var sb strings.Builder
for _, seg := range contentArr {
segMap, ok := seg.(map[string]any)
if !ok {
continue
}
if segMap["type"] == "text" {
if txt, ok := segMap["text"].(map[string]any); ok {
if val, ok := txt["value"].(string); ok {
sb.WriteString(val)
}
}
}
}
finalContent = sb.String()
foundCompleted = true
}
case "thread.run.step.completed", "thread.run.completed":
if u, ok := msg["usage"].(map[string]any); ok {
if v, ok := u["prompt_tokens"].(float64); ok {
usage.PromptTokens = int(v)
}
if v, ok := u["completion_tokens"].(float64); ok {
usage.CompletionTokens = int(v)
}
if v, ok := u["total_tokens"].(float64); ok {
usage.TotalTokens = int(v)
}
}
}
}
}
// If thread.message.completed was found, use only its content
if foundCompleted {
contentB.Reset()
contentB.WriteString(finalContent)
}
merged := map[string]any{
"id": id,
"object": "thread.run",
"created_at": created,
"assistant_id": assistantID,
"thread_id": threadID,
"status": status,
"model": model,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant",
"content": contentB.String(),
},
"finish_reason": "",
}},
}
if usage.PromptTokens+usage.CompletionTokens+usage.TotalTokens > 0 {
merged["usage"] = map[string]any{
"prompt_tokens": usage.PromptTokens,
"completion_tokens": usage.CompletionTokens,
"total_tokens": usage.TotalTokens,
}
}
// Merged thread.run chunks successfully
return merged, nil
}
// TransformEvent transforms an OpenAI event for logging/analytics.
// It handles OPTIONS skipping, header filtering, decoding, chunk merging, token counting, and snake_case normalization.
func (t *OpenAITransformer) TransformEvent(evt map[string]any) (map[string]any, error) {
// Skip OPTIONS requests
if method, _ := evt["Method"].(string); strings.ToUpper(method) == "OPTIONS" {
return nil, nil
}
// Filter out set-cookie from response_headers
if headers, ok := evt["ResponseHeaders"].(map[string]any); ok {
for k := range headers {
if strings.ToLower(k) == "set-cookie" {
delete(headers, k)
}
}
}
// Set request_id
requestID := ""
if headers, ok := evt["RequestHeaders"].(map[string]any); ok {
for k, v := range headers {
if strings.ToLower(k) == "x-request-id" {
if s, ok := v.(string); ok && s != "" {
requestID = s
break
}
}
}
}
if requestID == "" {
requestID = uuid.NewString()
}
evt["request_id"] = requestID
// Decode request_body to UTF-8 or compact JSON
for _, key := range []string{"RequestBody", "request_body"} {
if v, ok := evt[key]; ok {
switch val := v.(type) {
case string:
if val != "" {
decoded := tryBase64DecodeWithLog(val)
if compact, _, ok := normalizeToCompactJSON(decoded); ok {
evt["request_body"] = compact
// promptTokenSource is set below if request_body is present
} else if isValidUTF8(decoded) {
evt["request_body"] = decoded
} else {
evt["request_body"] = "[binary or undecodable data]"
}
break
}
case []byte:
if len(val) > 0 {
decoded := tryBase64DecodeWithLog(string(val))
evt["request_body"] = decoded
break
}
}
}
}
// Handle ResponseBody
contentType := ""
hdrs, ok := evt["ResponseHeaders"]
var hdrMap map[string]any
if ok {
hdrMap, _ = hdrs.(map[string]any)
} else {
hdrMap = map[string]any{}
}
if hdrMap == nil {
hdrMap = map[string]any{}
}
if len(hdrMap) > 0 {
for k, v := range hdrMap {
if strings.ToLower(strings.ReplaceAll(k, "-", "_")) == "content_type" {
switch arr := v.(type) {
case []any:
if len(arr) > 0 {
if s, ok := arr[0].(string); ok {
contentType = strings.ToLower(s)
}
}
case string:
contentType = strings.ToLower(arr)
}
}
}
}
if respBody, ok := evt["ResponseBody"].(string); ok && respBody != "" {
decoded, okDecoded := DecompressAndDecode(respBody, hdrMap)
if !okDecoded {
decoded = tryBase64DecodeWithLog(respBody)
}
if strings.HasPrefix(contentType, "audio/") || strings.HasPrefix(contentType, "image/") || contentType == "application/octet-stream" {
if os.Getenv("LOG_BINARY_RESPONSES") == "1" {
evt["response_body"] = respBody
} else {
evt["response_body"] = "[binary or undecodable data]"
}
evt["response_body_binary"] = true
} else {
if IsOpenAIStreaming(decoded) {
var merged map[string]any
var err error
if strings.Contains(decoded, "event: thread.run") {
merged, err = MergeThreadStreamingChunks(decoded)
} else {
merged, err = MergeOpenAIStreamingChunks(decoded)
}
if err == nil {
comp, _ := json.Marshal(merged)
evt["response_body"] = string(comp)
if usage, ok := merged["usage"].(map[string]any); ok {
evt["TokenUsage"] = usage
}
} else {
// Merge stream error - using decoded response body
evt["response_body"] = decoded
}
} else {
if compact, _, ok := normalizeToCompactJSON(decoded); ok {
evt["response_body"] = compact
} else if isValidUTF8(decoded) {
evt["response_body"] = decoded
} else {
evt["response_body"] = "[binary or undecodable data]"
}
// Extract usage from non-streaming completion
if resp, ok := evt["response_body"].(string); ok && json.Valid([]byte(resp)) {
var respObj map[string]any
if err := json.Unmarshal([]byte(resp), &respObj); err == nil {
if usage, ok := respObj["usage"].(map[string]any); ok {
evt["TokenUsage"] = usage
delete(respObj, "usage")
b, _ := json.Marshal(respObj)
evt["response_body"] = string(b)
}
}
}
}
}
}
// Token usage fallback
if _, has := evt["TokenUsage"]; !has {
if resp, ok := evt["response_body"].(string); ok && json.Valid([]byte(resp)) {
pt, ct := 0, 0
// Compose prompt token source: messages + instructions if present
var promptTokenSource string
if req, ok := evt["request_body"].(string); ok && req != "" {
var reqObj map[string]any
if err := json.Unmarshal([]byte(req), &reqObj); err == nil {
if msgs, ok := reqObj["messages"]; ok {
b, _ := json.Marshal(msgs)
promptTokenSource = string(b)
}
if instr, ok := reqObj["instructions"].(string); ok && instr != "" {
if promptTokenSource != "" {
promptTokenSource += instr
} else {
promptTokenSource = instr
}
}
}
}
if promptTokenSource != "" {
modelName := ""
if req, ok := evt["request_body"].(string); ok && req != "" {
var reqObj map[string]any
_ = json.Unmarshal([]byte(req), &reqObj)
if m, ok := reqObj["model"].(string); ok {
modelName = m
}
}
t, _ := CountOpenAITokensForModel(promptTokenSource, modelName)
pt = t
}
cnt, _ := extractAssistantReplyContent(resp)
if cnt != "" {
// Try to read model from the parsed response JSON
modelName := ""
var respObj map[string]any
if err := json.Unmarshal([]byte(resp), &respObj); err == nil {
if m, ok := respObj["model"].(string); ok {
modelName = m
}
}
tk, _ := CountOpenAITokensForModel(cnt, modelName)
ct = tk
}
evt["TokenUsage"] = map[string]int{"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}
}
}
// Clean up and normalize
delete(evt, "RequestBody")
delete(evt, "ResponseBody")
delete(evt, "response_body_streamed")
return ToSnakeCaseMap(evt), nil
}
// -- helper functions --
func tryBase64DecodeWithLog(val string) string {
clean := strings.ReplaceAll(val, "", "")
if json.Valid([]byte(clean)) {
return clean
}
if b, err := base64.StdEncoding.DecodeString(clean); err == nil {
return string(b)
}
if b, err := base64.URLEncoding.DecodeString(clean); err == nil {
return string(b)
}
return clean
}
func normalizeToCompactJSON(input string) (string, string, bool) {
var obj any
if err := json.Unmarshal([]byte(input), &obj); err != nil {
return input, "", false
}
if s, ok := obj.(string); ok {
var inner any
if err := json.Unmarshal([]byte(s), &inner); err == nil {
b, _ := json.Marshal(inner)
str := string(b)
if m, ok := inner.(map[string]any); ok {
if msgs, ok := m["messages"]; ok {
mj, _ := json.Marshal(msgs)
return str, string(mj), true
}
}
return str, "", true
}
return s, "", true
}
b, err := json.Marshal(obj)
if err != nil {
return input, "", false
}
str := string(b)
if m, ok := obj.(map[string]any); ok {
if msgs, ok := m["messages"]; ok {
mj, _ := json.Marshal(msgs)
return str, string(mj), true
}
}
return str, "", true
}
func isValidUTF8(s string) bool { return utf8.ValidString(s) }
func extractAssistantReplyContent(resp string) (string, error) {
var obj map[string]any
if err := json.Unmarshal([]byte(resp), &obj); err != nil {
return "", err
}
if ch, ok := obj["choices"].([]any); ok && len(ch) > 0 {
if c0, ok := ch[0].(map[string]any); ok {
if msg, ok := c0["message"].(map[string]any); ok {
if content, ok := msg["content"].(string); ok {
return content, nil
}
}
}
}
return "", nil
}
package eventtransformer
import (
"strings"
"unicode"
)
// ToSnakeCaseMap recursively converts all map keys to snake_case, replacing dashes and handling consecutive uppercase letters
func ToSnakeCaseMap(m map[string]interface{}) map[string]interface{} {
snake := make(map[string]interface{}, len(m))
for k, v := range m {
sk := ToSnakeCase(k)
switch vv := v.(type) {
case map[string]interface{}:
snake[sk] = ToSnakeCaseMap(vv)
case []interface{}:
arr := make([]interface{}, len(vv))
for i, elem := range vv {
if mm, ok := elem.(map[string]interface{}); ok {
arr[i] = ToSnakeCaseMap(mm)
} else {
arr[i] = elem
}
}
snake[sk] = arr
default:
snake[sk] = v
}
}
return snake
}
// ToSnakeCase converts a string from CamelCase, PascalCase, or kebab-case to snake_case
func ToSnakeCase(s string) string {
// Replace dashes with underscores
s = strings.ReplaceAll(s, "-", "_")
var out []rune
var prevLower, prevUnderscore bool
for i, r := range s {
if r == '_' {
out = append(out, r)
prevLower, prevUnderscore = false, true
continue
}
if unicode.IsUpper(r) {
if (i > 0 && prevLower) || (i > 0 && !prevUnderscore && i+1 < len(s) && unicode.IsLower(rune(s[i+1]))) {
out = append(out, '_')
}
out = append(out, unicode.ToLower(r))
prevLower, prevUnderscore = false, false
} else {
out = append(out, r)
prevLower, prevUnderscore = true, false
}
}
return string(out)
}
package eventtransformer
import (
"strings"
"github.com/pkoukk/tiktoken-go"
)
// CountOpenAITokens counts tokens using a general-purpose encoding.
// Note: Prefer CountOpenAITokensForModel when the model is known.
func CountOpenAITokens(text string) (int, error) {
enc, err := tiktoken.GetEncoding("cl100k_base")
if err != nil {
return 0, err
}
return len(enc.Encode(text, nil, nil)), nil
}
// CountOpenAITokensForModel selects an encoding based on the provided model name.
// Fallback rules:
// - 4o/omni/o1 family → o200k_base
// - otherwise → cl100k_base
// If EncodingForModel succeeds, it is used directly.
func CountOpenAITokensForModel(text, model string) (int, error) {
if model != "" {
if enc, err := tiktoken.EncodingForModel(model); err == nil {
return len(enc.Encode(text, nil, nil)), nil
}
}
// Heuristic fallback by family
var base string
lower := strings.ToLower(model)
switch {
case strings.Contains(lower, "gpt-4o"), strings.HasPrefix(lower, "o1"), strings.Contains(lower, "omni"):
base = "o200k_base"
default:
base = "cl100k_base"
}
enc, err := tiktoken.GetEncoding(base)
if err != nil {
// Last resort: try cl100k_base
enc, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return 0, err
}
}
return len(enc.Encode(text, nil, nil)), nil
}
package eventtransformer
// Package eventtransformer provides event transformation logic for different LLM API providers.
// Each provider (e.g., OpenAI, Anthropic) should have its own transformer implementation.
// Transformer is the interface for provider-specific event transformers.
type Transformer interface {
TransformEvent(event map[string]interface{}) (map[string]interface{}, error)
}
// DispatchTransformer returns the appropriate transformer for a given provider.
func DispatchTransformer(provider string) Transformer {
switch provider {
case "openai":
return &OpenAITransformer{}
// case "anthropic":
// return &AnthropicTransformer{}
default:
return nil
}
}
// OpenAITransformer implements Transformer for OpenAI events.
type OpenAITransformer struct{}
package logging
import (
"context"
"os"
"strings"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// NewLogger creates a zap.Logger with the specified level, format, and optional file output.
// level can be debug, info, warn, or error. format can be json or console.
// If filePath is empty, logs are written to stdout.
func NewLogger(level, format, filePath string) (*zap.Logger, error) {
var lvl zapcore.Level
switch strings.ToLower(level) {
case "debug":
lvl = zapcore.DebugLevel
case "info", "":
lvl = zapcore.InfoLevel
case "warn":
lvl = zapcore.WarnLevel
case "error":
lvl = zapcore.ErrorLevel
default:
lvl = zapcore.InfoLevel
}
encCfg := zapcore.EncoderConfig{
TimeKey: "ts",
LevelKey: "level",
NameKey: "logger",
MessageKey: "msg",
CallerKey: "caller",
StacktraceKey: "stacktrace",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.StringDurationEncoder,
EncodeLevel: zapcore.LowercaseLevelEncoder,
}
var encoder zapcore.Encoder
if strings.ToLower(format) == "console" {
encoder = zapcore.NewConsoleEncoder(encCfg)
} else {
encoder = zapcore.NewJSONEncoder(encCfg)
}
var ws = zapcore.AddSync(os.Stdout)
if filePath != "" {
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, err
}
ws = f
}
core := zapcore.NewCore(encoder, ws, lvl)
return zap.New(core), nil
}
// Context keys for request and correlation IDs
type contextKey string
const (
requestIDKey contextKey = "request_id"
correlationIDKey contextKey = "correlation_id"
)
// Canonical field helpers for structured logging
// RequestFields returns fields for HTTP request logging
func RequestFields(requestID, method, path string, statusCode, durationMs int) []zap.Field {
return []zap.Field{
zap.String("request_id", requestID),
zap.String("method", method),
zap.String("path", path),
zap.Int("status_code", statusCode),
zap.Int("duration_ms", durationMs),
}
}
// CorrelationID returns a field for correlation ID
func CorrelationID(id string) zap.Field {
return zap.String("correlation_id", id)
}
// ProjectID returns a field for project ID
func ProjectID(id string) zap.Field {
return zap.String("project_id", id)
}
// TokenID returns a field for token ID (obfuscated for security)
func TokenID(token string) zap.Field {
return zap.String("token_id", obfuscate.ObfuscateTokenGeneric(token))
}
// ClientIP returns a field for client IP address
func ClientIP(ip string) zap.Field {
return zap.String("client_ip", ip)
}
// Context management for request/correlation IDs
// WithRequestID adds a request ID to the context
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey, requestID)
}
// WithCorrelationID adds a correlation ID to the context
func WithCorrelationID(ctx context.Context, correlationID string) context.Context {
return context.WithValue(ctx, correlationIDKey, correlationID)
}
// GetRequestID retrieves the request ID from context
func GetRequestID(ctx context.Context) (string, bool) {
id, ok := ctx.Value(requestIDKey).(string)
return id, ok
}
// GetCorrelationID retrieves the correlation ID from context
func GetCorrelationID(ctx context.Context) (string, bool) {
id, ok := ctx.Value(correlationIDKey).(string)
return id, ok
}
// Logger enhancement helpers
// WithRequestContext adds request ID from context to logger if present
func WithRequestContext(ctx context.Context, logger *zap.Logger) *zap.Logger {
if requestID, ok := GetRequestID(ctx); ok {
return logger.With(zap.String("request_id", requestID))
}
return logger
}
// WithCorrelationContext adds correlation ID from context to logger if present
func WithCorrelationContext(ctx context.Context, logger *zap.Logger) *zap.Logger {
if correlationID, ok := GetCorrelationID(ctx); ok {
return logger.With(zap.String("correlation_id", correlationID))
}
return logger
}
// NewChildLogger creates a child logger with a component field
func NewChildLogger(parent *zap.Logger, component string) *zap.Logger {
return parent.With(zap.String("component", component))
}
package middleware
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/logging"
"go.uber.org/zap"
)
// Middleware defines a function that wraps an http.Handler.
type Middleware func(http.Handler) http.Handler
// ObservabilityConfig controls the behavior of the observability middleware.
type ObservabilityConfig struct {
Enabled bool
EventBus eventbus.EventBus
// MaxRequestBodyBytes limits request body capture for observability events. 0 means "use default".
MaxRequestBodyBytes int64
// MaxResponseBodyBytes limits response body capture for observability events. 0 means "use default".
MaxResponseBodyBytes int64
}
// ObservabilityMiddleware captures request/response data and forwards it to an event bus.
type ObservabilityMiddleware struct {
cfg ObservabilityConfig
logger *zap.Logger
}
// NewObservabilityMiddleware creates a new ObservabilityMiddleware instance.
func NewObservabilityMiddleware(cfg ObservabilityConfig, logger *zap.Logger) *ObservabilityMiddleware {
if logger == nil {
logger = zap.NewNop()
}
return &ObservabilityMiddleware{cfg: cfg, logger: logger}
}
// Middleware returns the http middleware function.
func (m *ObservabilityMiddleware) Middleware() Middleware {
if !m.cfg.Enabled || m.cfg.EventBus == nil {
return func(next http.Handler) http.Handler { return next }
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
maxReq := m.cfg.MaxRequestBodyBytes
if maxReq <= 0 {
maxReq = 64 * 1024 // default 64KB
}
maxResp := m.cfg.MaxResponseBodyBytes
if maxResp <= 0 {
maxResp = 256 * 1024 // default 256KB
}
crw := &captureResponseWriter{ResponseWriter: w, statusCode: http.StatusOK, maxBodyBytes: maxResp}
var reqBody []byte
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
if r.Body != nil {
// Capture only up to maxReq bytes without consuming the downstream body.
// We read maxReq+1 bytes so we can detect truncation.
originalBody := r.Body
limitedReader := io.LimitReader(originalBody, maxReq+1)
bodyBytes, err := io.ReadAll(limitedReader)
if err == nil && len(bodyBytes) > 0 {
if int64(len(bodyBytes)) > maxReq {
reqBody = bodyBytes[:maxReq]
} else {
reqBody = bodyBytes
}
// Restore the body for downstream handlers (including any unread bytes from originalBody).
r.Body = &readerWithCloser{
r: io.MultiReader(bytes.NewReader(bodyBytes), originalBody),
c: originalBody,
}
} else {
// Restore the body even on read errors.
r.Body = originalBody
}
}
}
next.ServeHTTP(crw, r)
// Resolve request ID from header, then context, then response headers
reqID := r.Header.Get("X-Request-ID")
if reqID == "" {
if v, ok := logging.GetRequestID(r.Context()); ok {
reqID = v
}
}
if reqID == "" {
reqID = crw.Header().Get("X-Request-ID")
}
// Skip publishing cache hits (do not incur provider cost)
if v := strings.ToLower(crw.Header().Get("X-PROXY-CACHE")); v == "hit" || v == "conditional-hit" {
return
}
evt := eventbus.Event{
RequestID: reqID,
Method: r.Method,
Path: r.URL.Path,
Status: crw.statusCode,
Duration: time.Since(start),
// Clone once so the event payload is immutable and safe for async consumers.
ResponseHeaders: cloneHeader(crw.Header()),
ResponseBody: crw.body.Bytes(),
RequestBody: reqBody,
}
// Publish is non-blocking; avoid spawning a goroutine per request.
// Any heavy transformations (e.g., OpenAI metadata extraction) should happen in downstream consumers.
m.cfg.EventBus.Publish(context.Background(), evt)
})
}
}
// captureResponseWriter wraps http.ResponseWriter to capture status and body while supporting streaming.
type captureResponseWriter struct {
http.ResponseWriter
statusCode int
body bytes.Buffer
maxBodyBytes int64
capturedBytes int64
}
func (w *captureResponseWriter) WriteHeader(code int) {
w.statusCode = code
w.ResponseWriter.WriteHeader(code)
}
func (w *captureResponseWriter) Write(b []byte) (int, error) {
if w.maxBodyBytes <= 0 || w.capturedBytes < w.maxBodyBytes {
remaining := int64(len(b))
if w.maxBodyBytes > 0 {
remaining = w.maxBodyBytes - w.capturedBytes
}
if remaining > 0 {
toWrite := b
if int64(len(b)) > remaining {
toWrite = b[:remaining]
}
_, _ = w.body.Write(toWrite)
w.capturedBytes += int64(len(toWrite))
}
}
return w.ResponseWriter.Write(b)
}
func (w *captureResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (w *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := w.ResponseWriter.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("hijack not supported")
}
func (w *captureResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return http.ErrNotSupported
}
func cloneHeader(h http.Header) http.Header {
cloned := make(http.Header, len(h))
for k, v := range h {
vv := make([]string, len(v))
copy(vv, v)
cloned[k] = vv
}
return cloned
}
type readerWithCloser struct {
r io.Reader
c io.Closer
}
func (rc *readerWithCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
func (rc *readerWithCloser) Close() error { return rc.c.Close() }
package middleware
import (
"net/http"
"strings"
"github.com/google/uuid"
"github.com/sofatutor/llm-proxy/internal/logging"
)
// RequestIDMiddleware handles request and correlation ID context propagation
type RequestIDMiddleware struct{}
// NewRequestIDMiddleware creates a new middleware instance for request ID management
func NewRequestIDMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get or generate request ID
requestID := getOrGenerateID(r.Header.Get("X-Request-ID"))
// Get or generate correlation ID
correlationID := getOrGenerateID(r.Header.Get("X-Correlation-ID"))
// Add IDs to context
ctx := logging.WithRequestID(r.Context(), requestID)
ctx = logging.WithCorrelationID(ctx, correlationID)
// Set response headers
w.Header().Set("X-Request-ID", requestID)
w.Header().Set("X-Correlation-ID", correlationID)
// Continue with the request using the enriched context
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// getOrGenerateID returns the provided ID if valid, otherwise generates a new UUID
func getOrGenerateID(existingID string) string {
// Trim whitespace
existingID = strings.TrimSpace(existingID)
// If empty, generate new UUID
if existingID == "" {
return uuid.New().String()
}
// For now, accept any non-empty ID (could add validation later if needed)
return existingID
}
// Package obfuscate centralizes redaction/obfuscation helpers used across the codebase.
package obfuscate
import (
"strings"
)
// ObfuscateTokenGeneric obfuscates arbitrary token-like strings for display/logging.
// Behavior (kept for backward-compat with previous utils implementation):
// - length <= 4 → all asterisks of same length
// - 5..12 → keep first 2 characters, replace the rest with asterisks
// - > 12 → keep first 8 characters, then "...", then last 4 characters
func ObfuscateTokenGeneric(s string) string {
if len(s) <= 4 {
return strings.Repeat("*", len(s))
}
if len(s) <= 12 {
return s[:2] + strings.Repeat("*", len(s)-2)
}
return s[:8] + "..." + s[len(s)-4:]
}
// ObfuscateTokenSimple obfuscates token-like strings with a fixed pattern suitable for UIs.
// Behavior (kept for backward-compat with Admin template helper):
// - length <= 8 → "****"
// - > 8 → first 4 + "****" + last 4
func ObfuscateTokenSimple(s string) string {
if len(s) <= 8 {
return "****"
}
return s[:4] + "****" + s[len(s)-4:]
}
// ObfuscateTokenByPrefix obfuscates tokens that follow a known prefix convention (e.g., "sk-").
// If the string doesn't have the expected prefix, it falls back to the generic strategy.
// Behavior (kept for backward-compat with token.ObfuscateToken):
// - With prefix: keep the prefix, then show first 4 and last 4 of the remainder, replacing the middle with '*'
// If remainder <= 8, return the input unmodified
// - Without prefix: same result as ObfuscateTokenGeneric
func ObfuscateTokenByPrefix(s string, prefix string) string {
if s == "" {
return s
}
if !strings.HasPrefix(s, prefix) {
// Preserve previous behavior of token.ObfuscateToken: leave non-prefixed strings unchanged
return s
}
rest := s[len(prefix):]
if len(rest) <= 8 {
return s
}
visible := 4
first := rest[:visible]
last := rest[len(rest)-visible:]
middle := strings.Repeat("*", len(rest)-(visible*2))
return prefix + first + middle + last
}
package proxy
import (
"net/http"
"strings"
"sync"
"time"
)
type cachedResponse struct {
statusCode int
headers http.Header
body []byte
expiresAt time.Time
vary string // Vary header from upstream response for per-response cache key generation
}
// httpCache is a minimal cache interface used by the proxy cache layer.
// Implementations must be safe for concurrent use.
type httpCache interface {
Get(key string) (cachedResponse, bool)
Set(key string, value cachedResponse)
Purge(key string) bool // Remove exact key
PurgePrefix(prefix string) int // Remove all keys with prefix, return count
}
type inMemoryCache struct {
mu sync.RWMutex
store map[string]cachedResponse
}
func newInMemoryCache() *inMemoryCache {
return &inMemoryCache{store: make(map[string]cachedResponse)}
}
func (c *inMemoryCache) Get(key string) (cachedResponse, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.store[key]
if !ok {
return cachedResponse{}, false
}
if time.Now().After(v.expiresAt) {
return cachedResponse{}, false
}
return v, true
}
func (c *inMemoryCache) Set(key string, value cachedResponse) {
c.mu.Lock()
c.store[key] = value
c.mu.Unlock()
}
func (c *inMemoryCache) Purge(key string) bool {
c.mu.Lock()
defer c.mu.Unlock()
_, exists := c.store[key]
if exists {
delete(c.store, key)
}
return exists
}
func (c *inMemoryCache) PurgePrefix(prefix string) int {
c.mu.Lock()
defer c.mu.Unlock()
count := 0
for key := range c.store {
if strings.HasPrefix(key, prefix) {
delete(c.store, key)
count++
}
}
return count
}
package proxy
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"net/textproto"
"sort"
"strconv"
"strings"
"time"
)
// generateCacheKey builds a cache key from method, path, query and an optional
// ordered list of header names to include. When headersToInclude is nil or empty,
// no header values are incorporated. For methods carrying a body (POST/PUT/PATCH)
// this function also incorporates X-Body-Hash and an optional TTL derived from
// Cache-Control (public, max-age/s-maxage) to avoid collisions across different TTLs.
func generateCacheKey(r *http.Request, headersToInclude []string) string {
// Key base: METHOD|PATH|sorted(query)
// Host/scheme are intentionally excluded to keep keys stable across proxy ↔ upstream phases.
b := strings.Builder{}
b.WriteString(r.Method)
b.WriteString("|")
b.WriteString(r.URL.Path)
b.WriteString("|")
// Sorted query to normalize key
q := r.URL.Query()
keys := make([]string, 0, len(q))
for k := range q {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
vals := q[k]
sort.Strings(vals)
b.WriteString(k)
b.WriteString("=")
b.WriteString(strings.Join(vals, ","))
b.WriteString("&")
}
raw := b.String()
sum := sha256.Sum256([]byte(raw))
baseKey := hex.EncodeToString(sum[:])
// Include selected request header values (normalized) deterministically
vb := strings.Builder{}
if len(headersToInclude) > 0 {
for _, hk := range headersToInclude {
if v := r.Header.Get(hk); v != "" {
vb.WriteString(strings.ToLower(hk))
vb.WriteString(":")
vb.WriteString(strings.TrimSpace(v))
vb.WriteString("|")
}
}
}
vraw := baseKey + vb.String()
vsum := sha256.Sum256([]byte(vraw))
varyKey := hex.EncodeToString(vsum[:])
final := strings.Builder{}
final.WriteString(varyKey)
// For methods with body, include X-Body-Hash when present (computed in proxy)
if r.Header.Get("X-Body-Hash") != "" {
final.WriteString("|body=")
final.WriteString(r.Header.Get("X-Body-Hash"))
}
// If client explicitly opts into shared caching with a TTL (public + max-age/s-maxage),
// include the requested TTL in the key so different TTLs do not collide with older entries.
// Only apply this for methods that can carry a body (POST/PUT/PATCH) to avoid splitting
// GET/HEAD cache keys unnecessarily when origin already provides TTLs.
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
cc := parseCacheControl(r.Header.Get("Cache-Control"))
if cc.publicCache && (cc.sMaxAge > 0 || cc.maxAge > 0) {
final.WriteString("|ttl=")
if cc.sMaxAge > 0 {
final.WriteString("smax=")
final.WriteString(strconv.Itoa(cc.sMaxAge))
} else {
final.WriteString("max=")
final.WriteString(strconv.Itoa(cc.maxAge))
}
}
}
return final.String()
}
func CacheKeyFromRequest(r *http.Request) string {
// Conservative subset until per-response Vary handling is applied by caller
return generateCacheKey(r, []string{"Accept", "Accept-Encoding", "Accept-Language"})
}
// CacheKeyFromRequestWithVary generates a cache key using the specified Vary header
// from the upstream response. This enables per-response driven cache key generation.
//
// Behavior:
// - When vary is empty or "*", this function returns a key that ignores header values
// (treated as "no vary" for key generation purposes).
// - Header names parsed from vary are normalized to canonical MIME header names.
func CacheKeyFromRequestWithVary(r *http.Request, vary string) string {
var headers []string
if vary != "" && vary != "*" {
headers = parseVaryHeader(vary)
sort.Strings(headers) // ensure consistent ordering
}
return generateCacheKey(r, headers)
}
// parseVaryHeader parses a Vary header value and returns the list of header names.
// It handles comma-separated values and normalizes header names.
func parseVaryHeader(vary string) []string {
if vary == "" || vary == "*" {
return nil
}
var headers []string
parts := strings.Split(vary, ",")
for _, part := range parts {
header := strings.TrimSpace(part)
if header != "" {
headers = append(headers, normalizeHeaderName(header))
}
}
return headers
}
// normalizeHeaderName canonicalizes a header name for consistent lookups.
func normalizeHeaderName(name string) string {
return textproto.CanonicalMIMEHeaderKey(name)
}
func isResponseCacheable(res *http.Response) bool {
if res == nil {
return false
}
cc := parseCacheControl(res.Header.Get("Cache-Control"))
if cc.noStore || cc.privateCache {
return false
}
// Cacheable status codes (basic set)
switch res.StatusCode {
case 200, 203, 301, 308, 404, 410:
// ok
default:
return false
}
// If Authorization was present on request, require explicit shared cache directives
if res.Request != nil {
if res.Request.Header.Get("Authorization") != "" {
if !cc.publicCache && cc.sMaxAge <= 0 {
return false
}
}
}
// Don't cache SSE
if strings.Contains(res.Header.Get("Content-Type"), "text/event-stream") {
return false
}
if res.Header.Get("Vary") == "*" {
return false
}
return true
}
type cacheControl struct {
noStore bool
noCache bool
mustReval bool
maxAge int
sMaxAge int
publicCache bool
privateCache bool
}
func parseCacheControl(v string) cacheControl {
cc := cacheControl{}
parts := strings.Split(v, ",")
for _, p := range parts {
p = strings.TrimSpace(strings.ToLower(p))
switch {
case p == "no-store":
cc.noStore = true
case p == "no-cache":
cc.noCache = true
case p == "must-revalidate":
cc.mustReval = true
case p == "public":
cc.publicCache = true
case p == "private":
cc.privateCache = true
case strings.HasPrefix(p, "s-maxage="):
cc.sMaxAge = atoiSafe(strings.TrimPrefix(p, "s-maxage="))
case strings.HasPrefix(p, "max-age="):
cc.maxAge = atoiSafe(strings.TrimPrefix(p, "max-age="))
}
}
return cc
}
func cacheTTLFromHeaders(res *http.Response, defaultTTL time.Duration) time.Duration {
cc := parseCacheControl(res.Header.Get("Cache-Control"))
if cc.noStore {
return 0
}
if cc.sMaxAge > 0 {
return time.Duration(cc.sMaxAge) * time.Second
}
if cc.maxAge > 0 {
return time.Duration(cc.maxAge) * time.Second
}
if defaultTTL > 0 && (cc.publicCache || (!cc.privateCache && !cc.noCache)) {
return defaultTTL
}
return 0
}
// requestForcedCacheTTL returns a TTL requested by the client via Cache-Control
// when the client explicitly asks for shared caching (public) and provides a TTL.
// This is primarily used for benchmarking when upstream does not send cache hints.
func requestForcedCacheTTL(req *http.Request) time.Duration {
if req == nil {
return 0
}
cc := parseCacheControl(req.Header.Get("Cache-Control"))
if !cc.publicCache {
return 0
}
if cc.sMaxAge > 0 {
return time.Duration(cc.sMaxAge) * time.Second
}
if cc.maxAge > 0 {
return time.Duration(cc.maxAge) * time.Second
}
return 0
}
func atoiSafe(s string) int {
n := 0
for i := 0; i < len(s); i++ {
ch := s[i]
if ch < '0' || ch > '9' {
break
}
n = n*10 + int(ch-'0')
}
return n
}
func cloneHeadersForCache(h http.Header) http.Header {
// Drop hop-by-hop headers and request-specific timing headers
drop := map[string]struct{}{
// Hop-by-hop headers (RFC 2616)
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"TE": {},
"Trailers": {},
"Transfer-Encoding": {},
"Upgrade": {},
// Request-specific timing headers (not cacheable)
"X-Upstream-Request-Start": {},
"X-Upstream-Request-Stop": {},
"X-Proxy-Received-At": {},
"X-Proxy-Final-Response-At": {},
"X-Proxy-First-Response-At": {},
"Date": {},
// Cookies are user-specific, not cacheable for shared cache
"Set-Cookie": {},
}
out := http.Header{}
for k, vals := range h {
if _, ok := drop[k]; ok {
continue
}
for _, v := range vals {
out.Add(k, v)
}
}
return out
}
// canServeCachedForRequest decides if a cached response is reusable for the given request.
// In particular, for requests with Authorization, only allow reuse when the cached
// response explicitly allows shared caching (public or s-maxage>0).
func canServeCachedForRequest(r *http.Request, cachedHeaders http.Header) bool {
if r == nil {
return false
}
if r.Header.Get("Authorization") == "" {
return true
}
cc := parseCacheControl(cachedHeaders.Get("Cache-Control"))
if cc.publicCache || cc.sMaxAge > 0 {
return true
}
return false
}
// conditionalRequestMatches returns true if the client's conditional headers
// (If-None-Match or If-Modified-Since) match the cached response headers.
// RFC semantics simplified for strong validators; good enough for proxy cache use.
func conditionalRequestMatches(r *http.Request, cachedHeaders http.Header) bool {
if r == nil {
return false
}
// If-None-Match takes precedence over If-Modified-Since
if inm := strings.TrimSpace(r.Header.Get("If-None-Match")); inm != "" {
// Canonicalize header name and compare values (support multiple via comma)
etag := strings.TrimSpace(cachedHeaders.Get(textproto.CanonicalMIMEHeaderKey("ETag")))
if etag == "" {
return false
}
// Strip surrounding quotes for comparison robustness
ce := strings.Trim(etag, "\"")
for _, part := range strings.Split(inm, ",") {
p := strings.TrimSpace(part)
p = strings.Trim(p, "\"")
if p == "*" || p == ce {
return true
}
}
return false
}
if ims := strings.TrimSpace(r.Header.Get("If-Modified-Since")); ims != "" {
lm := strings.TrimSpace(cachedHeaders.Get(textproto.CanonicalMIMEHeaderKey("Last-Modified")))
if lm == "" {
return false
}
imsTime, err1 := http.ParseTime(ims)
lmTime, err2 := http.ParseTime(lm)
if err1 != nil || err2 != nil {
return false
}
if !lmTime.After(imsTime) {
return true
}
}
return false
}
// hasClientCacheOptIn returns true if the client request explicitly opts into
// shared caching via Cache-Control (public with max-age or s-maxage > 0).
func hasClientCacheOptIn(r *http.Request) bool {
if r == nil {
return false
}
cc := parseCacheControl(r.Header.Get("Cache-Control"))
if !cc.publicCache {
return false
}
if cc.sMaxAge > 0 {
return true
}
if cc.maxAge > 0 {
return true
}
return false
}
// Note: client conditional inspection is handled by conditionalRequestMatches; removed redundant helper.
// wantsRevalidation returns true if the client requests origin revalidation
// (e.g., Cache-Control: no-cache or max-age=0).
func wantsRevalidation(r *http.Request) bool {
if r == nil {
return false
}
ccVal := strings.ToLower(r.Header.Get("Cache-Control"))
if ccVal == "" {
return false
}
if strings.Contains(ccVal, "no-cache") {
return true
}
if strings.Contains(ccVal, "max-age=0") {
return true
}
return false
}
package proxy
import (
"context"
"encoding/json"
"os"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
// redisCache implements httpCache using Redis.
// It stores cachedResponse as JSON and uses Redis TTL for expiration.
type redisCache struct {
client *redis.Client
prefix string
scanCount int
}
func newRedisCache(client *redis.Client, keyPrefix string) *redisCache {
if keyPrefix == "" {
keyPrefix = "llmproxy:cache:"
}
// Determine SCAN batch size from env (default 2048)
scan := 2048
if v := os.Getenv("REDIS_SCAN_COUNT"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
scan = n
}
}
return &redisCache{client: client, prefix: keyPrefix, scanCount: scan}
}
type redisCachedResponse struct {
StatusCode int `json:"status_code"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body"`
Vary string `json:"vary"` // Vary header for per-response cache key generation
}
func (r *redisCache) Get(key string) (cachedResponse, bool) {
ctx := context.Background()
data, err := r.client.Get(ctx, r.prefix+key).Bytes()
if err != nil {
return cachedResponse{}, false
}
var rc redisCachedResponse
if err := json.Unmarshal(data, &rc); err != nil {
return cachedResponse{}, false
}
// Convert map to http.Header lazily in caller; keep simple here
hdr := make(map[string][]string, len(rc.Headers))
for k, v := range rc.Headers {
hdr[k] = v
}
return cachedResponse{
statusCode: rc.StatusCode,
headers: hdr,
body: rc.Body,
vary: rc.Vary, // Include vary field
// expiresAt not needed; Redis TTL enforces expiry
expiresAt: time.Now().Add(time.Second),
}, true
}
func (r *redisCache) Set(key string, value cachedResponse) {
ctx := context.Background()
// Serialize
ser := redisCachedResponse{StatusCode: value.statusCode, Headers: value.headers, Body: value.body, Vary: value.vary}
payload, err := json.Marshal(ser)
if err != nil {
return
}
ttl := time.Until(value.expiresAt)
if ttl <= 0 {
return
}
_ = r.client.Set(ctx, r.prefix+key, payload, ttl).Err()
}
// Purge removes a single cache entry by exact key. Returns true if deleted.
func (r *redisCache) Purge(key string) bool {
ctx := context.Background()
res := r.client.Del(ctx, r.prefix+key)
n, _ := res.Result()
return n > 0
}
// PurgePrefix removes all cache entries whose keys start with the given prefix.
// Returns number of deleted keys. Uses SCAN to avoid blocking Redis.
func (r *redisCache) PurgePrefix(prefix string) int {
ctx := context.Background()
fullPrefix := r.prefix + prefix
var cursor uint64
total := 0
for {
keys, next, err := r.client.Scan(ctx, cursor, fullPrefix+"*", int64(r.scanCount)).Result()
if err != nil {
// Abort on scan error to avoid infinite loop; return what we deleted so far
return total
}
cursor = next
if len(keys) > 0 {
delCount, _ := r.client.Del(ctx, keys...).Result()
total += int(delCount)
}
if cursor == 0 {
break
}
}
return total
}
// Package proxy provides the transparent reverse proxy implementation.
package proxy
import (
"context"
"sync"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
)
// CacheStatsStore defines the interface for persisting cache hit counts.
//
// Consistency invariant: Under normal operation, CacheHitCount ≤ RequestCount should
// always hold for any token. Cache hits are only recorded when responses are served
// from cache, which requires a prior upstream request that incremented RequestCount.
//
// If CacheHitCount > RequestCount is observed, it indicates a system issue that should
// be investigated (e.g., request count increment failed while cache hit was recorded).
// The Admin UI uses safeSub to display max(0, RequestCount-CacheHitCount) to handle
// this edge case gracefully.
type CacheStatsStore interface {
// IncrementCacheHitCountBatch increments cache_hit_count for multiple tokens.
// The deltas map has token IDs as keys and increment values as values.
IncrementCacheHitCountBatch(ctx context.Context, deltas map[string]int) error
}
// CacheStatsAggregatorConfig holds configuration for the cache stats aggregator.
type CacheStatsAggregatorConfig struct {
BufferSize int // Size of the buffered channel (default: 1000)
FlushInterval time.Duration // How often to flush stats to DB (default: 5s)
BatchSize int // Max events before flush (default: 100)
}
// DefaultCacheStatsAggregatorConfig returns the default configuration.
func DefaultCacheStatsAggregatorConfig() CacheStatsAggregatorConfig {
return CacheStatsAggregatorConfig{
BufferSize: 1000,
FlushInterval: 5 * time.Second,
BatchSize: 100,
}
}
// CacheStatsAggregator aggregates cache hit events and periodically flushes them to the database.
// It uses a buffered channel for non-blocking enqueue and drops events when the buffer is full.
type CacheStatsAggregator struct {
config CacheStatsAggregatorConfig
store CacheStatsStore
logger *zap.Logger
eventsCh chan string // channel of token IDs
stopCh chan struct{}
doneCh chan struct{}
mu sync.RWMutex
stopped bool
}
// NewCacheStatsAggregator creates a new aggregator with the given configuration.
func NewCacheStatsAggregator(config CacheStatsAggregatorConfig, store CacheStatsStore, logger *zap.Logger) *CacheStatsAggregator {
if config.BufferSize <= 0 {
config.BufferSize = 1000
}
if config.FlushInterval <= 0 {
config.FlushInterval = 5 * time.Second
}
if config.BatchSize <= 0 {
config.BatchSize = 100
}
if logger == nil {
logger = zap.NewNop()
}
return &CacheStatsAggregator{
config: config,
store: store,
logger: logger,
eventsCh: make(chan string, config.BufferSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins the background aggregation worker.
func (a *CacheStatsAggregator) Start() {
go a.run()
}
// Stop gracefully shuts down the aggregator, flushing any pending stats.
func (a *CacheStatsAggregator) Stop(ctx context.Context) error {
a.mu.Lock()
if a.stopped {
a.mu.Unlock()
return nil
}
a.stopped = true
a.mu.Unlock()
close(a.stopCh)
select {
case <-a.doneCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// RecordCacheHit enqueues a cache hit event for the given token.
// This is non-blocking; if the buffer is full, the event is dropped.
func (a *CacheStatsAggregator) RecordCacheHit(tokenID string) {
if tokenID == "" {
return
}
a.mu.RLock()
stopped := a.stopped
a.mu.RUnlock()
if stopped {
return
}
select {
case a.eventsCh <- tokenID:
// Event enqueued successfully
default:
// Buffer full, drop the event
a.logger.Debug("cache stats buffer full, dropping event",
zap.String("token_id", obfuscate.ObfuscateTokenGeneric(tokenID)))
}
}
// run is the main loop of the aggregator worker.
func (a *CacheStatsAggregator) run() {
defer close(a.doneCh)
ticker := time.NewTicker(a.config.FlushInterval)
defer ticker.Stop()
deltas := make(map[string]int)
eventCount := 0
flush := func() {
if len(deltas) == 0 {
return
}
// Copy and reset
toFlush := deltas
deltas = make(map[string]int)
flushedCount := eventCount
eventCount = 0
// Flush with a short timeout
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := a.store.IncrementCacheHitCountBatch(ctx, toFlush); err != nil {
a.logger.Warn("failed to flush cache hit stats",
zap.Error(err),
zap.Int("event_count", flushedCount),
zap.Int("token_count", len(toFlush)))
// On error, we drop the stats (lossy-tolerant as per design)
} else {
a.logger.Debug("flushed cache hit stats",
zap.Int("event_count", flushedCount),
zap.Int("token_count", len(toFlush)))
}
}
for {
select {
case <-a.stopCh:
// Drain remaining events
for {
select {
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
default:
// No more events
flush()
return
}
}
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
if eventCount >= a.config.BatchSize {
flush()
}
case <-ticker.C:
flush()
}
}
}
package proxy
import (
"net/http"
"sync"
"time"
)
// Define a custom context key type for circuit breaker cooldown override
type cbContextKey string
const cbCooldownOverrideKey cbContextKey = "circuitbreaker_cooldown_override"
// CircuitBreakerMiddleware returns a middleware that opens the circuit after N consecutive failures.
// While open, it returns 503 immediately. After a cooldown, it closes and allows requests again.
func CircuitBreakerMiddleware(failureThreshold int, cooldown time.Duration, isTransient func(status int) bool) Middleware {
cb := &circuitBreaker{
failureThreshold: failureThreshold,
cooldown: cooldown,
isTransient: isTransient,
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cb.mu.Lock()
if cb.open {
// Allow test override of cooldown via context
if override, ok := r.Context().Value(cbCooldownOverrideKey).(time.Duration); ok {
cb.cooldown = override
}
if time.Since(cb.openedAt) < cb.cooldown {
cb.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("{\"error\":\"Upstream unavailable (circuit breaker open)\"}")) // Ignore error: nothing we can do if write fails
return
}
// Cooldown expired, close circuit
cb.open = false
cb.failureCount = 0
}
cb.mu.Unlock()
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rec, r)
if cb.isTransient(rec.statusCode) {
cb.mu.Lock()
cb.failureCount++
if cb.failureCount >= cb.failureThreshold {
cb.open = true
cb.openedAt = time.Now()
}
cb.mu.Unlock()
} else {
cb.mu.Lock()
cb.failureCount = 0
cb.mu.Unlock()
}
})
}
}
type circuitBreaker struct {
mu sync.Mutex
open bool
failureThreshold int
cooldown time.Duration
isTransient func(status int) bool
openedAt time.Time
failureCount int
}
// Package proxy provides the transparent proxy functionality for the LLM API.
package proxy
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
// APIConfig represents the top-level configuration for API proxying
type APIConfig struct {
// APIs contains configurations for different API providers
APIs map[string]*APIProviderConfig `yaml:"apis"`
// DefaultAPI is the default API provider to use if not specified
DefaultAPI string `yaml:"default_api"`
}
// APIProviderConfig represents the configuration for a specific API provider
type APIProviderConfig struct {
// BaseURL is the target API base URL
BaseURL string `yaml:"base_url"`
// AllowedEndpoints is a list of endpoint prefixes that are allowed to be accessed
AllowedEndpoints []string `yaml:"allowed_endpoints"`
// AllowedMethods is a list of HTTP methods that are allowed
AllowedMethods []string `yaml:"allowed_methods"`
// Timeouts for various operations
Timeouts TimeoutConfig `yaml:"timeouts"`
// Connection settings
Connection ConnectionConfig `yaml:"connection"`
// ParamWhitelist is a map of parameter names to allowed values
ParamWhitelist map[string][]string `yaml:"param_whitelist"`
AllowedOrigins []string `yaml:"allowed_origins"`
RequiredHeaders []string `yaml:"required_headers"`
}
// TimeoutConfig contains timeout settings for the proxy
type TimeoutConfig struct {
// Request is the overall request timeout
Request time.Duration `yaml:"request"`
// ResponseHeader is the timeout for receiving response headers
ResponseHeader time.Duration `yaml:"response_header"`
// IdleConnection is the timeout for idle connections
IdleConnection time.Duration `yaml:"idle_connection"`
// FlushInterval controls how often to flush streaming responses
FlushInterval time.Duration `yaml:"flush_interval"`
}
// ConnectionConfig contains connection settings for the proxy
type ConnectionConfig struct {
// MaxIdleConns is the maximum number of idle connections
MaxIdleConns int `yaml:"max_idle_conns"`
// MaxIdleConnsPerHost is the maximum number of idle connections per host
MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host"`
}
// LoadAPIConfigFromFile loads API configuration from a YAML file
func LoadAPIConfigFromFile(filePath string) (*APIConfig, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config APIConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Validate the configuration
if err := validateAPIConfig(&config); err != nil {
return nil, err
}
return &config, nil
}
// validateAPIConfig checks that the API configuration is valid
func validateAPIConfig(config *APIConfig) error {
if len(config.APIs) == 0 {
return fmt.Errorf("no API providers configured")
}
if config.DefaultAPI != "" {
if _, exists := config.APIs[config.DefaultAPI]; !exists {
return fmt.Errorf("default API '%s' not found in configured APIs", config.DefaultAPI)
}
}
for name, api := range config.APIs {
if api.BaseURL == "" {
return fmt.Errorf("API '%s' has empty base_url", name)
}
if len(api.AllowedEndpoints) == 0 {
return fmt.Errorf("API '%s' has no allowed_endpoints", name)
}
if len(api.AllowedMethods) == 0 {
return fmt.Errorf("API '%s' has no allowed_methods", name)
}
}
return nil
}
// GetProxyConfigForAPI returns a ProxyConfig for the specified API provider
func (c *APIConfig) GetProxyConfigForAPI(apiName string) (*ProxyConfig, error) {
// If no API name specified, use the default
if apiName == "" {
apiName = c.DefaultAPI
}
// Find the API configuration
apiConfig, exists := c.APIs[apiName]
if !exists {
return nil, fmt.Errorf("API provider '%s' not found in configuration", apiName)
}
// Create the proxy configuration
proxyConfig := ProxyConfig{
TargetBaseURL: apiConfig.BaseURL,
AllowedEndpoints: apiConfig.AllowedEndpoints,
AllowedMethods: apiConfig.AllowedMethods,
RequestTimeout: apiConfig.Timeouts.Request,
ResponseHeaderTimeout: apiConfig.Timeouts.ResponseHeader,
FlushInterval: apiConfig.Timeouts.FlushInterval,
IdleConnTimeout: apiConfig.Timeouts.IdleConnection,
MaxIdleConns: apiConfig.Connection.MaxIdleConns,
MaxIdleConnsPerHost: apiConfig.Connection.MaxIdleConnsPerHost,
ParamWhitelist: apiConfig.ParamWhitelist,
AllowedOrigins: apiConfig.AllowedOrigins,
RequiredHeaders: apiConfig.RequiredHeaders,
}
return &proxyConfig, nil
}
package proxy
import (
"context"
"errors"
"net/http"
"time"
"github.com/sofatutor/llm-proxy/internal/audit"
)
// TokenValidator defines the interface for token validation
type TokenValidator interface {
// ValidateToken validates a token and returns the associated project ID
ValidateToken(ctx context.Context, token string) (string, error)
// ValidateTokenWithTracking validates a token, increments its usage, and returns the project ID
ValidateTokenWithTracking(ctx context.Context, token string) (string, error)
}
// ProjectStore defines the interface for retrieving and managing project information
// (extended for management API)
type ProjectStore interface {
// GetAPIKeyForProject retrieves the API key for a project
GetAPIKeyForProject(ctx context.Context, projectID string) (string, error)
// GetProjectActive checks if a project is active
GetProjectActive(ctx context.Context, projectID string) (bool, error)
// Management API CRUD
ListProjects(ctx context.Context) ([]Project, error)
CreateProject(ctx context.Context, project Project) error
GetProjectByID(ctx context.Context, projectID string) (Project, error)
UpdateProject(ctx context.Context, project Project) error
DeleteProject(ctx context.Context, projectID string) error
}
// ProjectActiveChecker defines the interface for checking project active status
type ProjectActiveChecker interface {
// GetProjectActive checks if a project is active
GetProjectActive(ctx context.Context, projectID string) (bool, error)
}
// AuditLogger defines the interface for audit event logging
type AuditLogger interface {
// Log records an audit event
Log(event *audit.Event) error
}
// Proxy defines the interface for a transparent HTTP proxy
type Proxy interface {
// Handler returns an http.Handler for the proxy
Handler() http.Handler
// Shutdown gracefully shuts down the proxy
Shutdown(ctx context.Context) error
}
// ProxyConfig contains configuration for the proxy
type ProxyConfig struct {
// TargetBaseURL is the base URL of the API to proxy to
TargetBaseURL string
// AllowedEndpoints is a whitelist of endpoints that can be accessed
AllowedEndpoints []string
// AllowedMethods is a whitelist of HTTP methods that can be used
AllowedMethods []string
// RequestTimeout is the maximum duration for a complete request
RequestTimeout time.Duration
// ResponseHeaderTimeout is the time to wait for response headers
ResponseHeaderTimeout time.Duration
// FlushInterval is how often to flush streaming responses
FlushInterval time.Duration
// MaxIdleConns is the maximum number of idle connections
MaxIdleConns int
// MaxIdleConnsPerHost is the maximum number of idle connections per host
MaxIdleConnsPerHost int
// IdleConnTimeout is how long to keep idle connections alive
IdleConnTimeout time.Duration
// LogLevel controls the verbosity of logging
LogLevel string
// LogFormat controls the log output format (json or console)
LogFormat string
// LogFile specifies a file path for logs (stdout if empty)
LogFile string
// SetXForwardedFor determines whether to set the X-Forwarded-For header
SetXForwardedFor bool
// ParamWhitelist is a map of parameter names to allowed values
ParamWhitelist map[string][]string
// AllowedOrigins is a list of allowed CORS origins for this provider
AllowedOrigins []string
// RequiredHeaders is a list of required request headers (case-insensitive)
RequiredHeaders []string
// Project active guard configuration
EnforceProjectActive bool // Whether to enforce project active status
// --- HTTP cache (global, opt-in; set programmatically, not via YAML) ---
// HTTPCacheEnabled toggles the proxy cache for GET/HEAD based on HTTP semantics
HTTPCacheEnabled bool
// HTTPCacheDefaultTTL is used only when upstream allows caching but omits explicit TTL
HTTPCacheDefaultTTL time.Duration
// HTTPCacheMaxObjectBytes is a guardrail for maximum cacheable response size
HTTPCacheMaxObjectBytes int64
// RedisCacheURL enables Redis-backed cache when non-empty (e.g., redis://localhost:6379/0)
RedisCacheURL string
// RedisCacheKeyPrefix allows namespacing cache keys (default: llmproxy:cache:)
RedisCacheKeyPrefix string
}
// Validate checks that the ProxyConfig is valid and returns an error if not.
func (c *ProxyConfig) Validate() error {
if c.TargetBaseURL == "" {
return errors.New("TargetBaseURL must not be empty")
}
if len(c.AllowedMethods) == 0 {
return errors.New("AllowedMethods must not be empty")
}
if len(c.AllowedEndpoints) == 0 {
return errors.New("AllowedEndpoints must not be empty")
}
return nil
}
// ErrorResponse is the standard format for error responses
type ErrorResponse struct {
Error string `json:"error"`
Description string `json:"description,omitempty"`
Code string `json:"code,omitempty"`
}
// Middleware defines a function that wraps an http.Handler
type Middleware func(http.Handler) http.Handler
// Chain applies a series of middleware to a handler
func Chain(h http.Handler, middleware ...Middleware) http.Handler {
for i := len(middleware) - 1; i >= 0; i-- {
h = middleware[i](h)
}
return h
}
// contextKey is a type for context keys
type contextKey string
// version is the current version of the proxy
const version = "0.1.0"
const (
// ctxKeyRequestID is the context key for the request ID
ctxKeyRequestID contextKey = "request_id"
// ctxKeyProjectID is the context key for the project ID
ctxKeyProjectID contextKey = "project_id"
// ctxKeyTokenID is the context key for the token ID (used for cache stats)
ctxKeyTokenID contextKey = "token_id"
// ctxKeyLogger is the context key for a request-scoped logger
ctxKeyLogger contextKey = "logger"
// ctxKeyOriginalPath stores the original request path before proxy rewriting
ctxKeyOriginalPath contextKey = "original_path"
// ctxKeyValidationError carries token validation error (if any)
ctxKeyValidationError contextKey = "validation_error"
// Timing keys for observability
ctxKeyProxyReceivedAt contextKey = "proxy_received_at"
ctxKeyProxySentBackendAt contextKey = "proxy_sent_backend_at"
ctxKeyProxyFirstRespAt contextKey = "proxy_first_resp_at"
ctxKeyProxyFinalRespAt contextKey = "proxy_final_resp_at"
// ctxKeyRequestStart marks the time when a handler started processing
ctxKeyRequestStart contextKey = "request_start"
)
// Project represents a project for the management API and proxy
// (copied from database/models.go)
type Project struct {
ID string `json:"id"`
Name string `json:"name"`
APIKey string `json:"api_key"` // Provider-agnostic API key. Encrypted when ENCRYPTION_KEY is set.
IsActive bool `json:"is_active"`
DeactivatedAt *time.Time `json:"deactivated_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
package proxy
import (
"container/list"
"context"
"sync"
"time"
)
// CachedProjectActiveStore wraps a ProjectStore with an in-memory TTL+LRU cache for GetProjectActive.
//
// Rationale: GetProjectActive is on the hot path when EnforceProjectActive is enabled and can be a DB lookup.
// Caching avoids per-request DB round-trips in steady state.
type CachedProjectActiveStore struct {
underlying ProjectStore
cache *projectActiveCache
}
type CachedProjectActiveStoreConfig struct {
TTL time.Duration
Max int
}
func NewCachedProjectActiveStore(underlying ProjectStore, cfg CachedProjectActiveStoreConfig) *CachedProjectActiveStore {
if cfg.TTL <= 0 {
cfg.TTL = 5 * time.Second
}
if cfg.Max <= 0 {
cfg.Max = 10000
}
return &CachedProjectActiveStore{
underlying: underlying,
cache: newProjectActiveCache(cfg.TTL, cfg.Max),
}
}
func (s *CachedProjectActiveStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
if projectID != "" {
if v, ok := s.cache.Get(projectID); ok {
return v, nil
}
}
active, err := s.underlying.GetProjectActive(ctx, projectID)
if err != nil {
return false, err
}
if projectID != "" {
// Cache both true and false to avoid repeated DB lookups for inactive projects.
// Invalidation happens on Update/Delete/Create; TTL bounds staleness for out-of-band changes.
s.cache.Set(projectID, active)
}
return active, nil
}
func (s *CachedProjectActiveStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
return s.underlying.GetAPIKeyForProject(ctx, projectID)
}
func (s *CachedProjectActiveStore) ListProjects(ctx context.Context) ([]Project, error) {
return s.underlying.ListProjects(ctx)
}
func (s *CachedProjectActiveStore) CreateProject(ctx context.Context, project Project) error {
if err := s.underlying.CreateProject(ctx, project); err != nil {
return err
}
if project.ID != "" {
s.cache.Purge(project.ID)
}
return nil
}
func (s *CachedProjectActiveStore) GetProjectByID(ctx context.Context, projectID string) (Project, error) {
return s.underlying.GetProjectByID(ctx, projectID)
}
func (s *CachedProjectActiveStore) UpdateProject(ctx context.Context, project Project) error {
if err := s.underlying.UpdateProject(ctx, project); err != nil {
return err
}
if project.ID != "" {
s.cache.Purge(project.ID)
}
return nil
}
func (s *CachedProjectActiveStore) DeleteProject(ctx context.Context, projectID string) error {
if err := s.underlying.DeleteProject(ctx, projectID); err != nil {
return err
}
if projectID != "" {
s.cache.Purge(projectID)
}
return nil
}
type projectActiveCacheEntry struct {
key string
value bool
expiresAt time.Time
elem *list.Element
}
type projectActiveCache struct {
mu sync.Mutex
ll *list.List
m map[string]*projectActiveCacheEntry
ttl time.Duration
max int
}
func newProjectActiveCache(ttl time.Duration, max int) *projectActiveCache {
return &projectActiveCache{
ll: list.New(),
m: make(map[string]*projectActiveCacheEntry, max),
ttl: ttl,
max: max,
}
}
func (c *projectActiveCache) Get(key string) (bool, bool) {
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
ent := c.m[key]
if ent == nil {
return false, false
}
if now.After(ent.expiresAt) {
c.removeLocked(ent)
return false, false
}
c.ll.MoveToFront(ent.elem)
return ent.value, true
}
func (c *projectActiveCache) Set(key string, value bool) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.m[key]; ent != nil {
ent.value = value
ent.expiresAt = time.Now().Add(c.ttl)
c.ll.MoveToFront(ent.elem)
return
}
elem := c.ll.PushFront(key)
ent := &projectActiveCacheEntry{key: key, value: value, expiresAt: time.Now().Add(c.ttl), elem: elem}
c.m[key] = ent
if c.max > 0 && c.ll.Len() > c.max {
c.evictOldestLocked()
}
}
func (c *projectActiveCache) Purge(key string) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.m[key]; ent != nil {
c.removeLocked(ent)
}
}
func (c *projectActiveCache) evictOldestLocked() {
elem := c.ll.Back()
if elem == nil {
return
}
key, ok := elem.Value.(string)
if !ok {
c.ll.Remove(elem)
return
}
if ent := c.m[key]; ent != nil {
c.removeLocked(ent)
return
}
c.ll.Remove(elem)
}
func (c *projectActiveCache) removeLocked(ent *projectActiveCacheEntry) {
delete(c.m, ent.key)
if ent.elem != nil {
c.ll.Remove(ent.elem)
}
}
package proxy
import (
"context"
"encoding/json"
"net/http"
"strings"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/logging"
"go.uber.org/zap"
)
// shouldAllowProject determines whether the request should proceed based on project active status.
// It returns (allowed, statusCode, errorResponse). When allowed is true, statusCode and errorResponse are ignored.
// It also emits audit events for denied or error cases.
func shouldAllowProject(ctx context.Context, enforceActive bool, checker ProjectActiveChecker, projectID string, auditLogger AuditLogger, r *http.Request) (bool, int, ErrorResponse) {
if !enforceActive {
return true, 0, ErrorResponse{}
}
isActive, err := checker.GetProjectActive(ctx, projectID)
if err != nil {
// Get request metadata for audit events (only on error)
requestID, _ := logging.GetRequestID(ctx)
clientIP := getClientIP(r)
userAgent := r.UserAgent()
if logger := getLoggerFromContext(ctx); logger != nil {
logger.Error("Failed to check project active status",
zap.String("project_id", projectID),
zap.Error(err))
}
// Emit audit event for service unavailable
if auditLogger != nil {
auditEvent := audit.NewEvent(audit.ActionProxyRequest, audit.ActorSystem, audit.ResultError).
WithProjectID(projectID).
WithRequestID(requestID).
WithClientIP(clientIP).
WithUserAgent(userAgent).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithReason("service_unavailable").
WithError(err)
_ = auditLogger.Log(auditEvent)
}
return false, http.StatusServiceUnavailable, ErrorResponse{Error: "Service temporarily unavailable", Code: "service_unavailable"}
}
if !isActive {
// Get request metadata for audit events (only on deny)
requestID, _ := logging.GetRequestID(ctx)
clientIP := getClientIP(r)
userAgent := r.UserAgent()
// Emit audit event for project inactive denial
if auditLogger != nil {
auditEvent := audit.NewEvent(audit.ActionProxyRequest, audit.ActorSystem, audit.ResultDenied).
WithProjectID(projectID).
WithRequestID(requestID).
WithClientIP(clientIP).
WithUserAgent(userAgent).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithReason("project_inactive")
_ = auditLogger.Log(auditEvent)
}
return false, http.StatusForbidden, ErrorResponse{Error: "Project is inactive", Code: "project_inactive"}
}
return true, 0, ErrorResponse{}
}
// ProjectActiveGuardMiddleware creates middleware that enforces project active status
// If enforceActive is false, the middleware passes through all requests without checking
// If enforceActive is true, inactive projects receive a 403 Forbidden response
func ProjectActiveGuardMiddleware(enforceActive bool, checker ProjectActiveChecker, auditLogger AuditLogger) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get project ID from context (should be set by token validation middleware)
projectIDValue := r.Context().Value(ctxKeyProjectID)
if projectIDValue == nil {
writeErrorResponse(w, http.StatusInternalServerError, ErrorResponse{
Error: "Internal server error",
Code: "internal_error",
Description: "missing project ID in request context",
})
return
}
projectID, ok := projectIDValue.(string)
if !ok || projectID == "" {
writeErrorResponse(w, http.StatusInternalServerError, ErrorResponse{
Error: "Internal server error",
Code: "internal_error",
Description: "invalid project ID in request context",
})
return
}
if allowed, status, er := shouldAllowProject(r.Context(), enforceActive, checker, projectID, auditLogger, r); !allowed {
writeErrorResponse(w, status, er)
return
}
// Project is active, continue with request
next.ServeHTTP(w, r)
})
}
}
// writeErrorResponse writes a JSON error response
func writeErrorResponse(w http.ResponseWriter, statusCode int, errorResp ErrorResponse) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResp); err != nil {
// No request context available here, so we cannot access a scoped logger safely
// Consider hooking a global logger if needed; for now we silently ignore to avoid panics
_ = err
}
}
// getLoggerFromContext tries to get a logger from context
// This is a helper function that works with the existing logging context
func getLoggerFromContext(ctx context.Context) *zap.Logger {
// Try to get logger from context if available
// If not available, return nil and let caller handle gracefully
if loggerValue := ctx.Value(ctxKeyLogger); loggerValue != nil {
if logger, ok := loggerValue.(*zap.Logger); ok {
return logger
}
}
return nil
}
// getClientIP extracts the client IP address from the request
// It checks X-Forwarded-For, X-Real-IP headers and falls back to RemoteAddr
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (comma-separated list)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP from the list
if idx := strings.Index(xff, ","); idx >= 0 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fallback to RemoteAddr (remove port if present)
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx >= 0 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
package proxy
import (
"container/list"
"context"
"sync"
"time"
)
// CachedProjectStore wraps a ProjectStore with an in-memory TTL+LRU cache for GetAPIKeyForProject.
//
// Rationale: GetAPIKeyForProject is on the hot path for cache misses and currently performs a DB query.
// Caching avoids per-request DB round-trips in steady state.
//
// Security note: The API key is stored in memory in plaintext (same as other in-memory config); this does
// not change persistence characteristics and is scoped to the process lifetime.
type CachedProjectStore struct {
underlying ProjectStore
cache *apiKeyCache
}
type CachedProjectStoreConfig struct {
TTL time.Duration
Max int
}
func NewCachedProjectStore(underlying ProjectStore, cfg CachedProjectStoreConfig) *CachedProjectStore {
if cfg.TTL <= 0 {
cfg.TTL = 30 * time.Second
}
if cfg.Max <= 0 {
cfg.Max = 10000
}
return &CachedProjectStore{
underlying: underlying,
cache: newAPIKeyCache(cfg.TTL, cfg.Max),
}
}
func (s *CachedProjectStore) GetAPIKeyForProject(ctx context.Context, projectID string) (string, error) {
if projectID != "" {
if v, ok := s.cache.Get(projectID); ok {
return v, nil
}
}
apiKey, err := s.underlying.GetAPIKeyForProject(ctx, projectID)
if err != nil {
return "", err
}
if projectID != "" && apiKey != "" {
// Intentionally do not cache empty API keys: an empty key typically indicates misconfiguration
// and we prefer not to "stick" that state in cache while an operator fixes the project.
s.cache.Set(projectID, apiKey)
}
return apiKey, nil
}
func (s *CachedProjectStore) GetProjectActive(ctx context.Context, projectID string) (bool, error) {
return s.underlying.GetProjectActive(ctx, projectID)
}
func (s *CachedProjectStore) ListProjects(ctx context.Context) ([]Project, error) {
return s.underlying.ListProjects(ctx)
}
func (s *CachedProjectStore) CreateProject(ctx context.Context, project Project) error {
if err := s.underlying.CreateProject(ctx, project); err != nil {
return err
}
if project.ID != "" {
// Defensive purge: ensures we never serve a stale API key for a re-created project ID
// (e.g., delete+recreate with same ID, or out-of-band DB changes).
s.cache.Purge(project.ID)
}
return nil
}
func (s *CachedProjectStore) GetProjectByID(ctx context.Context, projectID string) (Project, error) {
return s.underlying.GetProjectByID(ctx, projectID)
}
func (s *CachedProjectStore) UpdateProject(ctx context.Context, project Project) error {
if err := s.underlying.UpdateProject(ctx, project); err != nil {
return err
}
if project.ID != "" {
s.cache.Purge(project.ID)
}
return nil
}
func (s *CachedProjectStore) DeleteProject(ctx context.Context, projectID string) error {
if err := s.underlying.DeleteProject(ctx, projectID); err != nil {
return err
}
if projectID != "" {
s.cache.Purge(projectID)
}
return nil
}
type apiKeyCacheEntry struct {
key string
value string
expiresAt time.Time
elem *list.Element
}
type apiKeyCache struct {
mu sync.Mutex
ll *list.List
m map[string]*apiKeyCacheEntry
ttl time.Duration
max int
}
func newAPIKeyCache(ttl time.Duration, max int) *apiKeyCache {
return &apiKeyCache{
ll: list.New(),
m: make(map[string]*apiKeyCacheEntry, max),
ttl: ttl,
max: max,
}
}
func (c *apiKeyCache) Get(key string) (string, bool) {
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
ent := c.m[key]
if ent == nil {
return "", false
}
if now.After(ent.expiresAt) {
c.removeLocked(ent)
return "", false
}
c.ll.MoveToFront(ent.elem)
return ent.value, true
}
func (c *apiKeyCache) Set(key, value string) {
if key == "" {
return
}
now := time.Now()
exp := now.Add(c.ttl)
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.m[key]; ent != nil {
ent.value = value
ent.expiresAt = exp
c.ll.MoveToFront(ent.elem)
return
}
elem := c.ll.PushFront(key)
ent := &apiKeyCacheEntry{key: key, value: value, expiresAt: exp, elem: elem}
c.m[key] = ent
if c.max > 0 && c.ll.Len() > c.max {
c.evictOldestLocked()
}
}
func (c *apiKeyCache) Purge(key string) {
c.mu.Lock()
defer c.mu.Unlock()
if ent := c.m[key]; ent != nil {
c.removeLocked(ent)
}
}
func (c *apiKeyCache) evictOldestLocked() {
elem := c.ll.Back()
if elem == nil {
return
}
key, ok := elem.Value.(string)
if !ok {
// Value type is unexpected; remove the list element defensively.
c.ll.Remove(elem)
return
}
ent := c.m[key]
if ent != nil {
c.removeLocked(ent)
return
}
// Shouldn't happen, but be defensive.
c.ll.Remove(elem)
}
func (c *apiKeyCache) removeLocked(ent *apiKeyCacheEntry) {
delete(c.m, ent.key)
if ent.elem != nil {
c.ll.Remove(ent.elem)
}
}
package proxy
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
"crypto/sha256"
"encoding/hex"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/middleware"
"github.com/sofatutor/llm-proxy/internal/token"
"go.uber.org/zap"
)
// TransparentProxy implements the Proxy interface for transparent proxying
type TransparentProxy struct {
config ProxyConfig
tokenValidator TokenValidator
projectStore ProjectStore
logger *zap.Logger
auditLogger AuditLogger
metrics *ProxyMetrics
proxy *httputil.ReverseProxy
targetURL *url.URL
httpServer *http.Server
shuttingDown bool
mu sync.RWMutex
allowedMethodsHeader string // cached comma-separated allowed methods
obsMiddleware *middleware.ObservabilityMiddleware
cache httpCache
cacheStatsAggregator *CacheStatsAggregator
}
// ProxyMetrics tracks proxy usage statistics
type ProxyMetrics struct {
RequestCount int64
ErrorCount int64
TotalResponseTime time.Duration
// Cache metrics (provider-agnostic counters)
CacheHits int64 // Cache hits (responses served from cache)
CacheMisses int64 // Cache misses (responses fetched from upstream)
CacheBypass int64 // Cache bypassed (e.g., due to authorization)
CacheStores int64 // Cache stores (responses stored in cache)
mu sync.Mutex
}
// CacheMetricType represents the kind of cache metric to increment.
type CacheMetricType int
const (
CacheMetricHit CacheMetricType = iota
CacheMetricMiss
CacheMetricBypass
CacheMetricStore
)
// Metrics returns a copy of the current proxy metrics.
// Returns a value copy to ensure thread-safety when reading metrics.
func (p *TransparentProxy) Metrics() ProxyMetrics {
// Defensive nil guard for p.metrics
if p.metrics == nil {
return ProxyMetrics{}
}
p.metrics.mu.Lock()
defer p.metrics.mu.Unlock()
// Return a copy to avoid race conditions when accessing fields
return ProxyMetrics{
RequestCount: p.metrics.RequestCount,
ErrorCount: p.metrics.ErrorCount,
TotalResponseTime: p.metrics.TotalResponseTime,
CacheHits: p.metrics.CacheHits,
CacheMisses: p.metrics.CacheMisses,
CacheBypass: p.metrics.CacheBypass,
CacheStores: p.metrics.CacheStores,
}
}
// SetMetrics overwrites the current metrics (primarily for testing).
func (p *TransparentProxy) SetMetrics(m *ProxyMetrics) {
p.mu.Lock()
defer p.mu.Unlock()
p.metrics = m
}
// Cache returns the HTTP cache instance for management operations.
// Returns nil if caching is disabled.
func (p *TransparentProxy) Cache() httpCache {
p.mu.RLock()
defer p.mu.RUnlock()
return p.cache
}
// SetCacheStatsAggregator sets the cache stats aggregator for per-token cache hit tracking.
func (p *TransparentProxy) SetCacheStatsAggregator(agg *CacheStatsAggregator) {
p.mu.Lock()
defer p.mu.Unlock()
p.cacheStatsAggregator = agg
}
// isVaryCompatible reports whether a cached response with a given Vary header
// is valid for the current request and lookup key.
func isVaryCompatible(r *http.Request, cr cachedResponse, lookupKey string) bool {
if cr.vary == "" || cr.vary == "*" {
return true
}
varyKey := CacheKeyFromRequestWithVary(r, cr.vary)
return varyKey == lookupKey
}
// storageKeyForResponse returns the cache storage key to use for a response,
// based on the upstream Vary header. Falls back to the lookup key when Vary is empty or '*'.
func storageKeyForResponse(r *http.Request, varyHeader string, lookupKey string) string {
if varyHeader != "" && varyHeader != "*" {
return CacheKeyFromRequestWithVary(r, varyHeader)
}
return lookupKey
}
// incrementCacheMetric safely increments the specified cache metric counter.
func (p *TransparentProxy) incrementCacheMetric(metric CacheMetricType) {
p.metrics.mu.Lock()
defer p.metrics.mu.Unlock()
switch metric {
case CacheMetricHit:
p.metrics.CacheHits++
case CacheMetricMiss:
p.metrics.CacheMisses++
case CacheMetricBypass:
p.metrics.CacheBypass++
case CacheMetricStore:
p.metrics.CacheStores++
}
}
// NewTransparentProxy creates a new proxy instance with an internally
// configured logger based on the provided ProxyConfig.
func NewTransparentProxy(config ProxyConfig, validator TokenValidator, store ProjectStore) (*TransparentProxy, error) {
logger, err := logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
return NewTransparentProxyWithLogger(config, validator, store, logger)
}
// NewTransparentProxyWithObservability creates a new proxy with observability middleware.
func NewTransparentProxyWithObservability(config ProxyConfig, validator TokenValidator, store ProjectStore, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
logger, err := logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
return NewTransparentProxyWithLoggerAndObservability(config, validator, store, logger, obsCfg)
}
// NewTransparentProxyWithLogger allows providing a custom logger. If logger is nil
// a new one is created based on the ProxyConfig.
func NewTransparentProxyWithLogger(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger) (*TransparentProxy, error) {
if logger == nil {
var err error
logger, err = logging.NewLogger(config.LogLevel, config.LogFormat, config.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
}
// Precompute allowed methods header
allowedMethodsHeader := "GET, POST, PUT, PATCH, DELETE, OPTIONS"
if len(config.AllowedMethods) > 0 {
allowedMethodsHeader = strings.Join(config.AllowedMethods, ", ")
}
targetURL, err := url.Parse(config.TargetBaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse target base URL: %w", err)
}
proxy := &TransparentProxy{
config: config,
tokenValidator: validator,
projectStore: store,
logger: logger,
metrics: &ProxyMetrics{},
allowedMethodsHeader: allowedMethodsHeader,
targetURL: targetURL,
}
// Initialize HTTP cache (enabled only when HTTPCacheEnabled is true)
if !config.HTTPCacheEnabled {
logger.Info("HTTP cache disabled")
proxy.cache = nil
} else {
if config.RedisCacheURL != "" {
if opt, err := redis.ParseURL(config.RedisCacheURL); err == nil {
// Tune Redis client defaults for cache workloads.
// go-redis default pool sizing depends on GOMAXPROCS and can be too small
// for bursty cache-hit traffic (especially when an ingress pins many
// connections to a single pod).
//
// Env overrides:
// - REDIS_CACHE_POOL_SIZE: int (e.g., 50, 100)
// - REDIS_CACHE_TIMEOUT: duration (Go duration string, e.g., 1s, 250ms)
// Applies to dial/read/write.
envPoolSize := 0
if v := os.Getenv("REDIS_CACHE_POOL_SIZE"); v != "" {
if n, convErr := strconv.Atoi(v); convErr == nil && n > 0 {
envPoolSize = n
}
}
if envPoolSize > 0 {
// Env explicitly wins over redis:// URL options.
opt.PoolSize = envPoolSize
} else if opt.PoolSize < 50 {
// Minimum default for cache workloads.
opt.PoolSize = 50
}
timeout := 1 * time.Second
if v := os.Getenv("REDIS_CACHE_TIMEOUT"); v != "" {
if d, convErr := time.ParseDuration(v); convErr == nil && d > 0 {
timeout = d
}
}
if opt.DialTimeout == 0 {
opt.DialTimeout = timeout
}
if opt.ReadTimeout == 0 {
opt.ReadTimeout = timeout
}
if opt.WriteTimeout == 0 {
opt.WriteTimeout = timeout
}
client := redis.NewClient(opt)
proxy.cache = newRedisCache(client, config.RedisCacheKeyPrefix)
logger.Info(
"HTTP cache enabled",
zap.String("backend", "redis"),
zap.String("redis_addr", opt.Addr),
zap.Int("redis_pool_size", opt.PoolSize),
zap.Duration("redis_dial_timeout", opt.DialTimeout),
zap.Duration("redis_read_timeout", opt.ReadTimeout),
zap.Duration("redis_write_timeout", opt.WriteTimeout),
)
} else {
proxy.cache = newInMemoryCache()
logger.Warn("Failed to parse RedisCacheURL; falling back to in-memory cache", zap.Error(err))
}
} else {
proxy.cache = newInMemoryCache()
logger.Info("HTTP cache enabled", zap.String("backend", "in-memory"))
}
}
// Initialize the reverse proxy
reverseProxy := &httputil.ReverseProxy{
Director: proxy.director,
ModifyResponse: proxy.modifyResponse,
ErrorHandler: proxy.errorHandler,
Transport: proxy.createTransport(),
FlushInterval: config.FlushInterval,
}
proxy.proxy = reverseProxy
return proxy, nil
}
// NewTransparentProxyWithLoggerAndObservability creates a proxy with observability middleware using an existing logger.
func NewTransparentProxyWithLoggerAndObservability(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
p, err := NewTransparentProxyWithLogger(config, validator, store, logger)
if err != nil {
return nil, err
}
p.obsMiddleware = middleware.NewObservabilityMiddleware(obsCfg, logger)
return p, nil
}
// NewTransparentProxyWithAudit creates a proxy with audit logging capabilities.
func NewTransparentProxyWithAudit(config ProxyConfig, validator TokenValidator, store ProjectStore, logger *zap.Logger, auditLogger AuditLogger, obsCfg middleware.ObservabilityConfig) (*TransparentProxy, error) {
p, err := NewTransparentProxyWithLoggerAndObservability(config, validator, store, logger, obsCfg)
if err != nil {
return nil, err
}
p.auditLogger = auditLogger
return p, nil
}
// director is the Director function for the reverse proxy
func (p *TransparentProxy) director(req *http.Request) {
// Store original path in context for logging
*req = *req.WithContext(context.WithValue(req.Context(), ctxKeyOriginalPath, req.URL.Path))
// Update request URL
req.URL.Scheme = p.targetURL.Scheme
req.URL.Host = p.targetURL.Host
req.Host = p.targetURL.Host
// Add proxy identification headers
req.Header.Set("X-Proxy", "llm-proxy")
req.Header.Set("X-Proxy-Version", version)
if pid, ok := req.Context().Value(ctxKeyProjectID).(string); ok {
req.Header.Set("X-Proxy-ID", pid)
}
// Preserve or strip certain headers
p.processRequestHeaders(req)
// --- PATCH: Add X-UPSTREAM-REQUEST-START header ---
upstreamStart := time.Now().UnixNano()
req.Header.Set("X-UPSTREAM-REQUEST-START", strconv.FormatInt(upstreamStart, 10))
if !p.logger.Core().Enabled(zap.DebugLevel) {
return
}
requestID, _ := req.Context().Value(ctxKeyRequestID).(string)
p.logger.Debug("Proxying request",
zap.String("request_id", requestID),
zap.String("method", req.Method),
zap.String("path", req.URL.Path),
zap.String("project_id", req.Header.Get("X-Proxy-ID")),
)
// Verbose upstream request logging
headers := make(map[string][]string)
for k, v := range req.Header {
headers[k] = v
}
p.logger.Debug("Upstream request",
zap.String("request_id", requestID),
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.Any("headers", headers),
)
}
// processRequestHeaders handles the manipulation of request headers
func (p *TransparentProxy) processRequestHeaders(req *http.Request) {
// Headers to remove for security/privacy reasons
headersToRemove := []string{
"X-Forwarded-For", // We'll set this ourselves if needed
"X-Real-IP", // Remove client IP for privacy
"CF-Connecting-IP", // Cloudflare headers
"CF-IPCountry", // Cloudflare headers
"X-Client-IP", // Other proxies
"X-Original-Forwarded-For", // Chain of proxies
}
// Remove headers that shouldn't be passed to the upstream
for _, header := range headersToRemove {
req.Header.Del(header)
}
// Set X-Forwarded-For if configured to do so
if p.config.SetXForwardedFor {
// Get the client IP
clientIP := req.RemoteAddr
// Remove port if present
if idx := strings.LastIndex(clientIP, ":"); idx != -1 {
clientIP = clientIP[:idx]
}
req.Header.Set("X-Forwarded-For", clientIP)
}
// If Content-Length is 0 and there's a body, let Go calculate the correct Content-Length
if req.ContentLength == 0 && req.Body != nil {
req.Header.Del("Content-Length")
}
// Ensure proper Accept header for SSE streaming if needed
if isStreamingRequest(req) && req.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
// calculateCacheTTL determines the effective TTL for caching a response.
// It prefers TTL from response headers (when the response is cacheable),
// otherwise falls back to client-forced TTL from the request. It returns the
// chosen TTL and whether it came from the response headers.
func calculateCacheTTL(res *http.Response, req *http.Request, defaultTTL time.Duration) (time.Duration, bool) {
if res == nil || req == nil {
return 0, false
}
respTTL := cacheTTLFromHeaders(res, defaultTTL)
if respTTL > 0 {
if !isResponseCacheable(res) {
return 0, false
}
return respTTL, true
}
forcedTTL := requestForcedCacheTTL(req)
if forcedTTL > 0 {
return forcedTTL, false
}
return 0, false
}
func (p *TransparentProxy) modifyResponse(res *http.Response) error {
// Set proxy headers (always)
res.Header.Set("X-Proxy", "llm-proxy")
// Attach basic timing + request ID headers before ReverseProxy writes headers to the client.
// Note: ReverseProxy copies headers then calls WriteHeader, so post-WriteHeader mutations won't be visible.
if res.Request != nil {
firstRespAt := time.Now().UTC()
ctx := context.WithValue(res.Request.Context(), ctxKeyProxyFirstRespAt, firstRespAt)
ctx = context.WithValue(ctx, ctxKeyProxyFinalRespAt, firstRespAt)
res.Request = res.Request.WithContext(ctx)
setTimingHeaders(res, res.Request.Context())
if requestID, ok := res.Request.Context().Value(ctxKeyRequestID).(string); ok && requestID != "" {
res.Header.Set("X-Request-ID", requestID)
}
}
// --- PATCH: Add X-UPSTREAM-REQUEST-STOP header ---
upstreamStop := time.Now().UnixNano()
res.Header.Set("X-UPSTREAM-REQUEST-STOP", strconv.FormatInt(upstreamStop, 10))
// For streaming responses, skip heavy side effects (metadata extraction, caching) but keep headers.
if isStreaming(res) {
return nil
}
// NOTE: Response metadata extraction (X-OpenAI-*) is intentionally not performed here.
// Reading/parsing response bodies in ModifyResponse can add latency/GC pressure on the hot path.
// Metadata extraction for logging/observability is handled asynchronously in the observability middleware.
// Update metrics
p.metrics.mu.Lock()
p.metrics.RequestCount++
if res.StatusCode >= 400 {
p.metrics.ErrorCount++
}
p.metrics.mu.Unlock()
// --- PATCH: Copy X-UPSTREAM-REQUEST-START from request to response ---
if res.Request != nil {
if v := res.Request.Header.Get("X-UPSTREAM-REQUEST-START"); v != "" {
res.Header.Set("X-UPSTREAM-REQUEST-START", v)
}
}
// Store in cache when enabled and request is cacheable
if p.cache != nil && res.Request != nil {
req := res.Request
if req.Method == http.MethodGet || req.Method == http.MethodHead || req.Method == http.MethodPost {
// Only cache successful responses
if res.StatusCode < 200 || res.StatusCode >= 300 {
res.Header.Set("X-CACHE-DEBUG", "status-not-cacheable")
return nil
}
// Calculate effective TTL
ttl, fromResponse := calculateCacheTTL(res, req, p.config.HTTPCacheDefaultTTL)
if ttl <= 0 {
res.Header.Set("X-CACHE-DEBUG", fmt.Sprintf("ttl-zero-ttl=%v-from-resp=%v", ttl, fromResponse))
return nil
}
// Ensure Cache-Status preserves miss set by handler
if res.Header.Get("Cache-Status") == "" {
res.Header.Set("Cache-Status", "llm-proxy; miss")
}
key := CacheKeyFromRequest(req)
// Compute storage key via helper to respect Vary
storageKey := storageKeyForResponse(req, res.Header.Get("Vary"), key)
if !isStreaming(res) {
bodyBytes, err := io.ReadAll(res.Body)
if err == nil {
_ = res.Body.Close()
res.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if p.config.HTTPCacheMaxObjectBytes == 0 || int64(len(bodyBytes)) <= p.config.HTTPCacheMaxObjectBytes {
headers := cloneHeadersForCache(res.Header)
if !fromResponse {
headers.Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(ttl.Seconds())))
}
// Store the Vary header for per-response cache key generation
varyValue := res.Header.Get("Vary")
cr := cachedResponse{
statusCode: res.StatusCode,
headers: headers,
body: bodyBytes,
expiresAt: time.Now().Add(ttl),
vary: varyValue,
}
p.cache.Set(storageKey, cr)
res.Header.Set("X-PROXY-CACHE", "stored")
res.Header.Set("X-PROXY-CACHE-KEY", storageKey)
p.incrementCacheMetric(CacheMetricStore)
if !fromResponse {
res.Header.Set("Cache-Status", "llm-proxy; stored (forced)")
} else {
res.Header.Set("Cache-Status", "llm-proxy; stored")
}
}
} else {
res.Header.Set("X-CACHE-DEBUG", fmt.Sprintf("read-body-error=%v", err))
}
} else {
res.Header.Set("X-CACHE-DEBUG", "streaming-response")
maxBytes := p.config.HTTPCacheMaxObjectBytes
if maxBytes <= 0 {
maxBytes = 2 * 1024 * 1024 // default 2MB
}
headers := cloneHeadersForCache(res.Header)
if !fromResponse {
headers.Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(ttl.Seconds())))
}
// Store the Vary header for per-response cache key generation
varyValue := res.Header.Get("Vary")
// Compute storage key via helper
storageKey := storageKeyForResponse(req, varyValue, key)
expiresAt := time.Now().Add(ttl)
orig := res.Body
res.Body = newStreamingCapture(orig, maxBytes, func(buf []byte) {
if len(buf) == 0 {
return
}
if int64(len(buf)) > maxBytes {
return
}
p.cache.Set(storageKey, cachedResponse{
statusCode: res.StatusCode,
headers: headers,
body: append([]byte(nil), buf...),
expiresAt: expiresAt,
vary: varyValue,
})
p.incrementCacheMetric(CacheMetricStore)
})
}
}
}
// Set miss status if no cache status was set
if res.Header.Get("Cache-Status") == "" {
res.Header.Set("Cache-Status", "llm-proxy; miss")
}
return nil
}
// errorHandler handles errors that occur during proxying
func (p *TransparentProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
logProxyTimings(p.logger, r.Context())
// Check if there was a validation error
if validationErr, ok := r.Context().Value(ctxKeyValidationError).(error); ok {
p.handleValidationError(w, r, validationErr)
return
}
// Handle different error types
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
p.logger.Error("Proxy error",
zap.String("request_id", requestID),
zap.Error(err),
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
statusCode := http.StatusBadGateway
errorResponse := ErrorResponse{
Error: "Proxy error",
}
switch {
case errors.Is(err, context.DeadlineExceeded):
statusCode = http.StatusGatewayTimeout
errorResponse.Error = "Request timeout"
errorResponse.Code = "timeout"
case errors.Is(err, context.Canceled):
statusCode = http.StatusRequestTimeout
errorResponse.Error = "Request canceled"
errorResponse.Code = "canceled"
default:
// Use default values
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
}
// handleValidationError handles errors specific to token validation
func (p *TransparentProxy) handleValidationError(w http.ResponseWriter, r *http.Request, err error) {
// Get request ID and token directly from the request
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
var obfuscatedToken string
authHeader := r.Header.Get("Authorization")
if len(authHeader) > 7 && strings.HasPrefix(authHeader, "Bearer ") {
tok := strings.TrimSpace(authHeader[7:])
obfuscatedToken = token.ObfuscateToken(tok)
}
statusCode := http.StatusUnauthorized
errorResponse := ErrorResponse{
Error: "Authentication error",
}
// Check for specific token errors
switch {
case errors.Is(err, token.ErrTokenNotFound):
errorResponse.Error = "Token not found"
errorResponse.Code = "token_not_found"
case errors.Is(err, token.ErrTokenInactive):
errorResponse.Error = "Token is inactive"
errorResponse.Code = "token_inactive"
case errors.Is(err, token.ErrTokenExpired):
errorResponse.Error = "Token has expired"
errorResponse.Code = "token_expired"
case errors.Is(err, token.ErrTokenRateLimit):
statusCode = http.StatusTooManyRequests
errorResponse.Error = "Rate limit exceeded"
errorResponse.Code = "rate_limit_exceeded"
default:
errorResponse.Error = "Invalid token"
errorResponse.Description = err.Error()
errorResponse.Code = "invalid_token"
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorResponse); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
p.logger.Error("Validation error",
zap.String("request_id", requestID),
zap.Error(err),
zap.String("token", obfuscatedToken),
)
}
// createTransport creates an HTTP transport with appropriate settings
func (p *TransparentProxy) createTransport() *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: p.config.MaxIdleConns,
MaxIdleConnsPerHost: p.config.MaxIdleConnsPerHost,
IdleConnTimeout: p.config.IdleConnTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: p.config.ResponseHeaderTimeout,
}
}
// getMaxBodyHashBytes returns the configured max body size for hashing, or a 1MB default.
func (p *TransparentProxy) getMaxBodyHashBytes() int64 {
if p.config.HTTPCacheMaxObjectBytes > 0 {
return p.config.HTTPCacheMaxObjectBytes
}
return 1024 * 1024
}
type readerWithCloser struct {
r io.Reader
c io.Closer
}
func (rc *readerWithCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
func (rc *readerWithCloser) Close() error { return rc.c.Close() }
// prepareBodyHashForCaching reads the request body up to maxBytes,
// computes a SHA-256 hash, sets X-Body-Hash header, and restores the body.
// Returns true if successful, false if body exceeds limits or read fails.
func prepareBodyHashForCaching(r *http.Request, maxBytes int64, logger *zap.Logger) bool {
if r.Body == nil || r.ContentLength < 0 || r.ContentLength > maxBytes {
return false
}
originalBody := r.Body
// Enforce a hard limit on how much we read from the body to avoid unbounded memory usage.
// We read up to maxBytes+1 so we can detect if the body is larger than allowed.
limitedReader := io.LimitReader(originalBody, maxBytes+1)
bodyBytes, readErr := io.ReadAll(limitedReader)
if readErr != nil {
logger.Warn("Failed to read request body for hashing", zap.Error(readErr))
// Restore the body with whatever we have read plus any remaining unread bytes.
// Note: io.LimitReader stops after maxBytes+1 bytes and does not drain the underlying body.
r.Body = &readerWithCloser{r: io.MultiReader(bytes.NewReader(bodyBytes), originalBody), c: originalBody}
return false
}
if int64(len(bodyBytes)) > maxBytes {
logger.Warn("Request body exceeds maxBytes for hashing; skipping body hash",
zap.Int64("max_bytes", maxBytes),
zap.Int64("read_bytes", int64(len(bodyBytes))),
)
// Restore the full body (the bytes we already consumed plus what remains unread).
r.Body = &readerWithCloser{r: io.MultiReader(bytes.NewReader(bodyBytes), originalBody), c: originalBody}
return false
}
// Body is within the allowed size. Restore it from the bytes we read and compute the hash.
r.Body = &readerWithCloser{r: bytes.NewReader(bodyBytes), c: originalBody}
sum := sha256.Sum256(bodyBytes)
r.Header.Set("X-Body-Hash", hex.EncodeToString(sum[:]))
return true
}
// Handler returns the HTTP handler for the proxy
func (p *TransparentProxy) Handler() http.Handler {
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Short-circuit OPTIONS requests: no auth required, respond with 204 and CORS headers
if r.Method == http.MethodOptions {
// Set CORS headers for preflight requests
if origin := r.Header.Get("Origin"); origin != "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
if reqHeaders := r.Header.Get("Access-Control-Request-Headers"); reqHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", reqHeaders)
} else {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With")
}
w.Header().Set("Access-Control-Expose-Headers", "X-Request-ID, X-Proxy-ID, X-LLM-Proxy-Remote-Duration, X-LLM-Proxy-Remote-Duration-Ms")
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours
}
w.WriteHeader(http.StatusNoContent)
return
}
// Generate request ID and add to context
requestID := uuid.NewString()
ctx := context.WithValue(r.Context(), ctxKeyRequestID, requestID)
// Also add to shared logging context so middlewares can read it
ctx = logging.WithRequestID(ctx, requestID)
r = r.WithContext(ctx)
// Set request header so observability/file logger can capture it (response header is still set in ModifyResponse only)
r.Header.Set("X-Request-ID", requestID)
// Record when proxy receives the request
receivedAt := time.Now().UTC()
ctx = context.WithValue(ctx, ctxKeyProxyReceivedAt, receivedAt)
r = r.WithContext(ctx)
// Pre-check cache so we can avoid token usage tracking / upstream auth lookup
// on true cache hits. We still enforce auth and project status before serving.
var (
preCacheKey string
preCacheRes cachedResponse
preCacheOK bool
)
if p.cache != nil && (r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodPost) {
// For POST, we need to check cache opt-in and prepare body hash early
if r.Method == http.MethodPost {
if !hasClientCacheOptIn(r) {
goto skipPreCache
}
if !prepareBodyHashForCaching(r, p.getMaxBodyHashBytes(), p.logger) {
goto skipPreCache
}
}
key := CacheKeyFromRequest(r)
if cr, ok := p.cache.Get(key); ok {
// Only treat as a fast-path cache hit if it is actually eligible to serve
// without going upstream.
if isVaryCompatible(r, cr, key) && canServeCachedForRequest(r, cr.headers) && !wantsRevalidation(r) {
preCacheKey = key
preCacheRes = cr
preCacheOK = true
}
}
}
skipPreCache:
// --- Token extraction and validation (moved from director) ---
authHeader := r.Header.Get("Authorization")
tokenStr := extractTokenFromHeader(authHeader)
if tokenStr == "" {
p.handleValidationError(w, r, errors.New("missing or invalid authorization header"))
return
}
var (
projectID string
err error
)
if preCacheOK {
// Avoid per-request DB updates on true cache hits.
projectID, err = p.tokenValidator.ValidateToken(r.Context(), tokenStr)
} else {
projectID, err = p.tokenValidator.ValidateTokenWithTracking(r.Context(), tokenStr)
}
if err != nil {
p.handleValidationError(w, r, err)
return
}
ctx = context.WithValue(r.Context(), ctxKeyProjectID, projectID)
ctx = context.WithValue(ctx, ctxKeyTokenID, tokenStr)
r = r.WithContext(ctx)
// Defer upstream API key lookup until we actually need to proxy upstream.
// This keeps cache-hit latency low under concurrency.
var (
upstreamAPIKey string
upstreamAPIKeyErr error
upstreamAPIKeyOnce sync.Once
)
ensureUpstreamAuthorization := func(reqToAuthorize *http.Request) bool {
upstreamAPIKeyOnce.Do(func() {
upstreamAPIKey, upstreamAPIKeyErr = p.projectStore.GetAPIKeyForProject(reqToAuthorize.Context(), projectID)
if upstreamAPIKeyErr != nil {
upstreamAPIKeyErr = fmt.Errorf("failed to get API key: %w", upstreamAPIKeyErr)
return
}
})
if upstreamAPIKeyErr != nil {
requestID, _ := reqToAuthorize.Context().Value(ctxKeyRequestID).(string)
p.logger.Error(
"Upstream API key lookup failed",
zap.String("request_id", requestID),
zap.String("project_id", projectID),
zap.Error(upstreamAPIKeyErr),
)
writeErrorResponse(w, http.StatusServiceUnavailable, ErrorResponse{
Error: "Upstream authentication error",
Code: "upstream_auth_error",
Description: "failed to load upstream API key",
})
return false
}
reqToAuthorize.Header.Set("Authorization", fmt.Sprintf("Bearer %s", upstreamAPIKey))
return true
}
// Enforce project active status using shared helper (if enabled)
if allowed, status, er := shouldAllowProject(r.Context(), p.config.EnforceProjectActive, p.projectStore, projectID, p.auditLogger, r); !allowed {
writeErrorResponse(w, status, er)
return
}
// Wrap the ResponseWriter to allow us to set headers at first/last byte
rw := &timingResponseWriter{ResponseWriter: w}
// Simple cache lookup with conditional handling (ETag/Last-Modified)
if p.cache != nil && (r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodPost) {
// Allow GET/HEAD lookups by default when cache is enabled, since reuse will still be gated by canServeCachedForRequest.
// Require explicit client opt-in for POST lookups.
optIn := hasClientCacheOptIn(r)
allowedLookup := (r.Method == http.MethodGet || r.Method == http.MethodHead) || (r.Method == http.MethodPost && optIn)
if r.Method == http.MethodPost && allowedLookup {
// Body already read and hashed in pre-cache check if we got here.
// If X-Body-Hash is not set, it means pre-cache was skipped, so read it now.
if r.Header.Get("X-Body-Hash") == "" {
if !prepareBodyHashForCaching(r, p.getMaxBodyHashBytes(), p.logger) {
allowedLookup = false
}
}
}
if !allowedLookup {
// Cache is enabled but this request type/method is not cacheable - count as miss
p.recordCacheMiss()
if !ensureUpstreamAuthorization(r) {
return
}
p.proxy.ServeHTTP(rw, r)
return
}
key := CacheKeyFromRequest(r)
var (
cr cachedResponse
ok bool
)
if preCacheOK {
key = preCacheKey
cr = preCacheRes
ok = true
} else {
cr, ok = p.cache.Get(key)
}
if ok {
// Validate Vary compatibility using helper
if !isVaryCompatible(r, cr, key) {
p.recordCacheMiss()
if !ensureUpstreamAuthorization(r) {
return
}
// Note: don't set miss status here; let modifyResponse handle cache status
p.proxy.ServeHTTP(rw, r)
return
}
if !canServeCachedForRequest(r, cr.headers) {
// Authorization present but cached response not explicitly shared-cacheable
w.Header().Set("Cache-Status", "llm-proxy; bypass")
w.Header().Set("X-PROXY-CACHE", "bypass")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.incrementCacheMetric(CacheMetricBypass)
if !ensureUpstreamAuthorization(r) {
return
}
p.proxy.ServeHTTP(rw, r)
return
}
// Origin revalidation path: if client requests revalidation (no-cache/max-age=0),
// send conditional request upstream using cached validators (ETag/Last-Modified).
if wantsRevalidation(r) {
condReq := r.Clone(r.Context())
if etag := cr.headers.Get("ETag"); etag != "" {
condReq.Header.Set("If-None-Match", etag)
}
if lm := cr.headers.Get("Last-Modified"); lm != "" {
condReq.Header.Set("If-Modified-Since", lm)
}
if !ensureUpstreamAuthorization(condReq) {
return
}
// Forward conditionally to upstream; let modifyResponse handle store/refresh
// Don't increment miss here since this is a conditional revalidation
p.proxy.ServeHTTP(rw, condReq)
return
}
// If the client provided conditionals, respond 304 when validators match
if r.Method == http.MethodGet || r.Method == http.MethodHead {
if conditionalRequestMatches(r, cr.headers) {
for hk, hv := range cr.headers {
for _, v := range hv {
w.Header().Add(hk, v)
}
}
// Set fresh timing headers for conditional cache hit
setFreshCacheTimingHeaders(w, time.Now())
w.Header().Set("Cache-Status", "llm-proxy; conditional-hit")
w.Header().Set("X-PROXY-CACHE", "conditional-hit")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.recordCacheHit(r) // Conditional hit counts as cache hit
w.WriteHeader(http.StatusNotModified)
return
}
}
for hk, hv := range cr.headers {
for _, v := range hv {
w.Header().Add(hk, v)
}
}
// Set fresh timing headers for cache hit
setFreshCacheTimingHeaders(w, time.Now())
w.Header().Set("Cache-Status", "llm-proxy; hit")
w.Header().Set("X-PROXY-CACHE", "hit")
w.Header().Set("X-PROXY-CACHE-KEY", key)
p.recordCacheHit(r)
w.WriteHeader(cr.statusCode)
if r.Method != http.MethodHead {
_, _ = w.Write(cr.body)
}
return
}
// Cache miss - no entry found
p.recordCacheMiss()
// Note: don't set miss status here; let modifyResponse handle cache status
// w.Header().Set("Cache-Status", "llm-proxy; miss")
// Do not set X-PROXY-CACHE(-KEY) on miss; only set definitive headers on hit/bypass/conditional-hit or store path
} else if p.cache != nil {
// Cache is enabled but method is not cacheable (e.g., DELETE, OPTIONS, etc.) - count as miss
p.recordCacheMiss()
}
if !ensureUpstreamAuthorization(r) {
return
}
p.proxy.ServeHTTP(rw, r)
})
var handler http.Handler = baseHandler
handler = p.ValidateRequestMiddleware()(handler)
if p.obsMiddleware != nil {
handler = p.obsMiddleware.Middleware()(handler)
}
handler = CircuitBreakerMiddleware(5, 30*time.Second, func(status int) bool {
return status == http.StatusBadGateway || status == http.StatusServiceUnavailable || status == http.StatusGatewayTimeout
})(handler)
return handler
}
type timingResponseWriter struct {
http.ResponseWriter
firstByteOnce sync.Once
firstByteAt time.Time
finalByteAt time.Time
}
func (w *timingResponseWriter) Write(b []byte) (int, error) {
now := time.Now().UTC()
w.firstByteOnce.Do(func() {
w.firstByteAt = now
w.Header().Set("X-Proxy-First-Response-At", w.firstByteAt.Format(time.RFC3339Nano))
})
w.finalByteAt = now
return w.ResponseWriter.Write(b)
}
func (w *timingResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// recordCacheMiss centralizes cache miss accounting to reduce duplication and
// ensure consistent metric semantics across all miss paths.
func (p *TransparentProxy) recordCacheMiss() {
p.incrementCacheMetric(CacheMetricMiss)
}
// recordCacheHit records a cache hit for metrics and per-token tracking.
func (p *TransparentProxy) recordCacheHit(r *http.Request) {
p.incrementCacheMetric(CacheMetricHit)
// Record per-token cache hit if aggregator is configured
if p.cacheStatsAggregator != nil {
if tokenID, ok := r.Context().Value(ctxKeyTokenID).(string); ok && tokenID != "" {
p.cacheStatsAggregator.RecordCacheHit(tokenID)
}
}
}
func setFreshCacheTimingHeaders(w http.ResponseWriter, now time.Time) {
formatted := now.UTC().Format(time.RFC3339Nano)
w.Header().Set("X-Proxy-Received-At", formatted)
// For cache hits/conditional-hits, the full response is served immediately from cache.
w.Header().Set("X-Proxy-First-Response-At", formatted)
w.Header().Set("X-Proxy-Final-Response-At", formatted)
w.Header().Set("Date", now.UTC().Format(http.TimeFormat))
}
func setTimingHeaders(res *http.Response, ctx context.Context) {
if v := ctx.Value(ctxKeyProxyReceivedAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Received-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxySentBackendAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Sent-Backend-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxyFirstRespAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-First-Response-At", t.Format(time.RFC3339Nano))
}
}
if v := ctx.Value(ctxKeyProxyFinalRespAt); v != nil {
if t, ok := v.(time.Time); ok {
res.Header.Set("X-Proxy-Final-Response-At", t.Format(time.RFC3339Nano))
}
}
}
func logProxyTimings(logger *zap.Logger, ctx context.Context) {
received, _ := ctx.Value(ctxKeyProxyReceivedAt).(time.Time)
sent, _ := ctx.Value(ctxKeyProxySentBackendAt).(time.Time)
first, _ := ctx.Value(ctxKeyProxyFirstRespAt).(time.Time)
final, _ := ctx.Value(ctxKeyProxyFinalRespAt).(time.Time)
requestID, _ := ctx.Value(ctxKeyRequestID).(string)
if !received.IsZero() && !sent.IsZero() {
logger.Debug("Proxy overhead (pre-backend)", zap.Duration("duration", sent.Sub(received)), zap.String("request_id", requestID))
}
if !sent.IsZero() && !first.IsZero() {
logger.Debug("Backend latency (first byte)", zap.Duration("duration", first.Sub(sent)), zap.String("request_id", requestID))
}
if !first.IsZero() && !final.IsZero() {
logger.Debug("Streaming duration", zap.Duration("duration", final.Sub(first)), zap.String("request_id", requestID))
}
if !received.IsZero() && !final.IsZero() {
logger.Debug("Total proxy duration", zap.Duration("duration", final.Sub(received)), zap.String("request_id", requestID))
}
}
// LoggingMiddleware logs request details
func (p *TransparentProxy) LoggingMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID, _ := r.Context().Value(ctxKeyRequestID).(string)
p.logger.Info("Request started",
zap.String("request_id", requestID),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("remote_addr", r.RemoteAddr))
// Create a response recorder to capture response details
rec := &responseRecorder{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Process request
next.ServeHTTP(rec, r)
// Log request completion
duration := time.Since(start)
p.logger.Info("Request completed",
zap.String("request_id", requestID),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.Int("status", rec.statusCode),
zap.Duration("duration", duration))
})
}
}
// ValidateRequestMiddleware validates the incoming request against allowed endpoints and methods
func (p *TransparentProxy) ValidateRequestMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Ensure request_id is set in context as the very first step
requestID, ok := r.Context().Value(ctxKeyRequestID).(string)
if !ok || requestID == "" {
requestID = uuid.New().String()
r = r.WithContext(context.WithValue(r.Context(), ctxKeyRequestID, requestID))
}
// --- Validation Scope: Only token, path, and method are validated here ---
// Do not add API-specific validation or transformation logic here.
// Check if method is allowed
if !p.isMethodAllowed(r.Method) {
p.logger.Warn("Method not allowed",
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
w.WriteHeader(http.StatusMethodNotAllowed)
if requestID != "" {
w.Header().Set("X-Request-ID", requestID)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Method not allowed",
Code: "method_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
// Check if endpoint is allowed
if !p.isEndpointAllowed(r.URL.Path) {
p.logger.Warn("Endpoint not allowed",
zap.String("method", r.Method),
zap.String("path", r.URL.Path))
w.WriteHeader(http.StatusNotFound)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Endpoint not found",
Code: "endpoint_not_found",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
// --- End of validation scope ---
// --- Begin param whitelist validation ---
if r.Method == http.MethodPost && len(p.config.ParamWhitelist) > 0 && r.Header.Get("Content-Type") == "application/json" {
// Read and buffer the body for validation and later proxying
var bodyBytes []byte
if r.Body != nil {
bodyBytes, _ = io.ReadAll(r.Body)
}
if len(bodyBytes) > 0 {
var bodyMap map[string]interface{}
if err := json.Unmarshal(bodyBytes, &bodyMap); err == nil {
for param, allowed := range p.config.ParamWhitelist {
if val, ok := bodyMap[param]; ok {
valStr := ""
switch v := val.(type) {
case string:
valStr = v
case float64:
valStr = fmt.Sprintf("%v", v)
default:
valStr = fmt.Sprintf("%v", v)
}
found := false
// Support glob expressions in allowed values
for _, allowedVal := range allowed {
if ok, _ := path.Match(allowedVal, valStr); ok {
found = true
break
}
}
if !found {
w.WriteHeader(http.StatusBadRequest)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: fmt.Sprintf("Parameter '%s' value '%s' is not allowed. Allowed patterns: %v", param, valStr, allowed),
Code: "param_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
}
}
// Restore the body for downstream handlers
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
// --- End param whitelist validation ---
// --- Begin CORS origin validation ---
origin := r.Header.Get("Origin")
originRequired := false
for _, h := range p.config.RequiredHeaders {
if strings.EqualFold(h, "origin") {
originRequired = true
break
}
}
if originRequired {
if origin == "" {
w.WriteHeader(http.StatusBadRequest)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin header required",
Code: "origin_required",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
if len(p.config.AllowedOrigins) > 0 {
allowed := false
for _, o := range p.config.AllowedOrigins {
if o == origin {
allowed = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusForbidden)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin not allowed",
Code: "origin_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
} else if origin != "" && len(p.config.AllowedOrigins) > 0 {
allowed := false
for _, o := range p.config.AllowedOrigins {
if o == origin {
allowed = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusForbidden)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ErrorResponse{
Error: "Origin not allowed",
Code: "origin_not_allowed",
}); err != nil {
p.logger.Error("Failed to encode error response", zap.Error(err))
}
return
}
}
// --- End CORS origin validation ---
// Continue to next middleware
next.ServeHTTP(w, r)
})
}
}
// TimeoutMiddleware adds a timeout to requests
func (p *TransparentProxy) TimeoutMiddleware(timeout time.Duration) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// MetricsMiddleware collects metrics about requests
func (p *TransparentProxy) MetricsMiddleware() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Store start time in context
ctx := context.WithValue(r.Context(), ctxKeyRequestStart, start)
// Create response recorder to capture status code
rec := &responseRecorder{
ResponseWriter: w,
statusCode: http.StatusOK,
}
// Process request
next.ServeHTTP(rec, r.WithContext(ctx))
// Record metrics
duration := time.Since(start)
p.metrics.mu.Lock()
p.metrics.RequestCount++
// Increment error count for status codes >= 400
if rec.statusCode >= 400 {
p.metrics.ErrorCount++
}
p.metrics.TotalResponseTime += duration
p.metrics.mu.Unlock()
})
}
}
// Shutdown gracefully shuts down the proxy
func (p *TransparentProxy) Shutdown(ctx context.Context) error {
p.mu.Lock()
p.shuttingDown = true
p.mu.Unlock()
p.logger.Info("Shutting down proxy")
// If we have an HTTP server, shut it down
if p.httpServer != nil {
return p.httpServer.Shutdown(ctx)
}
return nil
}
// isMethodAllowed checks if a method is in the allowed list
func (p *TransparentProxy) isMethodAllowed(method string) bool {
// If no allowed methods are specified, allow all methods
if len(p.config.AllowedMethods) == 0 {
return true
}
for _, allowed := range p.config.AllowedMethods {
if strings.EqualFold(method, allowed) {
return true
}
}
return false
}
// isEndpointAllowed checks if an endpoint is in the allowed list
func (p *TransparentProxy) isEndpointAllowed(path string) bool {
// If no allowed endpoints are specified, allow all endpoints
if len(p.config.AllowedEndpoints) == 0 {
return true
}
// Check if path matches any allowed endpoint
for _, endpoint := range p.config.AllowedEndpoints {
if strings.HasPrefix(path, endpoint) {
return true
}
}
return false
}
// responseRecorder is a wrapper for http.ResponseWriter that records the status code
type responseRecorder struct {
http.ResponseWriter
statusCode int
}
// WriteHeader records the status code and calls the wrapped ResponseWriter's WriteHeader method
func (r *responseRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}
// Add Flush forwarding for streaming support
func (r *responseRecorder) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// extractTokenFromHeader extracts a token from an authorization header
func extractTokenFromHeader(authHeader string) string {
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
return ""
}
return strings.TrimSpace(parts[1])
}
// isStreaming checks if a response is a streaming response
func isStreaming(res *http.Response) bool {
// Check Content-Type for SSE
if strings.Contains(res.Header.Get("Content-Type"), "text/event-stream") {
return true
}
// Check for chunked transfer encoding
return strings.Contains(
strings.ToLower(res.Header.Get("Transfer-Encoding")),
"chunked",
)
}
// isStreamingRequest checks if a request is intended for streaming
func isStreamingRequest(req *http.Request) bool {
// Check for SSE Accept header
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true
}
// Check query parameters for stream=true (common in OpenAI APIs)
if req.URL.Query().Get("stream") == "true" {
return true
}
// Check the request body for streaming flag
// This is a heuristic and may need refinement for specific APIs
// For OpenAI, the common pattern is POST with JSON containing "stream": true
// But checking this would require reading the body, which we want to avoid
// We'll just rely on the Accept header and query params for now
return false
}
package proxy
import (
"bytes"
"io"
"sync/atomic"
)
// streamingCaptureReadCloser wraps an io.ReadCloser and captures the bytes
// read into an internal buffer. Once EOF or Close is reached, it invokes
// the provided finalize callback with the captured bytes (if any).
type streamingCaptureReadCloser struct {
rc io.ReadCloser
buf *bytes.Buffer
maxBytes int64
written int64
finalized int32
onDone func([]byte)
}
func newStreamingCapture(rc io.ReadCloser, maxBytes int64, onDone func([]byte)) *streamingCaptureReadCloser {
return &streamingCaptureReadCloser{
rc: rc,
buf: &bytes.Buffer{},
maxBytes: maxBytes,
onDone: onDone,
}
}
func (s *streamingCaptureReadCloser) Read(p []byte) (int, error) {
n, err := s.rc.Read(p)
if n > 0 && (s.maxBytes == 0 || s.written < s.maxBytes) {
limit := n
remaining := s.maxBytes - s.written
if s.maxBytes > 0 && int64(limit) > remaining {
limit = int(remaining)
}
_, _ = s.buf.Write(p[:limit])
s.written += int64(limit)
}
if err == io.EOF {
s.finalize()
}
return n, err
}
func (s *streamingCaptureReadCloser) Close() error {
// Ensure finalize happens once even if Close is called before EOF
s.finalize()
return s.rc.Close()
}
func (s *streamingCaptureReadCloser) finalize() {
if atomic.CompareAndSwapInt32(&s.finalized, 0, 1) {
if s.onDone != nil {
s.onDone(s.buf.Bytes())
}
}
}
// Package server implements the HTTP server for the LLM Proxy.
// It handles request routing, lifecycle management, and provides
// health check endpoints and core API functionality.
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/sofatutor/llm-proxy/internal/audit"
"github.com/sofatutor/llm-proxy/internal/config"
"github.com/sofatutor/llm-proxy/internal/database"
"github.com/sofatutor/llm-proxy/internal/encryption"
"github.com/sofatutor/llm-proxy/internal/eventbus"
"github.com/sofatutor/llm-proxy/internal/logging"
"github.com/sofatutor/llm-proxy/internal/middleware"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"github.com/sofatutor/llm-proxy/internal/proxy"
"github.com/sofatutor/llm-proxy/internal/token"
"go.uber.org/zap"
)
// Server represents the HTTP server for the LLM Proxy.
// It encapsulates the underlying http.Server along with application configuration
// and handles request routing and server lifecycle management.
type Server struct {
server *http.Server
config *config.Config
tokenStore token.TokenStore
projectStore proxy.ProjectStore
logger *zap.Logger
proxy *proxy.TransparentProxy
metrics Metrics
eventBus eventbus.EventBus
auditLogger *audit.Logger
db *database.DB
cacheStatsAgg *proxy.CacheStatsAggregator
usageStatsAgg *token.UsageStatsAggregator
tokenHasher encryption.TokenHasherInterface // Optional hasher for encryption support
}
// ServerOption is a functional option for configuring the server.
type ServerOption func(*Server)
// WithTokenHasher sets the token hasher for usage stats encryption.
// When encryption is enabled, this hasher is used to hash tokens before batch updates.
func WithTokenHasher(hasher encryption.TokenHasherInterface) ServerOption {
return func(s *Server) {
s.tokenHasher = hasher
}
}
// HealthResponse is the response body for the health check endpoint.
// It provides basic information about the server status and version.
type HealthResponse struct {
Status string `json:"status"` // Service status, "ok" for a healthy system
Timestamp time.Time `json:"timestamp"` // Current server time
Version string `json:"version"` // Application version number
}
// Metrics holds runtime metrics for the server.
type Metrics struct {
StartTime time.Time
RequestCount int64
ErrorCount int64
}
// Version is the application version, following semantic versioning.
const Version = "0.1.0"
// maxDurationMinutes is the maximum allowed duration for a token (365 days)
const maxDurationMinutes = 525600
// New creates a new HTTP server with the provided configuration and store implementations.
// It initializes the server with appropriate timeouts and registers all necessary route handlers.
// The server is not started until the Start method is called.
func New(cfg *config.Config, tokenStore token.TokenStore, projectStore proxy.ProjectStore) (*Server, error) {
return NewWithDatabase(cfg, tokenStore, projectStore, nil)
}
// NewWithDatabase creates a new HTTP server with database support for audit logging.
// This allows the server to store audit events in both file and database backends.
// Options can be passed to configure additional features like encryption/hashing.
func NewWithDatabase(cfg *config.Config, tokenStore token.TokenStore, projectStore proxy.ProjectStore, db *database.DB, opts ...ServerOption) (*Server, error) {
mux := http.NewServeMux()
logger, err := logging.NewLogger(cfg.LogLevel, cfg.LogFormat, cfg.LogFile)
if err != nil {
return nil, fmt.Errorf("failed to initialize logger: %w", err)
}
// Initialize audit logger with optional database backend
var auditLogger *audit.Logger
if cfg.AuditEnabled && cfg.AuditLogFile != "" {
auditConfig := audit.LoggerConfig{
FilePath: cfg.AuditLogFile,
CreateDir: cfg.AuditCreateDir,
DatabaseStore: db, // Database store for audit events
EnableDatabase: cfg.AuditStoreInDB && db != nil,
}
auditLogger, err = audit.NewLogger(auditConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize audit logger: %w", err)
}
if cfg.AuditStoreInDB && db != nil {
logger.Info("Audit logging enabled with database storage", zap.String("log_file", cfg.AuditLogFile))
} else {
logger.Info("Audit logging enabled", zap.String("log_file", cfg.AuditLogFile))
}
} else {
auditLogger = audit.NewNullLogger()
logger.Info("Audit logging disabled")
}
metrics := Metrics{StartTime: time.Now()}
var bus eventbus.EventBus
switch cfg.EventBusBackend {
case "redis-streams":
client := redis.NewClient(&redis.Options{
Addr: cfg.RedisAddr,
DB: cfg.RedisDB,
})
// Redis diagnostics
logger.Info("Connecting to Redis for Streams", zap.String("addr", cfg.RedisAddr), zap.Int("db", cfg.RedisDB))
pong, err := client.Ping(context.Background()).Result()
if err != nil {
logger.Fatal("Failed to ping Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.Error(err))
}
logger.Info("Successfully pinged Redis",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.String("response", pong))
// Generate consumer name if not provided
consumerName := cfg.RedisConsumerName
if consumerName == "" {
consumerName = fmt.Sprintf("proxy-%s", uuid.New().String()[:8])
}
streamsConfig := eventbus.RedisStreamsConfig{
StreamKey: cfg.RedisStreamKey,
ConsumerGroup: cfg.RedisConsumerGroup,
ConsumerName: consumerName,
MaxLen: cfg.RedisStreamMaxLen,
BlockTimeout: cfg.RedisStreamBlockTime,
ClaimMinIdleTime: cfg.RedisStreamClaimTime,
BatchSize: cfg.RedisStreamBatchSize,
}
adapter := &eventbus.RedisStreamsClientAdapter{Client: client}
bus = eventbus.NewRedisStreamsEventBus(adapter, streamsConfig)
logger.Info("Using Redis Streams event bus",
zap.String("addr", cfg.RedisAddr),
zap.Int("db", cfg.RedisDB),
zap.String("stream", cfg.RedisStreamKey),
zap.String("consumer_group", cfg.RedisConsumerGroup),
zap.String("consumer_name", consumerName))
case "in-memory":
logger.Info("Using in-memory event bus", zap.String("mode", "single-process"))
bus = eventbus.NewInMemoryEventBus(cfg.ObservabilityBufferSize)
default:
return nil, fmt.Errorf("unknown event bus backend: %s", cfg.EventBusBackend)
}
s := &Server{
config: cfg,
tokenStore: tokenStore,
projectStore: projectStore,
logger: logger,
metrics: metrics,
eventBus: bus,
auditLogger: auditLogger,
db: db,
server: &http.Server{
Addr: cfg.ListenAddr,
Handler: mux,
ReadTimeout: cfg.RequestTimeout,
WriteTimeout: cfg.RequestTimeout,
IdleTimeout: cfg.RequestTimeout * 2,
},
}
// Apply options
for _, opt := range opts {
opt(s)
}
// Register routes
mux.HandleFunc("/health", s.logRequestMiddleware(s.handleHealth))
mux.HandleFunc("/ready", s.logRequestMiddleware(s.handleReady))
mux.HandleFunc("/live", s.logRequestMiddleware(s.handleLive))
mux.HandleFunc("/manage/projects", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleProjects)))
mux.HandleFunc("/manage/projects/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleProjectByID)))
mux.HandleFunc("/manage/tokens", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleTokens)))
mux.HandleFunc("/manage/tokens/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleTokenByID)))
mux.HandleFunc("/manage/audit", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleAuditEvents)))
mux.HandleFunc("/manage/audit/", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleAuditEventByID)))
mux.HandleFunc("/manage/cache/purge", s.logRequestMiddleware(s.managementAuthMiddleware(s.handleCachePurge)))
// Add catch-all handler for unmatched routes to ensure logging
mux.HandleFunc("/", s.logRequestMiddleware(s.handleNotFound))
if cfg.EnableMetrics {
path := cfg.MetricsPath
if path == "" {
path = "/metrics"
}
mux.HandleFunc(path, s.logRequestMiddleware(s.handleMetrics))
mux.HandleFunc(path+"/prometheus", s.logRequestMiddleware(s.handleMetricsPrometheus))
}
return s, nil
}
// Start initializes all required components and starts the HTTP server.
// This method blocks until the server is shut down or an error occurs.
//
// It returns an error if the server fails to start or encounters an
// unrecoverable error during operation.
func (s *Server) Start() error {
// Initialize required components
if err := s.initializeComponents(); err != nil {
return fmt.Errorf("failed to initialize components: %w", err)
}
s.logger.Info("Server starting", zap.String("listen_addr", s.config.ListenAddr))
return s.server.ListenAndServe()
}
// initializeComponents sets up all the required components for the server
func (s *Server) initializeComponents() error {
// Initialize API routes from configuration
if err := s.initializeAPIRoutes(); err != nil {
return fmt.Errorf("failed to initialize API routes: %w", err)
}
// Pending: database, logging, admin, and metrics initialization.
// See server_test.go for test stubs covering these responsibilities.
return nil
}
// initializeAPIRoutes sets up the API proxy routes based on configuration
func (s *Server) initializeAPIRoutes() error {
// Load API providers configuration
apiConfig, err := proxy.LoadAPIConfigFromFile(s.config.APIConfigPath)
if err != nil {
// If the config file doesn't exist or has errors, fall back to a default OpenAI configuration
s.logger.Warn("Failed to load API config, using default OpenAI configuration",
zap.String("config_path", s.config.APIConfigPath),
zap.Error(err))
// Create a default API configuration
apiConfig = &proxy.APIConfig{
DefaultAPI: "openai",
APIs: map[string]*proxy.APIProviderConfig{
"openai": {
BaseURL: s.config.OpenAIAPIURL,
AllowedEndpoints: []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/embeddings",
"/v1/models",
"/v1/edits",
"/v1/fine-tunes",
"/v1/files",
"/v1/images/generations",
"/v1/audio/transcriptions",
"/v1/moderations",
},
AllowedMethods: []string{"GET", "POST", "DELETE"},
Timeouts: proxy.TimeoutConfig{
Request: s.config.RequestTimeout,
ResponseHeader: 30 * time.Second,
IdleConnection: 90 * time.Second,
FlushInterval: 100 * time.Millisecond,
},
Connection: proxy.ConnectionConfig{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20,
},
},
},
}
}
// Get proxy configuration for the default API provider
proxyConfig, err := apiConfig.GetProxyConfigForAPI(s.config.DefaultAPIProvider)
if err != nil {
// If specified provider doesn't exist, use the default one
s.logger.Warn("Specified API provider not found, using default", zap.Error(err))
proxyConfig, err = apiConfig.GetProxyConfigForAPI(apiConfig.DefaultAPI)
if err != nil {
return fmt.Errorf("failed to get proxy configuration: %w", err)
}
}
// Apply HTTP cache env overrides (simple toggle + backend selection)
if v := os.Getenv("HTTP_CACHE_ENABLED"); v != "" {
// Parse bool; default to true on invalid for safety
proxyConfig.HTTPCacheEnabled = strings.EqualFold(v, "true") || strings.EqualFold(v, "1") || strings.EqualFold(v, "yes")
} else {
// Default: enabled
proxyConfig.HTTPCacheEnabled = true
}
backend := strings.ToLower(os.Getenv("HTTP_CACHE_BACKEND"))
if backend == "redis" {
// Use REDIS_CACHE_URL if explicitly set, otherwise construct from REDIS_ADDR
url := os.Getenv("REDIS_CACHE_URL")
if url == "" {
// Construct URL from unified REDIS_ADDR config (same as event bus)
addr := os.Getenv("REDIS_ADDR")
if addr == "" {
addr = "localhost:6379"
}
db := os.Getenv("REDIS_DB")
if db == "" {
db = "0"
}
url = fmt.Sprintf("redis://%s/%s", addr, db)
}
proxyConfig.RedisCacheURL = url
if kp := os.Getenv("REDIS_CACHE_KEY_PREFIX"); kp != "" {
proxyConfig.RedisCacheKeyPrefix = kp
}
}
// Use the injected tokenStore and projectStore
// (No more creation of mock stores or test data here)
tokenValidator := token.NewValidator(s.tokenStore)
// Async usage tracking for unlimited tokens: keep DB writes off the hot path.
if s.db != nil {
usageCfg := token.DefaultUsageStatsAggregatorConfig()
if s.config.UsageStatsBufferSize > 0 {
usageCfg.BufferSize = s.config.UsageStatsBufferSize
}
// Use the raw DB store, but wrap with secure store if encryption is enabled
var usageStore token.UsageStatsStore = s.db
if s.tokenHasher != nil {
usageStore = encryption.NewSecureUsageStatsStore(s.db, s.tokenHasher)
s.logger.Debug("Using secure usage stats store with token hashing")
}
s.usageStatsAgg = token.NewUsageStatsAggregator(usageCfg, usageStore, s.logger)
s.usageStatsAgg.Start()
tokenValidator.SetUsageStatsAggregator(s.usageStatsAgg)
s.logger.Info("Usage stats aggregator started", zap.Int("buffer_size", usageCfg.BufferSize))
}
cachedValidator := token.NewCachedValidator(tokenValidator)
obsCfg := middleware.ObservabilityConfig{
Enabled: s.config.ObservabilityEnabled,
EventBus: s.eventBus,
MaxRequestBodyBytes: s.config.ObservabilityMaxRequestBodyBytes,
MaxResponseBodyBytes: s.config.ObservabilityMaxResponseBodyBytes,
}
projectStore := s.projectStore
if s.config.APIKeyCacheTTL > 0 && s.config.APIKeyCacheMax > 0 {
projectStore = proxy.NewCachedProjectStore(projectStore, proxy.CachedProjectStoreConfig{
TTL: s.config.APIKeyCacheTTL,
Max: s.config.APIKeyCacheMax,
})
s.logger.Info("Upstream API key cache enabled",
zap.Duration("ttl", s.config.APIKeyCacheTTL),
zap.Int("max", s.config.APIKeyCacheMax),
)
}
if proxyConfig.EnforceProjectActive && s.config.ActiveCacheTTL > 0 && s.config.ActiveCacheMax > 0 {
projectStore = proxy.NewCachedProjectActiveStore(projectStore, proxy.CachedProjectActiveStoreConfig{
TTL: s.config.ActiveCacheTTL,
Max: s.config.ActiveCacheMax,
})
s.logger.Info("Project active status cache enabled",
zap.Duration("ttl", s.config.ActiveCacheTTL),
zap.Int("max", s.config.ActiveCacheMax),
)
}
proxyHandler, err := proxy.NewTransparentProxyWithAudit(*proxyConfig, cachedValidator, projectStore, s.logger, s.auditLogger, obsCfg)
if err != nil {
return fmt.Errorf("failed to initialize proxy: %w", err)
}
s.proxy = proxyHandler
// Initialize cache stats aggregator for per-token cache hit tracking.
// NOTE: Cache stats tracking is only enabled when HTTP caching is enabled (HTTPCacheEnabled=true).
// When caching is disabled, no cache hits occur, so tracking is not needed.
// The Admin UI will show CacheHitCount=0 for all tokens when caching is disabled.
if s.db != nil && proxyConfig.HTTPCacheEnabled {
aggConfig := proxy.CacheStatsAggregatorConfig{
BufferSize: s.config.CacheStatsBufferSize,
FlushInterval: 5 * time.Second,
BatchSize: 100,
}
// Use the raw DB store, but wrap with secure store if encryption is enabled
var cacheStatsStore proxy.CacheStatsStore = s.db
if s.tokenHasher != nil {
cacheStatsStore = encryption.NewSecureCacheStatsStore(s.db, s.tokenHasher)
s.logger.Debug("Using secure cache stats store with token hashing")
}
s.cacheStatsAgg = proxy.NewCacheStatsAggregator(aggConfig, cacheStatsStore, s.logger)
s.cacheStatsAgg.Start()
proxyHandler.SetCacheStatsAggregator(s.cacheStatsAgg)
s.logger.Info("Cache stats aggregator started", zap.Int("buffer_size", aggConfig.BufferSize))
}
// Register proxy routes
s.server.Handler.(*http.ServeMux).Handle("/v1/", proxyHandler.Handler())
s.logger.Info("Initialized proxy",
zap.String("target_base_url", proxyConfig.TargetBaseURL),
zap.Int("allowed_endpoints", len(proxyConfig.AllowedEndpoints)))
return nil
}
// Shutdown gracefully shuts down the server without interrupting
// active connections. It waits for all connections to complete
// or for the provided context to be canceled, whichever comes first.
//
// The context should typically include a timeout to prevent
// the shutdown from blocking indefinitely.
func (s *Server) Shutdown(ctx context.Context) error {
// Stop usage stats aggregator first to flush pending usage updates
if s.usageStatsAgg != nil {
s.logger.Info("Stopping usage stats aggregator")
if err := s.usageStatsAgg.Stop(ctx); err != nil {
s.logger.Error("failed to stop usage stats aggregator during shutdown", zap.Error(err))
}
}
// Stop cache stats aggregator to flush pending stats
if s.cacheStatsAgg != nil {
s.logger.Info("Stopping cache stats aggregator")
if err := s.cacheStatsAgg.Stop(ctx); err != nil {
s.logger.Error("failed to stop cache stats aggregator during shutdown", zap.Error(err))
}
}
// Close audit logger to ensure all events are written
if s.auditLogger != nil {
if err := s.auditLogger.Close(); err != nil {
s.logger.Error("failed to close audit logger during shutdown", zap.Error(err))
}
}
return s.server.Shutdown(ctx)
}
// handleHealth is the HTTP handler for the health check endpoint.
// It responds with a JSON payload containing the server status,
// current timestamp, and application version.
//
// This endpoint can be used by load balancers, monitoring tools,
// and container orchestration systems to verify service health.
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
response := HealthResponse{
Status: "ok",
Timestamp: time.Now(),
Version: Version,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
s.logger.Error("Error encoding health response", zap.Error(err))
return
}
// Status code 200 OK is set implicitly when the response is written successfully
}
// handleReady is used for readiness probes.
func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
}
// handleLive is used for liveness probes.
func (s *Server) handleLive(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("alive"))
}
// handleMetrics returns basic runtime metrics in JSON format.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
m := struct {
UptimeSeconds float64 `json:"uptime_seconds"`
RequestCount int64 `json:"request_count"`
ErrorCount int64 `json:"error_count"`
// Cache metrics (provider-agnostic counters)
CacheHits int64 `json:"cache_hits"`
CacheMisses int64 `json:"cache_misses"`
CacheBypass int64 `json:"cache_bypass"`
CacheStores int64 `json:"cache_stores"`
}{
UptimeSeconds: time.Since(s.metrics.StartTime).Seconds(),
}
if s.proxy != nil {
pm := s.proxy.Metrics()
m.RequestCount = pm.RequestCount
m.ErrorCount = pm.ErrorCount
m.CacheHits = pm.CacheHits
m.CacheMisses = pm.CacheMisses
m.CacheBypass = pm.CacheBypass
m.CacheStores = pm.CacheStores
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(m); err != nil {
s.logger.Error("Failed to encode metrics", zap.Error(err))
http.Error(w, "Failed to encode metrics", http.StatusInternalServerError)
}
}
// handleMetricsPrometheus returns metrics in Prometheus text exposition format.
func (s *Server) handleMetricsPrometheus(w http.ResponseWriter, r *http.Request) {
var buf strings.Builder
// Uptime gauge
uptimeSeconds := time.Since(s.metrics.StartTime).Seconds()
buf.WriteString("# HELP llm_proxy_uptime_seconds Time since the server started\n")
buf.WriteString("# TYPE llm_proxy_uptime_seconds gauge\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_uptime_seconds %g\n", uptimeSeconds)
// Get proxy metrics or use zero values
var requestCount, errorCount, cacheHits, cacheMisses, cacheBypass, cacheStores int64
if s.proxy != nil {
pm := s.proxy.Metrics()
requestCount = pm.RequestCount
errorCount = pm.ErrorCount
cacheHits = pm.CacheHits
cacheMisses = pm.CacheMisses
cacheBypass = pm.CacheBypass
cacheStores = pm.CacheStores
}
// Write metrics in Prometheus format
buf.WriteString("# HELP llm_proxy_requests_total Total number of proxy requests\n")
buf.WriteString("# TYPE llm_proxy_requests_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_requests_total %d\n", requestCount)
buf.WriteString("# HELP llm_proxy_errors_total Total number of proxy errors\n")
buf.WriteString("# TYPE llm_proxy_errors_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_errors_total %d\n", errorCount)
buf.WriteString("# HELP llm_proxy_cache_hits_total Total number of cache hits\n")
buf.WriteString("# TYPE llm_proxy_cache_hits_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_cache_hits_total %d\n", cacheHits)
buf.WriteString("# HELP llm_proxy_cache_misses_total Total number of cache misses\n")
buf.WriteString("# TYPE llm_proxy_cache_misses_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_cache_misses_total %d\n", cacheMisses)
buf.WriteString("# HELP llm_proxy_cache_bypass_total Total number of cache bypasses\n")
buf.WriteString("# TYPE llm_proxy_cache_bypass_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_cache_bypass_total %d\n", cacheBypass)
buf.WriteString("# HELP llm_proxy_cache_stores_total Total number of cache stores\n")
buf.WriteString("# TYPE llm_proxy_cache_stores_total counter\n")
_, _ = fmt.Fprintf(&buf, "llm_proxy_cache_stores_total %d\n", cacheStores)
// Go runtime metrics
s.writeGoRuntimeMetrics(&buf)
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
if _, err := w.Write([]byte(buf.String())); err != nil {
s.logger.Error("Failed to write Prometheus metrics", zap.Error(err))
}
}
// writeGoRuntimeMetrics writes Go runtime metrics to the buffer in Prometheus format.
func (s *Server) writeGoRuntimeMetrics(buf *strings.Builder) {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
// Goroutines
buf.WriteString("# HELP llm_proxy_goroutines Number of goroutines currently running\n")
buf.WriteString("# TYPE llm_proxy_goroutines gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_goroutines %d\n", runtime.NumGoroutine())
// Memory metrics
buf.WriteString("# HELP llm_proxy_memory_heap_alloc_bytes Number of heap bytes allocated and currently in use\n")
buf.WriteString("# TYPE llm_proxy_memory_heap_alloc_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_heap_alloc_bytes %d\n", memStats.Alloc)
buf.WriteString("# HELP llm_proxy_memory_heap_sys_bytes Number of heap bytes obtained from the OS\n")
buf.WriteString("# TYPE llm_proxy_memory_heap_sys_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_heap_sys_bytes %d\n", memStats.HeapSys)
buf.WriteString("# HELP llm_proxy_memory_heap_idle_bytes Number of heap bytes waiting to be used\n")
buf.WriteString("# TYPE llm_proxy_memory_heap_idle_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_heap_idle_bytes %d\n", memStats.HeapIdle)
buf.WriteString("# HELP llm_proxy_memory_heap_inuse_bytes Number of heap bytes that are in use\n")
buf.WriteString("# TYPE llm_proxy_memory_heap_inuse_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_heap_inuse_bytes %d\n", memStats.HeapInuse)
buf.WriteString("# HELP llm_proxy_memory_heap_released_bytes Number of heap bytes released to the OS\n")
buf.WriteString("# TYPE llm_proxy_memory_heap_released_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_heap_released_bytes %d\n", memStats.HeapReleased)
buf.WriteString("# HELP llm_proxy_memory_total_alloc_bytes Total number of bytes allocated (cumulative)\n")
buf.WriteString("# TYPE llm_proxy_memory_total_alloc_bytes counter\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_total_alloc_bytes %d\n", memStats.TotalAlloc)
buf.WriteString("# HELP llm_proxy_memory_sys_bytes Total number of bytes obtained from the OS\n")
buf.WriteString("# TYPE llm_proxy_memory_sys_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_sys_bytes %d\n", memStats.Sys)
buf.WriteString("# HELP llm_proxy_memory_mallocs_total Total number of malloc operations\n")
buf.WriteString("# TYPE llm_proxy_memory_mallocs_total counter\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_mallocs_total %d\n", memStats.Mallocs)
buf.WriteString("# HELP llm_proxy_memory_frees_total Total number of free operations\n")
buf.WriteString("# TYPE llm_proxy_memory_frees_total counter\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_memory_frees_total %d\n", memStats.Frees)
// GC metrics
buf.WriteString("# HELP llm_proxy_gc_runs_total Total number of GC runs\n")
buf.WriteString("# TYPE llm_proxy_gc_runs_total counter\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_gc_runs_total %d\n", memStats.NumGC)
buf.WriteString("# HELP llm_proxy_gc_pause_total_seconds Total GC pause time in seconds\n")
buf.WriteString("# TYPE llm_proxy_gc_pause_total_seconds counter\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_gc_pause_total_seconds %g\n", float64(memStats.PauseTotalNs)/1e9)
buf.WriteString("# HELP llm_proxy_gc_next_bytes Target heap size for next GC cycle\n")
buf.WriteString("# TYPE llm_proxy_gc_next_bytes gauge\n")
_, _ = fmt.Fprintf(buf, "llm_proxy_gc_next_bytes %d\n", memStats.NextGC)
}
// managementAuthMiddleware checks the management token in the Authorization header
func (s *Server) managementAuthMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
const prefix = "Bearer "
header := r.Header.Get("Authorization")
if !strings.HasPrefix(header, prefix) || len(header) <= len(prefix) {
http.Error(w, `{"error":"missing or invalid Authorization header"}`, http.StatusUnauthorized)
return
}
token := header[len(prefix):]
if token != s.config.ManagementToken {
http.Error(w, `{"error":"invalid management token"}`, http.StatusUnauthorized)
return
}
next(w, r)
}
}
// GET /manage/projects
// We only register /manage/projects (no trailing slash) for handleProjects. This ensures that both /manage/projects and /manage/projects/ are handled identically, and only /manage/projects/{id} is handled by handleProjectByID. This avoids ambiguity and double handling in Go's http.ServeMux.
func (s *Server) handleProjects(w http.ResponseWriter, r *http.Request) {
s.logger.Debug("handleProjects: START", zap.String("method", r.Method), zap.String("path", r.URL.Path))
// Normalize path: treat /manage/projects/ as /manage/projects
if r.URL.Path == "/manage/projects/" {
r.URL.Path = "/manage/projects"
}
// DEBUG: Log method and headers
for k, v := range r.Header {
if strings.EqualFold(k, "Authorization") {
s.logger.Debug("handleProjects: header", zap.String("key", k), zap.String("value", "******"))
} else {
s.logger.Debug("handleProjects: header", zap.String("key", k), zap.Any("value", v))
}
}
// Mask management token in logs
maskedToken := "******"
if len(s.config.ManagementToken) > 4 {
maskedToken = s.config.ManagementToken[:4] + "******"
}
s.logger.Debug("handleProjects: config.ManagementToken", zap.String("ManagementToken", maskedToken))
if !s.checkManagementAuth(w, r) {
s.logger.Debug("handleProjects: END (auth failed)")
return
}
ctx := r.Context()
requestID := getRequestID(ctx)
switch r.Method {
case http.MethodGet:
s.logger.Info("listing projects", zap.String("request_id", requestID))
s.handleListProjects(w, r.WithContext(ctx))
case http.MethodPost:
s.logger.Info("creating project", zap.String("request_id", requestID))
s.handleCreateProject(w, r.WithContext(ctx))
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
s.logger.Debug("handleProjects: END", zap.String("method", r.Method), zap.String("path", r.URL.Path))
}
func (s *Server) handleListProjects(w http.ResponseWriter, r *http.Request) {
s.logger.Debug("handleListProjects: START")
ctx := r.Context()
requestID := getRequestID(ctx)
projects, err := s.projectStore.ListProjects(ctx)
if err != nil {
s.logger.Error("failed to list projects", zap.Error(err))
// Audit: project list failure
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectList, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err))
http.Error(w, `{"error":"failed to list projects"}`, http.StatusInternalServerError)
s.logger.Debug("handleListProjects: END (error)")
return
}
// Audit: project list success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("project_count", len(projects)))
// Create response with obfuscated API keys
sanitizedProjects := make([]ProjectResponse, len(projects))
for i, p := range projects {
sanitizedProjects[i] = ProjectResponse{
ID: p.ID,
Name: p.Name,
APIKey: obfuscate.ObfuscateTokenGeneric(p.APIKey),
IsActive: p.IsActive,
DeactivatedAt: p.DeactivatedAt,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(sanitizedProjects); err != nil {
s.logger.Error("failed to encode projects response", zap.Error(err))
s.logger.Debug("handleListProjects: END (encode error)")
} else {
s.logger.Debug("handleListProjects: END (success)")
}
}
// POST /manage/projects
func (s *Server) handleCreateProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
var req struct {
Name string `json:"name"`
APIKey string `json:"api_key"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: project creation failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
if req.Name == "" || req.APIKey == "" {
s.logger.Error(
"missing required fields",
zap.String("name", req.Name),
zap.Bool("api_key_provided", req.APIKey != ""),
zap.String("request_id", requestID),
)
// Audit: project creation failure - missing fields
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "missing required fields").
WithDetail("name_provided", req.Name != "").
WithDetail("api_key_provided", req.APIKey != ""))
http.Error(w, `{"error":"name and api_key are required"}`, http.StatusBadRequest)
return
}
// Reject obfuscated keys to prevent data corruption
if strings.Contains(req.APIKey, "...") || strings.Contains(req.APIKey, "****") {
s.logger.Error("attempted to create project with obfuscated API key", zap.String("request_id", requestID))
// Audit: project creation failure - obfuscated key
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "cannot save obfuscated API key"))
http.Error(w, `{"error":"cannot save obfuscated API key - please provide the full API key"}`, http.StatusBadRequest)
return
}
id := uuid.NewString()
now := time.Now().UTC()
project := proxy.Project{
ID: id,
Name: req.Name,
APIKey: req.APIKey,
IsActive: true, // Projects are active by default
CreatedAt: now,
UpdatedAt: now,
}
if err := s.projectStore.CreateProject(ctx, project); err != nil {
s.logger.Error("failed to create project", zap.Error(err), zap.String("name", req.Name), zap.String("request_id", requestID))
// Audit: project creation failure - store error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("project_name", req.Name))
http.Error(w, `{"error":"failed to create project"}`, http.StatusInternalServerError)
return
}
s.logger.Info("project created", zap.String("id", id), zap.String("name", req.Name), zap.String("request_id", requestID))
// Audit: project creation success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectCreate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(id).
WithDetail("project_name", req.Name))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(project); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// GET /manage/projects/{id}
func (s *Server) handleGetProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if id == "" || strings.Contains(id, "/") {
s.logger.Error("invalid project id", zap.String("id", id))
http.Error(w, `{"error":"invalid project id"}`, http.StatusBadRequest)
return
}
project, err := s.projectStore.GetProjectByID(ctx, id)
if err != nil {
s.logger.Error("project not found", zap.String("id", id), zap.Error(err))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Create response with obfuscated API key
response := ProjectResponse{
ID: project.ID,
Name: project.Name,
APIKey: obfuscate.ObfuscateTokenGeneric(project.APIKey),
IsActive: project.IsActive,
DeactivatedAt: project.DeactivatedAt,
CreatedAt: project.CreatedAt,
UpdatedAt: project.UpdatedAt,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// PATCH /manage/projects/{id}
func (s *Server) handleUpdateProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if id == "" || strings.Contains(id, "/") {
s.logger.Error("invalid project id for update", zap.String("id", id))
// Audit: project update failure - invalid ID
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "invalid project id").
WithDetail("provided_id", id))
http.Error(w, `{"error":"invalid project id"}`, http.StatusBadRequest)
return
}
var req struct {
Name *string `json:"name,omitempty"`
APIKey *string `json:"api_key,omitempty"`
IsActive *bool `json:"is_active,omitempty"`
RevokeTokens *bool `json:"revoke_tokens,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid request body for update", zap.Error(err))
// Audit: project update failure - invalid request body
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
// Validate revoke_tokens usage: it is only valid when explicitly deactivating the project
if req.RevokeTokens != nil {
if req.IsActive == nil || (req.IsActive != nil && *req.IsActive) {
// Audit: invalid field combination
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithDetail("validation_error", "revoke_tokens requires is_active=false"))
http.Error(w, `{"error":"revoke_tokens requires is_active=false"}`, http.StatusBadRequest)
return
}
}
project, err := s.projectStore.GetProjectByID(ctx, id)
if err != nil {
s.logger.Error("project not found for update", zap.String("id", id), zap.Error(err))
// Audit: project update failure - not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Track what fields are being updated
var updatedFields []string
if req.Name != nil {
project.Name = *req.Name
updatedFields = append(updatedFields, "name")
}
if req.APIKey != nil && *req.APIKey != "" {
// Reject obfuscated keys to prevent data corruption
if strings.Contains(*req.APIKey, "...") || strings.Contains(*req.APIKey, "****") {
s.logger.Error("attempted to save obfuscated API key", zap.String("project_id", id))
// Audit: project update failure - obfuscated key
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithDetail("validation_error", "cannot save obfuscated API key"))
http.Error(w, `{"error":"cannot save obfuscated API key - please provide the full API key"}`, http.StatusBadRequest)
return
}
project.APIKey = *req.APIKey
updatedFields = append(updatedFields, "api_key")
}
// Handle project activation/deactivation
var shouldRevokeTokens bool
if req.IsActive != nil {
if *req.IsActive != project.IsActive {
project.IsActive = *req.IsActive
updatedFields = append(updatedFields, "is_active")
// If deactivating project, set deactivated timestamp
if !*req.IsActive {
now := time.Now().UTC()
project.DeactivatedAt = &now
updatedFields = append(updatedFields, "deactivated_at")
// Check if tokens should be revoked when deactivating
if req.RevokeTokens != nil && *req.RevokeTokens {
shouldRevokeTokens = true
}
} else {
// Reactivating project, clear deactivated timestamp
project.DeactivatedAt = nil
}
}
}
project.UpdatedAt = time.Now().UTC()
if err := s.projectStore.UpdateProject(ctx, project); err != nil {
s.logger.Error("failed to update project", zap.String("id", id), zap.Error(err))
// Audit: project update failure - store error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithError(err).
WithDetail("updated_fields", updatedFields))
http.Error(w, `{"error":"failed to update project"}`, http.StatusInternalServerError)
return
}
// Revoke project tokens if requested
var revokedTokensCount int
if shouldRevokeTokens {
tokens, err := s.tokenStore.GetTokensByProjectID(ctx, id)
if err != nil {
s.logger.Warn("failed to get project tokens for revocation", zap.String("project_id", id), zap.Error(err))
} else {
// Revoke all active tokens for this project
for _, token := range tokens {
if token.IsActive {
token.IsActive = false
token.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, token); err != nil {
s.logger.Warn("failed to revoke token during project deactivation",
zap.String("token_id", token.ID),
zap.String("project_id", id),
zap.Error(err))
} else {
revokedTokensCount++
}
}
}
}
}
s.logger.Info("project updated", zap.String("id", id), zap.Strings("updated_fields", updatedFields))
// Audit: project update success
auditEvent := s.auditEvent(audit.ActionProjectUpdate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(id).
WithDetail("updated_fields", updatedFields).
WithDetail("project_name", project.Name)
if shouldRevokeTokens {
auditEvent.WithDetail("tokens_revoked", revokedTokensCount)
}
_ = s.auditLogger.Log(auditEvent)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(project); err != nil {
s.logger.Error("failed to encode project response", zap.Error(err))
}
}
// DELETE /manage/projects/{id}
// DELETE /manage/projects/{id} - Returns 405 Method Not Allowed
func (s *Server) handleDeleteProject(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
id := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
// Audit: project delete attempt - method not allowed
_ = s.auditLogger.Log(s.auditEvent(audit.ActionProjectDelete, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(id).
WithDetail("error_type", "method not allowed").
WithDetail("reason", "project deletion is not permitted"))
w.Header().Set("Allow", "GET, PATCH")
http.Error(w, `{"error":"method not allowed","message":"project deletion is not permitted"}`, http.StatusMethodNotAllowed)
}
// POST /manage/projects/{id}/tokens/revoke - Bulk revoke all tokens for a project
func (s *Server) handleBulkRevokeProjectTokens(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Extract project ID from path
pathSuffix := strings.TrimPrefix(r.URL.Path, "/manage/projects/")
if !strings.HasSuffix(pathSuffix, "/tokens/revoke") {
s.logger.Error("invalid bulk revoke path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid path"}`, http.StatusBadRequest)
return
}
projectID := strings.TrimSuffix(pathSuffix, "/tokens/revoke")
if projectID == "" {
s.logger.Error("invalid project ID in bulk revoke path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid path"}`, http.StatusBadRequest)
return
}
// Verify project exists
_, err := s.projectStore.GetProjectByID(ctx, projectID)
if err != nil {
s.logger.Error("project not found for bulk token revoke", zap.String("project_id", projectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: bulk revoke failure - project not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(projectID).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Get all tokens for the project
tokens, err := s.tokenStore.GetTokensByProjectID(ctx, projectID)
if err != nil {
s.logger.Error("failed to get tokens for bulk revoke", zap.String("project_id", projectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: bulk revoke failure - failed to get tokens
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(projectID).
WithError(err).
WithDetail("error_type", "failed to get tokens"))
http.Error(w, `{"error":"failed to get project tokens"}`, http.StatusInternalServerError)
return
}
// Count and revoke active tokens
var revokedCount, alreadyRevokedCount int
var failedRevocations []string
for _, token := range tokens {
if !token.IsActive {
alreadyRevokedCount++
continue
}
// Revoke the token
token.IsActive = false
token.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, token); err != nil {
s.logger.Warn("failed to revoke individual token during bulk revoke",
zap.String("token_id", token.ID),
zap.String("project_id", projectID),
zap.Error(err))
failedRevocations = append(failedRevocations, token.ID)
} else {
revokedCount++
}
}
s.logger.Info("bulk token revocation completed",
zap.String("project_id", projectID),
zap.Int("revoked_count", revokedCount),
zap.Int("already_revoked_count", alreadyRevokedCount),
zap.Int("failed_count", len(failedRevocations)),
zap.String("request_id", requestID),
)
// Audit: bulk revoke success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevokeBatch, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(projectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("total_tokens", len(tokens)).
WithDetail("revoked_count", revokedCount).
WithDetail("already_revoked_count", alreadyRevokedCount).
WithDetail("failed_count", len(failedRevocations)))
// Return summary response
response := map[string]interface{}{
"revoked_count": revokedCount,
"already_revoked_count": alreadyRevokedCount,
"total_tokens": len(tokens),
}
if len(failedRevocations) > 0 {
response["failed_count"] = len(failedRevocations)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode bulk revoke response", zap.Error(err))
}
}
// Add this helper to *Server
func (s *Server) checkManagementAuth(w http.ResponseWriter, r *http.Request) bool {
const prefix = "Bearer "
header := r.Header.Get("Authorization")
s.logger.Debug("checkManagementAuth: header",
zap.Bool("present", header != ""),
zap.Bool("has_bearer_prefix", strings.HasPrefix(header, prefix)),
zap.Int("header_len", len(header)),
)
if !strings.HasPrefix(header, prefix) || len(header) <= len(prefix) {
s.logger.Debug("checkManagementAuth: missing or invalid prefix")
http.Error(w, `{"error":"missing or invalid Authorization header"}`, http.StatusUnauthorized)
return false
}
token := header[len(prefix):]
s.logger.Debug("checkManagementAuth: token compare",
zap.Int("provided_len", len(token)),
zap.Int("expected_len", len(s.config.ManagementToken)),
)
if token != s.config.ManagementToken {
s.logger.Debug("checkManagementAuth: token mismatch")
http.Error(w, `{"error":"invalid management token"}`, http.StatusUnauthorized)
return false
}
s.logger.Debug("checkManagementAuth: token match")
return true
}
// Add the handler function
func (s *Server) handleProjectByID(w http.ResponseWriter, r *http.Request) {
// Check if this is a bulk token revoke request
if strings.HasSuffix(r.URL.Path, "/tokens/revoke") && r.Method == http.MethodPost {
s.handleBulkRevokeProjectTokens(w, r)
return
}
switch r.Method {
case http.MethodGet:
s.handleGetProject(w, r)
case http.MethodPatch:
s.handleUpdateProject(w, r)
case http.MethodDelete:
s.handleDeleteProject(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// Handler for /manage/tokens (POST: create, GET: list)
func (s *Server) handleTokens(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
switch r.Method {
case http.MethodPost:
var req struct {
ProjectID string `json:"project_id"`
DurationMinutes int `json:"duration_minutes"`
MaxRequests *int `json:"max_requests"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid token create request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
var duration time.Duration
if req.DurationMinutes > 0 {
if req.DurationMinutes > maxDurationMinutes {
s.logger.Error("duration_minutes exceeds maximum allowed", zap.Int("duration_minutes", req.DurationMinutes), zap.String("request_id", requestID))
// Audit: token creation failure - duration too long
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("validation_error", "duration exceeds maximum").
WithDetail("requested_duration_minutes", req.DurationMinutes).
WithDetail("max_duration_minutes", maxDurationMinutes))
http.Error(w, `{"error":"duration_minutes exceeds maximum allowed"}`, http.StatusBadRequest)
return
}
duration = time.Duration(req.DurationMinutes) * time.Minute
} else {
s.logger.Error("missing required fields for token create", zap.String("project_id", req.ProjectID), zap.Int("duration_minutes", req.DurationMinutes), zap.String("request_id", requestID))
// Audit: token creation failure - missing duration
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("validation_error", "missing duration_minutes"))
http.Error(w, `{"error":"project_id and duration_minutes are required"}`, http.StatusBadRequest)
return
}
if req.ProjectID == "" {
s.logger.Error("missing project_id for token create", zap.String("request_id", requestID))
// Audit: token creation failure - missing project ID
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("validation_error", "missing project_id"))
http.Error(w, `{"error":"project_id is required"}`, http.StatusBadRequest)
return
}
if req.MaxRequests != nil {
if *req.MaxRequests < 0 {
s.logger.Error("max_requests must be >= 0", zap.Int("max_requests", *req.MaxRequests), zap.String("request_id", requestID))
// Audit: token creation failure - invalid max_requests
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("validation_error", "max_requests must be >= 0").
WithDetail("requested_max_requests", *req.MaxRequests))
http.Error(w, `{"error":"max_requests must be >= 0"}`, http.StatusBadRequest)
return
}
// 0 means unlimited.
if *req.MaxRequests == 0 {
req.MaxRequests = nil
}
}
// Check project exists and is active
project, err := s.projectStore.GetProjectByID(ctx, req.ProjectID)
if err != nil {
s.logger.Error("project not found for token create", zap.String("project_id", req.ProjectID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - project not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithError(err).
WithDetail("error_type", "project not found"))
http.Error(w, `{"error":"project not found"}`, http.StatusNotFound)
return
}
// Check if project is active
if !project.IsActive {
s.logger.Warn("token creation denied for inactive project", zap.String("project_id", req.ProjectID), zap.String("request_id", requestID))
// Audit: token creation failure - project inactive
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithDetail("error_type", "project_inactive").
WithDetail("reason", "cannot create tokens for inactive projects"))
http.Error(w, `{"error":"cannot create tokens for inactive projects","code":"project_inactive"}`, http.StatusForbidden)
return
}
// Generate token ID (UUID) and token string
tokenID := uuid.New().String()
tokenStr, expiresAt, _, err := token.NewTokenGenerator().GenerateWithOptions(duration, req.MaxRequests)
if err != nil {
s.logger.Error("failed to generate token", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - generation error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithError(err).
WithDetail("error_type", "token generation failed"))
http.Error(w, `{"error":"failed to generate token"}`, http.StatusInternalServerError)
return
}
now := time.Now().UTC()
dbToken := token.TokenData{
ID: tokenID,
Token: tokenStr,
ProjectID: req.ProjectID,
ExpiresAt: expiresAt,
IsActive: true,
RequestCount: 0,
MaxRequests: req.MaxRequests,
CreatedAt: now,
}
if err := s.tokenStore.CreateToken(ctx, dbToken); err != nil {
s.logger.Error("failed to store token", zap.Error(err), zap.String("request_id", requestID))
// Audit: token creation failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithProjectID(req.ProjectID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to store token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token created",
zap.String("token", token.ObfuscateToken(tokenStr)),
zap.String("project_id", req.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token creation success
auditEvent := s.auditEvent(audit.ActionTokenCreate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithProjectID(req.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithTokenID(tokenID).
WithDetail("duration_minutes", req.DurationMinutes).
WithDetail("expires_at", expiresAt.Format(time.RFC3339))
if req.MaxRequests != nil {
auditEvent.WithDetail("max_requests", *req.MaxRequests)
}
_ = s.auditLogger.Log(auditEvent)
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"id": tokenID,
"token": tokenStr,
"expires_at": expiresAt,
}
if req.MaxRequests != nil {
response["max_requests"] = *req.MaxRequests
}
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode token response", zap.Error(err))
}
case http.MethodGet:
projectID := r.URL.Query().Get("projectId")
var tokens []token.TokenData
var err error
if projectID != "" {
tokens, err = s.tokenStore.GetTokensByProjectID(ctx, projectID)
} else {
tokens, err = s.tokenStore.ListTokens(ctx)
}
if err != nil {
s.logger.Error("failed to list tokens", zap.Error(err))
// Audit: token list failure
auditEvent := s.auditEvent(audit.ActionTokenList, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithError(err)
if projectID != "" {
auditEvent.WithProjectID(projectID)
}
_ = s.auditLogger.Log(auditEvent)
http.Error(w, `{"error":"failed to list tokens"}`, http.StatusInternalServerError)
return
}
s.logger.Info("tokens listed", zap.Int("count", len(tokens)))
// Audit: token list success
auditEvent := s.auditEvent(audit.ActionTokenList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("token_count", len(tokens))
if projectID != "" {
auditEvent.WithProjectID(projectID).WithDetail("filtered_by_project", true)
}
_ = s.auditLogger.Log(auditEvent)
w.Header().Set("Content-Type", "application/json")
// Create sanitized response with token IDs and obfuscated token strings
sanitizedTokens := make([]TokenListResponse, len(tokens))
for i, t := range tokens {
sanitizedTokens[i] = TokenListResponse{
ID: t.ID,
Token: token.ObfuscateToken(t.Token),
ProjectID: t.ProjectID,
ExpiresAt: t.ExpiresAt,
IsActive: t.IsActive,
RequestCount: t.RequestCount,
MaxRequests: t.MaxRequests,
CreatedAt: t.CreatedAt,
LastUsedAt: t.LastUsedAt,
CacheHitCount: t.CacheHitCount,
}
}
if err := json.NewEncoder(w).Encode(sanitizedTokens); err != nil {
s.logger.Error("failed to encode tokens response", zap.Error(err))
}
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// Handler for /manage/tokens/{id} (GET: retrieve, PATCH: update, DELETE: revoke)
func (s *Server) handleTokenByID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Extract token ID from path
tokenID := strings.TrimPrefix(r.URL.Path, "/manage/tokens/")
if tokenID == "" || tokenID == "/" {
s.logger.Error("invalid token ID in path", zap.String("path", r.URL.Path), zap.String("request_id", requestID))
http.Error(w, `{"error":"token ID is required"}`, http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodGet:
s.handleGetToken(w, r, tokenID)
case http.MethodPatch:
s.handleUpdateToken(w, r, tokenID)
case http.MethodDelete:
s.handleRevokeToken(w, r, tokenID)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
// GET /manage/tokens/{id}
func (s *Server) handleGetToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Get token from store
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token get failure
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRead, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Audit: token get success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRead, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path))
// Create sanitized response with ID and obfuscated token string
response := TokenListResponse{
ID: tokenData.ID,
Token: token.ObfuscateToken(tokenData.Token),
ProjectID: tokenData.ProjectID,
ExpiresAt: tokenData.ExpiresAt,
IsActive: tokenData.IsActive,
RequestCount: tokenData.RequestCount,
MaxRequests: tokenData.MaxRequests,
CreatedAt: tokenData.CreatedAt,
LastUsedAt: tokenData.LastUsedAt,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode token response", zap.Error(err))
}
}
// PATCH /manage/tokens/{id}
func (s *Server) handleUpdateToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Parse request body
var req struct {
IsActive *bool `json:"is_active,omitempty"`
MaxRequests *int `json:"max_requests,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Error("invalid token update request body", zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - invalid request
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("validation_error", "invalid request body"))
http.Error(w, `{"error":"invalid request body"}`, http.StatusBadRequest)
return
}
// Get existing token
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token for update", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - token not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Normalize max_requests semantics:
// - negative: invalid
// - 0: unlimited (stored as nil)
maxRequestsProvided := req.MaxRequests != nil
normalizedMaxRequests := req.MaxRequests
if maxRequestsProvided {
if *req.MaxRequests < 0 {
s.logger.Error("max_requests must be >= 0", zap.Int("max_requests", *req.MaxRequests), zap.String("request_id", requestID))
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(fmt.Errorf("max_requests must be >= 0")))
http.Error(w, `{"error":"max_requests must be >= 0"}`, http.StatusBadRequest)
return
}
if *req.MaxRequests == 0 {
normalizedMaxRequests = nil
}
}
// Update fields if provided
updated := false
if req.IsActive != nil {
tokenData.IsActive = *req.IsActive
updated = true
}
if maxRequestsProvided {
tokenData.MaxRequests = normalizedMaxRequests
updated = true
}
if !updated {
s.logger.Error("no fields to update", zap.String("token_id", tokenID), zap.String("request_id", requestID))
// Audit: token update failure - no fields
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithDetail("validation_error", "no fields to update"))
http.Error(w, `{"error":"no fields to update"}`, http.StatusBadRequest)
return
}
// Update token in store
if err := s.tokenStore.UpdateToken(ctx, tokenData); err != nil {
s.logger.Error("failed to update token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token update failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to update token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token updated",
zap.String("token_id", tokenID),
zap.String("project_id", tokenData.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token update success
auditEvent := s.auditEvent(audit.ActionTokenUpdate, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path)
if req.IsActive != nil {
auditEvent.WithDetail("updated_is_active", *req.IsActive)
}
if maxRequestsProvided && normalizedMaxRequests != nil {
auditEvent.WithDetail("updated_max_requests", *normalizedMaxRequests)
}
if maxRequestsProvided && normalizedMaxRequests == nil {
auditEvent.WithDetail("updated_max_requests", "unlimited")
}
_ = s.auditLogger.Log(auditEvent)
// Return updated token (sanitized with ID and obfuscated token)
response := TokenListResponse{
ID: tokenData.ID,
Token: token.ObfuscateToken(tokenData.Token),
ProjectID: tokenData.ProjectID,
ExpiresAt: tokenData.ExpiresAt,
IsActive: tokenData.IsActive,
RequestCount: tokenData.RequestCount,
MaxRequests: tokenData.MaxRequests,
CreatedAt: tokenData.CreatedAt,
LastUsedAt: tokenData.LastUsedAt,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode updated token response", zap.Error(err))
}
}
// DELETE /manage/tokens/{id} (revoke token)
func (s *Server) handleRevokeToken(w http.ResponseWriter, r *http.Request, tokenID string) {
ctx := r.Context()
requestID := getRequestID(ctx)
// Get existing token first to verify it exists and get project ID
tokenData, err := s.tokenStore.GetTokenByID(ctx, tokenID)
if err != nil {
s.logger.Error("failed to get token for revocation", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token revoke failure - token not found
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithError(err).
WithDetail("error_type", "token not found"))
http.Error(w, `{"error":"token not found"}`, http.StatusNotFound)
return
}
// Check if already inactive
if !tokenData.IsActive {
s.logger.Warn("token already revoked", zap.String("token_id", tokenID), zap.String("request_id", requestID))
// Audit: token revoke success (idempotent)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithDetail("already_revoked", true))
w.WriteHeader(http.StatusNoContent)
return
}
// Revoke token by setting is_active to false
tokenData.IsActive = false
tokenData.DeactivatedAt = nowPtrUTC()
if err := s.tokenStore.UpdateToken(ctx, tokenData); err != nil {
s.logger.Error("failed to revoke token", zap.String("token_id", tokenID), zap.Error(err), zap.String("request_id", requestID))
// Audit: token revoke failure - storage error
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithError(err).
WithDetail("error_type", "storage failed"))
http.Error(w, `{"error":"failed to revoke token"}`, http.StatusInternalServerError)
return
}
s.logger.Info("token revoked",
zap.String("token_id", tokenID),
zap.String("project_id", tokenData.ProjectID),
zap.String("request_id", requestID),
)
// Audit: token revoke success
_ = s.auditLogger.Log(s.auditEvent(audit.ActionTokenRevoke, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithTokenID(tokenID).
WithProjectID(tokenData.ProjectID).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path))
w.WriteHeader(http.StatusNoContent)
}
func getRequestID(ctx context.Context) string {
if requestID, ok := logging.GetRequestID(ctx); ok && requestID != "" {
return requestID
}
return uuid.New().String()
}
// nowPtrUTC returns the current UTC time as a *time.Time convenience helper.
func nowPtrUTC() *time.Time {
t := time.Now().UTC()
return &t
}
// parseInt parses a string to an integer with a default value
func parseInt(s string, defaultValue int) int {
if s == "" {
return defaultValue
}
var result int
if _, err := fmt.Sscanf(s, "%d", &result); err != nil {
return defaultValue
}
return result
}
// logRequestMiddleware logs all incoming requests with timing information using structured logging
func (s *Server) logRequestMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// Get or generate request ID from header
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
// Get or generate correlation ID from header
correlationID := r.Header.Get("X-Correlation-ID")
if correlationID == "" {
correlationID = uuid.New().String()
}
// Add to context using our new context helpers
ctx := logging.WithRequestID(r.Context(), requestID)
ctx = logging.WithCorrelationID(ctx, correlationID)
// Set response headers
w.Header().Set("X-Request-ID", requestID)
w.Header().Set("X-Correlation-ID", correlationID)
// Create a response writer that captures status code
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// Get client IP
clientIP := s.getClientIP(r)
// Create logger with request context
reqLogger := logging.WithRequestContext(ctx, s.logger)
reqLogger = logging.WithCorrelationContext(ctx, reqLogger)
reqLogger.Info("request started",
logging.ClientIP(clientIP),
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("user_agent", r.UserAgent()),
)
// Call the next handler
next(rw, r.WithContext(ctx))
duration := time.Since(startTime)
durationMs := int(duration.Milliseconds())
// Log completion with canonical fields
if rw.statusCode >= 500 {
reqLogger.Error("request completed with server error",
logging.RequestFields(requestID, r.Method, r.URL.Path, rw.statusCode, durationMs)...,
)
} else {
reqLogger.Info("request completed",
logging.RequestFields(requestID, r.Method, r.URL.Path, rw.statusCode, durationMs)...,
)
}
}
}
// getClientIP extracts the client IP address from the request
func (s *Server) getClientIP(r *http.Request) string {
// Check for X-Forwarded-For header first (in case of proxy)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check for X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// Add Flush forwarding for streaming support
func (rw *responseWriter) Flush() {
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// handleNotFound is a catch-all handler for unmatched routes
func (s *Server) handleNotFound(w http.ResponseWriter, r *http.Request) {
s.logger.Info("route not found",
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("remote_addr", r.RemoteAddr),
)
http.NotFound(w, r)
}
// EventBus returns the event bus used by the server (may be nil if observability is disabled)
func (s *Server) EventBus() eventbus.EventBus {
return s.eventBus
}
// auditEvent creates a new audit event with common fields filled from the HTTP request
func (s *Server) auditEvent(action string, actor string, result audit.ResultType, r *http.Request, requestID string) *audit.Event {
clientIP := s.getClientIP(r)
// Prefer forwarded UA and referer from Admin UI if present
forwardedUA := r.Header.Get("X-Forwarded-User-Agent")
userAgent := r.UserAgent()
if forwardedUA != "" {
userAgent = forwardedUA
}
forwardedRef := r.Header.Get("X-Forwarded-Referer")
ev := audit.NewEvent(action, actor, result).
WithRequestID(requestID).
WithHTTPMethod(r.Method).
WithEndpoint(r.URL.Path).
WithClientIP(clientIP).
WithUserAgent(userAgent)
if forwardedRef != "" {
ev = ev.WithDetail("referer", forwardedRef)
}
if r.Header.Get("X-Admin-Origin") == "1" {
ev = ev.WithDetail("origin", "admin-ui")
}
return ev
}
// Handler for /manage/audit (GET: list)
func (s *Server) handleAuditEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
// Only proceed if we have database access
if s.db == nil {
s.logger.Error("audit events requested but database not available", zap.String("request_id", requestID))
http.Error(w, `{"error":"audit events not available"}`, http.StatusServiceUnavailable)
return
}
// Parse query parameters for filtering
query := r.URL.Query()
filters := database.AuditEventFilters{
Action: query.Get("action"),
ClientIP: query.Get("client_ip"),
ProjectID: query.Get("project_id"),
Outcome: query.Get("outcome"),
Actor: query.Get("actor"),
RequestID: query.Get("request_id"),
CorrelationID: query.Get("correlation_id"),
Method: query.Get("method"),
Path: query.Get("path"),
Search: query.Get("search"),
}
// Parse time filters
if startTime := query.Get("start_time"); startTime != "" {
filters.StartTime = &startTime
}
if endTime := query.Get("end_time"); endTime != "" {
filters.EndTime = &endTime
}
// Parse pagination
page := parseInt(query.Get("page"), 1)
pageSize := parseInt(query.Get("page_size"), 20)
if pageSize > 100 {
pageSize = 100 // Limit page size
}
filters.Limit = pageSize
filters.Offset = (page - 1) * pageSize
// Get audit events
events, err := s.db.ListAuditEvents(ctx, filters)
if err != nil {
s.logger.Error("failed to list audit events", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to list audit events"}`, http.StatusInternalServerError)
return
}
// Get total count for pagination
totalCount, err := s.db.CountAuditEvents(ctx, filters)
if err != nil {
s.logger.Error("failed to count audit events", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to count audit events"}`, http.StatusInternalServerError)
return
}
// Calculate pagination info
totalPages := (totalCount + pageSize - 1) / pageSize
hasNext := page < totalPages
hasPrev := page > 1
response := map[string]interface{}{
"events": events,
"pagination": map[string]interface{}{
"page": page,
"page_size": pageSize,
"total_count": totalCount,
"total_pages": totalPages,
"has_next": hasNext,
"has_prev": hasPrev,
},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode audit events response", zap.Error(err), zap.String("request_id", requestID))
}
// Audit: successful audit events listing
_ = s.auditLogger.Log(s.auditEvent(audit.ActionAuditList, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("events_count", len(events)).
WithDetail("page", page).
WithDetail("page_size", pageSize).
WithDetail("total_count", totalCount))
}
// Handler for /manage/audit/{id} (GET: show)
func (s *Server) handleAuditEventByID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := getRequestID(ctx)
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
// Only proceed if we have database access
if s.db == nil {
s.logger.Error("audit event requested but database not available", zap.String("request_id", requestID))
http.Error(w, `{"error":"audit event not available"}`, http.StatusServiceUnavailable)
return
}
// Extract audit event ID from path
id := strings.TrimPrefix(r.URL.Path, "/manage/audit/")
if id == "" {
http.Error(w, `{"error":"audit event ID is required"}`, http.StatusBadRequest)
return
}
// Get specific audit event by ID
event, err := s.db.GetAuditEventByID(ctx, id)
if err != nil {
if strings.Contains(err.Error(), "not found") {
http.Error(w, `{"error":"audit event not found"}`, http.StatusNotFound)
} else {
s.logger.Error("failed to get audit event", zap.Error(err), zap.String("audit_id", id), zap.String("request_id", requestID))
http.Error(w, `{"error":"failed to get audit event"}`, http.StatusInternalServerError)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(event); err != nil {
s.logger.Error("failed to encode audit event response", zap.Error(err), zap.String("audit_id", id), zap.String("request_id", requestID))
}
// Audit: successful audit event retrieval
_ = s.auditLogger.Log(s.auditEvent(audit.ActionAuditShow, audit.ActorManagement, audit.ResultSuccess, r, requestID).
WithDetail("audit_event_id", id))
}
// CachePurgeRequest represents the request body for cache purge operations
type CachePurgeRequest struct {
Method string `json:"method" binding:"required"`
URL string `json:"url" binding:"required"`
Prefix string `json:"prefix,omitempty"`
}
// CachePurgeResponse represents the response body for cache purge operations
type CachePurgeResponse struct {
Deleted interface{} `json:"deleted"` // bool for exact purge, int for prefix purge
}
// Handler for POST /manage/cache/purge
func (s *Server) handleCachePurge(w http.ResponseWriter, r *http.Request) {
requestID := getRequestID(r.Context())
if r.Method != http.MethodPost {
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
return
}
// Check if proxy and cache are available
if s.proxy == nil {
s.logger.Error("proxy not initialized", zap.String("request_id", requestID))
http.Error(w, `{"error":"proxy not available"}`, http.StatusInternalServerError)
return
}
cache := s.proxy.Cache()
if cache == nil {
s.logger.Warn("cache purge attempted but caching is disabled", zap.String("request_id", requestID))
http.Error(w, `{"error":"caching is disabled"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "caching_disabled"))
return
}
var req CachePurgeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.Warn("invalid JSON in cache purge request", zap.Error(err), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid JSON"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "invalid_json"))
return
}
// Validate required fields
if req.Method == "" || req.URL == "" {
s.logger.Warn("missing required fields in cache purge request",
zap.String("method", req.Method), zap.String("url", req.URL), zap.String("request_id", requestID))
http.Error(w, `{"error":"method and url are required"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "missing_fields"))
return
}
var response CachePurgeResponse
var auditDetails map[string]interface{}
if req.Prefix != "" {
// Prefix purge
deleted := cache.PurgePrefix(req.Prefix)
response.Deleted = deleted
auditDetails = map[string]interface{}{
"purge_type": "prefix",
"prefix": req.Prefix,
"deleted": deleted,
}
s.logger.Info("cache prefix purge completed",
zap.String("prefix", req.Prefix), zap.Int("deleted", deleted), zap.String("request_id", requestID))
} else {
// Exact key purge - need to compute cache key from method and URL
// Create a mock request to generate the cache key
mockURL, err := url.Parse(req.URL)
if err != nil {
s.logger.Warn("invalid URL in cache purge request", zap.Error(err), zap.String("url", req.URL), zap.String("request_id", requestID))
http.Error(w, `{"error":"invalid URL"}`, http.StatusBadRequest)
_ = s.auditLogger.Log(s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultFailure, r, requestID).
WithDetail("reason", "invalid_url"))
return
}
mockReq := &http.Request{
Method: req.Method,
URL: mockURL,
Header: make(http.Header),
}
// Generate cache key using existing helper
cacheKey := proxy.CacheKeyFromRequest(mockReq)
deleted := cache.Purge(cacheKey)
response.Deleted = deleted
auditDetails = map[string]interface{}{
"purge_type": "exact",
"method": req.Method,
"url": req.URL,
"cache_key": cacheKey,
"deleted": deleted,
}
s.logger.Info("cache exact purge completed",
zap.String("method", req.Method), zap.String("url", req.URL),
zap.String("cache_key", cacheKey), zap.Bool("deleted", deleted), zap.String("request_id", requestID))
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.logger.Error("failed to encode cache purge response", zap.Error(err), zap.String("request_id", requestID))
return
}
// Audit: successful cache purge
auditEvent := s.auditEvent(audit.ActionCachePurge, audit.ActorManagement, audit.ResultSuccess, r, requestID)
for k, v := range auditDetails {
auditEvent = auditEvent.WithDetail(k, v)
}
_ = s.auditLogger.Log(auditEvent)
}
// Package setup provides configuration setup and management utilities.
package setup
import (
"fmt"
"os"
"path/filepath"
"github.com/sofatutor/llm-proxy/internal/utils"
)
// SetupConfig holds configuration parameters for setup
type SetupConfig struct {
ConfigPath string
OpenAIAPIKey string
ManagementToken string
DatabasePath string
ListenAddr string
}
// ValidateConfig validates the setup configuration
func (sc *SetupConfig) ValidateConfig() error {
if sc.OpenAIAPIKey == "" {
return fmt.Errorf("OpenAI API key is required")
}
if sc.ConfigPath == "" {
return fmt.Errorf("config path is required")
}
if sc.DatabasePath == "" {
return fmt.Errorf("database path is required")
}
if sc.ListenAddr == "" {
return fmt.Errorf("listen address is required")
}
return nil
}
// GenerateManagementToken generates a management token if not provided
func (sc *SetupConfig) GenerateManagementToken() error {
if sc.ManagementToken == "" {
token, err := utils.GenerateSecureToken(16)
if err != nil {
return fmt.Errorf("failed to generate management token: %w", err)
}
sc.ManagementToken = token
}
return nil
}
// WriteConfigFile writes the configuration to a file
func (sc *SetupConfig) WriteConfigFile() error {
if err := sc.ValidateConfig(); err != nil {
return err
}
// Ensure directory exists
dir := filepath.Dir(sc.ConfigPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Ensure database directory exists
dbDir := filepath.Dir(sc.DatabasePath)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return fmt.Errorf("failed to create database directory: %w", err)
}
content := fmt.Sprintf(`# LLM Proxy Configuration
OPENAI_API_KEY=%s
MANAGEMENT_TOKEN=%s
DATABASE_PATH=%s
LISTEN_ADDR=%s
LOG_LEVEL=info
`, sc.OpenAIAPIKey, sc.ManagementToken, sc.DatabasePath, sc.ListenAddr)
if err := os.WriteFile(sc.ConfigPath, []byte(content), 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}
// RunNonInteractiveSetup performs non-interactive setup with the given configuration
func RunNonInteractiveSetup(sc *SetupConfig) error {
if err := sc.GenerateManagementToken(); err != nil {
return fmt.Errorf("failed to generate management token: %w", err)
}
if err := sc.WriteConfigFile(); err != nil {
return fmt.Errorf("failed to write configuration: %w", err)
}
return nil
}
package token
import (
"container/heap"
"context"
"fmt"
"sync"
"time"
)
// CacheEntry represents a cached token data with expiration
type CacheEntry struct {
Data TokenData
ValidUntil time.Time
}
// cacheEntry is a heap entry for eviction, with index for fast updates
// and tokenID for lookup.
type cacheEntry struct {
tokenID string
validUntil time.Time
insertedAt int64 // strictly increasing for FIFO eviction
index int // index in the heap
}
type cacheEntryHeap []*cacheEntry
func (h cacheEntryHeap) Len() int { return len(h) }
func (h cacheEntryHeap) Less(i, j int) bool { return h[i].insertedAt < h[j].insertedAt } // FIFO eviction
func (h cacheEntryHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
h[i].index = i
h[j].index = j
}
func (h *cacheEntryHeap) Push(x interface{}) {
entry := x.(*cacheEntry)
entry.index = len(*h)
*h = append(*h, entry)
}
func (h *cacheEntryHeap) Pop() interface{} {
old := *h
n := len(old)
item := old[n-1]
item.index = -1 // for safety
*h = old[0 : n-1]
return item
}
// CachedValidator wraps a TokenValidator with caching
type CachedValidator struct {
validator TokenValidator
cache map[string]CacheEntry
cacheMutex sync.RWMutex
cacheTTL time.Duration
maxCacheSize int
// Min-heap for eviction
heap cacheEntryHeap
heapIndex map[string]*cacheEntry // tokenID -> *cacheEntry
insertCounter int64 // strictly increasing counter for insertedAt
// For cache stats
hits int
misses int
evictions int
statsMutex sync.Mutex
}
// CacheOptions defines the options for the token cache
type CacheOptions struct {
// Time-to-live for cache entries (default: 5 minutes)
TTL time.Duration
// Maximum size of the cache (default: 1000)
MaxSize int
// Whether to enable automatic cache cleanup (default: true)
EnableCleanup bool
// Interval for cache cleanup (default: 1 minute)
CleanupInterval time.Duration
}
// DefaultCacheOptions returns the default cache options
func DefaultCacheOptions() CacheOptions {
return CacheOptions{
TTL: 5 * time.Minute,
MaxSize: 1000,
EnableCleanup: true,
CleanupInterval: 1 * time.Minute,
}
}
// NewCachedValidator creates a new validator with caching
func NewCachedValidator(validator TokenValidator, options ...CacheOptions) *CachedValidator {
opts := DefaultCacheOptions()
if len(options) > 0 {
opts = options[0]
}
cv := &CachedValidator{
validator: validator,
cache: make(map[string]CacheEntry),
cacheTTL: opts.TTL,
maxCacheSize: opts.MaxSize,
heap: make(cacheEntryHeap, 0, opts.MaxSize),
heapIndex: make(map[string]*cacheEntry, opts.MaxSize),
}
// Start cache cleanup if enabled
if opts.EnableCleanup {
go cv.startCleanup(opts.CleanupInterval)
}
return cv
}
// ValidateToken validates a token using the cache when possible
func (cv *CachedValidator) ValidateToken(ctx context.Context, tokenID string) (string, error) {
// Check cache first
projectID, found := cv.checkCache(tokenID)
if found {
return projectID, nil
}
// Cache miss, validate using the underlying validator
projectID, err := cv.validator.ValidateToken(ctx, tokenID)
if err != nil {
return "", err
}
// Cache the successful validation
cv.cacheToken(ctx, tokenID)
return projectID, nil
}
// ValidateTokenWithTracking validates a token and tracks usage.
//
// For cached *unlimited* tokens, we keep validation cheap by returning from the cache and enqueueing
// async tracking (when an aggregator is configured).
//
// For cached *limited* tokens, we still want to avoid an extra DB read on every request. We do this
// by using the cached token metadata (active/expiry/project) and performing a synchronous usage
// increment (which is where max_requests is enforced).
func (cv *CachedValidator) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
cv.cacheMutex.RLock()
entry, found := cv.cache[tokenID]
cv.cacheMutex.RUnlock()
if !found {
cv.statsMutex.Lock()
cv.misses++
cv.statsMutex.Unlock()
} else {
now := time.Now()
if now.After(entry.ValidUntil) {
cv.invalidateCache(tokenID)
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
} else if isCacheableTokenValid(entry.Data) {
cv.statsMutex.Lock()
cv.hits++
cv.statsMutex.Unlock()
isUnlimited := entry.Data.MaxRequests == nil || *entry.Data.MaxRequests <= 0
if isUnlimited {
if getter, ok := cv.validator.(usageStatsAggregatorGetter); ok {
if agg := getter.usageStatsAggregator(); agg != nil {
agg.RecordTokenUsage(tokenID)
return entry.Data.ProjectID, nil
}
}
} else if sv, ok := cv.validator.(*StandardValidator); ok && sv != nil && sv.store != nil {
// Limited token: enforce max_requests via a synchronous increment, but avoid a DB read.
if err := sv.store.IncrementTokenUsage(ctx, tokenID); err != nil {
// If the token is no longer usable (inactive/expired/quota), invalidate the cache entry
// so we avoid repeatedly hitting the cache and failing the same increment.
if err == ErrTokenRateLimit || err == ErrTokenInactive || err == ErrTokenExpired {
cv.invalidateCache(tokenID)
}
return "", err
}
return entry.Data.ProjectID, nil
}
} else {
// Token became invalid (inactive/expired). Be defensive and drop the cache entry.
cv.invalidateCache(tokenID)
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
}
}
projectID, err := cv.validator.ValidateTokenWithTracking(ctx, tokenID)
if err != nil {
return "", err
}
// Populate the cache after successful tracking so subsequent requests can avoid extra DB reads.
// (Only works for StandardValidator; others are safely ignored.)
cv.cacheToken(ctx, tokenID)
return projectID, nil
}
// checkCache checks if a token is in the cache and still valid
func (cv *CachedValidator) checkCache(tokenID string) (string, bool) {
cv.cacheMutex.RLock()
entry, found := cv.cache[tokenID]
cv.cacheMutex.RUnlock()
// Not in cache
if !found {
cv.statsMutex.Lock()
cv.misses++
cv.statsMutex.Unlock()
return "", false
}
// In cache but expired
now := time.Now()
if now.After(entry.ValidUntil) {
cv.cacheMutex.Lock()
delete(cv.cache, tokenID)
cv.cacheMutex.Unlock()
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
return "", false
}
// In cache but token is no longer valid (inactive/expired)
if !isCacheableTokenValid(entry.Data) {
cv.invalidateCache(tokenID)
cv.statsMutex.Lock()
cv.misses++
cv.evictions++
cv.statsMutex.Unlock()
return "", false
}
// In cache and valid
cv.statsMutex.Lock()
cv.hits++
cv.statsMutex.Unlock()
return entry.Data.ProjectID, true
}
// isCacheableTokenValid determines whether a cached token can be used for authentication.
//
// Important: We intentionally do NOT use RequestCount/MaxRequests here.
// Cache hits are not supposed to count against token quotas (see cache-hit fast path),
// and cached RequestCount is inherently stale under concurrency.
func isCacheableTokenValid(td TokenData) bool {
if !td.IsActive {
return false
}
if IsExpired(td.ExpiresAt) {
return false
}
return true
}
// cacheToken retrieves and caches a token
func (cv *CachedValidator) cacheToken(ctx context.Context, tokenID string) {
standardValidator, ok := cv.validator.(*StandardValidator)
if !ok {
return
}
// TokenValidator receives the token *string* (sk-...) in ValidateToken/ValidateTokenWithTracking.
// Populate cache using token-string lookup.
tokenData, err := standardValidator.store.GetTokenByToken(ctx, tokenID)
if err != nil {
return
}
if !isCacheableTokenValid(tokenData) {
return
}
validUntil := time.Now().Add(cv.cacheTTL)
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
insertedAt := cv.insertCounter
cv.insertCounter++
cv.cache[tokenID] = CacheEntry{
Data: tokenData,
ValidUntil: validUntil,
}
// Remove old heap entry if present
if oldEntry, ok := cv.heapIndex[tokenID]; ok {
idx := oldEntry.index
heap.Remove(&cv.heap, idx)
delete(cv.heapIndex, tokenID)
}
entry := &cacheEntry{
tokenID: tokenID,
validUntil: validUntil,
insertedAt: insertedAt,
}
heap.Push(&cv.heap, entry)
cv.heapIndex[tokenID] = entry
// Evict if over capacity
if cv.maxCacheSize > 0 && len(cv.cache) > cv.maxCacheSize {
cv.evictOldest()
}
}
// invalidateCache removes a token from the cache
func (cv *CachedValidator) invalidateCache(tokenID string) {
cv.cacheMutex.Lock()
delete(cv.cache, tokenID)
// Remove from heap if present
if entry, ok := cv.heapIndex[tokenID]; ok {
idx := entry.index
heap.Remove(&cv.heap, idx)
delete(cv.heapIndex, tokenID)
}
cv.cacheMutex.Unlock()
// Note: In production, cache and heap sizes should always be consistent
}
// evictOldest removes the single oldest entry from the cache
func (cv *CachedValidator) evictOldest() {
if cv.heap.Len() == 0 {
return
}
entry := heap.Pop(&cv.heap).(*cacheEntry)
// Remove from heapIndex
delete(cv.heapIndex, entry.tokenID)
// Remove from cache
delete(cv.cache, entry.tokenID)
cv.statsMutex.Lock()
cv.evictions++
cv.statsMutex.Unlock()
}
// startCleanup periodically cleans up expired entries from the cache
func (cv *CachedValidator) startCleanup(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
cv.cleanup()
}
}
// cleanup removes expired entries from the cache
func (cv *CachedValidator) cleanup() {
now := time.Now()
cv.cacheMutex.Lock()
defer cv.cacheMutex.Unlock()
for k, v := range cv.cache {
if now.After(v.ValidUntil) {
delete(cv.cache, k)
cv.statsMutex.Lock()
cv.evictions++
cv.statsMutex.Unlock()
}
}
}
// ClearCache removes all entries from the cache
func (cv *CachedValidator) ClearCache() {
cv.cacheMutex.Lock()
cv.cache = make(map[string]CacheEntry)
cv.cacheMutex.Unlock()
}
// GetCacheStats returns statistics about the cache
func (cv *CachedValidator) GetCacheStats() (hits, misses, evictions, size int) {
cv.statsMutex.Lock()
hits = cv.hits
misses = cv.misses
evictions = cv.evictions
cv.statsMutex.Unlock()
cv.cacheMutex.RLock()
size = len(cv.cache)
cv.cacheMutex.RUnlock()
return
}
// GetCacheInfo returns a formatted string with cache statistics
func (cv *CachedValidator) GetCacheInfo() string {
hits, misses, evictions, size := cv.GetCacheStats()
total := hits + misses
hitRate := 0.0
if total > 0 {
hitRate = float64(hits) / float64(total) * 100
}
return fmt.Sprintf(
"Cache Stats:\n"+
" Size: %d (max: %d)\n"+
" Hits: %d (%.1f%%)\n"+
" Misses: %d\n"+
" Evictions: %d\n"+
" TTL: %s",
size, cv.maxCacheSize, hits, hitRate, misses, evictions, cv.cacheTTL,
)
}
package token
import (
"errors"
"time"
)
var (
// ErrInvalidDuration is returned when an expiration duration is invalid
ErrInvalidDuration = errors.New("invalid duration")
// ErrExpirationInPast is returned when an expiration time is in the past
ErrExpirationInPast = errors.New("expiration time is in the past")
)
// Common expiration durations
const (
OneHour = time.Hour
OneDay = 24 * time.Hour
OneWeek = 7 * 24 * time.Hour
ThirtyDays = 30 * 24 * time.Hour
NinetyDays = 90 * 24 * time.Hour
NoExpiration = time.Duration(0)
MaxDuration = time.Duration(1<<63 - 1)
)
// CalculateExpiration returns an expiration time based on the current time and the provided duration.
// If duration is 0 or negative, it returns nil (no expiration).
func CalculateExpiration(duration time.Duration) *time.Time {
if duration <= 0 {
return nil // No expiration
}
expiry := time.Now().Add(duration)
return &expiry
}
// CalculateExpirationFrom returns an expiration time based on the provided start time and duration.
// If duration is 0 or negative, it returns nil (no expiration).
func CalculateExpirationFrom(startTime time.Time, duration time.Duration) *time.Time {
if duration <= 0 {
return nil // No expiration
}
expiry := startTime.Add(duration)
return &expiry
}
// ValidateExpiration checks if the given expiration time is valid (nil or in the future).
func ValidateExpiration(expiresAt *time.Time) error {
if expiresAt == nil {
return nil // No expiration is valid
}
if expiresAt.Before(time.Now()) {
return ErrExpirationInPast
}
return nil
}
// IsExpired checks if a token with the given expiration time is expired.
// If expiresAt is nil, the token never expires.
func IsExpired(expiresAt *time.Time) bool {
if expiresAt == nil {
return false // No expiration
}
return time.Now().After(*expiresAt)
}
// TimeUntilExpiration returns the duration until the token expires.
// If expiresAt is nil, it returns the max possible duration (effectively "never").
func TimeUntilExpiration(expiresAt *time.Time) time.Duration {
if expiresAt == nil {
return MaxDuration // Max duration, practically "never"
}
until := time.Until(*expiresAt)
if until < 0 {
return 0 // Already expired
}
return until
}
// ExpiresWithin checks if the token will expire within the given duration.
// If expiresAt is nil, it returns false (never expires).
func ExpiresWithin(expiresAt *time.Time, duration time.Duration) bool {
if expiresAt == nil {
return false // No expiration
}
expiresIn := time.Until(*expiresAt)
return expiresIn >= 0 && expiresIn <= duration
}
// FormatExpirationTime returns a human-readable string for the expiration time.
// If expiresAt is nil, it returns "Never expires".
func FormatExpirationTime(expiresAt *time.Time) string {
if expiresAt == nil {
return "Never expires"
}
if expiresAt.Before(time.Now()) {
return "Expired"
}
return expiresAt.Format(time.RFC3339)
}
package token
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
)
// ManagerStore is a composite interface for all required store operations
// ManagerStore embeds all store interfaces required by Manager
// (TokenStore, RevocationStore, RateLimitStore)
//
//go:generate mockgen -destination=mock_managerstore.go -package=token . ManagerStore
type ManagerStore interface {
TokenStore
RevocationStore
RateLimitStore
}
// Manager provides a unified interface for all token operations
type Manager struct {
validator TokenValidator
revoker *Revoker
limiter *StandardRateLimiter
generator *TokenGenerator
store ManagerStore // Underlying store (must implement all store interfaces)
useCaching bool
}
// NewManager creates a new token manager with the given store
func NewManager(store ManagerStore, useCaching bool) (*Manager, error) {
// No need to check interfaces, type system enforces it now
// Create components
baseValidator := NewValidator(store)
var validator TokenValidator = baseValidator
if useCaching {
validator = NewCachedValidator(baseValidator)
}
revoker := NewRevoker(store)
limiter := NewRateLimiter(store)
generator := NewTokenGenerator()
return &Manager{
validator: validator,
revoker: revoker,
limiter: limiter,
generator: generator,
store: store,
useCaching: useCaching,
}, nil
}
// CreateToken generates a new token with the specified options
func (m *Manager) CreateToken(ctx context.Context, projectID string, options TokenOptions) (TokenData, error) {
// Generate a new token
tokenStr, expiresAt, maxRequests, err := m.generator.GenerateWithOptions(options.Expiration, options.MaxRequests)
if err != nil {
return TokenData{}, fmt.Errorf("failed to generate token: %w", err)
}
// Create token data
now := time.Now()
token := TokenData{
Token: tokenStr,
ProjectID: projectID,
ExpiresAt: expiresAt,
IsActive: true,
RequestCount: 0,
MaxRequests: maxRequests,
CreatedAt: now,
}
// For tests only: if the store is a mock with AddToken, call it
if mockStore, ok := any(m.store).(interface{ AddToken(string, TokenData) }); ok {
mockStore.AddToken(token.Token, token)
}
return token, nil
}
// ValidateToken validates a token without incrementing usage
func (m *Manager) ValidateToken(ctx context.Context, tokenID string) (string, error) {
return m.validator.ValidateToken(ctx, tokenID)
}
// ValidateTokenWithTracking validates a token and increments usage count
func (m *Manager) ValidateTokenWithTracking(ctx context.Context, tokenID string) (string, error) {
return m.validator.ValidateTokenWithTracking(ctx, tokenID)
}
// RevokeToken revokes a token
func (m *Manager) RevokeToken(ctx context.Context, tokenID string) error {
return m.revoker.RevokeToken(ctx, tokenID)
}
// DeleteToken completely removes a token
func (m *Manager) DeleteToken(ctx context.Context, tokenID string) error {
return m.revoker.DeleteToken(ctx, tokenID)
}
// RevokeExpiredTokens revokes all expired tokens
func (m *Manager) RevokeExpiredTokens(ctx context.Context) (int, error) {
return m.revoker.RevokeExpiredTokens(ctx)
}
// RevokeProjectTokens revokes all tokens for a project
func (m *Manager) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
return m.revoker.RevokeProjectTokens(ctx, projectID)
}
// GetTokenInfo gets detailed information about a token
func (m *Manager) GetTokenInfo(ctx context.Context, tokenID string) (*TokenInfo, error) {
tokenData, err := m.store.GetTokenByID(ctx, tokenID)
if err != nil {
return nil, err
}
info := GetTokenInfo(tokenData)
return &info, nil
}
// GetTokenStats gets statistics about token usage
func (m *Manager) GetTokenStats(ctx context.Context, tokenID string) (*TokenStats, error) {
tokenData, err := m.store.GetTokenByID(ctx, tokenID)
if err != nil {
return nil, err
}
var remaining int
if tokenData.MaxRequests != nil {
remaining = *tokenData.MaxRequests - tokenData.RequestCount
if remaining < 0 {
remaining = 0
}
} else {
remaining = -1 // Unlimited
}
var timeRemaining time.Duration
if tokenData.ExpiresAt != nil {
timeRemaining = time.Until(*tokenData.ExpiresAt)
if timeRemaining < 0 {
timeRemaining = 0
}
} else {
timeRemaining = -1 // No expiration
}
stats := &TokenStats{
Token: tokenData.Token,
RequestCount: tokenData.RequestCount,
RemainingCount: remaining,
LastUsed: tokenData.LastUsedAt,
TimeRemaining: timeRemaining,
IsValid: tokenData.IsValid(),
ObfuscatedToken: ObfuscateToken(tokenData.Token),
}
return stats, nil
}
// UpdateToken updates an existing token in the store
func (m *Manager) UpdateToken(ctx context.Context, token TokenData) error {
// Validate token format first
if err := token.ValidateFormat(); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
return m.store.UpdateToken(ctx, token)
}
// UpdateTokenLimit updates the maximum allowed requests for a token
func (m *Manager) UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error {
return m.limiter.UpdateLimit(ctx, tokenID, maxRequests)
}
// ResetTokenUsage resets the usage count for a token
func (m *Manager) ResetTokenUsage(ctx context.Context, tokenID string) error {
return m.limiter.ResetUsage(ctx, tokenID)
}
// IsTokenValid checks if a token is valid
func (m *Manager) IsTokenValid(ctx context.Context, tokenID string) bool {
_, err := m.validator.ValidateToken(ctx, tokenID)
return err == nil
}
// StartAutomaticRevocation starts automatic revocation of expired tokens
func (m *Manager) StartAutomaticRevocation(interval time.Duration, logger *zap.Logger) *AutomaticRevocation {
autoRevoke := NewAutomaticRevocation(m.revoker, interval, logger)
autoRevoke.Start()
return autoRevoke
}
// GetCacheInfo returns information about the token validation cache if caching is enabled
func (m *Manager) GetCacheInfo() (string, bool) {
if !m.useCaching {
return "Caching disabled", false
}
if cachedValidator, ok := m.validator.(*CachedValidator); ok {
return cachedValidator.GetCacheInfo(), true
}
return "Cache info not available", false
}
// WithGeneratorOptions configures the token generator with new options
func (m *Manager) WithGeneratorOptions(expiration time.Duration, maxRequests *int) *Manager {
generator := m.generator.WithExpiration(expiration)
if maxRequests != nil {
generator = generator.WithMaxRequests(*maxRequests)
}
m.generator = generator
return m
}
// TokenOptions contains options for token creation
type TokenOptions struct {
// Expiration duration (0 for no expiration)
Expiration time.Duration
// Maximum requests (nil for no limit)
MaxRequests *int
// Custom metadata (implementation-dependent)
Metadata map[string]string
}
// TokenStats contains statistics about token usage
type TokenStats struct {
Token string
ObfuscatedToken string
RequestCount int
RemainingCount int // -1 means unlimited
LastUsed *time.Time
TimeRemaining time.Duration // -1 means no expiration
IsValid bool
}
package token
import (
"context"
"errors"
"fmt"
"sync"
"time"
)
var (
// ErrRateLimitExceeded is returned when a token exceeds its rate limit
ErrRateLimitExceeded = errors.New("rate limit exceeded")
// ErrLimitOperation is returned when an operation on rate limits fails
ErrLimitOperation = errors.New("limit operation failed")
)
// RateLimiter defines the interface for rate limiting
type RateLimiter interface {
// AllowRequest checks if a token is within its rate limits and updates usage
AllowRequest(ctx context.Context, tokenID string) error
// GetRemainingRequests returns the number of remaining requests for a token
GetRemainingRequests(ctx context.Context, tokenID string) (int, error)
// ResetUsage resets the usage counter for a token
ResetUsage(ctx context.Context, tokenID string) error
// UpdateLimit updates the maximum allowed requests for a token
UpdateLimit(ctx context.Context, tokenID string, maxRequests *int) error
}
// RateLimitStore defines the interface for rate limit persistence
type RateLimitStore interface {
// GetTokenByID retrieves a token by its ID
GetTokenByID(ctx context.Context, tokenID string) (TokenData, error)
// IncrementTokenUsage increments the usage count for a token
IncrementTokenUsage(ctx context.Context, tokenID string) error
// ResetTokenUsage resets the usage count for a token to zero
ResetTokenUsage(ctx context.Context, tokenID string) error
// UpdateTokenLimit updates the maximum allowed requests for a token
UpdateTokenLimit(ctx context.Context, tokenID string, maxRequests *int) error
}
// StandardRateLimiter implements RateLimiter using a persistent store
type StandardRateLimiter struct {
store RateLimitStore
}
// NewRateLimiter creates a new StandardRateLimiter with the given store
func NewRateLimiter(store RateLimitStore) *StandardRateLimiter {
return &StandardRateLimiter{
store: store,
}
}
// AllowRequest checks if a token is within its rate limits and updates usage
func (r *StandardRateLimiter) AllowRequest(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Get current token data
token, err := r.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to retrieve token: %w", err)
}
// Check if token has a rate limit
if token.MaxRequests != nil {
// Check if token has exceeded its rate limit
if token.RequestCount >= *token.MaxRequests {
return ErrRateLimitExceeded
}
}
// Increment usage count
if err := r.store.IncrementTokenUsage(ctx, tokenID); err != nil {
return fmt.Errorf("failed to update token usage: %w", err)
}
return nil
}
// GetRemainingRequests returns the number of remaining requests for a token
func (r *StandardRateLimiter) GetRemainingRequests(ctx context.Context, tokenID string) (int, error) {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return 0, fmt.Errorf("invalid token format: %w", err)
}
// Get current token data
token, err := r.store.GetTokenByID(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return 0, ErrTokenNotFound
}
return 0, fmt.Errorf("failed to retrieve token: %w", err)
}
// If token has no limit, return a high number
if token.MaxRequests == nil {
return 1000000000, nil // Unlimited
}
// Calculate remaining requests
remaining := *token.MaxRequests - token.RequestCount
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// ResetUsage resets the usage counter for a token
func (r *StandardRateLimiter) ResetUsage(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Reset token usage
if err := r.store.ResetTokenUsage(ctx, tokenID); err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to reset token usage: %w", err)
}
return nil
}
// UpdateLimit updates the maximum allowed requests for a token
func (r *StandardRateLimiter) UpdateLimit(ctx context.Context, tokenID string, maxRequests *int) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Update token limit
if err := r.store.UpdateTokenLimit(ctx, tokenID, maxRequests); err != nil {
if errors.Is(err, ErrTokenNotFound) {
return ErrTokenNotFound
}
return fmt.Errorf("failed to update token limit: %w", err)
}
return nil
}
// MemoryRateLimiter implements in-memory rate limiting with token bucket algorithm
type MemoryRateLimiter struct {
// Mapping from token ID to rate limit data
limits map[string]*TokenBucket
limitsMutex sync.RWMutex
// Default rate (tokens per second)
defaultRate float64
// Default capacity
defaultCapacity int
}
// TokenBucket represents a token bucket rate limiter for a specific token
type TokenBucket struct {
// Current number of tokens in the bucket
tokens float64
// Maximum number of tokens the bucket can hold
capacity int
// Rate at which tokens are added to the bucket (tokens per second)
rate float64
// Last time the bucket was refilled
lastRefill time.Time
// Mutex to protect the bucket from concurrent access
mutex sync.Mutex
}
// NewMemoryRateLimiter creates a new in-memory rate limiter with token bucket algorithm
func NewMemoryRateLimiter(defaultRate float64, defaultCapacity int) *MemoryRateLimiter {
return &MemoryRateLimiter{
limits: make(map[string]*TokenBucket),
defaultRate: defaultRate,
defaultCapacity: defaultCapacity,
}
}
// Allow checks if a request is allowed for a token
func (m *MemoryRateLimiter) Allow(tokenID string) bool {
m.limitsMutex.RLock()
bucket, exists := m.limits[tokenID]
m.limitsMutex.RUnlock()
if !exists {
// Create a new bucket for this token
bucket = &TokenBucket{
tokens: float64(m.defaultCapacity),
capacity: m.defaultCapacity,
rate: m.defaultRate,
lastRefill: time.Now(),
}
m.limitsMutex.Lock()
m.limits[tokenID] = bucket
m.limitsMutex.Unlock()
return true // First request is always allowed
}
// Try to consume a token from the bucket
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
// Refill the bucket based on elapsed time
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.lastRefill = now
bucket.tokens += elapsed * bucket.rate
if bucket.tokens > float64(bucket.capacity) {
bucket.tokens = float64(bucket.capacity)
}
// Check if we have a token to consume
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0
return true
}
return false
}
// SetLimit sets the rate limit for a specific token
func (m *MemoryRateLimiter) SetLimit(tokenID string, rate float64, capacity int) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
bucket, exists := m.limits[tokenID]
if !exists {
bucket = &TokenBucket{
tokens: float64(capacity),
capacity: capacity,
rate: rate,
lastRefill: time.Now(),
}
m.limits[tokenID] = bucket
return
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
// Refill first to avoid losing tokens
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.lastRefill = now
bucket.tokens += elapsed * bucket.rate
// Update rate and capacity
bucket.rate = rate
// Adjust tokens if capacity changed
if capacity < bucket.capacity && bucket.tokens > float64(capacity) {
bucket.tokens = float64(capacity)
}
bucket.capacity = capacity
}
// GetLimit gets the current rate limit and capacity for a token
func (m *MemoryRateLimiter) GetLimit(tokenID string) (float64, int, bool) {
m.limitsMutex.RLock()
defer m.limitsMutex.RUnlock()
bucket, exists := m.limits[tokenID]
if !exists {
return m.defaultRate, m.defaultCapacity, false
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
return bucket.rate, bucket.capacity, true
}
// Reset resets the rate limit for a token
func (m *MemoryRateLimiter) Reset(tokenID string) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
bucket, exists := m.limits[tokenID]
if !exists {
return
}
bucket.mutex.Lock()
defer bucket.mutex.Unlock()
bucket.tokens = float64(bucket.capacity)
bucket.lastRefill = time.Now()
}
// Remove removes rate limit data for a token
func (m *MemoryRateLimiter) Remove(tokenID string) {
m.limitsMutex.Lock()
defer m.limitsMutex.Unlock()
delete(m.limits, tokenID)
}
package token
import (
"context"
"time"
"github.com/redis/go-redis/v9"
)
// RedisGoRateLimitAdapter adapts go-redis/v9 Client to the RedisRateLimitClient interface.
type RedisGoRateLimitAdapter struct {
Client *redis.Client
}
// NewRedisGoRateLimitAdapter creates a new adapter for rate limiting operations
func NewRedisGoRateLimitAdapter(client *redis.Client) *RedisGoRateLimitAdapter {
return &RedisGoRateLimitAdapter{Client: client}
}
// Incr atomically increments a key and returns the new value
func (a *RedisGoRateLimitAdapter) Incr(ctx context.Context, key string) (int64, error) {
return a.Client.Incr(ctx, key).Result()
}
// Get retrieves the value of a key
func (a *RedisGoRateLimitAdapter) Get(ctx context.Context, key string) (string, error) {
result, err := a.Client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return result, err
}
// Set sets the value of a key
func (a *RedisGoRateLimitAdapter) Set(ctx context.Context, key, value string) error {
return a.Client.Set(ctx, key, value, 0).Err()
}
// Expire sets a TTL on a key
func (a *RedisGoRateLimitAdapter) Expire(ctx context.Context, key string, expiration time.Duration) error {
return a.Client.Expire(ctx, key, expiration).Err()
}
// SetNX sets a key only if it doesn't exist (for distributed locking)
func (a *RedisGoRateLimitAdapter) SetNX(ctx context.Context, key string, value string, expiration time.Duration) (bool, error) {
return a.Client.SetNX(ctx, key, value, expiration).Result()
}
// Del deletes a key
func (a *RedisGoRateLimitAdapter) Del(ctx context.Context, key string) error {
return a.Client.Del(ctx, key).Err()
}
package token
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strconv"
"sync"
"time"
)
// ErrRedisUnavailable is returned when Redis is unavailable and fallback is disabled
var ErrRedisUnavailable = errors.New("redis unavailable for rate limiting")
// RedisRateLimitClient defines the Redis operations needed for distributed rate limiting.
// This is a subset of the eventbus.RedisClient interface focused on rate limiting operations.
type RedisRateLimitClient interface {
// Incr atomically increments a key and returns the new value
Incr(ctx context.Context, key string) (int64, error)
// Get retrieves the value of a key
Get(ctx context.Context, key string) (string, error)
// Set sets the value of a key
Set(ctx context.Context, key, value string) error
// Expire sets a TTL on a key
Expire(ctx context.Context, key string, expiration time.Duration) error
// SetNX sets a key only if it doesn't exist.
// Included for interface compatibility with eventbus.RedisClient and potential future enhancements.
SetNX(ctx context.Context, key string, value string, expiration time.Duration) (bool, error)
// Del deletes a key
Del(ctx context.Context, key string) error
}
// RedisRateLimiterConfig contains configuration for the Redis rate limiter
type RedisRateLimiterConfig struct {
// KeyPrefix is the prefix for all Redis keys used by the rate limiter
KeyPrefix string
// KeyHashSecret is the HMAC secret for hashing token IDs in Redis keys.
// When set, token IDs are hashed using HMAC-SHA256 to prevent cleartext exposure.
// This is recommended for production deployments to enhance security.
KeyHashSecret []byte
// DefaultWindowDuration is the default sliding window duration for rate limiting
DefaultWindowDuration time.Duration
// DefaultMaxRequests is the default maximum requests per window
DefaultMaxRequests int
// EnableFallback enables fallback to in-memory rate limiting when Redis is unavailable
EnableFallback bool
// FallbackRate is the rate for fallback in-memory token bucket (tokens per second)
FallbackRate float64
// FallbackCapacity is the capacity for fallback in-memory token bucket
FallbackCapacity int
}
// DefaultRedisRateLimiterConfig returns default configuration
func DefaultRedisRateLimiterConfig() RedisRateLimiterConfig {
return RedisRateLimiterConfig{
KeyPrefix: "ratelimit:",
DefaultWindowDuration: time.Minute,
DefaultMaxRequests: 60,
EnableFallback: true,
FallbackRate: 1.0, // 1 token per second
FallbackCapacity: 10,
}
}
// RedisRateLimiter implements distributed rate limiting using Redis.
// It uses a sliding window counter algorithm with Redis INCR for atomic operations.
type RedisRateLimiter struct {
client RedisRateLimitClient
config RedisRateLimiterConfig
fallback *MemoryRateLimiter
// Track Redis availability
redisAvailable bool
redisAvailableMu sync.RWMutex
// Per-token limits (optional override of defaults)
tokenLimits map[string]*TokenRateLimit
tokenLimitsMu sync.RWMutex
}
// TokenRateLimit holds rate limit configuration for a specific token
type TokenRateLimit struct {
MaxRequests int
WindowDuration time.Duration
}
// NewRedisRateLimiter creates a new distributed rate limiter using Redis
func NewRedisRateLimiter(client RedisRateLimitClient, config RedisRateLimiterConfig) *RedisRateLimiter {
limiter := &RedisRateLimiter{
client: client,
config: config,
redisAvailable: true,
tokenLimits: make(map[string]*TokenRateLimit),
}
// Create fallback in-memory limiter if enabled
if config.EnableFallback {
limiter.fallback = NewMemoryRateLimiter(config.FallbackRate, config.FallbackCapacity)
}
return limiter
}
// buildKey constructs the Redis key for a token's rate limit counter.
// If KeyHashSecret is configured, the token ID is hashed using HMAC-SHA256
// to prevent cleartext exposure in Redis keys.
func (r *RedisRateLimiter) buildKey(tokenID string, windowStart int64) string {
keyID := tokenID
if len(r.config.KeyHashSecret) > 0 {
keyID = hashTokenID(tokenID, r.config.KeyHashSecret)
}
return fmt.Sprintf("%s%s:%d", r.config.KeyPrefix, keyID, windowStart)
}
// hashTokenID generates a non-reversible identifier from a token ID using HMAC-SHA256.
// Returns the first 16 hex characters of the HMAC for brevity while maintaining uniqueness.
func hashTokenID(tokenID string, secret []byte) string {
h := hmac.New(sha256.New, secret)
h.Write([]byte(tokenID))
return hex.EncodeToString(h.Sum(nil))[:16]
}
// getWindowStart returns the start timestamp for the current window
func (r *RedisRateLimiter) getWindowStart(windowDuration time.Duration) int64 {
now := time.Now()
return now.Truncate(windowDuration).Unix()
}
// getTokenLimit returns the rate limit for a token, using defaults if not set
func (r *RedisRateLimiter) getTokenLimit(tokenID string) (int, time.Duration) {
r.tokenLimitsMu.RLock()
limit, exists := r.tokenLimits[tokenID]
r.tokenLimitsMu.RUnlock()
if exists && limit != nil {
return limit.MaxRequests, limit.WindowDuration
}
return r.config.DefaultMaxRequests, r.config.DefaultWindowDuration
}
// Allow checks if a request from the given token should be allowed.
// Returns true if the request is within rate limits, false otherwise.
func (r *RedisRateLimiter) Allow(ctx context.Context, tokenID string) (bool, error) {
maxRequests, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
return r.handleFallback(tokenID)
}
// Try to increment counter atomically in Redis
count, err := r.client.Incr(ctx, key)
if err != nil {
// Redis operation failed
r.markRedisUnavailable()
return r.handleFallback(tokenID)
}
// Mark Redis as available (successful operation)
r.markRedisAvailable()
// Set expiration on first increment (count == 1)
if count == 1 {
// Set TTL slightly longer than window to handle edge cases
ttl := windowDuration + time.Second
// Ignore expire errors for graceful degradation; orphaned keys will be cleaned up by Redis eventually.
_ = r.client.Expire(ctx, key, ttl)
}
// Check if request is within limit
return count <= int64(maxRequests), nil
}
// handleFallback handles rate limiting when Redis is unavailable
func (r *RedisRateLimiter) handleFallback(tokenID string) (bool, error) {
if !r.config.EnableFallback || r.fallback == nil {
return false, ErrRedisUnavailable
}
return r.fallback.Allow(tokenID), nil
}
// markRedisUnavailable marks Redis as unavailable
func (r *RedisRateLimiter) markRedisUnavailable() {
r.redisAvailableMu.Lock()
r.redisAvailable = false
r.redisAvailableMu.Unlock()
}
// markRedisAvailable marks Redis as available
func (r *RedisRateLimiter) markRedisAvailable() {
r.redisAvailableMu.Lock()
r.redisAvailable = true
r.redisAvailableMu.Unlock()
}
// GetRemainingRequests returns the number of remaining requests for a token in the current window
func (r *RedisRateLimiter) GetRemainingRequests(ctx context.Context, tokenID string) (int, error) {
maxRequests, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
if r.config.EnableFallback {
// Return a reasonable default when in fallback mode
return r.config.FallbackCapacity, nil
}
return 0, ErrRedisUnavailable
}
// Get current count from Redis
countStr, err := r.client.Get(ctx, key)
if err != nil {
// Redis error - mark unavailable and use fallback
r.markRedisUnavailable()
if r.config.EnableFallback {
return r.config.FallbackCapacity, nil
}
return 0, fmt.Errorf("failed to get rate limit counter: %w", err)
}
// Key doesn't exist (empty string returned, no error) - all requests remain
if countStr == "" {
return maxRequests, nil
}
// Parse count using strconv.Atoi for idiomatic integer parsing
count, parseErr := strconv.Atoi(countStr)
if parseErr != nil {
count = 0
}
remaining := maxRequests - count
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// SetTokenLimit sets a custom rate limit for a specific token
func (r *RedisRateLimiter) SetTokenLimit(tokenID string, maxRequests int, windowDuration time.Duration) {
r.tokenLimitsMu.Lock()
r.tokenLimits[tokenID] = &TokenRateLimit{
MaxRequests: maxRequests,
WindowDuration: windowDuration,
}
r.tokenLimitsMu.Unlock()
}
// RemoveTokenLimit removes the custom rate limit for a token (falls back to defaults)
func (r *RedisRateLimiter) RemoveTokenLimit(tokenID string) {
r.tokenLimitsMu.Lock()
delete(r.tokenLimits, tokenID)
r.tokenLimitsMu.Unlock()
}
// ResetTokenUsage resets the rate limit counter for a token
func (r *RedisRateLimiter) ResetTokenUsage(ctx context.Context, tokenID string) error {
_, windowDuration := r.getTokenLimit(tokenID)
windowStart := r.getWindowStart(windowDuration)
key := r.buildKey(tokenID, windowStart)
// Check Redis availability
r.redisAvailableMu.RLock()
available := r.redisAvailable
r.redisAvailableMu.RUnlock()
if !available {
if r.config.EnableFallback && r.fallback != nil {
r.fallback.Reset(tokenID)
return nil
}
return ErrRedisUnavailable
}
if err := r.client.Del(ctx, key); err != nil {
r.markRedisUnavailable()
if r.config.EnableFallback && r.fallback != nil {
r.fallback.Reset(tokenID)
return nil
}
return fmt.Errorf("failed to reset rate limit counter: %w", err)
}
r.markRedisAvailable()
return nil
}
// IsRedisAvailable returns whether Redis is currently available
func (r *RedisRateLimiter) IsRedisAvailable() bool {
r.redisAvailableMu.RLock()
defer r.redisAvailableMu.RUnlock()
return r.redisAvailable
}
// CheckRedisHealth performs a health check on the Redis connection
func (r *RedisRateLimiter) CheckRedisHealth(ctx context.Context) error {
// Try a simple operation to verify Redis is working
testKey := r.config.KeyPrefix + "_healthcheck"
if err := r.client.Set(ctx, testKey, "1"); err != nil {
r.markRedisUnavailable()
return fmt.Errorf("redis health check failed: %w", err)
}
// Cleanup
_ = r.client.Del(ctx, testKey)
r.markRedisAvailable()
return nil
}
package token
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
)
var (
// ErrTokenAlreadyRevoked is returned when trying to revoke an already revoked token
ErrTokenAlreadyRevoked = errors.New("token is already revoked")
)
// RevocationStore defines the interface for token revocation
type RevocationStore interface {
// RevokeToken disables a token by setting is_active to false
RevokeToken(ctx context.Context, tokenID string) error
// DeleteToken completely removes a token from storage
DeleteToken(ctx context.Context, tokenID string) error
// RevokeBatchTokens revokes multiple tokens at once
RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error)
// RevokeProjectTokens revokes all tokens for a project
RevokeProjectTokens(ctx context.Context, projectID string) (int, error)
// RevokeExpiredTokens revokes all tokens that have expired
RevokeExpiredTokens(ctx context.Context) (int, error)
}
// Revoker provides methods for token revocation
type Revoker struct {
store RevocationStore
}
// NewRevoker creates a new token revoker with the given store
func NewRevoker(store RevocationStore) *Revoker {
return &Revoker{
store: store,
}
}
// RevokeToken soft revokes a token by setting is_active to false
func (r *Revoker) RevokeToken(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Attempt to revoke the token
err := r.store.RevokeToken(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return fmt.Errorf("cannot revoke: %w", err)
}
if errors.Is(err, ErrTokenAlreadyRevoked) {
return err
}
return fmt.Errorf("failed to revoke token: %w", err)
}
return nil
}
// DeleteToken completely removes a token from storage (hard delete)
func (r *Revoker) DeleteToken(ctx context.Context, tokenID string) error {
// Validate token format first
if err := ValidateTokenFormat(tokenID); err != nil {
return fmt.Errorf("invalid token format: %w", err)
}
// Attempt to delete the token
err := r.store.DeleteToken(ctx, tokenID)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return fmt.Errorf("cannot delete: %w", err)
}
return fmt.Errorf("failed to delete token: %w", err)
}
return nil
}
// RevokeBatchTokens revokes multiple tokens in a single operation
func (r *Revoker) RevokeBatchTokens(ctx context.Context, tokenIDs []string) (int, error) {
if len(tokenIDs) == 0 {
return 0, nil
}
// Validate all token formats first
for _, tokenID := range tokenIDs {
if err := ValidateTokenFormat(tokenID); err != nil {
return 0, fmt.Errorf("invalid token format for %s: %w", tokenID, err)
}
}
// Revoke the tokens
count, err := r.store.RevokeBatchTokens(ctx, tokenIDs)
if err != nil {
return 0, fmt.Errorf("failed to revoke tokens in batch: %w", err)
}
return count, nil
}
// RevokeProjectTokens revokes all tokens for a project
func (r *Revoker) RevokeProjectTokens(ctx context.Context, projectID string) (int, error) {
if projectID == "" {
return 0, errors.New("project ID cannot be empty")
}
count, err := r.store.RevokeProjectTokens(ctx, projectID)
if err != nil {
return 0, fmt.Errorf("failed to revoke project tokens: %w", err)
}
return count, nil
}
// RevokeExpiredTokens revokes all tokens that have expired
func (r *Revoker) RevokeExpiredTokens(ctx context.Context) (int, error) {
count, err := r.store.RevokeExpiredTokens(ctx)
if err != nil {
return 0, fmt.Errorf("failed to revoke expired tokens: %w", err)
}
return count, nil
}
// AutomaticRevocation sets up periodic revocation of expired tokens
type AutomaticRevocation struct {
revoker *Revoker
interval time.Duration
stopChan chan struct{}
stoppedChan chan struct{}
logger *zap.Logger
}
// NewAutomaticRevocation creates a new automatic token revocation
func NewAutomaticRevocation(revoker *Revoker, interval time.Duration, logger *zap.Logger) *AutomaticRevocation {
return &AutomaticRevocation{
revoker: revoker,
interval: interval,
stopChan: make(chan struct{}),
stoppedChan: make(chan struct{}),
logger: logger,
}
}
// Start begins the automatic revocation of expired tokens
func (a *AutomaticRevocation) Start() {
go func() {
ticker := time.NewTicker(a.interval)
defer ticker.Stop()
defer close(a.stoppedChan)
for {
select {
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
count, err := a.revoker.RevokeExpiredTokens(ctx)
if err != nil {
a.logger.Error("Failed to automatically revoke expired tokens", zap.Error(err))
} else if count > 0 {
a.logger.Info("Automatically revoked expired tokens", zap.Int("count", count))
}
cancel()
case <-a.stopChan:
return
}
}
}()
}
// Stop halts the automatic revocation
func (a *AutomaticRevocation) Stop() {
close(a.stopChan)
<-a.stoppedChan
}
package token
import (
"encoding/base64"
"errors"
"fmt"
"regexp"
"strings"
"github.com/google/uuid"
)
const (
// TokenPrefix is the prefix for all tokens
TokenPrefix = "sk-"
// TokenRegex is the regular expression for validating token format
TokenRegexPattern = `^sk-[A-Za-z0-9_-]{22}$`
)
var (
// TokenRegex is the compiled regular expression for token format validation
TokenRegex = regexp.MustCompile(TokenRegexPattern)
// ErrInvalidTokenFormat is returned when the token format is invalid
ErrInvalidTokenFormat = errors.New("invalid token format")
// ErrTokenDecodingFailed is returned when the token cannot be decoded
ErrTokenDecodingFailed = errors.New("token decoding failed")
)
// GenerateToken generates a new token with the provided prefix and a UUIDv7.
// The UUIDv7 includes the current timestamp, making tokens time-ordered.
func GenerateToken() (string, error) {
// Generate a UUIDv7 which includes a timestamp component
id, err := uuid.NewV7()
if err != nil {
return "", fmt.Errorf("failed to generate UUID: %w", err)
}
// Convert UUID to Base64 URL-safe encoding
uuidBytes, err := id.MarshalBinary()
if err != nil {
return "", fmt.Errorf("failed to marshal UUID: %w", err)
}
// Use URL-safe base64 encoding without padding
encoded := base64.RawURLEncoding.EncodeToString(uuidBytes)
// Combine prefix with encoded UUID
token := TokenPrefix + encoded
return token, nil
}
// ValidateTokenFormat checks if the given token string follows the expected format.
// It does not check if the token exists or is valid in the database.
func ValidateTokenFormat(token string) error {
// Fast-path structural validation (avoid regex + decode double work).
// Tokens are: "sk-" + base64url(UUID bytes) where UUID is 16 bytes => 22 base64url chars.
if len(token) != len(TokenPrefix)+22 {
return ErrInvalidTokenFormat
}
if !strings.HasPrefix(token, TokenPrefix) {
return ErrInvalidTokenFormat
}
// Attempt to decode the token to ensure it was properly generated.
// This implicitly validates the base64url charset and length.
if _, err := DecodeToken(token); err != nil {
if errors.Is(err, ErrInvalidTokenFormat) {
return ErrInvalidTokenFormat
}
return fmt.Errorf("%w: %v", ErrTokenDecodingFailed, err)
}
return nil
}
// DecodeToken extracts the UUID from a token string.
func DecodeToken(token string) (uuid.UUID, error) {
// Check if the token has the correct prefix
if !strings.HasPrefix(token, TokenPrefix) {
return uuid.UUID{}, ErrInvalidTokenFormat
}
// Remove the prefix
encodedPart := strings.TrimPrefix(token, TokenPrefix)
// Decode from base64
uuidBytes, err := base64.RawURLEncoding.DecodeString(encodedPart)
if err != nil {
return uuid.UUID{}, fmt.Errorf("failed to decode token: %w", err)
}
// Parse the UUID
var id uuid.UUID
if err := id.UnmarshalBinary(uuidBytes); err != nil {
return uuid.UUID{}, fmt.Errorf("failed to unmarshal UUID: %w", err)
}
return id, nil
}
package token
import (
"context"
"sync"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
"go.uber.org/zap"
)
// UsageStatsStore defines the interface for persisting per-token usage stats.
//
// Implementations are expected to treat tokenID as the token string (sk-...).
type UsageStatsStore interface {
// IncrementTokenUsageBatch increments request_count for multiple tokens and updates last_used_at.
// The deltas map has token strings as keys and increment values as values.
IncrementTokenUsageBatch(ctx context.Context, deltas map[string]int, lastUsedAt time.Time) error
}
// UsageStatsAggregatorConfig holds configuration for the usage stats aggregator.
type UsageStatsAggregatorConfig struct {
BufferSize int // Size of buffered channel (default: 1000)
FlushInterval time.Duration // How often to flush stats (default: 5s)
BatchSize int // Max events before flush (default: 100)
}
// DefaultUsageStatsAggregatorConfig returns default config.
func DefaultUsageStatsAggregatorConfig() UsageStatsAggregatorConfig {
return UsageStatsAggregatorConfig{
BufferSize: 1000,
FlushInterval: 5 * time.Second,
BatchSize: 100,
}
}
// UsageStatsAggregator aggregates token usage events and periodically flushes them.
// It uses a buffered channel for non-blocking enqueue and drops events when the buffer is full.
type UsageStatsAggregator struct {
config UsageStatsAggregatorConfig
store UsageStatsStore
logger *zap.Logger
eventsCh chan string // token strings
stopCh chan struct{}
doneCh chan struct{}
mu sync.RWMutex
stopped bool
}
// NewUsageStatsAggregator creates a new usage stats aggregator.
func NewUsageStatsAggregator(config UsageStatsAggregatorConfig, store UsageStatsStore, logger *zap.Logger) *UsageStatsAggregator {
cfg := config
if cfg.BufferSize <= 0 {
cfg.BufferSize = 1000
}
if cfg.FlushInterval <= 0 {
cfg.FlushInterval = 5 * time.Second
}
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100
}
if logger == nil {
logger = zap.NewNop()
}
return &UsageStatsAggregator{
config: cfg,
store: store,
logger: logger,
eventsCh: make(chan string, cfg.BufferSize),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
}
// Start begins the background aggregation worker.
func (a *UsageStatsAggregator) Start() {
go a.run()
}
// Stop gracefully shuts down the aggregator, flushing any pending stats.
func (a *UsageStatsAggregator) Stop(ctx context.Context) error {
a.mu.Lock()
if a.stopped {
a.mu.Unlock()
return nil
}
a.stopped = true
a.mu.Unlock()
close(a.stopCh)
select {
case <-a.doneCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// RecordTokenUsage enqueues a token usage event for the given token string.
// This is non-blocking; if the buffer is full, the event is dropped.
func (a *UsageStatsAggregator) RecordTokenUsage(tokenString string) {
if tokenString == "" {
return
}
a.mu.RLock()
stopped := a.stopped
a.mu.RUnlock()
if stopped {
return
}
select {
case <-a.stopCh:
return
case a.eventsCh <- tokenString:
// enqueued
default:
a.logger.Debug("usage stats buffer full, dropping event",
zap.String("token", obfuscate.ObfuscateTokenGeneric(tokenString)))
}
}
func (a *UsageStatsAggregator) run() {
defer close(a.doneCh)
ticker := time.NewTicker(a.config.FlushInterval)
defer ticker.Stop()
deltas := make(map[string]int)
eventCount := 0
flush := func() {
if eventCount == 0 {
return
}
snapshot := make(map[string]int, len(deltas))
for tokenID, delta := range deltas {
snapshot[tokenID] = delta
}
deltas = make(map[string]int)
eventCount = 0
now := time.Now().UTC()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := a.store.IncrementTokenUsageBatch(ctx, snapshot, now)
cancel()
if err != nil {
a.logger.Error("failed to flush usage stats batch", zap.Error(err))
}
}
for {
select {
case <-a.stopCh:
// Drain any queued events before the final flush so we don't lose
// events that were successfully enqueued but not yet processed.
for {
select {
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
default:
flush()
return
}
}
case <-ticker.C:
flush()
case tokenID := <-a.eventsCh:
deltas[tokenID]++
eventCount++
if eventCount >= a.config.BatchSize {
flush()
}
}
}
}
package token
import (
"crypto/rand"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// Constants for token options and generation
const (
// MinTokenLength is the minimum acceptable token length
MinTokenLength = 20
// DefaultTokenExpiration is the default expiration for tokens (30 days)
DefaultTokenExpiration = 30 * 24 * time.Hour
// DefaultMaxRequests is the default maximum requests per token (unlimited)
DefaultMaxRequests = 0 // 0 means unlimited
)
// TokenGenerator is a utility for generating tokens with specific options
type TokenGenerator struct {
// Default expiration time for new tokens
DefaultExpiration time.Duration
// Default maximum requests for new tokens (nil means unlimited)
DefaultMaxRequests *int
}
// NewTokenGenerator creates a new TokenGenerator with default options
func NewTokenGenerator() *TokenGenerator {
return &TokenGenerator{
DefaultExpiration: DefaultTokenExpiration,
DefaultMaxRequests: nil, // Unlimited by default
}
}
// WithExpiration sets the default expiration for new tokens
func (g *TokenGenerator) WithExpiration(expiration time.Duration) *TokenGenerator {
g.DefaultExpiration = expiration
return g
}
// WithMaxRequests sets the default maximum requests for new tokens
func (g *TokenGenerator) WithMaxRequests(maxRequests int) *TokenGenerator {
g.DefaultMaxRequests = &maxRequests
return g
}
// Generate generates a new token with default options
func (g *TokenGenerator) Generate() (string, error) {
return GenerateToken()
}
// GenerateWithOptions generates a new token with specific options
func (g *TokenGenerator) GenerateWithOptions(expiration time.Duration, maxRequests *int) (string, *time.Time, *int, error) {
// Generate token
token, err := GenerateToken()
if err != nil {
return "", nil, nil, err
}
// Calculate expiration
var expiresAt *time.Time
if expiration > 0 {
exp := CalculateExpiration(expiration)
expiresAt = exp
} else if g.DefaultExpiration > 0 {
exp := CalculateExpiration(g.DefaultExpiration)
expiresAt = exp
}
// Determine max requests
var maxReq *int
if maxRequests != nil {
maxReq = maxRequests
} else {
maxReq = g.DefaultMaxRequests
}
return token, expiresAt, maxReq, nil
}
// ExtractTokenFromHeader extracts a token from an HTTP Authorization header
func ExtractTokenFromHeader(header string) (string, bool) {
if header == "" {
return "", false
}
// Check for "Bearer" auth scheme
parts := strings.Split(header, " ")
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return "", false
}
token := parts[1]
if token == "" {
return "", false
}
// Validate the token format
if err := ValidateTokenFormat(token); err != nil {
return "", false
}
return token, true
}
// ExtractTokenFromRequest extracts a token from an HTTP request
func ExtractTokenFromRequest(r *http.Request) (string, bool) {
// Try Authorization header first
token, ok := ExtractTokenFromHeader(r.Header.Get("Authorization"))
if ok {
return token, true
}
// Try X-API-Key header next
apiKey := r.Header.Get("X-API-Key")
if apiKey != "" {
if err := ValidateTokenFormat(apiKey); err == nil {
return apiKey, true
}
}
// Try query parameter last
queryToken := r.URL.Query().Get("token")
if queryToken != "" {
if err := ValidateTokenFormat(queryToken); err == nil {
return queryToken, true
}
}
return "", false
}
// GenerateRandomKey generates a random string suitable for use as an API key
func GenerateRandomKey(length int) (string, error) {
if length < MinTokenLength {
length = MinTokenLength
}
// Character set for random key
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
charsetLen := big.NewInt(int64(len(charset)))
// Build the random string
b := make([]byte, length)
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, charsetLen)
if err != nil {
return "", err
}
b[i] = charset[n.Int64()]
}
return string(b), nil
}
// TruncateToken truncates a token string for display, preserving the prefix and suffix
func TruncateToken(token string, showChars int) string {
if token == "" || len(token) <= showChars*2 {
return token
}
prefix := token[:showChars]
suffix := token[len(token)-showChars:]
return prefix + "..." + suffix
}
// ObfuscateToken partially obfuscates a token for display purposes
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenByPrefix(token, TokenPrefix) }
// TokenInfo represents information about a token for display purposes
type TokenInfo struct {
Token string `json:"token"`
ObfuscatedToken string `json:"obfuscated_token"`
CreationTime time.Time `json:"creation_time"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
IsActive bool `json:"is_active"`
RequestCount int `json:"request_count"`
MaxRequests *int `json:"max_requests,omitempty"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
TimeRemaining string `json:"time_remaining,omitempty"`
IsValid bool `json:"is_valid"`
}
// GetTokenInfo creates a TokenInfo struct with token details
func GetTokenInfo(token TokenData) TokenInfo {
// Use the canonical creation time from the database
// Note: UUIDv7 embeds a timestamp, but the library does not expose it; CreatedAt is the source of truth
creationTime := token.CreatedAt
info := TokenInfo{
Token: token.Token,
ObfuscatedToken: ObfuscateToken(token.Token),
CreationTime: creationTime,
ExpiresAt: token.ExpiresAt,
IsActive: token.IsActive,
RequestCount: token.RequestCount,
MaxRequests: token.MaxRequests,
LastUsedAt: token.LastUsedAt,
IsValid: token.IsActive && !IsExpired(token.ExpiresAt) && !token.IsRateLimited(),
}
// Calculate time remaining
if token.ExpiresAt != nil && !IsExpired(token.ExpiresAt) {
duration := time.Until(*token.ExpiresAt)
info.TimeRemaining = formatDuration(duration)
}
return info
}
// FormatTokenInfo formats token information as a human-readable string
func FormatTokenInfo(token TokenData) string {
info := GetTokenInfo(token)
var sb strings.Builder
sb.WriteString("Token: " + info.ObfuscatedToken + "\n")
sb.WriteString("Created: " + info.CreationTime.Format(time.RFC3339) + "\n")
if info.ExpiresAt != nil {
sb.WriteString("Expires: " + info.ExpiresAt.Format(time.RFC3339))
if info.TimeRemaining != "" {
sb.WriteString(" (" + info.TimeRemaining + " remaining)")
}
sb.WriteString("\n")
} else {
sb.WriteString("Expires: Never\n")
}
sb.WriteString("Active: " + fmt.Sprintf("%t", info.IsActive) + "\n")
if info.MaxRequests != nil {
sb.WriteString(fmt.Sprintf("Requests: %d / %d\n", info.RequestCount, *info.MaxRequests))
} else {
sb.WriteString(fmt.Sprintf("Requests: %d / ∞\n", info.RequestCount))
}
if info.LastUsedAt != nil {
sb.WriteString("Last Used: " + info.LastUsedAt.Format(time.RFC3339) + "\n")
} else {
sb.WriteString("Last Used: Never\n")
}
sb.WriteString("Valid: " + fmt.Sprintf("%t", info.IsValid) + "\n")
return sb.String()
}
// formatDuration formats a duration in a human-readable way
func formatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%d seconds", int(d.Seconds()))
} else if d < time.Hour {
return fmt.Sprintf("%d minutes", int(d.Minutes()))
} else if d < 24*time.Hour {
return fmt.Sprintf("%d hours", int(d.Hours()))
} else if d < 30*24*time.Hour {
return fmt.Sprintf("%d days", int(d.Hours()/24))
} else if d < 365*24*time.Hour {
return fmt.Sprintf("%d months", int(d.Hours()/(24*30)))
}
return fmt.Sprintf("%d years", int(d.Hours()/(24*365)))
}
package token
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
)
// Errors related to token validation
var (
ErrTokenNotFound = errors.New("token not found")
ErrTokenInactive = errors.New("token is inactive")
ErrTokenExpired = errors.New("token has expired")
ErrTokenRateLimit = errors.New("token has reached rate limit")
)
// TokenValidator defines the interface for token validation
type TokenValidator interface {
// ValidateToken validates a token and returns the associated project ID
ValidateToken(ctx context.Context, token string) (string, error)
// ValidateTokenWithTracking validates a token, returns the project ID, and tracks usage
ValidateTokenWithTracking(ctx context.Context, token string) (string, error)
}
// TokenStore defines the interface for token storage and retrieval
type TokenStore interface {
// GetTokenByID retrieves a token by its UUID (for management operations)
GetTokenByID(ctx context.Context, id string) (TokenData, error)
// GetTokenByToken retrieves a token by its token string (for authentication)
GetTokenByToken(ctx context.Context, tokenString string) (TokenData, error)
// IncrementTokenUsage increments the usage count for a token by its token string
IncrementTokenUsage(ctx context.Context, tokenString string) error
// CreateToken creates a new token in the store
CreateToken(ctx context.Context, token TokenData) error
// UpdateToken updates an existing token
UpdateToken(ctx context.Context, token TokenData) error
// ListTokens retrieves all tokens from the store
ListTokens(ctx context.Context) ([]TokenData, error)
// GetTokensByProjectID retrieves all tokens for a project
GetTokensByProjectID(ctx context.Context, projectID string) ([]TokenData, error)
}
// TokenData represents the data associated with a token
type TokenData struct {
ID string // The token ID (UUID) - used for management operations
Token string // The token string (sk-...) - used for authentication
ProjectID string // The associated project ID
ExpiresAt *time.Time // When the token expires (nil for no expiration)
IsActive bool // Whether the token is active
DeactivatedAt *time.Time // When the token was deactivated (nil if not deactivated)
RequestCount int // Number of requests made with this token
MaxRequests *int // Maximum number of requests allowed (nil for unlimited)
CreatedAt time.Time // When the token was created
LastUsedAt *time.Time // When the token was last used (nil if never used)
CacheHitCount int // Number of cache hits for this token
}
// IsValid returns true if the token is active, not expired, and not rate limited
func (t *TokenData) IsValid() bool {
return t.IsActive && !IsExpired(t.ExpiresAt) && !t.IsRateLimited()
}
// IsRateLimited returns true if the token has reached its maximum number of requests
func (t *TokenData) IsRateLimited() bool {
if t.MaxRequests == nil {
return false
}
if *t.MaxRequests <= 0 {
return false
}
return t.RequestCount >= *t.MaxRequests
}
// ValidateTokenFormat checks if a token has the correct format
func (t *TokenData) ValidateFormat() error {
return ValidateTokenFormat(t.Token)
}
// StandardValidator is a validator that uses a TokenStore for validation
type StandardValidator struct {
store TokenStore
usageStatsAgg atomic.Pointer[UsageStatsAggregator]
}
type usageStatsAggregatorGetter interface {
usageStatsAggregator() *UsageStatsAggregator
}
// NewValidator creates a new StandardValidator with the given TokenStore
func NewValidator(store TokenStore) *StandardValidator {
return &StandardValidator{
store: store,
}
}
// SetUsageStatsAggregator configures an async usage stats aggregator.
//
// When set, ValidateTokenWithTracking will enqueue request_count/last_used_at updates
// for unlimited tokens (MaxRequests == nil or <= 0) instead of doing a synchronous DB write.
func (v *StandardValidator) SetUsageStatsAggregator(agg *UsageStatsAggregator) {
v.usageStatsAgg.Store(agg)
}
func (v *StandardValidator) usageStatsAggregator() *UsageStatsAggregator {
return v.usageStatsAgg.Load()
}
func (v *StandardValidator) validateTokenData(ctx context.Context, tokenString string) (TokenData, error) {
// First validate the token format
if err := ValidateTokenFormat(tokenString); err != nil {
return TokenData{}, fmt.Errorf("invalid token format: %w", err)
}
// Retrieve the token from the store by token string
tokenData, err := v.store.GetTokenByToken(ctx, tokenString)
if err != nil {
if errors.Is(err, ErrTokenNotFound) {
return TokenData{}, ErrTokenNotFound
}
return TokenData{}, fmt.Errorf("failed to retrieve token: %w", err)
}
// Check if the token is active
if !tokenData.IsActive {
return TokenData{}, ErrTokenInactive
}
// Check if the token has expired
if IsExpired(tokenData.ExpiresAt) {
return TokenData{}, ErrTokenExpired
}
return tokenData, nil
}
// ValidateToken validates a token without incrementing usage
func (v *StandardValidator) ValidateToken(ctx context.Context, tokenString string) (string, error) {
tokenData, err := v.validateTokenData(ctx, tokenString)
if err != nil {
return "", err
}
// Check if the token has reached its rate limit
if tokenData.IsRateLimited() {
return "", ErrTokenRateLimit
}
// Token is valid, return the project ID
return tokenData.ProjectID, nil
}
// ValidateTokenWithTracking validates a token and increments its usage count
func (v *StandardValidator) ValidateTokenWithTracking(ctx context.Context, tokenString string) (string, error) {
// Validate token and load token data once.
tokenData, err := v.validateTokenData(ctx, tokenString)
if err != nil {
return "", err
}
// Enforce rate limit before tracking.
if tokenData.IsRateLimited() {
return "", ErrTokenRateLimit
}
// For unlimited tokens, move request_count/last_used_at updates off the hot path.
if tokenData.MaxRequests == nil || *tokenData.MaxRequests <= 0 {
if agg := v.usageStatsAgg.Load(); agg != nil {
agg.RecordTokenUsage(tokenString)
return tokenData.ProjectID, nil
}
}
// Limited tokens (or no async aggregator configured): do synchronous tracking.
if err := v.store.IncrementTokenUsage(ctx, tokenString); err != nil {
if errors.Is(err, ErrTokenRateLimit) {
return "", ErrTokenRateLimit
}
return "", fmt.Errorf("failed to track token usage: %w", err)
}
return tokenData.ProjectID, nil
}
// ValidateTokenFormat checks if a token string has the correct format
func ValidateToken(ctx context.Context, validator TokenValidator, tokenID string) (string, error) {
return validator.ValidateToken(ctx, tokenID)
}
// ValidateTokenWithTracking validates a token and tracks its usage
func ValidateTokenWithTracking(ctx context.Context, validator TokenValidator, tokenID string) (string, error) {
return validator.ValidateTokenWithTracking(ctx, tokenID)
}
// Package utils provides common utility functions.
package utils
import (
"crypto/rand"
"encoding/hex"
"fmt"
"github.com/sofatutor/llm-proxy/internal/obfuscate"
)
// GenerateSecureToken generates a secure random token of the given length
func GenerateSecureToken(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length must be positive")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("failed to generate secure token: %w", err)
}
return hex.EncodeToString(b), nil
}
// GenerateSecureTokenMustSucceed generates a secure random token or panics
// This is useful for initialization code where failure is unrecoverable
func GenerateSecureTokenMustSucceed(length int) string {
token, err := GenerateSecureToken(length)
if err != nil {
panic(err)
}
return token
}
// Deprecated: use obfuscate.ObfuscateTokenGeneric directly at call sites.
func ObfuscateToken(token string) string { return obfuscate.ObfuscateTokenGeneric(token) }